feat: 重构模块

This commit is contained in:
ageerle
2025-04-10 17:25:23 +08:00
parent 3be9005f95
commit 2509099146
653 changed files with 1000 additions and 165766 deletions

View File

@@ -0,0 +1,55 @@
package org.ruoyi.chat.config;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
import org.ruoyi.common.core.service.ConfigService;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
/**
* Chat配置类
*
* @date: 2023/5/16
*/
@Configuration
@RequiredArgsConstructor
public class ChatConfig {
@Getter
private OpenAiStreamClient openAiStreamClient;
private final ConfigService configService;
@Bean
public OpenAiStreamClient openAiStreamClient() {
String apiHost = configService.getConfigValue("chat", "apiHost");
String apiKey = configService.getConfigValue("chat", "apiKey");
openAiStreamClient = createOpenAiStreamClient(apiHost,apiKey);
return openAiStreamClient;
}
public OpenAiStreamClient createOpenAiStreamClient(String apiHost, String apiKey) {
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
OkHttpClient okHttpClient = new OkHttpClient.Builder()
.addInterceptor(httpLoggingInterceptor)
.connectTimeout(30, TimeUnit.SECONDS)
.writeTimeout(600, TimeUnit.SECONDS)
.readTimeout(600, TimeUnit.SECONDS)
.build();
return OpenAiStreamClient.builder()
.apiHost(apiHost)
.apiKey(Collections.singletonList(apiKey))
.keyStrategy(new KeyRandomStrategy())
.okHttpClient(okHttpClient)
.build();
}
}

View File

@@ -0,0 +1,43 @@
package org.ruoyi.chat.config;
import jakarta.annotation.PostConstruct;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.ruoyi.common.core.utils.OkHttpUtil;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
@Component
@RequiredArgsConstructor
public class OkHttpConfig {
private final IChatModelService chatModelService;
private final Map<String, OkHttpUtil> okHttpUtilMap = new HashMap<>();
@Getter
private String generate;
@PostConstruct
public void init() {
initializeOkHttpUtil("suno");
initializeOkHttpUtil("luma");
initializeOkHttpUtil("ppt");
}
private void initializeOkHttpUtil(String modelName) {
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
OkHttpUtil okHttpUtil = new OkHttpUtil();
okHttpUtil.setApiHost(chatModelVo.getApiHost());
okHttpUtil.setApiKey(chatModelVo.getApiKey());
generate = String.valueOf(chatModelVo.getModelPrice());
okHttpUtilMap.put(modelName, okHttpUtil);
}
public OkHttpUtil getOkHttpUtil(String modelName) {
return okHttpUtilMap.get(modelName);
}
}

View File

@@ -6,11 +6,9 @@ import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
import org.ruoyi.common.core.domain.R;
import org.ruoyi.common.core.domain.model.LoginUser;
@@ -19,7 +17,6 @@ import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.domain.bo.ChatMessageBo;
import org.ruoyi.domain.vo.ChatMessageVo;
import org.ruoyi.service.IChatMessageService;
import org.springframework.core.io.Resource;

View File

@@ -1,58 +0,0 @@
package org.ruoyi.chat.controller;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.ruoyi.common.core.domain.R;
import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.common.web.core.BaseController;
import org.ruoyi.system.domain.vo.cover.CoverParamVo;
import org.ruoyi.system.domain.vo.cover.CoverVo;
import org.ruoyi.system.domain.vo.cover.CoverCallbackVo;
import org.ruoyi.system.domain.vo.cover.MusicVo;
import org.ruoyi.system.service.ICoverService;
import org.springframework.web.bind.annotation.*;
import java.util.List;
/**
* 绘声美音-翻唱
*
* @author NSL
* @since 2024-12-25
*/
@Api(tags = "歌曲翻唱")
@RequiredArgsConstructor
@RestController
@RequestMapping("/cover")
public class CoverController extends BaseController {
private final ICoverService coverService;
@ApiOperation(value = "查找歌曲")
@GetMapping("/searchMusic")
public R<List<MusicVo>> searchMusic(String musicName) {
return R.ok(coverService.searchMusic(musicName));
}
@ApiOperation(value = "翻唱歌曲")
@PostMapping("/saveCoverTask")
public R<Void> saveCoverTask(@RequestBody CoverParamVo coverParamVo) {
coverService.saveCoverTask(coverParamVo);
return R.ok("翻唱歌曲处理中请等待10分钟-30分钟翻唱结果请到翻唱记录中查询");
}
@ApiOperation(value = "查询翻唱记录")
@PostMapping("/searchCoverRecord")
public R<TableDataInfo<CoverVo>> searchCoverRecord(@RequestBody PageQuery pageQuery) {
return R.ok(coverService.searchCoverRecord(pageQuery));
}
@ApiOperation(value = "翻唱回调接口")
@PostMapping("/callback")
public R<Void> callback(@RequestBody CoverCallbackVo coverCallbackVo) {
coverService.callback(coverCallbackVo);
return R.ok();
}
}

