feat: 增加联网查询功能

This commit is contained in:
ageer
2025-03-12 00:17:47 +08:00
parent 6a1b544545
commit d8fda15593
10 changed files with 119 additions and 1067 deletions

View File

@@ -1,4 +1,4 @@
package org.ruoyi.common.chat.demo.zhipu;
package org.ruoyi.common.chat.demo;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
@@ -17,34 +17,11 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.zhipu.oapi.core.response.HttpxBinaryResponseContent;
import com.zhipu.oapi.service.v4.batchs.BatchCreateParams;
import com.zhipu.oapi.service.v4.batchs.BatchResponse;
import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse;
import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse;
import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest;
import com.zhipu.oapi.service.v4.file.*;
import com.zhipu.oapi.service.v4.fine_turning.*;
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import com.zhipu.oapi.service.v4.model.*;
import io.reactivex.Flowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
public class WebSearchToolsTest {

View File

@@ -1,122 +0,0 @@
package org.ruoyi.common.chat.demo.zhipu;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.Constants;
import com.zhipu.oapi.service.v4.deserialize.MessageDeserializeFactory;
import com.zhipu.oapi.service.v4.model.*;
import io.reactivex.Flowable;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class AllToolsTest {
private final static Logger logger = LoggerFactory.getLogger(AllToolsTest.class);
private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0";
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
private static final ObjectMapper mapper = MessageDeserializeFactory.defaultObjectMapper();
// 请自定义自己的业务id
private static final String requestIdTemplate = "mycompany-%d";
@Test
public void test1() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "帮我查询北京天气");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
// 函数调用参数构建部分
List<ChatTool> chatToolList = new ArrayList<>();
ChatTool chatTool = new ChatTool();
chatTool.setType("code_interpreter");
ObjectNode objectNode = mapper.createObjectNode();
objectNode.put("code", "北京天气");
// chatTool.set(chatFunction);
chatToolList.add(chatTool);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model("glm-4-alltools")
.stream(Boolean.TRUE)
.invokeMethod(Constants.invokeMethod)
.messages(messages)
.tools(chatToolList)
.toolChoice("auto")
.requestId(requestId)
.build();
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
if (sseModelApiResp.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
List<Choice> choices = new ArrayList<>();
AtomicReference<ChatMessageAccumulator> lastAccumulator = new AtomicReference<>();
mapStreamToAccumulator(sseModelApiResp.getFlowable())
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
logger.info("Response: ");
}
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
logger.info("tool_calls: {}", jsonString);
}
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
logger.info(accumulator.getDelta().getContent());
}
choices.add(accumulator.getChoice());
lastAccumulator.set(accumulator);
}
})
.doOnComplete(() -> System.out.println("Stream completed."))
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get();
ModelData data = new ModelData();
data.setChoices(choices);
if (chatMessageAccumulator != null) {
data.setUsage(chatMessageAccumulator.getUsage());
data.setId(chatMessageAccumulator.getId());
data.setCreated(chatMessageAccumulator.getCreated());
}
data.setRequestId(chatCompletionRequest.getRequestId());
sseModelApiResp.setFlowable(null);// 打印前置空
sseModelApiResp.setData(data);
}
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
client.getConfig().getHttpClient().connectionPool().evictAll();
// List all active threads
for (Thread t : Thread.getAllStackTraces().keySet()) {
logger.info("Thread: " + t.getName() + " State: " + t.getState());
}
}
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
return flowable.map(chunk -> {
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
});
}
}

View File

