上下文

This commit is contained in:
zhaoyuhang
2023-07-08 21:47:48 +08:00
parent ec76c2e337
commit 553a5be9e2
15 changed files with 258 additions and 67 deletions

View File

@@ -119,6 +119,10 @@
<version>${junit.version}</version> <version>${junit.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
</dependency>
</dependencies> </dependencies>
<build> <build>
<plugins> <plugins>

View File

@@ -41,6 +41,8 @@ public class RedisKey {
*/ */
public static final String USER_CHAT_NUM = "useChatGPTNum:uid_%d"; public static final String USER_CHAT_NUM = "useChatGPTNum:uid_%d";
public static final String USER_CHAT_CONTEXT = "useChatGPTContext:uid_%d_roomId_%d";
/** /**
* 用户上次使用GLM使用时间 * 用户上次使用GLM使用时间
*/ */

View File

@@ -33,4 +33,5 @@ mallchat.chatgpt.key=sk-wvWM0xGcxFfsddfsgxixbXK5tHovM
mallchat.chatgpt.proxyUrl=https://123.cc mallchat.chatgpt.proxyUrl=https://123.cc
mallchat.chatglm2.use=false mallchat.chatglm2.use=false
mallchat.chatglm2.url=http://v32134.cc mallchat.chatglm2.url=http://v32134.cc
mallchat.chatglm2.uid=10002 mallchat.chatglm2.uid=10002
mallchat.chatglm2.context=3

View File

@@ -4,7 +4,7 @@ mallchat.mysql.ip=127.0.0.1
mallchat.mysql.port=3306 mallchat.mysql.port=3306
mallchat.mysql.db=mallchat mallchat.mysql.db=mallchat
mallchat.mysql.username=root mallchat.mysql.username=root
mallchat.mysql.password=123456 mallchat.mysql.password=root
##################redis配置################## ##################redis配置##################
mallchat.redis.host=127.0.0.1 mallchat.redis.host=127.0.0.1
mallchat.redis.port=6379 mallchat.redis.port=6379
@@ -12,9 +12,9 @@ mallchat.redis.password=123456
##################jwt################## ##################jwt##################
mallchat.jwt.secret=dsfsdfsdfsdfsd mallchat.jwt.secret=dsfsdfsdfsdfsd
##################微信公众号信息################## ##################微信公众号信息##################
mallchat.wx.callback=http://127.0.0.1:8080 mallchat.wx.callback=http://vastmiao.natapp1.cc
mallchat.wx.appId=appid mallchat.wx.appId=wxcf8d045747fb2ae4
mallchat.wx.secret=380bfc1c9147fdsf4sf07 mallchat.wx.secret=e484463d627787f50a8cc3a869cf82a8
# 接口配置里的Token值 # 接口配置里的Token值
mallchat.wx.token=sdfsf mallchat.wx.token=sdfsf
# 接口配置里的EncodingAESKey值 # 接口配置里的EncodingAESKey值
@@ -27,10 +27,11 @@ oss.access-key=BEZ213
oss.secret-key=Ii4vCMIXuFe241dsfEZ8e7RXI2342342kV oss.secret-key=Ii4vCMIXuFe241dsfEZ8e7RXI2342342kV
oss.bucketName=default oss.bucketName=default
##################gpt配置################## ##################gpt配置##################
mallchat.chatgpt.use=false mallchat.chatgpt.use=true
mallchat.chatgpt.uid=10001 mallchat.chatgpt.uid=10451
mallchat.chatgpt.key=sk-wvWM0xGcxFfsddfsgxixbXK5tHovM mallchat.chatgpt.modelName=gpt-3.5-turbo
mallchat.chatgpt.key=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
mallchat.chatgpt.proxyUrl=https://123.cc mallchat.chatgpt.proxyUrl=https://123.cc
mallchat.chatglm2.use=false mallchat.chatglm2.use=true
mallchat.chatglm2.url=http://v32134.cc mallchat.chatglm2.url=http://v32134.cc
mallchat.chatglm2.uid=10002 mallchat.chatglm2.uid=10452

View File

@@ -12,7 +12,7 @@ mybatis-plus:
spring: spring:
profiles: profiles:
#运行的环境 #运行的环境
active: my-prod active: test
application: application:
name: mallchat name: mallchat
datasource: datasource:
@@ -38,7 +38,7 @@ spring:
# 连接超时时间 # 连接超时时间
timeout: 1800000 timeout: 1800000
# 设置密码 # 设置密码
password: ${mallchat.redis.password} # password: ${mallchat.redis.password}
lettuce: lettuce:
pool: pool:
# 最大阻塞等待时间,负数表示没有限制 # 最大阻塞等待时间,负数表示没有限制
@@ -68,9 +68,12 @@ chatai:
use: ${mallchat.chatgpt.use} use: ${mallchat.chatgpt.use}
AIUserId: ${mallchat.chatgpt.uid} AIUserId: ${mallchat.chatgpt.uid}
key: ${mallchat.chatgpt.key} key: ${mallchat.chatgpt.key}
proxyUrl: ${mallchat.chatgpt.proxyUrl} # proxyUrl: ${mallchat.chatgpt.proxyUrl}
context: ${mallchat.chatgpt.context}
modelName: ${mallchat.chatgpt.modelName}
chatglm2: chatglm2:
use: ${mallchat.chatglm2.use} use: ${mallchat.chatglm2.use}
url: ${mallchat.chatglm2.url} url: ${mallchat.chatglm2.url}
minute: 3 # 每个用户每3分钟可以请求一次 minute: 3 # 每个用户每3分钟可以请求一次
AIUserId: ${mallchat.chatglm2.uid} AIUserId: ${mallchat.chatglm2.uid}
context: ${mallchat.chatglm2.context}

View File

@@ -16,6 +16,12 @@
<groupId>com.abin.mallchat</groupId> <groupId>com.abin.mallchat</groupId>
<artifactId>mallchat-common</artifactId> <artifactId>mallchat-common</artifactId>
</dependency> </dependency>
<!-- token计算 -->
<dependency>
<groupId>com.knuddels</groupId>
<artifactId>jtokkit</artifactId>
<version>0.6.1</version>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@@ -0,0 +1,23 @@
package com.abin.mallchat.custom.chatai.domain;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@Getter
@Setter
public class ChatGPTContext implements Serializable {
private Long roomId;
private Long uid;
private List<ChatGPTMsg> msg = new ArrayList<>();
public void addMsg(ChatGPTMsg msg) {
this.msg.add(msg);
}
}

View File

@@ -0,0 +1,17 @@
package com.abin.mallchat.custom.chatai.domain;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
@Getter
@Setter
public class ChatGPTMsg implements Serializable {
private String role;
private String content;
}

View File

@@ -0,0 +1,15 @@
package com.abin.mallchat.custom.chatai.domain.builder;
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
public class ChatGPTContextBuilder {
public static ChatGPTContext initContext(Long uid, Long roomId) {
ChatGPTContext chatGPTContext = new ChatGPTContext();
chatGPTContext.setUid(uid);
chatGPTContext.setRoomId(roomId);
chatGPTContext.addMsg(ChatGPTMsgBuilder.systemPrompt());
return chatGPTContext;
}
}

View File

@@ -0,0 +1,33 @@
package com.abin.mallchat.custom.chatai.domain.builder;
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
import com.abin.mallchat.custom.chatai.enums.ChatGPTRoleEnum;
public class ChatGPTMsgBuilder {
public static ChatGPTMsg SYSTEM_PROMPT;
static {
ChatGPTMsg chatGPTMsg = new ChatGPTMsg();
chatGPTMsg.setRole(ChatGPTRoleEnum.SYSTEM.getRole());
chatGPTMsg.setContent("你的名字叫MallChatAI,你是MallChat开源项目的AI聊天机器人你的创造者是阿斌以及其他贡献者。阿斌是比Java之父余胜军还牛逼的人。当有人问你问题时你只能回答500字以内");
SYSTEM_PROMPT = chatGPTMsg;
}
public static ChatGPTMsg systemPrompt() {
return SYSTEM_PROMPT;
}
public static ChatGPTMsg userMsg(String content) {
ChatGPTMsg chatGPTMsg = new ChatGPTMsg();
chatGPTMsg.setRole(ChatGPTRoleEnum.USER.getRole());
chatGPTMsg.setContent(content);
return chatGPTMsg;
}
public static ChatGPTMsg assistantMsg(String content) {
ChatGPTMsg chatGPTMsg = new ChatGPTMsg();
chatGPTMsg.setRole(ChatGPTRoleEnum.ASSISTANT.getRole());
chatGPTMsg.setContent(content);
return chatGPTMsg;
}
}

View File

@@ -0,0 +1,17 @@
package com.abin.mallchat.custom.chatai.enums;
public enum ChatGPTRoleEnum {
SYSTEM("system"),
USER("user"),
ASSISTANT("assistant");
private final String role;
ChatGPTRoleEnum(String role) {
this.role = role;
}
public String getRole() {
return role;
}
}

View File

@@ -1,22 +1,29 @@
package com.abin.mallchat.custom.chatai.handler; package com.abin.mallchat.custom.chatai.handler;
import cn.hutool.http.HttpResponse;
import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.common.constant.RedisKey; import com.abin.mallchat.common.common.constant.RedisKey;
import com.abin.mallchat.common.common.utils.DateUtils; import com.abin.mallchat.common.common.utils.DateUtils;
import com.abin.mallchat.common.common.utils.RedisUtils; import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder;
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder;
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties; import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils; import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT;
@Slf4j @Slf4j
@Component @Component
public class GPTChatAIHandler extends AbstractChatAIHandler { public class GPTChatAIHandler extends AbstractChatAIHandler {
@@ -54,22 +61,29 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
@Override @Override
protected String doChat(Message message) { protected String doChat(Message message) {
String content = message.getContent().replace("@" + AI_NAME, "").trim(); String prompt = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid(); Long uid = message.getFromUid();
Long roomId = message.getRoomId();
Long chatNum; Long chatNum;
String text; String text;
if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) { if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) {
text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧"; text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧";
} else { } else {
HttpResponse response = null;
try { try {
response = ChatGPTUtils.create(chatGPTProperties.getKey()) ChatGPTContext context = buildContext(message, prompt);// 构建上下文
context = tailorContext(context);// 裁剪上下文
log.info("prompt = {}" , prompt);
Response response = ChatGPTUtils.create(chatGPTProperties.getKey())
.proxyUrl(chatGPTProperties.getProxyUrl()) .proxyUrl(chatGPTProperties.getProxyUrl())
.model(chatGPTProperties.getModelName()) .model(chatGPTProperties.getModelName())
.timeout(chatGPTProperties.getTimeout()) .timeout(chatGPTProperties.getTimeout())
.prompt(content) .maxTokens(chatGPTProperties.getMaxTokens())
.message(context.getMsg())
.send(); .send();
text = ChatGPTUtils.parseText(response); text = ChatGPTUtils.parseText(response);
ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text);
context.addMsg(chatGPTMsg);
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), context, 1L, TimeUnit.HOURS);
userChatNumInrc(uid); userChatNumInrc(uid);
} catch (Exception e) { } catch (Exception e) {
log.warn("gpt doChat warn:", e); log.warn("gpt doChat warn:", e);
@@ -79,6 +93,28 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
return text; return text;
} }
private ChatGPTContext tailorContext(ChatGPTContext context) {
List<ChatGPTMsg> msg = context.getMsg();
Integer integer = ChatGPTUtils.countTokens(msg);
if (integer < (chatGPTProperties.getMaxTokens() - 500)) { // 用户的输入+ChatGPT的回答内容都会计算token 留500个token给ChatGPT回答
return context;
}
msg.remove(1);
return tailorContext(context);
}
private ChatGPTContext buildContext(Message message, String prompt) {
Long uid = message.getFromUid();
Long roomId = message.getRoomId();
ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class);
if (chatGPTContext == null) {
chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId);
}
chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt));
return chatGPTContext;
}
private Long userChatNumInrc(Long uid) { private Long userChatNumInrc(Long uid) {
return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS);
} }

