2 Commits

Author SHA1 Message Date
Administrator
affdc5e3a6 问题概述
1.保存消息和计费逻辑存在耦合
2.修改计费逻辑:
按次计费被阈值限制:旧逻辑把 TIMES 分支放在 totalTokens ≥ 100 的大分支里,导致没到100 token时不扣费,违背“每次调用就扣费”的语义。
token累计不当:TIMES 分支只扣费不处理累计,同时在 totalTokens < 100 时不会进入任何TIMES逻辑,累计会无意义增长。
粒度不稳定:TOKEN 计费一旦达阈值就把 total 全扣完并清零,不利于对账与用户体验。
打印方式:使用 System.out.println,不利于生产追踪。

改动要点
1.新增独立方法
saveMessage(ChatRequest): 只落库。
publishBillingEvent(ChatRequest): 只发布异步计费事件。
保留组合方法 saveMessageAndPublishEvent(ChatRequest) 以便需要一行调用时使用。
调用处已改为“先保存,再发布事件”
SseServiceImpl: 先 saveMessage,再 publishBillingEvent。
SSEEventSourceListener: 同上。
DifyServiceImpl: 同上。

2.计费模式分流:
TIMES:每次调用直接扣费,不累计。
TOKEN:按阈值(100)批量扣费,保留余数,账单颗粒稳定。
保留余数:total = prev + delta;billable = floor(total/threshold)threshold;remainder = total % threshold。
日志替换:统一使用 log.debug。
结构更清晰、可维护。
所有金额计算统一用 BigDecimal,保留两位小数,RoundingMode.HALF_UP
按次计费:每次直接扣费(BigDecimal),边界转 Double
按 token 计费:按阈值批量结算,保留余数;费用=单价(BigDecimal)×可结算token数
1. 消息分类存储
用户消息:role="user", deductCost=null, totalTokens=本次token数, remark="用户消息"
系统账单:role="system", deductCost=实际扣费, totalTokens=计费token数, remark="TIMES_BILLING/TOKEN_BILLING"
2. 数据流程
用户发送消息 → 预检查余额 → 保存用户消息 → 发布计费事件 → 异步扣费 → 保存账单记录
2025-08-14 14:00:48 +08:00
Administrator
5a2e08f87d 问题概述
1.保存消息和计费逻辑存在耦合
2.修改计费逻辑:
按次计费被阈值限制:旧逻辑把 TIMES 分支放在 totalTokens ≥ 100 的大分支里,导致没到100 token时不扣费,违背“每次调用就扣费”的语义。
token累计不当:TIMES 分支只扣费不处理累计,同时在 totalTokens < 100 时不会进入任何TIMES逻辑,累计会无意义增长。
粒度不稳定:TOKEN 计费一旦达阈值就把 total 全扣完并清零,不利于对账与用户体验。
打印方式:使用 System.out.println,不利于生产追踪。
3.建议数据库不要存扣除金额和累计消耗token,消息表里不需要存“累计到目前为止多少”,否则每条消息都变成快照,既冗余又易不一致

改动要点
1.新增独立方法
saveMessage(ChatRequest): 只落库。
publishBillingEvent(ChatRequest): 只发布异步计费事件。
保留组合方法 saveMessageAndPublishEvent(ChatRequest) 以便需要一行调用时使用。
调用处已改为“先保存,再发布事件”
SseServiceImpl: 先 saveMessage,再 publishBillingEvent。
SSEEventSourceListener: 同上。
DifyServiceImpl: 同上。

2.计费模式分流:
TIMES:每次调用直接扣费,不累计。
TOKEN:按阈值(100)批量扣费,保留余数,账单颗粒稳定。
保留余数:total = prev + delta;billable = floor(total/threshold)threshold;remainder = total % threshold。
日志替换:统一使用 log.debug。
结构更清晰、可维护。
所有金额计算统一用 BigDecimal,保留两位小数,RoundingMode.HALF_UP
按次计费:每次直接扣费(BigDecimal),边界转 Double
按 token 计费:按阈值批量结算,保留余数;费用=单价(BigDecimal)×可结算token数
2025-08-08 13:39:37 +08:00
59 changed files with 1010 additions and 1284 deletions

View File

@@ -99,10 +99,6 @@
> 💡 **小贴士**:建议将 PR 提交到 GitHub我们会自动同步到其他代码托管平台
<a href="https://openomy.com/ageerle/ruoyi-ai" target="_blank" style="display: block; width: 100%;" align="center">
<img src="https://openomy.com/svg?repo=ageerle/ruoyi-ai&chart=bubble&latestMonth=3" target="_blank" alt="Contribution Leaderboard" style="display: block; width: 100%;" />
</a>
## 📄 开源协议
本项目采用 **MIT 开源协议**,详情请查看 [LICENSE](LICENSE) 文件。

View File

@@ -67,19 +67,4 @@ public class ChatRequest {
*/
private Long uuid;
/**
* 是否有附件
*/
private Boolean hasAttachment;
/**
* 是否自动切换模型
*/
private Boolean autoSelectModel;
/**
* 会话令牌为避免在非Web线程中获取Request入口处注入
*/
private String token;
}

View File

@@ -9,7 +9,6 @@ import org.ruoyi.core.domain.BaseEntity;
import java.io.Serial;
/**
* 聊天模型对象 chat_model
*
@@ -76,11 +75,6 @@ public class ChatModel extends BaseEntity {
*/
private String apiKey;
/**
* 优先级
*/
private Integer priority;
/**
* 备注
*/

View File

@@ -74,11 +74,6 @@ public class ChatModelBo extends BaseEntity {
@NotBlank(message = "请求地址不能为空", groups = { AddGroup.class, EditGroup.class })
private String apiHost;
/**
* 优先级
*/
private Integer priority;
/**
* 密钥
*/

View File

@@ -14,6 +14,7 @@ import java.io.Serializable;
/**
* 聊天模型视图对象 chat_model
*
@@ -89,17 +90,10 @@ public class ChatModelVo implements Serializable {
@ExcelProperty(value = "密钥")
private String apiKey;
/**
* 优先级(值越大优先级越高)
*/
@ExcelProperty(value = "优先级")
private Integer priority;
/**
* 备注
*/
@ExcelProperty(value = "备注")
private String remark;
}
}

View File