@@ -1,646 +0,0 @@
package org.ruoyi.common.chat.demo.zhipu;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.Constants;
import com.zhipu.oapi.core.response.HttpxBinaryResponseContent;
import com.zhipu.oapi.service.v4.batchs.BatchCreateParams;
import com.zhipu.oapi.service.v4.batchs.BatchResponse;
import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse;
import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse;
import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest;
import com.zhipu.oapi.service.v4.file.*;
import com.zhipu.oapi.service.v4.fine_turning.*;
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import com.zhipu.oapi.service.v4.model.*;
import io.reactivex.Flowable;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
public class V4Test {
private final static Logger logger = LoggerFactory.getLogger(V4Test.class);
private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0";
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
.enableTokenCache()
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
// 请自定义自己的业务id
private static final String requestIdTemplate = "mycompany-%d";
private static final ObjectMapper mapper = new ObjectMapper();
public static ObjectMapper defaultObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
return mapper;
}
@Test
public void test() {
}
/**
* sse-V4function调用
*/
@Test
public void testFunctionSSE() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "成都到北京要多久,天气如何");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
// 函数调用参数构建部分
List<ChatTool> chatToolList = new ArrayList<>();
ChatTool chatTool = new ChatTool();
chatTool.setType(ChatToolType.FUNCTION.value());
ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters();
chatFunctionParameters.setType("object");
Map<String, Object> properties = new HashMap<>();
properties.put("location", new HashMap<String, Object>() {{
put("type", "string");
put("description", "城市,如:北京");
}});
properties.put("unit", new HashMap<String, Object>() {{
put("type", "string");
put("enum", new ArrayList<String>() {{
add("celsius");
add("fahrenheit");
}});
}});
chatFunctionParameters.setProperties(properties);
ChatFunction chatFunction = ChatFunction.builder()
.name("get_weather")
.description("Get the current weather of a location")
.parameters(chatFunctionParameters)
.build();
chatTool.setFunction(chatFunction);
chatToolList.add(chatTool);
HashMap<String, Object> extraJson = new HashMap<>();
extraJson.put("temperature", 0.5);
extraJson.put("max_tokens", 50);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.TRUE)
.messages(messages)
.requestId(requestId)
.tools(chatToolList)
.toolChoice("auto")
.extraJson(extraJson)
.build();
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
if (sseModelApiResp.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
List<Choice> choices = new ArrayList<>();
ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable())
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
logger.info("Response: ");
}
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
logger.info("tool_calls: {}", jsonString);
}
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
logger.info(accumulator.getDelta().getContent());
}
choices.add(accumulator.getChoice());
}
})
.doOnComplete(System.out::println)
.lastElement()
.blockingGet();
ModelData data = new ModelData();
data.setChoices(choices);
data.setUsage(chatMessageAccumulator.getUsage());
data.setId(chatMessageAccumulator.getId());
data.setCreated(chatMessageAccumulator.getCreated());
data.setRequestId(chatCompletionRequest.getRequestId());
sseModelApiResp.setFlowable(null);// 打印前置空
sseModelApiResp.setData(data);
}
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
}
/**
* sse-V4非function调用
*/
@Test
public void testNonFunctionSSE() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大");
messages.add(chatMessage);
HashMap<String, Object> extraJson = new HashMap<>();
extraJson.put("temperature", 0.5);
extraJson.put("max_tokens", 3);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.TRUE)
.messages(messages)
.requestId(requestId)
.extraJson(extraJson)
.build();
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
// stream 处理方法
if (sseModelApiResp.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
List<Choice> choices = new ArrayList<>();
ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable())
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
logger.info("Response: ");
}
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
logger.info("tool_calls: {}", jsonString);
}
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
logger.info("accumulator.getDelta().getContent(): {}", accumulator.getDelta().getContent());
}
choices.add(accumulator.getChoice());
}
})
.doOnComplete(System.out::println)
.lastElement()
.blockingGet();
ModelData data = new ModelData();
data.setChoices(choices);
data.setUsage(chatMessageAccumulator.getUsage());
data.setId(chatMessageAccumulator.getId());
data.setCreated(chatMessageAccumulator.getCreated());
data.setRequestId(chatCompletionRequest.getRequestId());
sseModelApiResp.setFlowable(null);// 打印前置空
sseModelApiResp.setData(data);
}
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
}
/**
* V4-同步function调用
*/
@Test
public void testFunctionInvoke() {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "你可以做什么");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
// 函数调用参数构建部分
List<ChatTool> chatToolList = new ArrayList<>();
ChatTool chatTool = new ChatTool();
chatTool.setType(ChatToolType.FUNCTION.value());
ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters();
chatFunctionParameters.setType("object");
Map<String, Object> properties = new HashMap<>();
properties.put("location", new HashMap<String, Object>() {{
put("type", "string");
put("description", "城市,如:北京");
}});
properties.put("unit", new HashMap<String, Object>() {{
put("type", "string");
put("enum", new ArrayList<String>() {{
add("celsius");
add("fahrenheit");
}});
}});
chatFunctionParameters.setProperties(properties);
ChatFunction chatFunction = ChatFunction.builder()
.name("get_weather")
.description("Get the current weather of a location")
.parameters(chatFunctionParameters)
.build();
chatTool.setFunction(chatFunction);
ChatTool chatTool1 = new ChatTool();
chatTool1.setType(ChatToolType.WEB_SEARCH.value());
WebSearch webSearch = new WebSearch();
webSearch.setSearch_query("清华的升学率");
webSearch.setSearch_result(true);
webSearch.setEnable(false);
chatTool1.setWeb_search(webSearch);
chatToolList.add(chatTool);
chatToolList.add(chatTool1);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.FALSE)
.invokeMethod(Constants.invokeMethod)
.messages(messages)
.requestId(requestId)
.tools(chatToolList)
.toolChoice("auto")
.build();
ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest);
try {
logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp));
} catch (JsonProcessingException e) {
logger.error("model output error", e);
}
}
/**
* V4-同步非function调用
*/
@Test
public void testNonFunctionInvoke() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
HashMap<String, Object> extraJson = new HashMap<>();
extraJson.put("temperature", 0.5);
extraJson.put("max_tokens", 3);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.FALSE)
.invokeMethod(Constants.invokeMethod)
.messages(messages)
.requestId(requestId)
.extraJson(extraJson)
.build();
ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest);
logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp));
}
/**
* V4-同步非function调用
*/
@Test
public void testCharGlmInvoke() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
HashMap<String, Object> extraJson = new HashMap<>();
extraJson.put("temperature", 0.5);
ChatMeta meta = new ChatMeta();
meta.setUser_info("我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。");
meta.setBot_info("苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。");
meta.setBot_name("苏梦远");
meta.setUser_name("陆星辰");
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelCharGLM3)
.stream(Boolean.FALSE)
.invokeMethod(Constants.invokeMethod)
.messages(messages)
.requestId(requestId)
.meta(meta)
.extraJson(extraJson)
.build();
ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest);
logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp));
}
/**
* V4异步调用
*/
@Test
public void testAsyncInvoke() throws JsonProcessingException {
String taskId = getAsyncTaskId();
testQueryResult(taskId);
}
//
/**
* 文生图
*/
@Test
public void testCreateImage() throws JsonProcessingException {
CreateImageRequest createImageRequest = new CreateImageRequest();
createImageRequest.setModel(Constants.ModelCogView);
createImageRequest.setPrompt("Futuristic cloud data center, showcasing advanced technologgy and a high-tech atmosp\n" +
"here. The image should depict a spacious, well-lit interior with rows of server racks, glo\n" +
"wing lights, and digital displays. Include abstract representattions of data streams and\n" +
"onnectivity, symbolizing the essence of cloud computing. Thee style should be modern a\n" +
"nd sleek, with a focus on creating a sense of innovaticon and cutting-edge technology\n" +
"The overall ambiance should convey the power and effciency of cloud services in a visu\n" +
"ally engaging way.");
createImageRequest.setRequestId("test11111111111111");
ImageApiResponse imageApiResponse = client.createImage(createImageRequest);
logger.info("imageApiResponse: {}", mapper.writeValueAsString(imageApiResponse));
}
//
// /**
// * 图生文
// */
// @Test
// public void testImageToWord() throws JsonProcessingException {
// List<ChatMessage> messages = new ArrayList<>();
// List<Map<String, Object>> contentList = new ArrayList<>();
// Map<String, Object> textMap = new HashMap<>();
// textMap.put("type", "text");
// textMap.put("text", "图里有什么");
// Map<String, Object> typeMap = new HashMap<>();
// typeMap.put("type", "image_url");
// Map<String, Object> urlMap = new HashMap<>();
// urlMap.put("url", "https://sfile.chatglm.cn/testpath/275ae5b6-5390-51ca-a81a-60332d1a7cac_0.png");
// typeMap.put("image_url", urlMap);
// contentList.add(textMap);
// contentList.add(typeMap);
// ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), contentList);
// messages.add(chatMessage);
// String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
//
//
// ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
// .model(Constants.ModelChatGLM4V)
// .stream(Boolean.FALSE)
// .invokeMethod(Constants.invokeMethod)
// .messages(messages)
// .requestId(requestId)
// .build();
// ModelApiResponse modelApiResponse = client.invokeModelApi(chatCompletionRequest);
// logger.info("model output: {}", mapper.writeValueAsString(modelApiResponse));
// }
//
/**
* 向量模型V4
*/
@Test
public void testEmbeddings() throws JsonProcessingException {
EmbeddingRequest embeddingRequest = new EmbeddingRequest();
embeddingRequest.setInput("hello world");
embeddingRequest.setModel(Constants.ModelEmbedding2);
EmbeddingApiResponse apiResponse = client.invokeEmbeddingsApi(embeddingRequest);
logger.info("model output: {}", mapper.writeValueAsString(apiResponse));
}
/**
* V4微调上传数据集
*/
@Test
public void testUploadFile() throws JsonProcessingException {
String filePath = "demo.jsonl";
String path = ClassLoader.getSystemResource(filePath).getPath();
String purpose = "fine-tune";
UploadFileRequest request = UploadFileRequest.builder()
.purpose(purpose)
.filePath(path)
.build();
FileApiResponse fileApiResponse = client.invokeUploadFileApi(request);
logger.info("model output: {}", mapper.writeValueAsString(fileApiResponse));
}
/**
* 微调V4-查询上传文件列表
*/
@Test
public void testQueryUploadFileList() throws JsonProcessingException {
QueryFilesRequest queryFilesRequest = new QueryFilesRequest();
QueryFileApiResponse queryFileApiResponse = client.queryFilesApi(queryFilesRequest);
logger.info("model output: {}", mapper.writeValueAsString(queryFileApiResponse));
}
@Test
public void testFileContent() throws IOException {
try {
HttpxBinaryResponseContent httpxBinaryResponseContent = client.fileContent("20240514_ea19d21b-d256-4586-b0df-e80a45e3c286");
String filePath = "demo_output.jsonl";
String resourcePath = V4Test.class.getClassLoader().getResource("").getPath();
httpxBinaryResponseContent.streamToFile(resourcePath + "1" + filePath, 1000);
} catch (IOException e) {
logger.error("file content error", e);
}
}
//// @Test
//// public void deletedFile() throws IOException {
//// FileDelResponse fileDelResponse = client.deletedFile("20240514_ea19d21b-d256-4586-b0df-e80a45e3c286");
////
//// logger.info("model output: {}", mapper.writeValueAsString(fileDelResponse));
////
//// }
//
//
/**
* 微调V4-创建微调任务
*/
@Test
public void testCreateFineTuningJob() throws JsonProcessingException {
FineTuningJobRequest request = new FineTuningJobRequest();
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
request.setRequestId(requestId);
request.setModel("chatglm3-6b");
request.setTraining_file("file-20240118082608327-kp8qr");
CreateFineTuningJobApiResponse createFineTuningJobApiResponse = client.createFineTuningJob(request);
logger.info("model output: {}", mapper.writeValueAsString(createFineTuningJobApiResponse));
}
/**
* 微调V4-查询微调任务
*/
@Test
public void testRetrieveFineTuningJobs() throws JsonProcessingException {
QueryFineTuningJobRequest queryFineTuningJobRequest = new QueryFineTuningJobRequest();
queryFineTuningJobRequest.setJobId("ftjob-20240429172916475-fb7r9");
// queryFineTuningJobRequest.setLimit(1);
// queryFineTuningJobRequest.setAfter(1);
QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.retrieveFineTuningJobs(queryFineTuningJobRequest);
logger.info("model output: {}", mapper.writeValueAsString(queryFineTuningJobApiResponse));
}
/**
* 微调V4-查询微调任务
*/
@Test
public void testFueryFineTuningJobsEvents() throws JsonProcessingException {
QueryFineTuningJobRequest queryFineTuningJobRequest = new QueryFineTuningJobRequest();
queryFineTuningJobRequest.setJobId("ftjob-20240429172916475-fb7r9");
QueryFineTuningEventApiResponse queryFineTuningEventApiResponse = client.queryFineTuningJobsEvents(queryFineTuningJobRequest);
logger.info("model output: {}", mapper.writeValueAsString(queryFineTuningEventApiResponse));
}
/**
* testQueryPersonalFineTuningJobs V4-查询个人微调任务
*/
@Test
public void testQueryPersonalFineTuningJobs() throws JsonProcessingException {
QueryPersonalFineTuningJobRequest queryPersonalFineTuningJobRequest = new QueryPersonalFineTuningJobRequest();
queryPersonalFineTuningJobRequest.setLimit(1);
QueryPersonalFineTuningJobApiResponse queryPersonalFineTuningJobApiResponse = client.queryPersonalFineTuningJobs(queryPersonalFineTuningJobRequest);
logger.info("model output: {}", mapper.writeValueAsString(queryPersonalFineTuningJobApiResponse));
}
@Test
public void testBatchesCreate() {
BatchCreateParams batchCreateParams = new BatchCreateParams(
"24h",
"/v4/chat/completions",
"20240514_ea19d21b-d256-4586-b0df-e80a45e3c286",
new HashMap<String, String>() {{
put("key1", "value1");
put("key2", "value2");
}}
);
BatchResponse batchResponse = client.batchesCreate(batchCreateParams);
logger.info("output: {}", batchResponse);
// output: BatchResponse(code=200, msg=调用成功, success=true, data=Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847751822, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=null, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=null, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null))
}
@Test
public void testDeleteFineTuningJob() {
FineTuningJobIdRequest request = FineTuningJobIdRequest.builder().jobId("test").build();
QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.deleteFineTuningJob(request);
logger.info("output: {}", queryFineTuningJobApiResponse);
}
@Test
public void testCancelFineTuningJob() {
FineTuningJobIdRequest request = FineTuningJobIdRequest.builder().jobId("test").build();
QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.cancelFineTuningJob(request);
logger.info("output: {}", queryFineTuningJobApiResponse);
}
@Test
public void testBatchesRetrieve() {
BatchResponse batchResponse = client.batchesRetrieve("batch_1791021399316246528");
logger.info("output: {}", batchResponse);
}
@Test
public void testDeleteFineTuningModel() {
FineTuningJobModelRequest request = FineTuningJobModelRequest.builder().fineTunedModel("test").build();
FineTunedModelsStatusResponse fineTunedModelsStatusResponse = client.deleteFineTuningModel(request);
logger.info("output: {}", fineTunedModelsStatusResponse);
// output: BatchResponse(code=200, msg=调用成功, success=true, data=Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847752000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null))
}
@Test
public void testBatchesList() {
QueryBatchRequest queryBatchRequest = new QueryBatchRequest();
queryBatchRequest.setLimit(10);
QueryBatchResponse queryBatchResponse = client.batchesList(queryBatchRequest);
logger.info("output: {}", queryBatchResponse);
// output: QueryBatchResponse(code=200, msg=调用成功, success=true, data=BatchPage(object=list, data=[Batch(id=batch_1790291013237211136, completionWindow=24h, createdAt=1715673614000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=completed, cancelledAt=null, cancellingAt=1715673699000, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={description=job test}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null), Batch(id=batch_1790292763050508288, completionWindow=24h, createdAt=1715674031000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=completed, cancelledAt=null, cancellingAt=null, completedAt=1715766416000, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=1715754569000, inProgressAt=null, metadata={description=job test}, outputFileId=1715766415_e5a77222855a406ca8a082de28549c99, requestCounts=BatchRequestCounts(completed=2, failed=0, total=2), error=null), Batch(id=batch_1791021114887909376, completionWindow=24h, createdAt=1715847684000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null), Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847752000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null)], error=null))
}
@Test
public void testBatchesCancel() throws JsonProcessingException {
getAsyncTaskId();
}
private static String getAsyncTaskId() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
// 函数调用参数构建部分
List<ChatTool> chatToolList = new ArrayList<>();
ChatTool chatTool = new ChatTool();
chatTool.setType(ChatToolType.FUNCTION.value());
ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters();
chatFunctionParameters.setType("object");
Map<String, Object> properties = new HashMap<>();
properties.put("location", new HashMap<String, Object>() {{
put("type", "string");
put("description", "城市,如:北京");
}});
properties.put("unit", new HashMap<String, Object>() {{
put("type", "string");
put("enum", new ArrayList<String>() {{
add("celsius");
add("fahrenheit");
}});
}});
chatFunctionParameters.setProperties(properties);
ChatFunction chatFunction = ChatFunction.builder()
.name("get_weather")
.description("Get the current weather of a location")
.parameters(chatFunctionParameters)
.build();
chatTool.setFunction(chatFunction);
chatToolList.add(chatTool);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.FALSE)
.invokeMethod(Constants.invokeMethodAsync)
.messages(messages)
.requestId(requestId)
.tools(chatToolList)
.toolChoice("auto")
.build();
ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest);
logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp));
return invokeModelApiResp.getData().getId();
}
private static void testQueryResult(String taskId) throws JsonProcessingException {
QueryModelResultRequest request = new QueryModelResultRequest();
request.setTaskId(taskId);
QueryModelResultResponse queryResultResp = client.queryModelResult(request);
logger.info("model output {}", mapper.writeValueAsString(queryResultResp));
}
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
return flowable.map(chunk -> {
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
});
}
}

