Merge branch 'main' into main

This commit is contained in:
ageerle
2025-10-23 14:15:54 +08:00
committed by GitHub
197 changed files with 12263 additions and 343 deletions

View File

@@ -150,6 +150,13 @@
<strong>QQ技术交流群</strong><br> <strong>QQ技术交流群</strong><br>
<em>技术讨论</em> <em>技术讨论</em>
</td> </td>
<td align="center">
<img width="200" height="200" alt="95e8b1b3baeadbd24650bfb974ca5a58" src="https://github.com/user-attachments/assets/2a346218-6388-484d-aa75-6e98942193f7" /><br>
<strong>微信技术交流群</strong><br>
<em>技术讨论</em>
</td>
</tr> </tr>
</table> </table>

View File

@@ -0,0 +1,432 @@
# Ruoyi-AI 工作流模块详细说明文档
## 概述
Ruoyi-AI 工作流模块是一个基于 LangGraph4j 的智能工作流引擎支持可视化工作流设计、AI 模型集成、条件分支、人机交互等高级功能。该模块采用微服务架构,提供完整的 RESTful API 和流式响应支持。
## 模块架构
### 1. 模块结构
```
ruoyi-ai/
├── ruoyi-modules/
│ └── ruoyi-workflow/ # 工作流核心模块
│ ├── pom.xml
│ └── src/main/java/org/ruoyi/workflow/
│ └── controller/ # 控制器层
│ ├── WorkflowController.java
│ ├── WorkflowRuntimeController.java
│ └── admin/ # 管理端控制器
│ ├── AdminWorkflowController.java
│ └── AdminWorkflowComponentController.java
└── ruoyi-modules-api/
└── ruoyi-workflow-api/ # 工作流API模块
├── pom.xml
└── src/main/java/org/ruoyi/workflow/
├── entity/ # 实体类
├── dto/ # 数据传输对象
├── service/ # 服务接口
├── mapper/ # 数据访问层
├── workflow/ # 工作流核心逻辑
├── enums/ # 枚举类
├── util/ # 工具类
└── exception/ # 异常处理
```
### 2. 核心依赖
- **LangGraph4j**: 1.5.3 - 工作流图执行引擎
- **LangChain4j**: 1.2.0 - AI 模型集成框架
- **Spring Boot**: 3.x - 应用框架
- **MyBatis Plus**: 数据访问层
- **Redis**: 缓存和状态管理
- **Swagger/OpenAPI**: API 文档
## 核心功能
### 1. 工作流管理
#### 1.1 工作流定义
- **创建工作流**: 支持自定义标题、描述、公开性设置
- **编辑工作流**: 可视化节点编辑、连接线配置
- **版本控制**: 支持工作流的版本管理和回滚
- **权限管理**: 支持公开/私有工作流设置
#### 1.2 工作流执行
- **流式执行**: 基于 SSE 的实时流式响应
- **状态管理**: 完整的执行状态跟踪
- **错误处理**: 详细的错误信息和异常处理
- **中断恢复**: 支持工作流中断和恢复执行
### 2. 节点类型
#### 2.1 基础节点
- **Start**: 开始节点,定义工作流入口
- **End**: 结束节点,定义工作流出口
#### 2.2 AI 模型节点
- **Answer**: 大语言模型问答节点
- **Dalle3**: DALL-E 3 图像生成
- **Tongyiwanx**: 通义万相图像生成
- **Classifier**: 内容分类节点
#### 2.3 数据处理节点
- **DocumentExtractor**: 文档信息提取
- **KeywordExtractor**: 关键词提取
- **FaqExtractor**: 常见问题提取
- **KnowledgeRetrieval**: 知识库检索
#### 2.4 控制流节点
- **Switcher**: 条件分支节点
- **HumanFeedback**: 人机交互节点
#### 2.5 外部集成节点
- **Google**: Google 搜索集成
- **MailSend**: 邮件发送
- **HttpRequest**: HTTP 请求
- **Template**: 模板转换
### 3. 数据流管理
#### 3.1 输入输出定义
```java
// 节点输入输出数据结构
public class NodeIOData {
private String name; // 参数名称
private NodeIODataContent content; // 参数内容
}
// 支持的数据类型
public enum WfIODataTypeEnum {
TEXT, // 文本
NUMBER, // 数字
BOOLEAN, // 布尔值
FILES, // 文件
OPTIONS // 选项
}
```
#### 3.2 参数引用
- **节点间引用**: 支持上游节点输出作为下游节点输入
- **参数映射**: 自动处理参数名称映射
- **类型转换**: 自动进行数据类型转换
## 数据库设计
### 1. 核心表结构
#### 1.1 工作流定义表 (t_workflow)
```sql
CREATE TABLE t_workflow (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
uuid VARCHAR(32) NOT NULL DEFAULT '',
title VARCHAR(100) NOT NULL DEFAULT '',
remark TEXT NOT NULL DEFAULT '',
user_id BIGINT NOT NULL DEFAULT 0,
is_public TINYINT(1) NOT NULL DEFAULT 0,
is_enable TINYINT(1) NOT NULL DEFAULT 1,
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
is_deleted TINYINT(1) NOT NULL DEFAULT 0
);
```
#### 1.2 工作流节点表 (t_workflow_node)
```sql
CREATE TABLE t_workflow_node (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
uuid VARCHAR(32) NOT NULL DEFAULT '',
workflow_id BIGINT NOT NULL DEFAULT 0,
workflow_component_id BIGINT NOT NULL DEFAULT 0,
user_id BIGINT NOT NULL DEFAULT 0,
title VARCHAR(100) NOT NULL DEFAULT '',
remark VARCHAR(500) NOT NULL DEFAULT '',
input_config JSON NOT NULL DEFAULT ('{}'),
node_config JSON NOT NULL DEFAULT ('{}'),
position_x DOUBLE NOT NULL DEFAULT 0,
position_y DOUBLE NOT NULL DEFAULT 0,
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
is_deleted TINYINT(1) NOT NULL DEFAULT 0
);
```
#### 1.3 工作流边表 (t_workflow_edge)
```sql
CREATE TABLE t_workflow_edge (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
uuid VARCHAR(32) NOT NULL DEFAULT '',
workflow_id BIGINT NOT NULL DEFAULT 0,
source_node_uuid VARCHAR(32) NOT NULL DEFAULT '',
source_handle VARCHAR(32) NOT NULL DEFAULT '',
target_node_uuid VARCHAR(32) NOT NULL DEFAULT '',
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
is_deleted TINYINT(1) NOT NULL DEFAULT 0
);
```
#### 1.4 工作流运行时表 (t_workflow_runtime)
```sql
CREATE TABLE t_workflow_runtime (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
uuid VARCHAR(32) NOT NULL DEFAULT '',
user_id BIGINT NOT NULL DEFAULT 0,
workflow_id BIGINT NOT NULL DEFAULT 0,
input JSON NOT NULL DEFAULT ('{}'),
output JSON NOT NULL DEFAULT ('{}'),
status SMALLINT NOT NULL DEFAULT 1,
status_remark VARCHAR(250) NOT NULL DEFAULT '',
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
is_deleted TINYINT(1) NOT NULL DEFAULT 0
);
```
#### 1.5 工作流组件表 (t_workflow_component)
```sql
CREATE TABLE t_workflow_component (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
uuid VARCHAR(32) DEFAULT '' NOT NULL,
name VARCHAR(32) DEFAULT '' NOT NULL,
title VARCHAR(100) DEFAULT '' NOT NULL,
remark TEXT NOT NULL,
display_order INT DEFAULT 0 NOT NULL,
is_enable TINYINT(1) DEFAULT 0 NOT NULL,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted TINYINT(1) DEFAULT 0 NOT NULL
);
```
## API 接口
### 1. 工作流管理接口
#### 1.1 基础操作
```http
#
POST /workflow/add
Content-Type: application/json
{
"title": "",
"remark": "",
"isPublic": false
}
#
POST /workflow/update
Content-Type: application/json
{
"uuid": "UUID",
"title": "",
"remark": ""
}
#
POST /workflow/del/{uuid}
# /
POST /workflow/enable/{uuid}?enable=true
```
#### 1.2 搜索和查询
```http
#
GET /workflow/mine/search?keyword=&isPublic=true&currentPage=1&pageSize=10
#
GET /workflow/public/search?keyword=&currentPage=1&pageSize=10
#
GET /workflow/public/component/list
```
### 2. 工作流执行接口
#### 2.1 流式执行
```http
#
POST /workflow/run
Content-Type: application/json
Accept: text/event-stream
{
"uuid": "UUID",
"inputs": [
{
"name": "input",
"content": {
"type": 1,
"textContent": ""
}
}
]
}
```
#### 2.2 运行时管理
```http
#
POST /workflow/runtime/resume/{runtimeUuid}
Content-Type: application/json
{
"feedbackContent": ""
}
#
GET /workflow/runtime/page?wfUuid=UUID&currentPage=1&pageSize=10
#
GET /workflow/runtime/nodes/{runtimeUuid}
#
POST /workflow/runtime/clear?wfUuid=UUID
```
### 3. 管理端接口
#### 3.1 工作流管理
```http
#
POST /admin/workflow/search
Content-Type: application/json
{
"title": "",
"isPublic": true,
"isEnable": true
}
# /
POST /admin/workflow/enable?uuid=UUID&isEnable=true
```
## 核心实现
### 1. 工作流引擎 (WorkflowEngine)
工作流引擎是整个模块的核心,负责:
- 工作流图的构建和编译
- 节点执行调度
- 状态管理和持久化
- 流式输出处理
```java
public class WorkflowEngine {
// 核心执行方法
public void run(User user, List<ObjectNode> userInputs, SseEmitter sseEmitter) {
// 1. 验证工作流状态
// 2. 创建运行时实例
// 3. 构建状态图
// 4. 执行工作流
// 5. 处理流式输出
}
// 恢复执行方法
public void resume(String userInput) {
// 1. 更新状态
// 2. 继续执行
}
}
```
### 2. 节点工厂 (WfNodeFactory)
节点工厂负责根据组件类型创建对应的节点实例:
```java
public class WfNodeFactory {
public static AbstractWfNode create(WorkflowComponent component,
WorkflowNode node,
WfState wfState,
WfNodeState nodeState) {
// 根据组件类型创建对应的节点实例
switch (component.getName()) {
case "Answer":
return new LLMAnswerNode(component, node, wfState, nodeState);
case "Switcher":
return new SwitcherNode(component, node, wfState, nodeState);
// ... 其他节点类型
}
}
}
```
### 3. 图构建器 (WorkflowGraphBuilder)
图构建器负责将工作流定义转换为可执行的状态图:
```java
public class WorkflowGraphBuilder {
public StateGraph<WfNodeState> build(WorkflowNode startNode) {
// 1. 构建编译节点树
// 2. 转换为状态图
// 3. 添加节点和边
// 4. 处理条件分支
// 5. 处理并行执行
}
}
```
## 流式响应机制
### 1. SSE 事件类型
工作流执行过程中会发送多种类型的 SSE 事件:
```javascript
// 节点开始执行
[NODE_RUN_节点UUID] - 节点执行开始事件
// 节点输入数据
[NODE_INPUT_节点UUID] - 节点输入数据事件
// 节点输出数据
[NODE_OUTPUT_节点UUID] - 节点输出数据事件
// 流式内容块
[NODE_CHUNK_节点UUID] - 流式内容块事件
// 等待用户输入
[NODE_WAIT_FEEDBACK_BY_节点UUID] - 等待用户输入事件
```
### 2. 流式处理流程
1. **初始化**: 创建工作流运行时实例
2. **节点执行**: 逐个执行工作流节点
3. **实时输出**: 通过 SSE 实时推送执行结果
4. **状态更新**: 实时更新节点和工作流状态
5. **错误处理**: 捕获并处理执行过程中的错误
## 扩展开发
### 1. 自定义节点开发
要开发自定义工作流节点,需要:
1. **创建节点类**:继承 `AbstractWfNode`
2. **实现处理逻辑**:重写 `onProcess()` 方法
3. **定义配置类**:创建节点配置类
4. **注册组件**:在组件表中注册新组件
```java
public class CustomNode extends AbstractWfNode {
@Override
protected NodeProcessResult onProcess() {
// 实现自定义处理逻辑
List<NodeIOData> outputs = new ArrayList<>();
// ... 处理逻辑
return NodeProcessResult.success(outputs);
}
}
```
### 2. 自定义组件注册
```sql
-- 在 t_workflow_component 表中添加新组件
INSERT INTO t_workflow_component (uuid, name, title, remark, is_enable)
VALUES (REPLACE(UUID(), '-', ''), 'CustomNode', '自定义节点', '自定义节点描述', true);
```

