This commit is contained in:
ageer
2024-02-27 23:10:23 +08:00
parent fc01a64070
commit 363b2a625a
5 changed files with 5 additions and 563 deletions

View File

@@ -63,9 +63,11 @@
</div>
## 语音克隆
<div>
<video src="./video/01.mp4"></video>
</div>
<video src="https://github.com/ageerle/ruoyi-ai/blob/main/video/81d065fabda5c26f66b514442dce74a3.mp4?raw=true" controls="controls">
您的浏览器不支持 video 标签。
</video>
## 私有知识库管理(开发中)
<div>

View File

@@ -1,27 +0,0 @@
package com.xmzs.system.service;
import com.xmzs.system.domain.bo.ChatMessageBo;
/**
* @author hncboy
* @date 2023/3/22 19:41
* 聊天相关业务接口
*/
public interface ChatService {
/**
* 根据消耗的tokens扣除余额
*
* @param chatMessageBo
* @return 结果
*/
void deductToken(ChatMessageBo chatMessageBo);
/**
* 扣除用户的余额
*
*/
void deductUserBalance(Long userId, Double numberCost);
}

View File

@@ -1,49 +0,0 @@
package com.xmzs.system.service;
import com.xmzs.common.chat.domain.request.ChatRequest;
import com.xmzs.common.chat.domain.request.Dall3Request;
import com.xmzs.common.chat.entity.images.Item;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.List;
/**
* 描述:
*
* @author https:www.unfbx.com
* @date 2023-04-08
*/
public interface SseService {
/**
* 客户端发送消息到服务端
* @param chatRequest
*/
SseEmitter sseChat(ChatRequest chatRequest);
/**
* 绘画接口
* @param request
*/
List<Item> dall3(Dall3Request request);
/**
* mj绘画接口
*/
void mjTask();
/**
* 中转接口
*/
SseEmitter transitChat(ChatRequest chatRequest);
/**
* azure 聊天接口
*
* @param chatRequest
* @return
*/
SseEmitter azureChat(ChatRequest chatRequest);
}

View File

@@ -1,104 +0,0 @@
package com.xmzs.system.service.impl;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.xmzs.common.chat.entity.chat.BaseChatCompletion;
import com.xmzs.common.chat.entity.chat.ChatCompletion;
import com.xmzs.common.core.exception.ServiceException;
import com.xmzs.system.domain.ChatToken;
import com.xmzs.system.domain.SysUser;
import com.xmzs.system.domain.bo.ChatMessageBo;
import com.xmzs.system.mapper.SysUserMapper;
import com.xmzs.system.service.ChatService;
import com.xmzs.system.service.IChatMessageService;
import com.xmzs.system.service.IChatTokenService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
/**
* @author hncboy
* @date 2023/3/22 19:41
* 聊天相关业务实现类
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
private final SysUserMapper sysUserMapper;
private final IChatMessageService chatMessageService;
private final IChatTokenService chatTokenService;
/**
* 根据消耗的tokens扣除余额
*
* @param chatMessageBo
*
*/
public void deductToken(ChatMessageBo chatMessageBo) {
// 计算总token数
ChatToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), chatMessageBo.getModelName());
if(chatToken == null){
chatToken = new ChatToken();
chatToken.setToken(0);
}
int totalTokens = chatToken.getToken()+ chatMessageBo.getTotalTokens();
// 如果总token数大于等于1000,进行费用扣除
if (totalTokens >= 1000) {
// 计算费用
int token1 = totalTokens / 1000;
int token2 = totalTokens % 1000;
if(token2 > 0){
// 保存剩余tokens
chatToken.setToken(token2);
chatTokenService.editToken(chatToken);
}else {
chatTokenService.resetToken(chatMessageBo.getUserId(), chatMessageBo.getModelName());
}
chatMessageBo.setDeductCost(token1 * ChatCompletion.getModelCost(chatMessageBo.getModelName()));
// 扣除用户余额
deductUserBalance(chatMessageBo.getUserId(), chatMessageBo.getDeductCost());
} else {
chatMessageBo.setDeductCost(0d);
chatMessageBo.setRemark("不满1kToken,计入下一次!");
chatToken.setToken(totalTokens);
chatToken.setModelName(chatMessageBo.getModelName());
chatToken.setUserId(chatMessageBo.getUserId());
chatTokenService.editToken(chatToken);
}
// 保存消息记录
chatMessageService.insertByBo(chatMessageBo);
}
/**
* 从用户余额中扣除指定费用
*
* @param userId 用户ID
* @param numberCost 要扣除的费用
*/
@Override
public void deductUserBalance(Long userId, Double numberCost) {
SysUser sysUser = sysUserMapper.selectById(userId);
if (sysUser == null) {
return;
}
Double userBalance = sysUser.getUserBalance();
if (userBalance < numberCost) {
throw new ServiceException("余额不足,请联系管理员充值!");
}
sysUserMapper.update(null,
new LambdaUpdateWrapper<SysUser>()
.set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))
.eq(SysUser::getUserId, userId));
}
}

