mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-12 19:17:20 +00:00
Compare commits
4 Commits
138fa5f0e9
...
63ec00cd71
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63ec00cd71 | ||
|
|
7eebd87cc8 | ||
|
|
285aa2ae62 | ||
|
|
99114d3301 |
@@ -61,4 +61,10 @@ public class ChatRequest {
|
|||||||
*/
|
*/
|
||||||
private String role;
|
private String role;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 对话id(每个聊天窗口都不一样)
|
||||||
|
*/
|
||||||
|
private Long uuid;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,5 +19,7 @@ public interface VectorStoreService {
|
|||||||
|
|
||||||
void removeById(String id,String modelName);
|
void removeById(String id,String modelName);
|
||||||
|
|
||||||
|
void removeByDocId(String docId, String kid);
|
||||||
|
|
||||||
|
void removeByFid(String fid, String kid);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,25 @@
|
|||||||
package org.ruoyi.service.impl;
|
package org.ruoyi.service.impl;
|
||||||
|
|
||||||
|
import cn.hutool.json.JSONObject;
|
||||||
import com.google.protobuf.ServiceException;
|
import com.google.protobuf.ServiceException;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
|
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
|
||||||
import io.weaviate.client.Config;
|
import io.weaviate.client.Config;
|
||||||
import io.weaviate.client.WeaviateClient;
|
import io.weaviate.client.WeaviateClient;
|
||||||
import io.weaviate.client.base.Result;
|
import io.weaviate.client.base.Result;
|
||||||
|
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
|
||||||
|
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
|
||||||
|
import io.weaviate.client.v1.filters.Operator;
|
||||||
|
import io.weaviate.client.v1.filters.WhereFilter;
|
||||||
|
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||||
|
import io.weaviate.client.v1.schema.model.Property;
|
||||||
|
import io.weaviate.client.v1.schema.model.Schema;
|
||||||
|
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -21,9 +28,7 @@ import org.ruoyi.domain.bo.QueryVectorBo;
|
|||||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||||
import org.ruoyi.service.VectorStoreService;
|
import org.ruoyi.service.VectorStoreService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import java.util.*;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 向量库管理
|
* 向量库管理
|
||||||
@@ -38,17 +43,49 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|||||||
private final ConfigService configService;
|
private final ConfigService configService;
|
||||||
|
|
||||||
private EmbeddingStore<TextSegment> embeddingStore;
|
private EmbeddingStore<TextSegment> embeddingStore;
|
||||||
|
private WeaviateClient client;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void createSchema(String kid, String modelName) {
|
public void createSchema(String kid, String modelName) {
|
||||||
String protocol = configService.getConfigValue("weaviate", "protocol");
|
String protocol = configService.getConfigValue("weaviate", "protocol");
|
||||||
String host = configService.getConfigValue("weaviate", "host");
|
String host = configService.getConfigValue("weaviate", "host");
|
||||||
String className = configService.getConfigValue("weaviate", "classname");
|
String className = configService.getConfigValue("weaviate", "classname")+kid;
|
||||||
|
// 创建 Weaviate 客户端
|
||||||
|
client= new WeaviateClient(new Config(protocol, host));
|
||||||
|
// 检查类是否存在,如果不存在就创建 schema
|
||||||
|
Result<Schema> schemaResult = client.schema().getter().run();
|
||||||
|
Schema schema = schemaResult.getResult();
|
||||||
|
boolean classExists = false;
|
||||||
|
for (WeaviateClass weaviateClass : schema.getClasses()) {
|
||||||
|
if (weaviateClass.getClassName().equals(className)) {
|
||||||
|
classExists = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!classExists) {
|
||||||
|
// 类不存在,创建 schema
|
||||||
|
WeaviateClass build = WeaviateClass.builder()
|
||||||
|
.className(className)
|
||||||
|
.vectorizer("none")
|
||||||
|
.properties(
|
||||||
|
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
|
||||||
|
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
|
||||||
|
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
|
||||||
|
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
|
||||||
|
if (createResult.hasErrors()) {
|
||||||
|
log.error("Schema 创建失败: {}", createResult.getError());
|
||||||
|
} else {
|
||||||
|
log.info("Schema 创建成功: {}", className);
|
||||||
|
}
|
||||||
|
}
|
||||||
embeddingStore = WeaviateEmbeddingStore.builder()
|
embeddingStore = WeaviateEmbeddingStore.builder()
|
||||||
.scheme(protocol)
|
.scheme(protocol)
|
||||||
.host(host)
|
.host(host)
|
||||||
.objectClass(className+kid)
|
.objectClass(className)
|
||||||
.scheme(protocol)
|
.scheme(protocol)
|
||||||
.avoidDups(true)
|
.avoidDups(true)
|
||||||
.consistencyLevel("ALL")
|
.consistencyLevel("ALL")
|
||||||
@@ -61,33 +98,98 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|||||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
||||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||||
for (String s : chunkList) {
|
List<String> fidList = storeEmbeddingBo.getFids();
|
||||||
Embedding embedding = embeddingModel.embed(s).content();
|
String kid = storeEmbeddingBo.getKid();
|
||||||
TextSegment segment = TextSegment.from(s);
|
String docId = storeEmbeddingBo.getDocId();
|
||||||
embeddingStore.add(embedding, segment);
|
log.info("向量存储条数记录: " + chunkList.size());
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
for (int i = 0; i < chunkList.size(); i++) {
|
||||||
|
String text = chunkList.get(i);
|
||||||
|
String fid = fidList.get(i);
|
||||||
|
Embedding embedding = embeddingModel.embed(text).content();
|
||||||
|
Map<String, Object> properties = Map.of(
|
||||||
|
"text", text,
|
||||||
|
"fid",fid,
|
||||||
|
"kid", kid,
|
||||||
|
"docId", docId
|
||||||
|
);
|
||||||
|
Float[] vector = toObjectArray(embedding.vector());
|
||||||
|
client.data().creator()
|
||||||
|
.withClassName("LocalKnowledge" + kid) // 注意替换成实际类名
|
||||||
|
.withProperties(properties)
|
||||||
|
.withVector(vector)
|
||||||
|
.run();
|
||||||
}
|
}
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"秒");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Float[] toObjectArray(float[] primitive) {
|
||||||
|
Float[] result = new Float[primitive.length];
|
||||||
|
for (int i = 0; i < primitive.length; i++) {
|
||||||
|
result[i] = primitive[i]; // 自动装箱
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
@Override
|
@Override
|
||||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||||
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
|
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
|
||||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
|
||||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
float[] vector = queryEmbedding.vector();
|
||||||
.queryEmbedding(queryEmbedding)
|
List<String> vectorStrings = new ArrayList<>();
|
||||||
.maxResults(queryVectorBo.getMaxResults())
|
for (float v : vector) {
|
||||||
.build();
|
vectorStrings.add(String.valueOf(v));
|
||||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
|
}
|
||||||
List<String> results = new ArrayList<>();
|
String vectorStr = String.join(",", vectorStrings);
|
||||||
matches.forEach(embeddingMatch -> results.add(embeddingMatch.embedded().text()));
|
String className = configService.getConfigValue("weaviate", "classname") ;
|
||||||
return results;
|
// 构建 GraphQL 查询
|
||||||
}
|
String graphQLQuery = String.format(
|
||||||
|
"{\n" +
|
||||||
|
" Get {\n" +
|
||||||
|
" %s(nearVector: {vector: [%s], certainty: %f} limit: %d) {\n" +
|
||||||
|
" text\n" +
|
||||||
|
" fid\n" +
|
||||||
|
" kid\n" +
|
||||||
|
" docId\n" +
|
||||||
|
" _additional {\n" +
|
||||||
|
" distance\n" +
|
||||||
|
" id\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
"}",
|
||||||
|
className+ queryVectorBo.getKid(),
|
||||||
|
vectorStr,
|
||||||
|
queryVectorBo.getMaxResults()
|
||||||
|
);
|
||||||
|
|
||||||
|
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
|
||||||
|
List<String> resultList = new ArrayList<>();
|
||||||
|
if (result != null && !result.hasErrors()) {
|
||||||
|
Object data = result.getResult().getData();
|
||||||
|
JSONObject entries = new JSONObject(data);
|
||||||
|
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
|
||||||
|
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
|
||||||
|
if(objects.isEmpty()){
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
for (Object object : objects) {
|
||||||
|
Map<String, String> map = (Map<String, String>) object;
|
||||||
|
String content = map.get("text");
|
||||||
|
resultList.add( content);
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
} else {
|
||||||
|
log.error("GraphQL 查询失败: {}", result.getError());
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
String protocol = configService.getConfigValue("weaviate", "protocol");
|
String protocol = configService.getConfigValue("weaviate", "protocol");
|
||||||
String host = configService.getConfigValue("weaviate", "host");
|
String host = configService.getConfigValue("weaviate", "host");
|
||||||
String className = configService.getConfigValue("weaviate", "classname");
|
String className = configService.getConfigValue("weaviate", "classname");
|
||||||
@@ -102,6 +204,46 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void removeByDocId(String docId, String kid) {
|
||||||
|
String className = configService.getConfigValue("weaviate", "classname") + kid;
|
||||||
|
// 构建 Where 条件
|
||||||
|
WhereFilter whereFilter = WhereFilter.builder()
|
||||||
|
.path("docId")
|
||||||
|
.operator(Operator.Equal)
|
||||||
|
.valueText(docId)
|
||||||
|
.build();
|
||||||
|
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
|
||||||
|
Result<BatchDeleteResponse> result = deleter.withClassName(className)
|
||||||
|
.withWhere(whereFilter)
|
||||||
|
.run();
|
||||||
|
if (result != null && !result.hasErrors()) {
|
||||||
|
log.info("成功删除 docId={} 的所有向量数据", docId);
|
||||||
|
} else {
|
||||||
|
log.error("删除失败: {}", result.getError());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void removeByFid(String fid, String kid) {
|
||||||
|
String className = configService.getConfigValue("weaviate", "classname") + kid;
|
||||||
|
// 构建 Where 条件
|
||||||
|
WhereFilter whereFilter = WhereFilter.builder()
|
||||||
|
.path("fid")
|
||||||
|
.operator(Operator.Equal)
|
||||||
|
.valueText(fid)
|
||||||
|
.build();
|
||||||
|
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
|
||||||
|
Result<BatchDeleteResponse> result = deleter.withClassName(className)
|
||||||
|
.withWhere(whereFilter)
|
||||||
|
.run();
|
||||||
|
if (result != null && !result.hasErrors()) {
|
||||||
|
log.info("成功删除 fid={} 的所有向量数据", fid);
|
||||||
|
} else {
|
||||||
|
log.error("删除失败: {}", result.getError());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取向量模型
|
* 获取向量模型
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -47,46 +47,57 @@ public class ChatCostServiceImpl implements IChatCostService {
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void deductToken(ChatRequest chatRequest) {
|
public void deductToken(ChatRequest chatRequest) {
|
||||||
|
|
||||||
|
|
||||||
if(chatRequest.getUserId()==null || chatRequest.getSessionId()==null){
|
if(chatRequest.getUserId()==null || chatRequest.getSessionId()==null){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
|
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
|
||||||
|
|
||||||
|
System.out.println("deductToken->本次提交token数 : "+tokens);
|
||||||
|
|
||||||
String modelName = chatRequest.getModel();
|
String modelName = chatRequest.getModel();
|
||||||
|
|
||||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||||
|
|
||||||
// 设置用户id
|
// 设置用户id
|
||||||
chatMessageBo.setUserId(chatRequest.getUserId());
|
chatMessageBo.setUserId(chatRequest.getUserId());
|
||||||
// 设置对话角色
|
|
||||||
chatMessageBo.setRole(chatRequest.getRole());
|
|
||||||
// 设置会话id
|
// 设置会话id
|
||||||
chatMessageBo.setSessionId(chatRequest.getSessionId());
|
chatMessageBo.setSessionId(chatRequest.getSessionId());
|
||||||
|
|
||||||
|
// 设置对话角色
|
||||||
|
chatMessageBo.setRole(chatRequest.getRole());
|
||||||
|
|
||||||
// 设置对话内容
|
// 设置对话内容
|
||||||
chatMessageBo.setContent(chatRequest.getPrompt());
|
chatMessageBo.setContent(chatRequest.getPrompt());
|
||||||
|
|
||||||
// 计算总token数
|
// 设置模型名字
|
||||||
|
chatMessageBo.setModelName(chatRequest.getModel());
|
||||||
|
|
||||||
|
// 获得记录的累计token数
|
||||||
ChatUsageToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), modelName);
|
ChatUsageToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), modelName);
|
||||||
|
|
||||||
|
|
||||||
if (chatToken == null) {
|
if (chatToken == null) {
|
||||||
chatToken = new ChatUsageToken();
|
chatToken = new ChatUsageToken();
|
||||||
chatToken.setToken(0);
|
chatToken.setToken(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算总token数
|
||||||
int totalTokens = chatToken.getToken() + tokens;
|
int totalTokens = chatToken.getToken() + tokens;
|
||||||
// 如果总token数大于等于1000,进行费用扣除
|
|
||||||
if (totalTokens >= 1000) {
|
//当前未付费token
|
||||||
// 计算费用
|
int token = chatToken.getToken();
|
||||||
int token1 = totalTokens / 1000;
|
|
||||||
int token2 = totalTokens % 1000;
|
System.out.println("deductToken->未付费的token数 : "+token);
|
||||||
if (token2 > 0) {
|
System.out.println("deductToken->本次提交+未付费token数 : "+totalTokens);
|
||||||
// 保存剩余tokens
|
|
||||||
chatToken.setModelName(modelName);
|
|
||||||
chatToken.setUserId(chatMessageBo.getUserId());
|
//扣费核心逻辑(总token大于100就要对未结清的token进行扣费)
|
||||||
chatToken.setToken(token2);
|
if (totalTokens >= 100) {// 如果总token数大于等于100,进行费用扣除
|
||||||
chatTokenService.editToken(chatToken);
|
|
||||||
} else {
|
|
||||||
chatTokenService.resetToken(chatMessageBo.getUserId(), modelName);
|
|
||||||
}
|
|
||||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
|
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
|
||||||
double cost = chatModelVo.getModelPrice();
|
double cost = chatModelVo.getModelPrice();
|
||||||
if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
|
if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
|
||||||
@@ -95,22 +106,42 @@ public class ChatCostServiceImpl implements IChatCostService {
|
|||||||
chatMessageBo.setDeductCost(cost);
|
chatMessageBo.setDeductCost(cost);
|
||||||
}else {
|
}else {
|
||||||
// 按token扣费
|
// 按token扣费
|
||||||
Double numberCost = token1 * cost;
|
Double numberCost = totalTokens * cost;
|
||||||
|
System.out.println("deductToken->按token扣费 计算token数量: "+totalTokens);
|
||||||
|
System.out.println("deductToken->按token扣费 每token的价格: "+cost);
|
||||||
|
|
||||||
deductUserBalance(chatMessageBo.getUserId(), numberCost);
|
deductUserBalance(chatMessageBo.getUserId(), numberCost);
|
||||||
chatMessageBo.setDeductCost(numberCost);
|
chatMessageBo.setDeductCost(numberCost);
|
||||||
|
|
||||||
|
// 保存剩余tokens
|
||||||
|
chatToken.setModelName(modelName);
|
||||||
|
chatToken.setUserId(chatMessageBo.getUserId());
|
||||||
|
chatToken.setToken(0);//因为判断大于100token直接全部计算扣除了所以这里直接=0就可以了
|
||||||
|
chatTokenService.editToken(chatToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
deductUserBalance(chatMessageBo.getUserId(), 0.0);
|
//不满100Token,不需要进行扣费啊啊啊
|
||||||
|
//deductUserBalance(chatMessageBo.getUserId(), 0.0);
|
||||||
chatMessageBo.setDeductCost(0d);
|
chatMessageBo.setDeductCost(0d);
|
||||||
chatMessageBo.setRemark("不满1kToken,计入下一次!");
|
chatMessageBo.setRemark("不满100Token,计入下一次!");
|
||||||
|
System.out.println("deductToken->不满100Token,计入下一次!");
|
||||||
chatToken.setToken(totalTokens);
|
chatToken.setToken(totalTokens);
|
||||||
chatToken.setModelName(chatMessageBo.getModelName());
|
chatToken.setModelName(chatMessageBo.getModelName());
|
||||||
chatToken.setUserId(chatMessageBo.getUserId());
|
chatToken.setUserId(chatMessageBo.getUserId());
|
||||||
chatTokenService.editToken(chatToken);
|
chatTokenService.editToken(chatToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// 保存消息记录
|
// 保存消息记录
|
||||||
chatMessageService.insertByBo(chatMessageBo);
|
chatMessageService.insertByBo(chatMessageBo);
|
||||||
|
|
||||||
|
System.out.println("deductToken->chatMessageService.insertByBo(: "+chatMessageBo);
|
||||||
|
System.out.println("----------------------------------------");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -121,15 +152,25 @@ public class ChatCostServiceImpl implements IChatCostService {
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void deductUserBalance(Long userId, Double numberCost) {
|
public void deductUserBalance(Long userId, Double numberCost) {
|
||||||
|
|
||||||
SysUser sysUser = sysUserMapper.selectById(userId);
|
SysUser sysUser = sysUserMapper.selectById(userId);
|
||||||
if (sysUser == null) {
|
if (sysUser == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Double userBalance = sysUser.getUserBalance();
|
Double userBalance = sysUser.getUserBalance();
|
||||||
|
|
||||||
|
|
||||||
|
System.out.println("deductUserBalance->准备扣除:numberCost: "+numberCost);
|
||||||
|
System.out.println("deductUserBalance->剩余金额:userBalance: "+userBalance);
|
||||||
|
|
||||||
|
|
||||||
if (userBalance < numberCost || userBalance == 0) {
|
if (userBalance < numberCost || userBalance == 0) {
|
||||||
throw new ServiceException("余额不足, 请充值");
|
throw new ServiceException("余额不足, 请充值");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sysUserMapper.update(null,
|
sysUserMapper.update(null,
|
||||||
new LambdaUpdateWrapper<SysUser>()
|
new LambdaUpdateWrapper<SysUser>()
|
||||||
.set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))
|
.set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))
|
||||||
|
|||||||
@@ -81,6 +81,26 @@ public class SseServiceImpl implements ISseService {
|
|||||||
chatRequest.setRole(Message.Role.USER.getName());
|
chatRequest.setRole(Message.Role.USER.getName());
|
||||||
|
|
||||||
if(LoginHelper.isLogin()){
|
if(LoginHelper.isLogin()){
|
||||||
|
|
||||||
|
// 设置用户id
|
||||||
|
chatRequest.setUserId(LoginHelper.getUserId());
|
||||||
|
|
||||||
|
|
||||||
|
//待优化的地方 (这里请前端提交send的时候传递uuid进来或者sessionId)
|
||||||
|
//待优化的地方 (这里请前端提交send的时候传递uuid进来或者sessionId)
|
||||||
|
//待优化的地方 (这里请前端提交send的时候传递uuid进来或者sessionId)
|
||||||
|
{
|
||||||
|
// 设置会话id
|
||||||
|
if (chatRequest.getUuid() == null){
|
||||||
|
//暂时随机生成会话id
|
||||||
|
chatRequest.setSessionId(System.currentTimeMillis());
|
||||||
|
}else{
|
||||||
|
//这里或许需要修改一下,这里应该用uuid 或者 前端传递 sessionId
|
||||||
|
chatRequest.setSessionId(chatRequest.getUuid());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// 保存消息记录 并扣除费用
|
// 保存消息记录 并扣除费用
|
||||||
chatCostService.deductToken(chatRequest);
|
chatCostService.deductToken(chatRequest);
|
||||||
chatRequest.setUserId(chatCostService.getUserId());
|
chatRequest.setUserId(chatCostService.getUserId());
|
||||||
|
|||||||
Reference in New Issue
Block a user