24
pom.xml
View File

@@ -270,13 +270,6 @@
<version>${lock4j.version}</version> <version>${lock4j.version}</version>
</dependency> </dependency>
<!-- xxl-job-core -->
<dependency>
<groupId>com.xuxueli</groupId>
<artifactId>xxl-job-core</artifactId>
<version>${xxl-job.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.alibaba</groupId> <groupId>com.alibaba</groupId>
<artifactId>transmittable-thread-local</artifactId> <artifactId>transmittable-thread-local</artifactId>
@@ -373,6 +366,23 @@
<artifactId>langchain4j-community-neo4j</artifactId> <artifactId>langchain4j-community-neo4j</artifactId>
<version>${langchain4j-neo4j.version}</version> <version>${langchain4j-neo4j.version}</version>
</dependency> </dependency>
<artifactId>ruoyi-aihuman</artifactId>
<version>${revision}</version>
</dependency>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-workflow</artifactId>
<version>${revision}</version>
</dependency>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-workflow-api</artifactId>
<version>${revision}</version>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>

View File

@@ -61,6 +61,14 @@
<dependency> <dependency>
<groupId>org.ruoyi</groupId> <groupId>org.ruoyi</groupId>
<artifactId>ruoyi-graph</artifactId> <artifactId>ruoyi-graph</artifactId>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-workflow</artifactId>
</dependency>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-aihuman</artifactId>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@@ -11,7 +11,7 @@ import org.springframework.scheduling.annotation.EnableScheduling;
* *
* @author Lion Li * @author Lion Li
*/ */
@SpringBootApplication @SpringBootApplication(scanBasePackages = {"org.ruoyi", "org.ruoyi.aihuman"})
@EnableScheduling @EnableScheduling
@EnableAsync @EnableAsync
public class RuoYiAIApplication { public class RuoYiAIApplication {
@@ -22,4 +22,4 @@ public class RuoYiAIApplication {
application.run(args); application.run(args);
System.out.println("(♥◠‿◠)ノ゙ RuoYiAI启动成功 ლ(´ڡ`ლ)゙"); System.out.println("(♥◠‿◠)ノ゙ RuoYiAI启动成功 ლ(´ڡ`ლ)゙");
} }
} }

View File

@@ -37,6 +37,8 @@ spring:
connectionTestQuery: SELECT 1 connectionTestQuery: SELECT 1
# 多久检查一次连接的活性 # 多久检查一次连接的活性
keepaliveTime: 30000 keepaliveTime: 30000
mail:
username: xx
--- # redis 单机配置(单机与集群只能开启一个另一个需要注释掉) --- # redis 单机配置(单机与集群只能开启一个另一个需要注释掉)
spring.data: spring.data:
@@ -102,7 +104,15 @@ pdf:
#百炼模型配置 #百炼模型配置
dashscope: dashscope:
key: sk-xxxx key: sk-xxxx
model: qvq-max
local:
images: xx
files: xx
--- # Neo4j 知识图谱配置 --- # Neo4j 知识图谱配置
neo4j: neo4j:

View File

@@ -156,6 +156,8 @@ security:
# actuator 监控配置 # actuator 监控配置
- /actuator - /actuator
- /actuator/** - /actuator/**
- /workflow/**
- /admin/workflow/**
# 多租户配置 # 多租户配置
tenant: tenant:
# 是否开启 # 是否开启
@@ -328,3 +330,19 @@ spring:
servers-configuration: classpath:mcp-server.json servers-configuration: classpath:mcp-server.json
request-timeout: 300s request-timeout: 300s
# 向量库配置
vector-store:
# 向量存储类型 可选(weaviate/milvus)
# 如需修改向量库类型,请修改此配置值!
type: weaviate
# Weaviate配置
weaviate:
protocol: http
host: 127.0.0.1:6038
classname: LocalKnowledge
# Milvus配置
milvus:
url: http://localhost:19530
collectionname: LocalKnowledge

View File

@@ -26,11 +26,18 @@ public class ChatRequest {
*/ */
private String prompt; private String prompt;
/** /**
* 系统提示词 * 系统提示词
*/ */
private String sysPrompt; private String sysPrompt;
/**
* 消息id
*/
private Long messageId;
/** /**
* 是否开启流式对话 * 是否开启流式对话
*/ */
@@ -72,6 +79,11 @@ public class ChatRequest {
*/ */
private Boolean hasAttachment; private Boolean hasAttachment;
/**
* 是否启用深度思考
*/
private Boolean enableThinking;
/** /**
* 是否自动切换模型 * 是否自动切换模型
*/ */
@@ -82,9 +94,4 @@ public class ChatRequest {
*/ */
private String token; private String token;
/**
* 消息ID保存消息成功后设置用于后续扣费更新
*/
private Long messageId;
} }

View File

@@ -0,0 +1,62 @@
package org.ruoyi.common.core.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* 向量库配置属性
*
* @author ageer
*/
@Data
@Component
@ConfigurationProperties(prefix = "vector-store")
public class VectorStoreProperties {
/**
* 向量库类型
*/
private String type;
/**
* Weaviate配置
*/
private Weaviate weaviate = new Weaviate();
/**
* Milvus配置
*/
private Milvus milvus = new Milvus();
@Data
public static class Weaviate {
/**
* 协议
*/
private String protocol;
/**
* 主机地址
*/
private String host;
/**
* 类名
*/
private String classname;
}
@Data
public static class Milvus {
/**
* 连接URL
*/
private String url;
/**
* 集合名称
*/
private String collectionname;
}
}

View File

@@ -17,6 +17,7 @@
<module>ruoyi-chat-api</module> <module>ruoyi-chat-api</module>
<module>ruoyi-knowledge-api</module> <module>ruoyi-knowledge-api</module>
<module>ruoyi-system-api</module> <module>ruoyi-system-api</module>
<module>ruoyi-workflow-api</module>
</modules> </modules>
<properties> <properties>

View File

@@ -16,7 +16,7 @@
<maven.compiler.source>17</maven.compiler.source> <maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target> <maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<spring-ai.version>1.0.0</spring-ai.version> <spring-ai.version>1.0.0-M7</spring-ai.version>
</properties> </properties>
<dependencyManagement> <dependencyManagement>

View File

