Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,15 @@ def __init__(self, problem_id: str, problem_content: str):


class UpdateEmbeddingDatasetIdArgs:
def __init__(self, source_id_list: List[str], target_dataset_id: str):
self.source_id_list = source_id_list
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
self.paragraph_id_list = paragraph_id_list
self.target_dataset_id = target_dataset_id


class UpdateEmbeddingDocumentIdArgs:
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str):
self.paragraph_id_list = paragraph_id_list
self.target_document_id = target_document_id
self.target_dataset_id = target_dataset_id


Expand Down Expand Up @@ -213,13 +220,23 @@ def update_problem(args: UpdateProblemArgs):

@staticmethod
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
VectorStore.get_embedding_vector().update_by_source_ids(args.source_id_list,
{'dataset_id': args.target_dataset_id})
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})

@staticmethod
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id,
'dataset_id': args.target_dataset_id})

@staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)

@staticmethod
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_ids)

@staticmethod
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
Expand Down
6 changes: 4 additions & 2 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,12 @@ def migrate(self, with_valid=True):
meta={})
else:
document_list.update(dataset_id=target_dataset_id)
paragraph_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
[problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list],
[paragraph.id for paragraph in paragraph_list],
target_dataset_id))
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id)

@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
Expand Down
165 changes: 163 additions & 2 deletions apps/dataset/serializers/paragraph_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from rest_framework import serializers

from common.db.search import page_search
from common.event.listener_manage import ListenerManagement
from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs, UpdateEmbeddingDatasetIdArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType

Expand Down Expand Up @@ -272,6 +272,167 @@ def get_request_params_api():
description='问题id')
]

class Batch(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))

@transaction.atomic
def batch_delete(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Paragraph, raise_exception=True)
self.is_valid(raise_exception=True)
paragraph_id_list = instance.get("id_list")
QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_id_list)
return True

class Migrate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
target_dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标知识库id"))
target_document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标文档id"))
paragraph_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("段落列表"),
child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid("段落id")))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_list = QuerySet(Document).filter(
id__in=[self.data.get('document_id'), self.data.get('target_document_id')])
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
if document_id == target_document_id:
raise AppApiException(5000, "需要迁移的文档和目标文档一致")
if len([document for document in document_list if str(document.id) == self.data.get('document_id')]) < 1:
raise AppApiException(5000, f"文档id不存在【{self.data.get('document_id')}】")
if len([document for document in document_list if
str(document.id) == self.data.get('target_document_id')]) < 1:
raise AppApiException(5000, f"目标文档id不存在【{self.data.get('target_document_id')}】")

@transaction.atomic
def migrate(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
target_dataset_id = self.data.get('target_dataset_id')
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
paragraph_id_list = self.data.get('paragraph_id_list')
paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id=document_id,
id__in=paragraph_id_list)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
# 同数据集迁移
if target_dataset_id == dataset_id:
if len(problem_paragraph_mapping_list):
problem_paragraph_mapping_list = [
self.update_problem_paragraph_mapping(target_document_id,
problem_paragraph_mapping) for problem_paragraph_mapping
in
problem_paragraph_mapping_list]
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
# 修改段落信息
paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移
else:
problem_list = QuerySet(Problem).filter(
id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
problem_paragraph_mapping_list])
# 目标数据集问题
target_problem_list = list(
QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
dataset_id=target_dataset_id))

target_handle_problem_list = [
self.get_target_dataset_problem(target_dataset_id, target_document_id, problem_paragraph_mapping,
problem_list, target_problem_list) for
problem_paragraph_mapping
in
problem_paragraph_mapping_list]

create_problem_list = [problem for problem, is_create in target_handle_problem_list if
is_create is not None and is_create]
# 插入问题
QuerySet(Problem).bulk_create(create_problem_list)
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'dataset_id', 'document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)

@staticmethod
def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping):
problem_paragraph_mapping.document_id = target_document_id
return problem_paragraph_mapping

