mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-16 21:33:40 +00:00
Merge branch 'main' of github.com:PeinYu/ruoyi-ai
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
package org.ruoyi.knowledge.chain.vectorizer;
|
||||
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.config.ChatConfig;
|
||||
import org.ruoyi.common.chat.localModels.LocalModelsofitClient;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class LocalModelsVectorization {
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Resource
|
||||
private LocalModelsofitClient localModelsofitClient;
|
||||
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final ChatConfig chatConfig;
|
||||
|
||||
/**
|
||||
* 批量向量化
|
||||
*
|
||||
* @param chunkList 文本块列表
|
||||
* @param kid 知识 ID
|
||||
* @return 向量化结果
|
||||
*/
|
||||
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
logVectorizationRequest(kid, chunkList); // 在向量化开始前记录日志
|
||||
openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 获取 OpenAi 客户端
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 查询知识信息
|
||||
// 调用 localModelsofitClient 获取 Top K 嵌入向量
|
||||
try {
|
||||
return localModelsofitClient.getTopKEmbeddings(
|
||||
chunkList,
|
||||
knowledgeInfoVo.getVector(),
|
||||
knowledgeInfoVo.getKnowledgeSeparator(),
|
||||
knowledgeInfoVo.getRetrieveLimit(),
|
||||
knowledgeInfoVo.getTextBlockSize(),
|
||||
knowledgeInfoVo.getOverlapChar()
|
||||
);
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e);
|
||||
throw new RuntimeException("Batch vectorization failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 单一文本块向量化
|
||||
*
|
||||
* @param chunk 单一文本块
|
||||
* @param kid 知识 ID
|
||||
* @return 向量化结果
|
||||
*/
|
||||
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
chunkList.add(chunk);
|
||||
|
||||
// 调用批量向量化方法
|
||||
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
|
||||
|
||||
if (vectorList.isEmpty()) {
|
||||
log.warn("Vectorization returned empty list for chunk: {}", chunk);
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
return vectorList.get(0); // 返回第一个向量
|
||||
}
|
||||
|
||||
/**
|
||||
* 提供更简洁的日志记录方法
|
||||
*
|
||||
* @param kid 知识 ID
|
||||
* @param chunkList 文本块列表
|
||||
*/
|
||||
private void logVectorizationRequest(String kid, List<String> chunkList) {
|
||||
log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size());
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import org.springframework.stereotype.Component;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -27,6 +28,9 @@ public class OpenAiVectorization implements Vectorization {
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
@Lazy
|
||||
@Resource
|
||||
private LocalModelsVectorization localModelsVectorization;
|
||||
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
@@ -35,25 +39,63 @@ public class OpenAiVectorization implements Vectorization {
|
||||
|
||||
@Override
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
Embedding embedding = Embedding.builder()
|
||||
.input(chunkList)
|
||||
.model(knowledgeInfoVo.getVectorModel())
|
||||
.build();
|
||||
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
embeddings.getData().forEach(data -> {
|
||||
List<BigDecimal> vector = data.getEmbedding();
|
||||
List<Double> doubleVector = new ArrayList<>();
|
||||
for (BigDecimal bd : vector) {
|
||||
doubleVector.add(bd.doubleValue());
|
||||
}
|
||||
vectorList.add(doubleVector);
|
||||
});
|
||||
|
||||
// 获取知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
|
||||
// 如果使用本地模型
|
||||
try {
|
||||
return localModelsVectorization.batchVectorization(chunkList, kid);
|
||||
} catch (Exception e) {
|
||||
log.error("Local models vectorization failed, falling back to OpenAI embeddings", e);
|
||||
}
|
||||
|
||||
// 如果本地模型失败,则调用 OpenAI 服务进行向量化
|
||||
Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
|
||||
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
|
||||
|
||||
// 处理 OpenAI 返回的嵌入数据
|
||||
vectorList = processOpenAiEmbeddings(embeddings);
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Embedding 对象
|
||||
*/
|
||||
private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
|
||||
return Embedding.builder()
|
||||
.input(chunkList)
|
||||
.model(knowledgeInfoVo.getVectorModel())
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 OpenAI 返回的嵌入数据
|
||||
*/
|
||||
private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
|
||||
embeddings.getData().forEach(data -> {
|
||||
List<BigDecimal> vector = data.getEmbedding();
|
||||
List<Double> doubleVector = convertToDoubleList(vector);
|
||||
vectorList.add(doubleVector);
|
||||
});
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 BigDecimal 转换为 Double 列表
|
||||
*/
|
||||
private List<Double> convertToDoubleList(List<BigDecimal> vector) {
|
||||
return vector.stream()
|
||||
.map(BigDecimal::doubleValue)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package org.ruoyi.knowledge.chain.vectorizer;
|
||||
|
||||
public enum VectorizationType {
|
||||
OPENAI, // OpenAI 向量化
|
||||
LOCAL; // 本地模型向量化
|
||||
|
||||
public static VectorizationType fromString(String type) {
|
||||
for (VectorizationType v : values()) {
|
||||
if (v.name().equalsIgnoreCase(type)) {
|
||||
return v;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown VectorizationType: " + type);
|
||||
}
|
||||
}
|
||||
@@ -25,12 +25,14 @@ import org.ruoyi.system.domain.vo.SysUserVo;
|
||||
import org.ruoyi.system.service.ISysModelService;
|
||||
import org.ruoyi.system.service.ISysPackagePlanService;
|
||||
import org.ruoyi.system.service.ISysUserService;
|
||||
import org.ruoyi.system.util.DesensitizationUtil;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 系统模型
|
||||
@@ -50,7 +52,6 @@ public class SysModelController extends BaseController {
|
||||
|
||||
private final ISysUserService userService;
|
||||
|
||||
|
||||
/**
|
||||
* 查询系统模型列表 - 全部
|
||||
*/
|
||||
@@ -82,6 +83,14 @@ public class SysModelController extends BaseController {
|
||||
List<String> array = new ArrayList<>(Arrays.asList(sysPackagePlanVo.getPlanDetail().split(",")));
|
||||
sysModelVos.removeIf(model -> !array.contains(model.getModelName()));
|
||||
}
|
||||
sysModelVos.stream().map(vo -> {
|
||||
String maskedApiHost = DesensitizationUtil.maskData(vo.getApiHost());
|
||||
String maskedApiKey = DesensitizationUtil.maskData(vo.getApiKey());
|
||||
vo.setApiHost(maskedApiHost);
|
||||
vo.setApiKey(maskedApiKey);
|
||||
return vo;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
return R.ok(sysModelVos);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ruoyi.system.service;
|
||||
|
||||
import cn.binarywang.wx.miniapp.api.WxMaService;
|
||||
import cn.binarywang.wx.miniapp.api.WxMaUserService;
|
||||
import cn.binarywang.wx.miniapp.bean.WxMaJscode2SessionResult;
|
||||
import cn.binarywang.wx.miniapp.util.WxMaConfigHolder;
|
||||
import cn.dev33.satoken.exception.NotLoginException;
|
||||
@@ -68,6 +69,19 @@ public class SysLoginService {
|
||||
@Value("${user.password.lockTime}")
|
||||
private Integer lockTime;
|
||||
|
||||
/**
|
||||
* 获取微信
|
||||
* @param xcxCode 获取xcxCode
|
||||
*/
|
||||
public String getOpenidFromCode(String xcxCode) {
|
||||
try {
|
||||
WxMaJscode2SessionResult sessionInfo = wxMaService.getUserService().getSessionInfo(xcxCode);
|
||||
return sessionInfo.getOpenid();
|
||||
} catch (WxErrorException e) {
|
||||
e.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* 登录验证
|
||||
*
|
||||
@@ -135,15 +149,14 @@ public class SysLoginService {
|
||||
public void visitorLogin(VisitorLoginBody loginBody) {
|
||||
String openid = "";
|
||||
// PC端游客登录
|
||||
if(LoginUserType.PC.getCode().equals(loginBody.getType())){
|
||||
if (LoginUserType.PC.getCode().equals(loginBody.getType())) {
|
||||
openid = loginBody.getCode();
|
||||
}else {
|
||||
} else {
|
||||
// 小程序匿名登录
|
||||
try {
|
||||
WxMaJscode2SessionResult session = wxMaService.getUserService().getSessionInfo(loginBody.getCode());
|
||||
openid = session.getOpenid();
|
||||
} catch (
|
||||
WxErrorException e) {
|
||||
} catch (WxErrorException e) {
|
||||
log.error(e.getMessage(), e);
|
||||
} finally {
|
||||
// 清理ThreadLocal
|
||||
@@ -159,7 +172,8 @@ public class SysLoginService {
|
||||
if (ObjectUtil.isNull(user)) {
|
||||
SysUserBo sysUser = new SysUserBo();
|
||||
// 改为自增
|
||||
String name = "用户" + UUIDShortUtil.generateShortUuid();;
|
||||
String name = "用户" + UUIDShortUtil.generateShortUuid();
|
||||
;
|
||||
// 设置默认用户名
|
||||
sysUser.setUserName(name);
|
||||
// 设置默认昵称
|
||||
@@ -170,7 +184,7 @@ public class SysLoginService {
|
||||
sysUser.setOpenId(openid);
|
||||
String configValue = configService.getConfigValue("mail", "amount");
|
||||
// 设置默认余额
|
||||
sysUser.setUserBalance(NumberUtils.toDouble(configValue,1));
|
||||
sysUser.setUserBalance(NumberUtils.toDouble(configValue, 1));
|
||||
// 注册用户,设置默认租户为0
|
||||
SysUser registerUser = userService.registerUser(sysUser, "0");
|
||||
|
||||
@@ -284,10 +298,7 @@ public class SysLoginService {
|
||||
|
||||
private SysUserVo loadUserByUsername(String tenantId, String username) {
|
||||
|
||||
SysUser user = userMapper.selectOne(new LambdaQueryWrapper<SysUser>()
|
||||
.select(SysUser::getUserName, SysUser::getStatus)
|
||||
.eq(TenantHelper.isEnable(), SysUser::getTenantId, tenantId)
|
||||
.eq(SysUser::getUserName, username));
|
||||
SysUser user = userMapper.selectOne(new LambdaQueryWrapper<SysUser>().select(SysUser::getUserName, SysUser::getStatus).eq(TenantHelper.isEnable(), SysUser::getTenantId, tenantId).eq(SysUser::getUserName, username));
|
||||
if (ObjectUtil.isNull(user)) {
|
||||
log.info("登录用户:{} 不存在.", username);
|
||||
throw new UserException("user.not.exists", username);
|
||||
@@ -302,10 +313,7 @@ public class SysLoginService {
|
||||
}
|
||||
|
||||
private SysUserVo loadUserByPhonenumber(String tenantId, String phonenumber) {
|
||||
SysUser user = userMapper.selectOne(new LambdaQueryWrapper<SysUser>()
|
||||
.select(SysUser::getPhonenumber, SysUser::getStatus)
|
||||
.eq(TenantHelper.isEnable(), SysUser::getTenantId, tenantId)
|
||||
.eq(SysUser::getPhonenumber, phonenumber));
|
||||
SysUser user = userMapper.selectOne(new LambdaQueryWrapper<SysUser>().select(SysUser::getPhonenumber, SysUser::getStatus).eq(TenantHelper.isEnable(), SysUser::getTenantId, tenantId).eq(SysUser::getPhonenumber, phonenumber));
|
||||
if (ObjectUtil.isNull(user)) {
|
||||
log.info("登录用户:{} 不存在.", phonenumber);
|
||||
throw new UserException("user.not.exists", phonenumber);
|
||||
@@ -320,10 +328,7 @@ public class SysLoginService {
|
||||
}
|
||||
|
||||
private SysUserVo loadUserByEmail(String tenantId, String email) {
|
||||
SysUser user = userMapper.selectOne(new LambdaQueryWrapper<SysUser>()
|
||||
.select(SysUser::getPhonenumber, SysUser::getStatus)
|
||||
.eq(TenantHelper.isEnable(), SysUser::getTenantId, tenantId)
|
||||
.eq(SysUser::getEmail, email));
|
||||
SysUser user = userMapper.selectOne(new LambdaQueryWrapper<SysUser>().select(SysUser::getPhonenumber, SysUser::getStatus).eq(TenantHelper.isEnable(), SysUser::getTenantId, tenantId).eq(SysUser::getEmail, email));
|
||||
if (ObjectUtil.isNull(user)) {
|
||||
log.info("登录用户:{} 不存在.", email);
|
||||
throw new UserException("user.not.exists", email);
|
||||
@@ -419,8 +424,7 @@ public class SysLoginService {
|
||||
} else if (TenantStatus.DISABLE.getCode().equals(tenant.getStatus())) {
|
||||
log.info("登录租户:{} 已被停用.", tenantId);
|
||||
throw new TenantException("tenant.blocked");
|
||||
} else if (ObjectUtil.isNotNull(tenant.getExpireTime())
|
||||
&& new Date().after(tenant.getExpireTime())) {
|
||||
} else if (ObjectUtil.isNotNull(tenant.getExpireTime()) && new Date().after(tenant.getExpireTime())) {
|
||||
log.info("登录租户:{} 已超过有效期.", tenantId);
|
||||
throw new TenantException("tenant.expired");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package org.ruoyi.system.util;
|
||||
|
||||
|
||||
public class DesensitizationUtil {
|
||||
public static String maskData(String data) {
|
||||
if (data == null || data.length() <= 4) {
|
||||
return data;
|
||||
}
|
||||
int start = 2;
|
||||
int end = data.length() - 2;
|
||||
StringBuilder masked = new StringBuilder();
|
||||
masked.append(data, 0, start);
|
||||
for (int i = start; i < end; i++) {
|
||||
masked.append('*');
|
||||
}
|
||||
masked.append(data.substring(end));
|
||||
return masked.toString();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user