@@ -1,6 +1,7 @@
package org.ruoyi.domain; package org.ruoyi.domain;
import com.alibaba.excel.annotation.ExcelProperty;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data; import lombok.Data;
@@ -81,6 +82,11 @@ public class ChatModel extends BaseEntity {
*/ */
private Integer priority; private Integer priority;
/**
* 模型供应商
*/
private String ProviderName;
/** /**
* 备注 * 备注
*/ */

View File

@@ -0,0 +1,67 @@
package org.ruoyi.domain;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.ruoyi.annotation.DataColumn;
import org.ruoyi.core.domain.BaseEntity;
import java.util.Date;
/**
* MCP对象 mcp_info
*
* @author ageerle
* @date Sat Aug 09 16:50:58 CST 2025
*/
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("mcp_info")
public class McpInfo extends BaseEntity {
/**
* id
*/
@TableId(value = "mcp_id", type = IdType.AUTO)
private Integer mcpId;
/**
* 服务器名称
*/
private String serverName;
/**
* 链接方式
*/
private String transportType;
/**
* Command
*/
private String command;
/**
* Args
*/
private String arguments;
private String description;
/**
* Env
*/
private String env;
/**
* 是否启用
*/
private Boolean status;
}

View File

@@ -1,5 +1,6 @@
package org.ruoyi.domain.bo; package org.ruoyi.domain.bo;
import com.alibaba.excel.annotation.ExcelProperty;
import io.github.linpeilie.annotations.AutoMapper; import io.github.linpeilie.annotations.AutoMapper;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
@@ -85,6 +86,10 @@ public class ChatModelBo extends BaseEntity {
@NotBlank(message = "密钥不能为空", groups = { AddGroup.class, EditGroup.class }) @NotBlank(message = "密钥不能为空", groups = { AddGroup.class, EditGroup.class })
private String apiKey; private String apiKey;
/**
* 模型供应商
*/
private String ProviderName;
/** /**
* 备注 * 备注

View File

@@ -0,0 +1,59 @@
package org.ruoyi.domain.bo;
import io.github.linpeilie.annotations.AutoMapper;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import org.ruoyi.domain.McpInfo;
import java.io.Serializable;
/**
* MCP业务对象 mcp_info
*
* @author ageerle
* @date Sat Aug 09 16:50:58 CST 2025
*/
@Data
@AutoMapper(target = McpInfo.class, reverseConvertGenerate = false)
public class McpInfoBo implements Serializable {
/**
* id
*/
@NotNull(message = "id不能为空" )
private Integer mcpId;
/**
* 服务器名称
*/
private String serverName;
/**
* 链接方式
*/
private String transportType;
/**
* Command
*/
private String command;
/**
* Args
*/
private String arguments;
private String description;
/**
* Env
*/
private String env;
/**
* 是否启用
*/
private Boolean status;
}

View File

@@ -70,6 +70,11 @@ public class ChatModelVo implements Serializable {
@ExcelProperty(value = "是否显示") @ExcelProperty(value = "是否显示")
private String modelShow; private String modelShow;
/**
* 模型维度
*/
private Integer dimension;
/** /**
* 系统提示词 * 系统提示词
*/ */
@@ -95,6 +100,12 @@ public class ChatModelVo implements Serializable {
@ExcelProperty(value = "优先级") @ExcelProperty(value = "优先级")
private Integer priority; private Integer priority;
/**
* 模型供应商
*/
@ExcelProperty(value = "模型供应商")
private String ProviderName;
/** /**
* 备注 * 备注
*/ */

View File

@@ -0,0 +1,65 @@
package org.ruoyi.domain.vo;
import com.alibaba.excel.annotation.ExcelIgnoreUnannotated;
import com.alibaba.excel.annotation.ExcelProperty;
import io.github.linpeilie.annotations.AutoMapper;
import lombok.Data;
import org.ruoyi.common.excel.annotation.ExcelDictFormat;
import org.ruoyi.common.excel.convert.ExcelDictConvert;
import org.ruoyi.domain.McpInfo;
import java.io.Serializable;
/**
* MCP视图对象 mcp_info
*
* @author jiyi
* @date Sat Aug 09 16:50:58 CST 2025
*/
@Data
@ExcelIgnoreUnannotated
@AutoMapper(target = McpInfo.class)
public class McpInfoVo implements Serializable {
private Integer mcpId;
/**
* 服务器名称
*/
@ExcelProperty(value = "服务器名称")
private String serverName;
/**
* 链接方式
*/
@ExcelProperty(value = "链接方式", converter = ExcelDictConvert.class)
@ExcelDictFormat(dictType = "mcp_transport_type")
private String transportType;
/**
* Command
*/
@ExcelProperty(value = "Command")
private String command;
/**
* Args
*/
@ExcelProperty(value = "Args")
private String arguments;
@ExcelProperty(value = "Description")
private String description;
/**
* Env
*/
@ExcelProperty(value = "Env")
private String env;
/**
* 是否启用
*/
@ExcelProperty(value = "是否启用")
private Boolean status;
}

View File

@@ -0,0 +1,33 @@
package org.ruoyi.mapper;
import org.apache.ibatis.annotations.*;
import org.ruoyi.core.mapper.BaseMapperPlus;
import org.ruoyi.domain.McpInfo;
import org.ruoyi.domain.vo.McpInfoVo;
import java.util.List;
/**
* MCPMapper接口
*
* @author jiuyi
* @date Sat Aug 09 16:50:58 CST 2025
*/
@Mapper
public interface McpInfoMapper extends BaseMapperPlus<McpInfo, McpInfoVo> {
@Select("SELECT * FROM mcp_info WHERE server_name = #{serverName}")
McpInfo selectByServerName(@Param("serverName") String serverName);
@Select("SELECT * FROM mcp_info WHERE status = 1")
List<McpInfo> selectActiveServers();
@Select("SELECT server_name FROM mcp_info WHERE status = 1")
List<String> selectActiveServerNames();
@Update("UPDATE mcp_info SET status = #{status} WHERE server_name = #{serverName}")
int updateActiveStatus(@Param("serverName") String serverName, @Param("status") Boolean status);
@Delete("DELETE FROM mcp_info WHERE server_name = #{serverName}")
int deleteByServerName(@Param("serverName") String serverName);
}

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="org.ruoyi.mapper.McpInfoMapper">
</mapper>

View File

@@ -74,6 +74,18 @@
<version>1.19.6</version> <version>1.19.6</version>
</dependency> </dependency>
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.6.4</version>
</dependency>
<!-- LangChain4j Milvus Embedding Store -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-milvus</artifactId>
</dependency>
<dependency> <dependency>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId> <artifactId>langchain4j-open-ai</artifactId>
@@ -101,11 +113,10 @@
<artifactId>commons-compress</artifactId> <artifactId>commons-compress</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.alibaba</groupId> <groupId>org.ruoyi</groupId>
<artifactId>dashscope-sdk-java</artifactId> <artifactId>ruoyi-chat-api</artifactId>
<version>2.19.0</version> </dependency>
</dependency>
</dependencies> </dependencies>

View File

@@ -83,6 +83,11 @@ public class KnowledgeInfo extends BaseEntity {
*/ */
private String vectorModelName; private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/** /**
* 向量化模型名称 * 向量化模型名称
*/ */

View File

@@ -92,7 +92,11 @@ public class KnowledgeInfoBo extends BaseEntity {
/** /**
* 向量化模型名称 * 向量化模型名称
*/ */
@NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class }) private Long embeddingModelId;
/**
* 向量化模型名称
*/
private String embeddingModelName; private String embeddingModelName;

View File

@@ -31,7 +31,12 @@ public class QueryVectorBo {
private String vectorModelName; private String vectorModelName;
/** /**
* 向量化模型名称 * 向量化模型ID
*/
private Long embeddingModelId;
/**
* 向量化模型ID
*/ */
private String embeddingModelName; private String embeddingModelName;

View File

@@ -32,9 +32,14 @@ public class StoreEmbeddingBo {
private List<String> fids; private List<String> fids;
/** /**
* 向量库模型名称 * 向量库名称
*/ */
private String vectorModelName; private String vectorStoreName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/** /**
* 向量化模型名称 * 向量化模型名称

View File

@@ -101,6 +101,11 @@ public class KnowledgeInfoVo implements Serializable {
*/ */
private String vectorModelName; private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/** /**
* 向量化模型名称 * 向量化模型名称
*/ */

View File

@@ -0,0 +1,26 @@
package org.ruoyi.embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.embedding.model.ModalityType;
import java.util.Set;
/**
* BaseEmbedModelService 接口,扩展了 EmbeddingModel 接口
* 该接口定义了嵌入模型服务的基本配置和功能方法
*/
public interface BaseEmbedModelService extends EmbeddingModel {
/**
* 根据配置信息配置嵌入模型
* @param config 包含模型配置信息的 ChatModelVo 对象
*/
void configure(ChatModelVo config);
/**
* 获取当前嵌入模型支持的所有模态类型
* @return 返回支持的模态类型集合
*/
Set<ModalityType> getSupportedModalities();
}

View File

@@ -0,0 +1,120 @@
package org.ruoyi.embedding;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* 嵌入模型工厂服务类
* 负责创建和管理各种嵌入模型实例
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class EmbeddingModelFactory {
private final ApplicationContext applicationContext;
private final IChatModelService chatModelService;
// 模型缓存使用ConcurrentHashMap保证线程安全
private final Map<String, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
/**
* 创建嵌入模型实例
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
*
* @param embeddingModelName 嵌入模型名称
* @param dimension 模型维度大小
*/
public BaseEmbedModelService createModel(String embeddingModelName, Integer dimension) {
return modelCache.computeIfAbsent(embeddingModelName, name -> {
ChatModelVo modelConfig = chatModelService.selectModelByName(embeddingModelName);
if (modelConfig == null) {
throw new IllegalArgumentException("未找到模型配置name=" + name);
}
if (modelConfig.getDimension() != null) {
modelConfig.setDimension(dimension);
}
return createModelInstance(modelConfig.getProviderName(), modelConfig);
});
}
/**
* 检查模型是否支持多模态
*
* @param embeddingModelName 嵌入模型名称
* @return boolean 如果模型支持多模态则返回true否则返回false
*/
public boolean isMultimodalModel(String embeddingModelName) {
return createModel(embeddingModelName, null) instanceof MultiModalEmbedModelService;
}
/**
* 创建多模态嵌入模型实例
*
* @param embeddingModelName 嵌入模型名称
* @return MultiModalEmbedModelService 多模态嵌入模型服务实例
* @throws IllegalArgumentException 当模型不支持多模态时抛出
*/
public MultiModalEmbedModelService createMultimodalModel(String embeddingModelName) {
BaseEmbedModelService model = createModel(embeddingModelName, null);
if (model instanceof MultiModalEmbedModelService) {
return (MultiModalEmbedModelService) model;
}
throw new IllegalArgumentException("该模型不支持多模态");
}
/**
* 刷新模型缓存
* 根据给定的嵌入模型ID从缓存中移除对应的模型
*
* @param embeddingModelId 嵌入模型的唯一标识ID
*/
public void refreshModel(Long embeddingModelId) {
// 从模型缓存中移除指定ID的模型
modelCache.remove(embeddingModelId);
}
/**
* 获取所有支持模型工厂的列表
*
* @return List<String> 支持的模型工厂名称列表
*/
public List<String> getSupportedFactories() {
return new ArrayList<>(applicationContext.getBeansOfType(BaseEmbedModelService.class)
.keySet());
}
/**
* 创建具体的模型实例
* 根据提供的工厂名称和配置信息创建并配置模型实例
*
* @param factory 工厂名称,用于标识模型类型
* @param config 模型配置信息
* @return BaseEmbedModelService 配置好的模型实例
* @throws IllegalArgumentException 当无法获取指定的模型实例时抛出
*/
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) {
try {
// 从Spring上下文中获取模型实例
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class);
// 配置模型参数
model.configure(config);
log.info("成功创建嵌入模型: factory={}, modelId={}", config.getProviderName(), config.getId());
return model;
} catch (NoSuchBeanDefinitionException e) {
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
}
}
}

View File

@@ -0,0 +1,35 @@
package org.ruoyi.embedding;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import org.ruoyi.embedding.model.MultiModalInput;
/**
* 多模态嵌入模型服务接口,继承自基础嵌入模型服务
* 该接口提供了处理图像、视频以及多模态数据并转换为嵌入向量的功能
*/
public interface MultiModalEmbedModelService extends BaseEmbedModelService {
/**
* 将图像数据转换为嵌入向量
* @param imageDataUrl 图像的地址必须是公开可访问的URL
* @return 包含嵌入向量的响应对象,可能包含状态信息和嵌入结果
*/
Response<Embedding> embedImage(String imageDataUrl);
/**
* 将视频数据转换为嵌入向量
* @param videoDataUrl 视频的地址必须是公开可访问的URL
* @return 包含嵌入向量的响应对象,可能包含状态信息和嵌入结果
*/
Response<Embedding> embedVideo(String videoDataUrl);
/**
* 处理多模态输入并返回嵌入向量的方法
*
* @param input 包含多种模态信息(如图像、文本等)的输入对象
* @return Response<Embedding> 包含嵌入向量的响应对象Embedding通常表示输入数据的向量表示
*/
Response<Embedding> embedMultiModal(MultiModalInput input);
}

View File

@@ -0,0 +1,14 @@
package org.ruoyi.embedding.impl;
import org.springframework.stereotype.Component;
/**
* @Author: Robust_H
* @Date: 2025-09-30-下午3:00
* @Description: 阿里百炼基础嵌入模型兼容openai
*/
@Component("alibailian")
public class AliBaiLianBaseEmbedProvider extends OpenAiEmbeddingProvider{
}

View File

@@ -0,0 +1,281 @@
package org.ruoyi.embedding.impl;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.embedding.MultiModalEmbedModelService;
import org.ruoyi.embedding.model.AliyunMultiModalEmbedRequest;
import org.ruoyi.embedding.model.AliyunMultiModalEmbedResponse;
import org.ruoyi.embedding.model.ModalityType;
import org.ruoyi.embedding.model.MultiModalInput;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.TimeUnit;
/**
* 阿里云百炼多模态嵌入模型服务实现类
* 实现了MultiModalEmbedModelService接口提供文本、图像和视频的嵌入向量生成服务
*/
@Component("bailianMultiModel")
@Slf4j
public class AliBaiLianMultiEmbeddingProvider implements MultiModalEmbedModelService {
private ChatModelVo chatModelVo;
private final OkHttpClient okHttpClient;
/**
* 构造函数初始化HTTP客户端
* 设置连接超时、读取超时和写入超时时间
*/
public AliBaiLianMultiEmbeddingProvider() {
this.okHttpClient = new OkHttpClient.Builder()
.connectTimeout(30, TimeUnit.SECONDS)
.readTimeout(60, TimeUnit.SECONDS)
.writeTimeout(30, TimeUnit.SECONDS)
.build();
}
/**
* 图像嵌入向量生成
* @param imageDataUrl 图像数据的URL
* @return 包含图像嵌入向量的Response对象
*/
@Override
public Response<Embedding> embedImage(String imageDataUrl) {
return embedSingleModality("image", imageDataUrl);
}
/**
* 视频嵌入向量生成
* @param videoDataUrl 视频数据的URL
* @return 包含视频嵌入向量的Response对象
*/
@Override
public Response<Embedding> embedVideo(String videoDataUrl) {
return embedSingleModality("video", videoDataUrl);
}
/**
* 多模态嵌入向量生成
* 支持同时处理文本、图像和视频等多种模态的数据
* @param input 包含多种模态输入的对象
* @return 包含多模态嵌入向量的Response对象
*/
@Override
public Response<Embedding> embedMultiModal(MultiModalInput input) {
try {
// 构建请求内容
List<Map<String, Object>> contents = buildContentMap(input);
if (contents.isEmpty()) {
throw new IllegalArgumentException("至少提供一种模态的内容");
}
// 构建请求
AliyunMultiModalEmbedRequest request = buildRequest(contents, chatModelVo);
AliyunMultiModalEmbedResponse resp = executeRequest(request, chatModelVo);
// 转换为 embeddings
Response<List<Embedding>> response = toEmbeddings(resp);
List<Embedding> embeddings = response.content();
if (embeddings.isEmpty()) {
log.warn("阿里云混合模态嵌入返回为空");
return Response.from(Embedding.from(new float[0]), response.tokenUsage());
}
// 多模态通常取第一个向量作为代表,也可以根据业务场景返回多个
return Response.from(embeddings.get(0), response.tokenUsage());
} catch (Exception e) {
log.error("阿里云混合模态嵌入失败", e);
throw new IllegalArgumentException("阿里云混合模态嵌入失败", e);
}
}
/**
* 配置模型参数
* @param config 模型配置信息
*/
@Override
public void configure(ChatModelVo config) {
this.chatModelVo = config;
}
/**
* 获取支持的模态类型
* @return 支持的模态类型集合
*/
@Override
public Set<ModalityType> getSupportedModalities() {
return Set.of(ModalityType.TEXT, ModalityType.VIDEO, ModalityType.IMAGE);
}
/**
* 批量文本嵌入向量生成
* @param textSegments 文本段列表
* @return 包含所有文本嵌入向量的Response对象
*/
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
if (textSegments.isEmpty()) return Response.from(Collections.emptyList());
try {
List<Map<String, Object>> contents = new ArrayList<>();
for (TextSegment segment : textSegments) {
contents.add(Map.of("text", segment.text()));
}
AliyunMultiModalEmbedRequest request = buildRequest(contents, chatModelVo);
AliyunMultiModalEmbedResponse resp = executeRequest(request, chatModelVo);
return toEmbeddings(resp);
} catch (Exception e) {
log.error("阿里云文本嵌入失败", e);
throw new IllegalArgumentException("阿里云文本嵌入失败", e);
}
}
/**
* 单模态嵌入(图片/视频/单条文本)复用方法
* @param key 模态类型image/video/text
* @param dataUrl 数据URL
* @return 包含嵌入向量的Response对象
*/
public Response<Embedding> embedSingleModality(String key, String dataUrl) {
try {
AliyunMultiModalEmbedRequest request = buildRequest(List.of(Map.of(key, dataUrl)), chatModelVo);
AliyunMultiModalEmbedResponse resp = executeRequest(request, chatModelVo);
Response<List<Embedding>> response = toEmbeddings(resp);
List<Embedding> embeddings = response.content();
if (embeddings.isEmpty()) {
log.warn("阿里云 {} 嵌入返回为空", key);
return Response.from(Embedding.from(new float[0]), response.tokenUsage());
}
return Response.from(embeddings.get(0), response.tokenUsage());
} catch (Exception e) {
log.error("阿里云 {} 嵌入失败", key, e);
throw new IllegalArgumentException("阿里云 " + key + " 嵌入失败", e);
}
}
/**
* 构建请求对象
* @param contents 请求内容列表
* @param chatModelVo 模型配置信息
* @return 构建好的请求对象
*/
private AliyunMultiModalEmbedRequest buildRequest(List<Map<String, Object>> contents, ChatModelVo chatModelVo) {
if (contents.isEmpty()) throw new IllegalArgumentException("请求内容不能为空");
return AliyunMultiModalEmbedRequest.create(chatModelVo.getModelName(), contents);
}
/**
* 执行 HTTP 请求并解析响应
* @param request 请求对象
* @param chatModelVo 模型配置信息
* @return API响应对象
* @throws IOException IO异常
*/
private AliyunMultiModalEmbedResponse executeRequest(AliyunMultiModalEmbedRequest request, ChatModelVo chatModelVo) throws IOException {
String jsonBody = request.toJson();
RequestBody body = RequestBody.create(jsonBody, MediaType.get("application/json"));
Request httpRequest = new Request.Builder()
.url(chatModelVo.getApiHost())
.addHeader("Authorization", "Bearer " + chatModelVo.getApiKey())
.post(body)
.build();
try (okhttp3.Response response = okHttpClient.newCall(httpRequest).execute()) {
if (!response.isSuccessful()) {
String err = response.body() != null ? response.body().string() : "无错误信息";
throw new IllegalArgumentException("API调用失败: " + response.code() + " - " + err, null);
}
ResponseBody responseBody = response.body();
if (responseBody == null) throw new IllegalArgumentException("响应体为空", null);
return parseEmbeddingsFromResponse(responseBody.string());
}
}
/**
* 解析嵌入向量列表
* @param responseBody API响应的JSON字符串
* @return 嵌入向量响应对象
* @throws IOException IO异常
*/
private AliyunMultiModalEmbedResponse parseEmbeddingsFromResponse(String responseBody) throws IOException {
ObjectMapper objectMapper1 = new ObjectMapper();
return objectMapper1.readValue(responseBody, AliyunMultiModalEmbedResponse.class);
}
/**
* 构建 API 请求内容 Map
* @param input 多模态输入对象
* @return 包含各种模态内容的Map列表
*/
private List<Map<String, Object>> buildContentMap(MultiModalInput input) {
List<Map<String, Object>> contents = new ArrayList<>();
if (input.getText() != null && !input.getText().isBlank()) {
contents.add(Map.of("text", input.getText()));
}
if (input.getImageUrl() != null && !input.getImageUrl().isBlank()) {
contents.add(Map.of("image", input.getImageUrl()));
}
if (input.getVideoUrl() != null && !input.getVideoUrl().isBlank()) {
contents.add(Map.of("video", input.getVideoUrl()));
}
if (input.getMultiImageUrls() != null && input.getMultiImageUrls().length > 0) {
contents.add(Map.of("multi_images", Arrays.asList(input.getMultiImageUrls())));
}
return contents;
}
/**
* 将 API 原始响应解析为 LangChain4j 的 Response<Embedding>
* @param resp API原始响应对象
* @return 包含嵌入向量和token使用情况的Response对象
*/
private Response<List<Embedding>> toEmbeddings(AliyunMultiModalEmbedResponse resp) {
if (resp == null || resp.output() == null || resp.output().embeddings() == null) {
return Response.from(Collections.emptyList());
}
// 转换 double -> float
List<Embedding> embeddings = resp.output().embeddings().stream()
.map(item -> {
float[] vector = new float[item.embedding().size()];
for (int i = 0; i < item.embedding().size(); i++) {
vector[i] = item.embedding().get(i).floatValue();
}
return Embedding.from(vector);
})
.toList();
// 构建 TokenUsage
TokenUsage tokenUsage = null;
if (resp.usage() != null) {
tokenUsage = new TokenUsage(
resp.usage().input_tokens(),
resp.usage().image_tokens(),
resp.usage().input_tokens() +resp.usage().image_tokens()
);
}
return Response.from(embeddings, tokenUsage);
}
}

View File

@@ -0,0 +1,42 @@
package org.ruoyi.embedding.impl;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.output.Response;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.embedding.BaseEmbedModelService;
import org.ruoyi.embedding.model.ModalityType;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Set;
/**
* @Author: Robust_H
* @Date: 2025-09-30-下午3:00
* @Description: Ollama嵌入模型
*/
@Component("ollama")
public class OllamaEmbeddingProvider implements BaseEmbedModelService {
private ChatModelVo chatModelVo;
@Override
public void configure(ChatModelVo config) {
this.chatModelVo = config;
}
@Override
public Set<ModalityType> getSupportedModalities() {
return Set.of(ModalityType.TEXT);
}
// ollama不能设置embedding维度使用milvus时请注意创建向量表时需要先设定维度大小
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
return OllamaEmbeddingModel.builder()
.baseUrl(chatModelVo.getApiHost())
.modelName(chatModelVo.getModelName())
.build()
.embedAll(textSegments);
}
}

View File

@@ -0,0 +1,44 @@
package org.ruoyi.embedding.impl;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.output.Response;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.embedding.BaseEmbedModelService;
import org.ruoyi.embedding.model.ModalityType;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Set;
/**
* @Author: Robust_H
* @Date: 2025-09-30-下午3:59
* @Description: OpenAi嵌入模型
*/
@Component("openai")
public class OpenAiEmbeddingProvider implements BaseEmbedModelService {
protected ChatModelVo chatModelVo;
@Override
public void configure(ChatModelVo config) {
this.chatModelVo = config;
}
@Override
public Set<ModalityType> getSupportedModalities() {
return Set.of(ModalityType.TEXT);
}
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
return OpenAiEmbeddingModel.builder()
.baseUrl(chatModelVo.getApiHost())
.apiKey(chatModelVo.getApiKey())
.modelName(chatModelVo.getModelName())
.dimensions(chatModelVo.getDimension())
.build()
.embedAll(textSegments);
}
}

View File

@@ -0,0 +1,18 @@
package org.ruoyi.embedding.impl;
import org.ruoyi.embedding.BaseEmbedModelService;
import org.ruoyi.embedding.model.ModalityType;
import org.springframework.stereotype.Component;
import java.util.Set;
/**
* @Author: Robust_H
* @Date: 2025-09-30-下午3:59
* @Description: 硅基流动(兼容 OpenAi
*/
@Component("siliconflow")
public class SiliconFlowEmbeddingProvider extends OpenAiEmbeddingProvider {
}

View File

@@ -0,0 +1,44 @@
package org.ruoyi.embedding.impl;
import dev.langchain4j.community.model.zhipu.ZhipuAiEmbeddingModel;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.embedding.BaseEmbedModelService;
import org.ruoyi.embedding.model.ModalityType;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Set;
/**
* @Author: Robust_H
* @Date: 2025-09-30-下午4:02
* @Description: 智谱AI
*/
@Component("zhipu")
public class ZhiPuAiEmbeddingProvider implements BaseEmbedModelService {
private ChatModelVo chatModelVo;
@Override
public void configure(ChatModelVo config) {
this.chatModelVo = config;
}
@Override
public Set<ModalityType> getSupportedModalities() {
return Set.of();
}
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
return ZhipuAiEmbeddingModel.builder()
.baseUrl(chatModelVo.getApiHost())
.apiKey(chatModelVo.getApiKey())
.model(chatModelVo.getModelName())
.dimensions(chatModelVo.getDimension())
.build()
.embedAll(textSegments);
}
}