@@ -57,17 +57,6 @@ public interface IChatModelService {
* 通过模型分类获取模型信息
*/
ChatModelVo selectModelByCategory(String image);
/**
* 通过模型分类获取优先级最高的模型信息
*/
ChatModelVo selectModelByCategoryWithHighestPriority(String category);
/**
* 在同一分类下,查找优先级小于当前优先级的最高优先级模型(用于降级)。
*/
ChatModelVo selectFallbackModelByCategoryAndLessPriority(String category, Integer currentPriority);
/**
* 获取ppt模型信息
*/

View File

@@ -46,11 +46,4 @@ public interface IPromptTemplateService {
* 校验并批量删除提示词模板信息
*/
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
/**
* 根据分类查询提示词模板
*
* @param category 分类
*/
PromptTemplateVo queryByCategory(String category);
}

View File

@@ -136,33 +136,6 @@ public class ChatModelServiceImpl implements IChatModelService {
public ChatModelVo selectModelByCategory(String category) {
return baseMapper.selectVoOne(Wrappers.<ChatModel>lambdaQuery().eq(ChatModel::getCategory, category));
}
/**
* 通过模型分类获取优先级最高的模型信息
*/
@Override
public ChatModelVo selectModelByCategoryWithHighestPriority(String category) {
return baseMapper.selectVoOne(
Wrappers.<ChatModel>lambdaQuery()
.eq(ChatModel::getCategory, category)
.orderByDesc(ChatModel::getPriority)
.last("LIMIT 1")
);
}
/**
* 在同一分类下,查找优先级小于当前优先级的最高优先级模型(用于降级)。
*/
@Override
public ChatModelVo selectFallbackModelByCategoryAndLessPriority(String category, Integer currentPriority) {
return baseMapper.selectVoOne(
Wrappers.<ChatModel>lambdaQuery()
.eq(ChatModel::getCategory, category)
.lt(ChatModel::getPriority, currentPriority)
.orderByDesc(ChatModel::getPriority)
.last("LIMIT 1")
);
}
@Override
public ChatModel getPPT() {

View File

@@ -109,13 +109,4 @@ public class PromptTemplateServiceImpl implements IPromptTemplateService {
}
return baseMapper.deleteBatchIds(ids) > 0;
}
@Override
public PromptTemplateVo queryByCategory(String category) {
LambdaQueryWrapper<PromptTemplate> queryWrapper = Wrappers.lambdaQuery(PromptTemplate.class);
queryWrapper.eq(PromptTemplate::getCategory, category);
queryWrapper.orderByDesc(PromptTemplate::getUpdateTime);
queryWrapper.last("limit 1");
return baseMapper.selectVoOne(queryWrapper);
}
}

View File

@@ -31,7 +31,7 @@ public interface IKnowledgeInfoService {
/**
* 查询知识库列表
*/
TableDataInfo<KnowledgeInfoVo> queryPageListByRole(KnowledgeInfoBo bo, PageQuery pageQuery);
TableDataInfo<KnowledgeInfoVo> queryPageListByRole(PageQuery pageQuery);
/**
* 查询知识库列表

View File

@@ -42,7 +42,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
private final ConfigService configService;
// private EmbeddingStore<TextSegment> embeddingStore;
private EmbeddingStore<TextSegment> embeddingStore;
private WeaviateClient client;
@@ -82,14 +82,14 @@ public class VectorStoreServiceImpl implements VectorStoreService {
log.info("Schema 创建成功: {}", className);
}
}
// embeddingStore = WeaviateEmbeddingStore.builder()
// .scheme(protocol)
// .host(host)
// .objectClass(className)
// .scheme(protocol)
// .avoidDups(true)
// .consistencyLevel("ALL")
// .build();
embeddingStore = WeaviateEmbeddingStore.builder()
.scheme(protocol)
.host(host)
.objectClass(className)
.scheme(protocol)
.avoidDups(true)
.consistencyLevel("ALL")
.build();
}
@Override
@@ -148,7 +148,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
String graphQLQuery = String.format(
"{\n" +
" Get {\n" +
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
" %s(nearVector: {vector: [%s], certainty: %f} limit: %d) {\n" +
" text\n" +
" fid\n" +
" kid\n" +

View File

@@ -16,15 +16,6 @@ import java.util.List;
public interface ISysDictTypeService {
/**
* Select all dictionary types based on the specified conditions
*
* @param dictType The business object containing query conditions for dictionary types
* @return TableDataInfo containing a list of SysDictTypeVo objects that match the query criteria
*/
TableDataInfo<SysDictTypeVo> selectAll(SysDictTypeBo dictType);
TableDataInfo<SysDictTypeVo> selectPageDictTypeList(SysDictTypeBo dictType, PageQuery pageQuery);
/**

View File

@@ -10,7 +10,6 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import lombok.RequiredArgsConstructor;
import org.ruoyi.common.core.constant.CacheConstants;
import org.ruoyi.common.core.constant.CacheNames;
import org.ruoyi.common.core.constant.HttpStatus;
import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.common.core.service.DictService;
import org.ruoyi.common.core.utils.MapstructUtils;
@@ -51,18 +50,6 @@ public class SysDictTypeServiceImpl implements ISysDictTypeService, DictService
private final SysDictTypeMapper baseMapper;
private final SysDictDataMapper dictDataMapper;
@Override
public TableDataInfo<SysDictTypeVo> selectAll(SysDictTypeBo dictType) {
LambdaQueryWrapper<SysDictType> lqw = buildQueryWrapper(dictType);
// 2. 查询所有数据(不分页)
List<SysDictTypeVo> list = baseMapper.selectVoList(lqw);
TableDataInfo<SysDictTypeVo> rspData = new TableDataInfo<>();
rspData.setCode(HttpStatus.SUCCESS); // 200
rspData.setMsg("查询成功");
rspData.setRows(list);
rspData.setTotal(list.size()); // 总数为列表大小
return rspData;
}
@Override
public TableDataInfo<SysDictTypeVo> selectPageDictTypeList(SysDictTypeBo dictType, PageQuery pageQuery) {
LambdaQueryWrapper<SysDictType> lqw = buildQueryWrapper(dictType);

View File

@@ -7,6 +7,7 @@ import jakarta.validation.constraints.NotNull;
import lombok.RequiredArgsConstructor;
import org.ruoyi.chat.config.KnowledgeRoleConfig;
import org.ruoyi.common.core.domain.R;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.validate.AddGroup;
import org.ruoyi.common.excel.utils.ExcelUtil;
import org.ruoyi.common.log.annotation.Log;
@@ -30,7 +31,6 @@ import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
import java.util.Objects;
/**
* 知识库管理
@@ -60,9 +60,7 @@ public class KnowledgeController extends BaseController {
if (!StpUtil.isLogin()) {
throw new SecurityException("请先去登录!");
}
if (!Objects.equals(LoginHelper.getUserId(), 1L)) {
bo.setUid(LoginHelper.getUserId());
}
bo.setUid(LoginHelper.getUserId());
return knowledgeInfoService.queryPageList(bo, pageQuery);
}
@@ -74,16 +72,14 @@ public class KnowledgeController extends BaseController {
if (!StpUtil.isLogin()) {
throw new SecurityException("请先去登录!");
}
LoginUser loginUser = LoginHelper.getLoginUser();
// 管理员跳过权限
if (Objects.equals(LoginHelper.getUserId(), 1L)) {
return knowledgeInfoService.queryPageList(bo, pageQuery);
} else if (!knowledgeRoleConfig.getEnable()) {
if (loginUser.getUserId().equals(1L) || !knowledgeRoleConfig.getEnable()) {
bo.setUid(LoginHelper.getUserId());
return knowledgeInfoService.queryPageList(bo, pageQuery);
} else {
bo.setUid(LoginHelper.getUserId());
return knowledgeInfoService.queryPageListByRole(bo, pageQuery);
return knowledgeInfoService.queryPageListByRole(pageQuery);
}
}

View File

@@ -1,22 +0,0 @@
package org.ruoyi.chat.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 提示词模板分类
*
* @author evo
*/
@Getter
@AllArgsConstructor
public enum promptTemplateEnum {
CHAT(1, "chat"),
VECTOR(2, "vector"),
;
private final Integer code;
private final String desc;
}

View File

@@ -0,0 +1,31 @@
package org.ruoyi.chat.event;
import org.springframework.context.ApplicationEvent;
/**
* 聊天消息创建事件(用于异步计费/累计等)
*/
public class ChatMessageCreatedEvent extends ApplicationEvent {
private final Long userId;
private final Long sessionId;
private final String modelName;
private final String role;
private final String content;
public ChatMessageCreatedEvent(Long userId, Long sessionId, String modelName, String role, String content) {
super(userId);
this.userId = userId;
this.sessionId = sessionId;
this.modelName = modelName;
this.role = role;
this.content = content;
}
public Long getUserId() { return userId; }
public Long getSessionId() { return sessionId; }
public String getModelName() { return modelName; }
public String getRole() { return role; }
public String getContent() { return content; }
}

View File

@@ -0,0 +1,48 @@
package org.ruoyi.chat.listener;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.event.ChatMessageCreatedEvent;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.common.chat.request.ChatRequest;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.context.event.EventListener;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
@Slf4j
@Component
@RequiredArgsConstructor
public class BillingEventListener {
private final IChatCostService chatCostService;
@Async
@EventListener
public void onChatMessageCreated(ChatMessageCreatedEvent event) {
log.debug("BillingEventListener->接收到计费事件用户ID: {}会话ID: {},模型: {}",
event.getUserId(), event.getSessionId(), event.getModelName());
try {
ChatRequest chatRequest = new ChatRequest();
chatRequest.setUserId(event.getUserId());
chatRequest.setSessionId(event.getSessionId());
chatRequest.setModel(event.getModelName());
chatRequest.setRole(event.getRole());
chatRequest.setPrompt(event.getContent());
// 异步执行计费累计与扣费
log.debug("BillingEventListener->开始执行计费逻辑");
chatCostService.deductToken(chatRequest);
log.debug("BillingEventListener->计费逻辑执行完成");
} catch (Exception ex) {
// 由于已有预检查,这里的异常主要是系统异常(数据库连接等)
// 记录错误但不中断异步线程
log.error("BillingEventListener->异步计费异常用户ID: {},模型: {},错误: {}",
event.getUserId(), event.getModelName(), ex.getMessage(), ex);
// TODO: 可以考虑加入重试机制或者错误通知机制
// 例如:发送到死信队列,或者通知运维人员
}
}
}

View File

@@ -14,8 +14,6 @@ import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Objects;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.util.SSEUtil;
@Slf4j
@Component
@@ -23,18 +21,12 @@ import org.ruoyi.chat.util.SSEUtil;
public class FastGPTSSEEventSourceListener extends EventSourceListener {
private SseEmitter emitter;
private Long sessionId;
@Autowired(required = false)
public FastGPTSSEEventSourceListener(SseEmitter emitter) {
this.emitter = emitter;
}
public FastGPTSSEEventSourceListener(SseEmitter emitter, Long sessionId) {
this.emitter = emitter;
this.sessionId = sessionId;
}
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("FastGPT sse连接成功");
@@ -48,7 +40,6 @@ public class FastGPTSSEEventSourceListener extends EventSourceListener {
if ("flowResponses".equals(type)){
emitter.send(data);
emitter.complete();
RetryNotifier.clear(emitter);
} else {
emitter.send(data);
}
@@ -66,20 +57,13 @@ public class FastGPTSSEEventSourceListener extends EventSourceListener {
@SneakyThrows
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if (Objects.isNull(response)) {
SSEUtil.sendErrorEvent(emitter, t != null ? t.getMessage() : "SSE连接失败");
RetryNotifier.notifyFailure(emitter);
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
String msg = body.string();
log.error("FastGPT sse连接异常data{},异常:{}", msg, t);
SSEUtil.sendErrorEvent(emitter, msg);
RetryNotifier.notifyFailure(emitter);
log.error("FastGPT sse连接异常data{},异常:{}", body.string(), t);
} else {
log.error("FastGPT sse连接异常data{},异常:{}", response, t);
SSEUtil.sendErrorEvent(emitter, String.valueOf(response));
RetryNotifier.notifyFailure(emitter);
}
eventSource.cancel();
}

View File

@@ -21,8 +21,6 @@ import org.ruoyi.common.core.utils.SpringUtils;
import org.ruoyi.common.core.utils.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.ruoyi.chat.util.SSEUtil;
import org.ruoyi.chat.support.RetryNotifier;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Objects;
@@ -46,15 +44,12 @@ public class SSEEventSourceListener extends EventSourceListener {
private String token;
private boolean retryEnabled;
@Autowired(required = false)
public SSEEventSourceListener(SseEmitter emitter,Long userId,Long sessionId, String token, boolean retryEnabled) {
public SSEEventSourceListener(SseEmitter emitter,Long userId,Long sessionId, String token) {
this.emitter = emitter;
this.userId = userId;
this.sessionId = sessionId;
this.token = token;
this.retryEnabled = retryEnabled;
}
@@ -82,8 +77,6 @@ public class SSEEventSourceListener extends EventSourceListener {
if ("[DONE]".equals(data)) {
//成功响应
emitter.complete();
// 清理失败回调(以 emitter 为键)
RetryNotifier.clear(emitter);
// 扣除费用
ChatRequest chatRequest = new ChatRequest();
// 设置对话角色
@@ -94,7 +87,9 @@ public class SSEEventSourceListener extends EventSourceListener {
chatRequest.setPrompt(stringBuffer.toString());
// 记录会话token
BaseContext.setCurrentToken(token);
chatCostService.deductToken(chatRequest);
// 先保存助手消息,再发布异步计费事件
chatCostService.saveMessage(chatRequest);
chatCostService.publishBillingEvent(chatRequest);
return;
}
@@ -120,38 +115,19 @@ public class SSEEventSourceListener extends EventSourceListener {
@Override
public void onClosed(EventSource eventSource) {
log.info("OpenAI关闭sse连接...");
// 清理失败回调
RetryNotifier.clear(emitter);
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if (Objects.isNull(response)) {
// 透传错误到前端
SSEUtil.sendErrorEvent(emitter, t != null ? t.getMessage() : "SSE连接失败");
if (retryEnabled) {
// 通知重试(以 emitter 为键)
RetryNotifier.notifyFailure(emitter);
} else {
emitter.complete();
}
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
String msg = body.string();
log.error("OpenAI sse连接异常data{},异常:{}", msg, t);
SSEUtil.sendErrorEvent(emitter, msg);
log.error("OpenAI sse连接异常data{},异常:{}", body.string(), t);
} else {
log.error("OpenAI sse连接异常data{},异常:{}", response, t);
SSEUtil.sendErrorEvent(emitter, String.valueOf(response));
}
if (retryEnabled) {
// 通知重试
RetryNotifier.notifyFailure(emitter);
} else {
emitter.complete();
}
eventSource.cancel();
}

View File

@@ -19,6 +19,22 @@ public interface IChatCostService {
void deductToken(ChatRequest chatRequest);
/**
* 保存聊天消息记录(不进行计费)
*
* @param chatRequest 对话信息
*/
void saveMessage(ChatRequest chatRequest);
/**
* 仅发布异步计费事件(不做入库)
*
* @param chatRequest 对话信息
*/
void publishBillingEvent(ChatRequest chatRequest);
/**
* 直接扣除用户的余额
*

View File

@@ -3,7 +3,10 @@ package org.ruoyi.chat.service.chat.impl;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.math.BigDecimal;
import java.math.RoundingMode;
import org.ruoyi.chat.enums.BillingType;
import org.ruoyi.chat.event.ChatMessageCreatedEvent;
import org.ruoyi.chat.enums.UserGradeType;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.common.chat.request.ChatRequest;
@@ -20,6 +23,7 @@ import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.IChatTokenService;
import org.ruoyi.system.domain.SysUser;
import org.ruoyi.system.mapper.SysUserMapper;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
@@ -42,106 +46,248 @@ public class ChatCostServiceImpl implements IChatCostService {
private final IChatModelService chatModelService;
private final ApplicationEventPublisher eventPublisher;
/**
* 扣除用户余额
* 扣除用户余额(仅计费与累计,不保存消息)
*/
@Override
public void deductToken(ChatRequest chatRequest) {
if(chatRequest.getUserId()==null || chatRequest.getSessionId()==null){
if (chatRequest.getUserId() == null) {
log.warn("deductToken->用户ID为空跳过计费");
return;
}
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
System.out.println("deductToken->本次提交token数 : "+tokens);
log.debug("deductToken->本次提交token数: {}", tokens);
String modelName = chatRequest.getModel();
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
BigDecimal unitPrice = BigDecimal.valueOf(chatModelVo.getModelPrice());
ChatMessageBo chatMessageBo = new ChatMessageBo();
// 按次计费每次调用都直接扣费不累计token
if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
BigDecimal numberCost = unitPrice.setScale(2, RoundingMode.HALF_UP);
deductUserBalance(chatRequest.getUserId(), numberCost.doubleValue());
log.debug("deductToken->按次数扣费,费用: {},模型: {}", numberCost, modelName);
// 清理可能存在的历史累计token模型计费方式可能发生过变更
ChatUsageToken existingToken = chatTokenService.queryByUserId(chatRequest.getUserId(), modelName);
if (existingToken != null && existingToken.getToken() > 0) {
existingToken.setToken(0);
chatTokenService.editToken(existingToken);
log.debug("deductToken->按次计费清理历史累计token: {}", existingToken.getToken());
}
// 记录账单消息
saveBillingRecord(chatRequest, tokens, numberCost.doubleValue(), "TIMES_BILLING");
return;
}
// 设置用户id
chatMessageBo.setUserId(chatRequest.getUserId());
// 设置会话id
chatMessageBo.setSessionId(chatRequest.getSessionId());
// 设置对话角色
chatMessageBo.setRole(chatRequest.getRole());
// 设置对话内容
chatMessageBo.setContent(chatRequest.getPrompt());
// 设置模型名字
chatMessageBo.setModelName(chatRequest.getModel());
// 按token计费累加并按阈值批量扣费保留余数
final int threshold = 100;
// 获得记录的累计token数
ChatUsageToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), modelName);
// TODO: 这里存在并发竞态条件需要在chatTokenService层面添加乐观锁或分布式锁
ChatUsageToken chatToken = chatTokenService.queryByUserId(chatRequest.getUserId(), modelName);
if (chatToken == null) {
chatToken = new ChatUsageToken();
chatToken.setToken(0);
chatToken.setModelName(modelName);
chatToken.setUserId(chatRequest.getUserId());
}
// 计算总token数
int totalTokens = chatToken.getToken() + tokens;
int previousUnpaid = chatToken.getToken();
int totalTokens = previousUnpaid + tokens;
log.debug("deductToken->未付费token数: {},本次累计后总数: {}", previousUnpaid, totalTokens);
//当前未付费token
int token = chatToken.getToken();
int billable = (totalTokens / threshold) * threshold; // 可计费整批token
int remainder = totalTokens - billable; // 结算后保留的余数
System.out.println("deductToken->未付费的token数 : "+token);
System.out.println("deductToken->本次提交+未付费token数 : "+totalTokens);
//扣费核心逻辑总token大于100就要对未结清的token进行扣费
if (totalTokens >= 100) {// 如果总token数大于等于100,进行费用扣除
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
double cost = chatModelVo.getModelPrice();
if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
// 按次数扣费
deductUserBalance(chatMessageBo.getUserId(), cost);
chatMessageBo.setDeductCost(cost);
}else {
// 按token扣费
Double numberCost = totalTokens * cost;
System.out.println("deductToken->按token扣费 计算token数量: "+totalTokens);
System.out.println("deductToken->按token扣费 每token的价格: "+cost);
deductUserBalance(chatMessageBo.getUserId(), numberCost);
chatMessageBo.setDeductCost(numberCost);
// 保存剩余tokens
if (billable > 0) {
BigDecimal numberCost = unitPrice
.multiply(BigDecimal.valueOf(billable))
.setScale(2, RoundingMode.HALF_UP);
log.debug("deductToken->按token扣费结算token数量: {},单价: {},费用: {}", billable, unitPrice, numberCost);
try {
// 先尝试扣费
deductUserBalance(chatRequest.getUserId(), numberCost.doubleValue());
// 扣费成功后,保存余数
chatToken.setModelName(modelName);
chatToken.setUserId(chatMessageBo.getUserId());
chatToken.setToken(0);//因为判断大于100token直接全部计算扣除了所以这里直接=0就可以了
chatToken.setUserId(chatRequest.getUserId());
chatToken.setToken(remainder);
chatTokenService.editToken(chatToken);
log.debug("deductToken->扣费成功,更新余数: {}", remainder);
// 记录账单消息
saveBillingRecord(chatRequest, billable, numberCost.doubleValue(), "TOKEN_BILLING");
} catch (ServiceException e) {
// 余额不足时不更新token累计保持原有累计数
log.warn("deductToken->余额不足本次token累计保持不变: {}", totalTokens);
throw e; // 重新抛出异常
}
} else {
//不满100Token,不需要进行扣费啊啊啊
//deductUserBalance(chatMessageBo.getUserId(), 0.0);
chatMessageBo.setDeductCost(0d);
chatMessageBo.setRemark("不满100Token,计入下一次!");
System.out.println("deductToken->不满100Token,计入下一次!");
// 未达阈值累积token
log.debug("deductToken->未达到计费阈值({}),累积到下次", threshold);
chatToken.setModelName(modelName);
chatToken.setUserId(chatRequest.getUserId());
chatToken.setToken(totalTokens);
chatToken.setModelName(chatMessageBo.getModelName());
chatToken.setUserId(chatMessageBo.getUserId());
chatTokenService.editToken(chatToken);
}
}
/**
* 保存聊天消息记录(不进行计费)
*/
@Override
public void saveMessage(ChatRequest chatRequest) {
if (chatRequest.getUserId() == null || chatRequest.getSessionId() == null) {
log.warn("saveMessage->用户ID或会话ID为空跳过保存消息");
return;
}
// 验证消息内容
if (chatRequest.getPrompt() == null || chatRequest.getPrompt().trim().isEmpty()) {
log.warn("saveMessage->消息内容为空,跳过保存");
return;
}
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(chatRequest.getUserId());
chatMessageBo.setSessionId(chatRequest.getSessionId());
chatMessageBo.setRole(chatRequest.getRole());
chatMessageBo.setContent(chatRequest.getPrompt().trim());
chatMessageBo.setModelName(chatRequest.getModel());
// 计算并保存本次消息的token数
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
chatMessageBo.setTotalTokens(tokens);
// 普通消息不涉及扣费deductCost保持null
chatMessageBo.setDeductCost(null);
chatMessageBo.setRemark("用户消息");
try {
chatMessageService.insertByBo(chatMessageBo);
log.debug("saveMessage->成功保存消息用户ID: {}, 会话ID: {}, tokens: {}",
chatRequest.getUserId(), chatRequest.getSessionId(), tokens);
} catch (Exception e) {
log.error("saveMessage->保存消息失败", e);
throw new ServiceException("保存消息失败");
}
}
// 保存消息记录
chatMessageService.insertByBo(chatMessageBo);
System.out.println("deductToken->chatMessageService.insertByBo(: "+chatMessageBo);
System.out.println("----------------------------------------");
@Override
public void publishBillingEvent(ChatRequest chatRequest) {
log.debug("publishBillingEvent->发布计费事件用户ID: {}会话ID: {},模型: {}",
chatRequest.getUserId(), chatRequest.getSessionId(), chatRequest.getModel());
// 预检查:评估可能的扣费金额,如果余额不足则直接抛异常
try {
preCheckBalance(chatRequest);
} catch (ServiceException e) {
log.warn("publishBillingEvent->预检查余额不足用户ID: {},模型: {}",
chatRequest.getUserId(), chatRequest.getModel());
throw e; // 直接抛出,阻止消息保存和对话继续
}
eventPublisher.publishEvent(new ChatMessageCreatedEvent(
chatRequest.getUserId(),
chatRequest.getSessionId(),
chatRequest.getModel(),
chatRequest.getRole(),
chatRequest.getPrompt()
));
log.debug("publishBillingEvent->计费事件发布完成");
}
/**
* 预检查用户余额是否足够支付可能的费用
*/
private void preCheckBalance(ChatRequest chatRequest) {
if (chatRequest.getUserId() == null) {
return;
}
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
String modelName = chatRequest.getModel();
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
BigDecimal unitPrice = BigDecimal.valueOf(chatModelVo.getModelPrice());
// 按次计费:直接检查单次费用
if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
BigDecimal numberCost = unitPrice.setScale(2, RoundingMode.HALF_UP);
checkUserBalanceWithoutDeduct(chatRequest.getUserId(), numberCost.doubleValue());
return;
}
// 按token计费检查累计后可能的费用
final int threshold = 100;
ChatUsageToken chatToken = chatTokenService.queryByUserId(chatRequest.getUserId(), modelName);
int previousUnpaid = (chatToken == null) ? 0 : chatToken.getToken();
int totalTokens = previousUnpaid + tokens;
int billable = (totalTokens / threshold) * threshold;
if (billable > 0) {
BigDecimal numberCost = unitPrice
.multiply(BigDecimal.valueOf(billable))
.setScale(2, RoundingMode.HALF_UP);
checkUserBalanceWithoutDeduct(chatRequest.getUserId(), numberCost.doubleValue());
}
}
/**
* 检查用户余额是否足够,但不扣除
*/
private void checkUserBalanceWithoutDeduct(Long userId, Double numberCost) {
SysUser sysUser = sysUserMapper.selectById(userId);
if (sysUser == null) {
throw new ServiceException("用户不存在");
}
BigDecimal userBalance = BigDecimal.valueOf(sysUser.getUserBalance() == null ? 0D : sysUser.getUserBalance())
.setScale(2, RoundingMode.HALF_UP);
BigDecimal cost = BigDecimal.valueOf(numberCost == null ? 0D : numberCost)
.setScale(2, RoundingMode.HALF_UP);
if (userBalance.compareTo(cost) < 0 || userBalance.compareTo(BigDecimal.ZERO) == 0) {
throw new ServiceException("余额不足, 请充值。当前余额: " + userBalance + ",需要: " + cost);
}
}
/**
* 保存账单记录
*/
private void saveBillingRecord(ChatRequest chatRequest, int billedTokens, double cost, String billingType) {
try {
ChatMessageBo billingMessage = new ChatMessageBo();
billingMessage.setUserId(chatRequest.getUserId());
billingMessage.setSessionId(chatRequest.getSessionId());
billingMessage.setRole("system"); // 系统账单消息
billingMessage.setModelName(chatRequest.getModel());
billingMessage.setTotalTokens(billedTokens);
billingMessage.setDeductCost(cost);
billingMessage.setRemark(billingType);
// 构建账单消息内容
String content;
if ("TIMES_BILLING".equals(billingType)) {
content = String.format("按次计费:消耗 %d tokens扣费 %.2f 元", billedTokens, cost);
} else {
content = String.format("按量计费:结算 %d tokens扣费 %.2f 元", billedTokens, cost);
}
billingMessage.setContent(content);
chatMessageService.insertByBo(billingMessage);
log.debug("saveBillingRecord->保存账单记录成功用户ID: {}, 计费类型: {}, 费用: {}",
chatRequest.getUserId(), billingType, cost);
} catch (Exception e) {
log.error("saveBillingRecord->保存账单记录失败", e);
// 账单记录失败不影响主流程,只记录错误日志
}
}
/**
@@ -158,22 +304,26 @@ public class ChatCostServiceImpl implements IChatCostService {
return;
}
Double userBalance = sysUser.getUserBalance();
BigDecimal userBalance = BigDecimal.valueOf(sysUser.getUserBalance() == null ? 0D : sysUser.getUserBalance())
.setScale(2, RoundingMode.HALF_UP);
BigDecimal cost = BigDecimal.valueOf(numberCost == null ? 0D : numberCost)
.setScale(2, RoundingMode.HALF_UP);
log.debug("deductUserBalance->准备扣除: {},当前余额: {}", cost, userBalance);
System.out.println("deductUserBalance->准备扣除numberCost: "+numberCost);
System.out.println("deductUserBalance->剩余金额userBalance: "+userBalance);
if (userBalance < numberCost || userBalance == 0) {
if (userBalance.compareTo(cost) < 0 || userBalance.compareTo(BigDecimal.ZERO) == 0) {
throw new ServiceException("余额不足, 请充值");
}
BigDecimal newBalance = userBalance.subtract(cost);
if (newBalance.compareTo(BigDecimal.ZERO) < 0) {
newBalance = BigDecimal.ZERO;
}
newBalance = newBalance.setScale(2, RoundingMode.HALF_UP);
sysUserMapper.update(null,
new LambdaUpdateWrapper<SysUser>()
.set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))
.set(SysUser::getUserBalance, newBalance.doubleValue())
.eq(SysUser::getUserId, userId));
}

View File

@@ -20,8 +20,6 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Collections;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
* 扣子聊天管理
@@ -55,25 +53,19 @@ public class CozeServiceImpl implements IChatService {
Flowable<ChatEvent> resp = coze.chat().stream(req);
ExecutorService executor = Executors.newFixedThreadPool(10);
executor.submit(() -> {
try {
resp.blockingForEach(
event -> {
if (ChatEventType.CONVERSATION_MESSAGE_DELTA.equals(event.getEvent())) {
emitter.send(event.getMessage().getContent());
log.info("coze: {}", event.getMessage().getContent());
}
if (ChatEventType.CONVERSATION_CHAT_COMPLETED.equals(event.getEvent())) {
emitter.complete();
log.info("Token usage: {}", event.getChat().getUsage().getTokenCount());
RetryNotifier.clear(emitter);
}
resp.blockingForEach(
event -> {
if (ChatEventType.CONVERSATION_MESSAGE_DELTA.equals(event.getEvent())) {
emitter.send(event.getMessage().getContent());
log.info("coze: {}", event.getMessage().getContent());
}
);
} catch (Exception ex) {
ChatServiceHelper.onStreamError(emitter, ex.getMessage());
} finally {
coze.shutdownExecutor();
}
if (ChatEventType.CONVERSATION_CHAT_COMPLETED.equals(event.getEvent())) {
emitter.complete();
log.info("Token usage: {}", event.getChat().getUsage().getTokenCount());
}
}
);
coze.shutdownExecutor();
});

View File

@@ -9,13 +9,13 @@ import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.enums.ChatModeType;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.support.ChatServiceHelper;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* deepseek
*/
@@ -57,14 +57,11 @@ public class DeepSeekChatImpl implements IChatService {
@Override
public void onError(Throwable error) {
System.err.println("错误: " + error.getMessage());
ChatServiceHelper.onStreamError(emitter, error.getMessage());
}
});
} catch (Exception e) {
log.error("deepseek请求失败{}", e.getMessage());
// 同步异常直接通知失败
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -25,10 +25,8 @@ import org.ruoyi.service.IChatSessionService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.ruoyi.chat.support.ChatServiceHelper;
import java.util.Objects;
import org.ruoyi.chat.support.RetryNotifier;
/**
* dify 聊天管理
@@ -113,25 +111,23 @@ public class DifyServiceImpl implements IChatService {
chatRequestResponse.setUserId(chatRequest.getUserId());
chatRequestResponse.setSessionId(chatRequest.getSessionId());
chatRequestResponse.setPrompt(respMessage.toString());
chatCostService.deductToken(chatRequestResponse);
RetryNotifier.clear(emitter);
// 先保存助手消息,再发布异步计费事件
chatCostService.saveMessage(chatRequestResponse);
chatCostService.publishBillingEvent(chatRequestResponse);
}
@Override
public void onError(ErrorEvent event) {
System.err.println("错误: " + event.getMessage());
ChatServiceHelper.onStreamError(emitter, event.getMessage());
}
@Override
public void onException(Throwable throwable) {
System.err.println("异常: " + throwable.getMessage());
ChatServiceHelper.onStreamError(emitter, throwable.getMessage());
}
});
} catch (Exception e) {
log.error("dify请求失败{}", e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -33,7 +33,7 @@ public class FastGPTServiceImpl implements IChatService {
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
OpenAiStreamClient openAiStreamClient = ChatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
List<Message> messages = chatRequest.getMessages();
FastGPTSSEEventSourceListener listener = new FastGPTSSEEventSourceListener(emitter, chatRequest.getSessionId());
FastGPTSSEEventSourceListener listener = new FastGPTSSEEventSourceListener(emitter);
FastGPTChatCompletion completion = FastGPTChatCompletion
.builder()
.messages(messages)
@@ -41,12 +41,7 @@ public class FastGPTServiceImpl implements IChatService {
.detail(true)
.stream(true)
.build();
try {
openAiStreamClient.streamChatCompletion(completion, listener);
} catch (Exception ex) {
org.ruoyi.chat.support.RetryNotifier.notifyFailure(chatRequest.getSessionId());
throw ex;
}
openAiStreamClient.streamChatCompletion(completion, listener);
return emitter;
}

View File

@@ -18,7 +18,6 @@ import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.*;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
* 图片识别模型
@@ -129,10 +128,10 @@ public class ImageServiceImpl implements IChatService {
OpenAiStreamClient openAiStreamClient = ChatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
List<Message> messages = chatRequest.getMessages();
// 获取会话token从入口透传避免非Web线程取值报错
String token = chatRequest.getToken();
// 获取会话token
String token = StpUtil.getTokenValue();
// 创建 SSE 事件源监听器
SSEEventSourceListener listener = ChatServiceHelper.createOpenAiListener(emitter, chatRequest);
SSEEventSourceListener listener = new SSEEventSourceListener(emitter, chatRequest.getUserId(), chatRequest.getSessionId(), token);
// 构建聊天完成请求
ChatCompletion completion = ChatCompletion
@@ -143,12 +142,7 @@ public class ImageServiceImpl implements IChatService {
.build();
// 发起流式聊天完成请求
try {
openAiStreamClient.streamChatCompletion(completion, listener);
} catch (Exception ex) {
ChatServiceHelper.onStreamError(emitter, ex.getMessage());
throw ex;
}
openAiStreamClient.streamChatCompletion(completion, listener);
return emitter;
}

View File

@@ -22,8 +22,6 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
@@ -67,14 +65,13 @@ public class OllamaServiceImpl implements IChatService {
try {
emitter.send(substr);
} catch (IOException e) {
ChatServiceHelper.onStreamError(emitter, e.getMessage());
SSEUtil.sendErrorEvent(emitter, e.getMessage());
}
};
api.chat(requestModel, streamHandler);
emitter.complete();
RetryNotifier.clear(emitter);
} catch (Exception e) {
ChatServiceHelper.onStreamError(emitter, e.getMessage());
SSEUtil.sendErrorEvent(emitter, e.getMessage());
}
});

View File

@@ -1,12 +1,12 @@
package org.ruoyi.chat.service.chat.impl;
import cn.dev33.satoken.stp.StpUtil;
import io.modelcontextprotocol.client.McpSyncClient;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.config.ChatConfig;
import org.ruoyi.chat.enums.ChatModeType;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.support.ChatServiceHelper;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
@@ -57,19 +57,15 @@ public class OpenAIServiceImpl implements IChatService {
Message userMessage = Message.builder().content("工具返回信息:"+toolString).role(Message.Role.USER).build();
messages.add(userMessage);
}
SSEEventSourceListener listener = ChatServiceHelper.createOpenAiListener(emitter, chatRequest);
String token = StpUtil.getTokenValue();
SSEEventSourceListener listener = new SSEEventSourceListener(emitter,chatRequest.getUserId(),chatRequest.getSessionId(), token);
ChatCompletion completion = ChatCompletion
.builder()
.messages(messages)
.model(chatRequest.getModel())
.stream(true)
.build();
try {
openAiStreamClient.streamChatCompletion(completion, listener);
} catch (Exception ex) {
ChatServiceHelper.onStreamError(emitter, ex.getMessage());
throw ex;
}
openAiStreamClient.streamChatCompletion(completion, listener);
return emitter;
}

View File

@@ -14,7 +14,6 @@ import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
@@ -52,18 +51,15 @@ public class QianWenAiChatServiceImpl implements IChatService {
public void onCompleteResponse(ChatResponse completeResponse) {
emitter.complete();
log.info("消息结束完整消息ID: {}", completeResponse);
org.ruoyi.chat.support.RetryNotifier.clear(emitter);
}
@Override
public void onError(Throwable error) {
error.printStackTrace();
ChatServiceHelper.onStreamError(emitter, error.getMessage());
}
});
} catch (Exception e) {
log.error("千问请求失败:{}", e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -1,18 +1,14 @@
package org.ruoyi.chat.service.chat.impl;
import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.core.collection.CollectionUtil;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import org.ruoyi.chat.enums.promptTemplateEnum;
import org.ruoyi.chat.factory.ChatServiceFactory;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.chat.support.ChatRetryHelper;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.util.SSEUtil;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.Message;
@@ -29,11 +25,9 @@ import org.ruoyi.domain.bo.ChatSessionBo;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.domain.vo.PromptTemplateVo;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.IChatSessionService;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.IPromptTemplateService;
import org.ruoyi.service.VectorStoreService;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
@@ -49,8 +43,8 @@ import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
* @author ageer
@@ -76,28 +70,19 @@ public class SseServiceImpl implements ISseService {
private ChatModelVo chatModelVo;
// 提示词模板服务
private final IPromptTemplateService promptTemplateService;
@Override
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
SseEmitter sseEmitter = new SseEmitter(0L);
try {
// 记录当前会话令牌,供异步线程使用
try {
chatRequest.setToken(StpUtil.getTokenValue());
} catch (Exception ignore) {
// 保底无token场景下忽略
}
// 构建消息列表
buildChatMessageList(chatRequest);
// 设置对话角色
chatRequest.setRole(Message.Role.USER.getName());
if (LoginHelper.isLogin()) {
if(LoginHelper.isLogin()){
// 设置用户id
// 设置用户id
chatRequest.setUserId(LoginHelper.getUserId());
@@ -106,20 +91,21 @@ public class SseServiceImpl implements ISseService {
//待优化的地方 这里请前端提交send的时候传递uuid进来或者sessionId
{
// 设置会话id
if (chatRequest.getUuid() == null) {
if (chatRequest.getUuid() == null){
//暂时随机生成会话id
chatRequest.setSessionId(System.currentTimeMillis());
} else {
}else{
//这里或许需要修改一下这里应该用uuid 或者 前端传递 sessionId
chatRequest.setSessionId(chatRequest.getUuid());
}
}
// 保存消息记录 并扣除费用
chatCostService.deductToken(chatRequest);
// 保存消息,再发布异步计费事件
chatCostService.saveMessage(chatRequest);
chatCostService.publishBillingEvent(chatRequest);
chatRequest.setUserId(chatCostService.getUserId());
if (chatRequest.getSessionId() == null) {
if(chatRequest.getSessionId()==null){
ChatSessionBo chatSessionBo = new ChatSessionBo();
chatSessionBo.setUserId(chatCostService.getUserId());
chatSessionBo.setSessionTitle(getFirst10Characters(chatRequest.getPrompt()));
@@ -128,87 +114,16 @@ public class SseServiceImpl implements ISseService {
chatRequest.setSessionId(chatSessionBo.getId());
}
}
// 自动选择模型并获取对应的聊天服务
IChatService chatService = autoSelectModelAndGetService(chatRequest);
// 仅当 autoSelectModel = true 时,才启用重试与降级
if (Boolean.TRUE.equals(chatRequest.getAutoSelectModel())) {
ChatModelVo currentModel = this.chatModelVo;
String currentCategory = currentModel.getCategory();
ChatRetryHelper.executeWithRetry(
currentModel,
currentCategory,
chatModelService,
sseEmitter,
(modelForTry, onFailure) -> {
// 替换请求中的模型名称
chatRequest.setModel(modelForTry.getModelName());
// 以 emitter 实例为唯一键注册失败回调
RetryNotifier.setFailureCallback(sseEmitter, onFailure);
try {
autoSelectServiceByCategoryAndInvoke(chatRequest, sseEmitter,
modelForTry.getCategory());
} finally {
// 不在此处清理,待下游结束/失败时清理
}
}
);
} else {
// 不重试不降级,直接调用
chatService.chat(chatRequest, sseEmitter);
}
// 根据模型分类调用不同的处理逻辑
IChatService chatService = chatServiceFactory.getChatService(chatModelVo.getCategory());
chatService.chat(chatRequest, sseEmitter);
} catch (Exception e) {
log.error(e.getMessage(), e);
SSEUtil.sendErrorEvent(sseEmitter, e.getMessage());
log.error(e.getMessage(),e);
SSEUtil.sendErrorEvent(sseEmitter,e.getMessage());
}
return sseEmitter;
}
/**
* 自动选择模型并获取对应的聊天服务
*/
private IChatService autoSelectModelAndGetService(ChatRequest chatRequest) {
try {
if (Boolean.TRUE.equals(chatRequest.getHasAttachment())) {
chatModelVo = selectModelByCategory("image");
} else if (Boolean.TRUE.equals(chatRequest.getAutoSelectModel())) {
chatModelVo = selectModelByCategory("chat");
} else {
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
}
if (chatModelVo == null) {
throw new IllegalStateException("未找到模型名称:" + chatRequest.getModel());
}
// 自动设置请求参数中的模型名称
chatRequest.setModel(chatModelVo.getModelName());
// 直接返回对应的聊天服务
return chatServiceFactory.getChatService(chatModelVo.getCategory());
} catch (Exception e) {
log.error("模型选择和服务获取失败: {}", e.getMessage(), e);
throw new IllegalStateException("模型选择和服务获取失败: " + e.getMessage());
}
}
/**
* 根据给定分类获取服务并发起调用(避免在降级时重复选择模型)
*/
private void autoSelectServiceByCategoryAndInvoke(ChatRequest chatRequest, SseEmitter sseEmitter, String category) {
IChatService service = chatServiceFactory.getChatService(category);
service.chat(chatRequest, sseEmitter);
}
/**
* 根据分类选择优先级最高的模型
*/
private ChatModelVo selectModelByCategory(String category) {
ChatModelVo model = chatModelService.selectModelByCategoryWithHighestPriority(category);
if (model == null) {
throw new IllegalStateException("未找到" + category + "分类的模型配置");
}
return model;
}
/**
* 获取对话标题
*
@@ -227,23 +142,69 @@ public class SseServiceImpl implements ISseService {
}
/**
* 构建消息列表
* 构建消息列表
*/
private void buildChatMessageList(ChatRequest chatRequest) {
private void buildChatMessageList(ChatRequest chatRequest){
String sysPrompt;
// 矫正模型名称 如果是gpt-image 则查询image类型模型 获取模型名称
if(chatRequest.getModel().equals("gpt-image")) {
chatModelVo = chatModelService.selectModelByCategory("image");
if (chatModelVo == null) {
log.error("未找到image类型的模型配置");
throw new IllegalStateException("未找到image类型的模型配置");
}
}else{
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
}
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
// 查询向量库相关信息加入到上下文
if(StringUtils.isNotEmpty(chatRequest.getKid())){
List<Message> knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
// 通过kid查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKid()));
// 查询向量模型配置信息
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
// 处理知识库相关逻辑
String sysPrompt = processKnowledgeBase(chatRequest, messages);
// 设置系统提示词
Message sysMessage = Message.builder()
.content(sysPrompt)
.role(Message.Role.SYSTEM)
.build();
messages.add(0, sysMessage);
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(content);
queryVectorBo.setKid(chatRequest.getKid());
queryVectorBo.setApiKey(chatModel.getApiKey());
queryVectorBo.setBaseUrl(chatModel.getApiHost());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
}
messages.addAll(knMessages);
// 设置知识库系统提示词
sysPrompt = knowledgeInfoVo.getSystemPrompt();
if(StringUtils.isEmpty(sysPrompt)){
sysPrompt ="###角色设定\n" +
"你是一个智能知识助手,专注于利用上下文中的信息来提供准确和相关的回答。\n" +
"###指令\n" +
"当用户的问题与上下文知识匹配时,利用上下文信息进行回答。如果问题与上下文不匹配,运用自身的推理能力生成合适的回答。\n" +
"###限制\n" +
"确保回答清晰简洁,避免提供不必要的细节。始终保持语气友好" +
"当前时间:"+ DateUtils.getDate();
}
}else {
sysPrompt = chatModelVo.getSystemPrompt();
if(StringUtils.isEmpty(sysPrompt)){
sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手名字叫熊猫助手。你擅长中英文对话能够理解并处理各种问题提供安全、有帮助、准确的回答。" +
"当前时间:"+ DateUtils.getDate()+
"#注意:回复之前注意结合上下文和工具返回内容进行回复。";
}
}
// 设置系统默认提示词
Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build();
messages.add(0,sysMessage);
chatRequest.setSysPrompt(sysPrompt);
// 用户对话内容
String chatString = null;
// 获取用户对话信息
@@ -252,128 +213,13 @@ public class SseServiceImpl implements ISseService {
if (CollectionUtil.isNotEmpty(listContent)) {
chatString = listContent.get(0).toString();
}
} else {
chatString = content.toString();
} else if (content instanceof String) {
chatString = (String) content;
}
// 设置对话信息
chatRequest.setPrompt(chatString);
}
/**
* 处理知识库相关逻辑
*/
private String processKnowledgeBase(ChatRequest chatRequest, List<Message> messages) {
if (StringUtils.isEmpty(chatRequest.getKid())) {
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
}
try {
// 查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKid()));
if (knowledgeInfoVo == null) {
log.warn("知识库信息不存在kid: {}", chatRequest.getKid());
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
}
// 查询向量模型配置信息
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
if (chatModel == null) {
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModelName());
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
}
// 构建向量查询参数
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
// 获取向量查询结果
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
// 添加知识库消息到上下文
addKnowledgeMessages(messages, nearestList);
// 返回知识库系统提示词
return getKnowledgeSystemPrompt(knowledgeInfoVo);
} catch (Exception e) {
log.error("处理知识库信息失败: {}", e.getMessage(), e);
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
}
}
/**
* 构建向量查询参数
*/
private QueryVectorBo buildQueryVectorBo(ChatRequest chatRequest, KnowledgeInfoVo knowledgeInfoVo,
ChatModelVo chatModel) {
String content = chatRequest.getMessages().get(chatRequest.getMessages().size() - 1).getContent().toString();
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(content);
queryVectorBo.setKid(chatRequest.getKid());
queryVectorBo.setApiKey(chatModel.getApiKey());
queryVectorBo.setBaseUrl(chatModel.getApiHost());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
return queryVectorBo;
}
/**
* 添加知识库消息到上下文
*/
private void addKnowledgeMessages(List<Message> messages, List<String> nearestList) {
for (String prompt : nearestList) {
Message userMessage = Message.builder()
.content(prompt)
.role(Message.Role.USER)
.build();
messages.add(userMessage);
}
}
/**
* 获取知识库系统提示词
*/
private String getKnowledgeSystemPrompt(KnowledgeInfoVo knowledgeInfoVo) {
String sysPrompt = knowledgeInfoVo.getSystemPrompt();
if (StringUtils.isEmpty(sysPrompt)) {
sysPrompt = "###角色设定\n" +
"你是一个智能知识助手,专注于利用上下文中的信息来提供准确和相关的回答。\n" +
"###指令\n" +
"当用户的问题与上下文知识匹配时,利用上下文信息进行回答。如果问题与上下文不匹配,运用自身的推理能力生成合适的回答。\n" +
"###限制\n" +
"确保回答清晰简洁,避免提供不必要的细节。始终保持语气友好\n" +
"当前时间:" + DateUtils.getDate();
}
return sysPrompt;
}
/**
* 获取提示词模板提示词
*/
private String getPromptTemplatePrompt(String category) {
PromptTemplateVo promptTemplateVo = promptTemplateService.queryByCategory(category);
if (Objects.isNull(promptTemplateVo) || StringUtils.isEmpty(promptTemplateVo.getTemplateContent())) {
return getDefaultSystemPrompt();
}
return promptTemplateVo.getTemplateContent();
}
/**
* 获取默认系统提示词
*/
private String getDefaultSystemPrompt() {
String sysPrompt = chatModelVo != null ? chatModelVo.getSystemPrompt() : null;
if (StringUtils.isEmpty(sysPrompt)) {
sysPrompt = "你是一个由RuoYI-AI开发的人工智能助手名字叫RuoYI人工智能助手。"
+ "你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。"
+ "当前时间:" + DateUtils.getDate()
+ "#注意:回复之前注意结合上下文和工具返回内容进行回复。";
}
return sysPrompt;
}
/**
* 文字转语音
@@ -386,8 +232,8 @@ public class SseServiceImpl implements ISseService {
InputStreamResource resource = new InputStreamResource(body.byteStream());
// 创建并返回ResponseEntity
return ResponseEntity.ok()
.contentType(MediaType.parseMediaType("audio/mpeg"))
.body(resource);
.contentType(MediaType.parseMediaType("audio/mpeg"))
.body(resource);
} else {
// 如果ResponseBody为空返回404状态码
return ResponseEntity.notFound().build();

View File

@@ -15,7 +15,6 @@ import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.ruoyi.chat.support.ChatServiceHelper;
@@ -52,14 +51,14 @@ public class ZhipuAiChatServiceImpl implements IChatService {
@SneakyThrows
@Override
public void onError(Throwable error) {
ChatServiceHelper.onStreamError(emitter, error.getMessage());
// System.out.println(error.getMessage());
emitter.send(error.getMessage());
}
@Override
public void onCompleteResponse(ChatResponse response) {
emitter.complete();
log.info("消息结束完整消息ID: {}", response.aiMessage());
org.ruoyi.chat.support.RetryNotifier.clear(emitter);
}
};
@@ -72,7 +71,6 @@ public class ZhipuAiChatServiceImpl implements IChatService {
model.chat(chatRequest.getPrompt(), handler);
} catch (Exception e) {
log.error("智谱清言请求失败:{}", e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -89,7 +89,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
* 根据知识库角色查询知识库列表
*/
@Override
public TableDataInfo<KnowledgeInfoVo> queryPageListByRole(KnowledgeInfoBo bo, PageQuery pageQuery) {
public TableDataInfo<KnowledgeInfoVo> queryPageListByRole(PageQuery pageQuery) {
// 查询用户关联角色
LoginUser loginUser = LoginHelper.getLoginUser();
if (StringUtils.isEmpty(loginUser.getKroleGroupIds()) || StringUtils.isEmpty(loginUser.getKroleGroupType())) {
@@ -122,15 +122,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
return new TableDataInfo<>();
}
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
// 在查询用户创建的知识库条件下,拼接角色分配知识库
lqw.or(q -> q.in(
KnowledgeInfo::getId,
knowledgeRoleRelations.stream()
.map(KnowledgeRoleRelation::getKnowledgeId)
.filter(Objects::nonNull)
.collect(Collectors.toList())
));
LambdaQueryWrapper<KnowledgeInfo> lqw = Wrappers.lambdaQuery();
lqw.in(KnowledgeInfo::getId, knowledgeRoleRelations.stream().map(KnowledgeRoleRelation::getKnowledgeId).filter(Objects::nonNull).collect(Collectors.toList()));
Page<KnowledgeInfoVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
return TableDataInfo.build(result);
}

View File

@@ -1,115 +0,0 @@
package org.ruoyi.chat.support;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.util.SSEUtil;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
/**
* 统一的聊天重试与降级调度器。
*
* 策略:
* - 当前模型最多重试 3 次;仍失败则降级到同分类内、优先级小于当前的最高优先级模型。
* - 降级模型同样最多重试 3 次;仍失败则向前端返回失败信息并停止。
*
* 注意:实现依赖调用方在底层异步失败时执行 onFailure.run() 通知本调度器。
*/
@Slf4j
public class ChatRetryHelper {
public interface AttemptStarter {
void start(ChatModelVo model, Runnable onFailure) throws Exception;
}
public static void executeWithRetry(
ChatModelVo primaryModel,
String category,
IChatModelService chatModelService,
SseEmitter emitter,
AttemptStarter attemptStarter
) {
Objects.requireNonNull(primaryModel, "primaryModel must not be null");
Objects.requireNonNull(category, "category must not be null");
Objects.requireNonNull(chatModelService, "chatModelService must not be null");
Objects.requireNonNull(emitter, "emitter must not be null");
Objects.requireNonNull(attemptStarter, "attemptStarter must not be null");
AtomicInteger mainAttempts = new AtomicInteger(0);
AtomicInteger fallbackAttempts = new AtomicInteger(0);
AtomicBoolean inFallback = new AtomicBoolean(false);
AtomicBoolean scheduling = new AtomicBoolean(false);
class Scheduler {
volatile ChatModelVo current = primaryModel;
volatile ChatModelVo fallback = null;
void startAttempt() {
try {
if (!inFallback.get()) {
if (mainAttempts.incrementAndGet() > 3) {
// 进入降级
inFallback.set(true);
if (fallback == null) {
Integer curPriority = primaryModel.getPriority();
if (curPriority == null) {
curPriority = Integer.MAX_VALUE;
}
fallback = chatModelService.selectFallbackModelByCategoryAndLessPriority(category, curPriority);
}
if (fallback == null) {
SSEUtil.sendErrorEvent(emitter, "当前模型重试3次均失败且无可用降级模型");
emitter.complete();
return;
}
current = fallback;
mainAttempts.set(3); // 锁定
fallbackAttempts.set(0);
}
} else {
if (fallbackAttempts.incrementAndGet() > 3) {
SSEUtil.sendErrorEvent(emitter, "降级模型重试3次仍失败");
emitter.complete();
return;
}
}
Runnable onFailure = () -> {
// 去抖:避免同一次失败触发多次重试
if (scheduling.compareAndSet(false, true)) {
try {
SSEUtil.sendErrorEvent(emitter, (inFallback.get() ? "降级模型" : "当前模型") + "调用失败,准备重试...");
// 立即发起下一次尝试
startAttempt();
} finally {
scheduling.set(false);
}
}
};
attemptStarter.start(current, onFailure);
} catch (Exception ex) {
log.error("启动聊天尝试失败: {}", ex.getMessage(), ex);
SSEUtil.sendErrorEvent(emitter, "启动聊天尝试失败: " + ex.getMessage());
// 直接按失败处理,继续重试/降级
if (scheduling.compareAndSet(false, true)) {
try {
startAttempt();
} finally {
scheduling.set(false);
}
}
}
}
}
new Scheduler().startAttempt();
}
}

View File

@@ -1,45 +0,0 @@
package org.ruoyi.chat.support;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.chat.util.SSEUtil;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 抽取各聊天实现类的通用逻辑:
* - 创建带开关的 SSE 监听器
* - 统一的流错误处理(根据是否在重试场景决定通知或直接结束)
* - 统一的完成处理(清理回调并 complete
*/
public class ChatServiceHelper {
public static SSEEventSourceListener createOpenAiListener(SseEmitter emitter, ChatRequest chatRequest) {
boolean retryEnabled = Boolean.TRUE.equals(chatRequest.getAutoSelectModel());
return new SSEEventSourceListener(
emitter,
chatRequest.getUserId(),
chatRequest.getSessionId(),
chatRequest.getToken(),
retryEnabled
);
}
public static void onStreamError(SseEmitter emitter, String errorMessage) {
SSEUtil.sendErrorEvent(emitter, errorMessage);
if (RetryNotifier.hasCallback(emitter)) {
RetryNotifier.notifyFailure(emitter);
} else {
emitter.complete();
}
}
public static void onStreamComplete(SseEmitter emitter) {
try {
emitter.complete();
} finally {
RetryNotifier.clear(emitter);
}
}
}

View File

@@ -1,51 +0,0 @@
package org.ruoyi.chat.support;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
/**
* 失败回调通知器基于发射器实例SseEmitter 等对象地址)绑定回调,
* 避免与业务标识绑定,且能跨线程正确关联。
*/
public class RetryNotifier {
private static final Map<Integer, Runnable> FAILURE_CALLBACKS = new ConcurrentHashMap<>();
private static int keyOf(Object obj) {
return System.identityHashCode(obj);
}
public static void setFailureCallback(Object emitterLike, Runnable callback) {
if (emitterLike == null || callback == null) {
return;
}
FAILURE_CALLBACKS.put(keyOf(emitterLike), callback);
}
public static void clear(Object emitterLike) {
if (emitterLike == null) {
return;
}
FAILURE_CALLBACKS.remove(keyOf(emitterLike));
}
public static void notifyFailure(Object emitterLike) {
if (emitterLike == null) {
return;
}
Runnable cb = FAILURE_CALLBACKS.get(keyOf(emitterLike));
if (Objects.nonNull(cb)) {
cb.run();
}
}
public static boolean hasCallback(Object emitterLike) {
if (emitterLike == null) {
return false;
}
return FAILURE_CALLBACKS.containsKey(keyOf(emitterLike));
}
}

View File

@@ -25,6 +25,6 @@ public class SSEUtil {
} catch (IOException e) {
log.error("SSE发送失败: {}", e.getMessage());
}
// 不立即关闭,由上层策略决定是否继续重试或降级
sseEmitter.complete();
}
}

View File

@@ -1,6 +1,5 @@
package org.ruoyi.generator.controller;
import cn.hutool.core.net.URLDecoder;
import jakarta.validation.constraints.NotNull;
import lombok.RequiredArgsConstructor;
import org.ruoyi.common.core.domain.R;
@@ -9,11 +8,10 @@ import org.ruoyi.generator.service.IGenTableService;
import org.ruoyi.generator.service.SchemaFieldService;
import org.springframework.context.annotation.Profile;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.nio.charset.StandardCharsets;
/**
* 代码生成 操作处理
*
@@ -48,18 +46,4 @@ public class GenController extends BaseController {
genTableService.generateCodeToClasspathByTableNames(tableNameStr);
return R.ok("代码生成成功");
}
/**
* 生成前端代码
*
* @param workPath 执行命令路径
* @param previewCode 执行生成前端文件命令
*/
@GetMapping("/batchGenFrontendCode")
public R<String> batchGenFrontendCode(@NotNull(message = "路径不能为空") String workPath, @NotNull(message = "指令不能为空") String previewCode) {
String decodedWorkPath = URLDecoder.decode(workPath, StandardCharsets.UTF_8);
String decodedPreviewCode = URLDecoder.decode(previewCode, StandardCharsets.UTF_8);
genTableService.generateFrontendTemplateFiles(decodedWorkPath, decodedPreviewCode);
return R.ok("代码生成成功");
}
}

View File

@@ -53,7 +53,6 @@ public class SchemaGroupController extends BaseController {
/**
* 获取数据模型分组选择列表
*/
@SaCheckPermission("dev:schemaGroup:select")
@GetMapping("/select")
public R<List<SchemaGroupVo>> select() {
SchemaGroupBo bo = new SchemaGroupBo();

View File

@@ -1,14 +1,13 @@
package org.ruoyi.generator.domain;
import com.baomidou.mybatisplus.annotation.FieldFill;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableLogic;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.ruoyi.core.domain.BaseEntity;
import java.io.Serializable;
import java.util.Date;
import java.io.Serial;
/**
* 数据模型对象 dev_schema
@@ -17,9 +16,12 @@ import java.util.Date;
* @date 2024-01-01
*/
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("dev_schema")
public class Schema implements Serializable {
public class Schema extends BaseEntity {
@Serial
private static final long serialVersionUID = 1L;
/**
* 主键
@@ -47,6 +49,41 @@ public class Schema implements Serializable {
*/
private String tableName;
/**
* 表注释
*/
private String comment;
/**
* 存储引擎
*/
private String engine;
/**
* 列表字段
*/
private String listKeys;
/**
* 搜索表单字段
*/
private String searchFormKeys;
/**
* 表单设计
*/
private String designer;
/**
* 状态
*/
private String status;
/**
* 排序
*/
private Integer sort;
/**
* 备注
*/
@@ -59,33 +96,8 @@ public class Schema implements Serializable {
private String delFlag;
/**
* 创建部门
* 租户编号
*/
@TableField(fill = FieldFill.INSERT)
private Long createDept;
/**
* 创建者
*/
@TableField(fill = FieldFill.INSERT)
private Long createBy;
/**
* 创建时间
*/
@TableField(fill = FieldFill.INSERT)
private Date createTime;
/**
* 更新者
*/
@TableField(fill = FieldFill.INSERT_UPDATE)
private Long updateBy;
/**
* 更新时间
*/
@TableField(fill = FieldFill.INSERT_UPDATE)
private Date updateTime;
private String tenantId;
}

View File

@@ -1,14 +1,13 @@
package org.ruoyi.generator.domain;
import com.baomidou.mybatisplus.annotation.FieldFill;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableLogic;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.ruoyi.core.domain.BaseEntity;
import java.io.Serializable;
import java.util.Date;
import java.io.Serial;
/**
* 数据模型字段对象 dev_schema_field
@@ -17,8 +16,12 @@ import java.util.Date;
* @date 2024-01-01
*/
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("dev_schema_field")
public class SchemaField implements Serializable {
public class SchemaField extends BaseEntity {
@Serial
private static final long serialVersionUID = 1L;
/**
* 主键
@@ -126,6 +129,15 @@ public class SchemaField implements Serializable {
*/
private String dictType;
/**
* 状态
*/
private String status;
/**
* 扩展JSON
*/
private String extendJson;
/**
* 备注
@@ -139,33 +151,8 @@ public class SchemaField implements Serializable {
private String delFlag;
/**
* 创建部门
* 租户编号
*/
@TableField(fill = FieldFill.INSERT)
private Long createDept;
/**
* 创建者
*/
@TableField(fill = FieldFill.INSERT)
private Long createBy;
/**
* 创建时间
*/
@TableField(fill = FieldFill.INSERT)
private Date createTime;
/**
* 更新者
*/
@TableField(fill = FieldFill.INSERT_UPDATE)
private Long updateBy;
/**
* 更新时间
*/
@TableField(fill = FieldFill.INSERT_UPDATE)
private Date updateTime;
private String tenantId;
}

View File

@@ -1,14 +1,13 @@
package org.ruoyi.generator.domain;
import com.baomidou.mybatisplus.annotation.FieldFill;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableLogic;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.ruoyi.core.domain.BaseEntity;
import java.io.Serializable;
import java.util.Date;
import java.io.Serial;
/**
* 数据模型分组对象 dev_schema_group
@@ -17,8 +16,12 @@ import java.util.Date;
* @date 2024-01-01
*/
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("dev_schema_group")
public class SchemaGroup implements Serializable {
public class SchemaGroup extends BaseEntity {
@Serial
private static final long serialVersionUID = 1L;
/**
* 主键
@@ -41,6 +44,16 @@ public class SchemaGroup implements Serializable {
*/
private String icon;
/**
* 排序
*/
private Integer sort;
/**
* 状态
*/
private String status;
/**
* 备注
*/
@@ -53,32 +66,8 @@ public class SchemaGroup implements Serializable {
private String delFlag;
/**
* 创建部门
* 租户编号
*/
@TableField(fill = FieldFill.INSERT)
private Long createDept;
private String tenantId;
/**
* 创建者
*/
@TableField(fill = FieldFill.INSERT)
private Long createBy;
/**
* 创建时间
*/
@TableField(fill = FieldFill.INSERT)
private Date createTime;
/**
* 更新者
*/
@TableField(fill = FieldFill.INSERT_UPDATE)
private Long updateBy;
/**
* 更新时间
*/
@TableField(fill = FieldFill.INSERT_UPDATE)
private Date updateTime;
}

View File

@@ -4,12 +4,12 @@ import io.github.linpeilie.annotations.AutoMapper;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.ruoyi.common.core.validate.AddGroup;
import org.ruoyi.common.core.validate.EditGroup;
import org.ruoyi.core.domain.BaseEntity;
import org.ruoyi.generator.domain.Schema;
import java.io.Serializable;
/**
* 数据模型业务对象 SchemaBo
*
@@ -17,8 +17,9 @@ import java.io.Serializable;
* @date 2024-01-01
*/
@Data
@EqualsAndHashCode(callSuper = true)
@AutoMapper(target = Schema.class, reverseConvertGenerate = false)
public class SchemaBo implements Serializable {
public class SchemaBo extends BaseEntity {
/**
* 主键
@@ -37,12 +38,52 @@ public class SchemaBo implements Serializable {
@NotBlank(message = "模型名称不能为空", groups = {AddGroup.class, EditGroup.class})
private String name;
/**
* 模型编码
*/
@NotBlank(message = "模型编码不能为空", groups = {AddGroup.class, EditGroup.class})
private String code;
/**
* 表名
*/
@NotBlank(message = "表名不能为空", groups = {AddGroup.class, EditGroup.class})
private String tableName;
/**
* 表注释
*/
private String comment;
/**
* 存储引擎
*/
private String engine;
/**
* 列表字段
*/
private String listKeys;
/**
* 搜索表单字段
*/
private String searchFormKeys;
/**
* 表单设计
*/
private String designer;
/**
* 状态
*/
private String status;
/**
* 排序
*/
private Integer sort;
/**
* 备注
*/

View File

@@ -4,12 +4,12 @@ import io.github.linpeilie.annotations.AutoMapper;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.ruoyi.common.core.validate.AddGroup;
import org.ruoyi.common.core.validate.EditGroup;
import org.ruoyi.core.domain.BaseEntity;
import org.ruoyi.generator.domain.SchemaField;
import java.io.Serializable;
/**
* 数据模型字段业务对象 SchemaFieldBo
*
@@ -17,8 +17,9 @@ import java.io.Serializable;
* @date 2024-01-01
*/
@Data
@EqualsAndHashCode(callSuper = true)
@AutoMapper(target = SchemaField.class, reverseConvertGenerate = false)
public class SchemaFieldBo implements Serializable {
public class SchemaFieldBo extends BaseEntity {
/**
* 主键
@@ -35,7 +36,7 @@ public class SchemaFieldBo implements Serializable {
/**
* 模型名称
*/
// @NotNull(message = "模型名称不能为空", groups = {AddGroup.class, EditGroup.class})
@NotNull(message = "模型名称不能为空", groups = {AddGroup.class, EditGroup.class})
private String schemaName;
/**
@@ -130,6 +131,16 @@ public class SchemaFieldBo implements Serializable {
*/
private String dictType;
/**
* 状态
*/
private String status;
/**
* 扩展JSON
*/
private String extendJson;
/**
* 备注
*/

View File

@@ -4,12 +4,12 @@ import io.github.linpeilie.annotations.AutoMapper;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.ruoyi.common.core.validate.AddGroup;
import org.ruoyi.common.core.validate.EditGroup;
import org.ruoyi.core.domain.BaseEntity;
import org.ruoyi.generator.domain.SchemaGroup;
import java.io.Serializable;
/**
* 数据模型分组业务对象 SchemaGroupBo
*
@@ -17,8 +17,9 @@ import java.io.Serializable;
* @date 2024-01-01
*/
@Data
@EqualsAndHashCode(callSuper = true)
@AutoMapper(target = SchemaGroup.class, reverseConvertGenerate = false)
public class SchemaGroupBo implements Serializable {
public class SchemaGroupBo extends BaseEntity {
/**
* 主键
@@ -43,6 +44,16 @@ public class SchemaGroupBo implements Serializable {
*/
private String icon;
/**
* 排序
*/
private Integer sort;
/**
* 状态
*/
private String status;
/**
* 备注
*/

View File

@@ -6,7 +6,9 @@ import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.ruoyi.generator.domain.SchemaField;
import java.io.Serial;
import java.io.Serializable;
import java.util.Date;
/**
* 数据模型字段视图对象 SchemaFieldVo
@@ -18,6 +20,9 @@ import java.io.Serializable;
@AutoMapper(target = SchemaField.class)
public class SchemaFieldVo implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
/**
* 主键
*/
@@ -131,8 +136,25 @@ public class SchemaFieldVo implements Serializable {
@Schema(description = "字典类型")
private String dictType;
/**
* 状态
*/
@Schema(description = "状态")
private String status;
/**
* 扩展JSON
*/
private String extendJson;
/**
* 备注
*/
private String remark;
/**
* 创建时间
*/
private Date createTime;
}

View File

@@ -4,6 +4,7 @@ import io.github.linpeilie.annotations.AutoMapper;
import lombok.Data;
import org.ruoyi.generator.domain.SchemaGroup;
import java.io.Serial;
import java.io.Serializable;
import java.util.Date;
@@ -17,6 +18,9 @@ import java.util.Date;
@AutoMapper(target = SchemaGroup.class)
public class SchemaGroupVo implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
/**
* 主键
*/
@@ -42,6 +46,11 @@ public class SchemaGroupVo implements Serializable {
*/
private Integer sort;
/**
* 状态
*/
private String status;
/**
* 备注
*/

View File

@@ -19,6 +19,9 @@ import java.util.Date;
@AutoMapper(target = Schema.class)
public class SchemaVo implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
/**
* 主键
*/
@@ -43,11 +46,6 @@ public class SchemaVo implements Serializable {
* 表名
*/
private String tableName;
/**
* 字典
*/
private String dictType;
/**
* 表注释

View File

@@ -11,19 +11,26 @@ import org.apache.velocity.app.Velocity;
import org.ruoyi.common.core.constant.Constants;
import org.ruoyi.generator.config.GenConfig;
import org.ruoyi.generator.domain.vo.SchemaFieldVo;
import org.ruoyi.generator.domain.vo.SchemaGroupVo;
import org.ruoyi.generator.domain.vo.SchemaVo;
import org.ruoyi.generator.service.IGenTableService;
import org.ruoyi.generator.service.SchemaFieldService;
import org.ruoyi.generator.service.SchemaGroupService;
import org.ruoyi.generator.service.SchemaService;
import org.ruoyi.generator.util.VelocityInitializer;
import org.ruoyi.generator.util.VelocityUtils;
import org.springframework.stereotype.Service;
import java.io.*;
import java.io.File;
import java.io.FileWriter;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
* 业务 服务层实现
@@ -37,7 +44,6 @@ public class GenTableServiceImpl implements IGenTableService {
private final SchemaService schemaService;
private final SchemaFieldService schemaFieldService;
private final SchemaGroupService schemaGroupService;
/**
* 基于表名称批量生成代码到classpath路径
@@ -53,41 +59,6 @@ public class GenTableServiceImpl implements IGenTableService {
}
}
@Override
public void generateFrontendTemplateFiles(String workPath, String previewCode) {
String os = System.getProperty("os.name").toLowerCase();
ProcessBuilder builder;
if (os.contains("win")) {
// Windows下用 cmd /c 执行 previewCode
builder = new ProcessBuilder("cmd.exe", "/c", previewCode);
} else {
// macOS/Linux 用 bash -c 执行 previewCode
builder = new ProcessBuilder("bash", "-c", previewCode);
}
// 设置工作目录
builder.directory(new File(workPath));
builder.redirectErrorStream(true);
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(
builder.start().getInputStream(),
StandardCharsets.UTF_8
)
)) {
String line;
log.info("执行结果:");
while ((line = reader.readLine()) != null) {
log.info(line);
}
} catch (Exception e) {
log.error("生成前端代码出错", e);
throw new RuntimeException("生成前端代码失败", e);
}
}
/**
* 根据表名称生成代码到classpath
*/
@@ -98,7 +69,6 @@ public class GenTableServiceImpl implements IGenTableService {
log.warn("Schema不存在表名: {}", tableName);
return;
}
// 查询Schema字段信息
List<SchemaFieldVo> fields = schemaFieldService.queryListByTableName(tableName);
if (CollUtil.isEmpty(fields)) {
@@ -158,19 +128,17 @@ public class GenTableServiceImpl implements IGenTableService {
*/
private VelocityContext prepareSchemaContext(SchemaVo schema, List<SchemaFieldVo> fields) {
VelocityContext context = new VelocityContext();
// 从配置文件读取基本配置
String packageName = GenConfig.getPackageName();
String author = GenConfig.getAuthor();
String tablePrefix = GenConfig.getTablePrefix();
boolean autoRemovePre = GenConfig.getAutoRemovePre();
// 处理表名和类名
Long schemaGroupId = schema.getSchemaGroupId();
SchemaGroupVo schemaGroupVo = schemaGroupService.queryById(schemaGroupId);
String tableName = schema.getTableName();
String baseClassName = schema.getTableName();
// 自动去除表前缀
if (autoRemovePre && StrUtil.isNotBlank(tablePrefix)) {
String[] prefixes = tablePrefix.split(",");
@@ -181,12 +149,12 @@ public class GenTableServiceImpl implements IGenTableService {
}
}
}
String className = toCamelCase(baseClassName, true); // 首字母大写的类名SysRole
String classname = toCamelCase(baseClassName, false); // 首字母小写的类名sysRole
String businessName = toCamelCase(baseClassName, false);
String moduleName = schemaGroupVo.getCode();
String moduleName = getModuleName(packageName);
// 基本信息
context.put("tableName", tableName);
context.put("tableComment", schema.getComment());
@@ -200,18 +168,18 @@ public class GenTableServiceImpl implements IGenTableService {
context.put("packageName", packageName);
context.put("moduleName", moduleName);
context.put("businessName", businessName);
// 权限相关
context.put("permissionPrefix", moduleName + ":" + businessName);
context.put("parentMenuId", "2000"); // 默认父菜单ID可配置
// 生成菜单ID
List<Long> menuIds = new ArrayList<>();
for (int i = 0; i < 6; i++) {
menuIds.add(IdUtil.getSnowflakeNextId());
}
context.put("menuIds", menuIds);
// 创建table对象包含menuIds等信息和方法
Map<String, Object> table = new HashMap<>();
table.put("menuIds", menuIds);
@@ -220,19 +188,29 @@ public class GenTableServiceImpl implements IGenTableService {
table.put("className", className);
table.put("classname", classname);
table.put("functionName", schema.getName());
// 添加表类型属性默认为crud类型
table.put("crud", true);
table.put("sub", false);
table.put("tree", false);
// 添加isSuperColumn方法
table.put("isSuperColumn", new Object() {
public boolean isSuperColumn(String javaField) {
// 定义超类字段BaseEntity中的字段
return "createBy".equals(javaField) || "createTime".equals(javaField)
|| "updateBy".equals(javaField) || "updateTime".equals(javaField)
|| "remark".equals(javaField) || "tenantId".equals(javaField);
}
});
context.put("table", table);
// 处理字段信息
List<Map<String, Object>> columns = new ArrayList<>();
Map<String, Object> pkColumn = null;
Set<String> importList = new HashSet<>();
// 添加基础导入
importList.add("java.io.Serializable");
@@ -240,7 +218,7 @@ public class GenTableServiceImpl implements IGenTableService {
Map<String, Object> column = new HashMap<>();
String javaType = getJavaType(field.getType());
String javaField = StrUtil.toCamelCase(field.getCode());
column.put("columnName", field.getCode());
column.put("columnComment", field.getName());
column.put("comment", field.getName()); // 添加comment别名
@@ -248,15 +226,15 @@ public class GenTableServiceImpl implements IGenTableService {
column.put("javaType", javaType);
column.put("javaField", javaField);
column.put("capJavaField", toCamelCase(field.getCode(), true));
// 布尔值dictType属性(兼容两种格式)
// 布尔值属性(兼容两种格式)
boolean isPk = "1".equals(field.getIsPk());
boolean isRequired = "1".equals(field.getIsRequired());
boolean isInsert = "1".equals(field.getIsInsert());
boolean isEdit = "1".equals(field.getIsEdit());
boolean isList = "1".equals(field.getIsList());
boolean isQuery = "1".equals(field.getIsQuery());
column.put("isPk", isPk ? 1 : 0);
column.put("pk", isPk); // 添加pk别名
column.put("isRequired", isRequired);
@@ -269,27 +247,27 @@ public class GenTableServiceImpl implements IGenTableService {
column.put("list", isList); // 添加list别名
column.put("isQuery", isQuery);
column.put("query", isQuery); // 添加query别名
column.put("queryType", field.getQueryType());
column.put("htmlType", field.getHtmlType());
column.put("dictType", field.getDictType());
column.put("sort", field.getSort());
// 添加readConverterExp方法
column.put("readConverterExp", new Object() {
});
// 根据Java类型添加相应的导入
addImportForJavaType(javaType, importList);
columns.add(column);
// 设置主键列
if (isPk) {
pkColumn = column;
}
}
// 如果没有主键,使用第一个字段作为主键
if (pkColumn == null && !columns.isEmpty()) {
pkColumn = columns.get(0);
@@ -297,28 +275,27 @@ public class GenTableServiceImpl implements IGenTableService {
pkColumn.put("isPk", 1);
pkColumn.put("pk", true);
}
context.put("columns", columns);
context.put("pkColumn", pkColumn);
context.put("importList", new ArrayList<>(importList));
return context;
}
/**
* 根据Java类型添加相应的导入
*/
private void addImportForJavaType(String javaType, Set<String> importList) {
switch (javaType) {
case "BigDecimal" -> importList.add("java.math.BigDecimal");
case "Date" -> importList.add("java.util.Date");
case "LocalDateTime" -> importList.add("java.time.LocalDateTime");
case "LocalDate" -> importList.add("java.time.LocalDate");
case "LocalTime" -> importList.add("java.time.LocalTime");
default -> {
}
}
}
* 根据Java类型添加相应的导入
*/
private void addImportForJavaType(String javaType, Set<String> importList) {
switch (javaType) {
case "BigDecimal" -> importList.add("java.math.BigDecimal");
case "Date" -> importList.add("java.util.Date");
case "LocalDateTime" -> importList.add("java.time.LocalDateTime");
case "LocalDate" -> importList.add("java.time.LocalDate");
case "LocalTime" -> importList.add("java.time.LocalTime");
default -> {}
}
}
/**
* 从包名中提取模块名
@@ -342,10 +319,10 @@ public class GenTableServiceImpl implements IGenTableService {
String packageName = GenConfig.getPackageName();
String tablePrefix = GenConfig.getTablePrefix();
boolean autoRemovePre = GenConfig.getAutoRemovePre();
// 处理类名
String baseClassName = schema.getTableName();
// 自动去除表前缀
if (autoRemovePre && StrUtil.isNotBlank(tablePrefix)) {
String[] prefixes = tablePrefix.split(",");
@@ -356,13 +333,13 @@ public class GenTableServiceImpl implements IGenTableService {
}
}
}
String className = toCamelCase(baseClassName, true); // 首字母大写SysRole
// 首字母小写sysRole
String moduleName = getModuleName(packageName);
String javaPath = "src/main/java/";
String mybatisPath = "src/main/resources/mapper/";
if (template.contains("domain.java.vm")) {
return javaPath + packageName.replace(".", "/") + "/domain/" + className + ".java";
} else if (template.contains("mapper.java.vm")) {
@@ -435,17 +412,16 @@ public class GenTableServiceImpl implements IGenTableService {
return "String";
}
String type = dbType.toLowerCase();
if (StrUtil.equalsAny(type, "int", "tinyint")) {
if (type.contains("int") || type.contains("tinyint") || type.contains("smallint")) {
return "Integer";
} else if (StrUtil.equalsAny(type, "bigint")) {
} else if (type.contains("bigint")) {
return "Long";
} else if (StrUtil.equalsAny(type, "decimal", "numeric", "float", "double")) {
} else if (type.contains("decimal") || type.contains("numeric") || type.contains("float") || type.contains(
"double")) {
return "BigDecimal";
} else if (StrUtil.equalsAny(type, "date")) {
return "LocalDate";
} else if (StrUtil.equalsAny(type, "datetime", "timestamp")) {
return "LocalDateTime";
} else if (StrUtil.equalsAny(type, "bit", "boolean")) {
} else if (type.contains("date") || type.contains("time")) {
return "Date";
} else if (type.contains("bit") || type.contains("boolean")) {
return "Boolean";
} else {
return "String";

View File

@@ -79,6 +79,7 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
lqw.eq(StringUtils.isNotBlank(bo.getQueryType()), SchemaField::getQueryType, bo.getQueryType());
lqw.eq(StringUtils.isNotBlank(bo.getHtmlType()), SchemaField::getHtmlType, bo.getHtmlType());
lqw.like(StringUtils.isNotBlank(bo.getDictType()), SchemaField::getDictType, bo.getDictType());
lqw.eq(StringUtils.isNotBlank(bo.getStatus()), SchemaField::getStatus, bo.getStatus());
lqw.orderByAsc(SchemaField::getSort);
return lqw;
}
@@ -149,6 +150,7 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
public List<SchemaFieldVo> queryListBySchemaId(Long schemaId) {
LambdaQueryWrapper<SchemaField> lqw = Wrappers.lambdaQuery();
lqw.eq(SchemaField::getSchemaId, schemaId);
lqw.eq(SchemaField::getStatus, "0"); // 只查询正常状态的字段
lqw.orderByAsc(SchemaField::getSort);
return baseMapper.selectVoList(lqw);
}
@@ -207,9 +209,9 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
Map<String, Object> result = new HashMap<>();
result.put("schemaGroupCode", schemaGroupVo.getCode());
result.put("tableName", schema.getTableName());
result.put("dictType",schema.getDictType());
result.put("tableComment", schema.getComment());
result.put("className", toCamelCase(schema.getTableName(), true));
// result.put("className", StrUtil.toCamelCase(schema.getTableName()));
result.put("tableCamelName", StrUtil.toCamelCase(schema.getTableName()));
result.put("functionName", schema.getName());
result.put("schemaName", schema.getName());
@@ -223,8 +225,6 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
if (pkField != null) {
Map<String, Object> pkColumn = new HashMap<>();
pkColumn.put("columnName", pkField.getCode());
pkColumn.put("dictType", pkField.getDictType());
pkColumn.put("columnComment", pkField.getName());
pkColumn.put("javaField", StrUtil.toCamelCase(pkField.getCode()));
pkColumn.put("javaType", getJavaType(pkField.getType()));
@@ -236,7 +236,6 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
for (SchemaFieldVo field : fields) {
Map<String, Object> column = new HashMap<>();
column.put("columnName", field.getCode());
column.put("dictType", field.getDictType());
column.put("columnComment", field.getName());
column.put("javaField", StrUtil.toCamelCase(field.getCode()));
column.put("javaType", getJavaType(field.getType()));
@@ -266,7 +265,8 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
return false;
}
LambdaQueryWrapper<SchemaField> lqw = Wrappers.lambdaQuery();
lqw.eq(SchemaField::getSchemaId, schemaId);
lqw.eq(SchemaField::getSchemaName, tableName);
lqw.eq(SchemaField::getStatus, "0");
// 检查是否已存在字段数据
List<SchemaFieldVo> existingFields = baseMapper.selectVoList(lqw);
if (CollUtil.isNotEmpty(existingFields)) {
@@ -280,27 +280,20 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
SchemaField field = new SchemaField();
field.setSchemaId(schemaId);
field.setSchemaName(tableName);
field.setDefaultValue((String) columnInfo.get("columnDefault"));
field.setComment((String) columnInfo.get("columnComment"));
field.setName((String) columnInfo.get("columnComment"));
field.setDictType(StrUtil.toCamelCase((String) columnInfo.get("dictType")));
field.setCode(StrUtil.toCamelCase((String) columnInfo.get("columnName")));
field.setType((String) columnInfo.get("dataType"));
field.setLength(Integer.valueOf(String.valueOf(columnInfo.get("columnSize"))));
field.setIsPk((Boolean) columnInfo.get("isPrimaryKey") ? "1" : "0");
field.setIsRequired(!(Boolean) columnInfo.get("isNullable") ? "1" : "0");
if ("1".equals(field.getIsPk())) {
field.setIsInsert("0");
field.setIsEdit("0");
}else {
field.setIsInsert("1");
field.setIsEdit("1");
}
field.setIsInsert("1");
field.setIsEdit("1");
field.setIsList("1");
field.setIsQuery("1");
field.setQueryType("EQ");
field.setHtmlType(getDefaultHtmlType((String) columnInfo.get("dataType")));
field.setSort(sort++);
field.setStatus("0");
// 如果字段名为空,使用字段代码作为名称
if (StringUtils.isBlank(field.getName())) {
field.setName(field.getCode());
@@ -370,15 +363,16 @@ public class SchemaFieldServiceImpl implements SchemaFieldService {
}
String type = dbType.toLowerCase();
if (StrUtil.equalsAny(type, "int", "tinyint", "smallint")) {
if (type.contains("int") || type.contains("tinyint") || type.contains("smallint")) {
return "Integer";
} else if (StrUtil.equalsAny(type, "bigint")) {
} else if (type.contains("bigint")) {
return "Long";
} else if (StrUtil.equalsAny(type, "decimal", "numeric", "float", "double")) {
} else if (type.contains("decimal") || type.contains("numeric") || type.contains("float") || type.contains(
"double")) {
return "BigDecimal";
} else if (StrUtil.equalsAny(type, "date", "datetime","timestamp")) {
} else if (type.contains("date") || type.contains("time")) {
return "Date";
} else if (StrUtil.equalsAny(type, "bit", "boolean")) {
} else if (type.contains("bit") || type.contains("boolean")) {
return "Boolean";
} else {
return "String";

View File

@@ -60,6 +60,9 @@ public class SchemaGroupServiceImpl implements SchemaGroupService {
LambdaQueryWrapper<SchemaGroup> lqw = Wrappers.lambdaQuery();
lqw.like(StringUtils.isNotBlank(bo.getName()), SchemaGroup::getName, bo.getName());
lqw.eq(StringUtils.isNotBlank(bo.getCode()), SchemaGroup::getCode, bo.getCode());
lqw.eq(bo.getSort() != null, SchemaGroup::getSort, bo.getSort());
lqw.eq(StringUtils.isNotBlank(bo.getStatus()), SchemaGroup::getStatus, bo.getStatus());
lqw.orderByAsc(SchemaGroup::getSort);
return lqw;
}
@@ -99,6 +102,9 @@ public class SchemaGroupServiceImpl implements SchemaGroupService {
*/
@Override
public Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid) {
return baseMapper.deleteByIds(ids) > 0;
if (isValid) {
//TODO 做一些业务上的校验,判断是否需要校验
}
return baseMapper.deleteBatchIds(ids) > 0;
}
}

Some files were not shown because too many files have changed in this diff Show More