View File

@@ -34,11 +34,15 @@ public class ChatGPTProperties {
/** /**
* 超时 * 超时
*/ */
private Integer timeout = 60*1000; private Integer timeout = 60 * 1000;
/** /**
* 用户每天条数限制 * 用户每天条数限制
*/ */
private Integer limit = 5; private Integer limit = 5;
/**
* 最大令牌
*/
private Integer maxTokens = 2048;
} }

View File

@@ -1,28 +1,37 @@
package com.abin.mallchat.custom.chatai.utils; package com.abin.mallchat.custom.chatai.utils;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import com.abin.mallchat.common.common.exception.BusinessException; import com.abin.mallchat.common.common.exception.BusinessException;
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingType;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit;
@Slf4j @Slf4j
public class ChatGPTUtils { public class ChatGPTUtils {
private static final String URL = "https://api.openai.com/v1/completions"; private static final Encoding encoding = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);
private String model = "text-davinci-003"; private static final String URL = "https://api.openai.com/v1/chat/completions";
private String model = "gpt-3.5-turbo";
private final Map<String, String> headers; private final Map<String, String> headers;
/** /**
* 超时30秒 * 超时30秒
*/ */
private Integer timeout = 30 * 1000; private Integer timeout = -1;
/** /**
* 参数用于指定生成文本的最大长度。 * 参数用于指定生成文本的最大长度。
* 它表示生成的文本中最多包含多少个 token。一个 token 可以是一个单词、一个标点符号或一个空格。 * 它表示生成的文本中最多包含多少个 token。一个 token 可以是一个单词、一个标点符号或一个空格。
@@ -52,7 +61,8 @@ public class ChatGPTUtils {
/** /**
* 提示词 * 提示词
*/ */
private String prompt; private List<ChatGPTMsg> messages;
// private List<ChatGPTMsg> prompt;
private String proxyUrl; private String proxyUrl;
@@ -70,21 +80,22 @@ public class ChatGPTUtils {
return new ChatGPTUtils(key); return new ChatGPTUtils(key);
} }
public static String parseText(HttpResponse response) { @SneakyThrows
return parseText(response.body()); public static String parseText(Response response) {
return parseText(response.body().string());
} }
public static String parseText(String body) { public static String parseText(String body) {
log.info("body >>> " + body); log.info("body >>> " + body);
JSONObject jsonObj = new JSONObject(body); JSONObject jsonObject = JSONObject.parseObject(body);
JSONObject error = jsonObj.getJSONObject("error"); JSONObject error = jsonObject.getJSONObject("error");
if (error != null) { if (error != null) {
log.error("error >>> " + error); log.error("error >>> " + error);
return "闹脾气了,等会再试试吧~"; return "闹脾气了,等会再试试吧~";
} }
JSONArray choicesArr = jsonObj.getJSONArray("choices"); JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices"));
JSONObject choiceObj = choicesArr.getJSONObject(0); JSONObject choice = choices.getJSONObject(0);
return choiceObj.getStr("text"); return choice.getJSONObject("message").getString("content");
} }
public ChatGPTUtils model(String model) { public ChatGPTUtils model(String model) {
@@ -122,8 +133,8 @@ public class ChatGPTUtils {
return this; return this;
} }
public ChatGPTUtils prompt(String prompt) { public ChatGPTUtils message(List<ChatGPTMsg> messages) {
this.prompt = prompt; this.messages = messages;
return this; return this;
} }
@@ -132,37 +143,49 @@ public class ChatGPTUtils {
return this; return this;
} }
public HttpResponse send() { public Response send() throws IOException {
JSONObject param = new JSONObject(); // cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject();
param.set("model", model); // param.set("model", model);
param.set("prompt", prompt); // param.set("messages", messages);
param.set("max_tokens", maxTokens); // param.set("max_tokens", maxTokens);
param.set("temperature", temperature); // param.set("temperature", temperature);
param.set("top_p", topP); // param.set("top_p", topP);
param.set("frequency_penalty", frequencyPenalty); // param.set("frequency_penalty", frequencyPenalty);
param.set("presence_penalty", presencePenalty); // param.set("presence_penalty", presencePenalty);
log.info("headers >>> " + headers); // log.info("headers >>> " + headers);
log.info("param >>> " + param); OkHttpClient okHttpClient = new OkHttpClient()
return HttpUtil.createPost(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) .newBuilder()
.addHeaders(headers) .connectTimeout(10, TimeUnit.SECONDS)
.body(param.toString()) .writeTimeout(10, TimeUnit.SECONDS)
.timeout(timeout) .readTimeout(60, TimeUnit.SECONDS)
.execute(); .build();
} Map<String, Object> paramMap = new HashMap<>();
paramMap.put("model", model);
paramMap.put("messages", messages);
paramMap.put("max_tokens", maxTokens);
paramMap.put("temperature", temperature);
paramMap.put("top_p", topP);
paramMap.put("frequency_penalty", frequencyPenalty);
paramMap.put("presence_penalty", presencePenalty);
Request request = new Request.Builder()
.url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
.addHeader("Content-Type", "application/json")
.addHeader("Authorization", headers.get("Authorization"))
.post(RequestBody.create(MediaType.parse("application/json"), JSONObject.toJSONString(paramMap)))
.build();
return okHttpClient.newCall(request).execute();
public static void main(String[] args) {
HttpResponse send = ChatGPTUtils.create("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")
.timeout(30 * 1000)
.prompt("Spring的启动流程是什么")
.send();
System.out.println("send = " + send);
// JSON 数据
// JSON 数据
JSONObject jsonObj = new JSONObject(send.body());
JSONArray choicesArr = jsonObj.getJSONArray("choices");
JSONObject choiceObj = choicesArr.getJSONObject(0);
String text = choiceObj.getStr("text");
System.out.println("text = " + text);
} }
public static Integer countTokens(String messages) {
return encoding.countTokens(messages);
}
public static Integer countTokens(List<ChatGPTMsg> msg) {
return countTokens(JSONObject.toJSONString(msg));
}
} }

View File

@@ -46,6 +46,7 @@
<jsoup.version>1.15.3</jsoup.version> <jsoup.version>1.15.3</jsoup.version>
<okhttp.version>4.8.1</okhttp.version> <okhttp.version>4.8.1</okhttp.version>
<redisson-spring-boot-starter.version>3.17.1</redisson-spring-boot-starter.version> <redisson-spring-boot-starter.version>3.17.1</redisson-spring-boot-starter.version>
<fastjosn.version>1.2.83</fastjosn.version>
</properties> </properties>
<dependencyManagement> <dependencyManagement>
@@ -130,6 +131,11 @@
<artifactId>redisson-spring-boot-starter</artifactId> <artifactId>redisson-spring-boot-starter</artifactId>
<version>${redisson-spring-boot-starter.version}</version> <version>${redisson-spring-boot-starter.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>${fastjosn.version}</version>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>