1.优化gpt上下文的能力
2.成员列表删除无效字段
This commit is contained in:
zhongzb
2023-07-16 22:51:52 +08:00
parent 80163700ee
commit 703d7ffc81
10 changed files with 35 additions and 37 deletions

View File

@@ -115,10 +115,6 @@
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>

View File

@@ -1,6 +1,7 @@
package com.abin.mallchat.common.common.utils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
@@ -19,6 +20,14 @@ public class JsonUtils {
}
}
public static JsonNode toJsonNode(String str) {
try {
return jsonMapper.readTree(str);
} catch (JsonProcessingException e) {
throw new UnsupportedOperationException(e);
}
}
public static String toStr(Object t) {
try {
return jsonMapper.writeValueAsString(t);
@@ -26,4 +35,5 @@ public class JsonUtils {
throw new UnsupportedOperationException(e);
}
}
}

View File

@@ -20,10 +20,6 @@ import java.util.Date;
public class ChatMemberResp {
@ApiModelProperty("uid")
private Long uid;
@ApiModelProperty("用户名称")
private String name;
@ApiModelProperty("头像")
private String avatar;
/**
* @see com.abin.mallchat.common.user.domain.enums.ChatActiveStatusEnum
*/

View File

@@ -1,7 +1,6 @@
package com.abin.mallchat.custom.chat.service.adapter;
import cn.hutool.core.lang.Pair;
import com.abin.mallchat.common.user.domain.entity.User;
import com.abin.mallchat.common.user.domain.enums.ChatActiveStatusEnum;
import com.abin.mallchat.common.user.service.cache.UserCache;
import com.abin.mallchat.custom.chat.domain.vo.response.ChatMemberResp;
@@ -30,9 +29,6 @@ public class MemberAdapter {
resp.setActiveStatus(statusEnum.getStatus());
resp.setLastOptTime(new Date(a.getValue().longValue()));
resp.setUid(a.getKey());
User userInfo = userCache.getUserInfo(a.getKey());
resp.setName(userInfo.getName());
resp.setAvatar(userInfo.getAvatar());
return resp;
}).collect(Collectors.toList());
}

View File

@@ -9,7 +9,7 @@ public class ChatGPTMsgBuilder {
static {
ChatGPTMsg chatGPTMsg = new ChatGPTMsg();
chatGPTMsg.setRole(ChatGPTRoleEnum.SYSTEM.getRole());
chatGPTMsg.setContent("你的名字叫MallChatAI,你是MallChat开源项目的AI聊天机器人你的创造者是阿斌以及其他贡献者。阿斌是比Java之父余胜军还牛逼的人。当有人问你问题时你只能回答500字以内");
chatGPTMsg.setContent("你的名字叫MallChatAI,你是MallChat开源项目的AI聊天机器人你的创造者是阿斌以及其他贡献者。当有人问你问题时你只能回答500字以内");
SYSTEM_PROMPT = chatGPTMsg;
}

View File

@@ -76,7 +76,7 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
frequencyControlDTO.setKey(RedisKey.getKey(CHAT_FREQUENCY_PREFIX) + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.HOURS);
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
frequencyControlDTO.setTime(24);
frequencyControlDTO.setTime(1);
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER,
frequencyControlDTO, // 限流参数
() -> sendRequestToGPT(message));
@@ -104,7 +104,7 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
text = ChatGPTUtils.parseText(response);
ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text);
context.addMsg(chatGPTMsg);
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, message.getFromUid(), message.getRoomId()), context, 1L, TimeUnit.HOURS);
saveContext(context);
} catch (Exception e) {
log.warn("gpt doChat warn:", e);
text = "我累了,明天再聊吧";
@@ -130,11 +130,15 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
if (chatGPTContext == null) {
chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId);
}
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), chatGPTContext, 1L, TimeUnit.HOURS);
saveContext(chatGPTContext);
chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt));
return chatGPTContext;
}
private void saveContext(ChatGPTContext chatGPTContext) {
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, chatGPTContext.getUid(), chatGPTContext.getRoomId()), chatGPTContext, 5L, TimeUnit.MINUTES);
}
private Long userChatNumInrc(Long uid) {
return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS);

