mirror of
https://github.com/zongzibinbin/MallChat.git
synced 2026-03-13 21:53:41 +08:00
feat:上下文
This commit is contained in:
@@ -2,11 +2,13 @@ package com.abin.mallchat.custom.chatai.domain;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@ToString
|
||||
public class ChatGPTMsg implements Serializable {
|
||||
|
||||
private String role;
|
||||
|
||||
@@ -3,16 +3,15 @@ package com.abin.mallchat.custom.chatai.handler;
|
||||
import com.abin.mallchat.common.chat.domain.entity.Message;
|
||||
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
|
||||
import com.abin.mallchat.common.common.constant.RedisKey;
|
||||
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
|
||||
import com.abin.mallchat.common.common.exception.FrequencyControlException;
|
||||
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
|
||||
import com.abin.mallchat.common.common.utils.DateUtils;
|
||||
import com.abin.mallchat.common.common.utils.RedisUtils;
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
||||
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder;
|
||||
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder;
|
||||
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
|
||||
import com.abin.mallchat.common.common.exception.FrequencyControlException;
|
||||
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
|
||||
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
|
||||
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
|
||||
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
|
||||
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
|
||||
@@ -27,7 +26,6 @@ import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT;
|
||||
|
||||
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
|
||||
|
||||
@Slf4j
|
||||
@@ -72,15 +70,16 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
||||
|
||||
@Override
|
||||
protected String doChat(Message message) {
|
||||
String content = message.getContent().replace("@" + AI_NAME, "").trim();
|
||||
Long uid = message.getFromUid();
|
||||
try {
|
||||
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
|
||||
frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid);
|
||||
frequencyControlDTO.setKey(RedisKey.getKey(CHAT_FREQUENCY_PREFIX) + ":" + uid);
|
||||
frequencyControlDTO.setUnit(TimeUnit.HOURS);
|
||||
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
|
||||
frequencyControlDTO.setTime(24);
|
||||
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
|
||||
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER,
|
||||
frequencyControlDTO, // 限流参数
|
||||
() -> sendRequestToGPT(message));
|
||||
} catch (FrequencyControlException e) {
|
||||
return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见";
|
||||
} catch (Throwable e) {
|
||||
@@ -88,20 +87,24 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
||||
}
|
||||
}
|
||||
|
||||
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
|
||||
String content = gptRequestDTO.getContent();
|
||||
Long roomId = message.getRoomId();
|
||||
Long chatNum;
|
||||
|
||||
private String sendRequestToGPT(Message message) {
|
||||
ChatGPTContext context = buildContext(message);// 构建上下文
|
||||
context = tailorContext(context);// 裁剪上下文
|
||||
log.info("context = {}", context);
|
||||
String text;
|
||||
HttpResponse response = null;
|
||||
try {
|
||||
response = ChatGPTUtils.create(chatGPTProperties.getKey())
|
||||
Response response = ChatGPTUtils.create(chatGPTProperties.getKey())
|
||||
.proxyUrl(chatGPTProperties.getProxyUrl())
|
||||
.model(chatGPTProperties.getModelName())
|
||||
.timeout(chatGPTProperties.getTimeout())
|
||||
.prompt(content)
|
||||
.maxTokens(chatGPTProperties.getMaxTokens())
|
||||
.message(context.getMsg())
|
||||
.send();
|
||||
text = ChatGPTUtils.parseText(response);
|
||||
ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text);
|
||||
context.addMsg(chatGPTMsg);
|
||||
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, message.getFromUid(), message.getRoomId()), context, 1L, TimeUnit.HOURS);
|
||||
} catch (Exception e) {
|
||||
log.warn("gpt doChat warn:", e);
|
||||
text = "我累了,明天再聊吧";
|
||||
@@ -119,13 +122,15 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
||||
return tailorContext(context);
|
||||
}
|
||||
|
||||
private ChatGPTContext buildContext(Message message, String prompt) {
|
||||
private ChatGPTContext buildContext(Message message) {
|
||||
String prompt = message.getContent().replace("@" + AI_NAME, "").trim();
|
||||
Long uid = message.getFromUid();
|
||||
Long roomId = message.getRoomId();
|
||||
ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class);
|
||||
if (chatGPTContext == null) {
|
||||
chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId);
|
||||
}
|
||||
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), chatGPTContext, 1L, TimeUnit.HOURS);
|
||||
chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt));
|
||||
return chatGPTContext;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.abin.mallchat.custom.chatai.utils;
|
||||
|
||||
import com.abin.mallchat.common.common.exception.BusinessException;
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.knuddels.jtokkit.Encodings;
|
||||
import com.knuddels.jtokkit.api.Encoding;
|
||||
@@ -13,10 +12,9 @@ import okhttp3.*;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class ChatGPTUtils {
|
||||
@@ -85,17 +83,23 @@ public class ChatGPTUtils {
|
||||
return parseText(response.body().string());
|
||||
}
|
||||
|
||||
|
||||
public static String parseText(String body) {
|
||||
log.info("body >>> " + body);
|
||||
JSONObject jsonObject = JSONObject.parseObject(body);
|
||||
JSONObject error = jsonObject.getJSONObject("error");
|
||||
if (error != null) {
|
||||
log.error("error >>> " + error);
|
||||
// log.info("body >>> " + body);
|
||||
try {
|
||||
return Arrays.stream(body.split("data:"))
|
||||
.map(String::trim)
|
||||
.filter(x -> StringUtils.isNotBlank(x) && !"[DONE]".endsWith(x))
|
||||
.map(x -> JSONObject.parseObject(x)
|
||||
.getJSONArray("choices")
|
||||
.getJSONObject(0)
|
||||
.getJSONObject("delta")
|
||||
.getString("content")
|
||||
).filter(Objects::nonNull).collect(Collectors.joining());
|
||||
} catch (Exception e) {
|
||||
log.error("parseText error e:", e);
|
||||
return "闹脾气了,等会再试试吧~";
|
||||
}
|
||||
JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices"));
|
||||
JSONObject choice = choices.getJSONObject(0);
|
||||
return choice.getJSONObject("message").getString("content");
|
||||
}
|
||||
|
||||
public ChatGPTUtils model(String model) {
|
||||
@@ -144,15 +148,6 @@ public class ChatGPTUtils {
|
||||
}
|
||||
|
||||
public Response send() throws IOException {
|
||||
// cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject();
|
||||
// param.set("model", model);
|
||||
// param.set("messages", messages);
|
||||
// param.set("max_tokens", maxTokens);
|
||||
// param.set("temperature", temperature);
|
||||
// param.set("top_p", topP);
|
||||
// param.set("frequency_penalty", frequencyPenalty);
|
||||
// param.set("presence_penalty", presencePenalty);
|
||||
// log.info("headers >>> " + headers);
|
||||
OkHttpClient okHttpClient = new OkHttpClient()
|
||||
.newBuilder()
|
||||
.connectTimeout(10, TimeUnit.SECONDS)
|
||||
@@ -167,7 +162,9 @@ public class ChatGPTUtils {
|
||||
paramMap.put("top_p", topP);
|
||||
paramMap.put("frequency_penalty", frequencyPenalty);
|
||||
paramMap.put("presence_penalty", presencePenalty);
|
||||
paramMap.put("stream", true);
|
||||
|
||||
log.info("paramMap >>> " + JSONObject.toJSONString(paramMap));
|
||||
Request request = new Request.Builder()
|
||||
.url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
|
||||
.addHeader("Content-Type", "application/json")
|
||||
|
||||
Reference in New Issue
Block a user