View File

@@ -0,0 +1,44 @@
package org.ruoyi.embedding.model;
import org.ruoyi.common.json.utils.JsonUtils;
import lombok.Data;
import java.util.List;
import java.util.Map;
/**
* @Author: Robust_H
* @Date: 2025-10-1-上午10:00
* @Description: 阿里云多模态嵌入请求
*/
@Data
public class AliyunMultiModalEmbedRequest {
private String model;
private Input input;
/**
* 表示输入数据的记录类(Record)
* 该类用于封装一个包含多个映射关系列表的输入数据结构
*
* @param contents 包含多个Map的列表每个Map中存储String类型的键和Object类型的值
*/
public record Input(List<Map<String, Object>> contents) { }
/**
* 创建请求对象
*/
public static AliyunMultiModalEmbedRequest create(String modelName, List<Map<String, Object>> contents) {
AliyunMultiModalEmbedRequest request = new AliyunMultiModalEmbedRequest();
request.setModel(modelName);
Input input = new Input(contents);
request.setInput(input);
return request;
}
/**
* 转换为JSON字符串
*/
public String toJson() {
return JsonUtils.toJsonString(this);
}
}

View File

@@ -0,0 +1,44 @@
package org.ruoyi.embedding.model;
import java.util.List;
/**
* 阿里云多模态嵌入 API 响应数据模型
*/
public record AliyunMultiModalEmbedResponse(
Output output, // 输出结果对象
String request_id, // 请求唯一标识
String code, // 错误码
String message, // 错误消息
Usage usage // 用量信息
) {
/**
* 输出对象,包含嵌入向量结果
*/
public record Output(
List<EmbeddingItem> embeddings // 嵌入向量列表
) {
}
/**
* 单个嵌入向量条目
*/
public record EmbeddingItem(
int index, // 输入内容的索引
List<Double> embedding, // 生成的 1024 维向量
String type // 输入的类型 (text/image/video/multi_images)
) {
}
/**
* 用量统计信息
*/
public record Usage(
int input_tokens, // 本次请求输入的 Token 数量
int image_tokens, // 本次请求输入的图像 Token 数量
int image_count, // 本次请求输入的图像数量
int duration // 本次请求输入的视频时长(秒)
) {
}
}