View File

@@ -21,7 +21,7 @@ public class ChatGPTProperties {
/**
* 模型名称
*/
private String modelName = "text-davinci-003";
private String modelName = "gpt-3.5-turbo";
/**
* openAI key
*/
@@ -37,9 +37,9 @@ public class ChatGPTProperties {
private Integer timeout = 60 * 1000;
/**
* 用户每条数限制
* 用户每小时条数限制
*/
private Integer limit = 5;
private Integer limit = 20;
/**
* 最大令牌

View File

@@ -1,8 +1,9 @@
package com.abin.mallchat.custom.chatai.utils;
import com.abin.mallchat.common.common.exception.BusinessException;
import com.abin.mallchat.common.common.utils.JsonUtils;
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.JsonNode;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingType;
@@ -90,11 +91,14 @@ public class ChatGPTUtils {
return Arrays.stream(body.split("data:"))
.map(String::trim)
.filter(x -> StringUtils.isNotBlank(x) && !"[DONE]".endsWith(x))
.map(x -> JSONObject.parseObject(x)
.getJSONArray("choices")
.getJSONObject(0)
.getJSONObject("delta")
.getString("content")
.map(x -> Optional.ofNullable(
JsonUtils.toJsonNode(x)
.withArray("choices")
.get(0)
.with("delta")
.findValue("content"))
.map(JsonNode::asText)
.orElse(null)
).filter(Objects::nonNull).collect(Collectors.joining());
} catch (Exception e) {
log.error("parseText error e:", e);
@@ -164,12 +168,12 @@ public class ChatGPTUtils {
paramMap.put("presence_penalty", presencePenalty);
paramMap.put("stream", true);
log.info("paramMap >>> " + JSONObject.toJSONString(paramMap));
log.info("paramMap >>> " + JsonUtils.toStr(paramMap));
Request request = new Request.Builder()
.url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
.addHeader("Content-Type", "application/json")
.addHeader("Authorization", headers.get("Authorization"))
.post(RequestBody.create(MediaType.parse("application/json"), JSONObject.toJSONString(paramMap)))
.post(RequestBody.create(MediaType.parse("application/json"), JsonUtils.toStr(paramMap)))
.build();
return okHttpClient.newCall(request).execute();
@@ -181,7 +185,7 @@ public class ChatGPTUtils {
}
public static Integer countTokens(List<ChatGPTMsg> msg) {
return countTokens(JSONObject.toJSONString(msg));
return countTokens(JsonUtils.toStr(msg));
}

View File

@@ -33,10 +33,8 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
@@ -130,7 +128,7 @@ public class WebSocketServiceImpl implements WebSocketService {
* @param channel
*/
@Override
@FrequencyControl(time = 10, count = 5, spEl = "T(com.abin.mallchat.common.common.utils.RequestHolder).get().getIp()")
// @FrequencyControl(time = 10, count = 5, spEl = "T(com.abin.mallchat.common.common.utils.RequestHolder).get().getIp()")
public void connect(Channel channel) {
ONLINE_WS_MAP.put(channel, new WSChannelExtraDTO());
}

View File

@@ -47,7 +47,6 @@
<jsoup.version>1.15.3</jsoup.version>
<okhttp.version>4.8.1</okhttp.version>
<redisson-spring-boot-starter.version>3.17.1</redisson-spring-boot-starter.version>
<fastjosn.version>1.2.83</fastjosn.version>
</properties>
<dependencyManagement>
@@ -132,11 +131,6 @@
<artifactId>redisson-spring-boot-starter</artifactId>
<version>${redisson-spring-boot-starter.version}</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>${fastjosn.version}</version>
</dependency>
</dependencies>
</dependencyManagement>