mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-17 13:53:41 +00:00
Merge branch 'main' into main
This commit is contained in:
@@ -150,6 +150,13 @@
|
||||
<strong>QQ技术交流群</strong><br>
|
||||
<em>技术讨论</em>
|
||||
</td>
|
||||
|
||||
<td align="center">
|
||||
<img width="200" height="200" alt="95e8b1b3baeadbd24650bfb974ca5a58" src="https://github.com/user-attachments/assets/2a346218-6388-484d-aa75-6e98942193f7" /><br>
|
||||
<strong>微信技术交流群</strong><br>
|
||||
<em>技术讨论</em>
|
||||
</td>
|
||||
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
432
docs/工作流模块说明.md
Normal file
432
docs/工作流模块说明.md
Normal file
@@ -0,0 +1,432 @@
|
||||
# Ruoyi-AI 工作流模块详细说明文档
|
||||
|
||||
## 概述
|
||||
|
||||
Ruoyi-AI 工作流模块是一个基于 LangGraph4j 的智能工作流引擎,支持可视化工作流设计、AI 模型集成、条件分支、人机交互等高级功能。该模块采用微服务架构,提供完整的 RESTful API 和流式响应支持。
|
||||
|
||||
## 模块架构
|
||||
|
||||
### 1. 模块结构
|
||||
|
||||
```
|
||||
ruoyi-ai/
|
||||
├── ruoyi-modules/
|
||||
│ └── ruoyi-workflow/ # 工作流核心模块
|
||||
│ ├── pom.xml
|
||||
│ └── src/main/java/org/ruoyi/workflow/
|
||||
│ └── controller/ # 控制器层
|
||||
│ ├── WorkflowController.java
|
||||
│ ├── WorkflowRuntimeController.java
|
||||
│ └── admin/ # 管理端控制器
|
||||
│ ├── AdminWorkflowController.java
|
||||
│ └── AdminWorkflowComponentController.java
|
||||
└── ruoyi-modules-api/
|
||||
└── ruoyi-workflow-api/ # 工作流API模块
|
||||
├── pom.xml
|
||||
└── src/main/java/org/ruoyi/workflow/
|
||||
├── entity/ # 实体类
|
||||
├── dto/ # 数据传输对象
|
||||
├── service/ # 服务接口
|
||||
├── mapper/ # 数据访问层
|
||||
├── workflow/ # 工作流核心逻辑
|
||||
├── enums/ # 枚举类
|
||||
├── util/ # 工具类
|
||||
└── exception/ # 异常处理
|
||||
```
|
||||
|
||||
### 2. 核心依赖
|
||||
|
||||
- **LangGraph4j**: 1.5.3 - 工作流图执行引擎
|
||||
- **LangChain4j**: 1.2.0 - AI 模型集成框架
|
||||
- **Spring Boot**: 3.x - 应用框架
|
||||
- **MyBatis Plus**: 数据访问层
|
||||
- **Redis**: 缓存和状态管理
|
||||
- **Swagger/OpenAPI**: API 文档
|
||||
|
||||
## 核心功能
|
||||
|
||||
### 1. 工作流管理
|
||||
|
||||
#### 1.1 工作流定义
|
||||
- **创建工作流**: 支持自定义标题、描述、公开性设置
|
||||
- **编辑工作流**: 可视化节点编辑、连接线配置
|
||||
- **版本控制**: 支持工作流的版本管理和回滚
|
||||
- **权限管理**: 支持公开/私有工作流设置
|
||||
|
||||
#### 1.2 工作流执行
|
||||
- **流式执行**: 基于 SSE 的实时流式响应
|
||||
- **状态管理**: 完整的执行状态跟踪
|
||||
- **错误处理**: 详细的错误信息和异常处理
|
||||
- **中断恢复**: 支持工作流中断和恢复执行
|
||||
|
||||
### 2. 节点类型
|
||||
|
||||
#### 2.1 基础节点
|
||||
- **Start**: 开始节点,定义工作流入口
|
||||
- **End**: 结束节点,定义工作流出口
|
||||
|
||||
#### 2.2 AI 模型节点
|
||||
- **Answer**: 大语言模型问答节点
|
||||
- **Dalle3**: DALL-E 3 图像生成
|
||||
- **Tongyiwanx**: 通义万相图像生成
|
||||
- **Classifier**: 内容分类节点
|
||||
|
||||
#### 2.3 数据处理节点
|
||||
- **DocumentExtractor**: 文档信息提取
|
||||
- **KeywordExtractor**: 关键词提取
|
||||
- **FaqExtractor**: 常见问题提取
|
||||
- **KnowledgeRetrieval**: 知识库检索
|
||||
|
||||
#### 2.4 控制流节点
|
||||
- **Switcher**: 条件分支节点
|
||||
- **HumanFeedback**: 人机交互节点
|
||||
|
||||
#### 2.5 外部集成节点
|
||||
- **Google**: Google 搜索集成
|
||||
- **MailSend**: 邮件发送
|
||||
- **HttpRequest**: HTTP 请求
|
||||
- **Template**: 模板转换
|
||||
|
||||
### 3. 数据流管理
|
||||
|
||||
#### 3.1 输入输出定义
|
||||
```java
|
||||
// 节点输入输出数据结构
|
||||
public class NodeIOData {
|
||||
private String name; // 参数名称
|
||||
private NodeIODataContent content; // 参数内容
|
||||
}
|
||||
|
||||
// 支持的数据类型
|
||||
public enum WfIODataTypeEnum {
|
||||
TEXT, // 文本
|
||||
NUMBER, // 数字
|
||||
BOOLEAN, // 布尔值
|
||||
FILES, // 文件
|
||||
OPTIONS // 选项
|
||||
}
|
||||
```
|
||||
|
||||
#### 3.2 参数引用
|
||||
- **节点间引用**: 支持上游节点输出作为下游节点输入
|
||||
- **参数映射**: 自动处理参数名称映射
|
||||
- **类型转换**: 自动进行数据类型转换
|
||||
|
||||
## 数据库设计
|
||||
|
||||
### 1. 核心表结构
|
||||
|
||||
#### 1.1 工作流定义表 (t_workflow)
|
||||
```sql
|
||||
CREATE TABLE t_workflow (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
uuid VARCHAR(32) NOT NULL DEFAULT '',
|
||||
title VARCHAR(100) NOT NULL DEFAULT '',
|
||||
remark TEXT NOT NULL DEFAULT '',
|
||||
user_id BIGINT NOT NULL DEFAULT 0,
|
||||
is_public TINYINT(1) NOT NULL DEFAULT 0,
|
||||
is_enable TINYINT(1) NOT NULL DEFAULT 1,
|
||||
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
is_deleted TINYINT(1) NOT NULL DEFAULT 0
|
||||
);
|
||||
```
|
||||
|
||||
#### 1.2 工作流节点表 (t_workflow_node)
|
||||
```sql
|
||||
CREATE TABLE t_workflow_node (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
uuid VARCHAR(32) NOT NULL DEFAULT '',
|
||||
workflow_id BIGINT NOT NULL DEFAULT 0,
|
||||
workflow_component_id BIGINT NOT NULL DEFAULT 0,
|
||||
user_id BIGINT NOT NULL DEFAULT 0,
|
||||
title VARCHAR(100) NOT NULL DEFAULT '',
|
||||
remark VARCHAR(500) NOT NULL DEFAULT '',
|
||||
input_config JSON NOT NULL DEFAULT ('{}'),
|
||||
node_config JSON NOT NULL DEFAULT ('{}'),
|
||||
position_x DOUBLE NOT NULL DEFAULT 0,
|
||||
position_y DOUBLE NOT NULL DEFAULT 0,
|
||||
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
is_deleted TINYINT(1) NOT NULL DEFAULT 0
|
||||
);
|
||||
```
|
||||
|
||||
#### 1.3 工作流边表 (t_workflow_edge)
|
||||
```sql
|
||||
CREATE TABLE t_workflow_edge (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
uuid VARCHAR(32) NOT NULL DEFAULT '',
|
||||
workflow_id BIGINT NOT NULL DEFAULT 0,
|
||||
source_node_uuid VARCHAR(32) NOT NULL DEFAULT '',
|
||||
source_handle VARCHAR(32) NOT NULL DEFAULT '',
|
||||
target_node_uuid VARCHAR(32) NOT NULL DEFAULT '',
|
||||
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
is_deleted TINYINT(1) NOT NULL DEFAULT 0
|
||||
);
|
||||
```
|
||||
|
||||
#### 1.4 工作流运行时表 (t_workflow_runtime)
|
||||
```sql
|
||||
CREATE TABLE t_workflow_runtime (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
uuid VARCHAR(32) NOT NULL DEFAULT '',
|
||||
user_id BIGINT NOT NULL DEFAULT 0,
|
||||
workflow_id BIGINT NOT NULL DEFAULT 0,
|
||||
input JSON NOT NULL DEFAULT ('{}'),
|
||||
output JSON NOT NULL DEFAULT ('{}'),
|
||||
status SMALLINT NOT NULL DEFAULT 1,
|
||||
status_remark VARCHAR(250) NOT NULL DEFAULT '',
|
||||
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
is_deleted TINYINT(1) NOT NULL DEFAULT 0
|
||||
);
|
||||
```
|
||||
|
||||
#### 1.5 工作流组件表 (t_workflow_component)
|
||||
```sql
|
||||
CREATE TABLE t_workflow_component (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
uuid VARCHAR(32) DEFAULT '' NOT NULL,
|
||||
name VARCHAR(32) DEFAULT '' NOT NULL,
|
||||
title VARCHAR(100) DEFAULT '' NOT NULL,
|
||||
remark TEXT NOT NULL,
|
||||
display_order INT DEFAULT 0 NOT NULL,
|
||||
is_enable TINYINT(1) DEFAULT 0 NOT NULL,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
is_deleted TINYINT(1) DEFAULT 0 NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
## API 接口
|
||||
|
||||
### 1. 工作流管理接口
|
||||
|
||||
#### 1.1 基础操作
|
||||
```http
|
||||
# 创建工作流
|
||||
POST /workflow/add
|
||||
Content-Type: application/json
|
||||
{
|
||||
"title": "工作流标题",
|
||||
"remark": "工作流描述",
|
||||
"isPublic": false
|
||||
}
|
||||
|
||||
# 更新工作流
|
||||
POST /workflow/update
|
||||
Content-Type: application/json
|
||||
{
|
||||
"uuid": "工作流UUID",
|
||||
"title": "新标题",
|
||||
"remark": "新描述"
|
||||
}
|
||||
|
||||
# 删除工作流
|
||||
POST /workflow/del/{uuid}
|
||||
|
||||
# 启用/禁用工作流
|
||||
POST /workflow/enable/{uuid}?enable=true
|
||||
```
|
||||
|
||||
#### 1.2 搜索和查询
|
||||
```http
|
||||
# 搜索我的工作流
|
||||
GET /workflow/mine/search?keyword=关键词&isPublic=true¤tPage=1&pageSize=10
|
||||
|
||||
# 搜索公开工作流
|
||||
GET /workflow/public/search?keyword=关键词¤tPage=1&pageSize=10
|
||||
|
||||
# 获取工作流组件列表
|
||||
GET /workflow/public/component/list
|
||||
```
|
||||
|
||||
### 2. 工作流执行接口
|
||||
|
||||
#### 2.1 流式执行
|
||||
```http
|
||||
# 流式执行工作流
|
||||
POST /workflow/run
|
||||
Content-Type: application/json
|
||||
Accept: text/event-stream
|
||||
{
|
||||
"uuid": "工作流UUID",
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input",
|
||||
"content": {
|
||||
"type": 1,
|
||||
"textContent": "用户输入内容"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### 2.2 运行时管理
|
||||
```http
|
||||
# 恢复中断的工作流
|
||||
POST /workflow/runtime/resume/{runtimeUuid}
|
||||
Content-Type: application/json
|
||||
{
|
||||
"feedbackContent": "用户反馈内容"
|
||||
}
|
||||
|
||||
# 查询工作流执行历史
|
||||
GET /workflow/runtime/page?wfUuid=工作流UUID¤tPage=1&pageSize=10
|
||||
|
||||
# 查询运行时节点详情
|
||||
GET /workflow/runtime/nodes/{runtimeUuid}
|
||||
|
||||
# 清理运行时数据
|
||||
POST /workflow/runtime/clear?wfUuid=工作流UUID
|
||||
```
|
||||
|
||||
### 3. 管理端接口
|
||||
|
||||
#### 3.1 工作流管理
|
||||
```http
|
||||
# 搜索所有工作流
|
||||
POST /admin/workflow/search
|
||||
Content-Type: application/json
|
||||
{
|
||||
"title": "搜索关键词",
|
||||
"isPublic": true,
|
||||
"isEnable": true
|
||||
}
|
||||
|
||||
# 启用/禁用工作流
|
||||
POST /admin/workflow/enable?uuid=工作流UUID&isEnable=true
|
||||
```
|
||||
|
||||
## 核心实现
|
||||
|
||||
### 1. 工作流引擎 (WorkflowEngine)
|
||||
|
||||
工作流引擎是整个模块的核心,负责:
|
||||
- 工作流图的构建和编译
|
||||
- 节点执行调度
|
||||
- 状态管理和持久化
|
||||
- 流式输出处理
|
||||
|
||||
```java
|
||||
public class WorkflowEngine {
|
||||
// 核心执行方法
|
||||
public void run(User user, List<ObjectNode> userInputs, SseEmitter sseEmitter) {
|
||||
// 1. 验证工作流状态
|
||||
// 2. 创建运行时实例
|
||||
// 3. 构建状态图
|
||||
// 4. 执行工作流
|
||||
// 5. 处理流式输出
|
||||
}
|
||||
|
||||
// 恢复执行方法
|
||||
public void resume(String userInput) {
|
||||
// 1. 更新状态
|
||||
// 2. 继续执行
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 节点工厂 (WfNodeFactory)
|
||||
|
||||
节点工厂负责根据组件类型创建对应的节点实例:
|
||||
|
||||
```java
|
||||
public class WfNodeFactory {
|
||||
public static AbstractWfNode create(WorkflowComponent component,
|
||||
WorkflowNode node,
|
||||
WfState wfState,
|
||||
WfNodeState nodeState) {
|
||||
// 根据组件类型创建对应的节点实例
|
||||
switch (component.getName()) {
|
||||
case "Answer":
|
||||
return new LLMAnswerNode(component, node, wfState, nodeState);
|
||||
case "Switcher":
|
||||
return new SwitcherNode(component, node, wfState, nodeState);
|
||||
// ... 其他节点类型
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 图构建器 (WorkflowGraphBuilder)
|
||||
|
||||
图构建器负责将工作流定义转换为可执行的状态图:
|
||||
|
||||
```java
|
||||
public class WorkflowGraphBuilder {
|
||||
public StateGraph<WfNodeState> build(WorkflowNode startNode) {
|
||||
// 1. 构建编译节点树
|
||||
// 2. 转换为状态图
|
||||
// 3. 添加节点和边
|
||||
// 4. 处理条件分支
|
||||
// 5. 处理并行执行
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 流式响应机制
|
||||
|
||||
### 1. SSE 事件类型
|
||||
|
||||
工作流执行过程中会发送多种类型的 SSE 事件:
|
||||
|
||||
```javascript
|
||||
// 节点开始执行
|
||||
[NODE_RUN_节点UUID] - 节点执行开始事件
|
||||
|
||||
// 节点输入数据
|
||||
[NODE_INPUT_节点UUID] - 节点输入数据事件
|
||||
|
||||
// 节点输出数据
|
||||
[NODE_OUTPUT_节点UUID] - 节点输出数据事件
|
||||
|
||||
// 流式内容块
|
||||
[NODE_CHUNK_节点UUID] - 流式内容块事件
|
||||
|
||||
// 等待用户输入
|
||||
[NODE_WAIT_FEEDBACK_BY_节点UUID] - 等待用户输入事件
|
||||
```
|
||||
|
||||
### 2. 流式处理流程
|
||||
|
||||
1. **初始化**: 创建工作流运行时实例
|
||||
2. **节点执行**: 逐个执行工作流节点
|
||||
3. **实时输出**: 通过 SSE 实时推送执行结果
|
||||
4. **状态更新**: 实时更新节点和工作流状态
|
||||
5. **错误处理**: 捕获并处理执行过程中的错误
|
||||
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 1. 自定义节点开发
|
||||
|
||||
要开发自定义工作流节点,需要:
|
||||
|
||||
1. **创建节点类**:继承 `AbstractWfNode`
|
||||
2. **实现处理逻辑**:重写 `onProcess()` 方法
|
||||
3. **定义配置类**:创建节点配置类
|
||||
4. **注册组件**:在组件表中注册新组件
|
||||
|
||||
```java
|
||||
public class CustomNode extends AbstractWfNode {
|
||||
@Override
|
||||
protected NodeProcessResult onProcess() {
|
||||
// 实现自定义处理逻辑
|
||||
List<NodeIOData> outputs = new ArrayList<>();
|
||||
// ... 处理逻辑
|
||||
return NodeProcessResult.success(outputs);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 自定义组件注册
|
||||
|
||||
```sql
|
||||
-- 在 t_workflow_component 表中添加新组件
|
||||
INSERT INTO t_workflow_component (uuid, name, title, remark, is_enable)
|
||||
VALUES (REPLACE(UUID(), '-', ''), 'CustomNode', '自定义节点', '自定义节点描述', true);
|
||||
```
|
||||
24
pom.xml
24
pom.xml
@@ -270,13 +270,6 @@
|
||||
<version>${lock4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- xxl-job-core -->
|
||||
<dependency>
|
||||
<groupId>com.xuxueli</groupId>
|
||||
<artifactId>xxl-job-core</artifactId>
|
||||
<version>${xxl-job.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>transmittable-thread-local</artifactId>
|
||||
@@ -373,6 +366,23 @@
|
||||
<artifactId>langchain4j-community-neo4j</artifactId>
|
||||
<version>${langchain4j-neo4j.version}</version>
|
||||
</dependency>
|
||||
<artifactId>ruoyi-aihuman</artifactId>
|
||||
<version>${revision}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-workflow</artifactId>
|
||||
<version>${revision}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-workflow-api</artifactId>
|
||||
<version>${revision}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
@@ -61,6 +61,14 @@
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-graph</artifactId>
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-workflow</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-aihuman</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
@@ -11,7 +11,7 @@ import org.springframework.scheduling.annotation.EnableScheduling;
|
||||
*
|
||||
* @author Lion Li
|
||||
*/
|
||||
@SpringBootApplication
|
||||
@SpringBootApplication(scanBasePackages = {"org.ruoyi", "org.ruoyi.aihuman"})
|
||||
@EnableScheduling
|
||||
@EnableAsync
|
||||
public class RuoYiAIApplication {
|
||||
@@ -22,4 +22,4 @@ public class RuoYiAIApplication {
|
||||
application.run(args);
|
||||
System.out.println("(♥◠‿◠)ノ゙ RuoYiAI启动成功 ლ(´ڡ`ლ)゙");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,8 @@ spring:
|
||||
connectionTestQuery: SELECT 1
|
||||
# 多久检查一次连接的活性
|
||||
keepaliveTime: 30000
|
||||
mail:
|
||||
username: xx
|
||||
|
||||
--- # redis 单机配置(单机与集群只能开启一个另一个需要注释掉)
|
||||
spring.data:
|
||||
@@ -102,7 +104,15 @@ pdf:
|
||||
#百炼模型配置
|
||||
dashscope:
|
||||
key: sk-xxxx
|
||||
model: qvq-max
|
||||
|
||||
local:
|
||||
images: xx
|
||||
|
||||
|
||||
|
||||
files: xx
|
||||
|
||||
|
||||
|
||||
--- # Neo4j 知识图谱配置
|
||||
neo4j:
|
||||
|
||||
@@ -156,6 +156,8 @@ security:
|
||||
# actuator 监控配置
|
||||
- /actuator
|
||||
- /actuator/**
|
||||
- /workflow/**
|
||||
- /admin/workflow/**
|
||||
# 多租户配置
|
||||
tenant:
|
||||
# 是否开启
|
||||
@@ -328,3 +330,19 @@ spring:
|
||||
servers-configuration: classpath:mcp-server.json
|
||||
request-timeout: 300s
|
||||
|
||||
# 向量库配置
|
||||
vector-store:
|
||||
# 向量存储类型 可选(weaviate/milvus)
|
||||
# 如需修改向量库类型,请修改此配置值!
|
||||
type: weaviate
|
||||
|
||||
# Weaviate配置
|
||||
weaviate:
|
||||
protocol: http
|
||||
host: 127.0.0.1:6038
|
||||
classname: LocalKnowledge
|
||||
# Milvus配置
|
||||
milvus:
|
||||
url: http://localhost:19530
|
||||
collectionname: LocalKnowledge
|
||||
|
||||
|
||||
@@ -26,11 +26,18 @@ public class ChatRequest {
|
||||
*/
|
||||
private String prompt;
|
||||
|
||||
|
||||
/**
|
||||
* 系统提示词
|
||||
*/
|
||||
private String sysPrompt;
|
||||
|
||||
|
||||
/**
|
||||
* 消息id
|
||||
*/
|
||||
private Long messageId;
|
||||
|
||||
/**
|
||||
* 是否开启流式对话
|
||||
*/
|
||||
@@ -72,6 +79,11 @@ public class ChatRequest {
|
||||
*/
|
||||
private Boolean hasAttachment;
|
||||
|
||||
/**
|
||||
* 是否启用深度思考
|
||||
*/
|
||||
private Boolean enableThinking;
|
||||
|
||||
/**
|
||||
* 是否自动切换模型
|
||||
*/
|
||||
@@ -82,9 +94,4 @@ public class ChatRequest {
|
||||
*/
|
||||
private String token;
|
||||
|
||||
/**
|
||||
* 消息ID(保存消息成功后设置,用于后续扣费更新)
|
||||
*/
|
||||
private Long messageId;
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package org.ruoyi.common.core.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* 向量库配置属性
|
||||
*
|
||||
* @author ageer
|
||||
*/
|
||||
@Data
|
||||
@Component
|
||||
@ConfigurationProperties(prefix = "vector-store")
|
||||
public class VectorStoreProperties {
|
||||
|
||||
/**
|
||||
* 向量库类型
|
||||
*/
|
||||
private String type;
|
||||
|
||||
/**
|
||||
* Weaviate配置
|
||||
*/
|
||||
private Weaviate weaviate = new Weaviate();
|
||||
|
||||
/**
|
||||
* Milvus配置
|
||||
*/
|
||||
private Milvus milvus = new Milvus();
|
||||
|
||||
@Data
|
||||
public static class Weaviate {
|
||||
/**
|
||||
* 协议
|
||||
*/
|
||||
private String protocol;
|
||||
|
||||
/**
|
||||
* 主机地址
|
||||
*/
|
||||
private String host;
|
||||
|
||||
/**
|
||||
* 类名
|
||||
*/
|
||||
private String classname;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Milvus {
|
||||
/**
|
||||
* 连接URL
|
||||
*/
|
||||
private String url;
|
||||
|
||||
/**
|
||||
* 集合名称
|
||||
*/
|
||||
private String collectionname;
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@
|
||||
<module>ruoyi-chat-api</module>
|
||||
<module>ruoyi-knowledge-api</module>
|
||||
<module>ruoyi-system-api</module>
|
||||
<module>ruoyi-workflow-api</module>
|
||||
</modules>
|
||||
|
||||
<properties>
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
<maven.compiler.source>17</maven.compiler.source>
|
||||
<maven.compiler.target>17</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<spring-ai.version>1.0.0</spring-ai.version>
|
||||
<spring-ai.version>1.0.0-M7</spring-ai.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ruoyi.domain;
|
||||
|
||||
|
||||
import com.alibaba.excel.annotation.ExcelProperty;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
@@ -81,6 +82,11 @@ public class ChatModel extends BaseEntity {
|
||||
*/
|
||||
private Integer priority;
|
||||
|
||||
/**
|
||||
* 模型供应商
|
||||
*/
|
||||
private String ProviderName;
|
||||
|
||||
/**
|
||||
* 备注
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package org.ruoyi.domain;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.ruoyi.annotation.DataColumn;
|
||||
import org.ruoyi.core.domain.BaseEntity;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
/**
|
||||
* MCP对象 mcp_info
|
||||
*
|
||||
* @author ageerle
|
||||
* @date Sat Aug 09 16:50:58 CST 2025
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@TableName("mcp_info")
|
||||
public class McpInfo extends BaseEntity {
|
||||
|
||||
|
||||
/**
|
||||
* id
|
||||
*/
|
||||
@TableId(value = "mcp_id", type = IdType.AUTO)
|
||||
private Integer mcpId;
|
||||
|
||||
/**
|
||||
* 服务器名称
|
||||
*/
|
||||
private String serverName;
|
||||
|
||||
/**
|
||||
* 链接方式
|
||||
*/
|
||||
|
||||
private String transportType;
|
||||
|
||||
/**
|
||||
* Command
|
||||
*/
|
||||
private String command;
|
||||
|
||||
/**
|
||||
* Args
|
||||
*/
|
||||
private String arguments;
|
||||
|
||||
private String description;
|
||||
|
||||
/**
|
||||
* Env
|
||||
*/
|
||||
private String env;
|
||||
|
||||
/**
|
||||
* 是否启用
|
||||
*/
|
||||
private Boolean status;
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package org.ruoyi.domain.bo;
|
||||
|
||||
import com.alibaba.excel.annotation.ExcelProperty;
|
||||
import io.github.linpeilie.annotations.AutoMapper;
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
@@ -85,6 +86,10 @@ public class ChatModelBo extends BaseEntity {
|
||||
@NotBlank(message = "密钥不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* 模型供应商
|
||||
*/
|
||||
private String ProviderName;
|
||||
|
||||
/**
|
||||
* 备注
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package org.ruoyi.domain.bo;
|
||||
|
||||
import io.github.linpeilie.annotations.AutoMapper;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
import org.ruoyi.domain.McpInfo;
|
||||
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* MCP业务对象 mcp_info
|
||||
*
|
||||
* @author ageerle
|
||||
* @date Sat Aug 09 16:50:58 CST 2025
|
||||
*/
|
||||
@Data
|
||||
|
||||
@AutoMapper(target = McpInfo.class, reverseConvertGenerate = false)
|
||||
public class McpInfoBo implements Serializable {
|
||||
|
||||
/**
|
||||
* id
|
||||
*/
|
||||
@NotNull(message = "id不能为空" )
|
||||
private Integer mcpId;
|
||||
|
||||
/**
|
||||
* 服务器名称
|
||||
*/
|
||||
private String serverName;
|
||||
|
||||
/**
|
||||
* 链接方式
|
||||
*/
|
||||
private String transportType;
|
||||
|
||||
/**
|
||||
* Command
|
||||
*/
|
||||
private String command;
|
||||
|
||||
/**
|
||||
* Args
|
||||
*/
|
||||
private String arguments;
|
||||
private String description;
|
||||
/**
|
||||
* Env
|
||||
*/
|
||||
private String env;
|
||||
|
||||
/**
|
||||
* 是否启用
|
||||
*/
|
||||
private Boolean status;
|
||||
|
||||
|
||||
}
|
||||
@@ -70,6 +70,11 @@ public class ChatModelVo implements Serializable {
|
||||
@ExcelProperty(value = "是否显示")
|
||||
private String modelShow;
|
||||
|
||||
/**
|
||||
* 模型维度
|
||||
*/
|
||||
private Integer dimension;
|
||||
|
||||
/**
|
||||
* 系统提示词
|
||||
*/
|
||||
@@ -95,6 +100,12 @@ public class ChatModelVo implements Serializable {
|
||||
@ExcelProperty(value = "优先级")
|
||||
private Integer priority;
|
||||
|
||||
/**
|
||||
* 模型供应商
|
||||
*/
|
||||
@ExcelProperty(value = "模型供应商")
|
||||
private String ProviderName;
|
||||
|
||||
/**
|
||||
* 备注
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package org.ruoyi.domain.vo;
|
||||
|
||||
import com.alibaba.excel.annotation.ExcelIgnoreUnannotated;
|
||||
import com.alibaba.excel.annotation.ExcelProperty;
|
||||
import io.github.linpeilie.annotations.AutoMapper;
|
||||
import lombok.Data;
|
||||
import org.ruoyi.common.excel.annotation.ExcelDictFormat;
|
||||
import org.ruoyi.common.excel.convert.ExcelDictConvert;
|
||||
import org.ruoyi.domain.McpInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
|
||||
/**
|
||||
* MCP视图对象 mcp_info
|
||||
*
|
||||
* @author jiyi
|
||||
* @date Sat Aug 09 16:50:58 CST 2025
|
||||
*/
|
||||
@Data
|
||||
@ExcelIgnoreUnannotated
|
||||
@AutoMapper(target = McpInfo.class)
|
||||
public class McpInfoVo implements Serializable {
|
||||
private Integer mcpId;
|
||||
|
||||
/**
|
||||
* 服务器名称
|
||||
*/
|
||||
@ExcelProperty(value = "服务器名称")
|
||||
private String serverName;
|
||||
|
||||
/**
|
||||
* 链接方式
|
||||
*/
|
||||
@ExcelProperty(value = "链接方式", converter = ExcelDictConvert.class)
|
||||
@ExcelDictFormat(dictType = "mcp_transport_type")
|
||||
private String transportType;
|
||||
|
||||
/**
|
||||
* Command
|
||||
*/
|
||||
@ExcelProperty(value = "Command")
|
||||
private String command;
|
||||
|
||||
/**
|
||||
* Args
|
||||
*/
|
||||
@ExcelProperty(value = "Args")
|
||||
private String arguments;
|
||||
@ExcelProperty(value = "Description")
|
||||
private String description;
|
||||
/**
|
||||
* Env
|
||||
*/
|
||||
@ExcelProperty(value = "Env")
|
||||
private String env;
|
||||
|
||||
/**
|
||||
* 是否启用
|
||||
*/
|
||||
@ExcelProperty(value = "是否启用")
|
||||
private Boolean status;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package org.ruoyi.mapper;
|
||||
|
||||
|
||||
import org.apache.ibatis.annotations.*;
|
||||
import org.ruoyi.core.mapper.BaseMapperPlus;
|
||||
import org.ruoyi.domain.McpInfo;
|
||||
import org.ruoyi.domain.vo.McpInfoVo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* MCPMapper接口
|
||||
*
|
||||
* @author jiuyi
|
||||
* @date Sat Aug 09 16:50:58 CST 2025
|
||||
*/
|
||||
@Mapper
|
||||
public interface McpInfoMapper extends BaseMapperPlus<McpInfo, McpInfoVo> {
|
||||
@Select("SELECT * FROM mcp_info WHERE server_name = #{serverName}")
|
||||
McpInfo selectByServerName(@Param("serverName") String serverName);
|
||||
|
||||
@Select("SELECT * FROM mcp_info WHERE status = 1")
|
||||
List<McpInfo> selectActiveServers();
|
||||
|
||||
@Select("SELECT server_name FROM mcp_info WHERE status = 1")
|
||||
List<String> selectActiveServerNames();
|
||||
|
||||
@Update("UPDATE mcp_info SET status = #{status} WHERE server_name = #{serverName}")
|
||||
int updateActiveStatus(@Param("serverName") String serverName, @Param("status") Boolean status);
|
||||
|
||||
@Delete("DELETE FROM mcp_info WHERE server_name = #{serverName}")
|
||||
int deleteByServerName(@Param("serverName") String serverName);
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<!DOCTYPE mapper
|
||||
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="org.ruoyi.mapper.McpInfoMapper">
|
||||
|
||||
</mapper>
|
||||
@@ -74,6 +74,18 @@
|
||||
<version>1.19.6</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.milvus</groupId>
|
||||
<artifactId>milvus-sdk-java</artifactId>
|
||||
<version>2.6.4</version>
|
||||
</dependency>
|
||||
|
||||
<!-- LangChain4j Milvus Embedding Store -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-milvus</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
@@ -101,11 +113,10 @@
|
||||
<artifactId>commons-compress</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>dashscope-sdk-java</artifactId>
|
||||
<version>2.19.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-chat-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
|
||||
@@ -83,6 +83,11 @@ public class KnowledgeInfo extends BaseEntity {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -92,7 +92,11 @@ public class KnowledgeInfoBo extends BaseEntity {
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
@NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
private String embeddingModelName;
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,12 @@ public class QueryVectorBo {
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
* 向量化模型ID
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型ID
|
||||
*/
|
||||
private String embeddingModelName;
|
||||
|
||||
|
||||
@@ -32,9 +32,14 @@ public class StoreEmbeddingBo {
|
||||
private List<String> fids;
|
||||
|
||||
/**
|
||||
* 向量库模型名称
|
||||
* 向量库名称
|
||||
*/
|
||||
private String vectorModelName;
|
||||
private String vectorStoreName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
|
||||
@@ -101,6 +101,11 @@ public class KnowledgeInfoVo implements Serializable {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
package org.ruoyi.embedding;
|
||||
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.embedding.model.ModalityType;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* BaseEmbedModelService 接口,扩展了 EmbeddingModel 接口
|
||||
* 该接口定义了嵌入模型服务的基本配置和功能方法
|
||||
*/
|
||||
public interface BaseEmbedModelService extends EmbeddingModel {
|
||||
/**
|
||||
* 根据配置信息配置嵌入模型
|
||||
* @param config 包含模型配置信息的 ChatModelVo 对象
|
||||
*/
|
||||
void configure(ChatModelVo config);
|
||||
|
||||
/**
|
||||
* 获取当前嵌入模型支持的所有模态类型
|
||||
* @return 返回支持的模态类型集合
|
||||
*/
|
||||
Set<ModalityType> getSupportedModalities();
|
||||
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package org.ruoyi.embedding;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* 嵌入模型工厂服务类
|
||||
* 负责创建和管理各种嵌入模型实例
|
||||
*/
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class EmbeddingModelFactory {
|
||||
|
||||
private final ApplicationContext applicationContext;
|
||||
|
||||
private final IChatModelService chatModelService;
|
||||
|
||||
// 模型缓存,使用ConcurrentHashMap保证线程安全
|
||||
private final Map<String, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* 创建嵌入模型实例
|
||||
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
|
||||
*
|
||||
* @param embeddingModelName 嵌入模型名称
|
||||
* @param dimension 模型维度大小
|
||||
*/
|
||||
public BaseEmbedModelService createModel(String embeddingModelName, Integer dimension) {
|
||||
return modelCache.computeIfAbsent(embeddingModelName, name -> {
|
||||
ChatModelVo modelConfig = chatModelService.selectModelByName(embeddingModelName);
|
||||
if (modelConfig == null) {
|
||||
throw new IllegalArgumentException("未找到模型配置,name=" + name);
|
||||
}
|
||||
if (modelConfig.getDimension() != null) {
|
||||
modelConfig.setDimension(dimension);
|
||||
}
|
||||
return createModelInstance(modelConfig.getProviderName(), modelConfig);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否支持多模态
|
||||
*
|
||||
* @param embeddingModelName 嵌入模型名称
|
||||
* @return boolean 如果模型支持多模态则返回true,否则返回false
|
||||
*/
|
||||
public boolean isMultimodalModel(String embeddingModelName) {
|
||||
return createModel(embeddingModelName, null) instanceof MultiModalEmbedModelService;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建多模态嵌入模型实例
|
||||
*
|
||||
* @param embeddingModelName 嵌入模型名称
|
||||
* @return MultiModalEmbedModelService 多模态嵌入模型服务实例
|
||||
* @throws IllegalArgumentException 当模型不支持多模态时抛出
|
||||
*/
|
||||
public MultiModalEmbedModelService createMultimodalModel(String embeddingModelName) {
|
||||
BaseEmbedModelService model = createModel(embeddingModelName, null);
|
||||
if (model instanceof MultiModalEmbedModelService) {
|
||||
return (MultiModalEmbedModelService) model;
|
||||
}
|
||||
throw new IllegalArgumentException("该模型不支持多模态");
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新模型缓存
|
||||
* 根据给定的嵌入模型ID从缓存中移除对应的模型
|
||||
*
|
||||
* @param embeddingModelId 嵌入模型的唯一标识ID
|
||||
*/
|
||||
public void refreshModel(Long embeddingModelId) {
|
||||
// 从模型缓存中移除指定ID的模型
|
||||
modelCache.remove(embeddingModelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有支持模型工厂的列表
|
||||
*
|
||||
* @return List<String> 支持的模型工厂名称列表
|
||||
*/
|
||||
public List<String> getSupportedFactories() {
|
||||
return new ArrayList<>(applicationContext.getBeansOfType(BaseEmbedModelService.class)
|
||||
.keySet());
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建具体的模型实例
|
||||
* 根据提供的工厂名称和配置信息创建并配置模型实例
|
||||
*
|
||||
* @param factory 工厂名称,用于标识模型类型
|
||||
* @param config 模型配置信息
|
||||
* @return BaseEmbedModelService 配置好的模型实例
|
||||
* @throws IllegalArgumentException 当无法获取指定的模型实例时抛出
|
||||
*/
|
||||
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) {
|
||||
try {
|
||||
// 从Spring上下文中获取模型实例
|
||||
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class);
|
||||
// 配置模型参数
|
||||
model.configure(config);
|
||||
log.info("成功创建嵌入模型: factory={}, modelId={}", config.getProviderName(), config.getId());
|
||||
|
||||
return model;
|
||||
} catch (NoSuchBeanDefinitionException e) {
|
||||
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package org.ruoyi.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.ruoyi.embedding.model.MultiModalInput;
|
||||
|
||||
|
||||
/**
|
||||
* 多模态嵌入模型服务接口,继承自基础嵌入模型服务
|
||||
* 该接口提供了处理图像、视频以及多模态数据并转换为嵌入向量的功能
|
||||
*/
|
||||
public interface MultiModalEmbedModelService extends BaseEmbedModelService {
|
||||
/**
|
||||
* 将图像数据转换为嵌入向量
|
||||
* @param imageDataUrl 图像的地址,必须是公开可访问的URL
|
||||
* @return 包含嵌入向量的响应对象,可能包含状态信息和嵌入结果
|
||||
*/
|
||||
Response<Embedding> embedImage(String imageDataUrl);
|
||||
|
||||
/**
|
||||
* 将视频数据转换为嵌入向量
|
||||
* @param videoDataUrl 视频的地址,必须是公开可访问的URL
|
||||
* @return 包含嵌入向量的响应对象,可能包含状态信息和嵌入结果
|
||||
*/
|
||||
Response<Embedding> embedVideo(String videoDataUrl);
|
||||
|
||||
|
||||
/**
|
||||
* 处理多模态输入并返回嵌入向量的方法
|
||||
*
|
||||
* @param input 包含多种模态信息(如图像、文本等)的输入对象
|
||||
* @return Response<Embedding> 包含嵌入向量的响应对象,Embedding通常表示输入数据的向量表示
|
||||
*/
|
||||
Response<Embedding> embedMultiModal(MultiModalInput input);
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package org.ruoyi.embedding.impl;
|
||||
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* @Author: Robust_H
|
||||
* @Date: 2025-09-30-下午3:00
|
||||
* @Description: 阿里百炼基础嵌入模型(兼容openai)
|
||||
*/
|
||||
@Component("alibailian")
|
||||
public class AliBaiLianBaseEmbedProvider extends OpenAiEmbeddingProvider{
|
||||
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
package org.ruoyi.embedding.impl;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.embedding.MultiModalEmbedModelService;
|
||||
import org.ruoyi.embedding.model.AliyunMultiModalEmbedRequest;
|
||||
import org.ruoyi.embedding.model.AliyunMultiModalEmbedResponse;
|
||||
import org.ruoyi.embedding.model.ModalityType;
|
||||
import org.ruoyi.embedding.model.MultiModalInput;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* 阿里云百炼多模态嵌入模型服务实现类
|
||||
* 实现了MultiModalEmbedModelService接口,提供文本、图像和视频的嵌入向量生成服务
|
||||
*/
|
||||
@Component("bailianMultiModel")
|
||||
@Slf4j
|
||||
public class AliBaiLianMultiEmbeddingProvider implements MultiModalEmbedModelService {
|
||||
private ChatModelVo chatModelVo;
|
||||
|
||||
private final OkHttpClient okHttpClient;
|
||||
|
||||
/**
|
||||
* 构造函数,初始化HTTP客户端
|
||||
* 设置连接超时、读取超时和写入超时时间
|
||||
*/
|
||||
public AliBaiLianMultiEmbeddingProvider() {
|
||||
this.okHttpClient = new OkHttpClient.Builder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.readTimeout(60, TimeUnit.SECONDS)
|
||||
.writeTimeout(30, TimeUnit.SECONDS)
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 图像嵌入向量生成
|
||||
* @param imageDataUrl 图像数据的URL
|
||||
* @return 包含图像嵌入向量的Response对象
|
||||
*/
|
||||
@Override
|
||||
public Response<Embedding> embedImage(String imageDataUrl) {
|
||||
return embedSingleModality("image", imageDataUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* 视频嵌入向量生成
|
||||
* @param videoDataUrl 视频数据的URL
|
||||
* @return 包含视频嵌入向量的Response对象
|
||||
*/
|
||||
@Override
|
||||
public Response<Embedding> embedVideo(String videoDataUrl) {
|
||||
return embedSingleModality("video", videoDataUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* 多模态嵌入向量生成
|
||||
* 支持同时处理文本、图像和视频等多种模态的数据
|
||||
* @param input 包含多种模态输入的对象
|
||||
* @return 包含多模态嵌入向量的Response对象
|
||||
*/
|
||||
@Override
|
||||
public Response<Embedding> embedMultiModal(MultiModalInput input) {
|
||||
try {
|
||||
// 构建请求内容
|
||||
List<Map<String, Object>> contents = buildContentMap(input);
|
||||
if (contents.isEmpty()) {
|
||||
throw new IllegalArgumentException("至少提供一种模态的内容");
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
AliyunMultiModalEmbedRequest request = buildRequest(contents, chatModelVo);
|
||||
AliyunMultiModalEmbedResponse resp = executeRequest(request, chatModelVo);
|
||||
|
||||
// 转换为 embeddings
|
||||
Response<List<Embedding>> response = toEmbeddings(resp);
|
||||
List<Embedding> embeddings = response.content();
|
||||
|
||||
if (embeddings.isEmpty()) {
|
||||
log.warn("阿里云混合模态嵌入返回为空");
|
||||
return Response.from(Embedding.from(new float[0]), response.tokenUsage());
|
||||
}
|
||||
|
||||
// 多模态通常取第一个向量作为代表,也可以根据业务场景返回多个
|
||||
return Response.from(embeddings.get(0), response.tokenUsage());
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("阿里云混合模态嵌入失败", e);
|
||||
throw new IllegalArgumentException("阿里云混合模态嵌入失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 配置模型参数
|
||||
* @param config 模型配置信息
|
||||
*/
|
||||
@Override
|
||||
public void configure(ChatModelVo config) {
|
||||
this.chatModelVo = config;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取支持的模态类型
|
||||
* @return 支持的模态类型集合
|
||||
*/
|
||||
@Override
|
||||
public Set<ModalityType> getSupportedModalities() {
|
||||
return Set.of(ModalityType.TEXT, ModalityType.VIDEO, ModalityType.IMAGE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量文本嵌入向量生成
|
||||
* @param textSegments 文本段列表
|
||||
* @return 包含所有文本嵌入向量的Response对象
|
||||
*/
|
||||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
if (textSegments.isEmpty()) return Response.from(Collections.emptyList());
|
||||
|
||||
try {
|
||||
List<Map<String, Object>> contents = new ArrayList<>();
|
||||
for (TextSegment segment : textSegments) {
|
||||
contents.add(Map.of("text", segment.text()));
|
||||
}
|
||||
|
||||
AliyunMultiModalEmbedRequest request = buildRequest(contents, chatModelVo);
|
||||
AliyunMultiModalEmbedResponse resp = executeRequest(request, chatModelVo);
|
||||
|
||||
return toEmbeddings(resp);
|
||||
} catch (Exception e) {
|
||||
log.error("阿里云文本嵌入失败", e);
|
||||
throw new IllegalArgumentException("阿里云文本嵌入失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 单模态嵌入(图片/视频/单条文本)复用方法
|
||||
* @param key 模态类型(image/video/text)
|
||||
* @param dataUrl 数据URL
|
||||
* @return 包含嵌入向量的Response对象
|
||||
*/
|
||||
|
||||
public Response<Embedding> embedSingleModality(String key, String dataUrl) {
|
||||
try {
|
||||
AliyunMultiModalEmbedRequest request = buildRequest(List.of(Map.of(key, dataUrl)), chatModelVo);
|
||||
AliyunMultiModalEmbedResponse resp = executeRequest(request, chatModelVo);
|
||||
|
||||
Response<List<Embedding>> response = toEmbeddings(resp);
|
||||
List<Embedding> embeddings = response.content();
|
||||
|
||||
if (embeddings.isEmpty()) {
|
||||
log.warn("阿里云 {} 嵌入返回为空", key);
|
||||
return Response.from(Embedding.from(new float[0]), response.tokenUsage());
|
||||
}
|
||||
|
||||
return Response.from(embeddings.get(0), response.tokenUsage());
|
||||
} catch (Exception e) {
|
||||
log.error("阿里云 {} 嵌入失败", key, e);
|
||||
throw new IllegalArgumentException("阿里云 " + key + " 嵌入失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建请求对象
|
||||
* @param contents 请求内容列表
|
||||
* @param chatModelVo 模型配置信息
|
||||
* @return 构建好的请求对象
|
||||
*/
|
||||
private AliyunMultiModalEmbedRequest buildRequest(List<Map<String, Object>> contents, ChatModelVo chatModelVo) {
|
||||
if (contents.isEmpty()) throw new IllegalArgumentException("请求内容不能为空");
|
||||
return AliyunMultiModalEmbedRequest.create(chatModelVo.getModelName(), contents);
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 HTTP 请求并解析响应
|
||||
* @param request 请求对象
|
||||
* @param chatModelVo 模型配置信息
|
||||
* @return API响应对象
|
||||
* @throws IOException IO异常
|
||||
*/
|
||||
private AliyunMultiModalEmbedResponse executeRequest(AliyunMultiModalEmbedRequest request, ChatModelVo chatModelVo) throws IOException {
|
||||
String jsonBody = request.toJson();
|
||||
RequestBody body = RequestBody.create(jsonBody, MediaType.get("application/json"));
|
||||
|
||||
Request httpRequest = new Request.Builder()
|
||||
.url(chatModelVo.getApiHost())
|
||||
.addHeader("Authorization", "Bearer " + chatModelVo.getApiKey())
|
||||
.post(body)
|
||||
.build();
|
||||
|
||||
try (okhttp3.Response response = okHttpClient.newCall(httpRequest).execute()) {
|
||||
if (!response.isSuccessful()) {
|
||||
String err = response.body() != null ? response.body().string() : "无错误信息";
|
||||
throw new IllegalArgumentException("API调用失败: " + response.code() + " - " + err, null);
|
||||
}
|
||||
|
||||
ResponseBody responseBody = response.body();
|
||||
if (responseBody == null) throw new IllegalArgumentException("响应体为空", null);
|
||||
|
||||
return parseEmbeddingsFromResponse(responseBody.string());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析嵌入向量列表
|
||||
* @param responseBody API响应的JSON字符串
|
||||
* @return 嵌入向量响应对象
|
||||
* @throws IOException IO异常
|
||||
*/
|
||||
private AliyunMultiModalEmbedResponse parseEmbeddingsFromResponse(String responseBody) throws IOException {
|
||||
ObjectMapper objectMapper1 = new ObjectMapper();
|
||||
return objectMapper1.readValue(responseBody, AliyunMultiModalEmbedResponse.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 API 请求内容 Map
|
||||
* @param input 多模态输入对象
|
||||
* @return 包含各种模态内容的Map列表
|
||||
*/
|
||||
private List<Map<String, Object>> buildContentMap(MultiModalInput input) {
|
||||
List<Map<String, Object>> contents = new ArrayList<>();
|
||||
|
||||
if (input.getText() != null && !input.getText().isBlank()) {
|
||||
contents.add(Map.of("text", input.getText()));
|
||||
}
|
||||
if (input.getImageUrl() != null && !input.getImageUrl().isBlank()) {
|
||||
contents.add(Map.of("image", input.getImageUrl()));
|
||||
}
|
||||
if (input.getVideoUrl() != null && !input.getVideoUrl().isBlank()) {
|
||||
contents.add(Map.of("video", input.getVideoUrl()));
|
||||
}
|
||||
if (input.getMultiImageUrls() != null && input.getMultiImageUrls().length > 0) {
|
||||
contents.add(Map.of("multi_images", Arrays.asList(input.getMultiImageUrls())));
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 API 原始响应解析为 LangChain4j 的 Response<Embedding>
|
||||
* @param resp API原始响应对象
|
||||
* @return 包含嵌入向量和token使用情况的Response对象
|
||||
*/
|
||||
private Response<List<Embedding>> toEmbeddings(AliyunMultiModalEmbedResponse resp) {
|
||||
if (resp == null || resp.output() == null || resp.output().embeddings() == null) {
|
||||
return Response.from(Collections.emptyList());
|
||||
}
|
||||
|
||||
// 转换 double -> float
|
||||
List<Embedding> embeddings = resp.output().embeddings().stream()
|
||||
.map(item -> {
|
||||
float[] vector = new float[item.embedding().size()];
|
||||
for (int i = 0; i < item.embedding().size(); i++) {
|
||||
vector[i] = item.embedding().get(i).floatValue();
|
||||
}
|
||||
return Embedding.from(vector);
|
||||
})
|
||||
.toList();
|
||||
|
||||
// 构建 TokenUsage
|
||||
TokenUsage tokenUsage = null;
|
||||
if (resp.usage() != null) {
|
||||
tokenUsage = new TokenUsage(
|
||||
resp.usage().input_tokens(),
|
||||
resp.usage().image_tokens(),
|
||||
resp.usage().input_tokens() +resp.usage().image_tokens()
|
||||
);
|
||||
}
|
||||
|
||||
return Response.from(embeddings, tokenUsage);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package org.ruoyi.embedding.impl;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.embedding.BaseEmbedModelService;
|
||||
import org.ruoyi.embedding.model.ModalityType;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @Author: Robust_H
|
||||
* @Date: 2025-09-30-下午3:00
|
||||
* @Description: Ollama嵌入模型
|
||||
*/
|
||||
@Component("ollama")
|
||||
public class OllamaEmbeddingProvider implements BaseEmbedModelService {
|
||||
private ChatModelVo chatModelVo;
|
||||
|
||||
@Override
|
||||
public void configure(ChatModelVo config) {
|
||||
this.chatModelVo = config;
|
||||
}
|
||||
@Override
|
||||
public Set<ModalityType> getSupportedModalities() {
|
||||
return Set.of(ModalityType.TEXT);
|
||||
}
|
||||
|
||||
// ollama不能设置embedding维度,使用milvus时请注意!!创建向量表时需要先设定维度大小
|
||||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
return OllamaEmbeddingModel.builder()
|
||||
.baseUrl(chatModelVo.getApiHost())
|
||||
.modelName(chatModelVo.getModelName())
|
||||
.build()
|
||||
.embedAll(textSegments);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package org.ruoyi.embedding.impl;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.embedding.BaseEmbedModelService;
|
||||
import org.ruoyi.embedding.model.ModalityType;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @Author: Robust_H
|
||||
* @Date: 2025-09-30-下午3:59
|
||||
* @Description: OpenAi嵌入模型
|
||||
*/
|
||||
@Component("openai")
|
||||
public class OpenAiEmbeddingProvider implements BaseEmbedModelService {
|
||||
protected ChatModelVo chatModelVo;
|
||||
|
||||
@Override
|
||||
public void configure(ChatModelVo config) {
|
||||
this.chatModelVo = config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<ModalityType> getSupportedModalities() {
|
||||
return Set.of(ModalityType.TEXT);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
return OpenAiEmbeddingModel.builder()
|
||||
.baseUrl(chatModelVo.getApiHost())
|
||||
.apiKey(chatModelVo.getApiKey())
|
||||
.modelName(chatModelVo.getModelName())
|
||||
.dimensions(chatModelVo.getDimension())
|
||||
.build()
|
||||
.embedAll(textSegments);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package org.ruoyi.embedding.impl;
|
||||
|
||||
|
||||
import org.ruoyi.embedding.BaseEmbedModelService;
|
||||
import org.ruoyi.embedding.model.ModalityType;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @Author: Robust_H
|
||||
* @Date: 2025-09-30-下午3:59
|
||||
* @Description: 硅基流动(兼容 OpenAi)
|
||||
*/
|
||||
@Component("siliconflow")
|
||||
public class SiliconFlowEmbeddingProvider extends OpenAiEmbeddingProvider {
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package org.ruoyi.embedding.impl;
|
||||
|
||||
import dev.langchain4j.community.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.embedding.BaseEmbedModelService;
|
||||
import org.ruoyi.embedding.model.ModalityType;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @Author: Robust_H
|
||||
* @Date: 2025-09-30-下午4:02
|
||||
* @Description: 智谱AI
|
||||
*/
|
||||
@Component("zhipu")
|
||||
public class ZhiPuAiEmbeddingProvider implements BaseEmbedModelService {
|
||||
private ChatModelVo chatModelVo;
|
||||
|
||||
@Override
|
||||
public void configure(ChatModelVo config) {
|
||||
this.chatModelVo = config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<ModalityType> getSupportedModalities() {
|
||||
return Set.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
return ZhipuAiEmbeddingModel.builder()
|
||||
.baseUrl(chatModelVo.getApiHost())
|
||||
.apiKey(chatModelVo.getApiKey())
|
||||
.model(chatModelVo.getModelName())
|
||||
.dimensions(chatModelVo.getDimension())
|
||||
.build()
|
||||
.embedAll(textSegments);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package org.ruoyi.embedding.model;
|
||||
|
||||
import org.ruoyi.common.json.utils.JsonUtils;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* @Author: Robust_H
|
||||
* @Date: 2025-10-1-上午10:00
|
||||
* @Description: 阿里云多模态嵌入请求
|
||||
*/
|
||||
@Data
|
||||
public class AliyunMultiModalEmbedRequest {
|
||||
private String model;
|
||||
private Input input;
|
||||
|
||||
/**
|
||||
* 表示输入数据的记录类(Record)
|
||||
* 该类用于封装一个包含多个映射关系列表的输入数据结构
|
||||
*
|
||||
* @param contents 包含多个Map的列表,每个Map中存储String类型的键和Object类型的值
|
||||
*/
|
||||
public record Input(List<Map<String, Object>> contents) { }
|
||||
|
||||
/**
|
||||
* 创建请求对象
|
||||
*/
|
||||
public static AliyunMultiModalEmbedRequest create(String modelName, List<Map<String, Object>> contents) {
|
||||
AliyunMultiModalEmbedRequest request = new AliyunMultiModalEmbedRequest();
|
||||
request.setModel(modelName);
|
||||
Input input = new Input(contents);
|
||||
request.setInput(input);
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换为JSON字符串
|
||||
*/
|
||||
public String toJson() {
|
||||
return JsonUtils.toJsonString(this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package org.ruoyi.embedding.model;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 阿里云多模态嵌入 API 响应数据模型
|
||||
*/
|
||||
public record AliyunMultiModalEmbedResponse(
|
||||
Output output, // 输出结果对象
|
||||
String request_id, // 请求唯一标识
|
||||
String code, // 错误码
|
||||
String message, // 错误消息
|
||||
Usage usage // 用量信息
|
||||
) {
|
||||
|
||||
/**
|
||||
* 输出对象,包含嵌入向量结果
|
||||
*/
|
||||
public record Output(
|
||||
List<EmbeddingItem> embeddings // 嵌入向量列表
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 单个嵌入向量条目
|
||||
*/
|
||||
public record EmbeddingItem(
|
||||
int index, // 输入内容的索引
|
||||
List<Double> embedding, // 生成的 1024 维向量
|
||||
String type // 输入的类型 (text/image/video/multi_images)
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 用量统计信息
|
||||
*/
|
||||
public record Usage(
|
||||
int input_tokens, // 本次请求输入的 Token 数量
|
||||
int image_tokens, // 本次请求输入的图像 Token 数量
|
||||
int image_count, // 本次请求输入的图像数量
|
||||
int duration // 本次请求输入的视频时长(秒)
|
||||
) {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package org.ruoyi.embedding.model;
|
||||
|
||||
/**
|
||||
* 模态类型
|
||||
*/
|
||||
public enum ModalityType {
|
||||
TEXT, IMAGE, AUDIO, VIDEO, MULTI
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package org.ruoyi.embedding.model;
|
||||
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* @Author: Robust_H
|
||||
* @Date: 2025-09-30-下午2:13
|
||||
* @Description: 多模态输入
|
||||
*/
|
||||
@Data
|
||||
@Builder
|
||||
public class MultiModalInput {
|
||||
private String text;
|
||||
private byte[] imageData;
|
||||
private byte[] videoData;
|
||||
private String imageMimeType;
|
||||
private String videoMimeType;
|
||||
private String[] multiImageUrls;
|
||||
private String imageUrl;
|
||||
private String videoUrl;
|
||||
|
||||
/**
|
||||
* 检查是否有文本内容
|
||||
*/
|
||||
public boolean hasText() {
|
||||
return StrUtil.isNotBlank(text);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有图片内容
|
||||
*/
|
||||
public boolean hasImage() {
|
||||
return ArrayUtil.isNotEmpty(imageData) || StrUtil.isNotBlank(imageUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有视频内容
|
||||
*/
|
||||
public boolean hasVideo() {
|
||||
return ArrayUtil.isNotEmpty(videoData) || StrUtil.isNotBlank(videoUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有多图片
|
||||
*/
|
||||
public boolean hasMultiImages() {
|
||||
return ArrayUtil.isNotEmpty(multiImageUrls);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有任何内容
|
||||
*/
|
||||
public boolean hasAnyContent() {
|
||||
return hasText() || hasImage() || hasVideo() || hasMultiImages();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取内容的数量
|
||||
*/
|
||||
public int getContentCount() {
|
||||
int count = 0;
|
||||
if (hasText()) count++;
|
||||
if (hasImage()) count++;
|
||||
if (hasVideo()) count++;
|
||||
if (hasMultiImages()) count++;
|
||||
return count;
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package org.ruoyi.service;
|
||||
|
||||
import org.ruoyi.common.core.exception.ServiceException;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
|
||||
@@ -11,15 +12,15 @@ import java.util.List;
|
||||
*/
|
||||
public interface VectorStoreService {
|
||||
|
||||
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo);
|
||||
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) throws ServiceException;
|
||||
|
||||
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
||||
|
||||
void createSchema(String kid,String modelName);
|
||||
void createSchema(String kid, String embeddingModelName);
|
||||
|
||||
void removeById(String id,String modelName);
|
||||
void removeById(String id,String modelName) throws ServiceException;
|
||||
|
||||
void removeByDocId(String docId, String kid);
|
||||
void removeByDocId(String docId, String kid) throws ServiceException;
|
||||
|
||||
void removeByFid(String fid, String kid);
|
||||
void removeByFid(String fid, String kid) throws ServiceException;
|
||||
}
|
||||
|
||||
@@ -1,37 +1,14 @@
|
||||
package org.ruoyi.service.impl;
|
||||
|
||||
import cn.hutool.json.JSONObject;
|
||||
import com.google.protobuf.ServiceException;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
|
||||
import io.weaviate.client.Config;
|
||||
import io.weaviate.client.WeaviateClient;
|
||||
import io.weaviate.client.base.Result;
|
||||
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
|
||||
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
|
||||
import io.weaviate.client.v1.filters.Operator;
|
||||
import io.weaviate.client.v1.filters.WhereFilter;
|
||||
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||
import io.weaviate.client.v1.schema.model.Property;
|
||||
import io.weaviate.client.v1.schema.model.Schema;
|
||||
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.service.strategy.VectorStoreStrategyFactory;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 向量库管理
|
||||
@@ -39,235 +16,61 @@ import java.util.stream.Collectors;
|
||||
* @author ageer
|
||||
*/
|
||||
@Service
|
||||
@Primary
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
|
||||
private final ConfigService configService;
|
||||
private final VectorStoreStrategyFactory strategyFactory;
|
||||
|
||||
// private EmbeddingStore<TextSegment> embeddingStore;
|
||||
private WeaviateClient client;
|
||||
|
||||
/**
|
||||
* 获取当前配置的向量库策略
|
||||
*/
|
||||
private VectorStoreService getCurrentStrategy() {
|
||||
return strategyFactory.getStrategy();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid, String modelName) {
|
||||
String protocol = configService.getConfigValue("weaviate", "protocol");
|
||||
String host = configService.getConfigValue("weaviate", "host");
|
||||
String className = configService.getConfigValue("weaviate", "classname")+kid;
|
||||
// 创建 Weaviate 客户端
|
||||
client= new WeaviateClient(new Config(protocol, host));
|
||||
// 检查类是否存在,如果不存在就创建 schema
|
||||
Result<Schema> schemaResult = client.schema().getter().run();
|
||||
Schema schema = schemaResult.getResult();
|
||||
boolean classExists = false;
|
||||
for (WeaviateClass weaviateClass : schema.getClasses()) {
|
||||
if (weaviateClass.getClassName().equals(className)) {
|
||||
classExists = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!classExists) {
|
||||
// 类不存在,创建 schema
|
||||
WeaviateClass build = WeaviateClass.builder()
|
||||
.className(className)
|
||||
.vectorizer("none")
|
||||
.properties(
|
||||
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
|
||||
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
|
||||
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
|
||||
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
|
||||
)
|
||||
.build();
|
||||
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
|
||||
if (createResult.hasErrors()) {
|
||||
log.error("Schema 创建失败: {}", createResult.getError());
|
||||
} else {
|
||||
log.info("Schema 创建成功: {}", className);
|
||||
}
|
||||
}
|
||||
// embeddingStore = WeaviateEmbeddingStore.builder()
|
||||
// .scheme(protocol)
|
||||
// .host(host)
|
||||
// .objectClass(className)
|
||||
// .scheme(protocol)
|
||||
// .avoidDups(true)
|
||||
// .consistencyLevel("ALL")
|
||||
// .build();
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
strategy.createSchema(kid, modelName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
String kid = storeEmbeddingBo.getKid();
|
||||
String docId = storeEmbeddingBo.getDocId();
|
||||
log.info("向量存储条数记录: " + chunkList.size());
|
||||
long startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < chunkList.size(); i++) {
|
||||
String text = chunkList.get(i);
|
||||
String fid = fidList.get(i);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
Map<String, Object> properties = Map.of(
|
||||
"text", text,
|
||||
"fid",fid,
|
||||
"kid", kid,
|
||||
"docId", docId
|
||||
);
|
||||
Float[] vector = toObjectArray(embedding.vector());
|
||||
client.data().creator()
|
||||
.withClassName("LocalKnowledge" + kid) // 注意替换成实际类名
|
||||
.withProperties(properties)
|
||||
.withVector(vector)
|
||||
.run();
|
||||
}
|
||||
long endTime = System.currentTimeMillis();
|
||||
log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"秒");
|
||||
log.info("存储向量数据: kid={}, docId={}, 数据条数={}",
|
||||
storeEmbeddingBo.getKid(), storeEmbeddingBo.getDocId(), storeEmbeddingBo.getChunkList().size());
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
strategy.storeEmbeddings(storeEmbeddingBo);
|
||||
}
|
||||
|
||||
private static Float[] toObjectArray(float[] primitive) {
|
||||
Float[] result = new Float[primitive.length];
|
||||
for (int i = 0; i < primitive.length; i++) {
|
||||
result[i] = primitive[i]; // 自动装箱
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
|
||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
float[] vector = queryEmbedding.vector();
|
||||
List<String> vectorStrings = new ArrayList<>();
|
||||
for (float v : vector) {
|
||||
vectorStrings.add(String.valueOf(v));
|
||||
}
|
||||
String vectorStr = String.join(",", vectorStrings);
|
||||
String className = configService.getConfigValue("weaviate", "classname") ;
|
||||
// 构建 GraphQL 查询
|
||||
String graphQLQuery = String.format(
|
||||
"{\n" +
|
||||
" Get {\n" +
|
||||
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
|
||||
" text\n" +
|
||||
" fid\n" +
|
||||
" kid\n" +
|
||||
" docId\n" +
|
||||
" _additional {\n" +
|
||||
" distance\n" +
|
||||
" id\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}",
|
||||
className+ queryVectorBo.getKid(),
|
||||
vectorStr,
|
||||
queryVectorBo.getMaxResults()
|
||||
);
|
||||
|
||||
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
|
||||
List<String> resultList = new ArrayList<>();
|
||||
if (result != null && !result.hasErrors()) {
|
||||
Object data = result.getResult().getData();
|
||||
JSONObject entries = new JSONObject(data);
|
||||
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
|
||||
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
|
||||
if(objects.isEmpty()){
|
||||
return resultList;
|
||||
}
|
||||
for (Object object : objects) {
|
||||
Map<String, String> map = (Map<String, String>) object;
|
||||
String content = map.get("text");
|
||||
resultList.add( content);
|
||||
}
|
||||
return resultList;
|
||||
} else {
|
||||
log.error("GraphQL 查询失败: {}", result.getError());
|
||||
return resultList;
|
||||
}
|
||||
log.info("查询向量数据: kid={}, query={}, maxResults={}",
|
||||
queryVectorBo.getKid(), queryVectorBo.getQuery(), queryVectorBo.getMaxResults());
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
return strategy.getQueryVector(queryVectorBo);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public void removeById(String id, String modelName) {
|
||||
String protocol = configService.getConfigValue("weaviate", "protocol");
|
||||
String host = configService.getConfigValue("weaviate", "host");
|
||||
String className = configService.getConfigValue("weaviate", "classname");
|
||||
String finalClassName = className + id;
|
||||
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
|
||||
Result<Boolean> result = client.schema().classDeleter().withClassName(finalClassName).run();
|
||||
if (result.hasErrors()) {
|
||||
log.error("失败删除向量: " + result.getError());
|
||||
throw new ServiceException("失败删除向量数据!");
|
||||
} else {
|
||||
log.info("成功删除向量数据: " + result.getResult());
|
||||
}
|
||||
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
strategy.removeById(id, modelName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String docId, String kid) {
|
||||
String className = configService.getConfigValue("weaviate", "classname") + kid;
|
||||
// 构建 Where 条件
|
||||
WhereFilter whereFilter = WhereFilter.builder()
|
||||
.path("docId")
|
||||
.operator(Operator.Equal)
|
||||
.valueText(docId)
|
||||
.build();
|
||||
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
|
||||
Result<BatchDeleteResponse> result = deleter.withClassName(className)
|
||||
.withWhere(whereFilter)
|
||||
.run();
|
||||
if (result != null && !result.hasErrors()) {
|
||||
log.info("成功删除 docId={} 的所有向量数据", docId);
|
||||
} else {
|
||||
log.error("删除失败: {}", result.getError());
|
||||
}
|
||||
log.info("根据docId删除向量数据: docId={}, kid={}", docId, kid);
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
strategy.removeByDocId(docId, kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByFid(String fid, String kid) {
|
||||
String className = configService.getConfigValue("weaviate", "classname") + kid;
|
||||
// 构建 Where 条件
|
||||
WhereFilter whereFilter = WhereFilter.builder()
|
||||
.path("fid")
|
||||
.operator(Operator.Equal)
|
||||
.valueText(fid)
|
||||
.build();
|
||||
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
|
||||
Result<BatchDeleteResponse> result = deleter.withClassName(className)
|
||||
.withWhere(whereFilter)
|
||||
.run();
|
||||
if (result != null && !result.hasErrors()) {
|
||||
log.info("成功删除 fid={} 的所有向量数据", fid);
|
||||
} else {
|
||||
log.error("删除失败: {}", result.getError());
|
||||
}
|
||||
log.info("根据fid删除向量数据: fid={}, kid={}", fid, kid);
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
strategy.removeByFid(fid, kid);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
@SneakyThrows
|
||||
public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
|
||||
EmbeddingModel embeddingModel;
|
||||
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
|
||||
embeddingModel = OllamaEmbeddingModel.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else if ("baai/bge-m3".equals(modelName)) {
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else {
|
||||
throw new ServiceException("未找到对应向量化模型!");
|
||||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package org.ruoyi.service.strategy;
|
||||
|
||||
import org.ruoyi.common.core.exception.ServiceException;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.config.VectorStoreProperties;
|
||||
import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
|
||||
/**
|
||||
* 向量库策略抽象基类
|
||||
* 提供公共的方法实现,如embedding模型获取等
|
||||
*
|
||||
* @author Yzm
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public abstract class AbstractVectorStoreStrategy implements VectorStoreService {
|
||||
|
||||
protected final VectorStoreProperties vectorStoreProperties;
|
||||
|
||||
private final EmbeddingModelFactory embeddingModelFactory;
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
@SneakyThrows
|
||||
protected EmbeddingModel getEmbeddingModel(String modelName, Integer dimension) {
|
||||
return embeddingModelFactory.createModel(modelName, dimension);
|
||||
}
|
||||
|
||||
/**
|
||||
* 将float数组转换为Float对象数组
|
||||
*/
|
||||
protected static Float[] toObjectArray(float[] primitive) {
|
||||
Float[] result = new Float[primitive.length];
|
||||
for (int i = 0; i < primitive.length; i++) {
|
||||
result[i] = primitive[i]; // 自动装箱
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取向量库类型标识
|
||||
*/
|
||||
public abstract String getVectorStoreType();
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package org.ruoyi.service.strategy;
|
||||
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.config.VectorStoreProperties;
|
||||
import org.ruoyi.service.strategy.impl.MilvusVectorStoreStrategy;
|
||||
import org.ruoyi.service.strategy.impl.WeaviateVectorStoreStrategy;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 向量库策略工厂
|
||||
* 根据配置动态选择向量库实现
|
||||
*
|
||||
* @author Yzm
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class VectorStoreStrategyFactory {
|
||||
|
||||
private final VectorStoreProperties vectorStoreProperties;
|
||||
private final WeaviateVectorStoreStrategy weaviateStrategy;
|
||||
private final MilvusVectorStoreStrategy milvusStrategy;
|
||||
|
||||
private Map<String, VectorStoreService> strategies;
|
||||
|
||||
@PostConstruct
|
||||
public void init() {
|
||||
strategies = new HashMap<>();
|
||||
strategies.put("weaviate", weaviateStrategy);
|
||||
strategies.put("milvus", milvusStrategy);
|
||||
log.info("向量库策略工厂初始化完成,支持的策略: {}", strategies.keySet());
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前配置的向量库策略
|
||||
*/
|
||||
public VectorStoreService getStrategy() {
|
||||
String vectorStoreType = vectorStoreProperties.getType();
|
||||
if (vectorStoreType == null || vectorStoreType.trim().isEmpty()) {
|
||||
vectorStoreType = "weaviate"; // 默认使用weaviate
|
||||
}
|
||||
VectorStoreService strategy = strategies.get(vectorStoreType.toLowerCase());
|
||||
if (strategy == null) {
|
||||
log.warn("未找到向量库策略: {}, 使用默认策略: weaviate", vectorStoreType);
|
||||
strategy = strategies.get("weaviate");
|
||||
}
|
||||
log.debug("使用向量库策略: {}", vectorStoreType);
|
||||
return strategy;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package org.ruoyi.service.strategy.impl;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
||||
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
||||
import io.milvus.param.IndexType;
|
||||
import io.milvus.param.MetricType;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
|
||||
private final Integer DIMENSION = 2048;
|
||||
|
||||
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
|
||||
super(vectorStoreProperties, embeddingModelFactory);
|
||||
}
|
||||
|
||||
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
||||
private final Map<String, EmbeddingStore<TextSegment>> storeCache = new ConcurrentHashMap<>();
|
||||
|
||||
private EmbeddingStore<TextSegment> getMilvusStore(String collectionName, boolean autoFlushOnInsert) {
|
||||
String key = collectionName + "|" + autoFlushOnInsert;
|
||||
return storeCache.computeIfAbsent(key, k ->
|
||||
MilvusEmbeddingStore.builder()
|
||||
.uri(vectorStoreProperties.getMilvus().getUrl())
|
||||
.collectionName(collectionName)
|
||||
.dimension(DIMENSION)
|
||||
.indexType(IndexType.IVF_FLAT)
|
||||
.metricType(MetricType.L2)
|
||||
.autoFlushOnInsert(autoFlushOnInsert)
|
||||
.idFieldName("id")
|
||||
.textFieldName("text")
|
||||
.metadataFieldName("metadata")
|
||||
.vectorFieldName("vector")
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid, String modelName) {
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
// 使用缓存获取连接以确保只初始化一次
|
||||
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, true);
|
||||
log.info("Milvus集合初始化完成: {}", collectionName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION);
|
||||
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
String kid = storeEmbeddingBo.getKid();
|
||||
String docId = storeEmbeddingBo.getDocId();
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
|
||||
log.info("Milvus向量存储条数记录: {}", chunkList.size());
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
// 复用连接,写入场景使用 autoFlush=false 以提升批量插入性能
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
|
||||
|
||||
IntStream.range(0, chunkList.size()).forEach(i -> {
|
||||
String text = chunkList.get(i);
|
||||
String fid = fidList.get(i);
|
||||
Metadata metadata = new Metadata();
|
||||
metadata.put("fid", fid);
|
||||
metadata.put("kid", kid);
|
||||
metadata.put("docId", docId);
|
||||
|
||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
embeddingStore.add(embedding, textSegment);
|
||||
});
|
||||
long endTime = System.currentTimeMillis();
|
||||
log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION);
|
||||
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
|
||||
|
||||
// 查询复用连接,autoFlush 对查询无影响,此处保持 true
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, true);
|
||||
|
||||
List<String> resultList = new ArrayList<>();
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(queryEmbedding)
|
||||
.maxResults(queryVectorBo.getMaxResults())
|
||||
.build();
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
|
||||
for (EmbeddingMatch<TextSegment> match : matches) {
|
||||
TextSegment segment = match.embedded();
|
||||
if (segment != null) {
|
||||
resultList.add(segment.text());
|
||||
}
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public void removeById(String id, String modelName) {
|
||||
// 注意:此处原逻辑使用 collectionname + id,保持现状
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, false);
|
||||
embeddingStore.remove(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String docId, String kid) {
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
|
||||
Filter filter = MetadataFilterBuilder.metadataKey("docId").isEqualTo(docId);
|
||||
embeddingStore.removeAll(filter);
|
||||
log.info("Milvus成功删除 docId={} 的所有向量数据", docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByFid(String fid, String kid) {
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
|
||||
Filter filter = MetadataFilterBuilder.metadataKey("fid").isEqualTo(fid);
|
||||
embeddingStore.removeAll(filter);
|
||||
log.info("Milvus成功删除 fid={} 的所有向量数据", fid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getVectorStoreType() {
|
||||
return "milvus";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package org.ruoyi.service.strategy.impl;
|
||||
|
||||
import cn.hutool.json.JSONObject;
|
||||
import org.ruoyi.common.core.exception.ServiceException;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import io.weaviate.client.Config;
|
||||
import io.weaviate.client.WeaviateClient;
|
||||
import io.weaviate.client.base.Result;
|
||||
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
|
||||
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
|
||||
import io.weaviate.client.v1.filters.Operator;
|
||||
import io.weaviate.client.v1.filters.WhereFilter;
|
||||
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||
import io.weaviate.client.v1.schema.model.Property;
|
||||
import io.weaviate.client.v1.schema.model.Schema;
|
||||
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
|
||||
import org.springframework.stereotype.Component;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Weaviate向量库策略实现
|
||||
*
|
||||
* @author Yzm
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
private WeaviateClient client;
|
||||
|
||||
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
|
||||
super(vectorStoreProperties, embeddingModelFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getVectorStoreType() {
|
||||
return "weaviate";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid, String embeddingModelName) {
|
||||
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
|
||||
String host = vectorStoreProperties.getWeaviate().getHost();
|
||||
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
|
||||
// 创建 Weaviate 客户端
|
||||
client = new WeaviateClient(new Config(protocol, host));
|
||||
// 检查类是否存在,如果不存在就创建 schema
|
||||
Result<Schema> schemaResult = client.schema().getter().run();
|
||||
Schema schema = schemaResult.getResult();
|
||||
boolean classExists = false;
|
||||
for (WeaviateClass weaviateClass : schema.getClasses()) {
|
||||
if (weaviateClass.getClassName().equals(className)) {
|
||||
classExists = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!classExists) {
|
||||
// 类不存在,创建 schema
|
||||
WeaviateClass build = WeaviateClass.builder()
|
||||
.className(className)
|
||||
.vectorizer("none")
|
||||
.properties(
|
||||
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
|
||||
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
|
||||
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
|
||||
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
|
||||
)
|
||||
.build();
|
||||
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
|
||||
if (createResult.hasErrors()) {
|
||||
log.error("Schema 创建失败: {}", createResult.getError());
|
||||
} else {
|
||||
log.info("Schema 创建成功: {}", className);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
createSchema(storeEmbeddingBo.getKid(),storeEmbeddingBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), null);
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
String kid = storeEmbeddingBo.getKid();
|
||||
String docId = storeEmbeddingBo.getDocId();
|
||||
log.info("向量存储条数记录: " + chunkList.size());
|
||||
long startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < chunkList.size(); i++) {
|
||||
String text = chunkList.get(i);
|
||||
String fid = fidList.get(i);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
Map<String, Object> properties = Map.of(
|
||||
"text", text,
|
||||
"fid", fid,
|
||||
"kid", kid,
|
||||
"docId", docId
|
||||
);
|
||||
Float[] vector = toObjectArray(embedding.vector());
|
||||
client.data().creator()
|
||||
.withClassName("LocalKnowledge" + kid)
|
||||
.withProperties(properties)
|
||||
.withVector(vector)
|
||||
.run();
|
||||
}
|
||||
long endTime = System.currentTimeMillis();
|
||||
log.info("向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "秒");
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
createSchema(queryVectorBo.getKid(),queryVectorBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),null);
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
float[] vector = queryEmbedding.vector();
|
||||
List<String> vectorStrings = new ArrayList<>();
|
||||
for (float v : vector) {
|
||||
vectorStrings.add(String.valueOf(v));
|
||||
}
|
||||
String vectorStr = String.join(",", vectorStrings);
|
||||
String className = vectorStoreProperties.getWeaviate().getClassname();
|
||||
|
||||
// 构建 GraphQL 查询
|
||||
String graphQLQuery = String.format(
|
||||
"{\n" +
|
||||
" Get {\n" +
|
||||
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
|
||||
" text\n" +
|
||||
" fid\n" +
|
||||
" kid\n" +
|
||||
" docId\n" +
|
||||
" _additional {\n" +
|
||||
" distance\n" +
|
||||
" id\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}",
|
||||
className + queryVectorBo.getKid(),
|
||||
vectorStr,
|
||||
queryVectorBo.getMaxResults()
|
||||
);
|
||||
|
||||
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
|
||||
List<String> resultList = new ArrayList<>();
|
||||
if (result != null && !result.hasErrors()) {
|
||||
Object data = result.getResult().getData();
|
||||
JSONObject entries = new JSONObject(data);
|
||||
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
|
||||
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
|
||||
if (objects.isEmpty()) {
|
||||
return resultList;
|
||||
}
|
||||
for (Object object : objects) {
|
||||
Map<String, String> map = (Map<String, String>) object;
|
||||
String content = map.get("text");
|
||||
resultList.add(content);
|
||||
}
|
||||
return resultList;
|
||||
} else {
|
||||
log.error("GraphQL 查询失败: {}", result.getError());
|
||||
return resultList;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public void removeById(String id, String modelName) {
|
||||
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
|
||||
String host = vectorStoreProperties.getWeaviate().getHost();
|
||||
String className = vectorStoreProperties.getWeaviate().getClassname();
|
||||
String finalClassName = className + id;
|
||||
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
|
||||
Result<Boolean> result = client.schema().classDeleter().withClassName(finalClassName).run();
|
||||
if (result.hasErrors()) {
|
||||
log.error("失败删除向量: " + result.getError());
|
||||
throw new ServiceException("失败删除向量数据!");
|
||||
} else {
|
||||
log.info("成功删除向量数据: " + result.getResult());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String docId, String kid) {
|
||||
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
|
||||
// 构建 Where 条件
|
||||
WhereFilter whereFilter = WhereFilter.builder()
|
||||
.path("docId")
|
||||
.operator(Operator.Equal)
|
||||
.valueText(docId)
|
||||
.build();
|
||||
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
|
||||
Result<BatchDeleteResponse> result = deleter.withClassName(className)
|
||||
.withWhere(whereFilter)
|
||||
.run();
|
||||
if (result != null && !result.hasErrors()) {
|
||||
log.info("成功删除 docId={} 的所有向量数据", docId);
|
||||
} else {
|
||||
log.error("删除失败: {}", result.getError());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByFid(String fid, String kid) {
|
||||
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
|
||||
// 构建 Where 条件
|
||||
WhereFilter whereFilter = WhereFilter.builder()
|
||||
.path("fid")
|
||||
.operator(Operator.Equal)
|
||||
.valueText(fid)
|
||||
.build();
|
||||
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
|
||||
Result<BatchDeleteResponse> result = deleter.withClassName(className)
|
||||
.withWhere(whereFilter)
|
||||
.run();
|
||||
if (result != null && !result.hasErrors()) {
|
||||
log.info("成功删除 fid={} 的所有向量数据", fid);
|
||||
} else {
|
||||
log.error("删除失败: {}", result.getError());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,7 +18,7 @@ import java.io.Serializable;
|
||||
*/
|
||||
@Data
|
||||
@ExcelIgnoreUnannotated
|
||||
@AutoMapper(target = ChatConfig.class)
|
||||
@AutoMapper(target = ChatConfig.class)
|
||||
public class ChatConfigVo implements Serializable {
|
||||
|
||||
@Serial
|
||||
|
||||
133
ruoyi-modules-api/ruoyi-workflow-api/pom.xml
Normal file
133
ruoyi-modules-api/ruoyi-workflow-api/pom.xml
Normal file
@@ -0,0 +1,133 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-modules-api</artifactId>
|
||||
<version>${revision}</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>ruoyi-workflow-api</artifactId>
|
||||
|
||||
<description>
|
||||
工作流API模块
|
||||
</description>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>17</maven.compiler.source>
|
||||
<maven.compiler.target>17</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-web</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-system-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-common-satoken</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-common-mail</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ruoyi</groupId>
|
||||
<artifactId>ruoyi-chat</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>1.2.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>cn.hutool</groupId>
|
||||
<artifactId>hutool-all</artifactId>
|
||||
<version>5.8.12</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bsc.langgraph4j</groupId>
|
||||
<artifactId>langgraph4j-core</artifactId>
|
||||
<version>1.5.3</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bsc.langgraph4j</groupId>
|
||||
<artifactId>langgraph4j-langchain4j</artifactId>
|
||||
<version>1.5.3</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.swagger.core.v3</groupId>
|
||||
<artifactId>swagger-annotations</artifactId>
|
||||
<version>2.2.8</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<version>1.2.0</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-community-dashscope</artifactId>
|
||||
<version>1.2.0-beta8</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.baomidou</groupId>
|
||||
<artifactId>mybatis-plus-generator</artifactId>
|
||||
<version>3.5.3.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-http-client-jdk</artifactId>
|
||||
<version>1.2.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-document-parser-apache-poi</artifactId>
|
||||
<version>1.2.0-beta8</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.api-client</groupId>
|
||||
<artifactId>google-api-client</artifactId>
|
||||
<version>2.6.0</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@@ -0,0 +1,43 @@
|
||||
package org.ruoyi.workflow;
|
||||
|
||||
import com.baomidou.mybatisplus.generator.FastAutoGenerator;
|
||||
import com.baomidou.mybatisplus.generator.config.OutputFile;
|
||||
import com.baomidou.mybatisplus.generator.config.rules.DbColumnType;
|
||||
|
||||
import java.sql.Types;
|
||||
import java.util.Collections;
|
||||
|
||||
public class CodeGenerator {
|
||||
public static void main(String[] args) {
|
||||
FastAutoGenerator.create("jdbc:postgres://172.17.30.40:5432/aideepin?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true", "postgres", "postgres")
|
||||
.globalConfig(builder -> {
|
||||
builder.author("moyz") // 设置作者
|
||||
.enableSwagger() // 开启 swagger 模式
|
||||
.fileOverride() // 覆盖已生成文件
|
||||
.outputDir("D://"); // 指定输出目录
|
||||
})
|
||||
.dataSourceConfig(builder -> builder.typeConvertHandler((globalConfig, typeRegistry, metaInfo) -> {
|
||||
int typeCode = metaInfo.getJdbcType().TYPE_CODE;
|
||||
if (typeCode == Types.SMALLINT) {
|
||||
// 自定义类型转换
|
||||
return DbColumnType.INTEGER;
|
||||
}
|
||||
return typeRegistry.getColumnType(metaInfo);
|
||||
|
||||
}))
|
||||
.packageConfig(builder -> {
|
||||
builder.mapper("com.adi.common.mapper")
|
||||
.parent("")
|
||||
.moduleName("")
|
||||
.entity("po")
|
||||
.serviceImpl("service.impl")
|
||||
.pathInfo(Collections.singletonMap(OutputFile.xml, "D://mybatisplus-generatorcode")); // 设置mapperXml生成路径
|
||||
})
|
||||
.strategyConfig(builder -> {
|
||||
builder.addInclude("adi_knowledge_base_qa_record") // 设置需要生成的表名
|
||||
.addTablePrefix("adi_");
|
||||
builder.mapperBuilder().enableBaseResultMap().enableMapperAnnotation().build();
|
||||
})
|
||||
.execute();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package org.ruoyi.workflow.base;
|
||||
|
||||
import lombok.Data;
|
||||
import org.ruoyi.workflow.enums.ErrorEnum;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
public class BaseResponse<T> implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
/**
|
||||
* 是否成功
|
||||
*/
|
||||
private boolean success;
|
||||
/**
|
||||
* 状态码
|
||||
*/
|
||||
private String code;
|
||||
/**
|
||||
* 提示
|
||||
*/
|
||||
private String message;
|
||||
/**
|
||||
* 数据
|
||||
*/
|
||||
private T data;
|
||||
|
||||
public BaseResponse() {
|
||||
}
|
||||
|
||||
public BaseResponse(boolean success) {
|
||||
this.success = success;
|
||||
}
|
||||
|
||||
public BaseResponse(boolean success, T data) {
|
||||
this.data = data;
|
||||
this.success = success;
|
||||
}
|
||||
|
||||
public BaseResponse(String code, String message, T data) {
|
||||
this.code = code;
|
||||
this.success = false;
|
||||
this.message = message;
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
public static BaseResponse success(String message) {
|
||||
return new BaseResponse(ErrorEnum.SUCCESS.getCode(), message, "");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package org.ruoyi.workflow.base;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.ibatis.type.BaseTypeHandler;
|
||||
import org.apache.ibatis.type.JdbcType;
|
||||
import org.apache.ibatis.type.MappedJdbcTypes;
|
||||
import org.apache.ibatis.type.MappedTypes;
|
||||
import org.ruoyi.workflow.enums.WfIODataTypeEnum;
|
||||
import org.ruoyi.workflow.util.JsonUtil;
|
||||
import org.ruoyi.workflow.workflow.WfNodeInputConfig;
|
||||
import org.ruoyi.workflow.workflow.def.WfNodeIO;
|
||||
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
|
||||
|
||||
import java.sql.CallableStatement;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.ruoyi.workflow.workflow.WfNodeIODataUtil.INPUT_TYPE_TO_NODE_IO_DEF;
|
||||
|
||||
@Slf4j
|
||||
@MappedJdbcTypes({JdbcType.JAVA_OBJECT})
|
||||
@MappedTypes({WfNodeInputConfig.class})
|
||||
public class NodeInputConfigTypeHandler extends BaseTypeHandler<WfNodeInputConfig> {
|
||||
|
||||
public static WfNodeInputConfig fillNodeInputConfig(String jsonSource) {
|
||||
ObjectNode jsonNode = (ObjectNode) JsonUtil.toJsonNode(jsonSource);
|
||||
return createNodeInputConfig(jsonNode);
|
||||
}
|
||||
|
||||
public static WfNodeInputConfig createNodeInputConfig(ObjectNode jsonNode) {
|
||||
List<WfNodeIO> userInputs = new ArrayList<>();
|
||||
WfNodeInputConfig result = new WfNodeInputConfig();
|
||||
result.setUserInputs(userInputs);
|
||||
result.setRefInputs(new ArrayList<>());
|
||||
if (null == jsonNode) {
|
||||
return result;
|
||||
}
|
||||
ArrayNode userInputsJson = jsonNode.withArray("user_inputs");
|
||||
ArrayNode refInputs = jsonNode.withArray("ref_inputs");
|
||||
if (!userInputsJson.isEmpty()) {
|
||||
for (JsonNode userInput : userInputsJson) {
|
||||
if (userInput instanceof ObjectNode objectNode) {
|
||||
int type = objectNode.get("type").asInt();
|
||||
Class<? extends WfNodeIO> nodeIOClass = INPUT_TYPE_TO_NODE_IO_DEF.get(WfIODataTypeEnum.getByValue(type));
|
||||
WfNodeIO wfNodeIO = JsonUtil.fromJson(objectNode, nodeIOClass);
|
||||
if (null != wfNodeIO) {
|
||||
userInputs.add(wfNodeIO);
|
||||
} else {
|
||||
log.warn("用户输入格式不正确:{}", userInput);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!refInputs.isEmpty()) {
|
||||
List<WfNodeParamRef> list = JsonUtil.fromArrayNode(refInputs, WfNodeParamRef.class);
|
||||
if (CollectionUtils.isNotEmpty(list)) {
|
||||
result.setRefInputs(list);
|
||||
} else {
|
||||
log.warn("引用输入格式不正确:{}", refInputs);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNonNullParameter(PreparedStatement ps, int i, WfNodeInputConfig parameter, JdbcType jdbcType) {
|
||||
// PGobject jsonObject = new PGobject();
|
||||
// jsonObject.setType("jsonb");
|
||||
// try {
|
||||
// jsonObject.setValue(JsonUtil.toJson(parameter));
|
||||
// ps.setObject(i, jsonObject);
|
||||
// } catch (Exception e) {
|
||||
// throw new RuntimeException(e);
|
||||
// }
|
||||
}
|
||||
|
||||
@Override
|
||||
public WfNodeInputConfig getNullableResult(ResultSet rs, String columnName) throws SQLException {
|
||||
String jsonSource = rs.getString(columnName);
|
||||
if (jsonSource != null) {
|
||||
try {
|
||||
return fillNodeInputConfig(jsonSource);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public WfNodeInputConfig getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
|
||||
String jsonSource = rs.getString(columnIndex);
|
||||
if (jsonSource != null) {
|
||||
return fillNodeInputConfig(jsonSource);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public WfNodeInputConfig getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
|
||||
String jsonSource = cs.getString(columnIndex);
|
||||
if (jsonSource != null) {
|
||||
try {
|
||||
return fillNodeInputConfig(jsonSource);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package org.ruoyi.workflow.base;
|
||||
|
||||
import cn.dev33.satoken.stp.StpUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.ruoyi.common.core.domain.model.LoginUser;
|
||||
import org.ruoyi.common.core.exception.base.BaseException;
|
||||
import org.ruoyi.common.satoken.utils.LoginHelper;
|
||||
import org.ruoyi.workflow.entity.User;
|
||||
import org.ruoyi.workflow.enums.UserStatusEnum;
|
||||
|
||||
import static org.ruoyi.workflow.enums.ErrorEnum.A_USER_NOT_FOUND;
|
||||
|
||||
/**
|
||||
* 线程上下文适配器,统一接入 Sa-Token 登录态。
|
||||
*/
|
||||
public class ThreadContext {
|
||||
|
||||
private static final ThreadLocal<User> CURRENT_USER = new ThreadLocal<>();
|
||||
private static final ThreadLocal<String> CURRENT_TOKEN = new ThreadLocal<>();
|
||||
|
||||
private ThreadContext() {
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前登录的工作流用户。
|
||||
*/
|
||||
public static User getCurrentUser() {
|
||||
User cached = CURRENT_USER.get();
|
||||
if (cached != null) {
|
||||
return cached;
|
||||
}
|
||||
LoginUser loginUser = LoginHelper.getLoginUser();
|
||||
if (loginUser == null) {
|
||||
throw new BaseException(A_USER_NOT_FOUND.getInfo());
|
||||
}
|
||||
User mapped = mapToWorkflowUser(loginUser);
|
||||
CURRENT_USER.set(mapped);
|
||||
return mapped;
|
||||
}
|
||||
|
||||
/**
|
||||
* 允许在测试或特殊场景下显式设置当前用户。
|
||||
*/
|
||||
public static void setCurrentUser(User user) {
|
||||
if (user == null) {
|
||||
CURRENT_USER.remove();
|
||||
} else {
|
||||
CURRENT_USER.set(user);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前登录用户 ID。
|
||||
*/
|
||||
public static Long getCurrentUserId() {
|
||||
Long userId = LoginHelper.getUserId();
|
||||
if (userId != null) {
|
||||
return userId;
|
||||
}
|
||||
return getCurrentUser().getId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前访问 token。
|
||||
*/
|
||||
public static String getToken() {
|
||||
String token = CURRENT_TOKEN.get();
|
||||
if (StringUtils.isNotBlank(token)) {
|
||||
return token;
|
||||
}
|
||||
try {
|
||||
token = StpUtil.getTokenValue();
|
||||
} catch (Exception ignore) {
|
||||
token = null;
|
||||
}
|
||||
if (StringUtils.isNotBlank(token)) {
|
||||
CURRENT_TOKEN.set(token);
|
||||
}
|
||||
return token;
|
||||
}
|
||||
|
||||
public static void setToken(String token) {
|
||||
if (StringUtils.isBlank(token)) {
|
||||
CURRENT_TOKEN.remove();
|
||||
} else {
|
||||
CURRENT_TOKEN.set(token);
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean isLogin() {
|
||||
return LoginHelper.isLogin();
|
||||
}
|
||||
|
||||
public static User getExistCurrentUser() {
|
||||
return getCurrentUser();
|
||||
}
|
||||
|
||||
public static void unload() {
|
||||
CURRENT_USER.remove();
|
||||
CURRENT_TOKEN.remove();
|
||||
}
|
||||
|
||||
private static User mapToWorkflowUser(LoginUser loginUser) {
|
||||
User user = new User();
|
||||
user.setId(loginUser.getUserId());
|
||||
String nickname = loginUser.getNickName();
|
||||
user.setName(StringUtils.defaultIfBlank(nickname, loginUser.getUsername()));
|
||||
user.setEmail(loginUser.getUsername());
|
||||
user.setUuid(String.valueOf(loginUser.getUserId()));
|
||||
user.setUserStatus(UserStatusEnum.NORMAL);
|
||||
user.setIsAdmin(LoginHelper.isSuperAdmin(loginUser.getUserId()));
|
||||
user.setUnderstandContextMsgPairNum(0);
|
||||
user.setQuotaByTokenDaily(0);
|
||||
user.setQuotaByTokenMonthly(0);
|
||||
user.setQuotaByRequestDaily(0);
|
||||
user.setQuotaByRequestMonthly(0);
|
||||
user.setQuotaByImageDaily(0);
|
||||
user.setQuotaByImageMonthly(0);
|
||||
user.setIsDeleted(false);
|
||||
return user;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user