mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-16 13:23:42 +00:00
feat: 增加联网查询功能
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
package org.ruoyi.common.chat.demo.zhipu;
|
package org.ruoyi.common.chat.demo;
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.core.type.TypeReference;
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
@@ -17,34 +17,11 @@ import java.util.concurrent.atomic.AtomicBoolean;
|
|||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
|
|
||||||
import com.zhipu.oapi.core.response.HttpxBinaryResponseContent;
|
|
||||||
import com.zhipu.oapi.service.v4.batchs.BatchCreateParams;
|
|
||||||
import com.zhipu.oapi.service.v4.batchs.BatchResponse;
|
|
||||||
import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse;
|
|
||||||
import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse;
|
|
||||||
import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest;
|
|
||||||
import com.zhipu.oapi.service.v4.file.*;
|
|
||||||
import com.zhipu.oapi.service.v4.fine_turning.*;
|
|
||||||
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
|
|
||||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
|
||||||
import com.zhipu.oapi.service.v4.model.*;
|
import com.zhipu.oapi.service.v4.model.*;
|
||||||
import io.reactivex.Flowable;
|
import io.reactivex.Flowable;
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
|
||||||
|
|
||||||
|
|
||||||
public class WebSearchToolsTest {
|
public class WebSearchToolsTest {
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
package org.ruoyi.common.chat.demo.zhipu;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
|
||||||
import com.zhipu.oapi.ClientV4;
|
|
||||||
import com.zhipu.oapi.Constants;
|
|
||||||
import com.zhipu.oapi.service.v4.deserialize.MessageDeserializeFactory;
|
|
||||||
import com.zhipu.oapi.service.v4.model.*;
|
|
||||||
import io.reactivex.Flowable;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
|
|
||||||
public class AllToolsTest {
|
|
||||||
|
|
||||||
private final static Logger logger = LoggerFactory.getLogger(AllToolsTest.class);
|
|
||||||
private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0";
|
|
||||||
|
|
||||||
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
|
|
||||||
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
|
||||||
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
|
||||||
.build();
|
|
||||||
private static final ObjectMapper mapper = MessageDeserializeFactory.defaultObjectMapper();
|
|
||||||
// 请自定义自己的业务id
|
|
||||||
private static final String requestIdTemplate = "mycompany-%d";
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test1() throws JsonProcessingException {
|
|
||||||
|
|
||||||
|
|
||||||
List<ChatMessage> messages = new ArrayList<>();
|
|
||||||
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "帮我查询北京天气");
|
|
||||||
messages.add(chatMessage);
|
|
||||||
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
|
||||||
// 函数调用参数构建部分
|
|
||||||
List<ChatTool> chatToolList = new ArrayList<>();
|
|
||||||
ChatTool chatTool = new ChatTool();
|
|
||||||
|
|
||||||
chatTool.setType("code_interpreter");
|
|
||||||
ObjectNode objectNode = mapper.createObjectNode();
|
|
||||||
objectNode.put("code", "北京天气");
|
|
||||||
// chatTool.set(chatFunction);
|
|
||||||
chatToolList.add(chatTool);
|
|
||||||
|
|
||||||
|
|
||||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
|
||||||
.model("glm-4-alltools")
|
|
||||||
.stream(Boolean.TRUE)
|
|
||||||
.invokeMethod(Constants.invokeMethod)
|
|
||||||
.messages(messages)
|
|
||||||
.tools(chatToolList)
|
|
||||||
.toolChoice("auto")
|
|
||||||
.requestId(requestId)
|
|
||||||
.build();
|
|
||||||
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
|
|
||||||
if (sseModelApiResp.isSuccess()) {
|
|
||||||
AtomicBoolean isFirst = new AtomicBoolean(true);
|
|
||||||
List<Choice> choices = new ArrayList<>();
|
|
||||||
AtomicReference<ChatMessageAccumulator> lastAccumulator = new AtomicReference<>();
|
|
||||||
|
|
||||||
mapStreamToAccumulator(sseModelApiResp.getFlowable())
|
|
||||||
.doOnNext(accumulator -> {
|
|
||||||
{
|
|
||||||
if (isFirst.getAndSet(false)) {
|
|
||||||
logger.info("Response: ");
|
|
||||||
}
|
|
||||||
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
|
|
||||||
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
|
|
||||||
logger.info("tool_calls: {}", jsonString);
|
|
||||||
}
|
|
||||||
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
|
|
||||||
logger.info(accumulator.getDelta().getContent());
|
|
||||||
}
|
|
||||||
choices.add(accumulator.getChoice());
|
|
||||||
lastAccumulator.set(accumulator);
|
|
||||||
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.doOnComplete(() -> System.out.println("Stream completed."))
|
|
||||||
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
|
|
||||||
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
|
|
||||||
|
|
||||||
ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get();
|
|
||||||
ModelData data = new ModelData();
|
|
||||||
data.setChoices(choices);
|
|
||||||
if (chatMessageAccumulator != null) {
|
|
||||||
data.setUsage(chatMessageAccumulator.getUsage());
|
|
||||||
data.setId(chatMessageAccumulator.getId());
|
|
||||||
data.setCreated(chatMessageAccumulator.getCreated());
|
|
||||||
}
|
|
||||||
data.setRequestId(chatCompletionRequest.getRequestId());
|
|
||||||
sseModelApiResp.setFlowable(null);// 打印前置空
|
|
||||||
sseModelApiResp.setData(data);
|
|
||||||
}
|
|
||||||
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
|
|
||||||
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
|
|
||||||
|
|
||||||
client.getConfig().getHttpClient().connectionPool().evictAll();
|
|
||||||
// List all active threads
|
|
||||||
for (Thread t : Thread.getAllStackTraces().keySet()) {
|
|
||||||
logger.info("Thread: " + t.getName() + " State: " + t.getState());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
|
|
||||||
return flowable.map(chunk -> {
|
|
||||||
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,6 @@ import org.ruoyi.common.core.exception.base.BaseException;
|
|||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||||
import org.ruoyi.common.satoken.utils.LoginHelper;
|
import org.ruoyi.common.satoken.utils.LoginHelper;
|
||||||
import org.ruoyi.knowledge.service.EmbeddingService;
|
|
||||||
import org.ruoyi.system.domain.bo.ChatMessageBo;
|
import org.ruoyi.system.domain.bo.ChatMessageBo;
|
||||||
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
||||||
import org.ruoyi.system.domain.vo.ChatMessageVo;
|
import org.ruoyi.system.domain.vo.ChatMessageVo;
|
||||||
@@ -48,7 +47,6 @@ public class ChatController {
|
|||||||
|
|
||||||
private final IChatMessageService chatMessageService;
|
private final IChatMessageService chatMessageService;
|
||||||
|
|
||||||
private final EmbeddingService embeddingService;
|
|
||||||
/**
|
/**
|
||||||
* 聊天接口
|
* 聊天接口
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package org.ruoyi.knowledge.chain.split;
|
|||||||
|
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.core.utils.StringUtils;
|
||||||
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
||||||
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
|
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
|
||||||
import org.springframework.context.annotation.Lazy;
|
import org.springframework.context.annotation.Lazy;
|
||||||
@@ -29,7 +30,7 @@ public class CharacterTextSplitter implements TextSplitter {
|
|||||||
int textBlockSize = knowledgeInfoVo.getTextBlockSize();
|
int textBlockSize = knowledgeInfoVo.getTextBlockSize();
|
||||||
int overlapChar = knowledgeInfoVo.getOverlapChar();
|
int overlapChar = knowledgeInfoVo.getOverlapChar();
|
||||||
List<String> chunkList = new ArrayList<>();
|
List<String> chunkList = new ArrayList<>();
|
||||||
if (content.contains(knowledgeSeparator)) {
|
if (content.contains(knowledgeSeparator) && StringUtils.isNotBlank(knowledgeSeparator)) {
|
||||||
// 按自定义分隔符切分
|
// 按自定义分隔符切分
|
||||||
String[] chunks = content.split(knowledgeSeparator);
|
String[] chunks = content.split(knowledgeSeparator);
|
||||||
chunkList.addAll(Arrays.asList(chunks));
|
chunkList.addAll(Arrays.asList(chunks));
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package org.ruoyi.knowledge.chain.vectorstore;
|
package org.ruoyi.knowledge.chain.vectorstore;
|
||||||
|
|
||||||
|
import cn.hutool.core.util.StrUtil;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
||||||
@@ -23,8 +24,13 @@ public class VectorStoreFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public VectorStore getVectorStore(String kid){
|
public VectorStore getVectorStore(String kid){
|
||||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid));
|
String vectorModel = "weaviate";
|
||||||
String vectorModel = knowledgeInfoVo.getVector();
|
if (StrUtil.isNotEmpty(kid)) {
|
||||||
|
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid));
|
||||||
|
if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVector())) {
|
||||||
|
vectorModel = knowledgeInfoVo.getVector();
|
||||||
|
}
|
||||||
|
}
|
||||||
if ("weaviate".equals(vectorModel)){
|
if ("weaviate".equals(vectorModel)){
|
||||||
return weaviateVectorStore;
|
return weaviateVectorStore;
|
||||||
}else if ("milvus".equals(vectorModel)){
|
}else if ("milvus".equals(vectorModel)){
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
package org.ruoyi.system.plugin;
|
|
||||||
|
|
||||||
|
|
||||||
import cn.hutool.json.JSONUtil;
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import okhttp3.OkHttpClient;
|
|
||||||
import okhttp3.logging.HttpLoggingInterceptor;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.ruoyi.common.chat.demo.ConsoleEventSourceListenerV3;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.Message;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.Parameters;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.tool.Tools;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction;
|
|
||||||
import org.ruoyi.common.chat.openai.OpenAiClient;
|
|
||||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
|
||||||
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
|
|
||||||
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
|
||||||
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
|
|
||||||
import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.CountDownLatch;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
|
|
||||||
public class WebSearchPlugin {
|
|
||||||
|
|
||||||
private OpenAiClient openAiClient;
|
|
||||||
private OpenAiStreamClient openAiStreamClient;
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void before() {
|
|
||||||
//可以为null
|
|
||||||
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
|
|
||||||
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
|
|
||||||
//!!!!千万别再生产或者测试环境打开BODY级别日志!!!!
|
|
||||||
//!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!!
|
|
||||||
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
|
|
||||||
OkHttpClient okHttpClient = new OkHttpClient
|
|
||||||
.Builder()
|
|
||||||
// .proxy(proxy)
|
|
||||||
.addInterceptor(httpLoggingInterceptor)
|
|
||||||
.addInterceptor(new OpenAiResponseInterceptor())
|
|
||||||
.connectTimeout(10, TimeUnit.SECONDS)
|
|
||||||
.writeTimeout(30, TimeUnit.SECONDS)
|
|
||||||
.readTimeout(30, TimeUnit.SECONDS)
|
|
||||||
.build();
|
|
||||||
openAiClient = OpenAiClient.builder()
|
|
||||||
//支持多key传入,请求时候随机选择
|
|
||||||
.apiKey(Arrays.asList("xx"))
|
|
||||||
//自定义key的获取策略:默认KeyRandomStrategy
|
|
||||||
//.keyStrategy(new KeyRandomStrategy())
|
|
||||||
.keyStrategy(new KeyRandomStrategy())
|
|
||||||
.okHttpClient(okHttpClient)
|
|
||||||
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
|
||||||
.apiHost("https://open.bigmodel.cn/")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
openAiStreamClient = OpenAiStreamClient.builder()
|
|
||||||
//支持多key传入,请求时候随机选择
|
|
||||||
.apiKey(Arrays.asList("xx"))
|
|
||||||
//自定义key的获取策略:默认KeyRandomStrategy
|
|
||||||
.keyStrategy(new KeyRandomStrategy())
|
|
||||||
.authInterceptor(new DynamicKeyOpenAiAuthInterceptor())
|
|
||||||
.okHttpClient(okHttpClient)
|
|
||||||
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
|
||||||
.apiHost("https://open.bigmodel.cn/")
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test() {
|
|
||||||
Message message = Message.builder().role(Message.Role.USER).content("今天武汉天气怎么样").build();
|
|
||||||
ChatCompletion chatCompletion = ChatCompletion
|
|
||||||
.builder()
|
|
||||||
.messages(Collections.singletonList(message))
|
|
||||||
// .tools(Collections.singletonList(tools))
|
|
||||||
.model("web-search-pro")
|
|
||||||
.build();
|
|
||||||
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
|
||||||
|
|
||||||
System.out.printf("chatCompletionResponse=%s\n", JSONUtil.toJsonStr(chatCompletionResponse));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void streamToolsChat() {
|
|
||||||
|
|
||||||
CountDownLatch countDownLatch = new CountDownLatch(1);
|
|
||||||
ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch);
|
|
||||||
|
|
||||||
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
|
||||||
//属性一
|
|
||||||
JSONObject wordLength = new JSONObject();
|
|
||||||
wordLength.put("type", "number");
|
|
||||||
wordLength.put("description", "词语的长度");
|
|
||||||
//属性二
|
|
||||||
JSONObject language = new JSONObject();
|
|
||||||
language.put("type", "string");
|
|
||||||
language.put("enum", Arrays.asList("zh", "en"));
|
|
||||||
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
|
||||||
//参数
|
|
||||||
JSONObject properties = new JSONObject();
|
|
||||||
properties.put("wordLength", wordLength);
|
|
||||||
properties.put("language", language);
|
|
||||||
Parameters parameters = Parameters.builder()
|
|
||||||
.type("object")
|
|
||||||
.properties(properties)
|
|
||||||
.required(Collections.singletonList("wordLength")).build();
|
|
||||||
Tools tools = Tools.builder()
|
|
||||||
.type(Tools.Type.FUNCTION.getName())
|
|
||||||
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ChatCompletion chatCompletion = ChatCompletion
|
|
||||||
.builder()
|
|
||||||
.messages(Collections.singletonList(message))
|
|
||||||
.tools(Collections.singletonList(tools))
|
|
||||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
|
||||||
.build();
|
|
||||||
openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener);
|
|
||||||
|
|
||||||
try {
|
|
||||||
countDownLatch.await();
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
|
|
||||||
ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls();
|
|
||||||
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
|
|
||||||
String oneWord = getOneWord(wordParam);
|
|
||||||
|
|
||||||
|
|
||||||
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
|
|
||||||
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
|
|
||||||
//构造tool call
|
|
||||||
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
|
|
||||||
String content
|
|
||||||
= "{ " +
|
|
||||||
"\"wordLength\": \"3\", " +
|
|
||||||
"\"language\": \"zh\", " +
|
|
||||||
"\"word\": \"" + oneWord + "\"," +
|
|
||||||
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
|
||||||
"}";
|
|
||||||
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
|
|
||||||
List<Message> messageList = Arrays.asList(message, message2, message3);
|
|
||||||
ChatCompletion chatCompletionV2 = ChatCompletion
|
|
||||||
.builder()
|
|
||||||
.messages(messageList)
|
|
||||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
CountDownLatch countDownLatch1 = new CountDownLatch(1);
|
|
||||||
openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch));
|
|
||||||
try {
|
|
||||||
countDownLatch1.await();
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
countDownLatch1.await();
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Builder
|
|
||||||
static class WordParam {
|
|
||||||
private int wordLength;
|
|
||||||
@Builder.Default
|
|
||||||
private String language = "zh";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 获取一个词语(根据语言和字符长度查询)
|
|
||||||
* @param wordParam
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public String getOneWord(WordParam wordParam) {
|
|
||||||
|
|
||||||
List<String> zh = Arrays.asList("大香蕉", "哈密瓜", "苹果");
|
|
||||||
List<String> en = Arrays.asList("apple", "banana", "cantaloupe");
|
|
||||||
if (wordParam.getLanguage().equals("zh")) {
|
|
||||||
for (String e : zh) {
|
|
||||||
if (e.length() == wordParam.getWordLength()) {
|
|
||||||
return e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (wordParam.getLanguage().equals("en")) {
|
|
||||||
for (String e : en) {
|
|
||||||
if (e.length() == wordParam.getWordLength()) {
|
|
||||||
return e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "西瓜";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -2,6 +2,7 @@ package org.ruoyi.system.service;
|
|||||||
|
|
||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||||
|
import org.ruoyi.system.domain.SysModel;
|
||||||
import org.ruoyi.system.domain.bo.SysModelBo;
|
import org.ruoyi.system.domain.bo.SysModelBo;
|
||||||
import org.ruoyi.system.domain.vo.SysModelVo;
|
import org.ruoyi.system.domain.vo.SysModelVo;
|
||||||
|
|
||||||
@@ -45,4 +46,9 @@ public interface ISysModelService {
|
|||||||
* 校验并批量删除系统模型信息
|
* 校验并批量删除系统模型信息
|
||||||
*/
|
*/
|
||||||
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据模型名称查询模型
|
||||||
|
*/
|
||||||
|
SysModel selectModelByName(String modelName);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package org.ruoyi.system.service.impl;
|
package org.ruoyi.system.service.impl;
|
||||||
|
|
||||||
import cn.dev33.satoken.stp.StpUtil;
|
import cn.dev33.satoken.stp.StpUtil;
|
||||||
import cn.hutool.core.collection.CollectionUtil;
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.zhipu.oapi.ClientV4;
|
||||||
|
import com.zhipu.oapi.service.v4.tools.*;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||||
@@ -17,10 +19,7 @@ import org.ruoyi.common.chat.config.LocalCache;
|
|||||||
import org.ruoyi.common.chat.domain.request.ChatRequest;
|
import org.ruoyi.common.chat.domain.request.ChatRequest;
|
||||||
import org.ruoyi.common.chat.domain.request.Dall3Request;
|
import org.ruoyi.common.chat.domain.request.Dall3Request;
|
||||||
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
import org.ruoyi.common.chat.entity.chat.*;
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.Content;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.Message;
|
|
||||||
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
||||||
import org.ruoyi.common.chat.entity.images.Image;
|
import org.ruoyi.common.chat.entity.images.Image;
|
||||||
import org.ruoyi.common.chat.entity.images.ImageResponse;
|
import org.ruoyi.common.chat.entity.images.ImageResponse;
|
||||||
@@ -33,17 +32,15 @@ import org.ruoyi.common.chat.plugin.CmdPlugin;
|
|||||||
import org.ruoyi.common.chat.plugin.CmdReq;
|
import org.ruoyi.common.chat.plugin.CmdReq;
|
||||||
import org.ruoyi.common.chat.plugin.SqlPlugin;
|
import org.ruoyi.common.chat.plugin.SqlPlugin;
|
||||||
import org.ruoyi.common.chat.plugin.SqlReq;
|
import org.ruoyi.common.chat.plugin.SqlReq;
|
||||||
import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
|
|
||||||
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
||||||
import org.ruoyi.common.core.domain.model.LoginUser;
|
import org.ruoyi.common.core.domain.model.LoginUser;
|
||||||
import org.ruoyi.common.core.exception.base.BaseException;
|
import org.ruoyi.common.core.exception.base.BaseException;
|
||||||
import org.ruoyi.common.core.service.ConfigService;
|
import org.ruoyi.common.core.service.ConfigService;
|
||||||
import org.ruoyi.common.core.utils.StringUtils;
|
import org.ruoyi.common.core.utils.StringUtils;
|
||||||
import org.ruoyi.common.satoken.utils.LoginHelper;
|
import org.ruoyi.common.satoken.utils.LoginHelper;
|
||||||
|
import org.ruoyi.system.domain.SysModel;
|
||||||
import org.ruoyi.system.domain.bo.ChatMessageBo;
|
import org.ruoyi.system.domain.bo.ChatMessageBo;
|
||||||
import org.ruoyi.system.domain.bo.SysModelBo;
|
|
||||||
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
||||||
import org.ruoyi.system.domain.vo.SysModelVo;
|
|
||||||
import org.ruoyi.system.listener.SSEEventSourceListener;
|
import org.ruoyi.system.listener.SSEEventSourceListener;
|
||||||
import org.ruoyi.system.service.*;
|
import org.ruoyi.system.service.*;
|
||||||
import org.springframework.core.io.InputStreamResource;
|
import org.springframework.core.io.InputStreamResource;
|
||||||
@@ -65,6 +62,9 @@ import java.util.ArrayList;
|
|||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@@ -76,18 +76,21 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
private final ChatConfig chatConfig;
|
private final ChatConfig chatConfig;
|
||||||
|
|
||||||
|
|
||||||
private final IChatCostService chatService;
|
private final IChatCostService chatService;
|
||||||
|
|
||||||
private final IChatMessageService chatMessageService;
|
private final IChatMessageService chatMessageService;
|
||||||
|
|
||||||
private final ISysModelService sysModelService;
|
private final ISysModelService sysModelService;
|
||||||
|
|
||||||
private final ISysUserService userService;
|
|
||||||
|
|
||||||
private final ConfigService configService;
|
private final ConfigService configService;
|
||||||
|
|
||||||
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
|
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
|
||||||
|
|
||||||
|
private static final String requestIdTemplate = "mycompany-%d";
|
||||||
|
|
||||||
|
private static final ObjectMapper mapper = new ObjectMapper();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
||||||
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
||||||
@@ -96,11 +99,10 @@ public class SseServiceImpl implements ISseService {
|
|||||||
// 获取对话消息列表
|
// 获取对话消息列表
|
||||||
List<Message> messages = chatRequest.getMessages();
|
List<Message> messages = chatRequest.getMessages();
|
||||||
try {
|
try {
|
||||||
|
String chatString = null;
|
||||||
if (StpUtil.isLogin()) {
|
if (StpUtil.isLogin()) {
|
||||||
LocalCache.CACHE.put("userId", getUserId());
|
LocalCache.CACHE.put("userId", getUserId());
|
||||||
Object content = messages.get(messages.size() - 1).getContent();
|
Object content = messages.get(messages.size() - 1).getContent();
|
||||||
|
|
||||||
String chatString = "";
|
|
||||||
if (content instanceof List<?> listContent) {
|
if (content instanceof List<?> listContent) {
|
||||||
if (!listContent.isEmpty() && listContent.get(0) instanceof Content) {
|
if (!listContent.isEmpty() && listContent.get(0) instanceof Content) {
|
||||||
chatString = ((Content) listContent.get(0)).getText();
|
chatString = ((Content) listContent.get(0)).getText();
|
||||||
@@ -123,39 +125,89 @@ public class SseServiceImpl implements ISseService {
|
|||||||
throw new BaseException("文本不合规,请修改!");
|
throw new BaseException("文本不合规,请修改!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//根据模型名称查询模型信息
|
String model = chatRequest.getModel();
|
||||||
SysModelBo sysModelBo = new SysModelBo();
|
|
||||||
// 如果是gpts系列模型
|
// 如果是gpts系列模型
|
||||||
if (chatRequest.getModel().startsWith("gpt-4-gizmo")) {
|
if (chatRequest.getModel().startsWith("gpt-4-gizmo")) {
|
||||||
sysModelBo.setModelName("gpt-4-gizmo");
|
model = "gpt-4-gizmo";
|
||||||
} else {
|
|
||||||
sysModelBo.setModelName(chatRequest.getModel());
|
|
||||||
}
|
}
|
||||||
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
SysModel sysModel = sysModelService.selectModelByName(model);
|
||||||
|
if (sysModel != null) {
|
||||||
if (CollectionUtil.isEmpty(sysModelList)) {
|
|
||||||
// 如果模型不存在默认使用token扣费方式
|
// 如果模型不存在默认使用token扣费方式
|
||||||
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
|
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
|
||||||
} else {
|
} else {
|
||||||
openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModelList.get(0).getApiHost(), sysModelList.get(0).getApiKey());
|
openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModel.getApiHost(), sysModel.getApiKey());
|
||||||
// 模型设置默认提示词
|
// 模型设置默认提示词
|
||||||
SysModelVo firstModel = sysModelList.get(0);
|
|
||||||
if (StringUtils.isNotEmpty(firstModel.getSystemPrompt())) {
|
if (StringUtils.isNotEmpty(sysModel.getSystemPrompt())) {
|
||||||
Message sysMessage = Message.builder().content(firstModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
|
Message sysMessage = Message.builder().content(sysModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
|
||||||
messages.add(sysMessage);
|
messages.add(sysMessage);
|
||||||
}
|
}
|
||||||
// 计费类型: 1 token扣费 2 次数扣费
|
// 计费类型: 1 token扣费 2 次数扣费
|
||||||
if ("2".equals(firstModel.getModelType())) {
|
if ("2".equals(sysModel.getModelType())) {
|
||||||
processByModelPrice(firstModel, chatMessageBo);
|
processByModelPrice(sysModel, chatMessageBo);
|
||||||
} else {
|
} else {
|
||||||
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
|
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if("openCmd".equals(chatRequest.getModel())) {
|
String configValue = configService.getConfigValue("zhipu", "key");
|
||||||
|
// 添加联网信息
|
||||||
|
if(StringUtils.isNotEmpty(configValue)){
|
||||||
|
ClientV4 client = new ClientV4.Builder(configValue)
|
||||||
|
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
||||||
|
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
SearchChatMessage jsonNodes = new SearchChatMessage();
|
||||||
|
jsonNodes.setRole(Message.Role.USER.getName());
|
||||||
|
jsonNodes.setContent(chatString);
|
||||||
|
|
||||||
|
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||||
|
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
|
||||||
|
.model("web-search-pro")
|
||||||
|
.stream(Boolean.TRUE)
|
||||||
|
.messages(Collections.singletonList(jsonNodes))
|
||||||
|
.requestId(requestId)
|
||||||
|
.build();
|
||||||
|
WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
|
||||||
|
List<ChoiceDelta> choices = new ArrayList<>();
|
||||||
|
if (webSearchApiResponse.isSuccess()) {
|
||||||
|
AtomicBoolean isFirst = new AtomicBoolean(true);
|
||||||
|
|
||||||
|
AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
|
||||||
|
|
||||||
|
webSearchApiResponse.getFlowable().map(result -> result)
|
||||||
|
.doOnNext(accumulator -> {
|
||||||
|
{
|
||||||
|
if (isFirst.getAndSet(false)) {
|
||||||
|
log.info("Response: ");
|
||||||
|
}
|
||||||
|
ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
|
||||||
|
if (delta != null && delta.getToolCalls() != null) {
|
||||||
|
log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
|
||||||
|
}
|
||||||
|
choices.add(delta);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.doOnComplete(() -> System.out.println("Stream completed."))
|
||||||
|
.doOnError(throwable -> System.err.println("Error: " + throwable))
|
||||||
|
.blockingSubscribe();
|
||||||
|
|
||||||
|
WebSearchPro chatMessageAccumulator = lastAccumulator.get();
|
||||||
|
|
||||||
|
webSearchApiResponse.setFlowable(null);// 打印前置空
|
||||||
|
webSearchApiResponse.setData(chatMessageAccumulator);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Message message = Message.builder().role(Message.Role.ASSISTANT).content(choices.get(1).getToolCalls().toString()).build();
|
||||||
|
messages.add(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ("openCmd".equals(chatRequest.getModel())) {
|
||||||
sseEmitter.send(cmdPlugin(messages));
|
sseEmitter.send(cmdPlugin(messages));
|
||||||
sseEmitter.complete();
|
sseEmitter.complete();
|
||||||
}else if ("sqlPlugin".equals(chatRequest.getModel())){
|
} else if ("sqlPlugin".equals(chatRequest.getModel())) {
|
||||||
sseEmitter.send(sqlPlugin(messages));
|
sseEmitter.send(sqlPlugin(messages));
|
||||||
sseEmitter.complete();
|
sseEmitter.complete();
|
||||||
} else {
|
} else {
|
||||||
@@ -229,7 +281,7 @@ public class SseServiceImpl implements ISseService {
|
|||||||
* @param model 模型信息
|
* @param model 模型信息
|
||||||
* @param chatMessageBo 对话信息
|
* @param chatMessageBo 对话信息
|
||||||
*/
|
*/
|
||||||
private void processByModelPrice(SysModelVo model, ChatMessageBo chatMessageBo) {
|
private void processByModelPrice(SysModel model, ChatMessageBo chatMessageBo) {
|
||||||
double cost = model.getModelPrice();
|
double cost = model.getModelPrice();
|
||||||
chatService.deductUserBalance(getUserId(), cost);
|
chatService.deductUserBalance(getUserId(), cost);
|
||||||
chatMessageBo.setDeductCost(cost);
|
chatMessageBo.setDeductCost(cost);
|
||||||
@@ -316,16 +368,14 @@ public class SseServiceImpl implements ISseService {
|
|||||||
.style(request.getStyle())
|
.style(request.getStyle())
|
||||||
.build();
|
.build();
|
||||||
ImageResponse imageResponse = openAiStreamClient.genImages(image);
|
ImageResponse imageResponse = openAiStreamClient.genImages(image);
|
||||||
SysModelBo sysModelBo = new SysModelBo();
|
SysModel sysModel = sysModelService.selectModelByName(request.getModel());
|
||||||
sysModelBo.setModelName(request.getModel());
|
|
||||||
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
|
||||||
//chatService.deductUserBalance(getUserId(),sysModelList.get(0).getModelPrice());
|
//chatService.deductUserBalance(getUserId(),sysModelList.get(0).getModelPrice());
|
||||||
// 保存消息记录
|
// 保存消息记录
|
||||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||||
chatMessageBo.setUserId(getUserId());
|
chatMessageBo.setUserId(getUserId());
|
||||||
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
|
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
|
||||||
chatMessageBo.setContent(request.getPrompt());
|
chatMessageBo.setContent(request.getPrompt());
|
||||||
chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
|
chatMessageBo.setDeductCost(sysModel.getModelPrice());
|
||||||
chatMessageBo.setTotalTokens(0);
|
chatMessageBo.setTotalTokens(0);
|
||||||
chatMessageService.insertByBo(chatMessageBo);
|
chatMessageService.insertByBo(chatMessageBo);
|
||||||
return imageResponse.getData();
|
return imageResponse.getData();
|
||||||
@@ -342,16 +392,14 @@ public class SseServiceImpl implements ISseService {
|
|||||||
.n(1)
|
.n(1)
|
||||||
.build();
|
.build();
|
||||||
ImageResponse imageResponse = openAiStreamClient.genImages(image);
|
ImageResponse imageResponse = openAiStreamClient.genImages(image);
|
||||||
SysModelBo sysModelBo = new SysModelBo();
|
SysModel dall3 = sysModelService.selectModelByName("dall3");
|
||||||
sysModelBo.setModelName("dall3");
|
|
||||||
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
|
||||||
chatService.deductUserBalance(Long.valueOf(userId), 0.3);
|
chatService.deductUserBalance(Long.valueOf(userId), 0.3);
|
||||||
// 保存消息记录
|
// 保存消息记录
|
||||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||||
chatMessageBo.setUserId(getUserId());
|
chatMessageBo.setUserId(getUserId());
|
||||||
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
|
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
|
||||||
chatMessageBo.setContent(prompt);
|
chatMessageBo.setContent(prompt);
|
||||||
chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
|
chatMessageBo.setDeductCost(dall3.getModelPrice());
|
||||||
chatMessageBo.setTotalTokens(0);
|
chatMessageBo.setTotalTokens(0);
|
||||||
chatMessageService.insertByBo(chatMessageBo);
|
chatMessageService.insertByBo(chatMessageBo);
|
||||||
return imageResponse.getData();
|
return imageResponse.getData();
|
||||||
@@ -527,12 +575,9 @@ public class SseServiceImpl implements ISseService {
|
|||||||
chatMessageBo.setDeductCost(0.01);
|
chatMessageBo.setDeductCost(0.01);
|
||||||
chatMessageBo.setTotalTokens(0);
|
chatMessageBo.setTotalTokens(0);
|
||||||
chatMessageService.insertByBo(chatMessageBo);
|
chatMessageService.insertByBo(chatMessageBo);
|
||||||
|
|
||||||
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
||||||
|
|
||||||
List<Message> messageList = new ArrayList<>();
|
List<Message> messageList = new ArrayList<>();
|
||||||
|
Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一位精通各国语言的翻译大师\n" +
|
||||||
Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一名翻译老师\n" +
|
|
||||||
"\n" +
|
"\n" +
|
||||||
"请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" +
|
"请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
@@ -563,25 +608,21 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SseEmitter ollamaChat(ChatRequest chatRequest) {
|
public SseEmitter ollamaChat(ChatRequest chatRequest) {
|
||||||
|
String[] parts = chatRequest.getModel().split("ollama-");
|
||||||
|
SysModel sysModel = sysModelService.selectModelByName(parts[1]);
|
||||||
final SseEmitter emitter = new SseEmitter();
|
final SseEmitter emitter = new SseEmitter();
|
||||||
String host = "http://localhost:11434/";
|
String host = sysModel.getApiHost();
|
||||||
|
|
||||||
List<Message> msgList = chatRequest.getMessages();
|
List<Message> msgList = chatRequest.getMessages();
|
||||||
Message message = msgList.get(msgList.size() - 1);
|
Message message = msgList.get(msgList.size() - 1);
|
||||||
|
OllamaAPI api = new OllamaAPI(host);
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
api.setRequestTimeoutSeconds(100);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(sysModel.getModelName());
|
||||||
ollamaAPI.setRequestTimeoutSeconds(100);
|
|
||||||
|
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("qwen2.5:7b");
|
|
||||||
|
|
||||||
OllamaChatRequestModel requestModel = builder
|
OllamaChatRequestModel requestModel = builder
|
||||||
.withMessage(OllamaChatMessageRole.USER,
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
message.getContent().toString())
|
message.getContent().toString())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
// 异步执行 OllAma API 调用
|
||||||
// 异步执行 Ollama API 调用
|
|
||||||
CompletableFuture.runAsync(() -> {
|
CompletableFuture.runAsync(() -> {
|
||||||
try {
|
try {
|
||||||
StringBuilder response = new StringBuilder();
|
StringBuilder response = new StringBuilder();
|
||||||
@@ -595,14 +636,12 @@ public class SseServiceImpl implements ISseService {
|
|||||||
sendErrorEvent(emitter, e.getMessage());
|
sendErrorEvent(emitter, e.getMessage());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ollamaAPI.chat(requestModel, streamHandler);
|
api.chat(requestModel, streamHandler);
|
||||||
emitter.complete();
|
emitter.complete();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
sendErrorEvent(emitter, e.getMessage());
|
sendErrorEvent(emitter, e.getMessage());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
return emitter;
|
return emitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -620,6 +659,4 @@ public class SseServiceImpl implements ISseService {
|
|||||||
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
||||||
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
|
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,4 +107,11 @@ public class SysModelServiceImpl implements ISysModelService {
|
|||||||
}
|
}
|
||||||
return baseMapper.deleteBatchIds(ids) > 0;
|
return baseMapper.deleteBatchIds(ids) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SysModel selectModelByName(String modelName) {
|
||||||
|
return baseMapper.selectOne(
|
||||||
|
new LambdaQueryWrapper<SysModel>().eq(SysModel::getModelName, modelName)
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user