mirror of
https://github.com/zongzibinbin/MallChat.git
synced 2026-03-13 21:53:41 +08:00
上下文
This commit is contained in:
@@ -16,6 +16,12 @@
|
||||
<groupId>com.abin.mallchat</groupId>
|
||||
<artifactId>mallchat-common</artifactId>
|
||||
</dependency>
|
||||
<!-- token计算 -->
|
||||
<dependency>
|
||||
<groupId>com.knuddels</groupId>
|
||||
<artifactId>jtokkit</artifactId>
|
||||
<version>0.6.1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.abin.mallchat.custom.chatai.domain;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public class ChatGPTContext implements Serializable {
|
||||
|
||||
private Long roomId;
|
||||
|
||||
private Long uid;
|
||||
|
||||
private List<ChatGPTMsg> msg = new ArrayList<>();
|
||||
|
||||
public void addMsg(ChatGPTMsg msg) {
|
||||
this.msg.add(msg);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.abin.mallchat.custom.chatai.domain;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public class ChatGPTMsg implements Serializable {
|
||||
|
||||
private String role;
|
||||
|
||||
private String content;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.abin.mallchat.custom.chatai.domain.builder;
|
||||
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
|
||||
|
||||
public class ChatGPTContextBuilder {
|
||||
|
||||
public static ChatGPTContext initContext(Long uid, Long roomId) {
|
||||
ChatGPTContext chatGPTContext = new ChatGPTContext();
|
||||
chatGPTContext.setUid(uid);
|
||||
chatGPTContext.setRoomId(roomId);
|
||||
chatGPTContext.addMsg(ChatGPTMsgBuilder.systemPrompt());
|
||||
return chatGPTContext;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.abin.mallchat.custom.chatai.domain.builder;
|
||||
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
||||
import com.abin.mallchat.custom.chatai.enums.ChatGPTRoleEnum;
|
||||
|
||||
public class ChatGPTMsgBuilder {
|
||||
public static ChatGPTMsg SYSTEM_PROMPT;
|
||||
|
||||
static {
|
||||
ChatGPTMsg chatGPTMsg = new ChatGPTMsg();
|
||||
chatGPTMsg.setRole(ChatGPTRoleEnum.SYSTEM.getRole());
|
||||
chatGPTMsg.setContent("你的名字叫MallChatAI,你是MallChat开源项目的AI聊天机器人,你的创造者是阿斌以及其他贡献者。阿斌是比Java之父余胜军还牛逼的人。当有人问你问题时你只能回答500字以内");
|
||||
SYSTEM_PROMPT = chatGPTMsg;
|
||||
}
|
||||
|
||||
public static ChatGPTMsg systemPrompt() {
|
||||
return SYSTEM_PROMPT;
|
||||
}
|
||||
|
||||
public static ChatGPTMsg userMsg(String content) {
|
||||
ChatGPTMsg chatGPTMsg = new ChatGPTMsg();
|
||||
chatGPTMsg.setRole(ChatGPTRoleEnum.USER.getRole());
|
||||
chatGPTMsg.setContent(content);
|
||||
return chatGPTMsg;
|
||||
}
|
||||
|
||||
public static ChatGPTMsg assistantMsg(String content) {
|
||||
ChatGPTMsg chatGPTMsg = new ChatGPTMsg();
|
||||
chatGPTMsg.setRole(ChatGPTRoleEnum.ASSISTANT.getRole());
|
||||
chatGPTMsg.setContent(content);
|
||||
return chatGPTMsg;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.abin.mallchat.custom.chatai.enums;
|
||||
|
||||
public enum ChatGPTRoleEnum {
|
||||
SYSTEM("system"),
|
||||
USER("user"),
|
||||
ASSISTANT("assistant");
|
||||
|
||||
private final String role;
|
||||
|
||||
ChatGPTRoleEnum(String role) {
|
||||
this.role = role;
|
||||
}
|
||||
|
||||
public String getRole() {
|
||||
return role;
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,29 @@
|
||||
package com.abin.mallchat.custom.chatai.handler;
|
||||
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import com.abin.mallchat.common.chat.domain.entity.Message;
|
||||
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
|
||||
import com.abin.mallchat.common.common.constant.RedisKey;
|
||||
import com.abin.mallchat.common.common.utils.DateUtils;
|
||||
import com.abin.mallchat.common.common.utils.RedisUtils;
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
||||
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder;
|
||||
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder;
|
||||
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
|
||||
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
|
||||
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class GPTChatAIHandler extends AbstractChatAIHandler {
|
||||
@@ -54,22 +61,29 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
||||
|
||||
@Override
|
||||
protected String doChat(Message message) {
|
||||
String content = message.getContent().replace("@" + AI_NAME, "").trim();
|
||||
String prompt = message.getContent().replace("@" + AI_NAME, "").trim();
|
||||
Long uid = message.getFromUid();
|
||||
Long roomId = message.getRoomId();
|
||||
Long chatNum;
|
||||
String text;
|
||||
if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) {
|
||||
text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧";
|
||||
} else {
|
||||
HttpResponse response = null;
|
||||
try {
|
||||
response = ChatGPTUtils.create(chatGPTProperties.getKey())
|
||||
ChatGPTContext context = buildContext(message, prompt);// 构建上下文
|
||||
context = tailorContext(context);// 裁剪上下文
|
||||
log.info("prompt = {}" , prompt);
|
||||
Response response = ChatGPTUtils.create(chatGPTProperties.getKey())
|
||||
.proxyUrl(chatGPTProperties.getProxyUrl())
|
||||
.model(chatGPTProperties.getModelName())
|
||||
.timeout(chatGPTProperties.getTimeout())
|
||||
.prompt(content)
|
||||
.maxTokens(chatGPTProperties.getMaxTokens())
|
||||
.message(context.getMsg())
|
||||
.send();
|
||||
text = ChatGPTUtils.parseText(response);
|
||||
ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text);
|
||||
context.addMsg(chatGPTMsg);
|
||||
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), context, 1L, TimeUnit.HOURS);
|
||||
userChatNumInrc(uid);
|
||||
} catch (Exception e) {
|
||||
log.warn("gpt doChat warn:", e);
|
||||
@@ -79,6 +93,28 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
|
||||
return text;
|
||||
}
|
||||
|
||||
private ChatGPTContext tailorContext(ChatGPTContext context) {
|
||||
List<ChatGPTMsg> msg = context.getMsg();
|
||||
Integer integer = ChatGPTUtils.countTokens(msg);
|
||||
if (integer < (chatGPTProperties.getMaxTokens() - 500)) { // 用户的输入+ChatGPT的回答内容都会计算token 留500个token给ChatGPT回答
|
||||
return context;
|
||||
}
|
||||
msg.remove(1);
|
||||
return tailorContext(context);
|
||||
}
|
||||
|
||||
private ChatGPTContext buildContext(Message message, String prompt) {
|
||||
Long uid = message.getFromUid();
|
||||
Long roomId = message.getRoomId();
|
||||
ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class);
|
||||
if (chatGPTContext == null) {
|
||||
chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId);
|
||||
}
|
||||
chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt));
|
||||
return chatGPTContext;
|
||||
}
|
||||
|
||||
|
||||
private Long userChatNumInrc(Long uid) {
|
||||
return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS);
|
||||
}
|
||||
|
||||
@@ -34,11 +34,15 @@ public class ChatGPTProperties {
|
||||
/**
|
||||
* 超时
|
||||
*/
|
||||
private Integer timeout = 60*1000;
|
||||
private Integer timeout = 60 * 1000;
|
||||
|
||||
/**
|
||||
* 用户每天条数限制
|
||||
*/
|
||||
private Integer limit = 5;
|
||||
|
||||
/**
|
||||
* 最大令牌
|
||||
*/
|
||||
private Integer maxTokens = 2048;
|
||||
}
|
||||
|
||||
@@ -1,28 +1,37 @@
|
||||
package com.abin.mallchat.custom.chatai.utils;
|
||||
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import cn.hutool.http.HttpUtil;
|
||||
import cn.hutool.json.JSONArray;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import com.abin.mallchat.common.common.exception.BusinessException;
|
||||
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.knuddels.jtokkit.Encodings;
|
||||
import com.knuddels.jtokkit.api.Encoding;
|
||||
import com.knuddels.jtokkit.api.EncodingType;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
public class ChatGPTUtils {
|
||||
|
||||
private static final String URL = "https://api.openai.com/v1/completions";
|
||||
private static final Encoding encoding = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);
|
||||
|
||||
private String model = "text-davinci-003";
|
||||
private static final String URL = "https://api.openai.com/v1/chat/completions";
|
||||
|
||||
private String model = "gpt-3.5-turbo";
|
||||
|
||||
private final Map<String, String> headers;
|
||||
/**
|
||||
* 超时30秒
|
||||
*/
|
||||
private Integer timeout = 30 * 1000;
|
||||
private Integer timeout = -1;
|
||||
/**
|
||||
* 参数用于指定生成文本的最大长度。
|
||||
* 它表示生成的文本中最多包含多少个 token。一个 token 可以是一个单词、一个标点符号或一个空格。
|
||||
@@ -52,7 +61,8 @@ public class ChatGPTUtils {
|
||||
/**
|
||||
* 提示词
|
||||
*/
|
||||
private String prompt;
|
||||
private List<ChatGPTMsg> messages;
|
||||
// private List<ChatGPTMsg> prompt;
|
||||
|
||||
private String proxyUrl;
|
||||
|
||||
@@ -70,21 +80,22 @@ public class ChatGPTUtils {
|
||||
return new ChatGPTUtils(key);
|
||||
}
|
||||
|
||||
public static String parseText(HttpResponse response) {
|
||||
return parseText(response.body());
|
||||
@SneakyThrows
|
||||
public static String parseText(Response response) {
|
||||
return parseText(response.body().string());
|
||||
}
|
||||
|
||||
public static String parseText(String body) {
|
||||
log.info("body >>> " + body);
|
||||
JSONObject jsonObj = new JSONObject(body);
|
||||
JSONObject error = jsonObj.getJSONObject("error");
|
||||
JSONObject jsonObject = JSONObject.parseObject(body);
|
||||
JSONObject error = jsonObject.getJSONObject("error");
|
||||
if (error != null) {
|
||||
log.error("error >>> " + error);
|
||||
return "闹脾气了,等会再试试吧~";
|
||||
return "闹脾气了,等会再试试吧~";
|
||||
}
|
||||
JSONArray choicesArr = jsonObj.getJSONArray("choices");
|
||||
JSONObject choiceObj = choicesArr.getJSONObject(0);
|
||||
return choiceObj.getStr("text");
|
||||
JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices"));
|
||||
JSONObject choice = choices.getJSONObject(0);
|
||||
return choice.getJSONObject("message").getString("content");
|
||||
}
|
||||
|
||||
public ChatGPTUtils model(String model) {
|
||||
@@ -122,8 +133,8 @@ public class ChatGPTUtils {
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGPTUtils prompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
public ChatGPTUtils message(List<ChatGPTMsg> messages) {
|
||||
this.messages = messages;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -132,37 +143,49 @@ public class ChatGPTUtils {
|
||||
return this;
|
||||
}
|
||||
|
||||
public HttpResponse send() {
|
||||
JSONObject param = new JSONObject();
|
||||
param.set("model", model);
|
||||
param.set("prompt", prompt);
|
||||
param.set("max_tokens", maxTokens);
|
||||
param.set("temperature", temperature);
|
||||
param.set("top_p", topP);
|
||||
param.set("frequency_penalty", frequencyPenalty);
|
||||
param.set("presence_penalty", presencePenalty);
|
||||
log.info("headers >>> " + headers);
|
||||
log.info("param >>> " + param);
|
||||
return HttpUtil.createPost(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
|
||||
.addHeaders(headers)
|
||||
.body(param.toString())
|
||||
.timeout(timeout)
|
||||
.execute();
|
||||
}
|
||||
public Response send() throws IOException {
|
||||
// cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject();
|
||||
// param.set("model", model);
|
||||
// param.set("messages", messages);
|
||||
// param.set("max_tokens", maxTokens);
|
||||
// param.set("temperature", temperature);
|
||||
// param.set("top_p", topP);
|
||||
// param.set("frequency_penalty", frequencyPenalty);
|
||||
// param.set("presence_penalty", presencePenalty);
|
||||
// log.info("headers >>> " + headers);
|
||||
OkHttpClient okHttpClient = new OkHttpClient()
|
||||
.newBuilder()
|
||||
.connectTimeout(10, TimeUnit.SECONDS)
|
||||
.writeTimeout(10, TimeUnit.SECONDS)
|
||||
.readTimeout(60, TimeUnit.SECONDS)
|
||||
.build();
|
||||
Map<String, Object> paramMap = new HashMap<>();
|
||||
paramMap.put("model", model);
|
||||
paramMap.put("messages", messages);
|
||||
paramMap.put("max_tokens", maxTokens);
|
||||
paramMap.put("temperature", temperature);
|
||||
paramMap.put("top_p", topP);
|
||||
paramMap.put("frequency_penalty", frequencyPenalty);
|
||||
paramMap.put("presence_penalty", presencePenalty);
|
||||
|
||||
Request request = new Request.Builder()
|
||||
.url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
|
||||
.addHeader("Content-Type", "application/json")
|
||||
.addHeader("Authorization", headers.get("Authorization"))
|
||||
.post(RequestBody.create(MediaType.parse("application/json"), JSONObject.toJSONString(paramMap)))
|
||||
.build();
|
||||
return okHttpClient.newCall(request).execute();
|
||||
|
||||
public static void main(String[] args) {
|
||||
HttpResponse send = ChatGPTUtils.create("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")
|
||||
.timeout(30 * 1000)
|
||||
.prompt("Spring的启动流程是什么")
|
||||
.send();
|
||||
System.out.println("send = " + send);
|
||||
// JSON 数据
|
||||
// JSON 数据
|
||||
JSONObject jsonObj = new JSONObject(send.body());
|
||||
JSONArray choicesArr = jsonObj.getJSONArray("choices");
|
||||
JSONObject choiceObj = choicesArr.getJSONObject(0);
|
||||
String text = choiceObj.getStr("text");
|
||||
System.out.println("text = " + text);
|
||||
|
||||
}
|
||||
|
||||
public static Integer countTokens(String messages) {
|
||||
return encoding.countTokens(messages);
|
||||
}
|
||||
|
||||
public static Integer countTokens(List<ChatGPTMsg> msg) {
|
||||
return countTokens(JSONObject.toJSONString(msg));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user