4 Commits

Author SHA1 Message Date
ageerle
63ec00cd71 Merge pull request #136 from Administratos-User/main
修复token不扣费和扣费异常的问题
2025-07-16 15:53:56 +08:00
ageerle
7eebd87cc8 Merge pull request #141 from lixiang-cell/lixiang
存储向量库同时存储元数据的fid和docId,提供根据fid docId删除
2025-07-16 15:52:19 +08:00
lixiang
285aa2ae62 存储向量库同时存储元数据的fid和docId,提供根据fid docId删除 2025-07-14 17:15:29 +08:00
Administratos-User
99114d3301 修复token不扣费和扣费异常的问题
1.本次提交的token数+未付费token数 判断大于100token的时候就进行扣费,当然这里还可以改成更多,我觉得100合适。
2.不需要进行扣费的地方屏蔽了相关代码。
3.SessionId传递异常 建议前端传递uuid,也就是每次会话的id。
2025-07-11 11:59:20 +08:00
5 changed files with 252 additions and 41 deletions

View File

@@ -61,4 +61,10 @@ public class ChatRequest {
*/
private String role;
/**
* 对话id(每个聊天窗口都不一样)
*/
private Long uuid;
}

View File

@@ -19,5 +19,7 @@ public interface VectorStoreService {
void removeById(String id,String modelName);
void removeByDocId(String docId, String kid);
void removeByFid(String fid, String kid);
}

View File

@@ -1,18 +1,25 @@
package org.ruoyi.service.impl;
import cn.hutool.json.JSONObject;
import com.google.protobuf.ServiceException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
@@ -21,9 +28,7 @@ import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.*;
/**
* 向量库管理
@@ -38,17 +43,49 @@ public class VectorStoreServiceImpl implements VectorStoreService {
private final ConfigService configService;
private EmbeddingStore<TextSegment> embeddingStore;
private WeaviateClient client;
@Override
public void createSchema(String kid, String modelName) {
String protocol = configService.getConfigValue("weaviate", "protocol");
String host = configService.getConfigValue("weaviate", "host");
String className = configService.getConfigValue("weaviate", "classname");
String className = configService.getConfigValue("weaviate", "classname")+kid;
// 创建 Weaviate 客户端
client= new WeaviateClient(new Config(protocol, host));
// 检查类是否存在,如果不存在就创建 schema
Result<Schema> schemaResult = client.schema().getter().run();
Schema schema = schemaResult.getResult();
boolean classExists = false;
for (WeaviateClass weaviateClass : schema.getClasses()) {
if (weaviateClass.getClassName().equals(className)) {
classExists = true;
break;
}
}
if (!classExists) {
// 类不存在,创建 schema
WeaviateClass build = WeaviateClass.builder()
.className(className)
.vectorizer("none")
.properties(
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
)
.build();
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
if (createResult.hasErrors()) {
log.error("Schema 创建失败: {}", createResult.getError());
} else {
log.info("Schema 创建成功: {}", className);
}
}
embeddingStore = WeaviateEmbeddingStore.builder()
.scheme(protocol)
.host(host)
.objectClass(className+kid)
.objectClass(className)
.scheme(protocol)
.avoidDups(true)
.consistencyLevel("ALL")
@@ -61,33 +98,98 @@ public class VectorStoreServiceImpl implements VectorStoreService {
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
List<String> chunkList = storeEmbeddingBo.getChunkList();
for (String s : chunkList) {
Embedding embedding = embeddingModel.embed(s).content();
TextSegment segment = TextSegment.from(s);
embeddingStore.add(embedding, segment);
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
log.info("向量存储条数记录: " + chunkList.size());
long startTime = System.currentTimeMillis();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = embeddingModel.embed(text).content();
Map<String, Object> properties = Map.of(
"text", text,
"fid",fid,
"kid", kid,
"docId", docId
);
Float[] vector = toObjectArray(embedding.vector());
client.data().creator()
.withClassName("LocalKnowledge" + kid) // 注意替换成实际类名
.withProperties(properties)
.withVector(vector)
.run();
}
long endTime = System.currentTimeMillis();
log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"");
}
private static Float[] toObjectArray(float[] primitive) {
Float[] result = new Float[primitive.length];
for (int i = 0; i < primitive.length; i++) {
result[i] = primitive[i]; // 自动装箱
}
return result;
}
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(queryVectorBo.getMaxResults())
.build();
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
List<String> results = new ArrayList<>();
matches.forEach(embeddingMatch -> results.add(embeddingMatch.embedded().text()));
return results;
}
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
vectorStrings.add(String.valueOf(v));
}
String vectorStr = String.join(",", vectorStrings);
String className = configService.getConfigValue("weaviate", "classname") ;
// 构建 GraphQL 查询
String graphQLQuery = String.format(
"{\n" +
" Get {\n" +
" %s(nearVector: {vector: [%s], certainty: %f} limit: %d) {\n" +
" text\n" +
" fid\n" +
" kid\n" +
" docId\n" +
" _additional {\n" +
" distance\n" +
" id\n" +
" }\n" +
" }\n" +
" }\n" +
"}",
className+ queryVectorBo.getKid(),
vectorStr,
queryVectorBo.getMaxResults()
);
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
List<String> resultList = new ArrayList<>();
if (result != null && !result.hasErrors()) {
Object data = result.getResult().getData();
JSONObject entries = new JSONObject(data);
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
if(objects.isEmpty()){
return resultList;
}
for (Object object : objects) {
Map<String, String> map = (Map<String, String>) object;
String content = map.get("text");
resultList.add( content);
}
return resultList;
} else {
log.error("GraphQL 查询失败: {}", result.getError());
return resultList;
}
}
@Override
@SneakyThrows
public void removeById(String id, String modelName) {
public void removeById(String id, String modelName) {
String protocol = configService.getConfigValue("weaviate", "protocol");
String host = configService.getConfigValue("weaviate", "host");
String className = configService.getConfigValue("weaviate", "classname");
@@ -102,6 +204,46 @@ public class VectorStoreServiceImpl implements VectorStoreService {
}
}
@Override
public void removeByDocId(String docId, String kid) {
String className = configService.getConfigValue("weaviate", "classname") + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("docId")
.operator(Operator.Equal)
.valueText(docId)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 docId={} 的所有向量数据", docId);
} else {
log.error("删除失败: {}", result.getError());
}
}
@Override
public void removeByFid(String fid, String kid) {
String className = configService.getConfigValue("weaviate", "classname") + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("fid")
.operator(Operator.Equal)
.valueText(fid)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 fid={} 的所有向量数据", fid);
} else {
log.error("删除失败: {}", result.getError());
}
}
/**
* 获取向量模型
*/

