feat:基于sse模式 启动mcp服务器 (未测试)

This commit is contained in:
酒亦
2025-08-12 14:00:18 +08:00
parent bc2eb8fdb9
commit 43dc0f419f
3 changed files with 559 additions and 0 deletions

View File

@@ -0,0 +1,287 @@
package org.ruoyi.mcp.config;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.info.ProcessInfo;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.Disposable;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@Component
public class McpProcessSSEManager {
private final Map<String, Process> runningProcesses = new ConcurrentHashMap<>();
private final Map<String, ProcessInfo> processInfos = new ConcurrentHashMap<>();
private final Map<String, WebClient> sseClients = new ConcurrentHashMap<>();
private final Map<String, Disposable> sseSubscriptions = new ConcurrentHashMap<>();
private final ObjectMapper objectMapper = new ObjectMapper();
@Autowired
private McpSSEToolInvoker mcpToolInvoker;
/**
* 启动 MCP 服务器进程SSE 模式)
*/
public boolean startMcpServer(String serverName, String command, List<String> args, Map<String, String> env) {
try {
System.out.println("准备启动 MCP 服务器 (SSE 模式): " + serverName);
// 如果已经运行,先停止
if (isMcpServerRunning(serverName)) {
stopMcpServer(serverName);
}
// 构建命令
List<String> commandList = buildCommandList(command, args);
// 创建 ProcessBuilder
ProcessBuilder processBuilder = new ProcessBuilder(commandList);
processBuilder.redirectErrorStream(true);
// 设置工作目录
String workingDir = System.getProperty("user.dir");
processBuilder.directory(new File(workingDir));
// 打印调试信息
System.out.println("=== ProcessBuilder 调试信息 ===");
System.out.println("完整命令列表: " + commandList);
System.out.println("================================");
// 执行命令
Process process = processBuilder.start();
runningProcesses.put(serverName, process);
ProcessInfo processInfo = new ProcessInfo();
processInfo.setStartTime(System.currentTimeMillis());
processInfo.setPid(getProcessId(process));
processInfos.put(serverName, processInfo);
// 启动日志读取线程
ExecutorService executorService = Executors.newCachedThreadPool();
executorService.submit(() -> readProcessOutput(serverName, process));
// 等待进程启动
Thread.sleep(3000);
boolean isAlive = process.isAlive();
if (isAlive) {
System.out.println("✓ MCP 服务器 [" + serverName + "] 启动成功");
// 初始化 SSE 连接
initializeSseConnection(serverName);
} else {
System.err.println("✗ MCP 服务器 [" + serverName + "] 启动失败");
readErrorOutput(process);
}
return isAlive;
} catch (Exception e) {
System.err.println("✗ 启动 MCP 服务器 [" + serverName + "] 失败: " + e.getMessage());
e.printStackTrace();
return false;
}
}
private String getProcessId(Process process) {
try {
return String.valueOf(process.pid());
} catch (Exception e) {
return "unknown";
}
}
private void readProcessOutput(String serverName, Process process) {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(process.getInputStream()))) {
String line;
while ((line = reader.readLine()) != null && process.isAlive()) {
System.out.println("[" + serverName + "] " + line);
}
} catch (IOException e) {
System.err.println("Error reading output from " + serverName + ": " + e.getMessage());
}
}
/**
* 读取错误输出
*/
private void readErrorOutput(Process process) {
try {
InputStream errorStream = process.getErrorStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(errorStream));
String line;
while ((line = reader.readLine()) != null) {
System.err.println("ERROR: " + line);
}
} catch (Exception e) {
System.err.println("Failed to read error output: " + e.getMessage());
}
}
/**
* 初始化 SSE 连接
*/
private void initializeSseConnection(String serverName) {
try {
// 创建 WebClient 用于 SSE 连接
WebClient webClient = WebClient.builder()
.baseUrl("http://localhost:3000") // 假设默认端口 3000
.build();
sseClients.put(serverName, webClient);
// 建立 SSE 连接
String sseUrl = "/sse/" + serverName; // SSE 端点
Disposable subscription = webClient.get()
.uri(sseUrl)
.accept(MediaType.TEXT_EVENT_STREAM)
.retrieve()
.bodyToFlux(String.class)
.subscribe(
event -> handleSseEvent(serverName, event),
error -> System.err.println("SSE 连接错误 [" + serverName + "]: " + error.getMessage()),
() -> System.out.println("SSE 连接完成 [" + serverName + "]")
);
sseSubscriptions.put(serverName, subscription);
System.out.println("✓ SSE 连接建立成功 [" + serverName + "]");
} catch (Exception e) {
System.err.println("✗ 建立 SSE 连接失败 [" + serverName + "]: " + e.getMessage());
}
}
/**
* 处理 SSE 事件
*/
private void handleSseEvent(String serverName, String event) {
try {
System.out.println("收到来自 [" + serverName + "] 的 SSE 事件: " + event);
// 解析 SSE 事件
if (event.startsWith("data: ")) {
String jsonData = event.substring(6); // 移除 "data: " 前缀
Map<String, Object> message = objectMapper.readValue(jsonData, Map.class);
// 处理不同类型的事件
String type = (String) message.get("type");
if ("tool_response".equals(type)) {
mcpToolInvoker.handleSseResponse(serverName, message);
} else if ("tool_error".equals(type)) {
mcpToolInvoker.handleSseError(serverName, message);
} else if ("progress".equals(type)) {
handleProgressEvent(serverName, message);
} else {
System.out.println("[" + serverName + "] 未知事件类型: " + type);
}
}
} catch (Exception e) {
System.err.println("处理 SSE 事件失败 [" + serverName + "]: " + e.getMessage());
}
}
/**
* 处理进度事件
*/
private void handleProgressEvent(String serverName, Map<String, Object> message) {
Object progress = message.get("progress");
Object messageText = message.get("message");
System.out.println("[" + serverName + "] 进度: " + progress + " - " + messageText);
}
/**
* 构建命令列表
*/
private List<String> buildCommandList(String command, List<String> args) {
List<String> commandList = new ArrayList<>();
if (isWindows() && "npx".equalsIgnoreCase(command)) {
commandList.add("cmd.exe");
commandList.add("/c");
commandList.add("npx");
commandList.addAll(args);
} else {
commandList.add(command);
commandList.addAll(args);
}
return commandList;
}
/**
* 检查是否为 Windows 系统
*/
private boolean isWindows() {
return System.getProperty("os.name").toLowerCase().contains("windows");
}
/**
* 停止 MCP 服务器进程
*/
public boolean stopMcpServer(String serverName) {
// 停止 SSE 连接
Disposable subscription = sseSubscriptions.remove(serverName);
if (subscription != null && !subscription.isDisposed()) {
subscription.dispose();
}
sseClients.remove(serverName);
// 停止进程
Process process = runningProcesses.remove(serverName);
ProcessInfo processInfo = processInfos.remove(serverName);
if (process != null && process.isAlive()) {
process.destroy();
try {
if (!process.waitFor(10, TimeUnit.SECONDS)) {
process.destroyForcibly();
process.waitFor(2, TimeUnit.SECONDS);
}
System.out.println("MCP 服务器 [" + serverName + "] 已停止");
return true;
} catch (InterruptedException e) {
process.destroyForcibly();
Thread.currentThread().interrupt();
return false;
}
}
return false;
}
/**
* 检查 MCP 服务器是否运行
*/
public boolean isMcpServerRunning(String serverName) {
Process process = runningProcesses.get(serverName);
return process != null && process.isAlive();
}
/**
* 进程信息类
*/
public static class ProcessInfo {
private String pid;
private long startTime;
public String getPid() { return pid; }
public void setPid(String pid) { this.pid = pid; }
public long getStartTime() { return startTime; }
public void setStartTime(long startTime) { this.startTime = startTime; }
public long getUptime() {
return System.currentTimeMillis() - startTime;
}
}
}

