diff --git a/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java b/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java index 4c9727d..0f14e2f 100644 --- a/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java +++ b/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java @@ -1,5 +1,7 @@ package com.zl.mjga.config.ai; +import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey; + import com.zl.mjga.component.PromptConfiguration; import com.zl.mjga.service.LlmService; import dev.langchain4j.community.model.zhipu.ZhipuAiStreamingChatModel; @@ -72,6 +74,11 @@ public class ChatModelInitializer { .embeddingModel(zhipuEmbeddingModel) .minScore(0.75) .maxResults(5) + .dynamicFilter( + query -> { + String libraryId = (String) query.metadata().chatMemoryId(); + return metadataKey("libraryId").isEqualTo(libraryId); + }) .build()) .build(); } diff --git a/backend/src/main/java/com/zl/mjga/controller/AiController.java b/backend/src/main/java/com/zl/mjga/controller/AiController.java index df74e09..b97dff0 100644 --- a/backend/src/main/java/com/zl/mjga/controller/AiController.java +++ b/backend/src/main/java/com/zl/mjga/controller/AiController.java @@ -2,6 +2,7 @@ package com.zl.mjga.controller; import com.zl.mjga.dto.PageRequestDto; import com.zl.mjga.dto.PageResponseDto; +import com.zl.mjga.dto.ai.ChatDto; import com.zl.mjga.dto.ai.LlmQueryDto; import com.zl.mjga.dto.ai.LlmVm; import com.zl.mjga.exception.BusinessException; @@ -72,9 +73,9 @@ public class AiController { } @PostMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux chat(Principal principal, @RequestBody String userMessage) { + public Flux chat(Principal principal, @RequestBody ChatDto chatDto) { Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); - TokenStream chat = aiChatService.chatPrecedenceLlmWith(principal.getName(), userMessage); + TokenStream chat = aiChatService.chat(principal.getName(), chatDto); chat.onPartialResponse( text -> sink.tryEmitNext( diff --git a/backend/src/main/java/com/zl/mjga/controller/LibraryController.java b/backend/src/main/java/com/zl/mjga/controller/LibraryController.java index 88d4da7..6ccbc9c 100644 --- a/backend/src/main/java/com/zl/mjga/controller/LibraryController.java +++ b/backend/src/main/java/com/zl/mjga/controller/LibraryController.java @@ -8,7 +8,6 @@ import com.zl.mjga.repository.LibraryRepository; import com.zl.mjga.service.RagService; import com.zl.mjga.service.UploadService; import jakarta.validation.Valid; - import java.util.Comparator; import java.util.List; import lombok.RequiredArgsConstructor; @@ -34,16 +33,16 @@ public class LibraryController { @GetMapping("/libraries") public List queryLibraries() { - return libraryRepository.findAll().stream().sorted( - Comparator.comparing(Library::getId).reversed() - ).toList(); + return libraryRepository.findAll().stream() + .sorted(Comparator.comparing(Library::getId).reversed()) + .toList(); } @GetMapping("/docs") public List queryLibraryDocs(@RequestParam Long libraryId) { - return libraryDocRepository.fetchByLibId(libraryId).stream().sorted( - Comparator.comparing(LibraryDoc::getId).reversed() - ).toList(); + return libraryDocRepository.fetchByLibId(libraryId).stream() + .sorted(Comparator.comparing(LibraryDoc::getId).reversed()) + .toList(); } @GetMapping("/segments") diff --git a/backend/src/main/java/com/zl/mjga/dto/ai/ChatDto.java b/backend/src/main/java/com/zl/mjga/dto/ai/ChatDto.java new file mode 100644 index 0000000..e2a0e13 --- /dev/null +++ b/backend/src/main/java/com/zl/mjga/dto/ai/ChatDto.java @@ -0,0 +1,7 @@ +package com.zl.mjga.dto.ai; + +import com.zl.mjga.model.urp.ChatMode; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; + +public record ChatDto(@NotNull ChatMode mode, Long libraryId, @NotEmpty String message) {} diff --git a/backend/src/main/java/com/zl/mjga/model/urp/ChatMode.java b/backend/src/main/java/com/zl/mjga/model/urp/ChatMode.java new file mode 100644 index 0000000..72184bf --- /dev/null +++ b/backend/src/main/java/com/zl/mjga/model/urp/ChatMode.java @@ -0,0 +1,6 @@ +package com.zl.mjga.model.urp; + +public enum ChatMode { + NORMAL, + WITH_LIBRARY +} diff --git a/backend/src/main/java/com/zl/mjga/service/AiChatService.java b/backend/src/main/java/com/zl/mjga/service/AiChatService.java index 9ea1a5c..d7897cd 100644 --- a/backend/src/main/java/com/zl/mjga/service/AiChatService.java +++ b/backend/src/main/java/com/zl/mjga/service/AiChatService.java @@ -2,6 +2,7 @@ package com.zl.mjga.service; import com.zl.mjga.config.ai.AiChatAssistant; import com.zl.mjga.config.ai.SystemToolAssistant; +import com.zl.mjga.dto.ai.ChatDto; import com.zl.mjga.exception.BusinessException; import dev.langchain4j.service.TokenStream; import java.util.Optional; @@ -39,8 +40,20 @@ public class AiChatService { }; } - public TokenStream chatPrecedenceLlmWith(String sessionIdentifier, String userMessage) { + public TokenStream chat(String sessionIdentifier, ChatDto chatDto) { + return switch (chatDto.mode()) { + case NORMAL -> chatWithPrecedenceLlm(sessionIdentifier, chatDto); + case WITH_LIBRARY -> chatWithLibrary(chatDto.libraryId(), chatDto); + }; + } + + public TokenStream chatWithLibrary(Long libraryId, ChatDto chatDto) { + return zhiPuChatAssistant.chat(String.valueOf(libraryId), chatDto.message()); + } + + public TokenStream chatWithPrecedenceLlm(String sessionIdentifier, ChatDto chatDto) { LlmCodeEnum code = getPrecedenceLlmCode(); + String userMessage = chatDto.message(); return switch (code) { case ZHI_PU -> zhiPuChatAssistant.chat(sessionIdentifier, userMessage); case DEEP_SEEK -> deepSeekChatAssistant.chat(sessionIdentifier, userMessage); diff --git a/backend/src/main/java/com/zl/mjga/service/UploadService.java b/backend/src/main/java/com/zl/mjga/service/UploadService.java index 2f01f6e..7aa45f5 100644 --- a/backend/src/main/java/com/zl/mjga/service/UploadService.java +++ b/backend/src/main/java/com/zl/mjga/service/UploadService.java @@ -71,10 +71,6 @@ public class UploadService { if (size > 1024 * 1024) { throw new BusinessException("知识库文档大小不能超过1MB"); } - String contentType = multipartFile.getContentType(); - if (!StringUtils.startsWith(contentType, "text/")) { - throw new BusinessException("非法的上传文件"); - } minioClient.putObject( PutObjectArgs.builder().bucket(minIoConfig.getDefaultBucket()).object(objectName).stream( multipartFile.getInputStream(), size, -1) diff --git a/frontend/src/api/schema/openapi.json b/frontend/src/api/schema/openapi.json index 8d18abd..8dbaf0b 100644 --- a/frontend/src/api/schema/openapi.json +++ b/frontend/src/api/schema/openapi.json @@ -894,7 +894,7 @@ "content": { "application/json": { "schema": { - "type": "string" + "$ref": "#/components/schemas/ChatDto" } } }, @@ -1580,7 +1580,8 @@ "DocUpdateDto": { "required": [ "enable", - "id" + "id", + "libId" ], "type": "object", "properties": { @@ -1588,6 +1589,10 @@ "type": "integer", "format": "int64" }, + "libId": { + "type": "integer", + "format": "int64" + }, "enable": { "type": "boolean" } @@ -1868,6 +1873,29 @@ } } }, + "ChatDto": { + "required": [ + "message", + "mode" + ], + "type": "object", + "properties": { + "mode": { + "type": "string", + "enum": [ + "NORMAL", + "WITH_LIBRARY" + ] + }, + "libraryId": { + "type": "integer", + "format": "int64" + }, + "message": { + "type": "string" + } + } + }, "PageRequestDto": { "type": "object", "properties": { diff --git a/frontend/src/api/types/schema.d.ts b/frontend/src/api/types/schema.d.ts index 9a825b0..13459af 100644 --- a/frontend/src/api/types/schema.d.ts +++ b/frontend/src/api/types/schema.d.ts @@ -783,6 +783,8 @@ export interface components { DocUpdateDto: { /** Format: int64 */ id: number; + /** Format: int64 */ + libId: number; enable: boolean; }; LlmVm: { @@ -867,6 +869,13 @@ export interface components { username: string; password: string; }; + ChatDto: { + /** @enum {string} */ + mode: "NORMAL" | "WITH_LIBRARY"; + /** Format: int64 */ + libraryId?: number; + message: string; + }; PageRequestDto: { /** Format: int64 */ page?: number; @@ -1888,7 +1897,7 @@ export interface operations { }; requestBody: { content: { - "application/json": string; + "application/json": components["schemas"]["ChatDto"]; }; }; responses: { diff --git a/frontend/src/components/common/Assistant.vue b/frontend/src/components/common/Assistant.vue index 387469c..ba0a6b4 100644 --- a/frontend/src/components/common/Assistant.vue +++ b/frontend/src/components/common/Assistant.vue @@ -5,14 +5,18 @@
  • -
    + :class="['flex flex-col leading-1.5 p-4 border-gray-200 max-w-[calc(100%-40px)]', chatElement.isUser ? 'bg-blue-100 rounded-tl-xl rounded-bl-xl rounded-br-xl' : 'bg-gray-100 rounded-e-xl rounded-es-xl']">
    {{ chatElement.username }} + + {{ chatElement.libraryName }} +
    - +
    + + + +
    @@ -51,7 +65,7 @@ " required>
    -