View File

@@ -0,0 +1,8 @@
package org.ruoyi.embedding.model;
/**
* 模态类型
*/
public enum ModalityType {
TEXT, IMAGE, AUDIO, VIDEO, MULTI
}

View File

@@ -0,0 +1,71 @@
package org.ruoyi.embedding.model;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import lombok.Builder;
import lombok.Data;
/**
* @Author: Robust_H
* @Date: 2025-09-30-下午2:13
* @Description: 多模态输入
*/
@Data
@Builder
public class MultiModalInput {
private String text;
private byte[] imageData;
private byte[] videoData;
private String imageMimeType;
private String videoMimeType;
private String[] multiImageUrls;
private String imageUrl;
private String videoUrl;
/**
* 检查是否有文本内容
*/
public boolean hasText() {
return StrUtil.isNotBlank(text);
}
/**
* 检查是否有图片内容
*/
public boolean hasImage() {
return ArrayUtil.isNotEmpty(imageData) || StrUtil.isNotBlank(imageUrl);
}
/**
* 检查是否有视频内容
*/
public boolean hasVideo() {
return ArrayUtil.isNotEmpty(videoData) || StrUtil.isNotBlank(videoUrl);
}
/**
* 检查是否有多图片
*/
public boolean hasMultiImages() {
return ArrayUtil.isNotEmpty(multiImageUrls);
}
/**
* 检查是否有任何内容
*/
public boolean hasAnyContent() {
return hasText() || hasImage() || hasVideo() || hasMultiImages();
}
/**
* 获取内容的数量
*/
public int getContentCount() {
int count = 0;
if (hasText()) count++;
if (hasImage()) count++;
if (hasVideo()) count++;
if (hasMultiImages()) count++;
return count;
}
}