View File

@@ -17,7 +17,6 @@ import org.ruoyi.common.core.exception.base.BaseException;
import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.knowledge.service.EmbeddingService;
import org.ruoyi.system.domain.bo.ChatMessageBo;
import org.ruoyi.system.domain.request.translation.TranslationRequest;
import org.ruoyi.system.domain.vo.ChatMessageVo;
@@ -48,7 +47,6 @@ public class ChatController {
private final IChatMessageService chatMessageService;
private final EmbeddingService embeddingService;
/**
* 聊天接口
*/

View File

@@ -2,6 +2,7 @@ package org.ruoyi.knowledge.chain.split;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
import org.springframework.context.annotation.Lazy;
@@ -29,7 +30,7 @@ public class CharacterTextSplitter implements TextSplitter {
int textBlockSize = knowledgeInfoVo.getTextBlockSize();
int overlapChar = knowledgeInfoVo.getOverlapChar();
List<String> chunkList = new ArrayList<>();
if (content.contains(knowledgeSeparator)) {
if (content.contains(knowledgeSeparator) && StringUtils.isNotBlank(knowledgeSeparator)) {
// 按自定义分隔符切分
String[] chunks = content.split(knowledgeSeparator);
chunkList.addAll(Arrays.asList(chunks));

View File

@@ -1,5 +1,6 @@
package org.ruoyi.knowledge.chain.vectorstore;
import cn.hutool.core.util.StrUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
@@ -23,8 +24,13 @@ public class VectorStoreFactory {
}
public VectorStore getVectorStore(String kid){
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid));
String vectorModel = knowledgeInfoVo.getVector();
String vectorModel = "weaviate";
if (StrUtil.isNotEmpty(kid)) {
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid));
if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVector())) {
vectorModel = knowledgeInfoVo.getVector();
}
}
if ("weaviate".equals(vectorModel)){
return weaviateVectorStore;
}else if ("milvus".equals(vectorModel)){

View File

@@ -1,212 +0,0 @@
package org.ruoyi.system.plugin;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson.JSONObject;
import lombok.Builder;
import lombok.Data;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.junit.Before;
import org.junit.Test;
import org.ruoyi.common.chat.demo.ConsoleEventSourceListenerV3;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.chat.Parameters;
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
import org.ruoyi.common.chat.entity.chat.tool.Tools;
import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction;
import org.ruoyi.common.chat.openai.OpenAiClient;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
public class WebSearchPlugin {
private OpenAiClient openAiClient;
private OpenAiStreamClient openAiStreamClient;
@Before
public void before() {
//可以为null
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
//千万别再生产或者测试环境打开BODY级别日志
//生产或者测试环境建议设置为这三种级别NONE,BASIC,HEADERS,
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
// .proxy(proxy)
.addInterceptor(httpLoggingInterceptor)
.addInterceptor(new OpenAiResponseInterceptor())
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(30, TimeUnit.SECONDS)
.readTimeout(30, TimeUnit.SECONDS)
.build();
openAiClient = OpenAiClient.builder()
//支持多key传入请求时候随机选择
.apiKey(Arrays.asList("xx"))
//自定义key的获取策略默认KeyRandomStrategy
//.keyStrategy(new KeyRandomStrategy())
.keyStrategy(new KeyRandomStrategy())
.okHttpClient(okHttpClient)
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复openai ,获取免费的测试代理地址)
.apiHost("https://open.bigmodel.cn/")
.build();
openAiStreamClient = OpenAiStreamClient.builder()
//支持多key传入请求时候随机选择
.apiKey(Arrays.asList("xx"))
//自定义key的获取策略默认KeyRandomStrategy
.keyStrategy(new KeyRandomStrategy())
.authInterceptor(new DynamicKeyOpenAiAuthInterceptor())
.okHttpClient(okHttpClient)
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复openai ,获取免费的测试代理地址)
.apiHost("https://open.bigmodel.cn/")
.build();
}
@Test
public void test() {
Message message = Message.builder().role(Message.Role.USER).content("今天武汉天气怎么样").build();
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(Collections.singletonList(message))
// .tools(Collections.singletonList(tools))
.model("web-search-pro")
.build();
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
System.out.printf("chatCompletionResponse=%s\n", JSONUtil.toJsonStr(chatCompletionResponse));
}
@Test
public void streamToolsChat() {
CountDownLatch countDownLatch = new CountDownLatch(1);
ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch);
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语并解释下词语对应物品的用途").build();
//属性一
JSONObject wordLength = new JSONObject();
wordLength.put("type", "number");
wordLength.put("description", "词语的长度");
//属性二
JSONObject language = new JSONObject();
language.put("type", "string");
language.put("enum", Arrays.asList("zh", "en"));
language.put("description", "语言类型例如zh代表中文、en代表英语");
//参数
JSONObject properties = new JSONObject();
properties.put("wordLength", wordLength);
properties.put("language", language);
Parameters parameters = Parameters.builder()
.type("object")
.properties(properties)
.required(Collections.singletonList("wordLength")).build();
Tools tools = Tools.builder()
.type(Tools.Type.FUNCTION.getName())
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
.build();
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(Collections.singletonList(message))
.tools(Collections.singletonList(tools))
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls();
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
String oneWord = getOneWord(wordParam);
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
//构造tool call
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
String content
= "{ " +
"\"wordLength\": \"3\", " +
"\"language\": \"zh\", " +
"\"word\": \"" + oneWord + "\"," +
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
"}";
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
List<Message> messageList = Arrays.asList(message, message2, message3);
ChatCompletion chatCompletionV2 = ChatCompletion
.builder()
.messages(messageList)
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
CountDownLatch countDownLatch1 = new CountDownLatch(1);
openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch));
try {
countDownLatch1.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
try {
countDownLatch1.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
@Data
@Builder
static class WordParam {
private int wordLength;
@Builder.Default
private String language = "zh";
}
/**
* 获取一个词语(根据语言和字符长度查询)
* @param wordParam
* @return
*/
public String getOneWord(WordParam wordParam) {
List<String> zh = Arrays.asList("大香蕉", "哈密瓜", "苹果");
List<String> en = Arrays.asList("apple", "banana", "cantaloupe");
if (wordParam.getLanguage().equals("zh")) {
for (String e : zh) {
if (e.length() == wordParam.getWordLength()) {
return e;
}
}
}
if (wordParam.getLanguage().equals("en")) {
for (String e : en) {
if (e.length() == wordParam.getWordLength()) {
return e;
}
}
}
return "西瓜";
}
}

View File

@@ -2,6 +2,7 @@ package org.ruoyi.system.service;
import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.system.domain.SysModel;
import org.ruoyi.system.domain.bo.SysModelBo;
import org.ruoyi.system.domain.vo.SysModelVo;
@@ -45,4 +46,9 @@ public interface ISysModelService {
* 校验并批量删除系统模型信息
*/
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
/**
* 根据模型名称查询模型
*/
SysModel selectModelByName(String modelName);
}

View File

@@ -1,8 +1,10 @@
package org.ruoyi.system.service.impl;
import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.core.collection.CollectionUtil;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.service.v4.tools.*;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
@@ -17,10 +19,7 @@ import org.ruoyi.common.chat.config.LocalCache;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.domain.request.Dall3Request;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.Content;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.chat.*;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.images.Image;
import org.ruoyi.common.chat.entity.images.ImageResponse;
@@ -33,17 +32,15 @@ import org.ruoyi.common.chat.plugin.CmdPlugin;
import org.ruoyi.common.chat.plugin.CmdReq;
import org.ruoyi.common.chat.plugin.SqlPlugin;
import org.ruoyi.common.chat.plugin.SqlReq;
import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
import org.ruoyi.common.chat.utils.TikTokensUtil;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.exception.base.BaseException;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.system.domain.SysModel;
import org.ruoyi.system.domain.bo.ChatMessageBo;
import org.ruoyi.system.domain.bo.SysModelBo;
import org.ruoyi.system.domain.request.translation.TranslationRequest;
import org.ruoyi.system.domain.vo.SysModelVo;
import org.ruoyi.system.listener.SSEEventSourceListener;
import org.ruoyi.system.service.*;
import org.springframework.core.io.InputStreamResource;
@@ -65,6 +62,9 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@Service
@@ -76,18 +76,21 @@ public class SseServiceImpl implements ISseService {
private final ChatConfig chatConfig;
private final IChatCostService chatService;
private final IChatMessageService chatMessageService;
private final ISysModelService sysModelService;
private final ISysUserService userService;
private final ConfigService configService;
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
private static final String requestIdTemplate = "mycompany-%d";
private static final ObjectMapper mapper = new ObjectMapper();
@Override
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
openAiStreamClient = chatConfig.getOpenAiStreamClient();
@@ -96,11 +99,10 @@ public class SseServiceImpl implements ISseService {
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
try {
String chatString = null;
if (StpUtil.isLogin()) {
LocalCache.CACHE.put("userId", getUserId());
Object content = messages.get(messages.size() - 1).getContent();
String chatString = "";
if (content instanceof List<?> listContent) {
if (!listContent.isEmpty() && listContent.get(0) instanceof Content) {
chatString = ((Content) listContent.get(0)).getText();
@@ -123,39 +125,89 @@ public class SseServiceImpl implements ISseService {
throw new BaseException("文本不合规,请修改!");
}
}
//根据模型名称查询模型信息
SysModelBo sysModelBo = new SysModelBo();
String model = chatRequest.getModel();
// 如果是gpts系列模型
if (chatRequest.getModel().startsWith("gpt-4-gizmo")) {
sysModelBo.setModelName("gpt-4-gizmo");
} else {
sysModelBo.setModelName(chatRequest.getModel());
model = "gpt-4-gizmo";
}
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
if (CollectionUtil.isEmpty(sysModelList)) {
SysModel sysModel = sysModelService.selectModelByName(model);
if (sysModel != null) {
// 如果模型不存在默认使用token扣费方式
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
} else {
openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModelList.get(0).getApiHost(), sysModelList.get(0).getApiKey());
openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModel.getApiHost(), sysModel.getApiKey());
// 模型设置默认提示词
SysModelVo firstModel = sysModelList.get(0);
if (StringUtils.isNotEmpty(firstModel.getSystemPrompt())) {
Message sysMessage = Message.builder().content(firstModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
if (StringUtils.isNotEmpty(sysModel.getSystemPrompt())) {
Message sysMessage = Message.builder().content(sysModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
messages.add(sysMessage);
}
// 计费类型: 1 token扣费 2 次数扣费
if ("2".equals(firstModel.getModelType())) {
processByModelPrice(firstModel, chatMessageBo);
if ("2".equals(sysModel.getModelType())) {
processByModelPrice(sysModel, chatMessageBo);
} else {
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
}
}
}
if("openCmd".equals(chatRequest.getModel())) {
String configValue = configService.getConfigValue("zhipu", "key");
// 添加联网信息
if(StringUtils.isNotEmpty(configValue)){
ClientV4 client = new ClientV4.Builder(configValue)
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
SearchChatMessage jsonNodes = new SearchChatMessage();
jsonNodes.setRole(Message.Role.USER.getName());
jsonNodes.setContent(chatString);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
.model("web-search-pro")
.stream(Boolean.TRUE)
.messages(Collections.singletonList(jsonNodes))
.requestId(requestId)
.build();
WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
List<ChoiceDelta> choices = new ArrayList<>();
if (webSearchApiResponse.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
webSearchApiResponse.getFlowable().map(result -> result)
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
log.info("Response: ");
}
ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
if (delta != null && delta.getToolCalls() != null) {
log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
}
choices.add(delta);
}
})
.doOnComplete(() -> System.out.println("Stream completed."))
.doOnError(throwable -> System.err.println("Error: " + throwable))
.blockingSubscribe();
WebSearchPro chatMessageAccumulator = lastAccumulator.get();
webSearchApiResponse.setFlowable(null);// 打印前置空
webSearchApiResponse.setData(chatMessageAccumulator);
}
Message message = Message.builder().role(Message.Role.ASSISTANT).content(choices.get(1).getToolCalls().toString()).build();
messages.add(message);
}
if ("openCmd".equals(chatRequest.getModel())) {
sseEmitter.send(cmdPlugin(messages));
sseEmitter.complete();
}else if ("sqlPlugin".equals(chatRequest.getModel())){
} else if ("sqlPlugin".equals(chatRequest.getModel())) {
sseEmitter.send(sqlPlugin(messages));
sseEmitter.complete();
} else {
@@ -229,7 +281,7 @@ public class SseServiceImpl implements ISseService {
* @param model 模型信息
* @param chatMessageBo 对话信息
*/
private void processByModelPrice(SysModelVo model, ChatMessageBo chatMessageBo) {
private void processByModelPrice(SysModel model, ChatMessageBo chatMessageBo) {
double cost = model.getModelPrice();
chatService.deductUserBalance(getUserId(), cost);
chatMessageBo.setDeductCost(cost);
@@ -316,16 +368,14 @@ public class SseServiceImpl implements ISseService {
.style(request.getStyle())
.build();
ImageResponse imageResponse = openAiStreamClient.genImages(image);
SysModelBo sysModelBo = new SysModelBo();
sysModelBo.setModelName(request.getModel());
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
SysModel sysModel = sysModelService.selectModelByName(request.getModel());
//chatService.deductUserBalance(getUserId(),sysModelList.get(0).getModelPrice());
// 保存消息记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
chatMessageBo.setContent(request.getPrompt());
chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
chatMessageBo.setDeductCost(sysModel.getModelPrice());
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
return imageResponse.getData();
@@ -342,16 +392,14 @@ public class SseServiceImpl implements ISseService {
.n(1)
.build();
ImageResponse imageResponse = openAiStreamClient.genImages(image);
SysModelBo sysModelBo = new SysModelBo();
sysModelBo.setModelName("dall3");
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
SysModel dall3 = sysModelService.selectModelByName("dall3");
chatService.deductUserBalance(Long.valueOf(userId), 0.3);
// 保存消息记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
chatMessageBo.setContent(prompt);
chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
chatMessageBo.setDeductCost(dall3.getModelPrice());
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
return imageResponse.getData();
@@ -527,12 +575,9 @@ public class SseServiceImpl implements ISseService {
chatMessageBo.setDeductCost(0.01);
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
openAiStreamClient = chatConfig.getOpenAiStreamClient();
List<Message> messageList = new ArrayList<>();
Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一名翻译老师\n" +
Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一位精通各国语言的翻译大师\n" +
"\n" +
"请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" +
"\n" +
@@ -563,25 +608,21 @@ public class SseServiceImpl implements ISseService {
@Override
public SseEmitter ollamaChat(ChatRequest chatRequest) {
String[] parts = chatRequest.getModel().split("ollama-");
SysModel sysModel = sysModelService.selectModelByName(parts[1]);
final SseEmitter emitter = new SseEmitter();
String host = "http://localhost:11434/";
String host = sysModel.getApiHost();
List<Message> msgList = chatRequest.getMessages();
Message message = msgList.get(msgList.size() - 1);
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(100);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("qwen2.5:7b");
OllamaAPI api = new OllamaAPI(host);
api.setRequestTimeoutSeconds(100);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(sysModel.getModelName());
OllamaChatRequestModel requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
message.getContent().toString())
.build();
// 异步执行 Ollama API 调用
// 异步执行 OllAma API 调用
CompletableFuture.runAsync(() -> {
try {
StringBuilder response = new StringBuilder();
@@ -595,14 +636,12 @@ public class SseServiceImpl implements ISseService {
sendErrorEvent(emitter, e.getMessage());
}
};
ollamaAPI.chat(requestModel, streamHandler);
api.chat(requestModel, streamHandler);
emitter.complete();
} catch (Exception e) {
sendErrorEvent(emitter, e.getMessage());
}
});
return emitter;
}
@@ -620,6 +659,4 @@ public class SseServiceImpl implements ISseService {
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
}
}

View File

@@ -107,4 +107,11 @@ public class SysModelServiceImpl implements ISysModelService {
}
return baseMapper.deleteBatchIds(ids) > 0;
}
@Override
public SysModel selectModelByName(String modelName) {
return baseMapper.selectOne(
new LambdaQueryWrapper<SysModel>().eq(SysModel::getModelName, modelName)
);
}
}