diff --git a/backend/build.gradle.kts b/backend/build.gradle.kts index 0b141bd..81114d6 100644 --- a/backend/build.gradle.kts +++ b/backend/build.gradle.kts @@ -61,6 +61,7 @@ dependencies { implementation("org.springframework.boot:spring-boot-starter-quartz") implementation("dev.langchain4j:langchain4j:1.0.0") implementation("dev.langchain4j:langchain4j-open-ai:1.0.0") + implementation("dev.langchain4j:langchain4j-pgvector:1.0.1-beta6") implementation("dev.langchain4j:langchain4j-community-zhipu-ai:1.0.1-beta6") implementation("io.projectreactor:reactor-core:3.7.6") testImplementation("org.testcontainers:junit-jupiter:$testcontainersVersion") diff --git a/backend/src/main/java/com/zl/mjga/config/ai/EmbeddingConfig.java b/backend/src/main/java/com/zl/mjga/config/ai/EmbeddingConfig.java new file mode 100644 index 0000000..a0ac806 --- /dev/null +++ b/backend/src/main/java/com/zl/mjga/config/ai/EmbeddingConfig.java @@ -0,0 +1,45 @@ +package com.zl.mjga.config.ai; + +import dev.langchain4j.community.model.zhipu.ZhipuAiEmbeddingModel; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; +import jakarta.annotation.Resource; +import lombok.RequiredArgsConstructor; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.DependsOn; +import org.springframework.core.env.Environment; + +@Configuration +@RequiredArgsConstructor +public class EmbeddingConfig { + + @Resource private Environment env; + + @Bean + @DependsOn("flywayInitializer") + public EmbeddingModel zhipuEmbeddingModel(ZhiPuConfiguration zhiPuConfiguration) { + return ZhipuAiEmbeddingModel.builder() + .apiKey(zhiPuConfiguration.getApiKey()) + .model(zhiPuConfiguration.getEmbeddingModel()) + .dimensions(2048) + .build(); + } + + @Bean + public EmbeddingStore zhiPuEmbeddingStore(EmbeddingModel zhipuEmbeddingModel) { + String hostPort = env.getProperty("DATABASE_HOST_PORT"); + String host = hostPort.split(":")[0]; + return PgVectorEmbeddingStore.builder() + .host(host) + .port(env.getProperty("DATABASE_EXPOSE_PORT", Integer.class)) + .database(env.getProperty("DATABASE_DB")) + .user(env.getProperty("DATABASE_USER")) + .password(env.getProperty("DATABASE_PASSWORD")) + .table("mjga.zhipu_embedding_store") + .dimension(2048) + .build(); + } +} diff --git a/backend/src/main/java/com/zl/mjga/config/ai/ZhiPuConfiguration.java b/backend/src/main/java/com/zl/mjga/config/ai/ZhiPuConfiguration.java index f94a679..603cb2a 100644 --- a/backend/src/main/java/com/zl/mjga/config/ai/ZhiPuConfiguration.java +++ b/backend/src/main/java/com/zl/mjga/config/ai/ZhiPuConfiguration.java @@ -11,10 +11,12 @@ public class ZhiPuConfiguration { private String baseUrl; private String apiKey; private String modelName; + private String embeddingModel; public void init(AiLlmConfig config) { this.baseUrl = config.getUrl(); this.apiKey = config.getApiKey(); this.modelName = config.getModelName(); + this.embeddingModel = config.getEmbeddingModel(); } } diff --git a/backend/src/main/java/com/zl/mjga/controller/AiController.java b/backend/src/main/java/com/zl/mjga/controller/AiController.java index b77f900..0d1d2a1 100644 --- a/backend/src/main/java/com/zl/mjga/controller/AiController.java +++ b/backend/src/main/java/com/zl/mjga/controller/AiController.java @@ -4,15 +4,20 @@ import com.zl.mjga.dto.PageRequestDto; import com.zl.mjga.dto.PageResponseDto; import com.zl.mjga.dto.ai.LlmQueryDto; import com.zl.mjga.dto.ai.LlmVm; +import com.zl.mjga.exception.BusinessException; import com.zl.mjga.service.AiChatService; +import com.zl.mjga.service.EmbeddingService; import com.zl.mjga.service.LlmService; import dev.langchain4j.service.TokenStream; import jakarta.validation.Valid; import java.security.Principal; import java.time.Duration; import java.util.List; +import java.util.Map; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.jooq.generated.mjga.enums.LlmCodeEnum; +import org.jooq.generated.mjga.tables.pojos.AiLlmConfig; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.security.access.prepost.PreAuthorize; @@ -28,6 +33,7 @@ public class AiController { private final AiChatService aiChatService; private final LlmService llmService; + private final EmbeddingService embeddingService; @PostMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux chat(Principal principal, @RequestBody String userMessage) { @@ -57,4 +63,13 @@ public class AiController { @ModelAttribute PageRequestDto pageRequestDto, @ModelAttribute LlmQueryDto llmQueryDto) { return llmService.pageQueryLlm(pageRequestDto, llmQueryDto); } + + @PostMapping("/action/chat") + public Map actionChat(@RequestBody String message) { + AiLlmConfig aiLlmConfig = llmService.loadConfig(LlmCodeEnum.ZHI_PU); + if (!aiLlmConfig.getEnable()) { + throw new BusinessException("命令模型未启用,请开启后再试。"); + } + return embeddingService.searchAction(message); + } } diff --git a/backend/src/main/java/com/zl/mjga/model/urp/Actions.java b/backend/src/main/java/com/zl/mjga/model/urp/Actions.java new file mode 100644 index 0000000..30ec51c --- /dev/null +++ b/backend/src/main/java/com/zl/mjga/model/urp/Actions.java @@ -0,0 +1,14 @@ +package com.zl.mjga.model.urp; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +@Getter +public enum Actions { + CREATE_USER("CREATE_USER", "创建新用户"), + DELETE_USER("DELETE_USER", "删除用户"); + public static final String INDEX_KEY = "action"; + private final String code; + private final String content; +} diff --git a/backend/src/main/java/com/zl/mjga/service/EmbeddingService.java b/backend/src/main/java/com/zl/mjga/service/EmbeddingService.java new file mode 100644 index 0000000..797d898 --- /dev/null +++ b/backend/src/main/java/com/zl/mjga/service/EmbeddingService.java @@ -0,0 +1,65 @@ +package com.zl.mjga.service; + +import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey; + +import com.zl.mjga.model.urp.Actions; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; +import jakarta.annotation.PostConstruct; +import java.util.HashMap; +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.springframework.context.annotation.Configuration; +import org.springframework.stereotype.Service; + +@Configuration +@RequiredArgsConstructor +@Service +public class EmbeddingService { + + private final EmbeddingModel zhipuEmbeddingModel; + + private final EmbeddingStore zhiPuEmbeddingStore; + + public Map searchAction(String message) { + Map result = new HashMap<>(); + EmbeddingSearchRequest embeddingSearchRequest = + EmbeddingSearchRequest.builder() + .queryEmbedding(zhipuEmbeddingModel.embed(message).content()) + .build(); + EmbeddingSearchResult embeddingSearchResult = + zhiPuEmbeddingStore.search(embeddingSearchRequest); + if (!embeddingSearchResult.matches().isEmpty()) { + result = embeddingSearchResult.matches().getFirst().embedded().metadata().toMap(); + } + return result; + } + + @PostConstruct + public void initActionIndex() { + for (Actions action : Actions.values()) { + Embedding queryEmbedding = zhipuEmbeddingModel.embed(action.getContent()).content(); + Filter createUserFilter = metadataKey(Actions.INDEX_KEY).isEqualTo(action.getCode()); + EmbeddingSearchRequest embeddingSearchRequest = + EmbeddingSearchRequest.builder() + .queryEmbedding(queryEmbedding) + .filter(createUserFilter) + .build(); + EmbeddingSearchResult embeddingSearchResult = + zhiPuEmbeddingStore.search(embeddingSearchRequest); + if (embeddingSearchResult.matches().isEmpty()) { + TextSegment segment = + TextSegment.from( + action.getContent(), Metadata.metadata(Actions.INDEX_KEY, action.getCode())); + Embedding embedding = zhipuEmbeddingModel.embed(segment).content(); + zhiPuEmbeddingStore.add(embedding, segment); + } + } + } +} diff --git a/backend/src/main/resources/db/migration/V1_0_0__init_table.sql b/backend/src/main/resources/db/migration/V1_0_0__init_table.sql index 7a25894..9f1c677 100644 --- a/backend/src/main/resources/db/migration/V1_0_0__init_table.sql +++ b/backend/src/main/resources/db/migration/V1_0_0__init_table.sql @@ -77,6 +77,7 @@ CREATE TABLE mjga.ai_llm_config ( name VARCHAR(255) NOT NULL UNIQUE, code mjga.llm_code_enum NOT NULL UNIQUE, model_name VARCHAR(255) NOT NULL, + embedding_model VARCHAR(255) NOT NULL, api_key VARCHAR(255) NOT NULL, url VARCHAR(255) NOT NULL, enable BOOLEAN NOT NULL DEFAULT true, diff --git a/backend/src/main/resources/db/migration/V1_0_1__insert_init_table.sql b/backend/src/main/resources/db/migration/V1_0_1__insert_init_table.sql index 05b4ee9..70f8f7e 100644 --- a/backend/src/main/resources/db/migration/V1_0_1__insert_init_table.sql +++ b/backend/src/main/resources/db/migration/V1_0_1__insert_init_table.sql @@ -33,7 +33,7 @@ VALUES (1, 1), (1, 9), (1, 10); -INSERT INTO mjga.ai_llm_config (name,code,model_name, api_key, url, enable, priority) +INSERT INTO mjga.ai_llm_config (name,code,model_name,embedding_model, api_key, url, enable, priority) VALUES - ('DeepSeek','DEEP_SEEK','deepseek-chat','your_api_key', 'https://api.deepseek.com', false, 0), - ('智谱清言','ZHI_PU','glm-4-flash', 'your_api_key', 'https://open.bigmodel.cn/', false, 1); \ No newline at end of file + ('DeepSeek','DEEP_SEEK','deepseek-chat','embedding-model-name','your_api_key', 'https://api.deepseek.com', false, 0), + ('智谱清言','ZHI_PU','glm-4-flash','embedding-model-name', 'your_api_key', 'https://open.bigmodel.cn/', false, 1); \ No newline at end of file