View File

@@ -1,5 +1,6 @@
package org.ruoyi.service; package org.ruoyi.service;
import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.domain.bo.QueryVectorBo; import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo; import org.ruoyi.domain.bo.StoreEmbeddingBo;
@@ -11,15 +12,15 @@ import java.util.List;
*/ */
public interface VectorStoreService { public interface VectorStoreService {
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo); void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) throws ServiceException;
List<String> getQueryVector(QueryVectorBo queryVectorBo); List<String> getQueryVector(QueryVectorBo queryVectorBo);
void createSchema(String kid,String modelName); void createSchema(String kid, String embeddingModelName);
void removeById(String id,String modelName); void removeById(String id,String modelName) throws ServiceException;
void removeByDocId(String docId, String kid); void removeByDocId(String docId, String kid) throws ServiceException;
void removeByFid(String fid, String kid); void removeByFid(String fid, String kid) throws ServiceException;
} }

View File

@@ -1,37 +1,14 @@
package org.ruoyi.service.impl; package org.ruoyi.service.impl;
import cn.hutool.json.JSONObject;
import com.google.protobuf.ServiceException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.bo.QueryVectorBo; import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo; import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.service.VectorStoreService; import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.strategy.VectorStoreStrategyFactory;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
/** /**
* 向量库管理 * 向量库管理
@@ -39,235 +16,61 @@ import java.util.stream.Collectors;
* @author ageer * @author ageer
*/ */
@Service @Service
@Primary
@Slf4j @Slf4j
@RequiredArgsConstructor @RequiredArgsConstructor
public class VectorStoreServiceImpl implements VectorStoreService { public class VectorStoreServiceImpl implements VectorStoreService {
private final ConfigService configService; private final VectorStoreStrategyFactory strategyFactory;
// private EmbeddingStore<TextSegment> embeddingStore;
private WeaviateClient client;
/**
* 获取当前配置的向量库策略
*/
private VectorStoreService getCurrentStrategy() {
return strategyFactory.getStrategy();
}
@Override @Override
public void createSchema(String kid, String modelName) { public void createSchema(String kid, String modelName) {
String protocol = configService.getConfigValue("weaviate", "protocol"); VectorStoreService strategy = getCurrentStrategy();
String host = configService.getConfigValue("weaviate", "host"); strategy.createSchema(kid, modelName);
String className = configService.getConfigValue("weaviate", "classname")+kid;
// 创建 Weaviate 客户端
client= new WeaviateClient(new Config(protocol, host));
// 检查类是否存在,如果不存在就创建 schema
Result<Schema> schemaResult = client.schema().getter().run();
Schema schema = schemaResult.getResult();
boolean classExists = false;
for (WeaviateClass weaviateClass : schema.getClasses()) {
if (weaviateClass.getClassName().equals(className)) {
classExists = true;
break;
}
}
if (!classExists) {
// 类不存在,创建 schema
WeaviateClass build = WeaviateClass.builder()
.className(className)
.vectorizer("none")
.properties(
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
)
.build();
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
if (createResult.hasErrors()) {
log.error("Schema 创建失败: {}", createResult.getError());
} else {
log.info("Schema 创建成功: {}", className);
}
}
// embeddingStore = WeaviateEmbeddingStore.builder()
// .scheme(protocol)
// .host(host)
// .objectClass(className)
// .scheme(protocol)
// .avoidDups(true)
// .consistencyLevel("ALL")
// .build();
} }
@Override @Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName()); log.info("存储向量数据: kid={}, docId={}, 数据条数={}",
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getDocId(), storeEmbeddingBo.getChunkList().size());
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); VectorStoreService strategy = getCurrentStrategy();
List<String> chunkList = storeEmbeddingBo.getChunkList(); strategy.storeEmbeddings(storeEmbeddingBo);
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
log.info("向量存储条数记录: " + chunkList.size());
long startTime = System.currentTimeMillis();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = embeddingModel.embed(text).content();
Map<String, Object> properties = Map.of(
"text", text,
"fid",fid,
"kid", kid,
"docId", docId
);
Float[] vector = toObjectArray(embedding.vector());
client.data().creator()
.withClassName("LocalKnowledge" + kid) // 注意替换成实际类名
.withProperties(properties)
.withVector(vector)
.run();
}
long endTime = System.currentTimeMillis();
log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"");
} }
private static Float[] toObjectArray(float[] primitive) {
Float[] result = new Float[primitive.length];
for (int i = 0; i < primitive.length; i++) {
result[i] = primitive[i]; // 自动装箱
}
return result;
}
@Override @Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) { public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); log.info("查询向量数据: kid={}, query={}, maxResults={}",
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), queryVectorBo.getKid(), queryVectorBo.getQuery(), queryVectorBo.getMaxResults());
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); VectorStoreService strategy = getCurrentStrategy();
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); return strategy.getQueryVector(queryVectorBo);
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
vectorStrings.add(String.valueOf(v));
}
String vectorStr = String.join(",", vectorStrings);
String className = configService.getConfigValue("weaviate", "classname") ;
// 构建 GraphQL 查询
String graphQLQuery = String.format(
"{\n" +
" Get {\n" +
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
" text\n" +
" fid\n" +
" kid\n" +
" docId\n" +
" _additional {\n" +
" distance\n" +
" id\n" +
" }\n" +
" }\n" +
" }\n" +
"}",
className+ queryVectorBo.getKid(),
vectorStr,
queryVectorBo.getMaxResults()
);
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
List<String> resultList = new ArrayList<>();
if (result != null && !result.hasErrors()) {
Object data = result.getResult().getData();
JSONObject entries = new JSONObject(data);
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
if(objects.isEmpty()){
return resultList;
}
for (Object object : objects) {
Map<String, String> map = (Map<String, String>) object;
String content = map.get("text");
resultList.add( content);
}
return resultList;
} else {
log.error("GraphQL 查询失败: {}", result.getError());
return resultList;
}
} }
@Override @Override
@SneakyThrows
public void removeById(String id, String modelName) { public void removeById(String id, String modelName) {
String protocol = configService.getConfigValue("weaviate", "protocol"); log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
String host = configService.getConfigValue("weaviate", "host"); VectorStoreService strategy = getCurrentStrategy();
String className = configService.getConfigValue("weaviate", "classname"); strategy.removeById(id, modelName);
String finalClassName = className + id;
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
Result<Boolean> result = client.schema().classDeleter().withClassName(finalClassName).run();
if (result.hasErrors()) {
log.error("失败删除向量: " + result.getError());
throw new ServiceException("失败删除向量数据!");
} else {
log.info("成功删除向量数据: " + result.getResult());
}
} }
@Override @Override
public void removeByDocId(String docId, String kid) { public void removeByDocId(String docId, String kid) {
String className = configService.getConfigValue("weaviate", "classname") + kid; log.info("根据docId删除向量数据: docId={}, kid={}", docId, kid);
// 构建 Where 条件 VectorStoreService strategy = getCurrentStrategy();
WhereFilter whereFilter = WhereFilter.builder() strategy.removeByDocId(docId, kid);
.path("docId")
.operator(Operator.Equal)
.valueText(docId)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 docId={} 的所有向量数据", docId);
} else {
log.error("删除失败: {}", result.getError());
}
} }
@Override @Override
public void removeByFid(String fid, String kid) { public void removeByFid(String fid, String kid) {
String className = configService.getConfigValue("weaviate", "classname") + kid; log.info("根据fid删除向量数据: fid={}, kid={}", fid, kid);
// 构建 Where 条件 VectorStoreService strategy = getCurrentStrategy();
WhereFilter whereFilter = WhereFilter.builder() strategy.removeByFid(fid, kid);
.path("fid")
.operator(Operator.Equal)
.valueText(fid)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 fid={} 的所有向量数据", fid);
} else {
log.error("删除失败: {}", result.getError());
}
} }
/**
* 获取向量模型
*/
@SneakyThrows
public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
EmbeddingModel embeddingModel;
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
embeddingModel = OllamaEmbeddingModel.builder()
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else if ("baai/bge-m3".equals(modelName)) {
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else {
throw new ServiceException("未找到对应向量化模型!");
}
return embeddingModel;
}
} }

View File

@@ -0,0 +1,52 @@
package org.ruoyi.service.strategy;
import org.ruoyi.common.core.exception.ServiceException;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.embedding.EmbeddingModelFactory;
/**
* 向量库策略抽象基类
* 提供公共的方法实现如embedding模型获取等
*
* @author Yzm
*/
@Slf4j
@RequiredArgsConstructor
public abstract class AbstractVectorStoreStrategy implements VectorStoreService {
protected final VectorStoreProperties vectorStoreProperties;
private final EmbeddingModelFactory embeddingModelFactory;
/**
* 获取向量模型
*/
@SneakyThrows
protected EmbeddingModel getEmbeddingModel(String modelName, Integer dimension) {
return embeddingModelFactory.createModel(modelName, dimension);
}
/**
* 将float数组转换为Float对象数组
*/
protected static Float[] toObjectArray(float[] primitive) {
Float[] result = new Float[primitive.length];
for (int i = 0; i < primitive.length; i++) {
result[i] = primitive[i]; // 自动装箱
}
return result;
}
/**
* 获取向量库类型标识
*/
public abstract String getVectorStoreType();
}

View File