View File

@@ -8,6 +8,7 @@ import lombok.extern.slf4j.Slf4j;
import okhttp3.Request;
import org.apache.commons.lang3.math.NumberUtils;
import org.ruoyi.chat.domain.InsightFace;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.chat.util.MjOkHttpUtil;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;

View File

@@ -4,11 +4,13 @@ import cn.hutool.json.JSONUtil;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Request;
import org.apache.commons.lang3.math.NumberUtils;
import org.ruoyi.chat.config.OkHttpConfig;
import org.ruoyi.chat.domain.bo.GenerateLuma;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.common.core.utils.OkHttpUtil;
import org.ruoyi.system.cofing.OkHttpConfig;
import org.ruoyi.system.domain.GenerateLuma;
import org.springframework.web.bind.annotation.*;
/**

View File

@@ -1,79 +0,0 @@
package org.ruoyi.chat.controller;
import com.alibaba.fastjson.JSONObject;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.ruoyi.common.core.domain.R;
import org.ruoyi.common.web.core.BaseController;
import org.ruoyi.system.domain.vo.ppt.*;
import org.ruoyi.system.service.IPptService;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* AI_PPT
*
* @author NSL
* @since 2024-12-30
*/
@Api(tags = "AI-PPT")
@RequiredArgsConstructor
@RestController
@RequestMapping("/ppt")
public class PptController extends BaseController {
private final IPptService pptService;
@ApiOperation(value = "获取API Token")
@GetMapping("/getApiToken")
public R<String> getApiToken() {
return R.ok(pptService.getApiToken());
}
@ApiOperation(value = "同步流式生成 PPT")
@PostMapping("/syncStreamGeneratePpt")
public R<Void> syncStreamGeneratePpt(String title) {
pptService.syncStreamGeneratePpt(title);
return R.ok();
}
@ApiOperation(value = "查询所有PPT列表")
@PostMapping("/selectPptList")
public R<Void> selectPptList(@RequestBody PptAllQueryDto pptQueryVo) {
pptService.selectPptList(pptQueryVo);
return R.ok();
}
@ApiOperation(value = "生成大纲")
@PostMapping(value = "/generateOutline", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
public SseEmitter generateOutline(@RequestBody PptGenerateOutlineDto generateOutlineDto) {
return pptService.generateOutline(generateOutlineDto);
}
@ApiOperation(value = "生成大纲内容")
@PostMapping(value = "/generateContent", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
public SseEmitter generateOutline(@RequestBody PptGenerateContentDto generateContentDto) {
return pptService.generateContent(generateContentDto);
}
@ApiOperation(value = "分页查询 PPT 模板")
@PostMapping("/getTemplates")
public R<JSONObject> getPptTemplates(@RequestBody PptTemplateQueryDto pptQueryVo) {
return R.ok(pptService.getPptTemplates(pptQueryVo));
}
@ApiOperation(value = "生成 PPT")
@PostMapping("/generatePptx")
public R<JSONObject> generatePptx(@RequestBody PptGeneratePptxDto pptQueryVo) {
return R.ok(pptService.generatePptx(pptQueryVo));
}
@ApiOperation(value = "生成PPT成功回调接口")
@PostMapping("/successCallback")
public R<Void> successCallback() {
pptService.successCallback();
return R.ok();
}
}

View File

@@ -9,6 +9,7 @@ import okhttp3.Request;
import org.apache.commons.lang3.math.NumberUtils;
import org.ruoyi.chat.domain.dto.*;
import org.ruoyi.chat.enums.ActionType;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.chat.util.MjOkHttpUtil;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;

View File

@@ -4,12 +4,14 @@ import cn.hutool.json.JSONUtil;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Request;
import org.apache.commons.lang3.math.NumberUtils;
import org.ruoyi.chat.config.OkHttpConfig;
import org.ruoyi.chat.domain.bo.GenerateLyric;
import org.ruoyi.chat.domain.bo.GenerateSuno;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.common.core.utils.OkHttpUtil;
import org.ruoyi.system.cofing.OkHttpConfig;
import org.ruoyi.system.domain.GenerateLyric;
import org.ruoyi.system.domain.GenerateSuno;
import org.springframework.web.bind.annotation.*;
@RestController

View File

@@ -1,51 +0,0 @@
package org.ruoyi.chat.controller;
import lombok.RequiredArgsConstructor;
import org.ruoyi.common.core.domain.R;
import org.ruoyi.common.web.core.BaseController;
import org.ruoyi.system.request.RoleListDto;
import org.ruoyi.system.request.SimpleGenerateRequest;
import org.ruoyi.system.response.SimpleGenerateDataResponse;
import org.ruoyi.system.response.rolelist.ChatAppStoreVO;
import org.springframework.web.bind.annotation.*;
import java.util.List;
/**
* 应用市场
*
* @author Lion Li
* @date 2024-03-19
*/
@RequiredArgsConstructor
@RestController
@RequestMapping("/system/voice")
public class VoiceController extends BaseController {
private final IChatAppStoreService voiceRoleService;
/**
* 实时语音生成
*/
@PostMapping("/simpleGenerate")
public R<SimpleGenerateDataResponse> simpleGenerate(@RequestBody SimpleGenerateRequest simpleGenerateRequest) {
return R.ok(voiceRoleService.simpleGenerate(simpleGenerateRequest));
}
/**
* 角色市场
*/
@GetMapping("/roleList")
public R<List<ChatAppStoreVO>> roleList() {
return R.ok(voiceRoleService.roleList());
}
/**
* 收藏角色
*/
@PostMapping("/copyRole")
public R<String> copyRole(@RequestBody RoleListDto roleListDto) {
voiceRoleService.copyRole(roleListDto);
return R.ok();
}
}

View File

@@ -0,0 +1,22 @@
package org.ruoyi.chat.domain.bo;
import lombok.Data;
/**
* 描述:文生视频请求对象
*
* @author ageerle@163.com
* date 2024/6/27
*/
@Data
public class GenerateLuma {
private String aspect_ratio;
private boolean expand_prompt;
private String image_url;
private String user_prompt;
}

View File

@@ -0,0 +1,23 @@
package org.ruoyi.chat.domain.bo;
import lombok.Data;
/**
* 描述:生成歌词
*
* @author ageerle@163.com
* date 2024/6/27
*/
@Data
public class GenerateLyric {
/**
* 歌词提示词
*/
private String prompt;
/**
* 回调地址
*/
private String notify_hook;
}

View File

@@ -0,0 +1,64 @@
package org.ruoyi.chat.domain.bo;
import lombok.Data;
import java.io.Serializable;
/**
* @author WangLe
*/
@Data
public class GenerateSuno implements Serializable {
/**
* 歌词 (自定义模式专用)
*/
private String prompt;
/**
* mv模型chirp-v3-0、chirp-v3-5。不写默认 chirp-v3-0
*/
private String mv;
/**
* 标题(自定义模式专用)
*/
private String title;
/**
* 风格标签(自定义模式专用)
*/
private String tags;
/**
* 是否生成纯音乐true 为生成纯音乐
*/
private boolean make_instrumental;
/**
* 任务id用于对之前的任务再操作
*/
private String task_id;
/**
* float歌曲延长时间单位秒
*/
private int continue_at;
/**
* 歌曲id需要续写哪首歌
*/
private String continue_clip_id;
/**
* 灵感模式提示词(灵感模式专用)
*/
private String gpt_description_prompt;
/**
* 回调地址
*/
private String notify_hook;
}

View File

@@ -0,0 +1,34 @@
package org.ruoyi.chat.enums;
import lombok.Getter;
@Getter
public enum BillingType {
TOKEN("1", "token扣费"), // token扣费
TIMES("2", "次数扣费"); // 次数扣费
private final String code;
private final String description;
BillingType(String code, String description) {
this.code = code;
this.description = description;
}
public static BillingType fromCode(String code) {
for (BillingType type : values()) {
if (type.getCode().equals(code)) {
return type;
}
}
return null;
}
public String getCode() {
return code;
}
public String getDescription() {
return description;
}
}

View File

@@ -0,0 +1,34 @@
package org.ruoyi.chat.enums;
import lombok.Getter;
@Getter
public enum UserGradeType {
UNPAID("0", "未付费"), // 未付费用户
PAID("1", "已付费"); // 已付费用户
private final String code;
private final String description;
UserGradeType(String code, String description) {
this.code = code;
this.description = description;
}
public static UserGradeType fromCode(String code) {
for (UserGradeType type : values()) {
if (type.getCode().equals(code)) {
return type;
}
}
return null;
}
public String getCode() {
return code;
}
public String getDescription() {
return description;
}
}

View File

@@ -11,7 +11,10 @@ import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.core.utils.SpringUtils;
import org.ruoyi.common.core.utils.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@@ -40,6 +43,9 @@ public class SSEEventSourceListener extends EventSourceListener {
private StringBuilder stringBuffer;
private String modelName;
private static final IChatCostService chatCostService = SpringUtils.getBean(IChatCostService.class);
/**
* {@inheritDoc}
*/
@@ -55,11 +61,15 @@ public class SSEEventSourceListener extends EventSourceListener {
@Override
public void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
try {
if ("[DONE]".equals(data)) {
//成功响应
emitter.complete();
// 扣除费用 (消耗字符 模型名称)
// 扣除费用
ChatRequest chatRequest = new ChatRequest();
chatRequest.setModel(modelName);
chatRequest.setPrompt(stringBuffer.toString());
chatCostService.deductToken(chatRequest);
return;
}
// 解析返回内容

View File

@@ -1,6 +1,6 @@
package org.ruoyi.chat.service.chat;
import org.ruoyi.domain.bo.ChatMessageBo;
import org.ruoyi.common.chat.request.ChatRequest;
/**
* 计费管理Service接口
@@ -11,16 +11,16 @@ import org.ruoyi.domain.bo.ChatMessageBo;
public interface IChatCostService {
/**
* 根据消耗的tokens扣除余额
* 扣除余额并且保存记录
*
* @param chatMessageBo
* @param chatRequest 对话信息
* @return 结果
*/
void deductToken(ChatMessageBo chatMessageBo);
void deductToken(ChatRequest chatRequest);
/**
* 扣除用户的余额
* 直接扣除用户的余额
*
*/
void deductUserBalance(Long userId, Double numberCost);

View File

@@ -1,7 +1,7 @@
package org.ruoyi.chat.service.chat;
import jakarta.servlet.http.HttpServletRequest;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
@@ -62,4 +62,13 @@ public interface ISseService {
* @return 回复内容
*/
String wxCpChat(String prompt);
/**
* 联网查询
*
* @param prompt 提示词
* @return 查询内容
*/
String webSearch (String prompt);
}

View File

@@ -0,0 +1,166 @@
package org.ruoyi.chat.service.chat.impl;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.enums.BillingType;
import org.ruoyi.chat.enums.UserGradeType;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.utils.TikTokensUtil;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.common.core.exception.base.BaseException;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.domain.ChatToken;
import org.ruoyi.domain.bo.ChatMessageBo;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatMessageService;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.IChatTokenService;
import org.ruoyi.system.domain.SysUser;
import org.ruoyi.system.mapper.SysUserMapper;
import org.springframework.stereotype.Service;
/**
* 计费管理Service实现
*
* @author ageerle
* @date 2025-04-08
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatCostServiceImpl implements IChatCostService {
private final SysUserMapper sysUserMapper;
private final IChatMessageService chatMessageService;
private final IChatTokenService chatTokenService;
private final IChatModelService chatModelService;
/**
* 扣除用户余额
*/
public void deductToken(ChatRequest chatRequest) {
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
String modelName = chatRequest.getModel();
ChatMessageBo chatMessageBo = new ChatMessageBo();
// 计算总token数
ChatToken chatToken = chatTokenService.queryByUserId(getUserId(), modelName);
if (chatToken == null) {
chatToken = new ChatToken();
chatToken.setToken(0);
}
int totalTokens = chatToken.getToken() + tokens;
// 如果总token数大于等于1000,进行费用扣除
if (totalTokens >= 1000) {
// 计算费用
int token1 = totalTokens / 1000;
int token2 = totalTokens % 1000;
if (token2 > 0) {
// 保存剩余tokens
chatToken.setModelName(modelName);
chatToken.setUserId(getUserId());
chatToken.setToken(token2);
chatTokenService.editToken(chatToken);
} else {
chatTokenService.resetToken(getUserId(), modelName);
}
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
double cost = chatModelVo.getModelPrice();
if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
// 按次数扣费
deductUserBalance(getUserId(), cost);
chatMessageBo.setDeductCost(cost);
}else {
// 按token扣费
Double numberCost = token1 * cost;
deductUserBalance(chatMessageBo.getUserId(), numberCost);
chatMessageBo.setDeductCost(numberCost);
}
chatMessageBo.setContent(chatRequest.getPrompt());
} else {
deductUserBalance(getUserId(), 0.0);
chatMessageBo.setDeductCost(0d);
chatMessageBo.setRemark("不满1kToken,计入下一次!");
chatToken.setToken(totalTokens);
chatToken.setModelName(chatMessageBo.getModelName());
chatToken.setUserId(chatMessageBo.getUserId());
chatTokenService.editToken(chatToken);
}
// 保存消息记录
chatMessageService.insertByBo(chatMessageBo);
}
/**
* 从用户余额中扣除费用
*
* @param userId 用户ID
* @param numberCost 要扣除的费用
*/
@Override
public void deductUserBalance(Long userId, Double numberCost) {
SysUser sysUser = sysUserMapper.selectById(userId);
if (sysUser == null) {
return;
}
Double userBalance = sysUser.getUserBalance();
if (userBalance < numberCost || userBalance == 0) {
throw new ServiceException("余额不足, 请充值");
}
sysUserMapper.update(null,
new LambdaUpdateWrapper<SysUser>()
.set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))
.eq(SysUser::getUserId, userId));
}
/**
* 扣除任务费用
*/
@Override
public void taskDeduct(String type,String prompt, double cost) {
// 判断用户是否付费
checkUserGrade();
// 扣除费用
deductUserBalance(getUserId(), cost);
// 保存消息记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName(type);
chatMessageBo.setContent(prompt);
chatMessageBo.setDeductCost(cost);
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
}
/**
* 判断用户是否付费
*/
@Override
public void checkUserGrade() {
SysUser sysUser = sysUserMapper.selectById(getUserId());
if(UserGradeType.UNPAID.getCode().equals(sysUser.getUserGrade())){
throw new BaseException("该模型仅限付费用户使用。请升级套餐,开启高效体验之旅!");
}
}
/**
* 获取用户Id
*/
public Long getUserId() {
LoginUser loginUser = LoginHelper.getLoginUser();
if (loginUser == null) {
throw new BaseException("用户未登录!");
}
return loginUser.getUserId();
}
}

View File

@@ -1,8 +1,11 @@
package org.ruoyi.chat.service.chat.impl;
import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.extra.servlet.ServletUtil;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.ServiceException;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.service.v4.tools.*;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
@@ -13,11 +16,13 @@ import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.chat.config.ChatConfig;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.common.chat.config.ChatConfig;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.chat.util.IpUtil;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
@@ -26,12 +31,17 @@ import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.core.utils.file.FileUtils;
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
import org.ruoyi.common.redis.utils.RedisUtils;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.EmbeddingService;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.VectorStoreService;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.http.MediaType;
@@ -48,8 +58,12 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@Service
@Slf4j
@@ -62,33 +76,80 @@ public class SseServiceImpl implements ISseService {
private final IChatModelService chatModelService;
private final EmbeddingService embeddingService;
private final VectorStoreService vectorStore;
private final ConfigService configService;
private final IChatCostService chatCostService;
private static final String requestIdTemplate = "mycompany-%d";
private static final ObjectMapper mapper = new ObjectMapper();
@Override
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
SseEmitter sseEmitter = new SseEmitter(0L);
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
// 用户对话内容
String chatString = null;
try {
if (StpUtil.isLogin()) {
// 通过模型名称查询模型信息
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
// 构建api请求客户端
openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
// 设置默认提示词
Message sysMessage = Message.builder().content(chatModelVo.getSystemPrompt()).role(Message.Role.SYSTEM).build();
messages.add(0,sysMessage);
// 模型设置默认提示词
// 查询向量库相关信息加入到上下文
if(chatRequest.getKid()!=null){
List<Message> knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
List<String> nearestList;
List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
}
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意回答问题时须严格根据我给你的系统上下文内容原文进行回答请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
knMessages.add(userMessage);
messages.addAll(knMessages);
}
// 是否开启联网查询
// 获取用户对话信息
Object content = messages.get(messages.size() - 1).getContent();
if (content instanceof List<?> listContent) {
if (CollectionUtil.isNotEmpty(listContent)) {
chatString = listContent.get(0).toString();
}
} else if (content instanceof String) {
chatString = (String) content;
}
// 加载联网信息
if(chatRequest.getSearch()){
Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
messages.add(message);
}
}else {
// 未登录用户限制对话次数,默认5次
String clientIp = ServletUtil.getClientIP((javax.servlet.http.HttpServletRequest) request,"X-Forwarded-For");
// 未登录用户限制对话次数
String clientIp = IpUtil.getClientIp(request);
// 访客每天默认只能对话5次
int timeWindowInSeconds = 5;
String redisKey = "visitor:" + clientIp;
String redisKey = "clientIp:" + clientIp;
int count = 0;
if (RedisUtils.getCacheObject(redisKey) == null) {
// 当前访问次数
// 缓存有效时间1天
RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
}else {
count = RedisUtils.getCacheObject(redisKey);
@@ -104,13 +165,11 @@ public class SseServiceImpl implements ISseService {
.builder()
.messages(messages)
.model(chatRequest.getModel())
.temperature(chatRequest.getTemperature())
.topP(chatRequest.getTop_p())
.stream(true)
.stream(chatRequest.getStream())
.build();
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
// 保存消息记录 并扣除费用
chatCostService.deductToken(chatRequest);
} catch (Exception e) {
String message = e.getMessage();
sendErrorEvent(sseEmitter, message);
@@ -147,7 +206,6 @@ public class SseServiceImpl implements ISseService {
if (body != null) {
// 将ResponseBody转换为InputStreamResource
InputStreamResource resource = new InputStreamResource(body.byteStream());
// 创建并返回ResponseEntity
return ResponseEntity.ok()
.contentType(MediaType.parseMediaType("audio/mpeg"))
@@ -289,4 +347,58 @@ public class SseServiceImpl implements ISseService {
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
}
public String webSearch (String prompt) {
String zhipuValue = configService.getConfigValue("zhipu", "key");
if(StringUtils.isEmpty(zhipuValue)){
throw new IllegalStateException("zhipu config value is empty,请在chat_config中配置zhipu key信息");
}else {
ClientV4 client = new ClientV4.Builder(zhipuValue)
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
SearchChatMessage jsonNodes = new SearchChatMessage();
jsonNodes.setRole(Message.Role.USER.getName());
jsonNodes.setContent(prompt);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
.model("web-search-pro")
.stream(Boolean.TRUE)
.messages(Collections.singletonList(jsonNodes))
.requestId(requestId)
.build();
WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
List<ChoiceDelta> choices = new ArrayList<>();
if (webSearchApiResponse.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
webSearchApiResponse.getFlowable().map(result -> result)
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
log.info("Response: ");
}
ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
if (delta != null && delta.getToolCalls() != null) {
log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
}
choices.add(delta);
}
})
.doOnComplete(() -> System.out.println("Stream completed."))
.doOnError(throwable -> System.err.println("Error: " + throwable))
.blockingSubscribe();
WebSearchPro chatMessageAccumulator = lastAccumulator.get();
webSearchApiResponse.setFlowable(null);
webSearchApiResponse.setData(chatMessageAccumulator);
}
return choices.get(1).getToolCalls().toString();
}
}
}

View File

@@ -4,7 +4,6 @@ import jakarta.annotation.Resource;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.config.ChatConfig;
import org.ruoyi.common.chat.entity.embeddings.Embedding;
import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;

View File

@@ -0,0 +1,51 @@
package org.ruoyi.chat.util;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.util.StringUtils;
/**
* @author WangLe
*/
public class IpUtil {
public static String getClientIp(HttpServletRequest request) {
String ip = null;
// 获取 X-Forwarded-For 中的第一个非 unknown 的 IP
String xForwardedFor = request.getHeader("X-Forwarded-For");
if (StringUtils.hasLength(xForwardedFor) && !"unknown".equalsIgnoreCase(xForwardedFor)) {
String[] ipAddresses = xForwardedFor.split(",");
for (String ipAddress : ipAddresses) {
if (StringUtils.hasLength(ipAddress) && !"unknown".equalsIgnoreCase(ipAddress.trim())) {
ip = ipAddress.trim();
break;
}
}
}
// 如果 X-Forwarded-For 中没有找到,则依次尝试其他 header
if (ip == null) {
ip = request.getHeader("X-Real-IP");
}
if (ip == null || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
// 如果以上都没有获取到,则使用 RemoteAddr
if (ip == null || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
}