View File

@@ -0,0 +1,206 @@
package org.ruoyi.mcp.config;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
@Component
public class McpSSEToolInvoker {
private final Map<String, CompletableFuture<Object>> pendingRequests = new ConcurrentHashMap<>();
private final AtomicLong requestIdCounter = new AtomicLong(0);
/**
* 调用 MCP 工具SSE 模式)
*/
public Object invokeTool(String serverName, Object parameters) {
try {
// 生成请求ID
String requestId = "req_" + requestIdCounter.incrementAndGet();
// 创建 CompletableFuture 等待响应
CompletableFuture<Object> future = new CompletableFuture<>();
pendingRequests.put(requestId, future);
// 构造 MCP 调用请求
Map<String, Object> callRequest = new HashMap<>();
callRequest.put("requestId", requestId);
callRequest.put("serverName", serverName);
callRequest.put("parameters", convertToMap(parameters));
callRequest.put("timestamp", System.currentTimeMillis());
System.out.println("通过 SSE 调用 MCP 工具 [" + serverName + "] 参数: " + parameters);
// 发送请求到 MCP 服务器(通过 HTTP POST
sendSseToolCall(serverName, callRequest);
// 等待响应(超时 30 秒)
Object result = future.get(30, TimeUnit.SECONDS);
System.out.println("MCP 工具 [" + serverName + "] 调用成功,响应: " + result);
return result;
} catch (Exception e) {
System.err.println("调用 MCP 服务器 [" + serverName + "] 失败: " + e.getMessage());
e.printStackTrace();
return Map.of(
"serverName", serverName,
"status", "failed",
"message", "Tool invocation failed: " + e.getMessage(),
"parameters", parameters
);
}
}
/**
* 发送 SSE 工具调用请求
*/
private void sendSseToolCall(String serverName, Map<String, Object> callRequest) {
try {
// 通过 HTTP POST 发送工具调用请求
WebClient webClient = WebClient.builder()
.baseUrl("http://localhost:3000")
.build();
String toolCallUrl = "/tool/" + serverName;
webClient.post()
.uri(toolCallUrl)
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(callRequest)
.retrieve()
.bodyToMono(String.class)
.timeout(Duration.ofSeconds(5))
.subscribe(
response -> System.out.println("工具调用请求发送成功: " + response),
error -> System.err.println("工具调用请求发送失败: " + error.getMessage())
);
} catch (Exception e) {
System.err.println("发送 SSE 工具调用请求失败: " + e.getMessage());
}
}
/**
* 处理 SSE 响应
*/
public void handleSseResponse(String serverName, Map<String, Object> message) {
String requestId = (String) message.get("requestId");
if (requestId != null) {
CompletableFuture<Object> future = pendingRequests.remove(requestId);
if (future != null) {
Object data = message.get("data");
future.complete(data != null ? data : message);
}
}
}
/**
* 处理 SSE 错误
*/
public void handleSseError(String serverName, Map<String, Object> message) {
String requestId = (String) message.get("requestId");
if (requestId != null) {
CompletableFuture<Object> future = pendingRequests.remove(requestId);
if (future != null) {
String errorMessage = (String) message.get("message");
future.completeExceptionally(new RuntimeException(errorMessage));
}
}
}
/**
* 流式调用 MCP 工具(支持实时进度)
*/
public Flux<Object> invokeToolStream(String serverName, Object parameters) {
return Flux.create(emitter -> {
try {
// 生成请求ID
String requestId = "req_" + requestIdCounter.incrementAndGet();
// 构造 MCP 调用请求
Map<String, Object> callRequest = new HashMap<>();
callRequest.put("requestId", requestId);
callRequest.put("serverName", serverName);
callRequest.put("parameters", convertToMap(parameters));
callRequest.put("stream", true); // 标记为流式调用
callRequest.put("timestamp", System.currentTimeMillis());
// 创建流式处理器
StreamHandler streamHandler = new StreamHandler(emitter);
pendingRequests.put(requestId + "_stream", null); // 占位符
// 发送流式调用请求
sendSseToolCall(serverName, callRequest);
// 注册流式处理器
registerStreamHandler(requestId, streamHandler);
emitter.onDispose(() -> {
// 清理资源
pendingRequests.remove(requestId + "_stream");
});
} catch (Exception e) {
emitter.error(e);
}
});
}
/**
* 流式处理器
*/
private static class StreamHandler {
private final FluxSink<Object> emitter;
public StreamHandler(FluxSink<Object> emitter) {
this.emitter = emitter;
}
public void onNext(Object data) {
emitter.next(data);
}
public void onComplete() {
emitter.complete();
}
public void onError(Throwable error) {
emitter.error(error);
}
}
@SuppressWarnings("unchecked")
private Map<String, Object> convertToMap(Object parameters) {
if (parameters instanceof Map) {
Map<String, Object> result = new HashMap<>();
Map<?, ?> paramMap = (Map<?, ?>) parameters;
for (Map.Entry<?, ?> entry : paramMap.entrySet()) {
if (entry.getKey() instanceof String) {
result.put((String) entry.getKey(), entry.getValue());
}
}
return result;
}
return new HashMap<>();
}
private void registerStreamHandler(String requestId, StreamHandler streamHandler) {
// 实现流式处理器注册逻辑
}
}

