diff --git a/build.gradle b/build.gradle index 5116db6..e7c93c9 100644 --- a/build.gradle +++ b/build.gradle @@ -31,25 +31,28 @@ ext { } dependencies { - implementation("com.fasterxml.jackson.module:jackson-module-kotlin") implementation 'org.springframework.boot:spring-boot-starter-data-jpa' implementation 'org.springframework.boot:spring-boot-starter-validation' implementation 'org.springframework.boot:spring-boot-starter-web' implementation 'com.fasterxml.jackson.module:jackson-module-kotlin' implementation 'org.jetbrains.kotlin:kotlin-reflect' -// implementation 'org.springframework.ai:spring-ai-starter-model-openai' + implementation 'org.springframework.ai:spring-ai-starter-model-openai' compileOnly 'org.projectlombok:lombok' annotationProcessor 'org.projectlombok:lombok' testImplementation 'org.springframework.boot:spring-boot-starter-test' testImplementation 'org.jetbrains.kotlin:kotlin-test-junit5' testImplementation("org.mockito.kotlin:mockito-kotlin:5.3.1") testRuntimeOnly 'org.junit.platform:junit-platform-launcher' + + // postgresql implementation 'org.postgresql:postgresql' + implementation 'org.springframework.ai:spring-ai-starter-vector-store-pgvector' + // test testImplementation "org.springframework.boot:spring-boot-testcontainers" testImplementation "org.testcontainers:postgresql" testImplementation "org.testcontainers:junit-jupiter" - testImplementation "com.pgvector:pgvector:0.1.6" + testImplementation 'org.springframework.ai:spring-ai-starter-vector-store-pgvector' // swagger implementation 'org.springdoc:springdoc-openapi-starter-webmvc-ui:2.7.0' @@ -57,9 +60,6 @@ dependencies { // s3 implementation(platform("software.amazon.awssdk:bom:2.25.70")) implementation("software.amazon.awssdk:s3") - - // pgvector - implementation("com.pgvector:pgvector:0.1.6") } dependencyManagement { diff --git a/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt b/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt index a2c6197..f45cd31 100644 --- a/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt +++ b/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt @@ -17,10 +17,10 @@ import simplerag.ragback.domain.document.repository.TagRepository import simplerag.ragback.global.error.CustomException import simplerag.ragback.global.error.ErrorCode import simplerag.ragback.global.error.FileException -import simplerag.ragback.global.util.S3Type -import simplerag.ragback.global.util.S3Util -import simplerag.ragback.global.util.computeMetricsStreaming -import simplerag.ragback.global.util.resolveContentType +import simplerag.ragback.global.util.s3.S3Type +import simplerag.ragback.global.util.s3.S3Util +import simplerag.ragback.global.util.converter.computeMetricsStreaming +import simplerag.ragback.global.util.converter.resolveContentType import java.util.* @Service diff --git a/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexRequestDTO.kt b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexRequestDTO.kt index 568c8bf..9627d9d 100644 --- a/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexRequestDTO.kt +++ b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexRequestDTO.kt @@ -1,6 +1,7 @@ package simplerag.ragback.domain.index.dto import jakarta.validation.constraints.NotBlank +import jakarta.validation.constraints.NotEmpty import jakarta.validation.constraints.Positive import jakarta.validation.constraints.PositiveOrZero import org.hibernate.validator.constraints.Length @@ -8,6 +9,10 @@ import simplerag.ragback.domain.index.entity.enums.EmbeddingModel import simplerag.ragback.domain.index.entity.enums.SimilarityMetric data class IndexCreateRequest( + + @field:NotEmpty + val dataFileId: List, + @field:Length(max = 255) @field:NotBlank val snapshotName: String, diff --git a/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexResponseDTO.kt b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexResponseDTO.kt index 6d526f6..112548a 100644 --- a/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexResponseDTO.kt +++ b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexResponseDTO.kt @@ -19,7 +19,7 @@ data class IndexPreviewResponseList( data class IndexPreviewResponse( var indexId: Long?, - val snapshotName: String + val snapshotName: String, ) { companion object { fun toIndexPreviewResponse(index: Index): IndexPreviewResponse { diff --git a/src/main/kotlin/simplerag/ragback/domain/index/embed/Embedder.kt b/src/main/kotlin/simplerag/ragback/domain/index/embed/Embedder.kt new file mode 100644 index 0000000..9c4fd1d --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/embed/Embedder.kt @@ -0,0 +1,6 @@ +package simplerag.ragback.domain.index.embed + +interface Embedder { + val dim: Int + fun embed(text: String): FloatArray +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/embed/OpenAIEmbbeder.kt b/src/main/kotlin/simplerag/ragback/domain/index/embed/OpenAIEmbbeder.kt new file mode 100644 index 0000000..30046d6 --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/embed/OpenAIEmbbeder.kt @@ -0,0 +1,13 @@ +package simplerag.ragback.domain.index.embed + +import org.springframework.ai.openai.OpenAiEmbeddingModel +import org.springframework.stereotype.Component + +@Component +class OpenAIEmbedder( + private val openAiEmbeddingModel: OpenAiEmbeddingModel +) : Embedder { + override val dim: Int = 1536 + override fun embed(text: String): FloatArray = + openAiEmbeddingModel.embed(text) +} diff --git a/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt b/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt index 464f227..0c49a9f 100644 --- a/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt +++ b/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt @@ -1,6 +1,5 @@ package simplerag.ragback.domain.index.entity -import com.pgvector.PGvector import jakarta.persistence.* import simplerag.ragback.global.entity.BaseEntity @@ -9,12 +8,11 @@ import simplerag.ragback.global.entity.BaseEntity @Table(name = "chunk_embeddings") class ChunkEmbedding( - @Column(name = "content", nullable = false) - @Lob + @Column(name = "content", nullable = false, columnDefinition = "text") val content: String, - @Column(name = "embedding", columnDefinition = "vector") - var embedding: PGvector, + @Column(name = "embedding", columnDefinition = "vector(1536)", nullable = false) + var embedding: FloatArray, @Column(name = "embedding_dim", nullable = false) val embeddingDim: Int, diff --git a/src/main/kotlin/simplerag/ragback/domain/index/service/IndexService.kt b/src/main/kotlin/simplerag/ragback/domain/index/service/IndexService.kt index 6cf2983..e5971b5 100644 --- a/src/main/kotlin/simplerag/ragback/domain/index/service/IndexService.kt +++ b/src/main/kotlin/simplerag/ragback/domain/index/service/IndexService.kt @@ -3,24 +3,61 @@ package simplerag.ragback.domain.index.service import org.springframework.data.repository.findByIdOrNull import org.springframework.stereotype.Service import org.springframework.transaction.annotation.Transactional +import simplerag.ragback.domain.document.entity.DataFile +import simplerag.ragback.domain.document.repository.DataFileRepository import simplerag.ragback.domain.index.dto.* +import simplerag.ragback.domain.index.embed.Embedder +import simplerag.ragback.domain.index.entity.ChunkEmbedding import simplerag.ragback.domain.index.entity.Index import simplerag.ragback.domain.index.repository.IndexRepository +import simplerag.ragback.global.error.CustomException import simplerag.ragback.global.error.ErrorCode import simplerag.ragback.global.error.IndexException +import simplerag.ragback.global.util.loader.ContentLoader +import simplerag.ragback.global.util.TextChunker @Service class IndexService( private val indexRepository: IndexRepository, + private val embedder: Embedder, + private val dataFileRepository: DataFileRepository, + private val contentLoader: ContentLoader, ) { @Transactional - fun createIndex(indexCreateRequest: IndexCreateRequest): IndexPreviewResponse { + fun createIndex(req: IndexCreateRequest): IndexPreviewResponse { + validateOverlap(req.overlapSize, req.chunkingSize) + + val files: List = dataFileRepository.findAllById(req.dataFileId) + if (files.size != req.dataFileId.size) { + throw CustomException(ErrorCode.NOT_FOUND, "Some dataFileIds not found") + } + + if (embedder.dim != req.embeddingModel.dim) { + throw CustomException(ErrorCode.INVALID_INPUT, "Embedding dim mismatch: model=${req.embeddingModel.dim}, embedder=${embedder.dim}") + } + val index = indexRepository.save(Index.toIndex(req)) + + for (file in files) { + val url = file.fileUrl + val content = contentLoader.load(url) + println(content) + if (content.isBlank()) continue + + val chunks = TextChunker.chunkByCharsSeq(content, req.chunkingSize, req.overlapSize) + for (chunk in chunks) { + val vec = embedder.embed(chunk) + val entity = ChunkEmbedding( + content = chunk, + embedding = vec, + embeddingDim = embedder.dim, + index = index + ) + index.chunkEmbeddings.add(entity) + } + } - validateOverlap(indexCreateRequest.overlapSize, indexCreateRequest.chunkingSize) - - val createdIndex = indexRepository.save(Index.toIndex(indexCreateRequest)) - return IndexPreviewResponse.toIndexPreviewResponse(createdIndex) + return IndexPreviewResponse.toIndexPreviewResponse(index) } @Transactional(readOnly = true) diff --git a/src/main/kotlin/simplerag/ragback/global/storage/FakeS3Util.kt b/src/main/kotlin/simplerag/ragback/global/storage/FakeS3Util.kt index 7256546..ba3db6f 100644 --- a/src/main/kotlin/simplerag/ragback/global/storage/FakeS3Util.kt +++ b/src/main/kotlin/simplerag/ragback/global/storage/FakeS3Util.kt @@ -4,9 +4,9 @@ import org.springframework.context.annotation.Primary import org.springframework.context.annotation.Profile import org.springframework.stereotype.Component import org.springframework.web.multipart.MultipartFile -import simplerag.ragback.global.util.S3Type -import simplerag.ragback.global.util.S3Util -import simplerag.ragback.global.util.sha256Hex +import simplerag.ragback.global.util.s3.S3Type +import simplerag.ragback.global.util.s3.S3Util +import simplerag.ragback.global.util.converter.sha256Hex import java.util.concurrent.ConcurrentHashMap @Component diff --git a/src/main/kotlin/simplerag/ragback/global/util/TextChunker.kt b/src/main/kotlin/simplerag/ragback/global/util/TextChunker.kt new file mode 100644 index 0000000..bd34770 --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/global/util/TextChunker.kt @@ -0,0 +1,27 @@ +package simplerag.ragback.global.util + +object TextChunker { + fun chunkByCharsSeq(raw: String, size: Int, overlap: Int): Sequence = sequence { + require(size >= 1) { "chunk size must be >= 1" } + require(overlap in 0 until size) { "overlap must be 0..size-1" } + + val text = normalize(raw) + if (text.isBlank()) return@sequence + if (text.length <= size) { yield(text); return@sequence } + + val step = size - overlap + var start = 0 + while (start < text.length) { + val end = (start + size).coerceAtMost(text.length) + yield(text.substring(start, end)) + if (end == text.length) break + start += step + } + } + + private fun normalize(s: String): String = + s.replace("\r\n", "\n").replace("\r", "\n") + .replace(Regex("[ \t]+"), " ") + .replace(Regex("\\n{3,}"), "\n\n") + .trim() +} diff --git a/src/main/kotlin/simplerag/ragback/global/util/FileConvertUtil.kt b/src/main/kotlin/simplerag/ragback/global/util/converter/FileConvertUtil.kt similarity index 97% rename from src/main/kotlin/simplerag/ragback/global/util/FileConvertUtil.kt rename to src/main/kotlin/simplerag/ragback/global/util/converter/FileConvertUtil.kt index ca50c62..2fc19c9 100644 --- a/src/main/kotlin/simplerag/ragback/global/util/FileConvertUtil.kt +++ b/src/main/kotlin/simplerag/ragback/global/util/converter/FileConvertUtil.kt @@ -1,4 +1,4 @@ -package simplerag.ragback.global.util +package simplerag.ragback.global.util.converter import org.springframework.web.multipart.MultipartFile import java.io.BufferedInputStream diff --git a/src/main/kotlin/simplerag/ragback/global/util/MultipartJackson2HttpMessageConverter.kt b/src/main/kotlin/simplerag/ragback/global/util/converter/MultipartJackson2HttpMessageConverter.kt similarity index 94% rename from src/main/kotlin/simplerag/ragback/global/util/MultipartJackson2HttpMessageConverter.kt rename to src/main/kotlin/simplerag/ragback/global/util/converter/MultipartJackson2HttpMessageConverter.kt index 9e276dd..cb55937 100644 --- a/src/main/kotlin/simplerag/ragback/global/util/MultipartJackson2HttpMessageConverter.kt +++ b/src/main/kotlin/simplerag/ragback/global/util/converter/MultipartJackson2HttpMessageConverter.kt @@ -1,4 +1,4 @@ -package simplerag.ragback.global.util +package simplerag.ragback.global.util.converter import com.fasterxml.jackson.databind.ObjectMapper import org.springframework.http.MediaType diff --git a/src/main/kotlin/simplerag/ragback/global/util/loader/ContentLoader.kt b/src/main/kotlin/simplerag/ragback/global/util/loader/ContentLoader.kt new file mode 100644 index 0000000..743d96a --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/global/util/loader/ContentLoader.kt @@ -0,0 +1,6 @@ +package simplerag.ragback.global.util.loader + + +interface ContentLoader { + fun load(url: String): String +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/global/util/loader/HttpContentLoader.kt b/src/main/kotlin/simplerag/ragback/global/util/loader/HttpContentLoader.kt new file mode 100644 index 0000000..71da76c --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/global/util/loader/HttpContentLoader.kt @@ -0,0 +1,21 @@ +package simplerag.ragback.global.util.loader + +import org.springframework.http.converter.StringHttpMessageConverter +import org.springframework.stereotype.Component +import org.springframework.web.client.RestTemplate +import java.nio.charset.StandardCharsets + +@Component +class HttpContentLoader : ContentLoader { + + private val restTemplate: RestTemplate = RestTemplate().apply { + // 기존 String 컨버터 제거 후 UTF-8 컨버터를 맨 앞에 추가 + val replaced = messageConverters.filterNot { it is StringHttpMessageConverter }.toMutableList() + replaced.add(0, StringHttpMessageConverter(StandardCharsets.UTF_8)) + messageConverters = replaced + } + + override fun load(url: String): String { + return restTemplate.getForObject(url, String::class.java) ?: "" + } +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/global/util/S3Type.kt b/src/main/kotlin/simplerag/ragback/global/util/s3/S3Type.kt similarity index 66% rename from src/main/kotlin/simplerag/ragback/global/util/S3Type.kt rename to src/main/kotlin/simplerag/ragback/global/util/s3/S3Type.kt index f7136a3..2cbb9bd 100644 --- a/src/main/kotlin/simplerag/ragback/global/util/S3Type.kt +++ b/src/main/kotlin/simplerag/ragback/global/util/s3/S3Type.kt @@ -1,4 +1,4 @@ -package simplerag.ragback.global.util +package simplerag.ragback.global.util.s3 enum class S3Type( diff --git a/src/main/kotlin/simplerag/ragback/global/util/S3Util.kt b/src/main/kotlin/simplerag/ragback/global/util/s3/S3Util.kt similarity index 87% rename from src/main/kotlin/simplerag/ragback/global/util/S3Util.kt rename to src/main/kotlin/simplerag/ragback/global/util/s3/S3Util.kt index e1aba27..0c0ebf0 100644 --- a/src/main/kotlin/simplerag/ragback/global/util/S3Util.kt +++ b/src/main/kotlin/simplerag/ragback/global/util/s3/S3Util.kt @@ -1,4 +1,4 @@ -package simplerag.ragback.global.util +package simplerag.ragback.global.util.s3 import org.springframework.web.multipart.MultipartFile diff --git a/src/main/kotlin/simplerag/ragback/global/util/S3UtilImpl.kt b/src/main/kotlin/simplerag/ragback/global/util/s3/S3UtilImpl.kt similarity index 96% rename from src/main/kotlin/simplerag/ragback/global/util/S3UtilImpl.kt rename to src/main/kotlin/simplerag/ragback/global/util/s3/S3UtilImpl.kt index c657152..49a4fde 100644 --- a/src/main/kotlin/simplerag/ragback/global/util/S3UtilImpl.kt +++ b/src/main/kotlin/simplerag/ragback/global/util/s3/S3UtilImpl.kt @@ -1,4 +1,4 @@ -package simplerag.ragback.global.util +package simplerag.ragback.global.util.s3 import org.slf4j.LoggerFactory import org.springframework.context.annotation.Profile @@ -58,9 +58,7 @@ class S3UtilImpl( } override fun urlFromKey(key: String): String = - s3.utilities() - .getUrl { it.bucket(bucket).key(key) } - .toExternalForm() + "https://mukit-s3.s3.ap-northeast-2.amazonaws.com/" + key override fun deleteByUrl(url: String) { val key = keyFromUrl(url) ?: throw S3Exception(ErrorCode.S3_INVALID_URL) diff --git a/src/main/resources/application-local.yml b/src/main/resources/application-local.yml index 5019101..89011ba 100644 --- a/src/main/resources/application-local.yml +++ b/src/main/resources/application-local.yml @@ -8,6 +8,13 @@ spring: hibernate: ddl-auto: update show-sql: true + ai: + openai: + api-key: ${OPENAI_KEY} + model: + embedding: + options: + model: text-embedding-3-small logging: level: diff --git a/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt b/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt index b0482db..8cc7b79 100644 --- a/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt +++ b/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt @@ -2,7 +2,6 @@ package simplerag.ragback.domain.document.service import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.* -import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Test import org.springframework.beans.factory.annotation.Autowired @@ -14,7 +13,6 @@ import org.springframework.transaction.annotation.Transactional import org.springframework.transaction.support.TransactionTemplate import org.springframework.web.multipart.MultipartFile import org.testcontainers.containers.PostgreSQLContainer -import org.testcontainers.junit.jupiter.Container import org.testcontainers.utility.DockerImageName import simplerag.ragback.domain.document.dto.DataFileBulkCreateRequest import simplerag.ragback.domain.document.dto.DataFileCreateItem @@ -26,18 +24,18 @@ import simplerag.ragback.global.error.CustomException import simplerag.ragback.global.error.ErrorCode import simplerag.ragback.global.error.FileException import simplerag.ragback.global.storage.FakeS3Util -import simplerag.ragback.global.util.S3Type -import simplerag.ragback.global.util.sha256Hex +import simplerag.ragback.global.util.s3.S3Type +import simplerag.ragback.global.util.converter.sha256Hex import java.security.MessageDigest @SpringBootTest @ActiveProfiles("test") class DataFileServiceTest( - private val dataFileService: DataFileService, - private val dataFileRepository: DataFileRepository, - private val tagRepository: TagRepository, - private val dataFileTagRepository: DataFileTagRepository, - private val s3Util: FakeS3Util + @Autowired val dataFileService: DataFileService, + @Autowired val dataFileRepository: DataFileRepository, + @Autowired val tagRepository: TagRepository, + @Autowired val dataFileTagRepository: DataFileTagRepository, + @Autowired val s3Util: FakeS3Util ) { diff --git a/src/test/kotlin/simplerag/ragback/domain/index/service/IndexServiceTest.kt b/src/test/kotlin/simplerag/ragback/domain/index/service/IndexServiceTest.kt index 097603a..3dd32f1 100644 --- a/src/test/kotlin/simplerag/ragback/domain/index/service/IndexServiceTest.kt +++ b/src/test/kotlin/simplerag/ragback/domain/index/service/IndexServiceTest.kt @@ -27,8 +27,8 @@ import org.testcontainers.utility.DockerImageName @ActiveProfiles("test") @TestConstructor(autowireMode = TestConstructor.AutowireMode.ALL) class IndexServiceTest( - val indexService: IndexService, - val indexRepository: IndexRepository, + @Autowired val indexService: IndexService, + @Autowired val indexRepository: IndexRepository, ) {