@@ -0,0 +1,57 @@
package org.ruoyi.service.strategy;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.service.strategy.impl.MilvusVectorStoreStrategy;
import org.ruoyi.service.strategy.impl.WeaviateVectorStoreStrategy;
import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
/**
* 向量库策略工厂
* 根据配置动态选择向量库实现
*
* @author Yzm
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class VectorStoreStrategyFactory {
private final VectorStoreProperties vectorStoreProperties;
private final WeaviateVectorStoreStrategy weaviateStrategy;
private final MilvusVectorStoreStrategy milvusStrategy;
private Map<String, VectorStoreService> strategies;
@PostConstruct
public void init() {
strategies = new HashMap<>();
strategies.put("weaviate", weaviateStrategy);
strategies.put("milvus", milvusStrategy);
log.info("向量库策略工厂初始化完成,支持的策略: {}", strategies.keySet());
}
/**
* 获取当前配置的向量库策略
*/
public VectorStoreService getStrategy() {
String vectorStoreType = vectorStoreProperties.getType();
if (vectorStoreType == null || vectorStoreType.trim().isEmpty()) {
vectorStoreType = "weaviate"; // 默认使用weaviate
}
VectorStoreService strategy = strategies.get(vectorStoreType.toLowerCase());
if (strategy == null) {
log.warn("未找到向量库策略: {}, 使用默认策略: weaviate", vectorStoreType);
strategy = strategies.get("weaviate");
}
log.debug("使用向量库策略: {}", vectorStoreType);
return strategy;
}
}

View File

@@ -0,0 +1,157 @@
package org.ruoyi.service.strategy.impl;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.IntStream;
@Slf4j
@Component
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
private final Integer DIMENSION = 2048;
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
super(vectorStoreProperties, embeddingModelFactory);
}
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
private final Map<String, EmbeddingStore<TextSegment>> storeCache = new ConcurrentHashMap<>();
private EmbeddingStore<TextSegment> getMilvusStore(String collectionName, boolean autoFlushOnInsert) {
String key = collectionName + "|" + autoFlushOnInsert;
return storeCache.computeIfAbsent(key, k ->
MilvusEmbeddingStore.builder()
.uri(vectorStoreProperties.getMilvus().getUrl())
.collectionName(collectionName)
.dimension(DIMENSION)
.indexType(IndexType.IVF_FLAT)
.metricType(MetricType.L2)
.autoFlushOnInsert(autoFlushOnInsert)
.idFieldName("id")
.textFieldName("text")
.metadataFieldName("metadata")
.vectorFieldName("vector")
.build()
);
}
@Override
public void createSchema(String kid, String modelName) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
// 使用缓存获取连接以确保只初始化一次
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, true);
log.info("Milvus集合初始化完成: {}", collectionName);
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION);
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
log.info("Milvus向量存储条数记录: {}", chunkList.size());
long startTime = System.currentTimeMillis();
// 复用连接,写入场景使用 autoFlush=false 以提升批量插入性能
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
IntStream.range(0, chunkList.size()).forEach(i -> {
String text = chunkList.get(i);
String fid = fidList.get(i);
Metadata metadata = new Metadata();
metadata.put("fid", fid);
metadata.put("kid", kid);
metadata.put("docId", docId);
TextSegment textSegment = TextSegment.from(text, metadata);
Embedding embedding = embeddingModel.embed(text).content();
embeddingStore.add(embedding, textSegment);
});
long endTime = System.currentTimeMillis();
log.info("Milvus向量存储完成消耗时间{}秒", (endTime - startTime) / 1000);
}
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION);
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
// 查询复用连接autoFlush 对查询无影响,此处保持 true
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, true);
List<String> resultList = new ArrayList<>();
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(queryVectorBo.getMaxResults())
.build();
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
for (EmbeddingMatch<TextSegment> match : matches) {
TextSegment segment = match.embedded();
if (segment != null) {
resultList.add(segment.text());
}
}
return resultList;
}
@Override
@SneakyThrows
public void removeById(String id, String modelName) {
// 注意:此处原逻辑使用 collectionname + id保持现状
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, false);
embeddingStore.remove(id);
}
@Override
public void removeByDocId(String docId, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
Filter filter = MetadataFilterBuilder.metadataKey("docId").isEqualTo(docId);
embeddingStore.removeAll(filter);
log.info("Milvus成功删除 docId={} 的所有向量数据", docId);
}
@Override
public void removeByFid(String fid, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
Filter filter = MetadataFilterBuilder.metadataKey("fid").isEqualTo(fid);
embeddingStore.removeAll(filter);
log.info("Milvus成功删除 fid={} 的所有向量数据", fid);
}
@Override
public String getVectorStoreType() {
return "milvus";
}
}

View File

@@ -0,0 +1,232 @@
package org.ruoyi.service.strategy.impl;
import cn.hutool.json.JSONObject;
import org.ruoyi.common.core.exception.ServiceException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.*;
/**
* Weaviate向量库策略实现
*
* @author Yzm
*/
@Slf4j
@Component
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
private WeaviateClient client;
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
super(vectorStoreProperties, embeddingModelFactory);
}
@Override
public String getVectorStoreType() {
return "weaviate";
}
@Override
public void createSchema(String kid, String embeddingModelName) {
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
String host = vectorStoreProperties.getWeaviate().getHost();
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
// 创建 Weaviate 客户端
client = new WeaviateClient(new Config(protocol, host));
// 检查类是否存在,如果不存在就创建 schema
Result<Schema> schemaResult = client.schema().getter().run();
Schema schema = schemaResult.getResult();
boolean classExists = false;
for (WeaviateClass weaviateClass : schema.getClasses()) {
if (weaviateClass.getClassName().equals(className)) {
classExists = true;
break;
}
}
if (!classExists) {
// 类不存在,创建 schema
WeaviateClass build = WeaviateClass.builder()
.className(className)
.vectorizer("none")
.properties(
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
)
.build();
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
if (createResult.hasErrors()) {
log.error("Schema 创建失败: {}", createResult.getError());
} else {
log.info("Schema 创建成功: {}", className);
}
}
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getKid(),storeEmbeddingBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), null);
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
log.info("向量存储条数记录: " + chunkList.size());
long startTime = System.currentTimeMillis();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = embeddingModel.embed(text).content();
Map<String, Object> properties = Map.of(
"text", text,
"fid", fid,
"kid", kid,
"docId", docId
);
Float[] vector = toObjectArray(embedding.vector());
client.data().creator()
.withClassName("LocalKnowledge" + kid)
.withProperties(properties)
.withVector(vector)
.run();
}
long endTime = System.currentTimeMillis();
log.info("向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "");
}
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getKid(),queryVectorBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),null);
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
vectorStrings.add(String.valueOf(v));
}
String vectorStr = String.join(",", vectorStrings);
String className = vectorStoreProperties.getWeaviate().getClassname();
// 构建 GraphQL 查询
String graphQLQuery = String.format(
"{\n" +
" Get {\n" +
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
" text\n" +
" fid\n" +
" kid\n" +
" docId\n" +
" _additional {\n" +
" distance\n" +
" id\n" +
" }\n" +
" }\n" +
" }\n" +
"}",
className + queryVectorBo.getKid(),
vectorStr,
queryVectorBo.getMaxResults()
);
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
List<String> resultList = new ArrayList<>();
if (result != null && !result.hasErrors()) {
Object data = result.getResult().getData();
JSONObject entries = new JSONObject(data);
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
if (objects.isEmpty()) {
return resultList;
}
for (Object object : objects) {
Map<String, String> map = (Map<String, String>) object;
String content = map.get("text");
resultList.add(content);
}
return resultList;
} else {
log.error("GraphQL 查询失败: {}", result.getError());
return resultList;
}
}
@Override
@SneakyThrows
public void removeById(String id, String modelName) {
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
String host = vectorStoreProperties.getWeaviate().getHost();
String className = vectorStoreProperties.getWeaviate().getClassname();
String finalClassName = className + id;
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
Result<Boolean> result = client.schema().classDeleter().withClassName(finalClassName).run();
if (result.hasErrors()) {
log.error("失败删除向量: " + result.getError());
throw new ServiceException("失败删除向量数据!");
} else {
log.info("成功删除向量数据: " + result.getResult());
}
}
@Override
public void removeByDocId(String docId, String kid) {
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("docId")
.operator(Operator.Equal)
.valueText(docId)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 docId={} 的所有向量数据", docId);
} else {
log.error("删除失败: {}", result.getError());
}
}
@Override
public void removeByFid(String fid, String kid) {
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("fid")
.operator(Operator.Equal)
.valueText(fid)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 fid={} 的所有向量数据", fid);
} else {
log.error("删除失败: {}", result.getError());
}
}
}

View File

