diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index b266bcef0fb..415d20562ec 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -51,8 +51,15 @@ def __init__(self, problem_id: str, problem_content: str): class UpdateEmbeddingDatasetIdArgs: - def __init__(self, source_id_list: List[str], target_dataset_id: str): - self.source_id_list = source_id_list + def __init__(self, paragraph_id_list: List[str], target_dataset_id: str): + self.paragraph_id_list = paragraph_id_list + self.target_dataset_id = target_dataset_id + + +class UpdateEmbeddingDocumentIdArgs: + def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str): + self.paragraph_id_list = paragraph_id_list + self.target_document_id = target_document_id self.target_dataset_id = target_dataset_id @@ -213,13 +220,23 @@ def update_problem(args: UpdateProblemArgs): @staticmethod def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs): - VectorStore.get_embedding_vector().update_by_source_ids(args.source_id_list, - {'dataset_id': args.target_dataset_id}) + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'dataset_id': args.target_dataset_id}) + + @staticmethod + def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs): + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'document_id': args.target_document_id, + 'dataset_id': args.target_dataset_id}) @staticmethod def delete_embedding_by_source_ids(source_ids: List[str]): VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM) + @staticmethod + def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_ids) + @staticmethod def delete_embedding_by_dataset_id_list(source_ids: List[str]): VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 919172d139d..3da185d77b3 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -166,10 +166,12 @@ def migrate(self, with_valid=True): meta={}) else: document_list.update(dataset_id=target_dataset_id) - paragraph_list.update(dataset_id=target_dataset_id) + # 修改向量信息 ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( - [problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list], + [paragraph.id for paragraph in paragraph_list], target_dataset_id)) + # 修改段落信息 + paragraph_list.update(dataset_id=target_dataset_id) @staticmethod def get_target_dataset_problem(target_dataset_id: str, diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 0fc20744e75..589710ce09d 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -15,13 +15,13 @@ from rest_framework import serializers from common.db.search import page_search -from common.event.listener_manage import ListenerManagement +from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs, UpdateEmbeddingDatasetIdArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.field_message import ErrMessage from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping -from dataset.serializers.common_serializers import update_document_char_length +from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from embedding.models import SourceType @@ -272,6 +272,167 @@ def get_request_params_api(): description='问题id') ] + class Batch(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + @transaction.atomic + def batch_delete(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Paragraph, raise_exception=True) + self.is_valid(raise_exception=True) + paragraph_id_list = instance.get("id_list") + QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete() + QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete() + # 删除向量库 + ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_id_list) + return True + + class Migrate(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + target_dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标知识库id")) + target_document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标文档id")) + paragraph_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("段落列表"), + child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid("段落id"))) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_list = QuerySet(Document).filter( + id__in=[self.data.get('document_id'), self.data.get('target_document_id')]) + document_id = self.data.get('document_id') + target_document_id = self.data.get('target_document_id') + if document_id == target_document_id: + raise AppApiException(5000, "需要迁移的文档和目标文档一致") + if len([document for document in document_list if str(document.id) == self.data.get('document_id')]) < 1: + raise AppApiException(5000, f"文档id不存在【{self.data.get('document_id')}】") + if len([document for document in document_list if + str(document.id) == self.data.get('target_document_id')]) < 1: + raise AppApiException(5000, f"目标文档id不存在【{self.data.get('target_document_id')}】") + + @transaction.atomic + def migrate(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + target_dataset_id = self.data.get('target_dataset_id') + document_id = self.data.get('document_id') + target_document_id = self.data.get('target_document_id') + paragraph_id_list = self.data.get('paragraph_id_list') + paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id=document_id, + id__in=paragraph_id_list) + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list) + # 同数据集迁移 + if target_dataset_id == dataset_id: + if len(problem_paragraph_mapping_list): + problem_paragraph_mapping_list = [ + self.update_problem_paragraph_mapping(target_document_id, + problem_paragraph_mapping) for problem_paragraph_mapping + in + problem_paragraph_mapping_list] + # 修改mapping + QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, + ['document_id']) + # 修改向量段落信息 + ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( + [paragraph.id for paragraph in paragraph_list], + target_document_id, target_dataset_id)) + # 修改段落信息 + paragraph_list.update(document_id=target_document_id) + # 不同数据集迁移 + else: + problem_list = QuerySet(Problem).filter( + id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in + problem_paragraph_mapping_list]) + # 目标数据集问题 + target_problem_list = list( + QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list], + dataset_id=target_dataset_id)) + + target_handle_problem_list = [ + self.get_target_dataset_problem(target_dataset_id, target_document_id, problem_paragraph_mapping, + problem_list, target_problem_list) for + problem_paragraph_mapping + in + problem_paragraph_mapping_list] + + create_problem_list = [problem for problem, is_create in target_handle_problem_list if + is_create is not None and is_create] + # 插入问题 + QuerySet(Problem).bulk_create(create_problem_list) + # 修改mapping + QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, + ['problem_id', 'dataset_id', 'document_id']) + # 修改向量段落信息 + ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( + [paragraph.id for paragraph in paragraph_list], + target_document_id, target_dataset_id)) + # 修改段落信息 + paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id) + + @staticmethod + def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping): + problem_paragraph_mapping.document_id = target_document_id + return problem_paragraph_mapping + + @staticmethod + def get_target_dataset_problem(target_dataset_id: str, + target_document_id: str, + problem_paragraph_mapping, + source_problem_list, + target_problem_list): + source_problem_list = [source_problem for source_problem in source_problem_list if + source_problem.id == problem_paragraph_mapping.problem_id] + problem_paragraph_mapping.dataset_id = target_dataset_id + problem_paragraph_mapping.document_id = target_document_id + if len(source_problem_list) > 0: + problem_content = source_problem_list[-1].content + problem_list = [problem for problem in target_problem_list if problem.content == problem_content] + if len(problem_list) > 0: + problem = problem_list[-1] + problem_paragraph_mapping.problem_id = problem.id + return problem, False + else: + problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content) + target_problem_list.append(problem) + problem_paragraph_mapping.problem_id = problem.id + return problem, True + return None + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='target_dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='目标知识库id'), + openapi.Parameter(name='target_document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='目标知识库id') + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title='段落id列表', + description="段落id列表" + ) + class Operate(ApiMixin, serializers.Serializer): # 段落id paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index be68ccdc4c6..237a81f5976 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -25,6 +25,10 @@ path('dataset//document/migrate/', views.Document.Migrate.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), + path( + 'dataset//document//paragraph/migrate/dataset//document/', + views.Paragraph.BatchMigrate.as_view()), + path('dataset//document//paragraph/_batch', views.Paragraph.Batch.as_view()), path('dataset//document//paragraph//', views.Paragraph.Page.as_view(), name='paragraph_page'), path('dataset//document//paragraph/', diff --git a/apps/dataset/views/paragraph.py b/apps/dataset/views/paragraph.py index 3ce6f1114f9..af968b8ab27 100644 --- a/apps/dataset/views/paragraph.py +++ b/apps/dataset/views/paragraph.py @@ -12,9 +12,10 @@ from rest_framework.views import Request from common.auth import TokenAuth, has_permissions -from common.constants.permission_constants import Permission, Group, Operate +from common.constants.permission_constants import Permission, Group, Operate, CompareConstants from common.response import result from common.util.common import query_params_to_single_dict +from dataset.serializers.common_serializers import BatchSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers @@ -168,6 +169,50 @@ def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_ o.is_valid(raise_exception=True) return result.success(o.delete()) + class Batch(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="批量删除段落", + operation_id="批量删除段落", + request_body= + BatchSerializer.get_request_body_api(), + manual_parameters=ParagraphSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, document_id: str): + return result.success(ParagraphSerializers.Batch( + data={"dataset_id": dataset_id, 'document_id': document_id}).batch_delete(request.data)) + + class BatchMigrate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="批量迁移段落", + operation_id="批量迁移段落", + manual_parameters=ParagraphSerializers.Migrate.get_request_params_api(), + request_body=ParagraphSerializers.Migrate.get_request_body_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id')), + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('target_dataset_id')), + compare=CompareConstants.AND + ) + def put(self, request: Request, dataset_id: str, target_dataset_id: str, document_id: str, target_document_id): + return result.success( + ParagraphSerializers.Migrate( + data={'dataset_id': dataset_id, 'target_dataset_id': target_dataset_id, + 'document_id': document_id, + 'target_document_id': target_document_id, + 'paragraph_id_list': request.data}).migrate()) + class Page(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 496150ddab2..2bfd0e977d5 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -113,7 +113,7 @@ def search(self, query_text, dataset_id_list: list[str], exclude_document_id_lis return result[0] @abstractmethod - def query(self, query_text:str,query_embedding: List[float], dataset_id_list: list[str], + def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, search_mode: SearchMode): @@ -130,6 +130,10 @@ def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): pass + @abstractmethod + def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict): + pass + @abstractmethod def update_by_source_id(self, source_id: str, instance: Dict): pass @@ -173,3 +177,7 @@ def delete_by_source_ids(self, source_ids: List[str], source_type: str): @abstractmethod def delete_by_paragraph_id(self, paragraph_id: str): pass + + @abstractmethod + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index d5e5d125eff..5c0d045363b 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -119,6 +119,9 @@ def update_by_source_id(self, source_id: str, instance: Dict): def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance) + def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict): + QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance) + def delete_by_dataset_id(self, dataset_id: str): QuerySet(Embedding).filter(dataset_id=dataset_id).delete() @@ -139,6 +142,9 @@ def delete_by_source_id(self, source_id: str, source_type: str): def delete_by_paragraph_id(self, paragraph_id: str): QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete() + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete() + class ISearch(ABC): @abstractmethod diff --git a/ui/src/api/paragraph.ts b/ui/src/api/paragraph.ts index fc1169a57f4..675fa6efab9 100644 --- a/ui/src/api/paragraph.ts +++ b/ui/src/api/paragraph.ts @@ -48,6 +48,24 @@ const delParagraph: ( ) } +/** + * 批量删除段落 + * @param 参数 dataset_id, document_id + */ +const delMulParagraph: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data, loading) => { + return del( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/_batch`, + undefined, + { id_list: data }, + loading + ) +} + /** * 创建段落 * @param 参数 @@ -104,6 +122,33 @@ const putParagraph: ( ) } +/** + * 批量迁移段落 + * @param 参数 dataset_id,target_dataset_id, + */ +const putMigrateMulParagraph: ( + dataset_id: string, + document_id: string, + target_dataset_id: string, + target_document_id: string, + data: any, + loading?: Ref +) => Promise> = ( + dataset_id, + document_id, + target_dataset_id, + target_document_id, + data, + loading +) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/migrate/dataset/${target_dataset_id}/document/${target_document_id}`, + data, + undefined, + loading + ) +} + /** * 问题列表 * @param 参数 dataset_id,document_id,paragraph_id @@ -189,5 +234,7 @@ export default { getProblem, postProblem, disassociationProblem, - associationProblem + associationProblem, + delMulParagraph, + putMigrateMulParagraph } diff --git a/ui/src/styles/app.scss b/ui/src/styles/app.scss index 69b6440b48e..3743c0e7103 100644 --- a/ui/src/styles/app.scss +++ b/ui/src/styles/app.scss @@ -573,3 +573,33 @@ h4 { color: var(--app-text-color); } } + +// card 选中样式 +.selected { + border: 1px solid var(--el-color-primary) !important; + &:before { + content: ''; + position: absolute; + right: 0; + top: 0; + border: 14px solid var(--el-color-primary); + border-bottom-color: transparent; + border-left-color: transparent; + } + + &:after { + content: ''; + width: 3px; + height: 6px; + position: absolute; + right: 5px; + top: 2px; + border: 2px solid #fff; + border-top-color: transparent; + border-left-color: transparent; + transform: rotate(35deg); + } + &:hover { + border: 1px solid var(--el-color-primary); + } +} diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 8e5a5a4bc51..9b7cffdd1e2 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -20,15 +20,15 @@ v-if="datasetDetail.type === '1'" >同步文档 - 迁移 - 设置 - 删除 + + 迁移 + + + 设置 + + + 删除 + { function openBatchEditDocument() { const arr: string[] = multipleSelection.value.map((v) => v.id) - if (batchEditDocumentDialogRef) { - batchEditDocumentDialogRef?.value?.open(arr) - } + batchEditDocumentDialogRef?.value?.open(arr) } /** diff --git a/ui/src/views/paragraph/component/SelectDocumentDialog.vue b/ui/src/views/paragraph/component/SelectDocumentDialog.vue new file mode 100644 index 00000000000..caebe338ffd --- /dev/null +++ b/ui/src/views/paragraph/component/SelectDocumentDialog.vue @@ -0,0 +1,162 @@ + + + diff --git a/ui/src/views/paragraph/index.vue b/ui/src/views/paragraph/index.vue index 8471f759518..0307ab723a5 100644 --- a/ui/src/views/paragraph/index.vue +++ b/ui/src/views/paragraph/index.vue @@ -8,7 +8,20 @@ >)
- 添加分段 + + 批量选择 + + + 取消选择 + + + 添加分段 +
+ + +
+ + +
+ @@ -90,8 +138,20 @@
+ +
+ + 迁移 + + + + 删除 + + 已选 {{ multipleSelection.length }} 项 +
+