mirror of
https://github.com/zongzibinbin/MallChat.git
synced 2026-03-14 06:03:42 +08:00
feat:上下文
This commit is contained in:
@@ -2,11 +2,13 @@ package com.abin.mallchat.custom.chatai.domain;
|
|||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
|
@ToString
|
||||||
public class ChatGPTMsg implements Serializable {
|
public class ChatGPTMsg implements Serializable {
|
||||||
|
|
||||||
private String role;
|
private String role;
|
||||||
|
|||||||
@@ -3,16 +3,15 @@ package com.abin.mallchat.custom.chatai.handler;
|
|||||||
import com.abin.mallchat.common.chat.domain.entity.Message;
|
import com.abin.mallchat.common.chat.domain.entity.Message;
|
||||||
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
|
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
|
||||||
import com.abin.mallchat.common.common.constant.RedisKey;
|
import com.abin.mallchat.common.common.constant.RedisKey;
|
||||||
|
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
|
||||||
|
import com.abin.mallchat.common.common.exception.FrequencyControlException;
|
||||||
|
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
|
||||||
import com.abin.mallchat.common.common.utils.DateUtils;
|
import com.abin.mallchat.common.common.utils.DateUtils;
|
||||||
import com.abin.mallchat.common.common.utils.RedisUtils;
|
import com.abin.mallchat.common.common.utils.RedisUtils;
|
||||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
|
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
|
||||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
||||||
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder;
|
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder;
|
||||||
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder;
|
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder;
|
||||||
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
|
|
||||||
import com.abin.mallchat.common.common.exception.FrequencyControlException;
|
|
||||||
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
|
|
||||||
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
|
|
||||||
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
|
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
|
||||||
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
|
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
|
||||||
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
|
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
|
||||||
@@ -27,7 +26,6 @@ import java.util.List;
|
|||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT;
|
import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT;
|
||||||
|
|
||||||
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
|
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -72,15 +70,16 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected String doChat(Message message) {
|
protected String doChat(Message message) {
|
||||||
String content = message.getContent().replace("@" + AI_NAME, "").trim();
|
|
||||||
Long uid = message.getFromUid();
|
Long uid = message.getFromUid();
|
||||||
try {
|
try {
|
||||||
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
|
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
|
||||||
frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid);
|
frequencyControlDTO.setKey(RedisKey.getKey(CHAT_FREQUENCY_PREFIX) + ":" + uid);
|
||||||
frequencyControlDTO.setUnit(TimeUnit.HOURS);
|
frequencyControlDTO.setUnit(TimeUnit.HOURS);
|
||||||
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
|
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
|
||||||
frequencyControlDTO.setTime(24);
|
frequencyControlDTO.setTime(24);
|
||||||
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
|
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER,
|
||||||
|
frequencyControlDTO, // 限流参数
|
||||||
|
() -> sendRequestToGPT(message));
|
||||||
} catch (FrequencyControlException e) {
|
} catch (FrequencyControlException e) {
|
||||||
return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见";
|
return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见";
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
@@ -88,20 +87,24 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
|
|
||||||
String content = gptRequestDTO.getContent();
|
private String sendRequestToGPT(Message message) {
|
||||||
Long roomId = message.getRoomId();
|
ChatGPTContext context = buildContext(message);// 构建上下文
|
||||||
Long chatNum;
|
context = tailorContext(context);// 裁剪上下文
|
||||||
|
log.info("context = {}", context);
|
||||||
String text;
|
String text;
|
||||||
HttpResponse response = null;
|
|
||||||
try {
|
try {
|
||||||
response = ChatGPTUtils.create(chatGPTProperties.getKey())
|
Response response = ChatGPTUtils.create(chatGPTProperties.getKey())
|
||||||
.proxyUrl(chatGPTProperties.getProxyUrl())
|
.proxyUrl(chatGPTProperties.getProxyUrl())
|
||||||
.model(chatGPTProperties.getModelName())
|
.model(chatGPTProperties.getModelName())
|
||||||
.timeout(chatGPTProperties.getTimeout())
|
.timeout(chatGPTProperties.getTimeout())
|
||||||
.prompt(content)
|
.maxTokens(chatGPTProperties.getMaxTokens())
|
||||||
|
.message(context.getMsg())
|
||||||
.send();
|
.send();
|
||||||
text = ChatGPTUtils.parseText(response);
|
text = ChatGPTUtils.parseText(response);
|
||||||
|
ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text);
|
||||||
|
context.addMsg(chatGPTMsg);
|
||||||
|
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, message.getFromUid(), message.getRoomId()), context, 1L, TimeUnit.HOURS);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("gpt doChat warn:", e);
|
log.warn("gpt doChat warn:", e);
|
||||||
text = "我累了,明天再聊吧";
|
text = "我累了,明天再聊吧";
|
||||||
@@ -119,13 +122,15 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
|||||||
return tailorContext(context);
|
return tailorContext(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ChatGPTContext buildContext(Message message, String prompt) {
|
private ChatGPTContext buildContext(Message message) {
|
||||||
|
String prompt = message.getContent().replace("@" + AI_NAME, "").trim();
|
||||||
Long uid = message.getFromUid();
|
Long uid = message.getFromUid();
|
||||||
Long roomId = message.getRoomId();
|
Long roomId = message.getRoomId();
|
||||||
ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class);
|
ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class);
|
||||||
if (chatGPTContext == null) {
|
if (chatGPTContext == null) {
|
||||||
chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId);
|
chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId);
|
||||||
}
|
}
|
||||||
|
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), chatGPTContext, 1L, TimeUnit.HOURS);
|
||||||
chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt));
|
chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt));
|
||||||
return chatGPTContext;
|
return chatGPTContext;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.abin.mallchat.custom.chatai.utils;
|
|||||||
|
|
||||||
import com.abin.mallchat.common.common.exception.BusinessException;
|
import com.abin.mallchat.common.common.exception.BusinessException;
|
||||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
||||||
import com.alibaba.fastjson.JSONArray;
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.knuddels.jtokkit.Encodings;
|
import com.knuddels.jtokkit.Encodings;
|
||||||
import com.knuddels.jtokkit.api.Encoding;
|
import com.knuddels.jtokkit.api.Encoding;
|
||||||
@@ -13,10 +12,9 @@ import okhttp3.*;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ChatGPTUtils {
|
public class ChatGPTUtils {
|
||||||
@@ -85,17 +83,23 @@ public class ChatGPTUtils {
|
|||||||
return parseText(response.body().string());
|
return parseText(response.body().string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static String parseText(String body) {
|
public static String parseText(String body) {
|
||||||
log.info("body >>> " + body);
|
// log.info("body >>> " + body);
|
||||||
JSONObject jsonObject = JSONObject.parseObject(body);
|
try {
|
||||||
JSONObject error = jsonObject.getJSONObject("error");
|
return Arrays.stream(body.split("data:"))
|
||||||
if (error != null) {
|
.map(String::trim)
|
||||||
log.error("error >>> " + error);
|
.filter(x -> StringUtils.isNotBlank(x) && !"[DONE]".endsWith(x))
|
||||||
|
.map(x -> JSONObject.parseObject(x)
|
||||||
|
.getJSONArray("choices")
|
||||||
|
.getJSONObject(0)
|
||||||
|
.getJSONObject("delta")
|
||||||
|
.getString("content")
|
||||||
|
).filter(Objects::nonNull).collect(Collectors.joining());
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("parseText error e:", e);
|
||||||
return "闹脾气了,等会再试试吧~";
|
return "闹脾气了,等会再试试吧~";
|
||||||
}
|
}
|
||||||
JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices"));
|
|
||||||
JSONObject choice = choices.getJSONObject(0);
|
|
||||||
return choice.getJSONObject("message").getString("content");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatGPTUtils model(String model) {
|
public ChatGPTUtils model(String model) {
|
||||||
@@ -144,15 +148,6 @@ public class ChatGPTUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public Response send() throws IOException {
|
public Response send() throws IOException {
|
||||||
// cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject();
|
|
||||||
// param.set("model", model);
|
|
||||||
// param.set("messages", messages);
|
|
||||||
// param.set("max_tokens", maxTokens);
|
|
||||||
// param.set("temperature", temperature);
|
|
||||||
// param.set("top_p", topP);
|
|
||||||
// param.set("frequency_penalty", frequencyPenalty);
|
|
||||||
// param.set("presence_penalty", presencePenalty);
|
|
||||||
// log.info("headers >>> " + headers);
|
|
||||||
OkHttpClient okHttpClient = new OkHttpClient()
|
OkHttpClient okHttpClient = new OkHttpClient()
|
||||||
.newBuilder()
|
.newBuilder()
|
||||||
.connectTimeout(10, TimeUnit.SECONDS)
|
.connectTimeout(10, TimeUnit.SECONDS)
|
||||||
@@ -167,7 +162,9 @@ public class ChatGPTUtils {
|
|||||||
paramMap.put("top_p", topP);
|
paramMap.put("top_p", topP);
|
||||||
paramMap.put("frequency_penalty", frequencyPenalty);
|
paramMap.put("frequency_penalty", frequencyPenalty);
|
||||||
paramMap.put("presence_penalty", presencePenalty);
|
paramMap.put("presence_penalty", presencePenalty);
|
||||||
|
paramMap.put("stream", true);
|
||||||
|
|
||||||
|
log.info("paramMap >>> " + JSONObject.toJSONString(paramMap));
|
||||||
Request request = new Request.Builder()
|
Request request = new Request.Builder()
|
||||||
.url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
|
.url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
|
||||||
.addHeader("Content-Type", "application/json")
|
.addHeader("Content-Type", "application/json")
|
||||||
|
|||||||
Reference in New Issue
Block a user