feat:上下文

This commit is contained in:
zhaoyuhang
2023-07-10 23:11:41 +08:00
parent 0cb42d4adb
commit d8cc001eb8
3 changed files with 41 additions and 37 deletions

View File

@@ -2,11 +2,13 @@ package com.abin.mallchat.custom.chatai.domain;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import java.io.Serializable;
@Getter
@Setter
@ToString
public class ChatGPTMsg implements Serializable {
private String role;

View File

@@ -3,16 +3,15 @@ package com.abin.mallchat.custom.chatai.handler;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.common.constant.RedisKey;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.common.common.utils.DateUtils;
import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder;
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
@@ -27,7 +26,6 @@ import java.util.List;
import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j
@@ -72,15 +70,16 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
@Override
protected String doChat(Message message) {
String content = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid();
try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid);
frequencyControlDTO.setKey(RedisKey.getKey(CHAT_FREQUENCY_PREFIX) + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.HOURS);
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
frequencyControlDTO.setTime(24);
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER,
frequencyControlDTO, // 限流参数
() -> sendRequestToGPT(message));
} catch (FrequencyControlException e) {
return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见";
} catch (Throwable e) {
@@ -88,20 +87,24 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
}
}
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
String content = gptRequestDTO.getContent();
Long roomId = message.getRoomId();
Long chatNum;
private String sendRequestToGPT(Message message) {
ChatGPTContext context = buildContext(message);// 构建上下文
context = tailorContext(context);// 裁剪上下文
log.info("context = {}", context);
String text;
HttpResponse response = null;
try {
response = ChatGPTUtils.create(chatGPTProperties.getKey())
Response response = ChatGPTUtils.create(chatGPTProperties.getKey())
.proxyUrl(chatGPTProperties.getProxyUrl())
.model(chatGPTProperties.getModelName())
.timeout(chatGPTProperties.getTimeout())
.prompt(content)
.maxTokens(chatGPTProperties.getMaxTokens())
.message(context.getMsg())
.send();
text = ChatGPTUtils.parseText(response);
ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text);
context.addMsg(chatGPTMsg);
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, message.getFromUid(), message.getRoomId()), context, 1L, TimeUnit.HOURS);
} catch (Exception e) {
log.warn("gpt doChat warn:", e);
text = "我累了,明天再聊吧";
@@ -119,13 +122,15 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
return tailorContext(context);
}
private ChatGPTContext buildContext(Message message, String prompt) {
private ChatGPTContext buildContext(Message message) {
String prompt = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid();
Long roomId = message.getRoomId();
ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class);
if (chatGPTContext == null) {
chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId);
}
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), chatGPTContext, 1L, TimeUnit.HOURS);
chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt));
return chatGPTContext;
}

View File

@@ -2,7 +2,6 @@ package com.abin.mallchat.custom.chatai.utils;
import com.abin.mallchat.common.common.exception.BusinessException;
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
@@ -13,10 +12,9 @@ import okhttp3.*;
import org.apache.commons.lang3.StringUtils;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@Slf4j
public class ChatGPTUtils {
@@ -85,17 +83,23 @@ public class ChatGPTUtils {
return parseText(response.body().string());
}
public static String parseText(String body) {
log.info("body >>> " + body);
JSONObject jsonObject = JSONObject.parseObject(body);
JSONObject error = jsonObject.getJSONObject("error");
if (error != null) {
log.error("error >>> " + error);
// log.info("body >>> " + body);
try {
return Arrays.stream(body.split("data:"))
.map(String::trim)
.filter(x -> StringUtils.isNotBlank(x) && !"[DONE]".endsWith(x))
.map(x -> JSONObject.parseObject(x)
.getJSONArray("choices")
.getJSONObject(0)
.getJSONObject("delta")
.getString("content")
).filter(Objects::nonNull).collect(Collectors.joining());
} catch (Exception e) {
log.error("parseText error e:", e);
return "闹脾气了,等会再试试吧~";
}
JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices"));
JSONObject choice = choices.getJSONObject(0);
return choice.getJSONObject("message").getString("content");
}
public ChatGPTUtils model(String model) {
@@ -144,15 +148,6 @@ public class ChatGPTUtils {
}
public Response send() throws IOException {
// cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject();
// param.set("model", model);
// param.set("messages", messages);
// param.set("max_tokens", maxTokens);
// param.set("temperature", temperature);
// param.set("top_p", topP);
// param.set("frequency_penalty", frequencyPenalty);
// param.set("presence_penalty", presencePenalty);
// log.info("headers >>> " + headers);
OkHttpClient okHttpClient = new OkHttpClient()
.newBuilder()
.connectTimeout(10, TimeUnit.SECONDS)
@@ -167,7 +162,9 @@ public class ChatGPTUtils {
paramMap.put("top_p", topP);
paramMap.put("frequency_penalty", frequencyPenalty);
paramMap.put("presence_penalty", presencePenalty);
paramMap.put("stream", true);
log.info("paramMap >>> " + JSONObject.toJSONString(paramMap));
Request request = new Request.Builder()
.url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
.addHeader("Content-Type", "application/json")