View File

@@ -1,380 +0,0 @@
package com.xmzs.system.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.json.JSONUtil;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.*;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.IterableStream;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.xmzs.common.chat.config.LocalCache;
import com.xmzs.common.chat.constant.OpenAIConst;
import com.xmzs.common.chat.domain.request.ChatRequest;
import com.xmzs.common.chat.domain.request.Dall3Request;
import com.xmzs.common.chat.entity.chat.*;
import com.xmzs.common.chat.entity.images.Image;
import com.xmzs.common.chat.entity.images.ImageResponse;
import com.xmzs.common.chat.entity.images.Item;
import com.xmzs.common.chat.entity.images.ResponseFormat;
import com.xmzs.common.chat.openai.OpenAiStreamClient;
import com.xmzs.common.chat.utils.TikTokensUtil;
import com.xmzs.common.core.domain.model.LoginUser;
import com.xmzs.common.core.exception.ServiceException;
import com.xmzs.common.core.exception.base.BaseException;
import com.xmzs.common.core.utils.StringUtils;
import com.xmzs.common.satoken.utils.LoginHelper;
import com.xmzs.common.translation.annotation.Translation;
import com.xmzs.system.domain.SysUser;
import com.xmzs.system.domain.bo.ChatMessageBo;
import com.xmzs.system.listener.SSEEventSourceListener;
import com.xmzs.system.mapper.SysUserMapper;
import com.xmzs.system.service.ChatService;
import com.xmzs.system.service.IChatMessageService;
import com.xmzs.system.service.SseService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import com.azure.ai.openai.models.ImageGenerationOptions;
import com.azure.core.models.ResponseError;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ImageGenerationData;
import com.azure.ai.openai.models.ImageGenerationOptions;
import com.azure.ai.openai.models.ImageGenerations;
import com.azure.core.credential.AzureKeyCredential;
/**
* 描述:
*
* @author https:www.unfbx.com
* @date 2023-04-08
*/
@Service
@Slf4j
@RequiredArgsConstructor
public class SseServiceImpl implements SseService {
private final OpenAiStreamClient openAiStreamClient;
private final ChatService chatService;
private final SysUserMapper sysUserMapper;
private final IChatMessageService chatMessageService;
@Value("${transit.apiKey}")
private String API_KEY;
@Value("${transit.apiHost}")
private String API_HOST;
private static final String DONE_SIGNAL = "[DONE]";
@Override
@Transactional
public SseEmitter sseChat(ChatRequest chatRequest) {
LocalCache.CACHE.put("userId",getUserId());
SseEmitter sseEmitter = new SseEmitter(0L);
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
checkUserGrade(sseEmitter, chatRequest.getModel());
// 获取对话消息列表
List<Message> msgList = chatRequest.getMessages();
// 图文识别上下文信息
List<Content> contentList = chatRequest.getContent();
// 图文识别模型
if (ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName().equals(chatRequest.getModel())) {
MessagePicture message = MessagePicture.builder().role(Message.Role.USER.getName()).content(contentList).build();
ChatCompletionWithPicture chatCompletion = ChatCompletionWithPicture
.builder()
.messages(Collections.singletonList(message))
.model(chatRequest.getModel())
.temperature(chatRequest.getTemperature())
.topP(chatRequest.getTop_p())
.stream(true)
.build();
openAiStreamClient.streamChatCompletion(chatCompletion, openAIEventSourceListener);
// 扣除图文对话费用
chatService.deductUserBalance(getUserId(),OpenAIConst.GPT4_COST);
String text = contentList.get(contentList.size() - 1).getText();
// 保存消息记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName(chatRequest.getModel());
chatMessageBo.setContent(text);
chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST);
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
} else {
ChatCompletion completion = ChatCompletion
.builder()
.messages(msgList)
.model(chatRequest.getModel())
.temperature(chatRequest.getTemperature())
.topP(chatRequest.getTop_p())
.stream(true)
.build();
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
Message message = msgList.get(msgList.size() - 1);
// 扣除余额
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), msgList);
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName(chatRequest.getModel());
chatMessageBo.setContent(message.getContent());
chatMessageBo.setTotalTokens(tokens);
chatService.deductToken(chatMessageBo);
}
return sseEmitter;
}
/**
* dall-e-3绘画接口
*
* @param request
* @return
*/
public List<Item> dall3(Dall3Request request) {
checkUserGrade(null,"");
// DALL3 绘图模型
Image image = Image.builder()
.responseFormat(ResponseFormat.URL.getName())
.model(Image.Model.DALL_E_3.getName())
.prompt(request.getPrompt())
.n(1)
.quality(request.getQuality())
.size(request.getSize())
.style(request.getStyle())
.build();
ImageResponse imageResponse = openAiStreamClient.genImages(image);
// 扣除费用
if(Objects.equals(request.getSize(), "1792x1024") || Objects.equals(request.getSize(), "1024x1792")){
chatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_HD_COST);
}else {
chatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_COST);
}
// 保存扣费记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
chatMessageBo.setContent(request.getPrompt());
chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST);
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
return imageResponse.getData();
}
@Override
public void mjTask() {
// 检验是否是免费用户
checkUserGrade(null,"");
chatService.deductUserBalance(getUserId(),OpenAIConst.MJ_COST);
// 保存扣费记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName("mj");
chatMessageBo.setContent("mj绘图");
chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST);
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
}
/**
* 中转接口
*
* @param chatRequest
* @return
*/
@Override
public SseEmitter transitChat(ChatRequest chatRequest) {
// 获取对话消息列表
List<Message> msgList = chatRequest.getMessages();
Message message = msgList.get(msgList.size() - 1);
SseEmitter emitter = new SseEmitter(0L);
checkUserGrade(emitter, chatRequest.getModel());
ChatCompletion completion = ChatCompletion
.builder()
.messages(chatRequest.getMessages())
.model(chatRequest.getModel())
.temperature(chatRequest.getTemperature())
.topP(chatRequest.getTop_p())
.stream(true)
.build();
// 启动一个新的线程来处理数据流
new Thread(() -> {
// 启动一个新的线程来处理数据流
try {
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(completion);
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(API_HOST + "v1/chat/completions"))
.header("Authorization", "Bearer " + API_KEY)
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
.build();
// 发送请求并获取响应体作为InputStream
HttpResponse<InputStream> response = HttpClient.newHttpClient().send(request, HttpResponse.BodyHandlers.ofInputStream());
// 使用正确的字符编码将InputStream包装为InputStreamReader然后创建BufferedReader
BufferedReader reader = new BufferedReader(new InputStreamReader(response.body()));
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith("data: ")) {
String data = line.replace("data: ", "");
emitter.send(data, MediaType.TEXT_PLAIN);
if (data.equals(DONE_SIGNAL)) {
//成功响应
emitter.complete();
}
}
}
// 关闭资源
reader.close();
} catch (Exception e) {
emitter.complete();
throw new ServiceException("调用中转接口失败:"+e.getMessage());
}
}).start();
chatService.deductUserBalance(getUserId(),OpenAIConst.GPT4_COST);
// 保存消息记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName(chatRequest.getModel());
chatMessageBo.setContent(message.getContent());
chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST);
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
return emitter;
}
public static void main(String[] args) {
String azureOpenaiKey = "-";
String endpoint = "-";
String deploymentOrModelName = "-";
OpenAIClient client = new OpenAIClientBuilder()
.endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey))
.buildClient();
ImageGenerationOptions imageGenerationOptions = new ImageGenerationOptions(
"A drawing of the Seattle skyline in the style of Van Gogh");
ImageGenerations images = client.getImageGenerations(deploymentOrModelName, imageGenerationOptions);
for (ImageGenerationData imageGenerationData : images.getData()) {
System.out.printf(
"Image location URL that provides temporary access to download the generated image is %s.%n",
imageGenerationData.getUrl());
}
}
public SseEmitter azureChat(ChatRequest chatRequest) {
String azureOpenaiKey = "-";
String endpoint = "-";
String deploymentOrModelId = "-";
OpenAIClient client = new OpenAIClientBuilder()
.endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey))
.buildClient();
final SseEmitter emitter = new SseEmitter();
// 使用线程池异步执行
ExecutorService service = Executors.newSingleThreadExecutor();
service.execute(() -> {
try {
// 获取对话消息列表
List<Message> chatMessages = chatRequest.getMessages();
List<ChatRequestMessage> messages = new ArrayList<>();
chatMessages.forEach(
e->{
ChatRequestMessage chatMessage;
if(Message.Role.SYSTEM.getName().equals(e.getRole())){
chatMessage = new ChatRequestSystemMessage(e.getContent());
}else {
chatMessage = new ChatRequestUserMessage(e.getContent());
}
messages.add(chatMessage);
}
);
// 获取流式响应
IterableStream<ChatCompletions> chatCompletionsStream = client.getChatCompletionsStream(deploymentOrModelId, new ChatCompletionsOptions(messages));
// 遍历流式响应并发送到客户端
for (ChatCompletions chatCompletion : chatCompletionsStream) {
if(CollectionUtil.isEmpty(chatCompletion.getChoices())){
continue;
}
log.info("json ======{}", JSONUtil.toJsonStr(chatCompletion));
emitter.send(chatCompletion);
}
emitter.complete();
} catch (Exception e) {
emitter.completeWithError(e);
}
});
return emitter;
}
/**
* 判断用户是否付费
*/
public void checkUserGrade(SseEmitter emitter, String model) {
SysUser sysUser = sysUserMapper.selectById(getUserId());
if(StringUtils.isEmpty(model)){
if("0".equals(sysUser.getUserGrade())){
throw new ServiceException("免费用户暂时不支持此模型,请切换gpt-3.5-turbo模型或者点击《进入市场选购您的商品》充值后使用!",500);
}
}
// TODO 添加枚举
if ("0".equals(sysUser.getUserGrade()) && !ChatCompletion.Model.GPT_3_5_TURBO.getName().equals(model)) {
// 创建并发送一个名为 "error" 的事件,带有错误消息和状态码
SseEmitter.SseEventBuilder event = SseEmitter.event()
.name("error") // 客户端将监听这个事件名
.data("免费用户暂时不支持此模型,请切换gpt-3.5-turbo模型或者点击《进入市场选购您的商品》充值后使用!");
try {
emitter.send(event);
} catch (IOException e) {
throw new RuntimeException(e);
}
emitter.complete();
}
}
/**
* 获取用户Id
*
* @return
*/
public Long getUserId(){
LoginUser loginUser = LoginHelper.getLoginUser();
if (loginUser == null) {
throw new BaseException("用户未登录!");
}
return loginUser.getUserId();
}
}