From 8d285e1abcd9077ce89effa501db199683c7c3b8 Mon Sep 17 00:00:00 2001 From: Chuck1sn Date: Wed, 25 Jun 2025 15:10:02 +0800 Subject: [PATCH] add document ingestor --- backend/build.gradle.kts | 4 +- .../mjga/config/ai/ChatModelInitializer.java | 8 +- .../mjga/config/ai/EmbeddingInitializer.java | 39 ++++++++- .../config/security/WebSecurityConfig.java | 1 + .../com/zl/mjga/controller/AiController.java | 14 +++ .../controller/IdentityAccessController.java | 47 +--------- .../com/zl/mjga/service/EmbeddingService.java | 25 +++++- .../com/zl/mjga/service/UploadService.java | 85 +++++++++++++++++++ 8 files changed, 173 insertions(+), 50 deletions(-) create mode 100644 backend/src/main/java/com/zl/mjga/service/UploadService.java diff --git a/backend/build.gradle.kts b/backend/build.gradle.kts index 0047b5f..4464b35 100644 --- a/backend/build.gradle.kts +++ b/backend/build.gradle.kts @@ -32,7 +32,7 @@ sourceSets { group = "com.zl.mjga" version = "1.0.0" description = "make java great again!" -java.sourceCompatibility = JavaVersion.VERSION_17 +java.sourceCompatibility = JavaVersion.VERSION_21 configurations { compileOnly { @@ -64,6 +64,8 @@ dependencies { implementation("dev.langchain4j:langchain4j-open-ai:1.0.0") implementation("dev.langchain4j:langchain4j-pgvector:1.0.1-beta6") implementation("dev.langchain4j:langchain4j-community-zhipu-ai:1.0.1-beta6") + implementation("dev.langchain4j:langchain4j-easy-rag:1.1.0-beta7") + implementation("dev.langchain4j:langchain4j-document-loader-amazon-s3:1.1.0-beta7") implementation("io.projectreactor:reactor-core:3.7.6") testImplementation("org.testcontainers:junit-jupiter:$testcontainersVersion") testImplementation("org.testcontainers:postgresql:$testcontainersVersion") diff --git a/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java b/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java index eac37d4..ea27d5f 100644 --- a/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java +++ b/backend/src/main/java/com/zl/mjga/config/ai/ChatModelInitializer.java @@ -3,9 +3,12 @@ package com.zl.mjga.config.ai; import com.zl.mjga.component.PromptConfiguration; import com.zl.mjga.service.LlmService; import dev.langchain4j.community.model.zhipu.ZhipuAiStreamingChatModel; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; import dev.langchain4j.service.AiServices; +import dev.langchain4j.store.embedding.EmbeddingStore; import lombok.RequiredArgsConstructor; import org.jooq.generated.mjga.enums.LlmCodeEnum; import org.jooq.generated.mjga.tables.pojos.AiLlmConfig; @@ -54,11 +57,14 @@ public class ChatModelInitializer { @Bean @DependsOn("flywayInitializer") - public AiChatAssistant zhiPuChatAssistant(ZhipuAiStreamingChatModel zhipuChatModel) { + public AiChatAssistant zhiPuChatAssistant( + ZhipuAiStreamingChatModel zhipuChatModel, + EmbeddingStore zhiPuLibraryEmbeddingStore) { return AiServices.builder(AiChatAssistant.class) .streamingChatModel(zhipuChatModel) .systemMessageProvider(chatMemoryId -> promptConfiguration.getSystem()) .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10)) + .contentRetriever(EmbeddingStoreContentRetriever.from(zhiPuLibraryEmbeddingStore)) .build(); } diff --git a/backend/src/main/java/com/zl/mjga/config/ai/EmbeddingInitializer.java b/backend/src/main/java/com/zl/mjga/config/ai/EmbeddingInitializer.java index bc81f98..891b63d 100644 --- a/backend/src/main/java/com/zl/mjga/config/ai/EmbeddingInitializer.java +++ b/backend/src/main/java/com/zl/mjga/config/ai/EmbeddingInitializer.java @@ -1,10 +1,14 @@ package com.zl.mjga.config.ai; +import com.zl.mjga.config.minio.MinIoConfig; import com.zl.mjga.service.LlmService; import dev.langchain4j.community.model.zhipu.ZhipuAiEmbeddingModel; +import dev.langchain4j.data.document.loader.amazon.s3.AmazonS3DocumentLoader; +import dev.langchain4j.data.document.loader.amazon.s3.AwsCredentials; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; import jakarta.annotation.Resource; import lombok.RequiredArgsConstructor; @@ -42,7 +46,7 @@ public class EmbeddingInitializer { } @Bean - public EmbeddingStore zhiPuEmbeddingStore(EmbeddingModel zhipuEmbeddingModel) { + public EmbeddingStore zhiPuEmbeddingStore() { String hostPort = env.getProperty("DATABASE_HOST_PORT"); String host = hostPort.split(":")[0]; return PgVectorEmbeddingStore.builder() @@ -55,4 +59,37 @@ public class EmbeddingInitializer { .dimension(2048) .build(); } + + @Bean + public EmbeddingStore zhiPuLibraryEmbeddingStore() { + String hostPort = env.getProperty("DATABASE_HOST_PORT"); + String host = hostPort.split(":")[0]; + return PgVectorEmbeddingStore.builder() + .host(host) + .port(env.getProperty("DATABASE_EXPOSE_PORT", Integer.class)) + .database(env.getProperty("DATABASE_DB")) + .user(env.getProperty("DATABASE_USER")) + .password(env.getProperty("DATABASE_PASSWORD")) + .table("mjga.zhipu_library_embedding_store") + .dimension(2048) + .build(); + } + + @Bean + public EmbeddingStoreIngestor zhipuEmbeddingStoreIngestor( + EmbeddingStore zhiPuLibraryEmbeddingStore, EmbeddingModel zhipuEmbeddingModel) { + return EmbeddingStoreIngestor.builder() + .embeddingModel(zhipuEmbeddingModel) + .embeddingStore(zhiPuLibraryEmbeddingStore) + .build(); + } + + @Bean + public AmazonS3DocumentLoader amazonS3DocumentLoader(MinIoConfig minIoConfig) { + return AmazonS3DocumentLoader.builder() + .endpointUrl(minIoConfig.getEndpoint()) + .forcePathStyle(true) + .awsCredentials(new AwsCredentials(minIoConfig.getAccessKey(), minIoConfig.getSecretKey())) + .build(); + } } diff --git a/backend/src/main/java/com/zl/mjga/config/security/WebSecurityConfig.java b/backend/src/main/java/com/zl/mjga/config/security/WebSecurityConfig.java index 6592366..d43c3cd 100644 --- a/backend/src/main/java/com/zl/mjga/config/security/WebSecurityConfig.java +++ b/backend/src/main/java/com/zl/mjga/config/security/WebSecurityConfig.java @@ -41,6 +41,7 @@ public class WebSecurityConfig { new AntPathRequestMatcher("/auth/sign-up", HttpMethod.POST.name()), new AntPathRequestMatcher("/v3/api-docs/**", HttpMethod.GET.name()), new AntPathRequestMatcher("/swagger-ui/**", HttpMethod.GET.name()), + new AntPathRequestMatcher("/ai/library/upload", HttpMethod.POST.name()), new AntPathRequestMatcher("/swagger-ui.html", HttpMethod.GET.name()), new AntPathRequestMatcher("/error")); } diff --git a/backend/src/main/java/com/zl/mjga/controller/AiController.java b/backend/src/main/java/com/zl/mjga/controller/AiController.java index 3c912b9..969a035 100644 --- a/backend/src/main/java/com/zl/mjga/controller/AiController.java +++ b/backend/src/main/java/com/zl/mjga/controller/AiController.java @@ -9,6 +9,7 @@ import com.zl.mjga.repository.*; import com.zl.mjga.service.AiChatService; import com.zl.mjga.service.EmbeddingService; import com.zl.mjga.service.LlmService; +import com.zl.mjga.service.UploadService; import dev.langchain4j.service.TokenStream; import jakarta.validation.Valid; import java.security.Principal; @@ -24,6 +25,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; @@ -42,6 +44,7 @@ public class AiController { private final RoleRepository repository; private final PermissionRepository permissionRepository; private final RoleRepository roleRepository; + private final UploadService uploadService; @PostMapping(value = "/action/execute", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux actionExecute(Principal principal, @RequestBody String userMessage) { @@ -169,4 +172,15 @@ public class AiController { void createNewConversation(Principal principal) { aiChatService.evictChatMemory(principal.getName()); } + + @PostMapping( + value = "/library/upload", + consumes = MediaType.MULTIPART_FORM_DATA_VALUE, + produces = MediaType.TEXT_PLAIN_VALUE) + public String uploadLibraryFile(@RequestPart("file") MultipartFile multipartFile) + throws Exception { + String objectName = uploadService.uploadLibraryFile(multipartFile); + embeddingService.ingestDocument(objectName); + return objectName; + } } diff --git a/backend/src/main/java/com/zl/mjga/controller/IdentityAccessController.java b/backend/src/main/java/com/zl/mjga/controller/IdentityAccessController.java index bf28040..577f08a 100644 --- a/backend/src/main/java/com/zl/mjga/controller/IdentityAccessController.java +++ b/backend/src/main/java/com/zl/mjga/controller/IdentityAccessController.java @@ -1,6 +1,5 @@ package com.zl.mjga.controller; -import com.zl.mjga.config.minio.MinIoConfig; import com.zl.mjga.dto.PageRequestDto; import com.zl.mjga.dto.PageResponseDto; import com.zl.mjga.dto.department.DepartmentBindDto; @@ -13,17 +12,11 @@ import com.zl.mjga.repository.PermissionRepository; import com.zl.mjga.repository.RoleRepository; import com.zl.mjga.repository.UserRepository; import com.zl.mjga.service.IdentityAccessService; -import io.minio.MinioClient; -import io.minio.PutObjectArgs; +import com.zl.mjga.service.UploadService; import jakarta.validation.Valid; -import java.awt.image.BufferedImage; import java.security.Principal; -import java.time.Instant; import java.util.List; -import javax.imageio.ImageIO; import lombok.RequiredArgsConstructor; -import org.apache.commons.lang3.RandomStringUtils; -import org.apache.commons.lang3.StringUtils; import org.jooq.generated.mjga.tables.pojos.User; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -41,8 +34,7 @@ public class IdentityAccessController { private final UserRepository userRepository; private final RoleRepository roleRepository; private final PermissionRepository permissionRepository; - private final MinioClient minioClient; - private final MinIoConfig minIoConfig; + private final UploadService uploadService; @PreAuthorize("hasAuthority(T(com.zl.mjga.model.urp.EPermission).WRITE_USER_ROLE_PERMISSION)") @PostMapping( @@ -50,40 +42,7 @@ public class IdentityAccessController { consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.TEXT_PLAIN_VALUE) public String uploadAvatar(@RequestPart("file") MultipartFile multipartFile) throws Exception { - String originalFilename = multipartFile.getOriginalFilename(); - if (StringUtils.isEmpty(originalFilename)) { - throw new BusinessException("文件名不能为空"); - } - String contentType = multipartFile.getContentType(); - String extension = ""; - if ("image/jpeg".equals(contentType)) { - extension = ".jpg"; - } else if ("image/png".equals(contentType)) { - extension = ".png"; - } - String objectName = - String.format( - "/avatar/%d%s%s", - Instant.now().toEpochMilli(), - RandomStringUtils.insecure().nextAlphabetic(6), - extension); - if (multipartFile.isEmpty()) { - throw new BusinessException("上传的文件不能为空"); - } - long size = multipartFile.getSize(); - if (size > 200 * 1024) { - throw new BusinessException("头像文件大小不能超过200KB"); - } - BufferedImage img = ImageIO.read(multipartFile.getInputStream()); - if (img == null) { - throw new BusinessException("非法的上传文件"); - } - minioClient.putObject( - PutObjectArgs.builder().bucket(minIoConfig.getDefaultBucket()).object(objectName).stream( - multipartFile.getInputStream(), size, -1) - .contentType(multipartFile.getContentType()) - .build()); - return objectName; + return uploadService.uploadAvatarFile(multipartFile); } @GetMapping("/me") diff --git a/backend/src/main/java/com/zl/mjga/service/EmbeddingService.java b/backend/src/main/java/com/zl/mjga/service/EmbeddingService.java index 8967838..8c9eaaa 100644 --- a/backend/src/main/java/com/zl/mjga/service/EmbeddingService.java +++ b/backend/src/main/java/com/zl/mjga/service/EmbeddingService.java @@ -3,25 +3,30 @@ package com.zl.mjga.service; import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey; import com.zl.mjga.config.ai.ZhiPuEmbeddingModelConfig; +import com.zl.mjga.config.minio.MinIoConfig; import com.zl.mjga.model.urp.Actions; +import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.document.loader.amazon.s3.AmazonS3DocumentLoader; +import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.EmbeddingSearchRequest; -import dev.langchain4j.store.embedding.EmbeddingSearchResult; -import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.*; import dev.langchain4j.store.embedding.filter.Filter; +import io.minio.errors.*; import jakarta.annotation.PostConstruct; import java.util.HashMap; import java.util.Map; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Configuration; import org.springframework.stereotype.Service; @Configuration @RequiredArgsConstructor @Service +@Slf4j public class EmbeddingService { private final EmbeddingModel zhipuEmbeddingModel; @@ -30,6 +35,20 @@ public class EmbeddingService { private final ZhiPuEmbeddingModelConfig zhiPuEmbeddingModelConfig; + private final AmazonS3DocumentLoader amazonS3DocumentLoader; + + private final EmbeddingStoreIngestor zhiPuEmbeddingStoreIngestor; + + private final MinIoConfig minIoConfig; + + public void ingestDocument(String objectName) { + Document document = + amazonS3DocumentLoader.loadDocument( + minIoConfig.getDefaultBucket(), objectName, new ApacheTikaDocumentParser()); + IngestionResult ingest = zhiPuEmbeddingStoreIngestor.ingest(document); + log.info("Ingest document finished {}", ingest); + } + public Map searchAction(String message) { Map result = new HashMap<>(); EmbeddingSearchRequest embeddingSearchRequest = diff --git a/backend/src/main/java/com/zl/mjga/service/UploadService.java b/backend/src/main/java/com/zl/mjga/service/UploadService.java new file mode 100644 index 0000000..654dfad --- /dev/null +++ b/backend/src/main/java/com/zl/mjga/service/UploadService.java @@ -0,0 +1,85 @@ +package com.zl.mjga.service; + +import com.zl.mjga.config.minio.MinIoConfig; +import com.zl.mjga.exception.BusinessException; +import io.minio.*; +import java.awt.image.BufferedImage; +import java.time.Instant; +import javax.imageio.ImageIO; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.RandomStringUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.stereotype.Service; +import org.springframework.web.multipart.MultipartFile; + +@Service +@RequiredArgsConstructor +@Slf4j +public class UploadService { + + private final MinioClient minioClient; + private final MinIoConfig minIoConfig; + + public String uploadAvatarFile(MultipartFile multipartFile) throws Exception { + String originalFilename = multipartFile.getOriginalFilename(); + if (StringUtils.isEmpty(originalFilename)) { + throw new BusinessException("文件名不能为空"); + } + String contentType = multipartFile.getContentType(); + String extension = ""; + if ("image/jpeg".equals(contentType)) { + extension = ".jpg"; + } else if ("image/png".equals(contentType)) { + extension = ".png"; + } + String objectName = + String.format( + "/library/%d%s%s", + Instant.now().toEpochMilli(), + RandomStringUtils.insecure().nextAlphabetic(6), + extension); + if (multipartFile.isEmpty()) { + throw new BusinessException("上传的文件不能为空"); + } + long size = multipartFile.getSize(); + if (size > 200 * 1024) { + throw new BusinessException("头像大小不能超过200KB"); + } + BufferedImage img = ImageIO.read(multipartFile.getInputStream()); + if (img == null) { + throw new BusinessException("非法的上传文件"); + } + minioClient.putObject( + PutObjectArgs.builder().bucket(minIoConfig.getDefaultBucket()).object(objectName).stream( + multipartFile.getInputStream(), size, -1) + .contentType(multipartFile.getContentType()) + .build()); + return objectName; + } + + public String uploadLibraryFile(MultipartFile multipartFile) throws Exception { + String originalFilename = multipartFile.getOriginalFilename(); + if (StringUtils.isEmpty(originalFilename)) { + throw new BusinessException("文件名不能为空"); + } + String objectName = String.format("/library/%s", originalFilename); + if (multipartFile.isEmpty()) { + throw new BusinessException("上传的文件不能为空"); + } + long size = multipartFile.getSize(); + if (size > 1024 * 1024) { + throw new BusinessException("知识库文档大小不能超过1MB"); + } + String contentType = multipartFile.getContentType(); + if (!StringUtils.startsWith(contentType, "text/")) { + throw new BusinessException("非法的上传文件"); + } + minioClient.putObject( + PutObjectArgs.builder().bucket(minIoConfig.getDefaultBucket()).object(objectName).stream( + multipartFile.getInputStream(), size, -1) + .contentType(multipartFile.getContentType()) + .build()); + return objectName; + } +}