From 43dc0f419f919ac83ce97227f536efd05e75d14f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=85=92=E4=BA=A6?= Date: Tue, 12 Aug 2025 14:00:18 +0800 Subject: [PATCH] =?UTF-8?q?=20feat:=E5=9F=BA=E4=BA=8Esse=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=20=E5=90=AF=E5=8A=A8mcp=E6=9C=8D=E5=8A=A1=E5=99=A8=20=EF=BC=88?= =?UTF-8?q?=E6=9C=AA=E6=B5=8B=E8=AF=95=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mcp/config/McpProcessSSEManager.java | 287 ++++++++++++++++++ .../ruoyi/mcp/config/McpSSEToolInvoker.java | 206 +++++++++++++ .../mcp/controller/MCPSseController.java | 66 ++++ 3 files changed, 559 insertions(+) create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpProcessSSEManager.java create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpSSEToolInvoker.java create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/controller/MCPSseController.java diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpProcessSSEManager.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpProcessSSEManager.java new file mode 100644 index 00000000..4ac1fab5 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpProcessSSEManager.java @@ -0,0 +1,287 @@ +package org.ruoyi.mcp.config; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.info.ProcessInfo; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.Disposable; + +import java.io.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + + + +@Component +public class McpProcessSSEManager { + + private final Map runningProcesses = new ConcurrentHashMap<>(); + private final Map processInfos = new ConcurrentHashMap<>(); + private final Map sseClients = new ConcurrentHashMap<>(); + private final Map sseSubscriptions = new ConcurrentHashMap<>(); + private final ObjectMapper objectMapper = new ObjectMapper(); + + @Autowired + private McpSSEToolInvoker mcpToolInvoker; + + /** + * 启动 MCP 服务器进程(SSE 模式) + */ + public boolean startMcpServer(String serverName, String command, List args, Map env) { + try { + System.out.println("准备启动 MCP 服务器 (SSE 模式): " + serverName); + + // 如果已经运行,先停止 + if (isMcpServerRunning(serverName)) { + stopMcpServer(serverName); + } + + // 构建命令 + List commandList = buildCommandList(command, args); + + // 创建 ProcessBuilder + ProcessBuilder processBuilder = new ProcessBuilder(commandList); + processBuilder.redirectErrorStream(true); + + // 设置工作目录 + String workingDir = System.getProperty("user.dir"); + processBuilder.directory(new File(workingDir)); + + // 打印调试信息 + System.out.println("=== ProcessBuilder 调试信息 ==="); + System.out.println("完整命令列表: " + commandList); + System.out.println("================================"); + + // 执行命令 + Process process = processBuilder.start(); + runningProcesses.put(serverName, process); + + ProcessInfo processInfo = new ProcessInfo(); + processInfo.setStartTime(System.currentTimeMillis()); + processInfo.setPid(getProcessId(process)); + processInfos.put(serverName, processInfo); + + // 启动日志读取线程 + ExecutorService executorService = Executors.newCachedThreadPool(); + executorService.submit(() -> readProcessOutput(serverName, process)); + + // 等待进程启动 + Thread.sleep(3000); + boolean isAlive = process.isAlive(); + + if (isAlive) { + System.out.println("✓ MCP 服务器 [" + serverName + "] 启动成功"); + // 初始化 SSE 连接 + initializeSseConnection(serverName); + } else { + System.err.println("✗ MCP 服务器 [" + serverName + "] 启动失败"); + readErrorOutput(process); + } + + return isAlive; + + } catch (Exception e) { + System.err.println("✗ 启动 MCP 服务器 [" + serverName + "] 失败: " + e.getMessage()); + e.printStackTrace(); + return false; + } + } + private String getProcessId(Process process) { + try { + return String.valueOf(process.pid()); + } catch (Exception e) { + return "unknown"; + } + } + private void readProcessOutput(String serverName, Process process) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(process.getInputStream()))) { + String line; + while ((line = reader.readLine()) != null && process.isAlive()) { + System.out.println("[" + serverName + "] " + line); + } + } catch (IOException e) { + System.err.println("Error reading output from " + serverName + ": " + e.getMessage()); + } + } + /** + * 读取错误输出 + */ + private void readErrorOutput(Process process) { + try { + InputStream errorStream = process.getErrorStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(errorStream)); + String line; + while ((line = reader.readLine()) != null) { + System.err.println("ERROR: " + line); + } + } catch (Exception e) { + System.err.println("Failed to read error output: " + e.getMessage()); + } + } + /** + * 初始化 SSE 连接 + */ + private void initializeSseConnection(String serverName) { + try { + // 创建 WebClient 用于 SSE 连接 + WebClient webClient = WebClient.builder() + .baseUrl("http://localhost:3000") // 假设默认端口 3000 + .build(); + + sseClients.put(serverName, webClient); + + // 建立 SSE 连接 + String sseUrl = "/sse/" + serverName; // SSE 端点 + + Disposable subscription = webClient.get() + .uri(sseUrl) + .accept(MediaType.TEXT_EVENT_STREAM) + .retrieve() + .bodyToFlux(String.class) + .subscribe( + event -> handleSseEvent(serverName, event), + error -> System.err.println("SSE 连接错误 [" + serverName + "]: " + error.getMessage()), + () -> System.out.println("SSE 连接完成 [" + serverName + "]") + ); + + sseSubscriptions.put(serverName, subscription); + System.out.println("✓ SSE 连接建立成功 [" + serverName + "]"); + + } catch (Exception e) { + System.err.println("✗ 建立 SSE 连接失败 [" + serverName + "]: " + e.getMessage()); + } + } + + /** + * 处理 SSE 事件 + */ + private void handleSseEvent(String serverName, String event) { + try { + System.out.println("收到来自 [" + serverName + "] 的 SSE 事件: " + event); + + // 解析 SSE 事件 + if (event.startsWith("data: ")) { + String jsonData = event.substring(6); // 移除 "data: " 前缀 + Map message = objectMapper.readValue(jsonData, Map.class); + + // 处理不同类型的事件 + String type = (String) message.get("type"); + if ("tool_response".equals(type)) { + mcpToolInvoker.handleSseResponse(serverName, message); + } else if ("tool_error".equals(type)) { + mcpToolInvoker.handleSseError(serverName, message); + } else if ("progress".equals(type)) { + handleProgressEvent(serverName, message); + } else { + System.out.println("[" + serverName + "] 未知事件类型: " + type); + } + } + + } catch (Exception e) { + System.err.println("处理 SSE 事件失败 [" + serverName + "]: " + e.getMessage()); + } + } + + /** + * 处理进度事件 + */ + private void handleProgressEvent(String serverName, Map message) { + Object progress = message.get("progress"); + Object messageText = message.get("message"); + System.out.println("[" + serverName + "] 进度: " + progress + " - " + messageText); + } + + + + /** + * 构建命令列表 + */ + private List buildCommandList(String command, List args) { + List commandList = new ArrayList<>(); + + if (isWindows() && "npx".equalsIgnoreCase(command)) { + commandList.add("cmd.exe"); + commandList.add("/c"); + commandList.add("npx"); + commandList.addAll(args); + } else { + commandList.add(command); + commandList.addAll(args); + } + + return commandList; + } + /** + * 检查是否为 Windows 系统 + */ + private boolean isWindows() { + return System.getProperty("os.name").toLowerCase().contains("windows"); + } + + /** + * 停止 MCP 服务器进程 + */ + public boolean stopMcpServer(String serverName) { + // 停止 SSE 连接 + Disposable subscription = sseSubscriptions.remove(serverName); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); + } + + sseClients.remove(serverName); + + // 停止进程 + Process process = runningProcesses.remove(serverName); + ProcessInfo processInfo = processInfos.remove(serverName); + + if (process != null && process.isAlive()) { + process.destroy(); + try { + if (!process.waitFor(10, TimeUnit.SECONDS)) { + process.destroyForcibly(); + process.waitFor(2, TimeUnit.SECONDS); + } + System.out.println("MCP 服务器 [" + serverName + "] 已停止"); + return true; + } catch (InterruptedException e) { + process.destroyForcibly(); + Thread.currentThread().interrupt(); + return false; + } + } + return false; + } + /** + * 检查 MCP 服务器是否运行 + */ + public boolean isMcpServerRunning(String serverName) { + Process process = runningProcesses.get(serverName); + return process != null && process.isAlive(); + } + /** + * 进程信息类 + */ + public static class ProcessInfo { + private String pid; + private long startTime; + + public String getPid() { return pid; } + public void setPid(String pid) { this.pid = pid; } + + public long getStartTime() { return startTime; } + public void setStartTime(long startTime) { this.startTime = startTime; } + + public long getUptime() { + return System.currentTimeMillis() - startTime; + } + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpSSEToolInvoker.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpSSEToolInvoker.java new file mode 100644 index 00000000..682f74b7 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/config/McpSSEToolInvoker.java @@ -0,0 +1,206 @@ +package org.ruoyi.mcp.config; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +@Component +public class McpSSEToolInvoker { + + + private final Map> pendingRequests = new ConcurrentHashMap<>(); + private final AtomicLong requestIdCounter = new AtomicLong(0); + + /** + * 调用 MCP 工具(SSE 模式) + */ + public Object invokeTool(String serverName, Object parameters) { + try { + // 生成请求ID + String requestId = "req_" + requestIdCounter.incrementAndGet(); + + // 创建 CompletableFuture 等待响应 + CompletableFuture future = new CompletableFuture<>(); + pendingRequests.put(requestId, future); + + // 构造 MCP 调用请求 + Map callRequest = new HashMap<>(); + callRequest.put("requestId", requestId); + callRequest.put("serverName", serverName); + callRequest.put("parameters", convertToMap(parameters)); + callRequest.put("timestamp", System.currentTimeMillis()); + + System.out.println("通过 SSE 调用 MCP 工具 [" + serverName + "] 参数: " + parameters); + + // 发送请求到 MCP 服务器(通过 HTTP POST) + sendSseToolCall(serverName, callRequest); + + // 等待响应(超时 30 秒) + Object result = future.get(30, TimeUnit.SECONDS); + + System.out.println("MCP 工具 [" + serverName + "] 调用成功,响应: " + result); + + return result; + + } catch (Exception e) { + System.err.println("调用 MCP 服务器 [" + serverName + "] 失败: " + e.getMessage()); + e.printStackTrace(); + + return Map.of( + "serverName", serverName, + "status", "failed", + "message", "Tool invocation failed: " + e.getMessage(), + "parameters", parameters + ); + } + } + + /** + * 发送 SSE 工具调用请求 + */ + private void sendSseToolCall(String serverName, Map callRequest) { + try { + // 通过 HTTP POST 发送工具调用请求 + WebClient webClient = WebClient.builder() + .baseUrl("http://localhost:3000") + .build(); + + String toolCallUrl = "/tool/" + serverName; + + webClient.post() + .uri(toolCallUrl) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(callRequest) + .retrieve() + .bodyToMono(String.class) + .timeout(Duration.ofSeconds(5)) + .subscribe( + response -> System.out.println("工具调用请求发送成功: " + response), + error -> System.err.println("工具调用请求发送失败: " + error.getMessage()) + ); + + } catch (Exception e) { + System.err.println("发送 SSE 工具调用请求失败: " + e.getMessage()); + } + } + + /** + * 处理 SSE 响应 + */ + public void handleSseResponse(String serverName, Map message) { + String requestId = (String) message.get("requestId"); + if (requestId != null) { + CompletableFuture future = pendingRequests.remove(requestId); + if (future != null) { + Object data = message.get("data"); + future.complete(data != null ? data : message); + } + } + } + + /** + * 处理 SSE 错误 + */ + public void handleSseError(String serverName, Map message) { + String requestId = (String) message.get("requestId"); + if (requestId != null) { + CompletableFuture future = pendingRequests.remove(requestId); + if (future != null) { + String errorMessage = (String) message.get("message"); + future.completeExceptionally(new RuntimeException(errorMessage)); + } + } + } + + /** + * 流式调用 MCP 工具(支持实时进度) + */ + public Flux invokeToolStream(String serverName, Object parameters) { + return Flux.create(emitter -> { + try { + // 生成请求ID + String requestId = "req_" + requestIdCounter.incrementAndGet(); + + // 构造 MCP 调用请求 + Map callRequest = new HashMap<>(); + callRequest.put("requestId", requestId); + callRequest.put("serverName", serverName); + callRequest.put("parameters", convertToMap(parameters)); + callRequest.put("stream", true); // 标记为流式调用 + callRequest.put("timestamp", System.currentTimeMillis()); + + // 创建流式处理器 + StreamHandler streamHandler = new StreamHandler(emitter); + pendingRequests.put(requestId + "_stream", null); // 占位符 + + // 发送流式调用请求 + sendSseToolCall(serverName, callRequest); + + // 注册流式处理器 + registerStreamHandler(requestId, streamHandler); + + emitter.onDispose(() -> { + // 清理资源 + pendingRequests.remove(requestId + "_stream"); + }); + + } catch (Exception e) { + emitter.error(e); + } + }); + } + + /** + * 流式处理器 + */ + private static class StreamHandler { + private final FluxSink emitter; + + public StreamHandler(FluxSink emitter) { + this.emitter = emitter; + } + + public void onNext(Object data) { + emitter.next(data); + } + + public void onComplete() { + emitter.complete(); + } + + public void onError(Throwable error) { + emitter.error(error); + } + } + + @SuppressWarnings("unchecked") + private Map convertToMap(Object parameters) { + if (parameters instanceof Map) { + Map result = new HashMap<>(); + Map paramMap = (Map) parameters; + for (Map.Entry entry : paramMap.entrySet()) { + if (entry.getKey() instanceof String) { + result.put((String) entry.getKey(), entry.getValue()); + } + } + return result; + } + return new HashMap<>(); + } + + private void registerStreamHandler(String requestId, StreamHandler streamHandler) { + // 实现流式处理器注册逻辑 + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/controller/MCPSseController.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/controller/MCPSseController.java new file mode 100644 index 00000000..0ecb2dbb --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mcp/controller/MCPSseController.java @@ -0,0 +1,66 @@ +package org.ruoyi.mcp.controller; + +import org.ruoyi.mcp.config.McpSSEToolInvoker; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.Map; + +@RestController +@RequestMapping("/api/sse") +public class MCPSseController { + + @Autowired + private McpSSEToolInvoker mcpToolInvoker; + + /** + * SSE 流式响应端点 + */ + @GetMapping(value = "/{serverName}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter streamMcpResponse(@PathVariable String serverName) { + SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); + + try { + // 发送连接建立消息 + emitter.send(SseEmitter.event() + .name("connected") + .data(Map.of("serverName", serverName, "status", "connected"))); + + } catch (Exception e) { + emitter.completeWithError(e); + } + + return emitter; + } + + /** + * 调用 MCP 工具(流式) + */ + @PostMapping("/tool/{serverName}") + public ResponseEntity callMcpTool( + @PathVariable String serverName, + @RequestBody Map request) { + + try { + boolean isStream = (Boolean) request.getOrDefault("stream", false); + Object parameters = request.get("parameters"); + + if (isStream) { + // 流式调用 + return ResponseEntity.ok(Map.of("status", "streaming_started")); + } else { + // 普通调用 + Object result = mcpToolInvoker.invokeTool(serverName, parameters); + return ResponseEntity.ok(result); + } + + } catch (Exception e) { + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(Map.of("error", e.getMessage())); + } + } +}