mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-16 21:33:40 +00:00
feat: 增加联网查询功能
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.common.chat.demo.zhipu;
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
@@ -17,34 +17,11 @@ import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
|
||||
import com.zhipu.oapi.core.response.HttpxBinaryResponseContent;
|
||||
import com.zhipu.oapi.service.v4.batchs.BatchCreateParams;
|
||||
import com.zhipu.oapi.service.v4.batchs.BatchResponse;
|
||||
import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse;
|
||||
import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse;
|
||||
import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest;
|
||||
import com.zhipu.oapi.service.v4.file.*;
|
||||
import com.zhipu.oapi.service.v4.fine_turning.*;
|
||||
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
|
||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||
import com.zhipu.oapi.service.v4.model.*;
|
||||
import io.reactivex.Flowable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
|
||||
public class WebSearchToolsTest {
|
||||
@@ -1,122 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo.zhipu;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.zhipu.oapi.ClientV4;
|
||||
import com.zhipu.oapi.Constants;
|
||||
import com.zhipu.oapi.service.v4.deserialize.MessageDeserializeFactory;
|
||||
import com.zhipu.oapi.service.v4.model.*;
|
||||
import io.reactivex.Flowable;
|
||||
import org.junit.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
||||
public class AllToolsTest {
|
||||
|
||||
private final static Logger logger = LoggerFactory.getLogger(AllToolsTest.class);
|
||||
private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0";
|
||||
|
||||
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
|
||||
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
||||
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
||||
.build();
|
||||
private static final ObjectMapper mapper = MessageDeserializeFactory.defaultObjectMapper();
|
||||
// 请自定义自己的业务id
|
||||
private static final String requestIdTemplate = "mycompany-%d";
|
||||
|
||||
|
||||
@Test
|
||||
public void test1() throws JsonProcessingException {
|
||||
|
||||
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "帮我查询北京天气");
|
||||
messages.add(chatMessage);
|
||||
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||
// 函数调用参数构建部分
|
||||
List<ChatTool> chatToolList = new ArrayList<>();
|
||||
ChatTool chatTool = new ChatTool();
|
||||
|
||||
chatTool.setType("code_interpreter");
|
||||
ObjectNode objectNode = mapper.createObjectNode();
|
||||
objectNode.put("code", "北京天气");
|
||||
// chatTool.set(chatFunction);
|
||||
chatToolList.add(chatTool);
|
||||
|
||||
|
||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||
.model("glm-4-alltools")
|
||||
.stream(Boolean.TRUE)
|
||||
.invokeMethod(Constants.invokeMethod)
|
||||
.messages(messages)
|
||||
.tools(chatToolList)
|
||||
.toolChoice("auto")
|
||||
.requestId(requestId)
|
||||
.build();
|
||||
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
|
||||
if (sseModelApiResp.isSuccess()) {
|
||||
AtomicBoolean isFirst = new AtomicBoolean(true);
|
||||
List<Choice> choices = new ArrayList<>();
|
||||
AtomicReference<ChatMessageAccumulator> lastAccumulator = new AtomicReference<>();
|
||||
|
||||
mapStreamToAccumulator(sseModelApiResp.getFlowable())
|
||||
.doOnNext(accumulator -> {
|
||||
{
|
||||
if (isFirst.getAndSet(false)) {
|
||||
logger.info("Response: ");
|
||||
}
|
||||
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
|
||||
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
|
||||
logger.info("tool_calls: {}", jsonString);
|
||||
}
|
||||
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
|
||||
logger.info(accumulator.getDelta().getContent());
|
||||
}
|
||||
choices.add(accumulator.getChoice());
|
||||
lastAccumulator.set(accumulator);
|
||||
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> System.out.println("Stream completed."))
|
||||
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
|
||||
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
|
||||
|
||||
ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get();
|
||||
ModelData data = new ModelData();
|
||||
data.setChoices(choices);
|
||||
if (chatMessageAccumulator != null) {
|
||||
data.setUsage(chatMessageAccumulator.getUsage());
|
||||
data.setId(chatMessageAccumulator.getId());
|
||||
data.setCreated(chatMessageAccumulator.getCreated());
|
||||
}
|
||||
data.setRequestId(chatCompletionRequest.getRequestId());
|
||||
sseModelApiResp.setFlowable(null);// 打印前置空
|
||||
sseModelApiResp.setData(data);
|
||||
}
|
||||
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
|
||||
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
|
||||
|
||||
client.getConfig().getHttpClient().connectionPool().evictAll();
|
||||
// List all active threads
|
||||
for (Thread t : Thread.getAllStackTraces().keySet()) {
|
||||
logger.info("Thread: " + t.getName() + " State: " + t.getState());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
|
||||
return flowable.map(chunk -> {
|
||||
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
|
||||
});
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user