View File

@@ -47,46 +47,57 @@ public class ChatCostServiceImpl implements IChatCostService {
*/
@Override
public void deductToken(ChatRequest chatRequest) {
if(chatRequest.getUserId()==null || chatRequest.getSessionId()==null){
return;
}
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
System.out.println("deductToken->本次提交token数 : "+tokens);
String modelName = chatRequest.getModel();
ChatMessageBo chatMessageBo = new ChatMessageBo();
// 设置用户id
chatMessageBo.setUserId(chatRequest.getUserId());
// 设置对话角色
chatMessageBo.setRole(chatRequest.getRole());
// 设置会话id
chatMessageBo.setSessionId(chatRequest.getSessionId());
// 设置对话角色
chatMessageBo.setRole(chatRequest.getRole());
// 设置对话内容
chatMessageBo.setContent(chatRequest.getPrompt());
// 计算总token数
// 设置模型名字
chatMessageBo.setModelName(chatRequest.getModel());
// 获得记录的累计token数
ChatUsageToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), modelName);
if (chatToken == null) {
chatToken = new ChatUsageToken();
chatToken.setToken(0);
}
// 计算总token数
int totalTokens = chatToken.getToken() + tokens;
// 如果总token数大于等于1000,进行费用扣除
if (totalTokens >= 1000) {
// 计算费用
int token1 = totalTokens / 1000;
int token2 = totalTokens % 1000;
if (token2 > 0) {
// 保存剩余tokens
chatToken.setModelName(modelName);
chatToken.setUserId(chatMessageBo.getUserId());
chatToken.setToken(token2);
chatTokenService.editToken(chatToken);
} else {
chatTokenService.resetToken(chatMessageBo.getUserId(), modelName);
}
//当前未付费token
int token = chatToken.getToken();
System.out.println("deductToken->未付费的token数 : "+token);
System.out.println("deductToken->本次提交+未付费token数 : "+totalTokens);
//扣费核心逻辑总token大于100就要对未结清的token进行扣费
if (totalTokens >= 100) {// 如果总token数大于等于100,进行费用扣除
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
double cost = chatModelVo.getModelPrice();
if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
@@ -95,22 +106,42 @@ public class ChatCostServiceImpl implements IChatCostService {
chatMessageBo.setDeductCost(cost);
}else {
// 按token扣费
Double numberCost = token1 * cost;
Double numberCost = totalTokens * cost;
System.out.println("deductToken->按token扣费 计算token数量: "+totalTokens);
System.out.println("deductToken->按token扣费 每token的价格: "+cost);
deductUserBalance(chatMessageBo.getUserId(), numberCost);
chatMessageBo.setDeductCost(numberCost);
// 保存剩余tokens
chatToken.setModelName(modelName);
chatToken.setUserId(chatMessageBo.getUserId());
chatToken.setToken(0);//因为判断大于100token直接全部计算扣除了所以这里直接=0就可以了
chatTokenService.editToken(chatToken);
}
} else {
deductUserBalance(chatMessageBo.getUserId(), 0.0);
//不满100Token,不需要进行扣费啊啊啊
//deductUserBalance(chatMessageBo.getUserId(), 0.0);
chatMessageBo.setDeductCost(0d);
chatMessageBo.setRemark("不满1kToken,计入下一次!");
chatMessageBo.setRemark("不满100Token,计入下一次!");
System.out.println("deductToken->不满100Token,计入下一次!");
chatToken.setToken(totalTokens);
chatToken.setModelName(chatMessageBo.getModelName());
chatToken.setUserId(chatMessageBo.getUserId());
chatTokenService.editToken(chatToken);
}
// 保存消息记录
chatMessageService.insertByBo(chatMessageBo);
System.out.println("deductToken->chatMessageService.insertByBo(: "+chatMessageBo);
System.out.println("----------------------------------------");
}
/**
@@ -121,15 +152,25 @@ public class ChatCostServiceImpl implements IChatCostService {
*/
@Override
public void deductUserBalance(Long userId, Double numberCost) {
SysUser sysUser = sysUserMapper.selectById(userId);
if (sysUser == null) {
return;
}
Double userBalance = sysUser.getUserBalance();
System.out.println("deductUserBalance->准备扣除numberCost: "+numberCost);
System.out.println("deductUserBalance->剩余金额userBalance: "+userBalance);
if (userBalance < numberCost || userBalance == 0) {
throw new ServiceException("余额不足, 请充值");
}
sysUserMapper.update(null,
new LambdaUpdateWrapper<SysUser>()
.set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))

View File

@@ -81,6 +81,26 @@ public class SseServiceImpl implements ISseService {
chatRequest.setRole(Message.Role.USER.getName());
if(LoginHelper.isLogin()){
// 设置用户id
chatRequest.setUserId(LoginHelper.getUserId());
//待优化的地方 这里请前端提交send的时候传递uuid进来或者sessionId
//待优化的地方 这里请前端提交send的时候传递uuid进来或者sessionId
//待优化的地方 这里请前端提交send的时候传递uuid进来或者sessionId
{
// 设置会话id
if (chatRequest.getUuid() == null){
//暂时随机生成会话id
chatRequest.setSessionId(System.currentTimeMillis());
}else{
//这里或许需要修改一下这里应该用uuid 或者 前端传递 sessionId
chatRequest.setSessionId(chatRequest.getUuid());
}
}
// 保存消息记录 并扣除费用
chatCostService.deductToken(chatRequest);
chatRequest.setUserId(chatCostService.getUserId());