diff --git a/src/main/kotlin/simplerag/ragback/domain/chat/entity/Model.kt b/src/main/kotlin/simplerag/ragback/domain/chat/entity/Model.kt index 5d7a193..093aa50 100644 --- a/src/main/kotlin/simplerag/ragback/domain/chat/entity/Model.kt +++ b/src/main/kotlin/simplerag/ragback/domain/chat/entity/Model.kt @@ -26,4 +26,4 @@ class Model( @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "models_id") val id: Long? = null, -): BaseEntity() \ No newline at end of file +) : BaseEntity() \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFile.kt b/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFile.kt index 1169374..724f10a 100644 --- a/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFile.kt +++ b/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFile.kt @@ -2,7 +2,6 @@ package simplerag.ragback.domain.document.entity import jakarta.persistence.* import simplerag.ragback.global.entity.BaseEntity -import java.time.LocalDateTime @Entity @Table( diff --git a/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFileTag.kt b/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFileTag.kt index 6994eed..a00b9a6 100644 --- a/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFileTag.kt +++ b/src/main/kotlin/simplerag/ragback/domain/document/entity/DataFileTag.kt @@ -21,4 +21,4 @@ class DataFileTag( @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "data_files_tags_id") val id: Long? = null, -): BaseEntity() \ No newline at end of file +) : BaseEntity() \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/document/entity/Tag.kt b/src/main/kotlin/simplerag/ragback/domain/document/entity/Tag.kt index fe17647..0ce9934 100644 --- a/src/main/kotlin/simplerag/ragback/domain/document/entity/Tag.kt +++ b/src/main/kotlin/simplerag/ragback/domain/document/entity/Tag.kt @@ -16,4 +16,4 @@ class Tag( @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "tags_id") val id: Long? = null, -): BaseEntity() \ No newline at end of file +) : BaseEntity() \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/document/repository/DataFileTagRepository.kt b/src/main/kotlin/simplerag/ragback/domain/document/repository/DataFileTagRepository.kt index 847d51d..68c9b0f 100644 --- a/src/main/kotlin/simplerag/ragback/domain/document/repository/DataFileTagRepository.kt +++ b/src/main/kotlin/simplerag/ragback/domain/document/repository/DataFileTagRepository.kt @@ -9,12 +9,14 @@ import simplerag.ragback.domain.document.entity.DataFileTag interface DataFileTagRepository : JpaRepository { fun existsByDataFileIdAndTagId(dataFileId: Long, tagId: Long): Boolean - @Query(""" + @Query( + """ SELECT dft FROM DataFileTag dft JOIN FETCH dft.tag t WHERE dft.dataFile = :dataFile - """) + """ + ) fun findTagsByDataFile(@Param("dataFile") dataFile: DataFile): List fun deleteAllByDataFile(dataFile: DataFile) diff --git a/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt b/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt index cdf180a..5bb1671 100644 --- a/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt +++ b/src/main/kotlin/simplerag/ragback/domain/document/service/DataFileService.kt @@ -2,7 +2,6 @@ package simplerag.ragback.domain.document.service import org.springframework.dao.DataIntegrityViolationException import org.springframework.data.domain.PageRequest -import org.springframework.data.domain.Pageable import org.springframework.stereotype.Service import org.springframework.transaction.annotation.Transactional import org.springframework.transaction.support.TransactionSynchronization @@ -22,9 +21,7 @@ import simplerag.ragback.global.util.S3Type import simplerag.ragback.global.util.S3Util import simplerag.ragback.global.util.computeMetricsStreaming import simplerag.ragback.global.util.resolveContentType -import java.time.LocalDateTime import java.util.* -import kotlin.collections.ArrayList @Service class DataFileService( @@ -83,11 +80,10 @@ class DataFileService( val dataSlice = dataFileRepository.findByIdGreaterThanOrderById(cursor, PageRequest.of(0, take)) val dataFileList: MutableList = ArrayList() - dataSlice.forEach{ dataFile -> + dataSlice.forEach { dataFile -> val dataFileTags: List = dataFileTagRepository.findTagsByDataFile(dataFile) - val tagDtos: List = dataFileTags.map{ - dataFileTag -> + val tagDtos: List = dataFileTags.map { dataFileTag -> val tag = dataFileTag.tag TagDTO(tag.id, tag.name) } diff --git a/src/main/kotlin/simplerag/ragback/domain/index/controller/IndexController.kt b/src/main/kotlin/simplerag/ragback/domain/index/controller/IndexController.kt new file mode 100644 index 0000000..72cb889 --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/controller/IndexController.kt @@ -0,0 +1,59 @@ +package simplerag.ragback.domain.index.controller + +import jakarta.validation.Valid +import org.springframework.http.HttpStatus +import org.springframework.validation.annotation.Validated +import org.springframework.web.bind.annotation.* +import simplerag.ragback.domain.index.dto.* +import simplerag.ragback.domain.index.service.IndexService +import simplerag.ragback.global.response.ApiResponse + +@RestController +@RequestMapping("/api/v1/indexes") +@Validated +class IndexController( + private val indexService: IndexService +) { + + @PostMapping + @ResponseStatus(HttpStatus.CREATED) + fun createIndex( + @RequestBody @Valid indexCreateRequest: IndexCreateRequest + ): ApiResponse { + val createdIndex = indexService.createIndex(indexCreateRequest) + return ApiResponse.ok(createdIndex) + } + + @GetMapping + fun getIndexes(): ApiResponse { + val indexPreviewResponseList = indexService.getIndexes() + return ApiResponse.ok(indexPreviewResponseList) + } + + @GetMapping("/{indexId}") + fun getIndex( + @PathVariable indexId: Long + ): ApiResponse { + val indexDetailResponse = indexService.getIndex(indexId) + return ApiResponse.ok(indexDetailResponse) + } + + @PutMapping("/{indexId}") + fun updateIndexes( + @PathVariable indexId: Long, + @RequestBody @Valid indexUpdateRequest: IndexUpdateRequest, + ): ApiResponse { + val indexPreviewResponse = indexService.updateIndex(indexId, indexUpdateRequest) + return ApiResponse.ok(indexPreviewResponse) + } + + @DeleteMapping("/{indexId}") + fun deleteIndex( + @PathVariable indexId: Long + ): ApiResponse { + indexService.deleteIndex(indexId) + return ApiResponse.ok(null, "인덱스가 삭제 되었습니다.") + } + + +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/converter/IndexConverter.kt b/src/main/kotlin/simplerag/ragback/domain/index/converter/IndexConverter.kt new file mode 100644 index 0000000..a4f1ca0 --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/converter/IndexConverter.kt @@ -0,0 +1,47 @@ +package simplerag.ragback.domain.index.converter + +import simplerag.ragback.domain.index.dto.IndexCreateRequest +import simplerag.ragback.domain.index.dto.IndexDetailResponse +import simplerag.ragback.domain.index.dto.IndexPreviewResponse +import simplerag.ragback.domain.index.dto.IndexPreviewResponseList +import simplerag.ragback.domain.index.entity.Index + + +fun toIndex(createRequest: IndexCreateRequest): Index { + return Index( + snapshotName = createRequest.snapshotName.trim(), + overlapSize = createRequest.overlapSize, + chunkingSize = createRequest.chunkingSize, + similarityMetric = createRequest.similarityMetric, + topK = createRequest.topK, + embeddingModel = createRequest.embeddingModel, + reranker = createRequest.reranker + ) +} + +fun toIndexPreviewResponseList( + indexes: List +): IndexPreviewResponseList { + val indexList = indexes.map { toIndexPreviewResponse(it) } + return IndexPreviewResponseList(indexList) +} + +fun toIndexPreviewResponse(index: Index): IndexPreviewResponse { + return IndexPreviewResponse( + indexId = index.id, + snapshotName = index.snapshotName, + ) +} + +fun toIndexDetailResponse(index: Index): IndexDetailResponse { + return IndexDetailResponse( + indexId = index.id, + chunkingSize = index.chunkingSize, + overlapSize = index.overlapSize, + similarityMetric = index.similarityMetric, + topK = index.topK, + embeddingModel = index.embeddingModel, + reranker = index.reranker, + snapshotName = index.snapshotName, + ) +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexRequestDTO.kt b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexRequestDTO.kt new file mode 100644 index 0000000..568c8bf --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexRequestDTO.kt @@ -0,0 +1,48 @@ +package simplerag.ragback.domain.index.dto + +import jakarta.validation.constraints.NotBlank +import jakarta.validation.constraints.Positive +import jakarta.validation.constraints.PositiveOrZero +import org.hibernate.validator.constraints.Length +import simplerag.ragback.domain.index.entity.enums.EmbeddingModel +import simplerag.ragback.domain.index.entity.enums.SimilarityMetric + +data class IndexCreateRequest( + @field:Length(max = 255) + @field:NotBlank + val snapshotName: String, + + @field:Positive + val chunkingSize: Int, + + @field:PositiveOrZero + val overlapSize: Int, + + val similarityMetric: SimilarityMetric, + + @field:Positive + val topK: Int, + + val embeddingModel: EmbeddingModel, + + val reranker: Boolean, +) + +data class IndexUpdateRequest( + @field:Length(max = 255) + @field:NotBlank + val snapshotName: String, + + @field:Positive + val chunkingSize: Int, + + @field:PositiveOrZero + val overlapSize: Int, + + val similarityMetric: SimilarityMetric, + + @field:Positive + val topK: Int, + + val reranker: Boolean, +) \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexResponseDTO.kt b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexResponseDTO.kt new file mode 100644 index 0000000..7ec19ff --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/dto/IndexResponseDTO.kt @@ -0,0 +1,24 @@ +package simplerag.ragback.domain.index.dto + +import simplerag.ragback.domain.index.entity.enums.EmbeddingModel +import simplerag.ragback.domain.index.entity.enums.SimilarityMetric + +data class IndexPreviewResponseList( + val indexDetailResponse: List, +) + +data class IndexPreviewResponse( + var indexId: Long?, + val snapshotName: String +) + +data class IndexDetailResponse( + var indexId: Long?, + val snapshotName: String, + val chunkingSize: Int, + val overlapSize: Int, + val similarityMetric: SimilarityMetric, + val topK: Int, + val embeddingModel: EmbeddingModel, + val reranker: Boolean, +) \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt b/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt index 3c1862f..127770b 100644 --- a/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt +++ b/src/main/kotlin/simplerag/ragback/domain/index/entity/ChunkEmbedding.kt @@ -27,7 +27,7 @@ class ChunkEmbedding( @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "chunk_embeddings_id") val id: Long? = null, -): BaseEntity() { +) : BaseEntity() { @get:Transient val embedding: FloatArray get() = _embedding.copyOf() diff --git a/src/main/kotlin/simplerag/ragback/domain/index/entity/DataFileIndex.kt b/src/main/kotlin/simplerag/ragback/domain/index/entity/DataFileIndex.kt index 060722c..2e60163 100644 --- a/src/main/kotlin/simplerag/ragback/domain/index/entity/DataFileIndex.kt +++ b/src/main/kotlin/simplerag/ragback/domain/index/entity/DataFileIndex.kt @@ -19,4 +19,4 @@ class DataFileIndex( @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "data_files_indexes_id") val id: Long? = null, -): BaseEntity() \ No newline at end of file +) : BaseEntity() \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/entity/Index.kt b/src/main/kotlin/simplerag/ragback/domain/index/entity/Index.kt index 630783a..2f7c71f 100644 --- a/src/main/kotlin/simplerag/ragback/domain/index/entity/Index.kt +++ b/src/main/kotlin/simplerag/ragback/domain/index/entity/Index.kt @@ -2,6 +2,7 @@ package simplerag.ragback.domain.index.entity import jakarta.persistence.* import jakarta.validation.constraints.Min +import simplerag.ragback.domain.index.dto.IndexUpdateRequest import simplerag.ragback.domain.index.entity.enums.EmbeddingModel import simplerag.ragback.domain.index.entity.enums.SimilarityMetric import simplerag.ragback.global.entity.BaseEntity @@ -11,32 +12,46 @@ import simplerag.ragback.global.entity.BaseEntity class Index( @Column(name = "snapshot_name", length = 255, nullable = false) - val snapshotName: String, + var snapshotName: String, @Column(name = "chunking_size", nullable = false) @Min(1) - val chunkingSize: Int, + var chunkingSize: Int, @Column(name = "overlap_size", nullable = false) @Min(0) - val overlapSize: Int, + var overlapSize: Int, @Column(name = "similarity_metric", nullable = false) @Enumerated(EnumType.STRING) - val similarityMetric: SimilarityMetric, + var similarityMetric: SimilarityMetric, @Column(name = "top_k", nullable = false) @Min(1) - val topK: Int, + var topK: Int, @Column(name = "embedding_model", nullable = false, length = 255) @Enumerated(EnumType.STRING) val embeddingModel: EmbeddingModel, @Column(name = "reranker", nullable = false) - val reranker: Boolean, + var reranker: Boolean, + + @OneToMany(cascade = [CascadeType.ALL], orphanRemoval = true, mappedBy = "index") + val chunkEmbeddings: MutableList = mutableListOf(), @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "indexes_id") - val id: Long? = null, -): BaseEntity() \ No newline at end of file + var id: Long? = null, +) : BaseEntity() { + + fun update(req: IndexUpdateRequest) { + snapshotName = req.snapshotName.trim() + chunkingSize = req.chunkingSize + overlapSize = req.overlapSize + similarityMetric = req.similarityMetric + topK = req.topK + reranker = req.reranker + } + +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/repository/ChunkEmbeddingRepository.kt b/src/main/kotlin/simplerag/ragback/domain/index/repository/ChunkEmbeddingRepository.kt new file mode 100644 index 0000000..0060955 --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/repository/ChunkEmbeddingRepository.kt @@ -0,0 +1,6 @@ +package simplerag.ragback.domain.index.repository + +import org.springframework.data.jpa.repository.JpaRepository +import simplerag.ragback.domain.index.entity.ChunkEmbedding + +interface ChunkEmbeddingRepository : JpaRepository \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/repository/IndexRepository.kt b/src/main/kotlin/simplerag/ragback/domain/index/repository/IndexRepository.kt new file mode 100644 index 0000000..60af885 --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/repository/IndexRepository.kt @@ -0,0 +1,9 @@ +package simplerag.ragback.domain.index.repository + +import org.springframework.data.jpa.repository.JpaRepository +import simplerag.ragback.domain.index.entity.Index + +interface IndexRepository : JpaRepository { + + fun findAllByOrderByCreatedAtDesc(): List +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/index/service/IndexService.kt b/src/main/kotlin/simplerag/ragback/domain/index/service/IndexService.kt new file mode 100644 index 0000000..b93ffba --- /dev/null +++ b/src/main/kotlin/simplerag/ragback/domain/index/service/IndexService.kt @@ -0,0 +1,67 @@ +package simplerag.ragback.domain.index.service + +import org.springframework.data.repository.findByIdOrNull +import org.springframework.stereotype.Service +import org.springframework.transaction.annotation.Transactional +import simplerag.ragback.domain.index.converter.toIndex +import simplerag.ragback.domain.index.converter.toIndexDetailResponse +import simplerag.ragback.domain.index.converter.toIndexPreviewResponse +import simplerag.ragback.domain.index.converter.toIndexPreviewResponseList +import simplerag.ragback.domain.index.dto.* +import simplerag.ragback.domain.index.repository.IndexRepository +import simplerag.ragback.global.error.ErrorCode +import simplerag.ragback.global.error.IndexException + +@Service +class IndexService( + private val indexRepository: IndexRepository, +) { + + @Transactional + fun createIndex(indexCreateRequest: IndexCreateRequest): IndexPreviewResponse { + + validateOverlap(indexCreateRequest.overlapSize, indexCreateRequest.chunkingSize) + + val createdIndex = indexRepository.save(toIndex(indexCreateRequest)) + return toIndexPreviewResponse(createdIndex) + } + + @Transactional(readOnly = true) + fun getIndexes(): IndexPreviewResponseList { + val indexes = indexRepository.findAllByOrderByCreatedAtDesc() + return toIndexPreviewResponseList(indexes) + } + + @Transactional(readOnly = true) + fun getIndex(indexId: Long): IndexDetailResponse { + val index = indexRepository.findByIdOrNull(indexId) ?: throw IndexException(ErrorCode.NOT_FOUND) + + return toIndexDetailResponse(index) + } + + @Transactional + fun updateIndex( + indexId: Long, + indexUpdateRequest: IndexUpdateRequest + ): IndexPreviewResponse { + val index = indexRepository.findByIdOrNull(indexId) ?: throw IndexException(ErrorCode.NOT_FOUND) + + validateOverlap(indexUpdateRequest.overlapSize, indexUpdateRequest.chunkingSize) + + index.update(indexUpdateRequest) + + return toIndexPreviewResponse(index) + } + + @Transactional + fun deleteIndex(indexId: Long) { + val index = indexRepository.findByIdOrNull(indexId) ?: throw IndexException(ErrorCode.NOT_FOUND) + + indexRepository.delete(index) + } + + private fun validateOverlap(overlapSize: Int, chunkingSize: Int) { + if (overlapSize >= chunkingSize) throw IndexException(ErrorCode.OVERLAP_OVERFLOW) + } + +} \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/prompt/entity/FewShot.kt b/src/main/kotlin/simplerag/ragback/domain/prompt/entity/FewShot.kt index 270be2f..3a4cec4 100644 --- a/src/main/kotlin/simplerag/ragback/domain/prompt/entity/FewShot.kt +++ b/src/main/kotlin/simplerag/ragback/domain/prompt/entity/FewShot.kt @@ -25,4 +25,4 @@ class FewShot( @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "few_shots_id") val id: Long? = null, -): BaseEntity() \ No newline at end of file +) : BaseEntity() \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/domain/prompt/entity/Prompt.kt b/src/main/kotlin/simplerag/ragback/domain/prompt/entity/Prompt.kt index 17a920e..1a787e5 100644 --- a/src/main/kotlin/simplerag/ragback/domain/prompt/entity/Prompt.kt +++ b/src/main/kotlin/simplerag/ragback/domain/prompt/entity/Prompt.kt @@ -22,4 +22,4 @@ class Prompt( @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "prompts_id") val id: Long? = null, -): BaseEntity() \ No newline at end of file +) : BaseEntity() \ No newline at end of file diff --git a/src/main/kotlin/simplerag/ragback/global/error/CustomException.kt b/src/main/kotlin/simplerag/ragback/global/error/CustomException.kt index 723a62c..c824906 100644 --- a/src/main/kotlin/simplerag/ragback/global/error/CustomException.kt +++ b/src/main/kotlin/simplerag/ragback/global/error/CustomException.kt @@ -11,6 +11,11 @@ class S3Exception( override val cause: Throwable? = null, ) : CustomException(errorCode, errorCode.message, cause) +class IndexException( + override val errorCode: ErrorCode, + override val cause: Throwable? = null, +) : CustomException(errorCode, errorCode.message, cause) + class FileException( override val errorCode: ErrorCode, override val message: String, diff --git a/src/main/kotlin/simplerag/ragback/global/error/ErrorCode.kt b/src/main/kotlin/simplerag/ragback/global/error/ErrorCode.kt index f1a2810..16e8c49 100644 --- a/src/main/kotlin/simplerag/ragback/global/error/ErrorCode.kt +++ b/src/main/kotlin/simplerag/ragback/global/error/ErrorCode.kt @@ -21,5 +21,8 @@ enum class ErrorCode( S3_INVALID_URL(HttpStatus.BAD_REQUEST, "S3_004", "유효하지 않은 S3 URL 입니다."), S3_EMPTY_FILE(HttpStatus.BAD_REQUEST, "S3_005", "빈 파일은 업로드할 수 없습니다."), S3_PRESIGN_FAIL(HttpStatus.INTERNAL_SERVER_ERROR, "S3_006", "프리사인 URL 발급 실패"), - S3_UNSUPPORTED_CONTENT_TYPE(HttpStatus.BAD_REQUEST, "S3_007", "지원하지 않는 Content-Type 입니다.") + S3_UNSUPPORTED_CONTENT_TYPE(HttpStatus.BAD_REQUEST, "S3_007", "지원하지 않는 Content-Type 입니다."), + + // index + OVERLAP_OVERFLOW(HttpStatus.BAD_REQUEST, "INDEX_001", "overlap 크기는 chunking 크기를 넘을 수 없습니다.") } diff --git a/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt b/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt index 3ac7b70..8003c7d 100644 --- a/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt +++ b/src/test/kotlin/simplerag/ragback/domain/document/service/DataFileServiceTest.kt @@ -1,6 +1,5 @@ package simplerag.ragback.domain.document.service -import jakarta.annotation.PostConstruct import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.DisplayName @@ -228,12 +227,12 @@ class DataFileServiceTest( dataFileRepository.saveAll( listOf( DataFile( - title = "exists", - type = "text/plain", - sizeBytes = 0, - sha256 = sha1, - fileUrl = "fake://original/exists.txt", - ), + title = "exists", + type = "text/plain", + sizeBytes = 0, + sha256 = sha1, + fileUrl = "fake://original/exists.txt", + ), DataFile( title = "exists2", type = "text/pdf", diff --git a/src/test/kotlin/simplerag/ragback/domain/index/service/IndexServiceTest.kt b/src/test/kotlin/simplerag/ragback/domain/index/service/IndexServiceTest.kt new file mode 100644 index 0000000..ea26cea --- /dev/null +++ b/src/test/kotlin/simplerag/ragback/domain/index/service/IndexServiceTest.kt @@ -0,0 +1,282 @@ +package simplerag.ragback.domain.index.service + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.test.context.ActiveProfiles +import simplerag.ragback.domain.index.dto.IndexCreateRequest +import simplerag.ragback.domain.index.dto.IndexUpdateRequest +import simplerag.ragback.domain.index.entity.Index +import simplerag.ragback.domain.index.entity.enums.EmbeddingModel +import simplerag.ragback.domain.index.entity.enums.SimilarityMetric +import simplerag.ragback.domain.index.repository.IndexRepository +import simplerag.ragback.global.error.IndexException + +@SpringBootTest +@ActiveProfiles("test") +class IndexServiceTest( + @Autowired private val indexRepository: IndexRepository, + @Autowired private val indexService: IndexService, +) { + + @AfterEach + fun cleanUp() { + indexRepository.deleteAll() + } + + @Test + @DisplayName("인덱스 생성이 정상 작동한다") + fun createIndexTest() { + // given + val indexCreateRequest = + IndexCreateRequest("test", 1, 0, SimilarityMetric.COSINE, 1, EmbeddingModel.TEXT_EMBEDDING_3_LARGE, true) + + // when + val createIndexResponse = indexService.createIndex(indexCreateRequest) + + // then + val indices = indexRepository.findAll() + val index = indices[0] + assertEquals(index.id, createIndexResponse.indexId) + assertEquals(index.snapshotName, createIndexResponse.snapshotName) + } + + @Test + @DisplayName("인덱스 생성 시 overlap 크기가 chunking 크기를 넘어가면 에러가 터진다") + fun createIndexTestWithOverlapSize() { + // given + val indexCreateRequest = + IndexCreateRequest("test", 1, 1, SimilarityMetric.COSINE, 1, EmbeddingModel.TEXT_EMBEDDING_3_LARGE, true) + + // when * then + val message = assertThrows { + indexService.createIndex(indexCreateRequest) + }.message + + assertEquals("overlap 크기는 chunking 크기를 넘을 수 없습니다.", message) + } + + @Test + @DisplayName("인덱스 리스트 조회가 된다") + fun getIndexesTest() { + // given + indexRepository.saveAll( + listOf( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ), + Index( + "test2", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + ) + + // when + val indexes = indexService.getIndexes() + + // then + assertThat(indexes.indexDetailResponse.size).isEqualTo(2) + } + + @Test + @DisplayName("인덱스 상세 조회가 된다") + fun getIndexTest() { + // given + val savedIndex = indexRepository.save( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + + // when + val index = indexService.getIndex(savedIndex.id!!) + + // then + assertThat(index.indexId).isEqualTo(savedIndex.id) + assertThat(index.snapshotName).isEqualTo(savedIndex.snapshotName) + assertThat(index.chunkingSize).isEqualTo(savedIndex.chunkingSize) + assertThat(index.overlapSize).isEqualTo(savedIndex.overlapSize) + assertThat(index.topK).isEqualTo(savedIndex.topK) + assertThat(index.embeddingModel).isEqualTo(savedIndex.embeddingModel) + assertThat(index.similarityMetric).isEqualTo(savedIndex.similarityMetric) + assertThat(index.reranker).isEqualTo(savedIndex.reranker) + } + + @Test + @DisplayName("인덱스 상세 조회 시 없는 인덱스를 조회하면 에러가 터진다.") + fun getIndexTestWithInvalidIndex() { + // given + val savedIndex = indexRepository.save( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + + // when * then + val message = assertThrows { indexService.getIndex(savedIndex.id!! + 1L) }.message + + assertEquals("리소스를 찾을 수 없습니다.", message) + } + + @Test + @DisplayName("인덱스 수정이 잘 된다") + fun updateIndexTest() { + // given + val savedIndex = indexRepository.save( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + + val indexUpdateRequest = IndexUpdateRequest("fixedTest", 2, 1, SimilarityMetric.EUCLIDEAN, 3, false) + + // when + indexService.updateIndex(savedIndex.id!!, indexUpdateRequest) + + // then + val optionalIndex = indexRepository.findById(savedIndex.id!!) + val index = optionalIndex.get() + + assertThat(savedIndex.id).isEqualTo(index.id) + assertThat(indexUpdateRequest.snapshotName).isEqualTo(index.snapshotName) + assertThat(indexUpdateRequest.chunkingSize).isEqualTo(index.chunkingSize) + assertThat(indexUpdateRequest.overlapSize).isEqualTo(index.overlapSize) + assertThat(indexUpdateRequest.topK).isEqualTo(index.topK) + assertThat(savedIndex.embeddingModel).isEqualTo(index.embeddingModel) + assertThat(indexUpdateRequest.similarityMetric).isEqualTo(index.similarityMetric) + assertThat(indexUpdateRequest.reranker).isEqualTo(index.reranker) + } + + @Test + @DisplayName("인덱스 수정 시 없는 인덱스를 조회하면 에러가 터진다.") + fun updateTestWithInvalidIndex() { + // given + val savedIndex = indexRepository.save( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + + val indexUpdateRequest = IndexUpdateRequest("fixedTest", 2, 1, SimilarityMetric.EUCLIDEAN, 3, false) + + // when * then + val message = + assertThrows { indexService.updateIndex(savedIndex.id!! + 1L, indexUpdateRequest) }.message + + assertEquals("리소스를 찾을 수 없습니다.", message) + } + + @Test + @DisplayName("인덱스 수정 시 overlap 크기가 chunking 크기를 넘어가면 에러가 터진다") + fun updateIndexTestWithOverlapSize() { + // given + val savedIndex = indexRepository.save( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + + val indexUpdateRequest = IndexUpdateRequest("fixedTest", 2, 2, SimilarityMetric.EUCLIDEAN, 3, false) + + // when * then + val message = assertThrows { + indexService.updateIndex(savedIndex.id!!, indexUpdateRequest) + }.message + + assertEquals("overlap 크기는 chunking 크기를 넘을 수 없습니다.", message) + } + + @Test + @DisplayName("인덱스 삭제가 잘 된다") + fun deleteIndexTest() { + // given + val savedIndex = indexRepository.save( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + + // when + indexService.deleteIndex(savedIndex.id!!) + + // then + val indexes = indexRepository.findAll() + assertEquals(0, indexes.size) + } + + @Test + @DisplayName("인덱스 삭제 시 없는 인덱스를 조회하면 에러가 터진다.") + fun deleteTestWithInvalidIndex() { + // given + val savedIndex = indexRepository.save( + Index( + "test", + 1, + 0, + SimilarityMetric.COSINE, + 1, + EmbeddingModel.TEXT_EMBEDDING_3_LARGE, + true + ) + ) + + // when * then + val message = assertThrows { indexService.deleteIndex(savedIndex.id!! + 1L) }.message + + assertEquals("리소스를 찾을 수 없습니다.", message) + } + +} \ No newline at end of file