View File

@@ -0,0 +1,66 @@
package org.ruoyi.mcp.controller;
import org.ruoyi.mcp.config.McpSSEToolInvoker;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Map;
@RestController
@RequestMapping("/api/sse")
public class MCPSseController {
@Autowired
private McpSSEToolInvoker mcpToolInvoker;
/**
* SSE 流式响应端点
*/
@GetMapping(value = "/{serverName}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter streamMcpResponse(@PathVariable String serverName) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
try {
// 发送连接建立消息
emitter.send(SseEmitter.event()
.name("connected")
.data(Map.of("serverName", serverName, "status", "connected")));
} catch (Exception e) {
emitter.completeWithError(e);
}
return emitter;
}
/**
* 调用 MCP 工具(流式)
*/
@PostMapping("/tool/{serverName}")
public ResponseEntity<?> callMcpTool(
@PathVariable String serverName,
@RequestBody Map<String, Object> request) {
try {
boolean isStream = (Boolean) request.getOrDefault("stream", false);
Object parameters = request.get("parameters");
if (isStream) {
// 流式调用
return ResponseEntity.ok(Map.of("status", "streaming_started"));
} else {
// 普通调用
Object result = mcpToolInvoker.invokeTool(serverName, parameters);
return ResponseEntity.ok(result);
}
} catch (Exception e) {
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(Map.of("error", e.getMessage()));
}
}
}