@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
target_document_id: str,
problem_paragraph_mapping,
source_problem_list,
target_problem_list):
source_problem_list = [source_problem for source_problem in source_problem_list if
source_problem.id == problem_paragraph_mapping.problem_id]
problem_paragraph_mapping.dataset_id = target_dataset_id
problem_paragraph_mapping.document_id = target_document_id
if len(source_problem_list) > 0:
problem_content = source_problem_list[-1].content
problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
if len(problem_list) > 0:
problem = problem_list[-1]
problem_paragraph_mapping.problem_id = problem.id
return problem, False
else:
problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content)
target_problem_list.append(problem)
problem_paragraph_mapping.problem_id = problem.id
return problem, True
return None

@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='target_dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='目标知识库id'),
openapi.Parameter(name='target_document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='目标知识库id')
]

@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title='段落id列表',
description="段落id列表"
)

class Operate(ApiMixin, serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
Expand Down
4 changes: 4 additions & 0 deletions apps/dataset/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
path('dataset/<str:dataset_id>/document/migrate/<str:target_dataset_id>', views.Document.Migrate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path(
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/migrate/dataset/<str:target_dataset_id>/document/<str:target_document_id>',
views.Paragraph.BatchMigrate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/_batch', views.Paragraph.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
Expand Down
47 changes: 46 additions & 1 deletion apps/dataset/views/paragraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from rest_framework.views import Request

from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.constants.permission_constants import Permission, Group, Operate, CompareConstants
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.common_serializers import BatchSerializer
from dataset.serializers.paragraph_serializers import ParagraphSerializers


Expand Down Expand Up @@ -168,6 +169,50 @@ def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_
o.is_valid(raise_exception=True)
return result.success(o.delete())

class Batch(APIView):
authentication_classes = [TokenAuth]

@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="批量删除段落",
operation_id="批量删除段落",
request_body=
BatchSerializer.get_request_body_api(),
manual_parameters=ParagraphSerializers.Create.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str):
return result.success(ParagraphSerializers.Batch(
data={"dataset_id": dataset_id, 'document_id': document_id}).batch_delete(request.data))

class BatchMigrate(APIView):
authentication_classes = [TokenAuth]

@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="批量迁移段落",
operation_id="批量迁移段落",
manual_parameters=ParagraphSerializers.Migrate.get_request_params_api(),
request_body=ParagraphSerializers.Migrate.get_request_body_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('target_dataset_id')),
compare=CompareConstants.AND
)
def put(self, request: Request, dataset_id: str, target_dataset_id: str, document_id: str, target_document_id):
return result.success(
ParagraphSerializers.Migrate(
data={'dataset_id': dataset_id, 'target_dataset_id': target_dataset_id,
'document_id': document_id,
'target_document_id': target_document_id,
'paragraph_id_list': request.data}).migrate())

class Page(APIView):
authentication_classes = [TokenAuth]

Expand Down
10 changes: 9 additions & 1 deletion apps/embedding/vector/base_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def search(self, query_text, dataset_id_list: list[str], exclude_document_id_lis
return result[0]

@abstractmethod
def query(self, query_text:str,query_embedding: List[float], dataset_id_list: list[str],
def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
Expand All @@ -130,6 +130,10 @@ def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list:
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass

@abstractmethod
def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
pass

@abstractmethod
def update_by_source_id(self, source_id: str, instance: Dict):
pass
Expand Down Expand Up @@ -173,3 +177,7 @@ def delete_by_source_ids(self, source_ids: List[str], source_type: str):
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass

@abstractmethod
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
pass
6 changes: 6 additions & 0 deletions apps/embedding/vector/pg_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def update_by_source_id(self, source_id: str, instance: Dict):
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)

def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)

def delete_by_dataset_id(self, dataset_id: str):
QuerySet(Embedding).filter(dataset_id=dataset_id).delete()

Expand All @@ -139,6 +142,9 @@ def delete_by_source_id(self, source_id: str, source_type: str):
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()

def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()


class ISearch(ABC):
@abstractmethod
Expand Down
Loading