diff --git a/ruoyi-admin/src/main/resources/application.yml b/ruoyi-admin/src/main/resources/application.yml index f6abd846..1992a5ab 100644 --- a/ruoyi-admin/src/main/resources/application.yml +++ b/ruoyi-admin/src/main/resources/application.yml @@ -162,7 +162,7 @@ tenant: - sys_user_role knowledge-role: - enable: true + enable: false # MyBatisPlus配置 # https://baomidou.com/config/ diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/request/ChatRequest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/request/ChatRequest.java index 71ebc4e5..f0eacdfc 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/request/ChatRequest.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/request/ChatRequest.java @@ -82,4 +82,9 @@ public class ChatRequest { */ private String token; + /** + * 消息ID(保存消息成功后设置,用于后续扣费更新) + */ + private Long messageId; + } diff --git a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/ChatMessage.java b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/ChatMessage.java index 4adcded6..6a4fac2c 100644 --- a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/ChatMessage.java +++ b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/ChatMessage.java @@ -69,5 +69,10 @@ public class ChatMessage extends BaseEntity { */ private String remark; + /** + * 计费类型(1-token计费,2-次数计费,null-普通消息) + */ + private String billingType; + } diff --git a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/bo/ChatMessageBo.java b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/bo/ChatMessageBo.java index fd228f73..5b22b008 100644 --- a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/bo/ChatMessageBo.java +++ b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/bo/ChatMessageBo.java @@ -75,5 +75,10 @@ public class ChatMessageBo extends BaseEntity { @NotBlank(message = "备注不能为空", groups = { AddGroup.class, EditGroup.class }) private String remark; + /** + * 计费类型(1-token计费,2-次数计费,null-普通消息) + */ + private String billingType; + } diff --git a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatMessageVo.java b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatMessageVo.java index 00920ef9..140398fc 100644 --- a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatMessageVo.java +++ b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatMessageVo.java @@ -4,6 +4,8 @@ import com.alibaba.excel.annotation.ExcelIgnoreUnannotated; import com.alibaba.excel.annotation.ExcelProperty; import io.github.linpeilie.annotations.AutoMapper; import lombok.Data; +import org.ruoyi.common.excel.annotation.ExcelDictFormat; +import org.ruoyi.common.excel.convert.ExcelDictConvert; import org.ruoyi.domain.ChatMessage; import java.io.Serial; @@ -73,6 +75,13 @@ public class ChatMessageVo implements Serializable { @ExcelProperty(value = "模型名称") private String modelName; + /** + * 计费类型(1-token计费,2-次数计费) + */ + @ExcelProperty(value = "计费类型", converter = ExcelDictConvert.class) + @ExcelDictFormat(dictType = "sys_model_billing") + private String billingType; + /** * 备注 */ @@ -87,4 +96,5 @@ public class ChatMessageVo implements Serializable { private Date createTime; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/event/ChatMessageCreatedEvent.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/event/ChatMessageCreatedEvent.java new file mode 100644 index 00000000..be067f15 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/event/ChatMessageCreatedEvent.java @@ -0,0 +1,34 @@ +package org.ruoyi.chat.event; + +import org.springframework.context.ApplicationEvent; + +/** + * 聊天消息创建事件(用于异步计费/累计等) + */ +public class ChatMessageCreatedEvent extends ApplicationEvent { + + private final Long userId; + private final Long sessionId; + private final String modelName; + private final String role; + private final String content; + private final Long messageId; + + public ChatMessageCreatedEvent(Long userId, Long sessionId, String modelName, String role, String content, Long messageId) { + super(userId); + this.userId = userId; + this.sessionId = sessionId; + this.modelName = modelName; + this.role = role; + this.content = content; + this.messageId = messageId; + } + + public Long getUserId() { return userId; } + public Long getSessionId() { return sessionId; } + public String getModelName() { return modelName; } + public String getRole() { return role; } + public String getContent() { return content; } + public Long getMessageId() { return messageId; } +} + diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/factory/ChatServiceFactory.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/factory/ChatServiceFactory.java index f1e88a9a..8fec93e9 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/factory/ChatServiceFactory.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/factory/ChatServiceFactory.java @@ -1,6 +1,8 @@ package org.ruoyi.chat.factory; +import org.ruoyi.chat.service.chat.IChatCostService; import org.ruoyi.chat.service.chat.IChatService; +import org.ruoyi.chat.service.chat.proxy.BillingChatServiceProxy; import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; @@ -18,13 +20,18 @@ import java.util.concurrent.ConcurrentHashMap; @Component public class ChatServiceFactory implements ApplicationContextAware { private final Map chatServiceMap = new ConcurrentHashMap<>(); + private IChatCostService chatCostService; @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + // 获取计费服务 + this.chatCostService = applicationContext.getBean(IChatCostService.class); + // 初始化时收集所有IChatService的实现 Map serviceMap = applicationContext.getBeansOfType(IChatService.class); for (IChatService service : serviceMap.values()) { - if (service != null) { + if (service != null && !isBillingProxy(service)) { + // 只收集非代理的原始服务 chatServiceMap.put(service.getCategory(), service); } } @@ -32,12 +39,33 @@ public class ChatServiceFactory implements ApplicationContextAware { /** * 根据模型类别获取对应的聊天服务实现 + * 自动应用计费代理包装 */ public IChatService getChatService(String category) { + IChatService originalService = chatServiceMap.get(category); + if (originalService == null) { + throw new IllegalArgumentException("不支持的模型类别: " + category); + } + + // 自动包装为计费代理 + return new BillingChatServiceProxy(originalService, chatCostService); + } + + /** + * 获取原始服务(不包装代理) + */ + public IChatService getOriginalService(String category) { IChatService service = chatServiceMap.get(category); if (service == null) { throw new IllegalArgumentException("不支持的模型类别: " + category); } return service; } + + /** + * 判断是否为计费代理实例 + */ + private boolean isBillingProxy(IChatService service) { + return service instanceof BillingChatServiceProxy; + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/listener/BillingEventListener.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/listener/BillingEventListener.java new file mode 100644 index 00000000..91631f87 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/listener/BillingEventListener.java @@ -0,0 +1,49 @@ +package org.ruoyi.chat.listener; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.chat.event.ChatMessageCreatedEvent; +import org.ruoyi.chat.service.chat.IChatCostService; +import org.ruoyi.common.chat.request.ChatRequest; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Component; +import org.springframework.context.event.EventListener; +import org.springframework.transaction.event.TransactionPhase; +import org.springframework.transaction.event.TransactionalEventListener; + +@Slf4j +@Component +@RequiredArgsConstructor +public class BillingEventListener { + + private final IChatCostService chatCostService; + + @Async + @EventListener + public void onChatMessageCreated(ChatMessageCreatedEvent event) { + log.debug("BillingEventListener->接收到计费事件,用户ID: {},会话ID: {},模型: {}", + event.getUserId(), event.getSessionId(), event.getModelName()); + try { + ChatRequest chatRequest = new ChatRequest(); + chatRequest.setUserId(event.getUserId()); + chatRequest.setSessionId(event.getSessionId()); + chatRequest.setModel(event.getModelName()); + chatRequest.setRole(event.getRole()); + chatRequest.setPrompt(event.getContent()); + chatRequest.setMessageId(event.getMessageId()); // 设置消息ID + // 异步执行计费累计与扣费 + log.debug("BillingEventListener->开始执行计费逻辑"); + chatCostService.deductToken(chatRequest); + log.debug("BillingEventListener->计费逻辑执行完成"); + } catch (Exception ex) { + // 由于已有预检查,这里的异常主要是系统异常(数据库连接等) + // 记录错误但不中断异步线程 + log.error("BillingEventListener->异步计费异常,用户ID: {},模型: {},错误: {}", + event.getUserId(), event.getModelName(), ex.getMessage(), ex); + + // TODO: 可以考虑加入重试机制或者错误通知机制 + // 例如:发送到死信队列,或者通知运维人员 + } + } +} + diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/listener/SSEEventSourceListener.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/listener/SSEEventSourceListener.java index d10f8036..52e94b9e 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/listener/SSEEventSourceListener.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/listener/SSEEventSourceListener.java @@ -84,6 +84,8 @@ public class SSEEventSourceListener extends EventSourceListener { emitter.complete(); // 清理失败回调(以 emitter 为键) RetryNotifier.clear(emitter); + // 🔥 注释:AI回复的保存和计费已由BillingChatServiceProxy统一处理,此处代码已废弃 + /* // 扣除费用 ChatRequest chatRequest = new ChatRequest(); // 设置对话角色 @@ -94,7 +96,10 @@ public class SSEEventSourceListener extends EventSourceListener { chatRequest.setPrompt(stringBuffer.toString()); // 记录会话token BaseContext.setCurrentToken(token); - chatCostService.deductToken(chatRequest); + // 先保存助手消息,再发布异步计费事件 + chatCostService.saveMessage(chatRequest); + chatCostService.publishBillingEvent(chatRequest); + */ return; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/IChatCostService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/IChatCostService.java index 73c0c443..ccdd0617 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/IChatCostService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/IChatCostService.java @@ -19,6 +19,22 @@ public interface IChatCostService { void deductToken(ChatRequest chatRequest); + /** + * 保存聊天消息记录(不进行计费) + * + * @param chatRequest 对话信息 + */ + void saveMessage(ChatRequest chatRequest); + + + + /** + * 仅发布异步计费事件(不做入库) + * + * @param chatRequest 对话信息 + */ + void publishBillingEvent(ChatRequest chatRequest); + /** * 直接扣除用户的余额 * @@ -45,4 +61,12 @@ public interface IChatCostService { * 获取登录用户id */ Long getUserId(); + + /** + * 检查用户余额是否足够支付预估费用 + * + * @param chatRequest 对话信息 + * @return true=余额充足,false=余额不足 + */ + boolean checkBalanceSufficient(ChatRequest chatRequest); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/ChatCostServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/ChatCostServiceImpl.java index c460bba4..16507b94 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/ChatCostServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/ChatCostServiceImpl.java @@ -3,7 +3,10 @@ package org.ruoyi.chat.service.chat.impl; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import java.math.BigDecimal; +import java.math.RoundingMode; import org.ruoyi.chat.enums.BillingType; +import org.ruoyi.chat.event.ChatMessageCreatedEvent; import org.ruoyi.chat.enums.UserGradeType; import org.ruoyi.chat.service.chat.IChatCostService; import org.ruoyi.common.chat.request.ChatRequest; @@ -20,6 +23,7 @@ import org.ruoyi.service.IChatModelService; import org.ruoyi.service.IChatTokenService; import org.ruoyi.system.domain.SysUser; import org.ruoyi.system.mapper.SysUserMapper; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.stereotype.Service; @@ -42,108 +46,309 @@ public class ChatCostServiceImpl implements IChatCostService { private final IChatModelService chatModelService; + private final ApplicationEventPublisher eventPublisher; + /** - * 扣除用户余额 + * 扣除用户余额(仅计费与累计,不保存消息) */ @Override public void deductToken(ChatRequest chatRequest) { - - - if(chatRequest.getUserId()==null || chatRequest.getSessionId()==null){ + if (chatRequest.getUserId() == null) { + log.warn("deductToken->用户ID为空,跳过计费"); return; } - int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt()); - - System.out.println("deductToken->本次提交token数 : "+tokens); + log.debug("deductToken->本次提交token数: {}", tokens); String modelName = chatRequest.getModel(); + ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName); + BigDecimal unitPrice = BigDecimal.valueOf(chatModelVo.getModelPrice()); - ChatMessageBo chatMessageBo = new ChatMessageBo(); + // 按次计费:每次调用都直接扣费,不累计token + if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) { + BigDecimal numberCost = unitPrice.setScale(2, RoundingMode.HALF_UP); + deductUserBalance(chatRequest.getUserId(), numberCost.doubleValue()); + log.debug("deductToken->按次数扣费,费用: {},模型: {}", numberCost, modelName); - // 设置用户id - chatMessageBo.setUserId(chatRequest.getUserId()); - // 设置会话id - chatMessageBo.setSessionId(chatRequest.getSessionId()); + // 清理可能存在的历史累计token(模型计费方式可能发生过变更) + ChatUsageToken existingToken = chatTokenService.queryByUserId(chatRequest.getUserId(), modelName); + if (existingToken != null && existingToken.getToken() > 0) { + existingToken.setToken(0); + chatTokenService.editToken(existingToken); + log.debug("deductToken->按次计费,清理历史累计token: {}", existingToken.getToken()); + } - // 设置对话角色 - chatMessageBo.setRole(chatRequest.getRole()); + // 更新消息的计费信息到备注 + updateMessageBilling(chatRequest, tokens, numberCost.doubleValue(), chatModelVo.getModelType()); + return; + } - // 设置对话内容 - chatMessageBo.setContent(chatRequest.getPrompt()); - - // 设置模型名字 - chatMessageBo.setModelName(chatRequest.getModel()); + // 按token计费:累加并按阈值批量扣费,保留余数 + final int threshold = 1000; // 获得记录的累计token数 - ChatUsageToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), modelName); - - + // TODO: 这里存在并发竞态条件,需要在chatTokenService层面添加乐观锁或分布式锁 + ChatUsageToken chatToken = chatTokenService.queryByUserId(chatRequest.getUserId(), modelName); if (chatToken == null) { chatToken = new ChatUsageToken(); chatToken.setToken(0); + chatToken.setModelName(modelName); + chatToken.setUserId(chatRequest.getUserId()); } - // 计算总token数 - int totalTokens = chatToken.getToken() + tokens; + int previousUnpaid = chatToken.getToken(); + int totalTokens = previousUnpaid + tokens; + log.debug("deductToken->未付费token数: {},本次累计后总数: {}", previousUnpaid, totalTokens); - //当前未付费token - int token = chatToken.getToken(); + int billable = (totalTokens / threshold) * threshold; // 可计费整批token数 + int remainder = totalTokens - billable; // 结算后保留的余数 - System.out.println("deductToken->未付费的token数 : "+token); - System.out.println("deductToken->本次提交+未付费token数 : "+totalTokens); + if (billable > 0) { + // 计算批次数:每1000个Token为一批,每批扣费单价 + int batches = billable / threshold; + BigDecimal numberCost = unitPrice + .multiply(BigDecimal.valueOf(batches)) + .setScale(2, RoundingMode.HALF_UP); + log.debug("deductToken->按token扣费,结算token数量: {},批次数: {},单价: {},费用: {}", + billable, batches, unitPrice, numberCost); - - //扣费核心逻辑(总token大于100就要对未结清的token进行扣费) - if (totalTokens >= 100) {// 如果总token数大于等于100,进行费用扣除 - - ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName); - double cost = chatModelVo.getModelPrice(); - if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) { - // 按次数扣费 - deductUserBalance(chatMessageBo.getUserId(), cost); - chatMessageBo.setDeductCost(cost); - }else { - // 按token扣费 - Double numberCost = totalTokens * cost; - System.out.println("deductToken->按token扣费 计算token数量: "+totalTokens); - System.out.println("deductToken->按token扣费 每token的价格: "+cost); - - deductUserBalance(chatMessageBo.getUserId(), numberCost); - chatMessageBo.setDeductCost(numberCost); - - // 保存剩余tokens + try { + // 先尝试扣费 + deductUserBalance(chatRequest.getUserId(), numberCost.doubleValue()); + // 扣费成功后,保存余数 chatToken.setModelName(modelName); - chatToken.setUserId(chatMessageBo.getUserId()); - chatToken.setToken(0);//因为判断大于100token直接全部计算扣除了所以这里直接=0就可以了 + chatToken.setUserId(chatRequest.getUserId()); + chatToken.setToken(remainder); chatTokenService.editToken(chatToken); + log.debug("deductToken->扣费成功,更新余数: {}", remainder); + + // 更新消息的计费信息到备注 + updateMessageBilling(chatRequest, billable, numberCost.doubleValue(), chatModelVo.getModelType()); + } catch (ServiceException e) { + // 余额不足时,不更新token累计,保持原有累计数 + log.warn("deductToken->余额不足,本次token累计保持不变: {}", totalTokens); + throw e; // 重新抛出异常 } - - - } else { - //不满100Token,不需要进行扣费啊啊啊 - //deductUserBalance(chatMessageBo.getUserId(), 0.0); - chatMessageBo.setDeductCost(0d); - chatMessageBo.setRemark("不满100Token,计入下一次!"); - System.out.println("deductToken->不满100Token,计入下一次!"); + // 未达阈值,累积token + log.debug("deductToken->未达到计费阈值({}),累积到下次", threshold); + chatToken.setModelName(modelName); + chatToken.setUserId(chatRequest.getUserId()); chatToken.setToken(totalTokens); - chatToken.setModelName(chatMessageBo.getModelName()); - chatToken.setUserId(chatMessageBo.getUserId()); chatTokenService.editToken(chatToken); + + // 虽未扣费,但要更新消息的基本信息(实际token数、计费类型等) + updateMessageWithoutBilling(chatRequest, tokens, chatModelVo.getModelType()); + } + } + + /** + * 保存聊天消息记录(不进行计费) + * 保存成功后将消息ID设置到ChatRequest中,供后续扣费使用 + */ + @Override + public void saveMessage(ChatRequest chatRequest) { + if (chatRequest.getUserId() == null || chatRequest.getSessionId() == null) { + log.warn("saveMessage->用户ID或会话ID为空,跳过保存消息"); + return; } + // 验证消息内容 + if (chatRequest.getPrompt() == null || chatRequest.getPrompt().trim().isEmpty()) { + log.warn("saveMessage->消息内容为空,跳过保存"); + return; + } + ChatMessageBo chatMessageBo = new ChatMessageBo(); + chatMessageBo.setUserId(chatRequest.getUserId()); + chatMessageBo.setSessionId(chatRequest.getSessionId()); + chatMessageBo.setRole(chatRequest.getRole()); + chatMessageBo.setContent(chatRequest.getPrompt().trim()); + chatMessageBo.setModelName(chatRequest.getModel()); +// // 基础消息信息,计费相关数据(tokens、费用、计费类型等)在扣费时统一设置 +// chatMessageBo.setTotalTokens(0); // 初始设为0,扣费时更新 +// chatMessageBo.setDeductCost(null); +// chatMessageBo.setBillingType(null); +// chatMessageBo.setRemark("用户消息"); - // 保存消息记录 - chatMessageService.insertByBo(chatMessageBo); - - System.out.println("deductToken->chatMessageService.insertByBo(: "+chatMessageBo); - System.out.println("----------------------------------------"); + try { + chatMessageService.insertByBo(chatMessageBo); + // 保存成功后,将生成的消息ID设置到ChatRequest中 + chatRequest.setMessageId(chatMessageBo.getId()); + log.debug("saveMessage->成功保存消息,消息ID: {}, 用户ID: {}, 会话ID: {}", + chatMessageBo.getId(), chatRequest.getUserId(), chatRequest.getSessionId()); + } catch (Exception e) { + log.error("saveMessage->保存消息失败", e); + throw new ServiceException("保存消息失败"); + } } + + + @Override + public void publishBillingEvent(ChatRequest chatRequest) { + log.debug("publishBillingEvent->发布计费事件,用户ID: {},会话ID: {},模型: {}", + chatRequest.getUserId(), chatRequest.getSessionId(), chatRequest.getModel()); + + // 预检查:评估可能的扣费金额,如果余额不足则直接抛异常 + try { + preCheckBalance(chatRequest); + } catch (ServiceException e) { + log.warn("publishBillingEvent->预检查余额不足,用户ID: {},模型: {}", + chatRequest.getUserId(), chatRequest.getModel()); + throw e; // 直接抛出,阻止消息保存和对话继续 + } + + eventPublisher.publishEvent(new ChatMessageCreatedEvent( + chatRequest.getUserId(), + chatRequest.getSessionId(), + chatRequest.getModel(), + chatRequest.getRole(), + chatRequest.getPrompt(), + chatRequest.getMessageId() + )); + log.debug("publishBillingEvent->计费事件发布完成"); + } + + /** + * 预检查用户余额是否足够支付可能的费用 + */ + private void preCheckBalance(ChatRequest chatRequest) { + if (chatRequest.getUserId() == null) { + return; + } + + int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt()); + String modelName = chatRequest.getModel(); + ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName); + BigDecimal unitPrice = BigDecimal.valueOf(chatModelVo.getModelPrice()); + + // 按次计费:直接检查单次费用 + if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) { + BigDecimal numberCost = unitPrice.setScale(2, RoundingMode.HALF_UP); + checkUserBalanceWithoutDeduct(chatRequest.getUserId(), numberCost.doubleValue()); + return; + } + + // 按token计费:检查累计后可能的费用 + final int threshold = 1000; + ChatUsageToken chatToken = chatTokenService.queryByUserId(chatRequest.getUserId(), modelName); + int previousUnpaid = (chatToken == null) ? 0 : chatToken.getToken(); + int totalTokens = previousUnpaid + tokens; + + int billable = (totalTokens / threshold) * threshold; + if (billable > 0) { + // 计算批次数:每1000个Token为一批,每批扣费单价 + int batches = billable / threshold; + BigDecimal numberCost = unitPrice + .multiply(BigDecimal.valueOf(batches)) + .setScale(2, RoundingMode.HALF_UP); + checkUserBalanceWithoutDeduct(chatRequest.getUserId(), numberCost.doubleValue()); + } + } + + /** + * 检查用户余额是否足够,但不扣除 + */ + private void checkUserBalanceWithoutDeduct(Long userId, Double numberCost) { + SysUser sysUser = sysUserMapper.selectById(userId); + if (sysUser == null) { + throw new ServiceException("用户不存在"); + } + + BigDecimal userBalance = BigDecimal.valueOf(sysUser.getUserBalance() == null ? 0D : sysUser.getUserBalance()) + .setScale(2, RoundingMode.HALF_UP); + BigDecimal cost = BigDecimal.valueOf(numberCost == null ? 0D : numberCost) + .setScale(2, RoundingMode.HALF_UP); + + if (userBalance.compareTo(cost) < 0 || userBalance.compareTo(BigDecimal.ZERO) == 0) { + throw new ServiceException("余额不足, 请充值。当前余额: " + userBalance + ",需要: " + cost); + } + } + + /** + * 更新消息的基本信息(不涉及扣费) + */ + private void updateMessageWithoutBilling(ChatRequest chatRequest, int actualTokens, String billingTypeCode) { + // 检查是否有消息ID可以更新 + if (chatRequest.getMessageId() == null) { + log.warn("updateMessageWithoutBilling->消息ID为空,无法更新基本信息"); + return; + } + + try { + // 创建更新对象,只更新基本信息,不涉及扣费 + ChatMessageBo updateMessage = new ChatMessageBo(); + updateMessage.setId(chatRequest.getMessageId()); + updateMessage.setTotalTokens(actualTokens); // 设置实际token数 + updateMessage.setBillingType(billingTypeCode); // 设置计费类型 + updateMessage.setRemark("用户消息(累计中,未达扣费阈值)"); // 说明状态 + + // 更新消息 + chatMessageService.updateByBo(updateMessage); + log.debug("updateMessageWithoutBilling->更新消息基本信息成功,消息ID: {}, 实际tokens: {}, 计费类型: {}", + chatRequest.getMessageId(), actualTokens, billingTypeCode); + } catch (Exception e) { + log.error("updateMessageWithoutBilling->更新消息基本信息失败,消息ID: {}", chatRequest.getMessageId(), e); + // 更新失败不影响主流程,只记录错误日志 + } + } + + /** + * 更新消息的计费信息到备注字段 + */ + private void updateMessageBilling(ChatRequest chatRequest, int billedTokens, double cost, String billingTypeCode) { + // 检查是否有消息ID可以更新 + if (chatRequest.getMessageId() == null) { + log.warn("updateMessageBilling->消息ID为空,无法更新计费信息"); + return; + } + + try { + // 计算本次消息的实际token数 + int actualTokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt()); + + // 构建计费信息 + String billingInfo = buildBillingInfo(billingTypeCode, billedTokens, cost); + + // 创建更新对象 + ChatMessageBo updateMessage = new ChatMessageBo(); + updateMessage.setId(chatRequest.getMessageId()); + updateMessage.setTotalTokens(actualTokens); // 设置实际token数 + updateMessage.setDeductCost(cost); + updateMessage.setRemark(billingInfo); + updateMessage.setBillingType(billingTypeCode); + + // 更新消息 + chatMessageService.updateByBo(updateMessage); + log.debug("updateMessageBilling->更新消息计费信息成功,消息ID: {}, 实际tokens: {}, 计费tokens: {}, 费用: {}", + chatRequest.getMessageId(), actualTokens, billedTokens, cost); + } catch (Exception e) { + log.error("updateMessageBilling->更新消息计费信息失败,消息ID: {}", chatRequest.getMessageId(), e); + // 更新失败不影响主流程,只记录错误日志 + } + } + + /** + * 构建计费信息字符串 + */ + private String buildBillingInfo(String billingTypeCode, int billedTokens, double cost) { + // 使用枚举获取计费类型并构建计费信息 + BillingType billingType = BillingType.fromCode(billingTypeCode); + if (billingType != null) { + return switch (billingType) { + case TIMES -> String.format("%s:消耗 %d tokens,扣费 %.2f 元", billingType.getDescription(), billedTokens, cost); + case TOKEN -> String.format("%s:结算 %d tokens,扣费 %.2f 元", billingType.getDescription(), billedTokens, cost); + }; + } else { + return String.format("系统计费:处理 %d tokens,扣费 %.2f 元", billedTokens, cost); + } + } + + + /** * 从用户余额中扣除费用 * @@ -158,22 +363,26 @@ public class ChatCostServiceImpl implements IChatCostService { return; } - Double userBalance = sysUser.getUserBalance(); + BigDecimal userBalance = BigDecimal.valueOf(sysUser.getUserBalance() == null ? 0D : sysUser.getUserBalance()) + .setScale(2, RoundingMode.HALF_UP); + BigDecimal cost = BigDecimal.valueOf(numberCost == null ? 0D : numberCost) + .setScale(2, RoundingMode.HALF_UP); + log.debug("deductUserBalance->准备扣除: {},当前余额: {}", cost, userBalance); - System.out.println("deductUserBalance->准备扣除:numberCost: "+numberCost); - System.out.println("deductUserBalance->剩余金额:userBalance: "+userBalance); - - - if (userBalance < numberCost || userBalance == 0) { + if (userBalance.compareTo(cost) < 0 || userBalance.compareTo(BigDecimal.ZERO) == 0) { throw new ServiceException("余额不足, 请充值"); } - + BigDecimal newBalance = userBalance.subtract(cost); + if (newBalance.compareTo(BigDecimal.ZERO) < 0) { + newBalance = BigDecimal.ZERO; + } + newBalance = newBalance.setScale(2, RoundingMode.HALF_UP); sysUserMapper.update(null, new LambdaUpdateWrapper() - .set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0)) + .set(SysUser::getUserBalance, newBalance.doubleValue()) .eq(SysUser::getUserId, userId)); } @@ -193,6 +402,7 @@ public class ChatCostServiceImpl implements IChatCostService { chatMessageBo.setContent(prompt); chatMessageBo.setDeductCost(cost); chatMessageBo.setTotalTokens(0); + chatMessageBo.setRemark(String.format("任务计费:%s,扣费 %.2f 元", type, cost)); chatMessageService.insertByBo(chatMessageBo); } @@ -218,4 +428,29 @@ public class ChatCostServiceImpl implements IChatCostService { } return loginUser.getUserId(); } + + /** + * 检查用户余额是否足够支付预估费用 + */ + @Override + public boolean checkBalanceSufficient(ChatRequest chatRequest) { + if (chatRequest.getUserId() == null) { + log.warn("checkBalanceSufficient->用户ID为空,视为余额不足"); + return false; + } + + try { + // 重用现有的预检查逻辑,但不抛异常,只返回boolean + preCheckBalance(chatRequest); + return true; // 预检查通过,余额充足 + } catch (ServiceException e) { + log.debug("checkBalanceSufficient->余额不足,用户ID: {}, 模型: {}, 错误: {}", + chatRequest.getUserId(), chatRequest.getModel(), e.getMessage()); + return false; // 预检查失败,余额不足 + } catch (Exception e) { + log.error("checkBalanceSufficient->检查余额时发生异常,用户ID: {}, 模型: {}", + chatRequest.getUserId(), chatRequest.getModel(), e); + return false; // 异常情况视为余额不足,保守处理 + } + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/DifyServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/DifyServiceImpl.java index 51f9a960..43608516 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/DifyServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/DifyServiceImpl.java @@ -91,6 +91,8 @@ public class DifyServiceImpl implements IChatService { public void onMessageEnd(MessageEndEvent event) { emitter.complete(); log.info("消息结束,完整消息ID: {}", event.getMessageId()); + // 扣除费用 + ChatRequest chatRequestResponse = new ChatRequest(); // 更新conversationId if (StrUtil.isBlank(sessionInfo.getConversationId())) { String conversationId = event.getConversationId(); @@ -104,16 +106,16 @@ public class DifyServiceImpl implements IChatService { chatSessionBo.setSessionContent(sessionInfo.getSessionContent()); chatSessionBo.setRemark(sessionInfo.getRemark()); chatSessionService.updateByBo(chatSessionBo); + chatRequestResponse.setMessageId(chatSessionBo.getId()); } - // 扣除费用 - ChatRequest chatRequestResponse = new ChatRequest(); + // 设置对话角色 - chatRequestResponse.setRole(Message.Role.ASSISTANT.getName()); - chatRequestResponse.setModel(chatRequest.getModel()); - chatRequestResponse.setUserId(chatRequest.getUserId()); - chatRequestResponse.setSessionId(chatRequest.getSessionId()); - chatRequestResponse.setPrompt(respMessage.toString()); - chatCostService.deductToken(chatRequestResponse); +// chatRequestResponse.setRole(Message.Role.ASSISTANT.getName()); +// chatRequestResponse.setModel(chatRequest.getModel()); +// chatRequestResponse.setUserId(chatRequest.getUserId()); +// chatRequestResponse.setSessionId(chatRequest.getSessionId()); +// chatRequestResponse.setPrompt(respMessage.toString()); +// chatCostService.deductToken(chatRequestResponse); RetryNotifier.clear(emitter); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java index a3892eeb..0250177a 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java @@ -116,8 +116,7 @@ public class SseServiceImpl implements ISseService { } - // 保存消息记录 并扣除费用 - chatCostService.deductToken(chatRequest); + chatRequest.setUserId(chatCostService.getUserId()); if (chatRequest.getSessionId() == null) { ChatSessionBo chatSessionBo = new ChatSessionBo(); @@ -127,11 +126,15 @@ public class SseServiceImpl implements ISseService { chatSessionService.insertByBo(chatSessionBo); chatRequest.setSessionId(chatSessionBo.getId()); } + + // 保存用户消息 + chatCostService.saveMessage(chatRequest); } // 自动选择模型并获取对应的聊天服务 IChatService chatService = autoSelectModelAndGetService(chatRequest); - // 仅当 autoSelectModel = true 时,才启用重试与降级 + // 用户消息只保存不计费,AI回复由BillingChatServiceProxy自动处理计费 + // chatCostService.publishBillingEvent(chatRequest); // 用户输入不计费 if (Boolean.TRUE.equals(chatRequest.getAutoSelectModel())) { ChatModelVo currentModel = this.chatModelVo; String currentCategory = currentModel.getCategory(); diff --git a/script/sql/update/chat-message-billing-type.sql b/script/sql/update/chat-message-billing-type.sql new file mode 100644 index 00000000..364645a8 --- /dev/null +++ b/script/sql/update/chat-message-billing-type.sql @@ -0,0 +1,4 @@ +-- 为 chat_message 表添加 billing_type 字段 +ALTER TABLE chat_message + ADD COLUMN billing_type char NULL COMMENT '计费类型(1-token计费,2-次数计费,null-普通消息)'; +