@@ -18,7 +18,7 @@ import java.io.Serializable;
*/ */
@Data @Data
@ExcelIgnoreUnannotated @ExcelIgnoreUnannotated
@AutoMapper(target = ChatConfig.class) @AutoMapper(target = ChatConfig.class)
public class ChatConfigVo implements Serializable { public class ChatConfigVo implements Serializable {
@Serial @Serial

View File

@@ -0,0 +1,133 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-modules-api</artifactId>
<version>${revision}</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>ruoyi-workflow-api</artifactId>
<description>
工作流API模块
</description>
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-system-api</artifactId>
</dependency>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-common-satoken</artifactId>
</dependency>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-common-mail</artifactId>
</dependency>
<dependency>
<groupId>org.ruoyi</groupId>
<artifactId>ruoyi-chat</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>1.2.0</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.12</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.bsc.langgraph4j</groupId>
<artifactId>langgraph4j-core</artifactId>
<version>1.5.3</version>
</dependency>
<dependency>
<groupId>org.bsc.langgraph4j</groupId>
<artifactId>langgraph4j-langchain4j</artifactId>
<version>1.5.3</version>
</dependency>
<dependency>
<groupId>io.swagger.core.v3</groupId>
<artifactId>swagger-annotations</artifactId>
<version>2.2.8</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>1.2.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-community-dashscope</artifactId>
<version>1.2.0-beta8</version>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-generator</artifactId>
<version>3.5.3.1</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-http-client-jdk</artifactId>
<version>1.2.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-document-parser-apache-poi</artifactId>
<version>1.2.0-beta8</version>
</dependency>
<dependency>
<groupId>com.google.api-client</groupId>
<artifactId>google-api-client</artifactId>
<version>2.6.0</version>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,43 @@
package org.ruoyi.workflow;
import com.baomidou.mybatisplus.generator.FastAutoGenerator;
import com.baomidou.mybatisplus.generator.config.OutputFile;
import com.baomidou.mybatisplus.generator.config.rules.DbColumnType;
import java.sql.Types;
import java.util.Collections;
public class CodeGenerator {
public static void main(String[] args) {
FastAutoGenerator.create("jdbc:postgres://172.17.30.40:5432/aideepin?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true", "postgres", "postgres")
.globalConfig(builder -> {
builder.author("moyz") // 设置作者
.enableSwagger() // 开启 swagger 模式
.fileOverride() // 覆盖已生成文件
.outputDir("D://"); // 指定输出目录
})
.dataSourceConfig(builder -> builder.typeConvertHandler((globalConfig, typeRegistry, metaInfo) -> {
int typeCode = metaInfo.getJdbcType().TYPE_CODE;
if (typeCode == Types.SMALLINT) {
// 自定义类型转换
return DbColumnType.INTEGER;
}
return typeRegistry.getColumnType(metaInfo);
}))
.packageConfig(builder -> {
builder.mapper("com.adi.common.mapper")
.parent("")
.moduleName("")
.entity("po")
.serviceImpl("service.impl")
.pathInfo(Collections.singletonMap(OutputFile.xml, "D://mybatisplus-generatorcode")); // 设置mapperXml生成路径
})
.strategyConfig(builder -> {
builder.addInclude("adi_knowledge_base_qa_record") // 设置需要生成的表名
.addTablePrefix("adi_");
builder.mapperBuilder().enableBaseResultMap().enableMapperAnnotation().build();
})
.execute();
}
}

View File

@@ -0,0 +1,51 @@
package org.ruoyi.workflow.base;
import lombok.Data;
import org.ruoyi.workflow.enums.ErrorEnum;
import java.io.Serializable;
@Data
public class BaseResponse<T> implements Serializable {
private static final long serialVersionUID = 1L;
/**
* 是否成功
*/
private boolean success;
/**
* 状态码
*/
private String code;
/**
* 提示
*/
private String message;
/**
* 数据
*/
private T data;
public BaseResponse() {
}
public BaseResponse(boolean success) {
this.success = success;
}
public BaseResponse(boolean success, T data) {
this.data = data;
this.success = success;
}
public BaseResponse(String code, String message, T data) {
this.code = code;
this.success = false;
this.message = message;
this.data = data;
}
public static BaseResponse success(String message) {
return new BaseResponse(ErrorEnum.SUCCESS.getCode(), message, "");
}
}

View File

@@ -0,0 +1,118 @@
package org.ruoyi.workflow.base;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;
import org.ruoyi.workflow.enums.WfIODataTypeEnum;
import org.ruoyi.workflow.util.JsonUtil;
import org.ruoyi.workflow.workflow.WfNodeInputConfig;
import org.ruoyi.workflow.workflow.def.WfNodeIO;
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import static org.ruoyi.workflow.workflow.WfNodeIODataUtil.INPUT_TYPE_TO_NODE_IO_DEF;
@Slf4j
@MappedJdbcTypes({JdbcType.JAVA_OBJECT})
@MappedTypes({WfNodeInputConfig.class})
public class NodeInputConfigTypeHandler extends BaseTypeHandler<WfNodeInputConfig> {
public static WfNodeInputConfig fillNodeInputConfig(String jsonSource) {
ObjectNode jsonNode = (ObjectNode) JsonUtil.toJsonNode(jsonSource);
return createNodeInputConfig(jsonNode);
}
public static WfNodeInputConfig createNodeInputConfig(ObjectNode jsonNode) {
List<WfNodeIO> userInputs = new ArrayList<>();
WfNodeInputConfig result = new WfNodeInputConfig();
result.setUserInputs(userInputs);
result.setRefInputs(new ArrayList<>());
if (null == jsonNode) {
return result;
}
ArrayNode userInputsJson = jsonNode.withArray("user_inputs");
ArrayNode refInputs = jsonNode.withArray("ref_inputs");
if (!userInputsJson.isEmpty()) {
for (JsonNode userInput : userInputsJson) {
if (userInput instanceof ObjectNode objectNode) {
int type = objectNode.get("type").asInt();
Class<? extends WfNodeIO> nodeIOClass = INPUT_TYPE_TO_NODE_IO_DEF.get(WfIODataTypeEnum.getByValue(type));
WfNodeIO wfNodeIO = JsonUtil.fromJson(objectNode, nodeIOClass);
if (null != wfNodeIO) {
userInputs.add(wfNodeIO);
} else {
log.warn("用户输入格式不正确:{}", userInput);
}
}
}
}
if (!refInputs.isEmpty()) {
List<WfNodeParamRef> list = JsonUtil.fromArrayNode(refInputs, WfNodeParamRef.class);
if (CollectionUtils.isNotEmpty(list)) {
result.setRefInputs(list);
} else {
log.warn("引用输入格式不正确:{}", refInputs);
}
}
return result;
}
@Override
public void setNonNullParameter(PreparedStatement ps, int i, WfNodeInputConfig parameter, JdbcType jdbcType) {
// PGobject jsonObject = new PGobject();
// jsonObject.setType("jsonb");
// try {
// jsonObject.setValue(JsonUtil.toJson(parameter));
// ps.setObject(i, jsonObject);
// } catch (Exception e) {
// throw new RuntimeException(e);
// }
}
@Override
public WfNodeInputConfig getNullableResult(ResultSet rs, String columnName) throws SQLException {
String jsonSource = rs.getString(columnName);
if (jsonSource != null) {
try {
return fillNodeInputConfig(jsonSource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
@Override
public WfNodeInputConfig getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
String jsonSource = rs.getString(columnIndex);
if (jsonSource != null) {
return fillNodeInputConfig(jsonSource);
}
return null;
}
@Override
public WfNodeInputConfig getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
String jsonSource = cs.getString(columnIndex);
if (jsonSource != null) {
try {
return fillNodeInputConfig(jsonSource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
}

View File

@@ -0,0 +1,122 @@
package org.ruoyi.workflow.base;
import cn.dev33.satoken.stp.StpUtil;
import org.apache.commons.lang3.StringUtils;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.exception.base.BaseException;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.workflow.entity.User;
import org.ruoyi.workflow.enums.UserStatusEnum;
import static org.ruoyi.workflow.enums.ErrorEnum.A_USER_NOT_FOUND;
/**
* 线程上下文适配器,统一接入 Sa-Token 登录态。
*/
public class ThreadContext {
private static final ThreadLocal<User> CURRENT_USER = new ThreadLocal<>();
private static final ThreadLocal<String> CURRENT_TOKEN = new ThreadLocal<>();
private ThreadContext() {
}
/**
* 获取当前登录的工作流用户。
*/
public static User getCurrentUser() {
User cached = CURRENT_USER.get();
if (cached != null) {
return cached;
}
LoginUser loginUser = LoginHelper.getLoginUser();
if (loginUser == null) {
throw new BaseException(A_USER_NOT_FOUND.getInfo());
}
User mapped = mapToWorkflowUser(loginUser);
CURRENT_USER.set(mapped);
return mapped;
}
/**
* 允许在测试或特殊场景下显式设置当前用户。
*/
public static void setCurrentUser(User user) {
if (user == null) {
CURRENT_USER.remove();
} else {
CURRENT_USER.set(user);
}
}
/**
* 获取当前登录用户 ID。
*/
public static Long getCurrentUserId() {
Long userId = LoginHelper.getUserId();
if (userId != null) {
return userId;
}
return getCurrentUser().getId();
}
/**
* 获取当前访问 token。
*/
public static String getToken() {
String token = CURRENT_TOKEN.get();
if (StringUtils.isNotBlank(token)) {
return token;
}
try {
token = StpUtil.getTokenValue();
} catch (Exception ignore) {
token = null;
}
if (StringUtils.isNotBlank(token)) {
CURRENT_TOKEN.set(token);
}
return token;
}
public static void setToken(String token) {
if (StringUtils.isBlank(token)) {
CURRENT_TOKEN.remove();
} else {
CURRENT_TOKEN.set(token);
}
}
public static boolean isLogin() {
return LoginHelper.isLogin();
}
public static User getExistCurrentUser() {
return getCurrentUser();
}
public static void unload() {
CURRENT_USER.remove();
CURRENT_TOKEN.remove();
}
private static User mapToWorkflowUser(LoginUser loginUser) {
User user = new User();
user.setId(loginUser.getUserId());
String nickname = loginUser.getNickName();
user.setName(StringUtils.defaultIfBlank(nickname, loginUser.getUsername()));
user.setEmail(loginUser.getUsername());
user.setUuid(String.valueOf(loginUser.getUserId()));
user.setUserStatus(UserStatusEnum.NORMAL);
user.setIsAdmin(LoginHelper.isSuperAdmin(loginUser.getUserId()));
user.setUnderstandContextMsgPairNum(0);
user.setQuotaByTokenDaily(0);
user.setQuotaByTokenMonthly(0);
user.setQuotaByRequestDaily(0);
user.setQuotaByRequestMonthly(0);
user.setQuotaByImageDaily(0);
user.setQuotaByImageMonthly(0);
user.setIsDeleted(false);
return user;
}
}

Some files were not shown because too many files have changed in this diff Show More