mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-12 19:17:20 +00:00
Compare commits
10 Commits
f63ccbe7bd
...
3f1506b34b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f1506b34b | ||
|
|
deea428bfa | ||
|
|
66f133f089 | ||
|
|
e5ed5d0ef6 | ||
|
|
f04842ae12 | ||
|
|
77aeabb4be | ||
|
|
995507e757 | ||
|
|
f155bc284d | ||
|
|
c22c5eac7f | ||
|
|
9df246321e |
@@ -47,6 +47,39 @@
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Hutool 工具类 -->
|
||||
<dependency>
|
||||
<groupId>cn.hutool</groupId>
|
||||
<artifactId>hutool-all</artifactId>
|
||||
<version>5.8.25</version>
|
||||
</dependency>
|
||||
|
||||
<!-- OkHttp for HTTP requests -->
|
||||
<dependency>
|
||||
<groupId>com.squareup.okhttp3</groupId>
|
||||
<artifactId>okhttp</artifactId>
|
||||
<version>4.12.0</version>
|
||||
</dependency>
|
||||
|
||||
<!-- PlantUML -->
|
||||
<dependency>
|
||||
<groupId>net.sourceforge.plantuml</groupId>
|
||||
<artifactId>plantuml</artifactId>
|
||||
<version>1.2024.3</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Spring AI Tika for document parsing -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-tika-document-reader</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Lombok -->
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
<dependencyManagement>
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package org.ruoyi.mcpserve;
|
||||
|
||||
import org.ruoyi.mcpserve.service.ToolService;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
/**
|
||||
* MCP Server 应用启动类
|
||||
* 工具通过 DynamicToolCallbackProvider 动态加载
|
||||
*
|
||||
* @author ageer
|
||||
*/
|
||||
@SpringBootApplication
|
||||
@@ -17,9 +16,4 @@ public class RuoyiMcpServeApplication {
|
||||
SpringApplication.run(RuoyiMcpServeApplication.class, args);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ToolCallbackProvider systemTools(ToolService toolService) {
|
||||
return MethodToolCallbackProvider.builder().toolObjects(toolService).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
package org.ruoyi.mcpserve.config;
|
||||
|
||||
import org.ruoyi.mcpserve.tools.McpTool;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 动态工具回调提供者
|
||||
* 根据配置动态加载启用的MCP工具
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class DynamicToolCallbackProvider implements ToolCallbackProvider {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(DynamicToolCallbackProvider.class);
|
||||
|
||||
private final McpToolsConfig mcpToolsConfig;
|
||||
private final List<McpTool> allTools;
|
||||
private volatile ToolCallback[] cachedCallbacks;
|
||||
|
||||
public DynamicToolCallbackProvider(McpToolsConfig mcpToolsConfig, List<McpTool> allTools) {
|
||||
this.mcpToolsConfig = mcpToolsConfig;
|
||||
this.allTools = allTools;
|
||||
log.info("发现 {} 个MCP工具", allTools.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ToolCallback[] getToolCallbacks() {
|
||||
if (cachedCallbacks == null) {
|
||||
synchronized (this) {
|
||||
if (cachedCallbacks == null) {
|
||||
cachedCallbacks = buildToolCallbacks();
|
||||
}
|
||||
}
|
||||
}
|
||||
return cachedCallbacks;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建工具回调数组
|
||||
*/
|
||||
private ToolCallback[] buildToolCallbacks() {
|
||||
List<Object> enabledTools = allTools.stream()
|
||||
.filter(tool -> {
|
||||
boolean enabled = mcpToolsConfig.isToolEnabled(tool.getToolName());
|
||||
if (enabled) {
|
||||
log.info("启用工具: {}", tool.getToolName());
|
||||
} else {
|
||||
log.info("禁用工具: {}", tool.getToolName());
|
||||
}
|
||||
return enabled;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (enabledTools.isEmpty()) {
|
||||
log.warn("没有启用任何MCP工具");
|
||||
return new ToolCallback[0];
|
||||
}
|
||||
|
||||
// 使用 MethodToolCallbackProvider 构建工具回调
|
||||
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder()
|
||||
.toolObjects(enabledTools.toArray())
|
||||
.build();
|
||||
|
||||
return provider.getToolCallbacks();
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新工具缓存,用于配置变更后重新加载
|
||||
*/
|
||||
public void refreshTools() {
|
||||
synchronized (this) {
|
||||
cachedCallbacks = null;
|
||||
log.info("工具缓存已清除,将在下次调用时重新加载");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已注册的工具名称
|
||||
*/
|
||||
public List<String> getRegisteredToolNames() {
|
||||
return allTools.stream()
|
||||
.map(McpTool::getToolName)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已启用的工具名称
|
||||
*/
|
||||
public List<String> getEnabledToolNames() {
|
||||
return allTools.stream()
|
||||
.filter(tool -> mcpToolsConfig.isToolEnabled(tool.getToolName()))
|
||||
.map(McpTool::getToolName)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package org.ruoyi.mcpserve.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* MCP工具动态配置类
|
||||
* 支持通过配置文件启用/禁用各个工具
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Data
|
||||
@Component
|
||||
@ConfigurationProperties(prefix = "mcp.tools")
|
||||
public class McpToolsConfig {
|
||||
|
||||
/**
|
||||
* 工具启用配置
|
||||
* key: 工具名称
|
||||
* value: 是否启用
|
||||
*/
|
||||
private Map<String, Boolean> enabled = new HashMap<>();
|
||||
|
||||
/**
|
||||
* 检查工具是否启用
|
||||
* 默认情况下,如果未配置则启用
|
||||
*
|
||||
* @param toolName 工具名称
|
||||
* @return 是否启用
|
||||
*/
|
||||
public boolean isToolEnabled(String toolName) {
|
||||
return enabled.getOrDefault(toolName, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* 动态启用工具
|
||||
*
|
||||
* @param toolName 工具名称
|
||||
*/
|
||||
public void enableTool(String toolName) {
|
||||
enabled.put(toolName, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* 动态禁用工具
|
||||
*
|
||||
* @param toolName 工具名称
|
||||
*/
|
||||
public void disableTool(String toolName) {
|
||||
enabled.put(toolName, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* 动态设置工具启用状态
|
||||
*
|
||||
* @param toolName 工具名称
|
||||
* @param enable 是否启用
|
||||
*/
|
||||
public void setToolEnabled(String toolName, boolean enable) {
|
||||
enabled.put(toolName, enable);
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量设置工具启用状态
|
||||
*
|
||||
* @param toolStates 工具状态映射
|
||||
*/
|
||||
public void setToolsEnabled(Map<String, Boolean> toolStates) {
|
||||
enabled.putAll(toolStates);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package org.ruoyi.mcpserve.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* 工具配置属性类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Data
|
||||
@Component
|
||||
@ConfigurationProperties(prefix = "tools")
|
||||
public class ToolsProperties {
|
||||
|
||||
/**
|
||||
* Pexels图片搜索配置
|
||||
*/
|
||||
private Pexels pexels = new Pexels();
|
||||
|
||||
/**
|
||||
* Tavily搜索配置
|
||||
*/
|
||||
private Tavily tavily = new Tavily();
|
||||
|
||||
/**
|
||||
* 文件操作配置
|
||||
*/
|
||||
private FileConfig file = new FileConfig();
|
||||
|
||||
@Data
|
||||
public static class Pexels {
|
||||
/**
|
||||
* Pexels API密钥
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* API地址
|
||||
*/
|
||||
private String apiUrl;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Tavily {
|
||||
/**
|
||||
* Tavily API密钥
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* API地址
|
||||
*/
|
||||
private String baseUrl;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class FileConfig {
|
||||
/**
|
||||
* 文件保存目录
|
||||
*/
|
||||
private String saveDir;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package org.ruoyi.mcpserve.controller;
|
||||
import org.ruoyi.mcpserve.config.DynamicToolCallbackProvider;
|
||||
import org.ruoyi.mcpserve.config.McpToolsConfig;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* MCP工具测试Controller
|
||||
* 用于查看已加载的工具信息
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/tools")
|
||||
public class ToolsController {
|
||||
|
||||
private final DynamicToolCallbackProvider toolCallbackProvider;
|
||||
private final McpToolsConfig mcpToolsConfig;
|
||||
|
||||
public ToolsController(DynamicToolCallbackProvider toolCallbackProvider, McpToolsConfig mcpToolsConfig) {
|
||||
this.toolCallbackProvider = toolCallbackProvider;
|
||||
this.mcpToolsConfig = mcpToolsConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有工具信息
|
||||
*/
|
||||
@GetMapping
|
||||
public Map<String, Object> getToolsInfo() {
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
|
||||
// 所有已注册的工具
|
||||
result.put("registered", toolCallbackProvider.getRegisteredToolNames());
|
||||
|
||||
// 已加载的工具回调详情
|
||||
List<Map<String, String>> callbacks = Stream.of(toolCallbackProvider.getToolCallbacks())
|
||||
.map(callback -> {
|
||||
Map<String, String> info = new HashMap<>();
|
||||
info.put("name", callback.getToolDefinition().name());
|
||||
info.put("description", callback.getToolDefinition().description());
|
||||
return info;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
result.put("callbacks", callbacks);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新工具缓存
|
||||
*/
|
||||
@PostMapping("/refresh")
|
||||
public Map<String, String> refreshTools() {
|
||||
toolCallbackProvider.refreshTools();
|
||||
Map<String, String> result = new HashMap<>();
|
||||
result.put("message", "工具缓存已刷新");
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 启用指定工具
|
||||
*/
|
||||
@PostMapping("/enable/{toolName}")
|
||||
public Map<String, Object> enableTool(@PathVariable String toolName) {
|
||||
mcpToolsConfig.enableTool(toolName);
|
||||
toolCallbackProvider.refreshTools();
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("toolName", toolName);
|
||||
result.put("enabled", true);
|
||||
result.put("message", "工具已启用");
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 禁用指定工具
|
||||
*/
|
||||
@PostMapping("/disable/{toolName}")
|
||||
public Map<String, Object> disableTool(@PathVariable String toolName) {
|
||||
mcpToolsConfig.disableTool(toolName);
|
||||
toolCallbackProvider.refreshTools();
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("toolName", toolName);
|
||||
result.put("enabled", false);
|
||||
result.put("message", "工具已禁用");
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量设置工具状态
|
||||
* 请求体示例: {"basic": true, "terminal": false, "plantuml": true}
|
||||
*/
|
||||
@PostMapping("/batch")
|
||||
public Map<String, Object> batchSetTools(@RequestBody Map<String, Boolean> toolStates) {
|
||||
mcpToolsConfig.setToolsEnabled(toolStates);
|
||||
toolCallbackProvider.refreshTools();
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("updated", toolStates);
|
||||
result.put("enabled", toolCallbackProvider.getEnabledToolNames());
|
||||
result.put("message", "工具状态已更新");
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有工具的启用状态
|
||||
*/
|
||||
@GetMapping("/status")
|
||||
public Map<String, Object> getToolsStatus() {
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
List<String> registered = toolCallbackProvider.getRegisteredToolNames();
|
||||
|
||||
Map<String, Boolean> status = new HashMap<>();
|
||||
for (String toolName : registered) {
|
||||
status.put(toolName, mcpToolsConfig.isToolEnabled(toolName));
|
||||
}
|
||||
|
||||
result.put("status", status);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -1,28 +1,33 @@
|
||||
package org.ruoyi.mcpserve.service;
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.UUID;
|
||||
|
||||
|
||||
/**
|
||||
* @author ageer
|
||||
* 基础工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Service
|
||||
public class ToolService {
|
||||
@Component
|
||||
public class BasicTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "basic";
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "获取一个指定前缀的随机数")
|
||||
public String add(@ToolParam(description = "字符前缀") String prefix) {
|
||||
// 定义日期格式
|
||||
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyMMdd");
|
||||
//根据当前时间获取yyMMdd格式的时间字符串
|
||||
String format = LocalDate.now().format(formatter);
|
||||
//生成随机数
|
||||
String replace = prefix + UUID.randomUUID().toString().replace("-", "");
|
||||
return format + replace;
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.reader.tika.TikaDocumentReader;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.core.io.UrlResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 文档解析工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class DocumentTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "document";
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "从URL解析文档内容,支持PDF、Word、Excel等格式")
|
||||
public String parseDocumentFromUrl(
|
||||
@ToolParam(description = "要解析的文档URL地址") String fileUrl) {
|
||||
try {
|
||||
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new UrlResource(fileUrl));
|
||||
List<Document> documents = tikaDocumentReader.read();
|
||||
if (documents.isEmpty()) {
|
||||
return "No content found in the document.";
|
||||
}
|
||||
return documents.get(0).getText();
|
||||
} catch (Exception e) {
|
||||
return "Error parsing document: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import cn.hutool.core.io.FileUtil;
|
||||
import org.ruoyi.mcpserve.config.ToolsProperties;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* 文件操作工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class FileTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "file";
|
||||
|
||||
private final ToolsProperties toolsProperties;
|
||||
|
||||
public FileTools(ToolsProperties toolsProperties) {
|
||||
this.toolsProperties = toolsProperties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "读取文件内容")
|
||||
public String readFile(@ToolParam(description = "要读取的文件名") String fileName) {
|
||||
String fileDir = toolsProperties.getFile().getSaveDir() + "/file";
|
||||
String filePath = fileDir + "/" + fileName;
|
||||
try {
|
||||
return FileUtil.readUtf8String(filePath);
|
||||
} catch (Exception e) {
|
||||
return "Error reading file: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
|
||||
@Tool(description = "写入内容到文件")
|
||||
public String writeFile(
|
||||
@ToolParam(description = "要写入的文件名") String fileName,
|
||||
@ToolParam(description = "要写入的内容") String content) {
|
||||
String fileDir = toolsProperties.getFile().getSaveDir() + "/file";
|
||||
String filePath = fileDir + "/" + fileName;
|
||||
try {
|
||||
FileUtil.mkdir(fileDir);
|
||||
FileUtil.writeUtf8String(content, filePath);
|
||||
return "File written successfully to: " + filePath;
|
||||
} catch (Exception e) {
|
||||
return "Error writing to file: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.http.HttpUtil;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import org.ruoyi.mcpserve.config.ToolsProperties;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 图片搜索工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class ImageSearchTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "image-search";
|
||||
|
||||
private final ToolsProperties toolsProperties;
|
||||
|
||||
public ImageSearchTools(ToolsProperties toolsProperties) {
|
||||
this.toolsProperties = toolsProperties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "从Pexels搜索图片")
|
||||
public String searchImage(@ToolParam(description = "图片搜索关键词") String query) {
|
||||
try {
|
||||
String apiKey = toolsProperties.getPexels().getApiKey();
|
||||
String apiUrl = toolsProperties.getPexels().getApiUrl();
|
||||
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Authorization", apiKey);
|
||||
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("query", query);
|
||||
|
||||
String response = HttpUtil.createGet(apiUrl)
|
||||
.addHeaders(headers)
|
||||
.form(params)
|
||||
.execute()
|
||||
.body();
|
||||
|
||||
List<String> images = JSONUtil.parseObj(response)
|
||||
.getJSONArray("photos")
|
||||
.stream()
|
||||
.map(photoObj -> (JSONObject) photoObj)
|
||||
.map(photoObj -> photoObj.getJSONObject("src"))
|
||||
.map(photo -> photo.getStr("medium"))
|
||||
.filter(StrUtil::isNotBlank)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return String.join(",", images);
|
||||
} catch (Exception e) {
|
||||
return "Error search image: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
/**
|
||||
* MCP工具标记接口
|
||||
* 所有MCP工具类都需要实现此接口,以便动态加载器识别
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
public interface McpTool {
|
||||
|
||||
/**
|
||||
* 获取工具名称,用于配置文件中的启用/禁用控制
|
||||
*
|
||||
* @return 工具名称
|
||||
*/
|
||||
String getToolName();
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import net.sourceforge.plantuml.FileFormat;
|
||||
import net.sourceforge.plantuml.FileFormatOption;
|
||||
import net.sourceforge.plantuml.SourceStringReader;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
/**
|
||||
* PlantUML工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class PlantUmlTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "plantuml";
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "生成PlantUML图表并返回SVG代码")
|
||||
public String generatePlantUmlSvg(
|
||||
@ToolParam(description = "UML图表源代码") String umlCode) {
|
||||
try {
|
||||
if (umlCode == null || umlCode.trim().isEmpty()) {
|
||||
return "Error: UML代码不能为空";
|
||||
}
|
||||
|
||||
System.setProperty("PLANTUML_LIMIT_SIZE", "32768");
|
||||
System.setProperty("java.awt.headless", "true");
|
||||
|
||||
String normalizedUmlCode = normalizeUmlCode(umlCode);
|
||||
|
||||
SourceStringReader reader = new SourceStringReader(normalizedUmlCode);
|
||||
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
|
||||
reader.generateImage(outputStream, new FileFormatOption(FileFormat.SVG));
|
||||
|
||||
byte[] svgBytes = outputStream.toByteArray();
|
||||
if (svgBytes.length == 0) {
|
||||
return "Error: 生成的SVG内容为空,请检查UML语法是否正确";
|
||||
}
|
||||
|
||||
return new String(svgBytes, StandardCharsets.UTF_8);
|
||||
} catch (Exception e) {
|
||||
return "Error generating PlantUML: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
|
||||
private String normalizeUmlCode(String umlCode) {
|
||||
umlCode = umlCode.trim();
|
||||
if (umlCode.contains("@startuml")) {
|
||||
int startIndex = umlCode.indexOf("@startuml");
|
||||
int endIndex = umlCode.lastIndexOf("@enduml");
|
||||
if (endIndex > startIndex) {
|
||||
String startPart = umlCode.substring(startIndex);
|
||||
int firstNewLine = startPart.indexOf('\n');
|
||||
String content = firstNewLine > 0 ? startPart.substring(firstNewLine + 1) : "";
|
||||
if (content.contains("@enduml")) {
|
||||
content = content.substring(0, content.lastIndexOf("@enduml")).trim();
|
||||
}
|
||||
umlCode = content;
|
||||
}
|
||||
}
|
||||
|
||||
StringBuilder normalizedCode = new StringBuilder();
|
||||
normalizedCode.append("@startuml\n");
|
||||
normalizedCode.append("!pragma layout smetana\n");
|
||||
normalizedCode.append("skinparam charset UTF-8\n");
|
||||
normalizedCode.append("skinparam defaultFontName SimHei\n");
|
||||
normalizedCode.append("skinparam defaultFontSize 12\n");
|
||||
normalizedCode.append("skinparam dpi 150\n");
|
||||
normalizedCode.append("\n");
|
||||
normalizedCode.append(umlCode);
|
||||
if (!umlCode.endsWith("\n")) {
|
||||
normalizedCode.append("\n");
|
||||
}
|
||||
normalizedCode.append("@enduml");
|
||||
return normalizedCode.toString();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import org.ruoyi.mcpserve.config.ToolsProperties;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
/**
|
||||
* 终端命令工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class TerminalTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "terminal";
|
||||
|
||||
private final ToolsProperties toolsProperties;
|
||||
|
||||
public TerminalTools(ToolsProperties toolsProperties) {
|
||||
this.toolsProperties = toolsProperties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "在终端中执行命令")
|
||||
public String executeTerminalCommand(
|
||||
@ToolParam(description = "要执行的终端命令") String command) {
|
||||
StringBuilder output = new StringBuilder();
|
||||
try {
|
||||
String projectRoot = System.getProperty("user.dir");
|
||||
String fileDir = toolsProperties.getFile().getSaveDir() + "/file";
|
||||
File workingDir = new File(projectRoot, fileDir);
|
||||
|
||||
if (!workingDir.exists()) {
|
||||
workingDir.mkdirs();
|
||||
}
|
||||
|
||||
ProcessBuilder processBuilder;
|
||||
String os = System.getProperty("os.name").toLowerCase();
|
||||
if (os.contains("win")) {
|
||||
processBuilder = new ProcessBuilder("cmd.exe", "/c", command);
|
||||
} else {
|
||||
processBuilder = new ProcessBuilder("/bin/sh", "-c", command);
|
||||
}
|
||||
processBuilder.directory(workingDir);
|
||||
Process process = processBuilder.start();
|
||||
|
||||
try (BufferedReader reader = new BufferedReader(
|
||||
new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
output.append(line).append("\n");
|
||||
}
|
||||
}
|
||||
|
||||
int exitCode = process.waitFor();
|
||||
if (exitCode != 0) {
|
||||
output.append("Command execution failed with exit code: ").append(exitCode);
|
||||
}
|
||||
} catch (IOException | InterruptedException e) {
|
||||
output.append("Error executing command: ").append(e.getMessage());
|
||||
}
|
||||
return output.toString();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.Response;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* 网页内容加载工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class WebPageTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "web-page";
|
||||
|
||||
private final OkHttpClient httpClient;
|
||||
|
||||
public WebPageTools() {
|
||||
this.httpClient = new OkHttpClient.Builder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.readTimeout(30, TimeUnit.SECONDS)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "加载网页并提取文本内容")
|
||||
public String loadWebPage(@ToolParam(description = "要加载的网页URL地址") String url) {
|
||||
if (url == null || url.trim().isEmpty()) {
|
||||
return "Error: URL is empty. Please provide a valid URL.";
|
||||
}
|
||||
|
||||
try {
|
||||
Request request = new Request.Builder()
|
||||
.url(url)
|
||||
.header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
||||
.build();
|
||||
|
||||
try (Response response = httpClient.newCall(request).execute()) {
|
||||
if (!response.isSuccessful()) {
|
||||
return "Error: Failed to load web page, status: " + response.code();
|
||||
}
|
||||
|
||||
String html = response.body().string();
|
||||
// 简单的HTML文本提取
|
||||
String text = html.replaceAll("<script[^>]*>[\\s\\S]*?</script>", "")
|
||||
.replaceAll("<style[^>]*>[\\s\\S]*?</style>", "")
|
||||
.replaceAll("<[^>]+>", " ")
|
||||
.replaceAll("\\s+", " ")
|
||||
.trim();
|
||||
|
||||
return text;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
return "Error loading web page: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package org.ruoyi.mcpserve.tools;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import okhttp3.*;
|
||||
import org.ruoyi.mcpserve.config.ToolsProperties;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.annotation.ToolParam;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* 网页搜索工具类
|
||||
*
|
||||
* @author OpenX
|
||||
*/
|
||||
@Component
|
||||
public class WebSearchTools implements McpTool {
|
||||
|
||||
public static final String TOOL_NAME = "web-search";
|
||||
|
||||
private final ToolsProperties toolsProperties;
|
||||
private final OkHttpClient httpClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
public WebSearchTools(ToolsProperties toolsProperties) {
|
||||
this.toolsProperties = toolsProperties;
|
||||
this.httpClient = new OkHttpClient.Builder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.readTimeout(30, TimeUnit.SECONDS)
|
||||
.build();
|
||||
this.objectMapper = new ObjectMapper();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getToolName() {
|
||||
return TOOL_NAME;
|
||||
}
|
||||
|
||||
@Tool(description = "从网络搜索引擎搜索信息")
|
||||
public String webSearch(
|
||||
@ToolParam(description = "搜索查询文本") String query,
|
||||
@ToolParam(description = "最大返回结果数量") int maxResults) {
|
||||
List<Map<String, String>> results = new ArrayList<>();
|
||||
try {
|
||||
String apiKey = toolsProperties.getTavily().getApiKey();
|
||||
String baseUrl = toolsProperties.getTavily().getBaseUrl();
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("query", query);
|
||||
requestBody.put("max_results", maxResults);
|
||||
|
||||
Request request = new Request.Builder()
|
||||
.url(baseUrl)
|
||||
.post(RequestBody.create(MediaType.parse("application/json"),
|
||||
objectMapper.writeValueAsString(requestBody)))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", "Bearer " + apiKey)
|
||||
.build();
|
||||
|
||||
try (Response response = httpClient.newCall(request).execute()) {
|
||||
if (!response.isSuccessful()) {
|
||||
return "搜索请求失败: " + response;
|
||||
}
|
||||
|
||||
JsonNode jsonNode = objectMapper.readTree(response.body().string()).get("results");
|
||||
if (jsonNode != null && !jsonNode.isEmpty()) {
|
||||
jsonNode.forEach(data -> {
|
||||
Map<String, String> processedResult = new HashMap<>();
|
||||
processedResult.put("title", data.get("title").asText());
|
||||
processedResult.put("url", data.get("url").asText());
|
||||
processedResult.put("content", data.get("content").asText());
|
||||
results.add(processedResult);
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
return "搜索时发生错误: " + e.getMessage();
|
||||
}
|
||||
return JSONUtil.toJsonStr(results);
|
||||
}
|
||||
}
|
||||
@@ -7,4 +7,38 @@ spring:
|
||||
name: ruoyi-mcp-serve
|
||||
version: 1.0.0
|
||||
|
||||
# 工具配置
|
||||
tools:
|
||||
pexels:
|
||||
api-key: your-pexels-api-key #key获取地址: https://www.pexels.com/zh-cn/api/key
|
||||
api-url: https://api.pexels.com/v1/search
|
||||
tavily:
|
||||
api-key: your-tavily-api-key #key获取地址: https://app.tavily.com/home
|
||||
base-url: https://api.tavily.com/search
|
||||
file:
|
||||
save-dir: ./tmp
|
||||
|
||||
# MCP工具初始化配置
|
||||
mcp:
|
||||
tools:
|
||||
enabled:
|
||||
# 基础工具(随机数、当前时间)
|
||||
basic: true
|
||||
# 文件操作工具(读写文件)
|
||||
file: true
|
||||
# 图片搜索工具(Pexels)
|
||||
image-search: true
|
||||
# PlantUML图表生成工具
|
||||
plantuml: true
|
||||
# 网页搜索工具(Tavily)
|
||||
web-search: true
|
||||
# 终端命令执行工具
|
||||
terminal: true
|
||||
# 文档解析工具
|
||||
document: true
|
||||
# 网页内容加载工具
|
||||
web-page: true
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -31,22 +31,23 @@ import java.util.stream.IntStream;
|
||||
@Component
|
||||
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
|
||||
private final Integer DIMENSION = 2048;
|
||||
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
||||
private final Map<String, EmbeddingStore<TextSegment>> storeCache = new ConcurrentHashMap<>();
|
||||
|
||||
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
|
||||
super(vectorStoreProperties, embeddingModelFactory);
|
||||
}
|
||||
|
||||
private EmbeddingStore<TextSegment> getMilvusStore(String collectionName, boolean autoFlushOnInsert) {
|
||||
String key = collectionName + "|" + autoFlushOnInsert;
|
||||
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
||||
private final Map<String, EmbeddingStore<TextSegment>> storeCache = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* 获取 Milvus Store,支持动态维度
|
||||
*/
|
||||
private EmbeddingStore<TextSegment> getMilvusStore(String collectionName, int dimension, boolean autoFlushOnInsert) {
|
||||
String key = collectionName + "|" + dimension + "|" + autoFlushOnInsert;
|
||||
return storeCache.computeIfAbsent(key, k ->
|
||||
MilvusEmbeddingStore.builder()
|
||||
.uri(vectorStoreProperties.getMilvus().getUrl())
|
||||
.collectionName(collectionName)
|
||||
.dimension(DIMENSION)
|
||||
.dimension(dimension)
|
||||
.indexType(IndexType.IVF_FLAT)
|
||||
.metricType(MetricType.L2)
|
||||
.autoFlushOnInsert(autoFlushOnInsert)
|
||||
@@ -57,18 +58,37 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 embedding 模型的实际维度
|
||||
*/
|
||||
private int getModelDimension(String modelName) {
|
||||
try {
|
||||
EmbeddingModel model = getEmbeddingModel(modelName, null);
|
||||
// 使用一个测试文本获取向量维度
|
||||
Embedding testEmbedding = model.embed("test").content();
|
||||
int dimension = testEmbedding.dimension();
|
||||
log.info("Detected embedding model dimension: {} for model: {}", dimension, modelName);
|
||||
return dimension;
|
||||
} catch (Exception e) {
|
||||
log.warn("Failed to detect model dimension for: {}, using default 1024", modelName, e);
|
||||
return 1024; // 默认使用 1024 (bge-m3 的维度)
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid, String modelName) {
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
int dimension = getModelDimension(modelName);
|
||||
// 使用缓存获取连接以确保只初始化一次
|
||||
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, true);
|
||||
log.info("Milvus集合初始化完成: {}", collectionName);
|
||||
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, dimension, true);
|
||||
log.info("Milvus集合初始化完成: {}, dimension: {}", collectionName, dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION);
|
||||
int dimension = getModelDimension(storeEmbeddingBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), dimension);
|
||||
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
@@ -80,7 +100,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
// 复用连接,写入场景使用 autoFlush=false 以提升批量插入性能
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, false);
|
||||
|
||||
IntStream.range(0, chunkList.size()).forEach(i -> {
|
||||
String text = chunkList.get(i);
|
||||
@@ -100,13 +120,14 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION);
|
||||
int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), dimension);
|
||||
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
|
||||
|
||||
// 查询复用连接,autoFlush 对查询无影响,此处保持 true
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, true);
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, true);
|
||||
|
||||
List<String> resultList = new ArrayList<>();
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
@@ -127,14 +148,16 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
@SneakyThrows
|
||||
public void removeById(String id, String modelName) {
|
||||
// 注意:此处原逻辑使用 collectionname + id,保持现状
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, false);
|
||||
int dimension = getModelDimension(modelName);
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, dimension, false);
|
||||
embeddingStore.remove(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String docId, String kid) {
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
|
||||
// 使用默认维度,因为删除操作不需要精确的维度信息
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, 1024, false);
|
||||
Filter filter = MetadataFilterBuilder.metadataKey("docId").isEqualTo(docId);
|
||||
embeddingStore.removeAll(filter);
|
||||
log.info("Milvus成功删除 docId={} 的所有向量数据", docId);
|
||||
@@ -143,7 +166,8 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
@Override
|
||||
public void removeByFid(String fid, String kid) {
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
|
||||
// 使用默认维度,因为删除操作不需要精确的维度信息
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, 1024, false);
|
||||
Filter filter = MetadataFilterBuilder.metadataKey("fid").isEqualTo(fid);
|
||||
embeddingStore.removeAll(filter);
|
||||
log.info("Milvus成功删除 fid={} 的所有向量数据", fid);
|
||||
|
||||
@@ -5,7 +5,9 @@ import org.ruoyi.workflow.entity.WorkflowNode;
|
||||
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
|
||||
import org.ruoyi.workflow.workflow.node.EndNode;
|
||||
import org.ruoyi.workflow.workflow.node.answer.LLMAnswerNode;
|
||||
import org.ruoyi.workflow.workflow.node.httpRequest.HttpRequestNode;
|
||||
import org.ruoyi.workflow.workflow.node.keywordExtractor.KeywordExtractorNode;
|
||||
import org.ruoyi.workflow.workflow.node.knowledgeRetrieval.KnowledgeRetrievalNode;
|
||||
import org.ruoyi.workflow.workflow.node.mailSend.MailSendNode;
|
||||
import org.ruoyi.workflow.workflow.node.start.StartNode;
|
||||
import org.ruoyi.workflow.workflow.node.switcher.SwitcherNode;
|
||||
@@ -17,10 +19,11 @@ public class WfNodeFactory {
|
||||
switch (WfComponentNameEnum.getByName(wfComponent.getName())) {
|
||||
case START -> wfNode = new StartNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case LLM_ANSWER -> wfNode = new LLMAnswerNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case KEYWORD_EXTRACTOR ->
|
||||
wfNode = new KeywordExtractorNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case KEYWORD_EXTRACTOR -> wfNode = new KeywordExtractorNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case KNOWLEDGE_RETRIEVER -> wfNode = new KnowledgeRetrievalNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case END -> wfNode = new EndNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case MAIL_SEND -> wfNode = new MailSendNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case HTTP_REQUEST -> wfNode = new HttpRequestNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
case SWITCHER -> wfNode = new SwitcherNode(wfComponent, nodeDefinition, wfState, nodeState);
|
||||
default -> {
|
||||
}
|
||||
|
||||
@@ -0,0 +1,437 @@
|
||||
package org.ruoyi.workflow.workflow.node.httpRequest;
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.jsoup.Jsoup;
|
||||
import org.ruoyi.workflow.entity.WorkflowComponent;
|
||||
import org.ruoyi.workflow.entity.WorkflowNode;
|
||||
import org.ruoyi.workflow.workflow.NodeProcessResult;
|
||||
import org.ruoyi.workflow.workflow.WfNodeState;
|
||||
import org.ruoyi.workflow.workflow.WfState;
|
||||
import org.ruoyi.workflow.workflow.data.NodeIOData;
|
||||
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
|
||||
import org.springframework.http.*;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* HTTP 请求节点
|
||||
*/
|
||||
@Slf4j
|
||||
public class HttpRequestNode extends AbstractWfNode {
|
||||
|
||||
public HttpRequestNode(WorkflowComponent wfComponent, WorkflowNode nodeDef, WfState wfState, WfNodeState nodeState) {
|
||||
super(wfComponent, nodeDef, wfState, nodeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodeProcessResult onProcess() {
|
||||
try {
|
||||
HttpRequestNodeConfig config = checkAndGetConfig(HttpRequestNodeConfig.class);
|
||||
List<NodeIOData> inputs = state.getInputs();
|
||||
|
||||
// 渲染 URL(支持变量替换)
|
||||
String url = renderTemplate(config.getUrl(), inputs);
|
||||
if (StringUtils.isBlank(url)) {
|
||||
throw new IllegalArgumentException("请求 URL 不能为空");
|
||||
}
|
||||
|
||||
// 添加 Query 参数
|
||||
url = buildUrlWithParams(url, config.getParams(), inputs);
|
||||
|
||||
// 构建请求头
|
||||
HttpHeaders headers = buildHeaders(config.getHeaders(), inputs);
|
||||
|
||||
// 构建请求体
|
||||
Object requestBody = buildRequestBody(config, inputs);
|
||||
|
||||
// 执行 HTTP 请求(支持重试)
|
||||
String response = executeHttpRequest(url, config.getMethod(), headers, requestBody, config);
|
||||
|
||||
// 清除 HTML 标签(如果需要)
|
||||
if (Boolean.TRUE.equals(config.getClearHtml()) && StringUtils.isNotBlank(response)) {
|
||||
response = Jsoup.parse(response).text();
|
||||
}
|
||||
|
||||
// 构造输出
|
||||
List<NodeIOData> outputs = new ArrayList<>();
|
||||
outputs.add(NodeIOData.createByText("output", "HTTP响应", response));
|
||||
|
||||
return NodeProcessResult.builder().content(outputs).build();
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("HTTP 请求失败 in node: {}", node.getId(), e);
|
||||
|
||||
// 异常时返回错误信息
|
||||
List<NodeIOData> errorOutputs = new ArrayList<>();
|
||||
errorOutputs.add(NodeIOData.createByText("output", "错误", ""));
|
||||
errorOutputs.add(NodeIOData.createByText("error", "HTTP请求错误", e.getMessage()));
|
||||
|
||||
return NodeProcessResult.builder().content(errorOutputs).build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 渲染模板(支持变量替换)
|
||||
* 支持格式:
|
||||
* 1. {var_01} - 直接替换整个变量值
|
||||
* 2. {var_01.name} - 从 JSON 中提取 name 字段
|
||||
* 3. {var_01.user.email} - 支持嵌套路径
|
||||
*/
|
||||
private String renderTemplate(String template, List<NodeIOData> inputs) {
|
||||
if (StringUtils.isBlank(template)) {
|
||||
return "";
|
||||
}
|
||||
return renderTemplateWithJsonPath(template, inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* 增强的模板渲染,支持 JSON 路径提取
|
||||
*/
|
||||
private String renderTemplateWithJsonPath(String template, List<NodeIOData> inputs) {
|
||||
String result = template;
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
for (NodeIOData input : inputs) {
|
||||
if (input == null || input.getName() == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String varName = input.getName();
|
||||
String varValue = input.valueToString();
|
||||
|
||||
// 1. 处理简单变量替换 {var_01}
|
||||
result = result.replace("{" + varName + "}", varValue != null ? varValue : "");
|
||||
|
||||
// 2. 处理 JSON 路径提取 {var_01.field} 或 {var_01.user.name}
|
||||
// 尝试解析为 JSON
|
||||
Map<String, Object> jsonMap = tryParseJson(varValue, mapper);
|
||||
if (jsonMap != null) {
|
||||
result = replaceJsonPaths(result, varName, jsonMap);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 尝试将字符串解析为 JSON Map
|
||||
*/
|
||||
private Map<String, Object> tryParseJson(String value, ObjectMapper mapper) {
|
||||
if (StringUtils.isBlank(value)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
value = value.trim();
|
||||
if (!value.startsWith("{") && !value.startsWith("[")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
return mapper.readValue(value, new TypeReference<Map<String, Object>>() {});
|
||||
} catch (Exception e) {
|
||||
log.debug("无法解析为 JSON: {}", value);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 替换 JSON 路径变量,如 {var_01.name} 或 {var_01.user.email}
|
||||
*/
|
||||
private String replaceJsonPaths(String template, String varName, Map<String, Object> jsonMap) {
|
||||
String result = template;
|
||||
|
||||
// 查找所有 {varName.xxx} 格式的占位符
|
||||
String pattern = "\\{" + varName + "\\.([\\.\\w]+)\\}";
|
||||
java.util.regex.Pattern p = java.util.regex.Pattern.compile(pattern);
|
||||
java.util.regex.Matcher m = p.matcher(template);
|
||||
|
||||
while (m.find()) {
|
||||
String fullMatch = m.group(0); // 如 {var_01.name}
|
||||
String jsonPath = m.group(1); // 如 name 或 user.email
|
||||
|
||||
Object value = extractJsonValue(jsonMap, jsonPath);
|
||||
String replacement = value != null ? value.toString() : "";
|
||||
|
||||
result = result.replace(fullMatch, replacement);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 JSON Map 中提取嵌套路径的值
|
||||
* 例如:path = "user.email" 会提取 map.get("user").get("email")
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
private Object extractJsonValue(Map<String, Object> map, String path) {
|
||||
String[] parts = path.split("\\.");
|
||||
Object current = map;
|
||||
|
||||
for (String part : parts) {
|
||||
if (current instanceof Map) {
|
||||
current = ((Map<String, Object>) current).get(part);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return current;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建带参数的 URL
|
||||
*/
|
||||
private String buildUrlWithParams(String baseUrl, List<HttpRequestNodeConfig.ParamItem> params, List<NodeIOData> inputs) {
|
||||
if (params == null || params.isEmpty()) {
|
||||
return baseUrl;
|
||||
}
|
||||
|
||||
StringBuilder urlBuilder = new StringBuilder(baseUrl);
|
||||
boolean hasQuery = baseUrl.contains("?");
|
||||
|
||||
for (HttpRequestNodeConfig.ParamItem param : params) {
|
||||
if (StringUtils.isBlank(param.getName())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String name = renderTemplate(param.getName(), inputs);
|
||||
String value = renderTemplate(param.getValue(), inputs);
|
||||
|
||||
if (hasQuery) {
|
||||
urlBuilder.append("&");
|
||||
} else {
|
||||
urlBuilder.append("?");
|
||||
hasQuery = true;
|
||||
}
|
||||
|
||||
urlBuilder.append(name).append("=").append(value);
|
||||
}
|
||||
|
||||
return urlBuilder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建请求头
|
||||
*/
|
||||
private HttpHeaders buildHeaders(List<HttpRequestNodeConfig.HeaderItem> headerItems, List<NodeIOData> inputs) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
|
||||
if (headerItems != null) {
|
||||
for (HttpRequestNodeConfig.HeaderItem item : headerItems) {
|
||||
if (StringUtils.isNotBlank(item.getName())) {
|
||||
String name = renderTemplate(item.getName(), inputs);
|
||||
String value = renderTemplate(item.getValue(), inputs);
|
||||
headers.add(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建请求体
|
||||
*/
|
||||
private Object buildRequestBody(HttpRequestNodeConfig config, List<NodeIOData> inputs) {
|
||||
String method = config.getMethod();
|
||||
if ("GET".equalsIgnoreCase(method) || "DELETE".equalsIgnoreCase(method) || "HEAD".equalsIgnoreCase(method)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String contentType = config.getContentType();
|
||||
|
||||
// JSON Body
|
||||
if ("application/json".equalsIgnoreCase(contentType)) {
|
||||
if (config.getJsonBody() != null && !config.getJsonBody().isEmpty()) {
|
||||
return renderJsonBody(config.getJsonBody(), inputs);
|
||||
}
|
||||
}
|
||||
|
||||
// Form Data
|
||||
if ("multipart/form-data".equalsIgnoreCase(contentType)) {
|
||||
return buildFormData(config.getFormDataBody(), inputs);
|
||||
}
|
||||
|
||||
// Form URL Encoded
|
||||
if ("application/x-www-form-urlencoded".equalsIgnoreCase(contentType)) {
|
||||
return buildFormUrlEncoded(config.getFormUrlencodedBody(), inputs);
|
||||
}
|
||||
|
||||
// Text Body
|
||||
if (StringUtils.isNotBlank(config.getTextBody())) {
|
||||
return renderTemplate(config.getTextBody(), inputs);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 渲染 JSON 请求体
|
||||
* 支持三种模式:
|
||||
* 1. 普通字段替换:{"name": "{var_01.name}"}
|
||||
* 2. 整体 JSON 合并:{"$merge": "{var_01}"} - 将整个 JSON 对象合并进来
|
||||
* 3. 智能合并:如果值是 {var_01} 且是有效 JSON,自动展开合并
|
||||
*/
|
||||
private Map<String, Object> renderJsonBody(Map<String, Object> jsonBody, List<NodeIOData> inputs) {
|
||||
Map<String, Object> rendered = new HashMap<>();
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
for (Map.Entry<String, Object> entry : jsonBody.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
Object value = entry.getValue();
|
||||
|
||||
// 处理特殊的 $merge 指令
|
||||
if ("$merge".equals(key) && value instanceof String) {
|
||||
String varRef = (String) value;
|
||||
Map<String, Object> mergeData = resolveVariableAsJson(varRef, inputs, mapper);
|
||||
if (mergeData != null) {
|
||||
rendered.putAll(mergeData);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (value instanceof String) {
|
||||
String strValue = (String) value;
|
||||
|
||||
// 检查是否是单纯的变量引用(如 {var_01})
|
||||
if (strValue.matches("^\\{\\w+\\}$")) {
|
||||
// 尝试解析为 JSON 对象
|
||||
Map<String, Object> jsonValue = resolveVariableAsJson(strValue, inputs, mapper);
|
||||
if (jsonValue != null) {
|
||||
// 如果是 JSON 对象,合并所有字段
|
||||
rendered.putAll(jsonValue);
|
||||
} else {
|
||||
// 否则作为普通字符串处理
|
||||
rendered.put(key, renderTemplate(strValue, inputs));
|
||||
}
|
||||
} else {
|
||||
// 普通字符串或包含多个变量的模板
|
||||
rendered.put(key, renderTemplate(strValue, inputs));
|
||||
}
|
||||
} else if (value instanceof Map) {
|
||||
// 递归处理嵌套的 Map
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> nestedMap = (Map<String, Object>) value;
|
||||
rendered.put(key, renderJsonBody(nestedMap, inputs));
|
||||
} else {
|
||||
// 其他类型直接保留
|
||||
rendered.put(key, value);
|
||||
}
|
||||
}
|
||||
return rendered;
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析变量引用为 JSON 对象
|
||||
* 例如:{var_01} -> 尝试解析 var_01 的值为 JSON Map
|
||||
*/
|
||||
private Map<String, Object> resolveVariableAsJson(String varRef, List<NodeIOData> inputs, ObjectMapper mapper) {
|
||||
// 提取变量名(去掉 {})
|
||||
String varName = varRef.replaceAll("[{}]", "");
|
||||
|
||||
// 查找对应的输入变量
|
||||
for (NodeIOData input : inputs) {
|
||||
if (input != null && varName.equals(input.getName())) {
|
||||
String varValue = input.valueToString();
|
||||
return tryParseJson(varValue, mapper);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Form Data
|
||||
*/
|
||||
private MultiValueMap<String, String> buildFormData(List<HttpRequestNodeConfig.FormItem> formItems, List<NodeIOData> inputs) {
|
||||
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
|
||||
if (formItems != null) {
|
||||
for (HttpRequestNodeConfig.FormItem item : formItems) {
|
||||
if (StringUtils.isNotBlank(item.getName())) {
|
||||
String name = renderTemplate(item.getName(), inputs);
|
||||
String value = renderTemplate(item.getValue(), inputs);
|
||||
formData.add(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
return formData;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Form URL Encoded
|
||||
*/
|
||||
private MultiValueMap<String, String> buildFormUrlEncoded(List<HttpRequestNodeConfig.FormItem> formItems, List<NodeIOData> inputs) {
|
||||
return buildFormData(formItems, inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 HTTP 请求(支持重试)
|
||||
*/
|
||||
private String executeHttpRequest(String url, String method, HttpHeaders headers, Object body, HttpRequestNodeConfig config) {
|
||||
RestTemplate restTemplate = createRestTemplate(config.getTimeout());
|
||||
|
||||
int maxRetries = config.getRetryTimes() != null ? config.getRetryTimes() : 0;
|
||||
int attempt = 0;
|
||||
Exception lastException = null;
|
||||
|
||||
while (attempt <= maxRetries) {
|
||||
try {
|
||||
// 设置 Content-Type
|
||||
if (StringUtils.isNotBlank(config.getContentType())) {
|
||||
headers.setContentType(MediaType.parseMediaType(config.getContentType()));
|
||||
}
|
||||
|
||||
HttpEntity<?> requestEntity = new HttpEntity<>(body, headers);
|
||||
HttpMethod httpMethod = HttpMethod.valueOf(method.toUpperCase());
|
||||
|
||||
ResponseEntity<String> response = restTemplate.exchange(url, httpMethod, requestEntity, String.class);
|
||||
|
||||
if (response.getStatusCode().is2xxSuccessful()) {
|
||||
return response.getBody();
|
||||
} else {
|
||||
throw new RuntimeException("HTTP 请求失败,状态码: " + response.getStatusCode());
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
lastException = e;
|
||||
attempt++;
|
||||
|
||||
if (attempt <= maxRetries) {
|
||||
log.warn("HTTP 请求失败,正在重试 ({}/{}): {}", attempt, maxRetries, e.getMessage());
|
||||
try {
|
||||
TimeUnit.SECONDS.sleep(1); // 重试前等待 1 秒
|
||||
} catch (InterruptedException ie) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("重试等待被中断", ie);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new RuntimeException("HTTP 请求失败,已重试 " + maxRetries + " 次", lastException);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 RestTemplate(设置超时)
|
||||
*/
|
||||
private RestTemplate createRestTemplate(Integer timeoutSeconds) {
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
// 设置超时时间
|
||||
int timeout = (timeoutSeconds != null ? timeoutSeconds : 10) * 1000;
|
||||
org.springframework.http.client.SimpleClientHttpRequestFactory requestFactory =
|
||||
new org.springframework.http.client.SimpleClientHttpRequestFactory();
|
||||
requestFactory.setConnectTimeout(timeout);
|
||||
requestFactory.setReadTimeout(timeout);
|
||||
restTemplate.setRequestFactory(requestFactory);
|
||||
|
||||
return restTemplate;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package org.ruoyi.workflow.workflow.node.httpRequest;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* HTTP 请求节点配置
|
||||
*/
|
||||
@Data
|
||||
public class HttpRequestNodeConfig {
|
||||
|
||||
/**
|
||||
* HTTP 请求方法
|
||||
*/
|
||||
private String method = "GET";
|
||||
|
||||
/**
|
||||
* 请求 URL
|
||||
*/
|
||||
private String url;
|
||||
|
||||
/**
|
||||
* Content-Type
|
||||
*/
|
||||
@JsonProperty("content_type")
|
||||
private String contentType = "text/plain";
|
||||
|
||||
/**
|
||||
* 请求头列表
|
||||
*/
|
||||
private List<HeaderItem> headers;
|
||||
|
||||
/**
|
||||
* Query 参数列表
|
||||
*/
|
||||
private List<ParamItem> params;
|
||||
|
||||
/**
|
||||
* 纯文本请求体
|
||||
*/
|
||||
@JsonProperty("text_body")
|
||||
private String textBody;
|
||||
|
||||
/**
|
||||
* JSON 请求体
|
||||
*/
|
||||
@JsonProperty("json_body")
|
||||
private Map<String, Object> jsonBody;
|
||||
|
||||
/**
|
||||
* Form Data 请求体
|
||||
*/
|
||||
@JsonProperty("form_data_body")
|
||||
private List<FormItem> formDataBody;
|
||||
|
||||
/**
|
||||
* Form URL Encoded 请求体
|
||||
*/
|
||||
@JsonProperty("form_urlencoded_body")
|
||||
private List<FormItem> formUrlencodedBody;
|
||||
|
||||
/**
|
||||
* 请求体(通用)
|
||||
*/
|
||||
private Map<String, Object> body;
|
||||
|
||||
/**
|
||||
* 超时时间(秒)
|
||||
*/
|
||||
private Integer timeout = 10;
|
||||
|
||||
/**
|
||||
* 重试次数
|
||||
*/
|
||||
@JsonProperty("retry_times")
|
||||
private Integer retryTimes = 0;
|
||||
|
||||
/**
|
||||
* 是否清除 HTML 标签
|
||||
*/
|
||||
@JsonProperty("clear_html")
|
||||
private Boolean clearHtml = false;
|
||||
|
||||
/**
|
||||
* 请求头项
|
||||
*/
|
||||
@Data
|
||||
public static class HeaderItem {
|
||||
private String name;
|
||||
private String value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query 参数项
|
||||
*/
|
||||
@Data
|
||||
public static class ParamItem {
|
||||
private String name;
|
||||
private String value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Form 表单项
|
||||
*/
|
||||
@Data
|
||||
public static class FormItem {
|
||||
private String name;
|
||||
private String value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package org.ruoyi.workflow.workflow.node.knowledgeRetrieval;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.ruoyi.workflow.entity.WorkflowComponent;
|
||||
import org.ruoyi.workflow.entity.WorkflowNode;
|
||||
import org.ruoyi.workflow.util.SpringUtil;
|
||||
import org.ruoyi.workflow.workflow.NodeProcessResult;
|
||||
import org.ruoyi.workflow.workflow.WfNodeState;
|
||||
import org.ruoyi.workflow.workflow.WfState;
|
||||
import org.ruoyi.workflow.workflow.WorkflowUtil;
|
||||
import org.ruoyi.workflow.workflow.data.NodeIOData;
|
||||
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME;
|
||||
|
||||
/**
|
||||
* 【节点】知识库检索节点
|
||||
* 从知识库中检索相关内容
|
||||
*/
|
||||
@Slf4j
|
||||
public class KnowledgeRetrievalNode extends AbstractWfNode {
|
||||
|
||||
public KnowledgeRetrievalNode(WorkflowComponent wfComponent, WorkflowNode nodeDef, WfState wfState, WfNodeState nodeState) {
|
||||
super(wfComponent, nodeDef, wfState, nodeState);
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理知识库检索
|
||||
* nodeConfig 格式:
|
||||
* {
|
||||
* "knowledge_id": "kb_123",
|
||||
* "top_k": 5,
|
||||
* "similarity_threshold": 0.7,
|
||||
* "retrieval_mode": "vector",
|
||||
* "embedding_model": "text-embedding-3-small",
|
||||
* "return_source": true,
|
||||
* "prompt": "额外的查询改写提示词"
|
||||
* }
|
||||
*
|
||||
* @return 检索结果
|
||||
*/
|
||||
@Override
|
||||
public NodeProcessResult onProcess() {
|
||||
KnowledgeRetrievalNodeConfig config = checkAndGetConfig(KnowledgeRetrievalNodeConfig.class);
|
||||
|
||||
// 验证知识库ID
|
||||
if (StringUtils.isBlank(config.getKnowledgeId())) {
|
||||
log.error("Knowledge base ID is required but not provided");
|
||||
List<NodeIOData> outputs = new ArrayList<>();
|
||||
outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", "错误:未配置知识库ID"));
|
||||
return NodeProcessResult.builder().content(outputs).build();
|
||||
}
|
||||
|
||||
// 获取查询文本
|
||||
String queryText = getFirstInputText();
|
||||
if (StringUtils.isBlank(queryText)) {
|
||||
log.warn("Knowledge retrieval node has no input query, node: {}", state.getUuid());
|
||||
// 返回空结果
|
||||
List<NodeIOData> outputs = new ArrayList<>();
|
||||
outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", ""));
|
||||
return NodeProcessResult.builder().content(outputs).build();
|
||||
}
|
||||
|
||||
log.info("Knowledge retrieval node config: {}", config);
|
||||
log.info("Query text: {}", queryText);
|
||||
|
||||
// 如果有自定义提示词,对查询进行改写
|
||||
String finalQuery = queryText;
|
||||
if (StringUtils.isNotBlank(config.getPrompt())) {
|
||||
finalQuery = rewriteQuery(config, queryText);
|
||||
log.info("Rewritten query: {}", finalQuery);
|
||||
}
|
||||
|
||||
// 根据检索模式执行不同的检索策略
|
||||
String retrievalResult;
|
||||
String mode = config.getRetrievalMode() != null ? config.getRetrievalMode().toLowerCase() : "vector";
|
||||
|
||||
// 目前只支持向量检索,图谱检索需要依赖graph模块
|
||||
if ("graph".equals(mode) || "hybrid".equals(mode)) {
|
||||
log.warn("Graph retrieval mode is not supported in workflow-api module, falling back to vector retrieval");
|
||||
}
|
||||
|
||||
retrievalResult = retrieveFromVector(config, finalQuery);
|
||||
|
||||
log.info("Retrieval result length: {}", retrievalResult.length());
|
||||
|
||||
// 构建输出
|
||||
List<NodeIOData> outputs = new ArrayList<>();
|
||||
outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", retrievalResult));
|
||||
|
||||
// 如果需要返回原始查询
|
||||
outputs.add(NodeIOData.createByText("query", "", finalQuery));
|
||||
|
||||
return NodeProcessResult.builder().content(outputs).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用LLM改写查询
|
||||
*/
|
||||
private String rewriteQuery(KnowledgeRetrievalNodeConfig config, String originalQuery) {
|
||||
try {
|
||||
// 构建改写提示词
|
||||
String prompt = WorkflowUtil.renderTemplate(config.getPrompt(), state.getInputs());
|
||||
prompt = prompt.replace("{query}", originalQuery);
|
||||
|
||||
log.info("Query rewrite prompt: {}", prompt);
|
||||
|
||||
// 调用LLM进行查询改写
|
||||
String rewrittenQuery = invokeLLMSync(config, prompt);
|
||||
|
||||
if (StringUtils.isNotBlank(rewrittenQuery)) {
|
||||
log.info("Query rewritten from '{}' to '{}'", originalQuery, rewrittenQuery);
|
||||
return rewrittenQuery.trim();
|
||||
}
|
||||
|
||||
// 如果改写失败,返回原查询
|
||||
return originalQuery;
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to rewrite query, using original query", e);
|
||||
return originalQuery;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 同步调用LLM
|
||||
* 使用一个临时的流式处理器来收集完整响应
|
||||
*/
|
||||
private String invokeLLMSync(KnowledgeRetrievalNodeConfig config, String prompt) {
|
||||
try {
|
||||
// 创建一个StringBuilder来收集LLM响应
|
||||
StringBuilder responseBuilder = new StringBuilder();
|
||||
Object lock = new Object();
|
||||
boolean[] completed = {false};
|
||||
|
||||
// 创建临时节点状态用于LLM调用
|
||||
WfNodeState tempState = new WfNodeState();
|
||||
tempState.setUuid(state.getUuid() + "_rewrite");
|
||||
List<NodeIOData> tempInputs = new ArrayList<>();
|
||||
tempInputs.add(NodeIOData.createByText("input", "", prompt));
|
||||
tempState.setInputs(tempInputs);
|
||||
|
||||
// 创建临时工作流节点定义
|
||||
WorkflowNode tempNode = new WorkflowNode();
|
||||
tempNode.setUuid(tempState.getUuid());
|
||||
tempNode.setInputConfig(node.getInputConfig());
|
||||
|
||||
// 使用WorkflowUtil调用LLM(流式)
|
||||
WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class);
|
||||
List<dev.langchain4j.data.message.UserMessage> systemMessage =
|
||||
List.of(dev.langchain4j.data.message.UserMessage.from(prompt));
|
||||
|
||||
// 调用流式LLM
|
||||
String category = StringUtils.isNotBlank(config.getCategory()) ? config.getCategory() : "llm";
|
||||
String modelName = StringUtils.isNotBlank(config.getModelName()) ? config.getModelName() : "deepseek-chat";
|
||||
|
||||
workflowUtil.streamingInvokeLLM(
|
||||
wfState,
|
||||
tempState,
|
||||
tempNode,
|
||||
category,
|
||||
modelName,
|
||||
systemMessage
|
||||
);
|
||||
|
||||
// 等待LLM响应完成(最多等待30秒)
|
||||
long startTime = System.currentTimeMillis();
|
||||
long timeout = 30000; // 30秒超时
|
||||
|
||||
while (!completed[0] && (System.currentTimeMillis() - startTime) < timeout) {
|
||||
synchronized (lock) {
|
||||
// 检查是否有输出
|
||||
if (!tempState.getOutputs().isEmpty()) {
|
||||
for (NodeIOData output : tempState.getOutputs()) {
|
||||
if ("output".equals(output.getName())) {
|
||||
String text = output.valueToString();
|
||||
if (StringUtils.isNotBlank(text)) {
|
||||
responseBuilder.append(text);
|
||||
completed[0] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!completed[0]) {
|
||||
Thread.sleep(100); // 等待100ms后重试
|
||||
}
|
||||
}
|
||||
|
||||
String result = responseBuilder.toString().trim();
|
||||
if (StringUtils.isBlank(result)) {
|
||||
log.warn("LLM sync call returned empty response");
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to invoke LLM synchronously", e);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从向量库检索
|
||||
*/
|
||||
private String retrieveFromVector(KnowledgeRetrievalNodeConfig config, String query) {
|
||||
try {
|
||||
VectorStoreService vectorStoreService = SpringUtil.getBean(VectorStoreService.class);
|
||||
IKnowledgeInfoService knowledgeInfoService = SpringUtil.getBean(IKnowledgeInfoService.class);
|
||||
|
||||
// 获取知识库信息以获取embedding模型配置
|
||||
Long knowledgeId = Long.parseLong(config.getKnowledgeId());
|
||||
KnowledgeInfoVo knowledgeInfo = knowledgeInfoService.queryById(knowledgeId);
|
||||
|
||||
if (knowledgeInfo == null) {
|
||||
log.error("Knowledge base not found: {}", config.getKnowledgeId());
|
||||
return "错误:知识库不存在";
|
||||
}
|
||||
|
||||
// 构建查询参数
|
||||
QueryVectorBo queryBo = new QueryVectorBo();
|
||||
queryBo.setKid(config.getKnowledgeId());
|
||||
queryBo.setQuery(query);
|
||||
queryBo.setMaxResults(config.getTopK());
|
||||
|
||||
// 优先使用配置中的embedding模型,否则使用知识库的默认模型
|
||||
String embeddingModel = StringUtils.isNotBlank(config.getEmbeddingModel())
|
||||
? config.getEmbeddingModel()
|
||||
: knowledgeInfo.getEmbeddingModelName();
|
||||
|
||||
// 验证embedding模型配置
|
||||
if (StringUtils.isBlank(embeddingModel)) {
|
||||
log.error("Embedding model not configured for knowledge base: {}", config.getKnowledgeId());
|
||||
return "错误:知识库未配置向量化模型";
|
||||
}
|
||||
|
||||
queryBo.setEmbeddingModelName(embeddingModel);
|
||||
|
||||
log.info("Querying knowledge base: kid={}, query='{}', embedding model: {}, topK: {}, threshold: {}",
|
||||
config.getKnowledgeId(), query, embeddingModel, config.getTopK(), config.getSimilarityThreshold());
|
||||
|
||||
// 执行检索
|
||||
List<String> results = vectorStoreService.getQueryVector(queryBo);
|
||||
|
||||
log.info("Vector store query completed, results count: {}", results != null ? results.size() : 0);
|
||||
|
||||
if (results == null || results.isEmpty()) {
|
||||
log.warn("No results found from vector store for knowledge: {}, query: '{}'", config.getKnowledgeId(), query);
|
||||
return "";
|
||||
}
|
||||
|
||||
// 合并结果
|
||||
String mergedResult = String.join("\n\n---\n\n", results);
|
||||
log.info("Retrieved {} documents from vector store", results.size());
|
||||
|
||||
return mergedResult;
|
||||
} catch (NumberFormatException e) {
|
||||
log.error("Invalid knowledge base ID format: {}", config.getKnowledgeId(), e);
|
||||
return "错误:知识库ID格式无效";
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to retrieve from vector store", e);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package org.ruoyi.workflow.workflow.node.knowledgeRetrieval;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import jakarta.validation.constraints.Max;
|
||||
import jakarta.validation.constraints.Min;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
/**
|
||||
* 知识库检索节点配置
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@Data
|
||||
public class KnowledgeRetrievalNodeConfig {
|
||||
|
||||
/**
|
||||
* 知识库UUID(主要字段)
|
||||
*/
|
||||
@JsonProperty("knowledge_base_uuid")
|
||||
private String knowledgeBaseUuid;
|
||||
|
||||
/**
|
||||
* 知识库ID(兼容字段)
|
||||
*/
|
||||
@JsonProperty("knowledge_id")
|
||||
private String knowledgeId;
|
||||
|
||||
/**
|
||||
* 获取知识库ID(优先使用knowledgeBaseUuid)
|
||||
*/
|
||||
public String getKnowledgeId() {
|
||||
return knowledgeBaseUuid != null ? knowledgeBaseUuid : knowledgeId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检索的最大结果数
|
||||
*/
|
||||
@Min(1)
|
||||
@Max(100)
|
||||
@JsonProperty("top_k")
|
||||
private Integer topK = 5;
|
||||
|
||||
/**
|
||||
* 检索的最大结果数(兼容字段,前端使用top_n)
|
||||
*/
|
||||
@JsonProperty("top_n")
|
||||
private Integer topN;
|
||||
|
||||
/**
|
||||
* 获取topK值(优先使用topN)
|
||||
*/
|
||||
public Integer getTopK() {
|
||||
return topN != null ? topN : topK;
|
||||
}
|
||||
|
||||
/**
|
||||
* 相似度阈值(0-1之间)
|
||||
*/
|
||||
@Min(0)
|
||||
@Max(1)
|
||||
@JsonProperty("similarity_threshold")
|
||||
private Double similarityThreshold = 0.7;
|
||||
|
||||
/**
|
||||
* 相似度阈值(兼容字段,前端使用score)
|
||||
*/
|
||||
@JsonProperty("score")
|
||||
private Double score;
|
||||
|
||||
/**
|
||||
* 获取相似度阈值(优先使用score)
|
||||
*/
|
||||
public Double getSimilarityThreshold() {
|
||||
return score != null ? score : similarityThreshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检索模式:vector(向量检索)、graph(图谱检索)、hybrid(混合检索)
|
||||
*/
|
||||
@JsonProperty("retrieval_mode")
|
||||
private String retrievalMode = "vector";
|
||||
|
||||
/**
|
||||
* 模型分类(用于LLM查询改写)
|
||||
*/
|
||||
private String category;
|
||||
|
||||
/**
|
||||
* LLM模型名称(用于查询改写)
|
||||
*/
|
||||
@JsonProperty("model_name")
|
||||
private String modelName;
|
||||
|
||||
/**
|
||||
* Embedding模型名称(用于向量检索)
|
||||
*/
|
||||
@JsonProperty("embedding_model")
|
||||
private String embeddingModel;
|
||||
|
||||
/**
|
||||
* 是否返回原文
|
||||
*/
|
||||
@JsonProperty("return_source")
|
||||
private Boolean returnSource = true;
|
||||
|
||||
/**
|
||||
* 自定义查询提示词(可选)
|
||||
* 用于对查询进行预处理或改写
|
||||
*/
|
||||
private String prompt;
|
||||
}
|
||||
@@ -233,16 +233,22 @@ public class SwitcherNode extends AbstractWfNode {
|
||||
*/
|
||||
private String getValueFromInputs(String nodeUuid, String paramName, List<NodeIOData> inputs) {
|
||||
log.debug("从节点UUID '{}' 搜索参数 '{}'", nodeUuid, paramName);
|
||||
|
||||
String result = null;
|
||||
|
||||
// 首先尝试从当前输入中查找
|
||||
log.debug("检查当前输入 (数量: {})", inputs.size());
|
||||
for (NodeIOData input : inputs) {
|
||||
log.debug(" - 输入: 名称='{}', 值='{}'", input.getName(), input.valueToString());
|
||||
if (paramName.equals(input.getName())) {
|
||||
log.info("在当前输入中找到参数 '{}': '{}'", paramName, input.valueToString());
|
||||
return input.valueToString();
|
||||
result = input.valueToString();
|
||||
}
|
||||
}
|
||||
|
||||
if (result != null) {
|
||||
log.info("在当前输入中找到参数 '{}': '{}'", paramName, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// 如果当前输入中没有,尝试从工作流状态中查找指定节点的输出
|
||||
if (StringUtils.isNotBlank(nodeUuid)) {
|
||||
@@ -251,14 +257,27 @@ public class SwitcherNode extends AbstractWfNode {
|
||||
for (NodeIOData output : nodeOutputs) {
|
||||
log.debug(" - 输出: 名称='{}', 值='{}'", output.getName(), output.valueToString());
|
||||
if (paramName.equals(output.getName())) {
|
||||
log.info("在节点 '{}' 的输出中找到参数 '{}': '{}'", nodeUuid, paramName, output.valueToString());
|
||||
return output.valueToString();
|
||||
result = output.valueToString();
|
||||
}
|
||||
}
|
||||
|
||||
if (result != null) {
|
||||
log.info("在节点 '{}' 的输出中找到参数 '{}': '{}'", nodeUuid, paramName, result);
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
log.debug("节点UUID为空,跳过工作流状态搜索");
|
||||
}
|
||||
|
||||
// 特殊处理:如果找的是 'output' 但没找到,尝试找 'input'
|
||||
if ("output".equals(paramName)) {
|
||||
log.debug("未找到参数 'output',尝试查找 'input'");
|
||||
String inputValue = getValueFromInputs(nodeUuid, "input", inputs);
|
||||
if (inputValue != null) {
|
||||
return inputValue;
|
||||
}
|
||||
}
|
||||
|
||||
log.warn("在输入或节点 '{}' 的输出中未找到参数 '{}'", nodeUuid, paramName);
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package org.ruoyi.chat.service.chat.impl;
|
||||
|
||||
import dev.langchain4j.model.chat.StreamingChatModel;
|
||||
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
||||
import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
@@ -80,6 +83,30 @@ public class OllamaServiceImpl implements IChatService {
|
||||
return emitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* 工作流场景:支持 langchain4j handler
|
||||
*/
|
||||
@Override
|
||||
public void chat(ChatRequest request, StreamingChatResponseHandler handler) {
|
||||
log.info("workflow chat, model: {}", request.getModel());
|
||||
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(request.getModel());
|
||||
|
||||
StreamingChatModel model = OllamaStreamingChatModel.builder()
|
||||
.baseUrl(chatModelVo.getApiHost() != null ? chatModelVo.getApiHost() : "http://localhost:11434")
|
||||
.modelName(chatModelVo.getModelName())
|
||||
.build();
|
||||
|
||||
try {
|
||||
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
|
||||
dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
|
||||
model.chat(chatRequest, handler);
|
||||
} catch (Exception e) {
|
||||
log.error("workflow ollama请求失败:{}", e.getMessage(), e);
|
||||
throw new RuntimeException("ollama workflow chat failed: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getCategory() {
|
||||
return ChatModeType.OLLAMA.getCode();
|
||||
|
||||
Reference in New Issue
Block a user