From de930a43014cc64e018d432d0c397e0376cf4bac Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 12 Jun 2024 17:00:12 +0800 Subject: [PATCH 01/52] feat(providers/weaviate): update min weaviate-client version to 4.4.0 --- airflow/providers/weaviate/provider.yaml | 2 +- generated/provider_dependencies.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/weaviate/provider.yaml b/airflow/providers/weaviate/provider.yaml index b060c7407b704..577fce2688a70 100644 --- a/airflow/providers/weaviate/provider.yaml +++ b/airflow/providers/weaviate/provider.yaml @@ -50,7 +50,7 @@ integrations: dependencies: - apache-airflow>=2.7.0 - httpx>=0.25.0 - - weaviate-client>=3.24.2 + - weaviate-client>=4.4.0 # In pandas 2.2 minimal version of the sqlalchemy is 2.0 # https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies # However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723 diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index a794cfc25ef2f..f4760e5a0aeaa 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1319,7 +1319,7 @@ "httpx>=0.25.0", "pandas>=1.5.3,<2.2;python_version<\"3.9\"", "pandas>=2.1.2,<2.2;python_version>=\"3.9\"", - "weaviate-client>=3.24.2" + "weaviate-client>=4.4.0" ], "devel-deps": [], "plugins": [], From 88fcf40ddb685884c5b9be7ee67f5c2380b6fd17 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 12 Jun 2024 19:05:39 +0800 Subject: [PATCH 02/52] feat(providers/weaviate): update airflow connection to v4 style --- airflow/providers/weaviate/hooks/weaviate.py | 71 ++++++++++++-------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 56c7f666330f9..f87bfb03d910e 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -26,8 +26,8 @@ import weaviate.exceptions from deprecated import deprecated from tenacity import Retrying, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt -from weaviate import Client as WeaviateClient -from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword +from weaviate import WeaviateClient +from weaviate.auth import Auth from weaviate.data.replication import ConsistencyLevel from weaviate.exceptions import ObjectAlreadyExistsException from weaviate.util import generate_uuid5 @@ -39,8 +39,11 @@ from typing import Callable, Collection, Literal import pandas as pd + from weaviate.auth import AuthCredentials from weaviate.types import UUID + from airflow.models.connection import Connection + ExitingSchemaOptions = Literal["replace", "fail", "ignore"] HTTP_RETRY_STATUS_CODE = [429, 500, 503, 504] @@ -86,19 +89,20 @@ def __init__( @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to connection form.""" - from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext - from wtforms import PasswordField + from wtforms import PasswordField, StringField return { "token": PasswordField(lazy_gettext("Weaviate API Key"), widget=BS3PasswordFieldWidget()), + "grpc_host": StringField(lazy_gettext("gRPC host"), widget=BS3TextFieldWidget()), + "grpc_port": StringField(lazy_gettext("gRPC port"), widget=BS3TextFieldWidget()), } @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { - "hidden_fields": ["port", "schema"], "relabeling": { "login": "OIDC Username", "password": "OIDC Password", @@ -107,34 +111,47 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: def get_conn(self) -> WeaviateClient: conn = self.get_connection(self.conn_id) - url = conn.host - username = conn.login or "" - password = conn.password or "" + http_host = conn.host + http_port = conn.port or 8080 + extras = conn.extra_dejson + additional_headers = extras.pop("additional_headers", {}) + grpc_host = extras.pop("grpc_host", http_host) + grpc_port = extras.pop("grpc_port", 50051) + + return weaviate.connect_to_custom( + http_host=http_host, + http_port=http_port, + http_secure=False, + grpc_host=grpc_host, + grpc_port=grpc_port, + grpc_secure=False, + headers=additional_headers, + auth_credentials=self._extract_auth_credentials(conn), + ) + + def _extract_auth_credentials(self, conn: Connection) -> AuthCredentials: extras = conn.extra_dejson - access_token = extras.get("access_token", None) - refresh_token = extras.get("refresh_token", None) - expires_in = extras.get("expires_in", 60) # previously token was used as api_key(backwards compatibility) api_key = extras.get("api_key", None) or extras.get("token", None) - client_secret = extras.get("client_secret", None) - additional_headers = extras.pop("additional_headers", {}) - scope = extras.get("scope", None) or extras.get("oidc_scope", None) if api_key: - auth_client_secret: AuthApiKey | AuthBearerToken | AuthClientCredentials | AuthClientPassword = ( - AuthApiKey(api_key) - ) - elif access_token: - auth_client_secret = AuthBearerToken( - access_token, expires_in=expires_in, refresh_token=refresh_token + return Auth.api_key(api_key=api_key) + + access_token = extras.get("access_token", None) + if access_token: + refresh_token = extras.get("refresh_token", None) + expires_in = extras.get("expires_in", 60) + return Auth.bearer_token( + access_token=access_token, expires_in=expires_in, refresh_token=refresh_token ) - elif client_secret: - auth_client_secret = AuthClientCredentials(client_secret=client_secret, scope=scope) - else: - auth_client_secret = AuthClientPassword(username=username, password=password, scope=scope) - return WeaviateClient( - url=url, auth_client_secret=auth_client_secret, additional_headers=additional_headers - ) + client_secret = extras.get("client_secret", None) + if client_secret: + scope = extras.get("scope", None) or extras.get("oidc_scope", None) + return Auth.client_credentials(client_secret=client_secret, scope=scope) + + username = conn.login or "" + password = conn.password or "" + return Auth.client_password(username=username, password=password, scope=scope) @cached_property def conn(self) -> WeaviateClient: From d0a47baa86cd680ab36c1e9da1bf38d5125814cd Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 13 Jun 2024 15:48:35 +0800 Subject: [PATCH 03/52] feat(providers/weaviate): add http_secure and grpc_secure --- airflow/providers/weaviate/hooks/weaviate.py | 27 ++++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index f87bfb03d910e..f53affe3f4f0b 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -91,18 +91,23 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext - from wtforms import PasswordField, StringField + from wtforms import BooleanField, PasswordField, StringField return { + "http_secure": BooleanField(lazy_gettext("Use https"), default=False), "token": PasswordField(lazy_gettext("Weaviate API Key"), widget=BS3PasswordFieldWidget()), "grpc_host": StringField(lazy_gettext("gRPC host"), widget=BS3TextFieldWidget()), "grpc_port": StringField(lazy_gettext("gRPC port"), widget=BS3TextFieldWidget()), + "grcp_secure": BooleanField( + lazy_gettext("Use a secure channel for the underlying gRPC API"), default=False + ), } @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { + "hidden_fields": ["schema"], "relabeling": { "login": "OIDC Username", "password": "OIDC Password", @@ -111,21 +116,15 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: def get_conn(self) -> WeaviateClient: conn = self.get_connection(self.conn_id) - http_host = conn.host - http_port = conn.port or 8080 extras = conn.extra_dejson - additional_headers = extras.pop("additional_headers", {}) - grpc_host = extras.pop("grpc_host", http_host) - grpc_port = extras.pop("grpc_port", 50051) - return weaviate.connect_to_custom( - http_host=http_host, - http_port=http_port, - http_secure=False, - grpc_host=grpc_host, - grpc_port=grpc_port, - grpc_secure=False, - headers=additional_headers, + http_host=conn.host, + http_port=conn.port or 8080, + http_secure=extras.pop("http_secure", False), + grpc_host=extras.pop("grpc_host", conn.host), + grpc_port=extras.pop("grpc_port", 50051), + grpc_secure=extras.pop("grcp_secure", False), + headers=extras.pop("additional_headers", {}), auth_credentials=self._extract_auth_credentials(conn), ) From 41f25d5fcdf24a4350cb839d9780fd7e900cfde1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 13 Jun 2024 16:14:01 +0800 Subject: [PATCH 04/52] feat(providers/weaviate): migrate test_connections to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index f53affe3f4f0b..34066e61728f7 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -169,7 +169,7 @@ def get_client(self) -> WeaviateClient: def test_connection(self) -> tuple[bool, str]: try: client = self.conn - client.schema.get() + client.collections.list_all() return True, "Connection established!" except Exception as e: self.log.error("Error testing Weaviate connection: %s", e) From c366d5474845c4f5fb4594ec092e74073863ac9e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 13 Jun 2024 16:31:31 +0800 Subject: [PATCH 05/52] feat(providers/weaviate): migrate create_class to create_collection --- airflow/providers/weaviate/hooks/weaviate.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 34066e61728f7..1e8ef9d2fb2f1 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -40,6 +40,8 @@ import pandas as pd from weaviate.auth import AuthCredentials + from weaviate.collections.classes.internal import References + from weaviate.collections.classes.types import Properties from weaviate.types import UUID from airflow.models.connection import Connection @@ -175,10 +177,10 @@ def test_connection(self) -> tuple[bool, str]: self.log.error("Error testing Weaviate connection: %s", e) return False, str(e) - def create_class(self, class_json: dict[str, Any]) -> None: - """Create a new class.""" + def create_collection(self, name: str, **kwargs) -> Collection[Properties, References]: + """Create a new collection.""" client = self.conn - client.schema.create_class(class_json) + return client.collections.create(name=name, **kwargs) @retry( reraise=True, From cf282b08ac3f864ebe81b9f97f0ebce764fa191a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 13 Jun 2024 16:47:43 +0800 Subject: [PATCH 06/52] feat(providers/weaviate): migrate get_schema to get_collection_configuraiton --- airflow/providers/weaviate/hooks/weaviate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 1e8ef9d2fb2f1..442f5a0496197 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -40,6 +40,7 @@ import pandas as pd from weaviate.auth import AuthCredentials + from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple from weaviate.collections.classes.internal import References from weaviate.collections.classes.types import Properties from weaviate.types import UUID @@ -223,14 +224,13 @@ def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) ), ) - def get_schema(self, class_name: str | None = None): - """Get the schema from Weaviate. + def get_collection_configuraiton(self, collection_name: str) -> CollectionConfig | CollectionConfigSimple: + """Get the collection configuration from Weaviate. - :param class_name: The class for which to return the schema. If NOT provided the whole schema is - returned, otherwise only the schema of this class is returned. By default None. + :param collection_name: The collection for which to return the collection configuration. """ client = self.get_client() - return client.schema.get(class_name) + return client.collections.get(collection_name).config.get() def delete_classes(self, class_names: list[str] | str, if_error: str = "stop") -> list[str] | None: """Delete all or specific classes if class_names are provided. @@ -241,6 +241,7 @@ def delete_classes(self, class_names: list[str] | str, if_error: str = "stop") - :return: if `if_error=continue` return list of classes which we failed to delete. if `if_error=stop` returns None. """ + # TODO: migrate to v4 API client = self.get_client() class_names = [class_names] if class_names and isinstance(class_names, str) else class_names From 63dd7a68bbbda2fda3b6564d44cea7b43778e50e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 13 Jun 2024 17:11:45 +0800 Subject: [PATCH 07/52] feat(providers/weaviate): migrate delete_classes to delete_collections --- airflow/providers/weaviate/hooks/weaviate.py | 29 +++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 442f5a0496197..850914b532060 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -232,21 +232,24 @@ def get_collection_configuraiton(self, collection_name: str) -> CollectionConfig client = self.get_client() return client.collections.get(collection_name).config.get() - def delete_classes(self, class_names: list[str] | str, if_error: str = "stop") -> list[str] | None: - """Delete all or specific classes if class_names are provided. + def delete_collections( + self, collection_names: list[str] | str, if_error: str = "stop" + ) -> list[str] | None: + """Delete all or specific collections if collection_names are provided. - :param class_names: list of class names to be deleted. - :param if_error: define the actions to be taken if there is an error while deleting a class, possible + :param collection_names: list of collection names to be deleted. + :param if_error: define the actions to be taken if there is an error while deleting a collection, possible options are `stop` and `continue` - :return: if `if_error=continue` return list of classes which we failed to delete. + :return: if `if_error=continue` return list of collections which we failed to delete. if `if_error=stop` returns None. """ - # TODO: migrate to v4 API client = self.get_client() - class_names = [class_names] if class_names and isinstance(class_names, str) else class_names + collection_names = ( + [collection_names] if collection_names and isinstance(collection_names, str) else collection_names + ) - failed_class_list = [] - for class_name in class_names: + failed_collection_list = [] + for collection_name in collection_names: try: for attempt in Retrying( stop=stop_after_attempt(3), @@ -256,17 +259,17 @@ def delete_classes(self, class_names: list[str] | str, if_error: str = "stop") - ), ): with attempt: - print(attempt) - client.schema.delete_class(class_name) + self.log.info(attempt) + client.collections.delete(collection_name) except Exception as e: if if_error == "continue": self.log.error(e) - failed_class_list.append(class_name) + failed_collection_list.append(collection_name) elif if_error == "stop": raise e if if_error == "continue": - return failed_class_list + return failed_collection_list return None def delete_all_schema(self): From 814ca574c5297dd6416817a70e8ae651580841a5 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 15:32:17 +0800 Subject: [PATCH 08/52] feat(providers/weaviate): migrate query_with_vector to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 850914b532060..841ee12cafdc1 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -41,8 +41,8 @@ import pandas as pd from weaviate.auth import AuthCredentials from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple - from weaviate.collections.classes.internal import References - from weaviate.collections.classes.types import Properties + from weaviate.collections.classes.internal import QuerySearchReturnType + from weaviate.collections.classes.types import Properties, References, TProperties, TReferences from weaviate.types import UUID from airflow.models.connection import Connection @@ -512,27 +512,25 @@ def _process_batch_errors( def query_with_vector( self, embeddings: list[float], - class_name: str, + collection_name: str, *properties: list[str], certainty: float = 0.7, limit: int = 1, - ) -> dict[str, dict[Any, Any]]: + ) -> QuerySearchReturnType[Properties, References, TProperties, TReferences]: """ Query weaviate database with near vectors. This method uses a vector search using a Get query. we are using a with_near_vector to provide - weaviate with a query with vector itself. This is needed for query a Weaviate class with a custom, + weaviate with a query with vector itself. This is needed for query a Weaviate class with a custom, external vectorizer. Weaviate then converts this into a vector through the inference API (OpenAI in this particular example) and uses that vector as the basis for a vector search. """ client = self.conn - results: dict[str, dict[Any, Any]] = ( - client.query.get(class_name, properties[0]) - .with_near_vector({"vector": embeddings, "certainty": certainty}) - .with_limit(limit) - .do() + collection = client.collections.get(collection_name) + response = collection.query.near_vector( + near_vector=embeddings, certainty=certainty, limit=limit, return_properties=properties ) - return results + return response def query_without_vector( self, search_text: str, class_name: str, *properties: list[str], limit: int = 1 From b505d2a63e193341ca673a212418b4c5403020f0 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 15:41:31 +0800 Subject: [PATCH 09/52] feat(providers/weaviate): migrate query_with_text to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 841ee12cafdc1..534229da431ab 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -533,8 +533,8 @@ def query_with_vector( return response def query_without_vector( - self, search_text: str, class_name: str, *properties: list[str], limit: int = 1 - ) -> dict[str, dict[Any, Any]]: + self, search_text: str, collection_name: str, *properties: list[str], limit: int = 1 + ): """ Query using near text. @@ -543,13 +543,9 @@ def query_without_vector( API (OpenAI in this particular example) and uses that vector as the basis for a vector search. """ client = self.conn - results: dict[str, dict[Any, Any]] = ( - client.query.get(class_name, properties[0]) - .with_near_text({"concepts": [search_text]}) - .with_limit(limit) - .do() - ) - return results + collection = client.collections.get(collection_name) + response = collection.query.near_text(query=search_text, limit=limit, return_properties=properties) + return response def create_object( self, data_object: dict | str, class_name: str, **kwargs From 35aadc658cbca3866f82e792901e0c9ba40643fe Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 15:46:08 +0800 Subject: [PATCH 10/52] feat(providers/weaviate): migrate create_object to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 534229da431ab..abc006c4e0017 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -547,20 +547,19 @@ def query_without_vector( response = collection.query.near_text(query=search_text, limit=limit, return_properties=properties) return response - def create_object( - self, data_object: dict | str, class_name: str, **kwargs - ) -> str | dict[str, Any] | None: + def create_object(self, data_object: dict | str, collection_name: str, **kwargs) -> UUID | None: """Create a new object. :param data_object: Object to be added. If type is str it should be either a URL or a file. - :param class_name: Class name associated with the object given. + :param collection_name: Colletion name associated with the object given. :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() """ client = self.conn + collection = client.collections.get(collection_name) # generate deterministic uuid if not provided uuid = kwargs.pop("uuid", generate_uuid5(data_object)) try: - return client.data_object.create(data_object, class_name, uuid=uuid, **kwargs) + return collection.data.insert(properties=data_object, uuid=uuid, **kwargs) except ObjectAlreadyExistsException: self.log.warning("Object with the UUID %s already exists", uuid) return None From f1cda43f1621292240f9c011b6047eab82ee2fbc Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 15:54:18 +0800 Subject: [PATCH 11/52] feat(providers/weaviate): migrate get_object to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index abc006c4e0017..4ef5438eb978a 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -41,7 +41,7 @@ import pandas as pd from weaviate.auth import AuthCredentials from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple - from weaviate.collections.classes.internal import QuerySearchReturnType + from weaviate.collections.classes.internal import QueryReturnType, QuerySearchReturnType from weaviate.collections.classes.types import Properties, References, TProperties, TReferences from weaviate.types import UUID @@ -551,7 +551,7 @@ def create_object(self, data_object: dict | str, collection_name: str, **kwargs) """Create a new object. :param data_object: Object to be added. If type is str it should be either a URL or a file. - :param collection_name: Colletion name associated with the object given. + :param collection_name: Collection name associated with the object given. :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() """ client = self.conn @@ -603,14 +603,17 @@ def get_or_create_object( ) return obj - def get_object(self, **kwargs) -> dict[str, Any] | None: + def get_object( + self, collection_name: str, **kwargs + ) -> QueryReturnType[Properties, References, TProperties, TReferences] | None: """Get objects or an object from weaviate. - :param kwargs: parameters to be passed to weaviate_client.data_object.get() or - weaviate_client.data_object.get_by_id() + :param kwargs: parameters to be passed to collection.data.fetch_object_by_id() or + collection.data.fetch_objects() """ client = self.conn - return client.data_object.get(**kwargs) + collection = client.collections.get(collection_name) + return collection.query.fetch_objects(**kwargs) def get_all_objects( self, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs From 95f9db4f4e922c6690bcabbea1c67135dea32d50 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 15:56:46 +0800 Subject: [PATCH 12/52] feat(providers/weaviate): migrate delete_object to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 4ef5438eb978a..c51f740d9adc1 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -640,14 +640,15 @@ def get_all_objects( return pandas.DataFrame(all_objects) return all_objects - def delete_object(self, uuid: UUID | str, **kwargs) -> None: + def delete_object(self, collection_name: str, uuid: UUID | str) -> bool: """Delete an object from weaviate. + :param collection_name: Collection name associated with the object given. :param uuid: uuid of the object to be deleted - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.delete() """ client = self.conn - client.data_object.delete(uuid, **kwargs) + collection = client.collections.get(collection_name) + return collection.data.delete_by_id(uuid=uuid) def update_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None: """Update an object in weaviate. From c6b3e68c2f4b81b06147ca36cad04953c546a44e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 16:09:26 +0800 Subject: [PATCH 13/52] feat(providers/weaviate): migrate update_object to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index c51f740d9adc1..2fc8f7ebbbb65 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -650,18 +650,19 @@ def delete_object(self, collection_name: str, uuid: UUID | str) -> bool: collection = client.collections.get(collection_name) return collection.data.delete_by_id(uuid=uuid) - def update_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None: + def update_object( + self, collection_name: str, uuid: UUID | str, properties: Properties | None = None, **kwargs + ) -> None: """Update an object in weaviate. - :param data_object: The object states the fields that should be updated. Fields not specified in the - 'data_object' remain unchanged. Fields that are None will not be changed. - If type is str it should be either an URL or a file. - :param class_name: Class name associated with the object given. + :param collection_name: Collection name associated with the object given. :param uuid: uuid of the object to be updated - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.update() + :param properties: The properties of the object. + :param kwargs: Optional parameters to be passed to collection.data.update() """ client = self.conn - client.data_object.update(data_object, class_name, uuid, **kwargs) + collection = client.collections.get(collection_name) + collection.data.update(uuid=uuid, properties=properties, **kwargs) def replace_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None: """Replace an object in weaviate. From 52495de851287bb1c8d2608532bc995cea9c193e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 21:54:39 +0800 Subject: [PATCH 14/52] feat(providers/weaviate): migrate replace_object to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 25 +++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 2fc8f7ebbbb65..6eac426854511 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -41,7 +41,7 @@ import pandas as pd from weaviate.auth import AuthCredentials from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple - from weaviate.collections.classes.internal import QueryReturnType, QuerySearchReturnType + from weaviate.collections.classes.internal import QueryReturnType, QuerySearchReturnType, ReferenceInputs from weaviate.collections.classes.types import Properties, References, TProperties, TReferences from weaviate.types import UUID @@ -664,17 +664,24 @@ def update_object( collection = client.collections.get(collection_name) collection.data.update(uuid=uuid, properties=properties, **kwargs) - def replace_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None: + def replace_object( + self, + collection_name: str, + uuid: UUID | str, + properties: Properties, + references: ReferenceInputs | None = None, + **kwargs, + ) -> None: """Replace an object in weaviate. - :param data_object: The object states the fields that should be updated. Fields not specified in the - 'data_object' will be set to None. If type is str it should be either an URL or a file. - :param class_name: Class name associated with the object given. - :param uuid: uuid of the object to be replaced - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.replace() + :param collection_name: Collection name associated with the object given. + :param uuid: uuid of the object to be updated + :param properties: The properties of the object. + :param references: Any references to other objects in Weaviate. + :param kwargs: Optional parameters to be passed to collection.data.replace() """ - client = self.conn - client.data_object.replace(data_object, class_name, uuid, **kwargs) + collection = self.get_collection(collection_name) + collection.data.replace(uuid=uuid, properties=properties, references=references, **kwargs) def validate_object(self, data_object: dict | str, class_name: str, **kwargs): """Validate an object in weaviate. From 2ade84d1600a655fea073653cee359a6f78b50eb Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 21:55:39 +0800 Subject: [PATCH 15/52] feat(providers/weaviate): migrate object_exists to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 6eac426854511..a01f3e82f9490 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -693,14 +693,15 @@ def validate_object(self, data_object: dict | str, class_name: str, **kwargs): client = self.conn client.data_object.validate(data_object, class_name, **kwargs) - def object_exists(self, uuid: str | UUID, **kwargs) -> bool: + def object_exists(self, collection_name: str, uuid: str | UUID) -> bool: """Check if an object exists in weaviate. + :param collection_name: Collection name associated with the object given. :param uuid: The UUID of the object that may or may not exist within Weaviate. - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.exists() """ client = self.conn - return client.data_object.exists(uuid, **kwargs) + collection = client.collections.get(collection_name) + return collection.data.exists(uuid=uuid) def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per_object: int = 5): """Delete multiple objects. From 5b84c853cd4d1e238330d072a83036dd95042c29 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 16:19:29 +0800 Subject: [PATCH 16/52] feat(providers/weaviate): migrate _generate_uuids to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index a01f3e82f9490..46513507aff8b 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -736,7 +736,7 @@ def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per def _generate_uuids( self, df: pd.DataFrame, - class_name: str, + collection_name: str, unique_columns: list[str], vector_column: str | None = None, uuid_column: str | None = None, @@ -748,7 +748,7 @@ def _generate_uuids( The function can potentially ingest the same data multiple times with different UUIDs. :param df: A dataframe with data to generate a UUID from. - :param class_name: The name of the class use as part of the uuid namespace. + :param collection_name: The name of the collection use as part of the uuid namespace. :param uuid_column: Name of the column to create. Default is 'id'. :param unique_columns: A list of columns to use for UUID generation. By default, all columns except vector_column will be used. @@ -779,7 +779,7 @@ def _generate_uuids( df[uuid_column] = ( df[unique_columns] .drop(columns=[vector_column], inplace=False, errors="ignore") - .apply(lambda row: generate_uuid5(identifier=row.to_dict(), namespace=class_name), axis=1) + .apply(lambda row: generate_uuid5(identifier=row.to_dict(), namespace=collection_name), axis=1) ) return df, uuid_column From 84a1ce30db99dd94366e103ba8983a337ed484ff Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 16:32:18 +0800 Subject: [PATCH 17/52] refactor(providers/weaviate): extract common get collection logic --- airflow/providers/weaviate/hooks/weaviate.py | 23 +++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 46513507aff8b..03491e18081e8 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -183,6 +183,14 @@ def create_collection(self, name: str, **kwargs) -> Collection[Properties, Refer client = self.conn return client.collections.create(name=name, **kwargs) + def get_collection(self, name: str) -> Collection[Properties, References]: + """Get a collection by name. + + :param name: The name of the collection to get. + """ + client = self.conn + return client.collections.get(name) + @retry( reraise=True, stop=stop_after_attempt(3), @@ -554,8 +562,7 @@ def create_object(self, data_object: dict | str, collection_name: str, **kwargs) :param collection_name: Collection name associated with the object given. :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() """ - client = self.conn - collection = client.collections.get(collection_name) + collection = self.get_collection(collection_name) # generate deterministic uuid if not provided uuid = kwargs.pop("uuid", generate_uuid5(data_object)) try: @@ -611,8 +618,7 @@ def get_object( :param kwargs: parameters to be passed to collection.data.fetch_object_by_id() or collection.data.fetch_objects() """ - client = self.conn - collection = client.collections.get(collection_name) + collection = self.get_collection(collection_name) return collection.query.fetch_objects(**kwargs) def get_all_objects( @@ -646,8 +652,7 @@ def delete_object(self, collection_name: str, uuid: UUID | str) -> bool: :param collection_name: Collection name associated with the object given. :param uuid: uuid of the object to be deleted """ - client = self.conn - collection = client.collections.get(collection_name) + collection = self.get_collection(collection_name) return collection.data.delete_by_id(uuid=uuid) def update_object( @@ -660,8 +665,7 @@ def update_object( :param properties: The properties of the object. :param kwargs: Optional parameters to be passed to collection.data.update() """ - client = self.conn - collection = client.collections.get(collection_name) + collection = self.get_collection(collection_name) collection.data.update(uuid=uuid, properties=properties, **kwargs) def replace_object( @@ -699,8 +703,7 @@ def object_exists(self, collection_name: str, uuid: str | UUID) -> bool: :param collection_name: Collection name associated with the object given. :param uuid: The UUID of the object that may or may not exist within Weaviate. """ - client = self.conn - collection = client.collections.get(collection_name) + collection = self.get_collection(collection_name) return collection.data.exists(uuid=uuid) def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per_object: int = 5): From 45dd131ece09b16affa19608cc045ceccc2e62d8 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 17:05:14 +0800 Subject: [PATCH 18/52] feat(providers/weaviate): migrate batch_data to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 76 +++----------------- 1 file changed, 9 insertions(+), 67 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 03491e18081e8..b524f4ebfef03 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -412,84 +412,27 @@ def check_subset_of_schema(self, classes_objects: list) -> bool: def batch_data( self, - class_name: str, + collection_name: str, data: list[dict[str, Any]] | pd.DataFrame | None, - batch_config_params: dict[str, Any] | None = None, vector_col: str = "Vector", uuid_col: str = "id", retry_attempts_per_object: int = 5, tenant: str | None = None, - ) -> list: + ) -> None: """ Add multiple objects or object references at once into weaviate. - :param class_name: The name of the class that objects belongs to. + :param collection_name: The name of the collection that objects belongs to. :param data: list or dataframe of objects we want to add. - :param batch_config_params: dict of batch configuration option. - .. seealso:: `batch_config_params options `__ :param vector_col: name of the column containing the vector. :param uuid_col: Name of the column containing the UUID. :param retry_attempts_per_object: number of time to try in case of failure before giving up. :param tenant: The tenant to which the object will be added. """ converted_data = self._convert_dataframe_to_list(data) - total_results = 0 - error_results = 0 - insertion_errors: list = [] - - def _process_batch_errors( - results: list, - verbose: bool = True, - ) -> None: - """ - Process the results from insert or delete batch operation and collects any errors. - - :param results: Results from the batch operation. - :param verbose: Flag to enable verbose logging. - """ - nonlocal total_results - nonlocal error_results - total_batch_results = len(results) - error_batch_results = 0 - for item in results: - if "errors" in item["result"]: - error_batch_results = error_batch_results + 1 - item_error = {"uuid": item["id"], "errors": item["result"]["errors"]} - if verbose: - self.log.info( - "Error occurred in batch process for %s with error %s", - item["id"], - item["result"]["errors"], - ) - insertion_errors.append(item_error) - if verbose: - total_results = total_results + (total_batch_results - error_batch_results) - error_results = error_results + error_batch_results - - self.log.info( - "Total Objects %s / Objects %s successfully inserted and Objects %s had errors.", - len(converted_data), - total_results, - error_results, - ) - client = self.conn - if not batch_config_params: - batch_config_params = {} - - # configuration for context manager for __exit__ method to callback on errors for weaviate - # batch ingestion. - if not batch_config_params.get("callback"): - batch_config_params["callback"] = _process_batch_errors - - if not batch_config_params.get("timeout_retries"): - batch_config_params["timeout_retries"] = 5 - - if not batch_config_params.get("connection_error_retries"): - batch_config_params["connection_error_retries"] = 5 - - client.batch.configure(**batch_config_params) - with client.batch as batch: + collection = self.get_collection(collection_name) + with collection.batch.dynamic() as batch: # Batch import all data for data_obj in converted_data: for attempt in Retrying( @@ -507,15 +450,14 @@ def _process_batch_errors( attempt.retry_state.attempt_number, uuid, ) - batch.add_data_object( - data_object=data_obj, - class_name=class_name, - vector=vector, + batch.add_object( + collection=collection_name, + properties=data_obj, uuid=uuid, + vector=vector, tenant=tenant, ) self.log.debug("Inserted object with uuid: %s into batch", uuid) - return insertion_errors def query_with_vector( self, From 0f1d956c902d90bad1a9cbc1bd8d97fe1e7cb6f8 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 17:20:49 +0800 Subject: [PATCH 19/52] feat(providers/weaviate): migrate get_or_create_object to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 31 +++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index b524f4ebfef03..9acf58ee5ee4c 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -515,40 +515,29 @@ def create_object(self, data_object: dict | str, collection_name: str, **kwargs) def get_or_create_object( self, + collection_name, data_object: dict | str | None = None, - class_name: str | None = None, vector: Sequence | None = None, - consistency_level: ConsistencyLevel | None = None, - tenant: str | None = None, **kwargs, - ) -> str | dict[str, Any] | None: + ) -> QueryReturnType[Properties, References, TProperties, TReferences] | None | UUID: """Get or Create a new object. - Returns the object if already exists + Returns the object if already exists, return UUID if not + :param collection_name: Collection name associated with the object given.. :param data_object: Object to be added. If type is str it should be either a URL or a file. This is required to create a new object. - :param class_name: Class name associated with the object given. This is required to create a new object. :param vector: Vector associated with the object given. This argument is only used when creating object. - :param consistency_level: Consistency level to be used. Applies to both create and get operations. - :param tenant: Tenant to be used. Applies to both create and get operations. - :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() and - weaviate_client.data_object.get() + :param kwargs: parameters to be passed to collection.data.fetch_object_by_id() or + collection.data.fetch_objects() """ - obj = self.get_object( - class_name=class_name, consistency_level=consistency_level, tenant=tenant, **kwargs - ) + obj = self.get_object(collection_name=collection_name, **kwargs) if not obj: - if not (data_object and class_name): - raise ValueError("data_object and class_name are required to create a new object") + if not (data_object and collection_name): + raise ValueError("data_object and collection are required to create a new object") uuid = kwargs.pop("uuid", generate_uuid5(data_object)) return self.create_object( - data_object, - class_name, - vector=vector, - uuid=uuid, - consistency_level=consistency_level, - tenant=tenant, + data_object=data_object, collection_name=collection_name, uuid=uuid, vector=vector, **kwargs ) return obj From 9539b464aee42a8a0d3158d5040b1a07d04634a6 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 17:29:58 +0800 Subject: [PATCH 20/52] feat(providers/weaviate): migrate _delete_objects to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 9acf58ee5ee4c..64de4a5bd31de 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -637,13 +637,13 @@ def object_exists(self, collection_name: str, uuid: str | UUID) -> bool: collection = self.get_collection(collection_name) return collection.data.exists(uuid=uuid) - def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per_object: int = 5): + def _delete_objects(self, uuids: Collection, collection_name: str, retry_attempts_per_object: int = 5): """Delete multiple objects. Helper function for `create_or_replace_objects()` to delete multiple objects. :param uuids: Collection of uuids. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param collection_name: Name of the collection in Weaviate schema where data is to be ingested. :param retry_attempts_per_object: number of times to try in case of failure before giving up. """ for uuid in uuids: @@ -656,7 +656,7 @@ def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per ): with attempt: try: - self.delete_object(uuid=uuid, class_name=class_name) + self.delete_object(uuid=uuid, collection_name=collection_name) self.log.debug("Deleted object with uuid %s", uuid) except weaviate.exceptions.UnexpectedStatusCodeException as e: if e.status_code == 404: From 52af6ca22cd3cc09f94db1c785c5871663307cc4 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 17:54:15 +0800 Subject: [PATCH 21/52] feat(providers/weaviate): migrate _delete_all_documents_objects to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 61 ++++++-------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 64de4a5bd31de..ae7a7c069661b 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -28,7 +28,7 @@ from tenacity import Retrying, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt from weaviate import WeaviateClient from weaviate.auth import Auth -from weaviate.data.replication import ConsistencyLevel +from weaviate.classes.query import Filter from weaviate.exceptions import ObjectAlreadyExistsException from weaviate.util import generate_uuid5 @@ -822,66 +822,39 @@ def _delete_all_documents_objects( self, document_keys: list[str], document_column: str, - class_name: str, + collection_name: str, total_objects_count: int = 1, batch_delete_error: list | None = None, - tenant: str | None = None, - batch_config_params: dict[str, Any] | None = None, verbose: bool = False, - ): + ) -> list: """Delete all object that belong to list of documents. :param document_keys: list of unique documents identifiers. :param document_column: Column in DataFrame that identifying source document. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param collection_name: Name of the collection in Weaviate schema where data is to be ingested. :param total_objects_count: total number of objects to delete, needed as max limit on one delete query is 10,000, if we have more objects to delete we need to run query multiple times. :param batch_delete_error: list to hold errors while inserting. - :param tenant: The tenant to which the object will be added. - :param batch_config_params: Additional parameters for Weaviate batch configuration. :param verbose: Flag to enable verbose output during the ingestion process. """ batch_delete_error = batch_delete_error or [] - if not batch_config_params: - batch_config_params = {} - # This limit is imposed by Weavaite database MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS = 10000 - self.conn.batch.configure(**batch_config_params) - with self.conn.batch as batch: - # ConsistencyLevel.ALL is essential here to guarantee complete deletion of objects - # across all nodes. Maintaining this level ensures data integrity, preventing - # irrelevant objects from providing misleading context for LLM models. - batch.consistency_level = ConsistencyLevel.ALL - while total_objects_count > 0: - document_objects = batch.delete_objects( - class_name=class_name, - where={ - "operator": "Or", - "operands": [ - { - "path": [document_column], - "operator": "Equal", - "valueText": key, - } - for key in document_keys - ], - }, - output="verbose", - dry_run=False, - tenant=tenant, - ) - total_objects_count = total_objects_count - MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS - matched_objects = document_objects["results"]["matches"] - batch_delete_error = [ - {"uuid": obj["id"]} - for obj in document_objects["results"]["objects"] - if "error" in obj["status"] - ] - if verbose: - self.log.info("Deleted %s Objects", matched_objects) + collection = self.get_collection(collection_name) + document_objects = collection.data.delete_many( + where=Filter.any_of([Filter.by_property(document_column).equal(key) for key in document_keys]), + verbase=verbose, + dru_run=False, + ) + total_objects_count = total_objects_count - MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS + matched_objects = document_objects["results"]["matches"] + batch_delete_error = [ + {"uuid": obj["id"]} for obj in document_objects["results"]["objects"] if "error" in obj["status"] + ] + if verbose: + self.log.info("Deleted %s Objects", matched_objects) return batch_delete_error From 780305b8342331de785b2167e41ed538c1f17a6b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:06:28 +0800 Subject: [PATCH 22/52] feat(providers/weaviate): migrate _get_documents_to_uuid_map to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 29 ++++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index ae7a7c069661b..5539639a2e1a8 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -723,7 +723,7 @@ def _get_documents_to_uuid_map( data: pd.DataFrame, document_column: str, uuid_column: str, - class_name: str, + collection_name: str, offset: int = 0, limit: int = 2000, ) -> dict[str, set]: @@ -731,7 +731,7 @@ def _get_documents_to_uuid_map( :param data: A single pandas DataFrame. :param document_column: The name of the property to query. - :param class_name: The name of the class to query. + :param collection_name: The name of the collection to query. :param uuid_column: The name of the column containing the UUID. :param offset: pagination parameter to indicate the which object to start fetching data. :param limit: pagination param to indicate the number of records to fetch from start object. @@ -739,22 +739,15 @@ def _get_documents_to_uuid_map( documents_to_uuid: dict = {} document_keys = set(data[document_column]) while True: - data_objects = ( - self.conn.query.get(properties=[document_column], class_name=class_name) - .with_additional([uuid_column]) - .with_where( - { - "operator": "Or", - "operands": [ - {"valueText": key, "path": document_column, "operator": "Equal"} - for key in document_keys - ], - } - ) - .with_offset(offset) - .with_limit(limit) - .do()["data"]["Get"][class_name] - ) + collection = self.get_collection(collection_name) + data_objects = collection.query.fetech_objects( + filters=Filter.any_of( + [Filter.by_property(document_column).equal(key) for key in document_keys] + ), + return_properties=[document_column], + limit=limit, + offset=offset, + )["data"]["Get"][collection] if len(data_objects) == 0: break offset = offset + limit From 5e37b00482597d27fb9a9d1cf406c77c0e6bf40c Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:07:41 +0800 Subject: [PATCH 23/52] feat(providers/weaviate): migrate _get_segregated_documents to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 5539639a2e1a8..0578e63b225c4 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -776,21 +776,24 @@ def _prepare_document_to_uuid_map( return grouped_key_to_set def _get_segregated_documents( - self, data: pd.DataFrame, document_column: str, class_name: str, uuid_column: str + self, data: pd.DataFrame, document_column: str, collection_name: str, uuid_column: str ) -> tuple[dict[str, set], set, set, set]: """ Segregate documents into changed, unchanged and new document, when compared to Weaviate db. :param data: A single pandas DataFrame. :param document_column: The name of the property to query. - :param class_name: The name of the class to query. + :param collection_name: The name of the collection to query. :param uuid_column: The name of the column containing the UUID. """ changed_documents = set() unchanged_docs = set() new_documents = set() existing_documents_to_uuid = self._get_documents_to_uuid_map( - data=data, uuid_column=uuid_column, document_column=document_column, class_name=class_name + data=data, + uuid_column=uuid_column, + document_column=document_column, + collection_name=collection_name, ) input_documents_to_uuid = self._prepare_document_to_uuid_map( From 1c4f0ecabd7c1a8c8602fa65d31cc330186e1ee1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:08:42 +0800 Subject: [PATCH 24/52] feat(providers/weaviate): migrate create_or_replace_document_objects to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 33 +++++++++----------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 0578e63b225c4..4bbd1a7eb58ec 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -857,12 +857,11 @@ def _delete_all_documents_objects( def create_or_replace_document_objects( self, data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame], - class_name: str, + collection_name: str, document_column: str, existing: str = "skip", uuid_column: str | None = None, vector_column: str = "Vector", - batch_config_params: dict | None = None, tenant: str | None = None, verbose: bool = False, ): @@ -887,21 +886,20 @@ def create_or_replace_document_objects( error: raise an error if an object belonging to a existing document is tried to be created. :param data: A single pandas DataFrame or a list of dicts to be ingested. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param colleciton_name: Name of the collection in Weaviate schema where data is to be ingested. :param existing: Strategy for handling existing data: 'skip', or 'replace'. Default is 'skip'. :param document_column: Column in DataFrame that identifying source document. :param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated. :param vector_column: Column with embedding vectors for pre-embedded data. - :param batch_config_params: Additional parameters for Weaviate batch configuration. :param tenant: The tenant to which the object will be added. :param verbose: Flag to enable verbose output during the ingestion process. :return: list of UUID which failed to create """ - import pandas as pd - if existing not in ["skip", "replace", "error"]: raise ValueError("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'error'.") + import pandas as pd + if len(data) == 0: return [] @@ -925,7 +923,7 @@ def create_or_replace_document_objects( uuid_column, ) = self._generate_uuids( df=data, - class_name=class_name, + collection_name=collection_name, unique_columns=unique_columns, vector_column=vector_column, uuid_column=uuid_column, @@ -947,7 +945,7 @@ def create_or_replace_document_objects( data=data, document_column=document_column, uuid_column=uuid_column, - class_name=class_name, + collection_name=collection_name, ) if verbose: self.log.info( @@ -984,12 +982,10 @@ def create_or_replace_document_objects( ) batch_delete_error = self._delete_all_documents_objects( document_keys=list(changed_documents), - total_objects_count=total_objects_count, document_column=document_column, - class_name=class_name, + collection_name=collection_name, + total_objects_count=total_objects_count, batch_delete_error=batch_delete_error, - tenant=tenant, - batch_config_params=batch_config_params, verbose=verbose, ) data = data[data[document_column].isin(new_documents.union(changed_documents))] @@ -998,9 +994,8 @@ def create_or_replace_document_objects( insertion_errors: list = [] if data.shape[0]: insertion_errors = self.batch_data( - class_name=class_name, + collection_name=collection_name, data=data, - batch_config_params=batch_config_params, vector_col=vector_column, uuid_col=uuid_column, tenant=tenant, @@ -1012,13 +1007,15 @@ def create_or_replace_document_objects( self.log.info("Failed to delete %s objects.", len(insertion_errors)) # Rollback object that were not created properly self._delete_objects( - [item["uuid"] for item in insertion_errors + batch_delete_error], class_name=class_name + [item["uuid"] for item in insertion_errors + batch_delete_error], + collection_name=collection_name, ) if verbose: + collection = self.get_collection(collection_name) self.log.info( - "Total objects in class %s : %s ", - class_name, - self.conn.query.aggregate(class_name).with_meta_count().do(), + "Total objects in collection %s : %s ", + collection_name, + collection.aggregate.over_all(total_count=True), ) return insertion_errors, batch_delete_error From 0d442a021e85fc3e53bdfb6fc2e2d7bfaf06adff Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 14 Jun 2024 17:28:43 +0800 Subject: [PATCH 25/52] feat(providers/weaviate): remove validate_object --- airflow/providers/weaviate/hooks/weaviate.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 4bbd1a7eb58ec..538895701809d 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -618,16 +618,6 @@ def replace_object( collection = self.get_collection(collection_name) collection.data.replace(uuid=uuid, properties=properties, references=references, **kwargs) - def validate_object(self, data_object: dict | str, class_name: str, **kwargs): - """Validate an object in weaviate. - - :param data_object: The object to be validated. If type is str it should be either an URL or a file. - :param class_name: Class name associated with the object given. - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.validate() - """ - client = self.conn - client.data_object.validate(data_object, class_name, **kwargs) - def object_exists(self, collection_name: str, uuid: str | UUID) -> bool: """Check if an object exists in weaviate. From a600004effa244aaef3c892f1f51780b60763c12 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 13 Jun 2024 15:50:39 +0800 Subject: [PATCH 26/52] refactor(providers/weaviate): remove unused retry_status_codes --- airflow/providers/weaviate/hooks/weaviate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 538895701809d..f0ceffca4cab7 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -82,7 +82,6 @@ class WeaviateHook(BaseHook): def __init__( self, conn_id: str = default_conn_name, - retry_status_codes: list[int] | None = None, *args: Any, **kwargs: Any, ) -> None: From db5bfc7b6598d359683daca946ea57374bee6845 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 19:09:17 +0800 Subject: [PATCH 27/52] refactor(providers/weaviate): remove deprecated get_client --- airflow/providers/weaviate/hooks/weaviate.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index f0ceffca4cab7..49c605d0054ed 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -24,7 +24,6 @@ import requests import weaviate.exceptions -from deprecated import deprecated from tenacity import Retrying, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt from weaviate import WeaviateClient from weaviate.auth import Auth @@ -32,7 +31,6 @@ from weaviate.exceptions import ObjectAlreadyExistsException from weaviate.util import generate_uuid5 -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook if TYPE_CHECKING: @@ -159,15 +157,6 @@ def conn(self) -> WeaviateClient: """Returns a Weaviate client.""" return self.get_conn() - @deprecated( - reason="The `get_client` method has been renamed to `get_conn`", - category=AirflowProviderDeprecationWarning, - ) - def get_client(self) -> WeaviateClient: - """Return a Weaviate client.""" - # Keeping this for backwards compatibility - return self.conn - def test_connection(self) -> tuple[bool, str]: try: client = self.conn @@ -236,7 +225,7 @@ def get_collection_configuraiton(self, collection_name: str) -> CollectionConfig :param collection_name: The collection for which to return the collection configuration. """ - client = self.get_client() + client = self.get_conn() return client.collections.get(collection_name).config.get() def delete_collections( @@ -250,7 +239,7 @@ def delete_collections( :return: if `if_error=continue` return list of collections which we failed to delete. if `if_error=stop` returns None. """ - client = self.get_client() + client = self.get_conn() collection_names = ( [collection_names] if collection_names and isinstance(collection_names, str) else collection_names ) From 38494171cb3f5c6c70fdc7164904113f5aba2d30 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:25:43 +0800 Subject: [PATCH 28/52] feat(providers/weaviate): migrate update_config to v4 API --- airflow/providers/weaviate/hooks/weaviate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 49c605d0054ed..7cc34bd36dc67 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -273,10 +273,10 @@ def delete_all_schema(self): client = self.get_client() return client.schema.delete_all() - def update_config(self, class_name: str, config: dict): - """Update a schema configuration for a specific class.""" - client = self.get_client() - client.schema.update_config(class_name=class_name, config=config) + def update_config(self, collection_name: str, **kwargs) -> None: + """Update the collection definition.""" + collection = self.get_collection(collection_name) + collection.config.update(**kwargs) def create_or_replace_classes( self, schema_json: dict[str, Any] | str, existing: ExitingSchemaOptions = "ignore" From 9f3473510b0e996bbb80c1c59459bd92faefdc4e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:32:26 +0800 Subject: [PATCH 29/52] feat(providers/weaviate): remove create_schema and delete_all_schema as there's no v4 API counterpart --- airflow/providers/weaviate/hooks/weaviate.py | 24 -------------------- 1 file changed, 24 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 7cc34bd36dc67..e0543f83196df 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -179,25 +179,6 @@ def get_collection(self, name: str) -> Collection[Properties, References]: client = self.conn return client.collections.get(name) - @retry( - reraise=True, - stop=stop_after_attempt(3), - retry=( - retry_if_exception(lambda exc: check_http_error_is_retryable(exc)) - | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) - ), - ) - def create_schema(self, schema_json: dict[str, Any] | str) -> None: - """ - Create a new Schema. - - Instead of adding classes one by one , you can upload a full schema in JSON format at once. - - :param schema_json: Schema as a Python dict or the path to a JSON file, or the URL of a JSON file. - """ - client = self.conn - client.schema.create(schema_json) - @staticmethod def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) -> list[dict[str, Any]]: """Convert dataframe to list of dicts. @@ -268,11 +249,6 @@ def delete_collections( return failed_collection_list return None - def delete_all_schema(self): - """Remove the entire schema from the Weaviate instance and all data associated with it.""" - client = self.get_client() - return client.schema.delete_all() - def update_config(self, collection_name: str, **kwargs) -> None: """Update the collection definition.""" collection = self.get_collection(collection_name) From fd595acb4ac00caef40c2bc59094a852eeaec706 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:35:23 +0800 Subject: [PATCH 30/52] feat(providers/weaviate): remove create_or_replace_classes as there's no v4 API counterpart --- airflow/providers/weaviate/hooks/weaviate.py | 32 -------------------- 1 file changed, 32 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index e0543f83196df..29d51356bbf49 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -254,38 +254,6 @@ def update_config(self, collection_name: str, **kwargs) -> None: collection = self.get_collection(collection_name) collection.config.update(**kwargs) - def create_or_replace_classes( - self, schema_json: dict[str, Any] | str, existing: ExitingSchemaOptions = "ignore" - ): - """ - Create or replace the classes in schema of Weaviate database. - - :param schema_json: Json containing the schema. Format {"class_name": "class_dict"} - .. seealso:: `example of class_dict `_. - :param existing: Options to handle the case when the classes exist, possible options - 'replace', 'fail', 'ignore'. - """ - existing_schema_options = ["replace", "fail", "ignore"] - if existing not in existing_schema_options: - raise ValueError(f"Param 'existing' should be one of the {existing_schema_options} values.") - if isinstance(schema_json, str): - schema_json = cast(dict, json.load(open(schema_json))) - set__exiting_classes = {class_object["class"] for class_object in self.get_schema()["classes"]} - set__to_be_added_classes = {key for key, _ in schema_json.items()} - intersection_classes = set__exiting_classes.intersection(set__to_be_added_classes) - classes_to_create = set() - if existing == "fail" and intersection_classes: - raise ValueError(f"Trying to create class {intersection_classes} but this class already exists.") - elif existing == "ignore": - classes_to_create = set__to_be_added_classes - set__exiting_classes - elif existing == "replace": - error_list = self.delete_classes(class_names=list(intersection_classes)) - if error_list: - raise ValueError(error_list) - classes_to_create = intersection_classes.union(set__to_be_added_classes) - classes_to_create_list = [schema_json[item] for item in sorted(list(classes_to_create))] - self.create_schema({"classes": classes_to_create_list}) - def _compare_schema_subset(self, subset_object: Any, superset_object: Any) -> bool: """ Recursively check if requested subset_object is a subset of the superset_object. From 060c92ee3d773692a985d472ace8a6d7032f33ef Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:40:14 +0800 Subject: [PATCH 31/52] feat(providers/weaviate): remove _compare_schema_subset, _convert_properties_to_dict, check_subset_of_schema as there's no v4 API counterpart --- airflow/providers/weaviate/hooks/weaviate.py | 88 -------------------- 1 file changed, 88 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 29d51356bbf49..d705b464b3a0c 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -254,94 +254,6 @@ def update_config(self, collection_name: str, **kwargs) -> None: collection = self.get_collection(collection_name) collection.config.update(**kwargs) - def _compare_schema_subset(self, subset_object: Any, superset_object: Any) -> bool: - """ - Recursively check if requested subset_object is a subset of the superset_object. - - Example 1: - superset_object = {"a": {"b": [1, 2, 3], "c": "d"}} - subset_object = {"a": {"c": "d"}} - _compare_schema_subset(subset_object, superset_object) # will result in True - - superset_object = {"a": {"b": [1, 2, 3], "c": "d"}} - subset_object = {"a": {"d": "e"}} - _compare_schema_subset(subset_object, superset_object) # will result in False - - :param subset_object: The object to be checked - :param superset_object: The object to check against - """ - # Direct equality check - if subset_object == superset_object: - return True - - # Type mismatch early return - if type(subset_object) != type(superset_object): - return False - - # Dictionary comparison - if isinstance(subset_object, dict): - for k, v in subset_object.items(): - if (k not in superset_object) or (not self._compare_schema_subset(v, superset_object[k])): - return False - return True - - # List or Tuple comparison - if isinstance(subset_object, (list, tuple)): - for sub, sup in zip(subset_object, superset_object): - if len(subset_object) > len(superset_object) or not self._compare_schema_subset(sub, sup): - return False - return True - - # Default case for non-matching types or unsupported types - return False - - @staticmethod - def _convert_properties_to_dict(classes_objects, key_property: str = "name"): - """ - Convert list of class properties into dict by using a `key_property` as key. - - This is done to avoid class properties comparison as list of properties. - - Case 1: - A = [1, 2, 3] - B = [1, 2] - When comparing list we check for the length, but it's not suitable for subset check. - - Case 2: - A = [1, 2, 3] - B = [1, 3, 2] - When we compare two lists, we compare item 1 of list A with item 1 of list B and - pass if the two are same, but there can be scenarios when the properties are not in same order. - """ - for cls in classes_objects: - cls["properties"] = {p[key_property]: p for p in cls["properties"]} - return classes_objects - - def check_subset_of_schema(self, classes_objects: list) -> bool: - """Check if the class_objects is a subset of existing schema. - - Note - weaviate client's `contains()` don't handle the class properties mismatch, if you want to - compare `Class A` with `Class B` they must have exactly same properties. If `Class A` has fewer - numbers of properties than Class B, `contains()` will result in False. - - .. seealso:: `contains `_. - """ - # When the class properties are not in same order or not the same length. We convert them to dicts - # with property `name` as the key. This way we ensure, the subset is checked. - - classes_objects = self._convert_properties_to_dict(classes_objects) - exiting_classes_list = self._convert_properties_to_dict(self.get_schema()["classes"]) - - exiting_classes = {cls["class"]: cls for cls in exiting_classes_list} - exiting_classes_set = set(exiting_classes.keys()) - input_classes_set = {cls["class"] for cls in classes_objects} - if not input_classes_set.issubset(exiting_classes_set): - return False - for cls in classes_objects: - if not self._compare_schema_subset(cls, exiting_classes[cls["class"]]): - return False - return True - def batch_data( self, collection_name: str, From 32f8505d3ab850538d7d3acc3002b1d3866a2830 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:44:44 +0800 Subject: [PATCH 32/52] feat(providers/weaviate): migrate WeaviateIngestOperator to v4 API --- airflow/providers/weaviate/operators/weaviate.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py index 8a26ee5bfbed8..0b7d3aa2c700e 100644 --- a/airflow/providers/weaviate/operators/weaviate.py +++ b/airflow/providers/weaviate/operators/weaviate.py @@ -43,11 +43,10 @@ class WeaviateIngestOperator(BaseOperator): custom vectors and store them in the Weaviate class. :param conn_id: The Weaviate connection. - :param class_name: The Weaviate class to be used for storing the data objects into. + :param collection: The Weaviate collection to be used for storing the data objects into. :param input_data: The list of dicts or pandas dataframe representing Weaviate data objects to generate embeddings on (or provides custom vectors) and store them in the Weaviate class. :param vector_col: key/column name in which the vectors are stored. - :param batch_params: Additional parameters for Weaviate batch configuration. :param hook_params: Optional config params to be passed to the underlying hook. Should match the desired hook constructor params. :param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on @@ -59,25 +58,23 @@ class WeaviateIngestOperator(BaseOperator): def __init__( self, conn_id: str, - class_name: str, + collection_name: str, input_data: list[dict[str, Any]] | pd.DataFrame | None = None, vector_col: str = "Vector", uuid_column: str = "id", tenant: str | None = None, - batch_params: dict | None = None, hook_params: dict | None = None, input_json: list[dict[str, Any]] | pd.DataFrame | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.class_name = class_name + self.collection_name = collection_name self.conn_id = conn_id self.vector_col = vector_col self.input_json = input_json self.uuid_column = uuid_column self.tenant = tenant self.input_data = input_data - self.batch_params = batch_params or {} self.hook_params = hook_params or {} if (self.input_data is None) and (input_json is not None): @@ -100,9 +97,8 @@ def execute(self, context: Context) -> list: self.log.debug("Input data: %s", self.input_data) insertion_errors: list = [] self.hook.batch_data( - class_name=self.class_name, + collection_name=self.collection_name, data=self.input_data, - batch_config_params=self.batch_params, vector_col=self.vector_col, uuid_col=self.uuid_column, tenant=self.tenant, From bdff77f9279b6489261f0a8f032056e7db131bbd Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 17 Jun 2024 22:45:36 +0800 Subject: [PATCH 33/52] feat(providers/weaviate): migrate WeaviateDocumentIngestOperator to v4 API --- airflow/providers/weaviate/operators/weaviate.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py index 0b7d3aa2c700e..af2a25efa5f1c 100644 --- a/airflow/providers/weaviate/operators/weaviate.py +++ b/airflow/providers/weaviate/operators/weaviate.py @@ -128,12 +128,11 @@ class WeaviateDocumentIngestOperator(BaseOperator): error: raise an error if an object belonging to a existing document is tried to be created. :param data: A single pandas DataFrame or a list of dicts to be ingested. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param collection_name: Name of the collection in Weaviate schema where data is to be ingested. :param existing: Strategy for handling existing data: 'skip', or 'replace'. Default is 'skip'. :param document_column: Column in DataFrame that identifying source document. :param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated. :param vector_column: Column with embedding vectors for pre-embedded data. - :param batch_config_params: Additional parameters for Weaviate batch configuration. :param tenant: The tenant to which the object will be added. :param verbose: Flag to enable verbose output during the ingestion process. :param hook_params: Optional config params to be passed to the underlying hook. @@ -146,12 +145,11 @@ def __init__( self, conn_id: str, input_data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame], - class_name: str, + collection: str, document_column: str, existing: str = "skip", uuid_column: str = "id", vector_col: str = "Vector", - batch_config_params: dict | None = None, tenant: str | None = None, verbose: bool = False, hook_params: dict | None = None, @@ -160,12 +158,11 @@ def __init__( super().__init__(**kwargs) self.conn_id = conn_id self.input_data = input_data - self.class_name = class_name + self.collection = collection self.document_column = document_column self.existing = existing self.uuid_column = uuid_column self.vector_col = vector_col - self.batch_config_params = batch_config_params self.tenant = tenant self.verbose = verbose self.hook_params = hook_params or {} @@ -184,12 +181,11 @@ def execute(self, context: Context) -> list: self.log.debug("Total input objects : %s", len(self.input_data)) insertion_errors = self.hook.create_or_replace_document_objects( data=self.input_data, - class_name=self.class_name, + collection_name=self.collection_name, document_column=self.document_column, existing=self.existing, uuid_column=self.uuid_column, vector_column=self.vector_col, - batch_config_params=self.batch_config_params, tenant=self.tenant, verbose=self.verbose, ) From d341df77908d15366b1995711820b572dcfcb2a8 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 18 Jun 2024 15:53:46 +0800 Subject: [PATCH 34/52] fix(providers/weaviate): fix connect default port --- airflow/providers/weaviate/hooks/weaviate.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index d705b464b3a0c..c83503e8111cf 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, cast import requests +import weaviate import weaviate.exceptions from tenacity import Retrying, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt from weaviate import WeaviateClient @@ -117,13 +118,15 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: def get_conn(self) -> WeaviateClient: conn = self.get_connection(self.conn_id) extras = conn.extra_dejson + http_secure = extras.pop("http_secure", False) + grpc_secure = extras.pop("grcp_secure", False) return weaviate.connect_to_custom( http_host=conn.host, - http_port=conn.port or 8080, - http_secure=extras.pop("http_secure", False), + http_port=conn.port or 443 if http_secure else 80, + http_secure=http_secure, grpc_host=extras.pop("grpc_host", conn.host), - grpc_port=extras.pop("grpc_port", 50051), - grpc_secure=extras.pop("grcp_secure", False), + grpc_port=extras.pop("grpc_port", 443 if grpc_secure else 80), + grpc_secure=grpc_secure, headers=extras.pop("additional_headers", {}), auth_credentials=self._extract_auth_credentials(conn), ) @@ -143,9 +146,9 @@ def _extract_auth_credentials(self, conn: Connection) -> AuthCredentials: access_token=access_token, expires_in=expires_in, refresh_token=refresh_token ) + scope = extras.get("scope", None) or extras.get("oidc_scope", None) client_secret = extras.get("client_secret", None) if client_secret: - scope = extras.get("scope", None) or extras.get("oidc_scope", None) return Auth.client_credentials(client_secret=client_secret, scope=scope) username = conn.login or "" From de825ba513b3962b0f5fbe711fa322843ecc397e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 18 Jun 2024 16:15:37 +0800 Subject: [PATCH 35/52] refactor(provider/weaviate): rename update_config as update_collection_configuration for consistency --- airflow/providers/weaviate/hooks/weaviate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index c83503e8111cf..becc0d0cf08a2 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -252,8 +252,8 @@ def delete_collections( return failed_collection_list return None - def update_config(self, collection_name: str, **kwargs) -> None: - """Update the collection definition.""" + def update_collection_configuration(self, collection_name: str, **kwargs) -> None: + """Update the collection configuration.""" collection = self.get_collection(collection_name) collection.config.update(**kwargs) From d7e51ec2ea46b3f55e76f0bbfbe6a08545d90a12 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 18 Jun 2024 16:26:52 +0800 Subject: [PATCH 36/52] fix(providers/weaviate): fix batch_data wrong parameter used --- airflow/providers/weaviate/hooks/weaviate.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index becc0d0cf08a2..19b5b83137ee4 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -264,7 +264,7 @@ def batch_data( vector_col: str = "Vector", uuid_col: str = "id", retry_attempts_per_object: int = 5, - tenant: str | None = None, + references: ReferenceInputs | None = None, ) -> None: """ Add multiple objects or object references at once into weaviate. @@ -274,7 +274,7 @@ def batch_data( :param vector_col: name of the column containing the vector. :param uuid_col: Name of the column containing the UUID. :param retry_attempts_per_object: number of time to try in case of failure before giving up. - :param tenant: The tenant to which the object will be added. + :param references: The references of the object to be added as a dictionary. Use `wvc.Reference.to` to create the correct values in the dict. """ converted_data = self._convert_dataframe_to_list(data) @@ -298,11 +298,10 @@ def batch_data( uuid, ) batch.add_object( - collection=collection_name, properties=data_obj, + references=references, uuid=uuid, vector=vector, - tenant=tenant, ) self.log.debug("Inserted object with uuid: %s into batch", uuid) @@ -699,7 +698,6 @@ def create_or_replace_document_objects( existing: str = "skip", uuid_column: str | None = None, vector_column: str = "Vector", - tenant: str | None = None, verbose: bool = False, ): """ @@ -728,7 +726,6 @@ def create_or_replace_document_objects( :param document_column: Column in DataFrame that identifying source document. :param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated. :param vector_column: Column with embedding vectors for pre-embedded data. - :param tenant: The tenant to which the object will be added. :param verbose: Flag to enable verbose output during the ingestion process. :return: list of UUID which failed to create """ @@ -835,7 +832,6 @@ def create_or_replace_document_objects( data=data, vector_col=vector_column, uuid_col=uuid_column, - tenant=tenant, ) if insertion_errors or batch_delete_error: if insertion_errors: From 49e87d8a7446ea455707dd55cf6562b53d90b54a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 18 Jun 2024 17:36:45 +0800 Subject: [PATCH 37/52] fix(providers/weaviate): fix wrong v4 API calls in get_all_objects --- airflow/providers/weaviate/hooks/weaviate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 19b5b83137ee4..784551dec373a 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -399,7 +399,7 @@ def get_object( return collection.query.fetch_objects(**kwargs) def get_all_objects( - self, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs + self, collection_name: str, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs ) -> list[dict[str, Any]] | pd.DataFrame: """Get all objects from weaviate. @@ -412,11 +412,11 @@ def get_all_objects( all_objects = [] after = kwargs.pop("after", after) while True: - results = self.get_object(after=after, **kwargs) or {} - if not results.get("objects"): + results = self.get_object(collection_name=collection_name, after=after, **kwargs) + if not results or not results.objects: break - all_objects.extend(results["objects"]) - after = results["objects"][-1]["id"] + all_objects.extend(results.objects) + after = results.objects[-1].uuid if as_dataframe: import pandas From b07725ef36f1adab8aeb2f575c5221c7d1287621 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 18 Jun 2024 17:51:50 +0800 Subject: [PATCH 38/52] fix(providers/weaviate): fix get_all_object with as_dataframe set to True --- airflow/providers/weaviate/hooks/weaviate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 784551dec373a..32f4a521fed41 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -420,6 +420,9 @@ def get_all_objects( if as_dataframe: import pandas + # '_WeaviateUUIDInt' object has no attribute 'is_safe' which causes error + for obj in all_objects: + obj.uuid = str(obj.uuid) return pandas.DataFrame(all_objects) return all_objects From 391a01f8fca0e3adff6783c3ef1dc1c6c2406eb0 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 19 Jun 2024 17:12:09 +0800 Subject: [PATCH 39/52] fix(providers/weaviate): fix create_or_replace_document_objects --- airflow/providers/weaviate/hooks/weaviate.py | 79 +++++++++++--------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 32f4a521fed41..07b8ab9a9f464 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -579,22 +579,32 @@ def _get_documents_to_uuid_map( document_keys = set(data[document_column]) while True: collection = self.get_collection(collection_name) - data_objects = collection.query.fetech_objects( + data_objects = collection.query.fetch_objects( filters=Filter.any_of( [Filter.by_property(document_column).equal(key) for key in document_keys] ), return_properties=[document_column], limit=limit, offset=offset, - )["data"]["Get"][collection] - if len(data_objects) == 0: + ) + if len(data_objects.objects) == 0: break offset = offset + limit + + if uuid_column in data_objects.objects[0].properties: + data = [obj.properties for obj in data_objects.objects] + else: + data = [] + for obj in data_objects.objects: + row = obj.properties + row[uuid_column] = str(obj.uuid) + data.append(row) + documents_to_uuid.update( self._prepare_document_to_uuid_map( - data=data_objects, + data=data, group_key=document_column, - get_value=lambda x: x["_additional"][uuid_column], + get_value=lambda x: x[uuid_column], ) ) return documents_to_uuid @@ -640,16 +650,15 @@ def _get_segregated_documents( group_key=document_column, get_value=lambda x: x[uuid_column], ) - # segregate documents into changed, unchanged and non-existing documents. for doc_url, doc_set in input_documents_to_uuid.items(): if doc_url in existing_documents_to_uuid: if existing_documents_to_uuid[doc_url] != doc_set: - changed_documents.add(doc_url) + changed_documents.add(str(doc_url)) else: - unchanged_docs.add(doc_url) + unchanged_docs.add(str(doc_url)) else: - new_documents.add(doc_url) + new_documents.add(str(doc_url)) return existing_documents_to_uuid, changed_documents, unchanged_docs, new_documents @@ -661,7 +670,7 @@ def _delete_all_documents_objects( total_objects_count: int = 1, batch_delete_error: list | None = None, verbose: bool = False, - ) -> list: + ) -> list[dict[str, UUID | str]]: """Delete all object that belong to list of documents. :param document_keys: list of unique documents identifiers. @@ -678,16 +687,19 @@ def _delete_all_documents_objects( MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS = 10000 collection = self.get_collection(collection_name) - document_objects = collection.data.delete_many( + delete_many_return = collection.data.delete_many( where=Filter.any_of([Filter.by_property(document_column).equal(key) for key in document_keys]), - verbase=verbose, - dru_run=False, + verbose=verbose, + dry_run=False, ) total_objects_count = total_objects_count - MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS - matched_objects = document_objects["results"]["matches"] - batch_delete_error = [ - {"uuid": obj["id"]} for obj in document_objects["results"]["objects"] if "error" in obj["status"] - ] + matched_objects = delete_many_return.matches + if delete_many_return.failed > 0: + batch_delete_error = [ + {"uuid": obj.uuid, "error": obj.error} + for obj in delete_many_return.objects + if obj.error is not None + ] if verbose: self.log.info("Deleted %s Objects", matched_objects) @@ -702,7 +714,7 @@ def create_or_replace_document_objects( uuid_column: str | None = None, vector_column: str = "Vector", verbose: bool = False, - ): + ) -> list[dict[str, UUID | str] | None]: """ create or replace objects belonging to documents. @@ -795,7 +807,6 @@ def create_or_replace_document_objects( self.log.info( "Changed document: %s has %s objects.", document, len(documents_to_uuid_map[document]) ) - self.log.info("Non-existing document: %s", ", ".join(new_documents)) if existing == "error" and len(changed_documents): @@ -817,33 +828,31 @@ def create_or_replace_document_objects( total_objects_count, changed_documents, ) - batch_delete_error = self._delete_all_documents_objects( - document_keys=list(changed_documents), - document_column=document_column, - collection_name=collection_name, - total_objects_count=total_objects_count, - batch_delete_error=batch_delete_error, - verbose=verbose, - ) + if list(changed_documents): + batch_delete_error = self._delete_all_documents_objects( + document_keys=list(changed_documents), + document_column=document_column, + collection_name=collection_name, + total_objects_count=total_objects_count, + batch_delete_error=batch_delete_error, + verbose=verbose, + ) data = data[data[document_column].isin(new_documents.union(changed_documents))] self.log.info("Batch inserting %s objects for non-existing and changed documents.", data.shape[0]) - insertion_errors: list = [] if data.shape[0]: - insertion_errors = self.batch_data( + self.batch_data( collection_name=collection_name, data=data, vector_col=vector_column, uuid_col=uuid_column, ) - if insertion_errors or batch_delete_error: - if insertion_errors: - self.log.info("Failed to insert %s objects.", len(insertion_errors)) + if batch_delete_error: if batch_delete_error: - self.log.info("Failed to delete %s objects.", len(insertion_errors)) + self.log.info("Failed to delete %s objects.", len(batch_delete_error)) # Rollback object that were not created properly self._delete_objects( - [item["uuid"] for item in insertion_errors + batch_delete_error], + [item["uuid"] for item in batch_delete_error], collection_name=collection_name, ) @@ -854,4 +863,4 @@ def create_or_replace_document_objects( collection_name, collection.aggregate.over_all(total_count=True), ) - return insertion_errors, batch_delete_error + return batch_delete_error From 87b9125b534ae1dd8e61241c6603c524ea74f432 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 19 Jun 2024 18:56:35 +0800 Subject: [PATCH 40/52] refactor(providers/weaviate): rename query_without_vector as query_with_text --- airflow/providers/weaviate/hooks/weaviate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 07b8ab9a9f464..611c7fc599111 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -328,9 +328,7 @@ def query_with_vector( ) return response - def query_without_vector( - self, search_text: str, collection_name: str, *properties: list[str], limit: int = 1 - ): + def query_with_text(self, search_text: str, collection_name: str, *properties: list[str], limit: int = 1): """ Query using near text. From aabb499298a60234eddfb83325b0eedceeb57214 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 19 Jun 2024 19:01:24 +0800 Subject: [PATCH 41/52] feat(providers/weaviate): migrate operators to v4 API hook --- airflow/providers/weaviate/operators/weaviate.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py index af2a25efa5f1c..fad06e94e621d 100644 --- a/airflow/providers/weaviate/operators/weaviate.py +++ b/airflow/providers/weaviate/operators/weaviate.py @@ -93,17 +93,14 @@ def hook(self) -> WeaviateHook: """Return an instance of the WeaviateHook.""" return WeaviateHook(conn_id=self.conn_id, **self.hook_params) - def execute(self, context: Context) -> list: + def execute(self, context: Context) -> None: self.log.debug("Input data: %s", self.input_data) - insertion_errors: list = [] self.hook.batch_data( collection_name=self.collection_name, data=self.input_data, vector_col=self.vector_col, uuid_col=self.uuid_column, - tenant=self.tenant, ) - return insertion_errors class WeaviateDocumentIngestOperator(BaseOperator): @@ -145,7 +142,7 @@ def __init__( self, conn_id: str, input_data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame], - collection: str, + collection_name: str, document_column: str, existing: str = "skip", uuid_column: str = "id", @@ -158,7 +155,7 @@ def __init__( super().__init__(**kwargs) self.conn_id = conn_id self.input_data = input_data - self.collection = collection + self.collection_name = collection_name self.document_column = document_column self.existing = existing self.uuid_column = uuid_column @@ -179,14 +176,13 @@ def execute(self, context: Context) -> list: :return: List of UUID which failed to create """ self.log.debug("Total input objects : %s", len(self.input_data)) - insertion_errors = self.hook.create_or_replace_document_objects( + batch_delete_error = self.hook.create_or_replace_document_objects( data=self.input_data, collection_name=self.collection_name, document_column=self.document_column, existing=self.existing, uuid_column=self.uuid_column, vector_column=self.vector_col, - tenant=self.tenant, verbose=self.verbose, ) - return insertion_errors + return batch_delete_error From 86fcd03e0c58ab237a09e6c3f8cf90009e9f3e75 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 19 Jun 2024 19:02:14 +0800 Subject: [PATCH 42/52] refactor(providers/weavite): group similar methods together --- airflow/providers/weaviate/hooks/weaviate.py | 60 ++++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 611c7fc599111..2d2134ed764a2 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -182,36 +182,6 @@ def get_collection(self, name: str) -> Collection[Properties, References]: client = self.conn return client.collections.get(name) - @staticmethod - def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) -> list[dict[str, Any]]: - """Convert dataframe to list of dicts. - - In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing - Pandas will fail, which is invalid. This function handles this scenario. - """ - with contextlib.suppress(ImportError): - import pandas - - if isinstance(data, pandas.DataFrame): - data = json.loads(data.to_json(orient="records")) - return cast(List[Dict[str, Any]], data) - - @retry( - reraise=True, - stop=stop_after_attempt(3), - retry=( - retry_if_exception(lambda exc: check_http_error_is_retryable(exc)) - | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) - ), - ) - def get_collection_configuraiton(self, collection_name: str) -> CollectionConfig | CollectionConfigSimple: - """Get the collection configuration from Weaviate. - - :param collection_name: The collection for which to return the collection configuration. - """ - client = self.get_conn() - return client.collections.get(collection_name).config.get() - def delete_collections( self, collection_names: list[str] | str, if_error: str = "stop" ) -> list[str] | None: @@ -252,11 +222,41 @@ def delete_collections( return failed_collection_list return None + @retry( + reraise=True, + stop=stop_after_attempt(3), + retry=( + retry_if_exception(lambda exc: check_http_error_is_retryable(exc)) + | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) + ), + ) + def get_collection_configuraiton(self, collection_name: str) -> CollectionConfig | CollectionConfigSimple: + """Get the collection configuration from Weaviate. + + :param collection_name: The collection for which to return the collection configuration. + """ + client = self.get_conn() + return client.collections.get(collection_name).config.get() + def update_collection_configuration(self, collection_name: str, **kwargs) -> None: """Update the collection configuration.""" collection = self.get_collection(collection_name) collection.config.update(**kwargs) + @staticmethod + def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) -> list[dict[str, Any]]: + """Convert dataframe to list of dicts. + + In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing + Pandas will fail, which is invalid. This function handles this scenario. + """ + with contextlib.suppress(ImportError): + import pandas + + if isinstance(data, pandas.DataFrame): + data = json.loads(data.to_json(orient="records")) + return cast(List[Dict[str, Any]], data) + def batch_data( self, collection_name: str, From 93abecd0bd40d2f64057520731ec00198f57f61d Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 20 Jun 2024 10:05:29 +0800 Subject: [PATCH 43/52] test(providers/weaviate): update unit tests for operators --- .../weaviate/operators/test_weaviate.py | 34 ++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/providers/weaviate/operators/test_weaviate.py b/tests/providers/weaviate/operators/test_weaviate.py index e147743d74cfe..8060fdf023116 100644 --- a/tests/providers/weaviate/operators/test_weaviate.py +++ b/tests/providers/weaviate/operators/test_weaviate.py @@ -37,15 +37,14 @@ def operator(self): return WeaviateIngestOperator( task_id="weaviate_task", conn_id="weaviate_conn", - class_name="my_class", + collection_name="my_collection", input_data=[{"data": "sample_data"}], ) def test_constructor(self, operator): assert operator.conn_id == "weaviate_conn" - assert operator.class_name == "my_class" + assert operator.collection_name == "my_collection" assert operator.input_data == [{"data": "sample_data"}] - assert operator.batch_params == {} assert operator.hook_params == {} @patch("airflow.providers.weaviate.operators.weaviate.WeaviateIngestOperator.log") @@ -57,21 +56,18 @@ def test_execute_with_input_json(self, mock_log, operator): operator = WeaviateIngestOperator( task_id="weaviate_task", conn_id="weaviate_conn", - class_name="my_class", + collection_name="my_collection", input_json=[{"data": "sample_data"}], ) - operator.hook.batch_data = MagicMock() operator.execute(context=None) operator.hook.batch_data.assert_called_once_with( - class_name="my_class", + collection_name="my_collection", data=[{"data": "sample_data"}], - batch_config_params={}, vector_col="Vector", uuid_col="id", - tenant=None, ) mock_log.debug.assert_called_once_with("Input data: %s", [{"data": "sample_data"}]) @@ -82,12 +78,10 @@ def test_execute_with_input_data(self, mock_log, operator): operator.execute(context=None) operator.hook.batch_data.assert_called_once_with( - class_name="my_class", + collection_name="my_collection", data=[{"data": "sample_data"}], - batch_config_params={}, vector_col="Vector", uuid_col="id", - tenant=None, ) mock_log.debug.assert_called_once_with("Input data: %s", [{"data": "sample_data"}]) @@ -99,7 +93,7 @@ def test_templates(self, create_task_instance_of_operator): dag_id=dag_id, task_id="task-id", conn_id="weaviate_conn", - class_name="my_class", + collection_name="my_collection", input_json="{{ dag.dag_id }}", input_data="{{ dag.dag_id }}", ) @@ -114,8 +108,7 @@ def test_partial_batch_hook_params(self, dag_maker, session): WeaviateIngestOperator.partial( task_id="fake-task-id", conn_id="weaviate_conn", - class_name="FooBar", - batch_params={"spam": "egg"}, + collection_name="FooBar", hook_params={"baz": "biz"}, ).expand(input_data=[{}, {}]) @@ -124,7 +117,6 @@ def test_partial_batch_hook_params(self, dag_maker, session): with set_current_task_instance_session(session=session): for ti in tis: ti.render_templates() - assert ti.task.batch_params == {"spam": "egg"} assert ti.task.hook_params == {"baz": "biz"} @@ -135,23 +127,21 @@ def operator(self): task_id="weaviate_task", conn_id="weaviate_conn", input_data=[{"data": "sample_data"}], - class_name="my_class", + collection_name="my_collection", document_column="docLink", existing="skip", uuid_column="id", vector_col="vector", - batch_config_params={"size": 1000}, ) def test_constructor(self, operator): assert operator.conn_id == "weaviate_conn" assert operator.input_data == [{"data": "sample_data"}] - assert operator.class_name == "my_class" + assert operator.collection_name == "my_collection" assert operator.document_column == "docLink" assert operator.existing == "skip" assert operator.uuid_column == "id" assert operator.vector_col == "vector" - assert operator.batch_config_params == {"size": 1000} assert operator.hook_params == {} @patch("airflow.providers.weaviate.operators.weaviate.WeaviateDocumentIngestOperator.log") @@ -162,13 +152,11 @@ def test_execute_with_input_json(self, mock_log, operator): operator.hook.create_or_replace_document_objects.assert_called_once_with( data=[{"data": "sample_data"}], - class_name="my_class", + collection_name="my_collection", document_column="docLink", existing="skip", uuid_column="id", vector_column="vector", - batch_config_params={"size": 1000}, - tenant=None, verbose=False, ) mock_log.debug.assert_called_once_with("Total input objects : %s", len([{"data": "sample_data"}])) @@ -179,7 +167,7 @@ def test_partial_hook_params(self, dag_maker, session): WeaviateDocumentIngestOperator.partial( task_id="fake-task-id", conn_id="weaviate_conn", - class_name="FooBar", + collection_name="FooBar", document_column="spam-egg", hook_params={"baz": "biz"}, ).expand(input_data=[{}, {}]) From 756dc67313596002ef191d6292bb60ae4f8d9e26 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 20 Jun 2024 14:46:26 +0800 Subject: [PATCH 44/52] test(providers/weaviate): fix hooks tests due to API migration --- .../providers/weaviate/hooks/test_weaviate.py | 547 ++++++++---------- 1 file changed, 227 insertions(+), 320 deletions(-) diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 650f938dba6b6..6c1b769b6914d 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -48,28 +48,42 @@ def weaviate_hook(): @pytest.fixture def mock_auth_api_key(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthApiKey") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.api_key") as m: yield m @pytest.fixture def mock_auth_bearer_token(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthBearerToken") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.bearer_token") as m: yield m @pytest.fixture def mock_auth_client_credentials(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientCredentials") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.client_credentials") as m: yield m @pytest.fixture def mock_auth_client_password(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientPassword") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.client_password") as m: yield m +class MockFetchObjectReturn: + def __init__(self, *, objects): + self.objects = objects + + +class MockObject: + def __init__(self, *, properties: dict, uuid: str) -> None: + self.properties = properties + self.uuid = uuid + + def __eq__(self, other: MockObject) -> bool: + return self.properties == other.properties and self.uuid == other.uuid + + class TestWeaviateHook: """ Test the WeaviateHook Hook. @@ -86,103 +100,148 @@ def setup_method(self, monkeypatch): self.scope = "scope1 scope2" self.client_password = "client_password" self.client_bearer_token = "client_bearer_token" - self.host = "http://localhost:8080" + self.host = "localhost" + self.port = 8000 + self.grpc_host = "localhost" + self.grpc_port = 50051 conns = ( Connection( conn_id=self.weaviate_api_key1, host=self.host, + port=self.port, conn_type="weaviate", - extra={"api_key": self.api_key}, + extra={"api_key": self.api_key, "grpc_host": self.grpc_host, "grpc_port": self.grpc_port}, ), Connection( conn_id=self.weaviate_api_key2, host=self.host, + port=self.port, conn_type="weaviate", - extra={"token": self.api_key}, + extra={"token": self.api_key, "grpc_host": self.grpc_host, "grpc_port": self.grpc_port}, ), Connection( conn_id=self.weaviate_client_credentials, host=self.host, + port=self.port, conn_type="weaviate", - extra={"client_secret": self.client_secret, "scope": self.scope}, + extra={ + "client_secret": self.client_secret, + "scope": self.scope, + "grpc_host": self.grpc_host, + "grpc_port": self.grpc_port, + }, ), Connection( conn_id=self.client_password, host=self.host, + port=self.port, conn_type="weaviate", login="login", password="password", + extra={"grpc_host": self.grpc_host, "grpc_port": self.grpc_port}, ), Connection( conn_id=self.client_bearer_token, host=self.host, + port=self.port, conn_type="weaviate", extra={ "access_token": self.client_bearer_token, "expires_in": 30, "refresh_token": "refresh_token", + "grpc_host": self.grpc_host, + "grpc_port": self.grpc_port, }, ), ) for conn in conns: monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.get_uri()) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_api_key_in_extra(self, mock_client, mock_auth_api_key): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_api_key_in_extra(self, mock_connect_to_custom, mock_auth_api_key): hook = WeaviateHook(conn_id=self.weaviate_api_key1) hook.get_conn() - mock_auth_api_key.assert_called_once_with(self.api_key) - mock_client.assert_called_once_with( - url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + mock_auth_api_key.assert_called_once_with(api_key=self.api_key) + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_api_key(api_key=self.api_key), + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_token_in_extra(self, mock_client, mock_auth_api_key): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_token_in_extra(self, mock_connect_to_custom, mock_auth_api_key): # when token is passed in extra hook = WeaviateHook(conn_id=self.weaviate_api_key2) hook.get_conn() - mock_auth_api_key.assert_called_once_with(self.api_key) - mock_client.assert_called_once_with( - url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + mock_auth_api_key.assert_called_once_with(api_key=self.api_key) + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_api_key(api_key=self.api_key), + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_access_token_in_extra(self, mock_client, mock_auth_bearer_token): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_access_token_in_extra(self, mock_connect_to_custom, mock_auth_bearer_token): hook = WeaviateHook(conn_id=self.client_bearer_token) hook.get_conn() mock_auth_bearer_token.assert_called_once_with( - self.client_bearer_token, expires_in=30, refresh_token="refresh_token" + access_token=self.client_bearer_token, expires_in=30, refresh_token="refresh_token" ) - mock_client.assert_called_once_with( - url=self.host, - auth_client_secret=mock_auth_bearer_token( + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_bearer_token( access_token=self.client_bearer_token, expires_in=30, refresh_token="refresh_token" ), - additional_headers={}, + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_client_secret_in_extra(self, mock_client, mock_auth_client_credentials): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_client_secret_in_extra(self, mock_connect_to_custom, mock_auth_client_credentials): hook = WeaviateHook(conn_id=self.weaviate_client_credentials) hook.get_conn() mock_auth_client_credentials.assert_called_once_with( client_secret=self.client_secret, scope=self.scope ) - mock_client.assert_called_once_with( - url=self.host, - auth_client_secret=mock_auth_client_credentials(api_key=self.client_secret, scope=self.scope), - additional_headers={}, + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_client_credentials(api_key=self.client_secret, scope=self.scope), + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_client_password_in_extra(self, mock_client, mock_auth_client_password): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_client_password_in_extra(self, mock_connect_to_custom, mock_auth_client_password): hook = WeaviateHook(conn_id=self.client_password) hook.get_conn() mock_auth_client_password.assert_called_once_with(username="login", password="password", scope=None) - mock_client.assert_called_once_with( - url=self.host, - auth_client_secret=mock_auth_client_password(username="login", password="password", scope=None), - additional_headers={}, + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_client_password(username="login", password="password", scope=None), + headers={}, ) @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5") @@ -190,12 +249,14 @@ def test_create_object(self, mock_gen_uuid, weaviate_hook): """ Test the create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + return_value = weaviate_hook.create_object({"name": "Test"}, "TestCollection") + mock_gen_uuid.assert_called_once() - mock_client.data_object.create.assert_called_once_with( - {"name": "Test"}, "TestClass", uuid=mock_gen_uuid.return_value + mock_collection.data.insert.assert_called_once_with( + properties={"name": "Test"}, uuid=mock_gen_uuid.return_value ) assert return_value @@ -203,107 +264,126 @@ def test_create_object_already_exists_return_none(self, weaviate_hook): """ Test the create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - mock_client.data_object.create.side_effect = ObjectAlreadyExistsException - return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + mock_collection.data.insert.side_effect = ObjectAlreadyExistsException + + return_value = weaviate_hook.create_object({"name": "Test"}, "TestCollection") + assert return_value is None def test_get_object(self, weaviate_hook): """ Test the get_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.get_object(class_name="TestClass", uuid="uuid") - mock_client.data_object.get.assert_called_once_with(class_name="TestClass", uuid="uuid") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + weaviate_hook.get_object(collection_name="TestCollection", uuid="uuid") + + mock_collection.query.fetch_objects.assert_called_once_with(uuid="uuid") def test_get_of_get_or_create_object(self, weaviate_hook): """ Test the get part of get_or_create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass") - mock_client.data_object.get.assert_called_once_with( - class_name="TestClass", - consistency_level=None, - tenant=None, - ) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + weaviate_hook.get_or_create_object(data_object={"name": "Test"}, collection_name="TestCollection") + + mock_collection.query.fetch_objects.assert_called_once_with() @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5") def test_create_of_get_or_create_object(self, mock_gen_uuid, weaviate_hook): """ Test the create part of get_or_create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) weaviate_hook.get_object = MagicMock(return_value=None) mock_create_object = MagicMock() weaviate_hook.create_object = mock_create_object - weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass") + + weaviate_hook.get_or_create_object(data_object={"name": "Test"}, collection_name="TestCollection") + mock_create_object.assert_called_once_with( - {"name": "Test"}, - "TestClass", + data_object={"name": "Test"}, + collection_name="TestCollection", uuid=mock_gen_uuid.return_value, - consistency_level=None, - tenant=None, vector=None, ) def test_create_of_get_or_create_object_raises_valueerror(self, weaviate_hook): """ - Test that if data_object is None or class_name is None, ValueError is raised. + Test that if data_object is None or collection_name is None, ValueError is raised. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) weaviate_hook.get_object = MagicMock(return_value=None) mock_create_object = MagicMock() + weaviate_hook.create_object = mock_create_object + with pytest.raises(ValueError): - weaviate_hook.get_or_create_object(data_object=None, class_name="TestClass") + weaviate_hook.get_or_create_object(data_object=None, collection_name="TestCollection") with pytest.raises(ValueError): - weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name=None) + weaviate_hook.get_or_create_object(data_object={"name": "Test"}, collection_name=None) def test_get_all_objects(self, weaviate_hook): """ Test the get_all_objects method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) objects = [ - {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]}, - {"deprecations": None, "objects": []}, + MockFetchObjectReturn( + objects=[ + MockObject(properties={"name": "Test1", "id": 2}, uuid="u1"), + MockObject(properties={"name": "Test2", "id": 3}, uuid="u2"), + ] + ), + MockFetchObjectReturn(objects=[]), ] mock_get_object = MagicMock() weaviate_hook.get_object = mock_get_object mock_get_object.side_effect = objects - return_value = weaviate_hook.get_all_objects(class_name="TestClass") + return_value = weaviate_hook.get_all_objects(collection_name="TestCollection") + assert weaviate_hook.get_object.call_args_list == [ - mock.call(after=None, class_name="TestClass"), - mock.call(after=3, class_name="TestClass"), + mock.call(after=None, collection_name="TestCollection"), + mock.call(after="u2", collection_name="TestCollection"), + ] + assert return_value == [ + MockObject(properties={"name": "Test1", "id": 2}, uuid="u1"), + MockObject(properties={"name": "Test2", "id": 3}, uuid="u2"), ] - assert return_value == [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}] def test_get_all_objects_returns_dataframe(self, weaviate_hook): """ Test the get_all_objects method of WeaviateHook can return a dataframe. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) objects = [ - {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]}, - {"deprecations": None, "objects": []}, + MockFetchObjectReturn( + objects=[ + MockObject(properties={"name": "Test1", "id": 2}, uuid="u1"), + MockObject(properties={"name": "Test2", "id": 3}, uuid="u2"), + ] + ), + MockFetchObjectReturn(objects=[]), ] mock_get_object = MagicMock() weaviate_hook.get_object = mock_get_object mock_get_object.side_effect = objects - return_value = weaviate_hook.get_all_objects(class_name="TestClass", as_dataframe=True) + return_value = weaviate_hook.get_all_objects(collection_name="TestCollection", as_dataframe=True) + assert weaviate_hook.get_object.call_args_list == [ - mock.call(after=None, class_name="TestClass"), - mock.call(after=3, class_name="TestClass"), + mock.call(after=None, collection_name="TestCollection"), + mock.call(after="u2", collection_name="TestCollection"), ] import pandas @@ -313,59 +393,56 @@ def test_delete_object(self, weaviate_hook): """ Test the delete_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.delete_object(uuid="uuid", class_name="TestClass") - mock_client.data_object.delete.assert_called_once_with("uuid", class_name="TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + weaviate_hook.delete_object(collection_name="TestCollection", uuid="uuid") + + mock_collection.data.delete_by_id.assert_called_once_with(uuid="uuid") def test_update_object(self, weaviate_hook): """ Test the update_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + weaviate_hook.update_object( - uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d" - ) - mock_client.data_object.update.assert_called_once_with( - {"name": "Test"}, "TestClass", "uuid", tenant="2d" + uuid="uuid", collection_name="TestCollection", properties={"name": "Test"} ) - def test_validate_object(self, weaviate_hook): - """ - Test the validate_object method of WeaviateHook. - """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.validate_object(class_name="TestClass", data_object={"name": "Test"}, uuid="2d") - mock_client.data_object.validate.assert_called_once_with({"name": "Test"}, "TestClass", uuid="2d") + mock_collection.data.update.assert_called_once_with(properties={"name": "Test"}, uuid="uuid") def test_replace_object(self, weaviate_hook): """ Test the replace_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + weaviate_hook.replace_object( - uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d" + uuid="uuid", collection_name="TestCollection", properties={"name": "Test"} ) - mock_client.data_object.replace.assert_called_once_with( - {"name": "Test"}, "TestClass", "uuid", tenant="2d" + + mock_collection.data.replace.assert_called_once_with( + properties={"name": "Test"}, uuid="uuid", references=None ) def test_object_exists(self, weaviate_hook): """ Test the object_exists method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.object_exists(class_name="TestClass", uuid="2d") - mock_client.data_object.exists.assert_called_once_with("2d", class_name="TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + weaviate_hook.object_exists(collection_name="TestCollection", uuid="2d") + mock_collection.data.exists.assert_called_once_with(uuid="2d") -def test_create_class(weaviate_hook): + +def test_create_collection(weaviate_hook): """ - Test the create_class method of WeaviateHook. + Test the create_collection method of WeaviateHook. """ # Mock the Weaviate Client mock_client = MagicMock() @@ -373,40 +450,15 @@ def test_create_class(weaviate_hook): # Define test class JSON test_class_json = { - "class": "TestClass", + "class": "TestCollection", "description": "Test class for unit testing", } - # Test the create_class method - weaviate_hook.create_class(test_class_json) - - # Assert that the create_class method was called with the correct arguments - mock_client.schema.create_class.assert_called_once_with(test_class_json) - - -def test_create_schema(weaviate_hook): - """ - Test the create_schema method of WeaviateHook. - """ - # Mock the Weaviate Client - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - - # Define test schema JSON - test_schema_json = { - "classes": [ - { - "class": "TestClass", - "description": "Test class for unit testing", - } - ] - } - - # Test the create_schema method - weaviate_hook.create_schema(test_schema_json) + # Test the create_collection method + weaviate_hook.create_collection(test_class_json) - # Assert that the create_schema method was called with the correct arguments - mock_client.schema.create.assert_called_once_with(test_schema_json) + # Assert that the create_collection method was called with the correct arguments + mock_client.schema.create_collection.assert_called_once_with(test_class_json) @pytest.mark.parametrize( @@ -426,10 +478,10 @@ def test_batch_data(data, expected_length, weaviate_hook): weaviate_hook.get_conn = MagicMock(return_value=mock_client) # Define test data - test_class_name = "TestClass" + test_collection_name = "TestCollection" # Test the batch_data method - weaviate_hook.batch_data(test_class_name, data) + weaviate_hook.batch_data(test_collection_name, data) # Assert that the batch_data method was called with the correct arguments mock_client.batch.configure.assert_called_once() @@ -447,7 +499,7 @@ def test_batch_data_retry(get_conn, weaviate_hook): error.response = response side_effect = [None, error, None, error, None] get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect - weaviate_hook.batch_data("TestClass", data) + weaviate_hook.batch_data("TestCollection", data) assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect) @@ -461,25 +513,25 @@ def test_batch_data_retry(get_conn, weaviate_hook): ], ids=["ignore", "replace", "fail", "invalid_option"], ) -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.delete_classes") +@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.delete_collections") @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema") @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema") def test_upsert_schema_scenarios( - get_schema, create_schema, delete_classes, get_schema_value, existing, expected_value, weaviate_hook + get_schema, create_schema, delete_collections, get_schema_value, existing, expected_value, weaviate_hook ): schema_json = { "B": {"class": "B"}, "C": {"class": "C"}, } with ExitStack() as stack: - delete_classes.return_value = None + delete_collections.return_value = None if existing in ["fail", "invalid_option"]: stack.enter_context(pytest.raises(ValueError)) get_schema.return_value = get_schema_value weaviate_hook.create_or_replace_classes(schema_json=schema_json, existing=existing) create_schema.assert_called_once_with({"classes": expected_value}) if existing == "replace": - delete_classes.assert_called_once_with(class_names=["B"]) + delete_collections.assert_called_once_with(collection_names=["B"]) @mock.patch("builtins.open") @@ -497,180 +549,37 @@ def test_upsert_schema_json_file_param(get_schema, create_schema, load, open, we create_schema.assert_called_once_with({"classes": [{"class": "C"}]}) -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client") -def test_delete_classes(get_client, weaviate_hook): - class_names = ["class_a", "class_b"] - get_client.return_value.schema.delete_class.side_effect = [ +@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") +def test_delete_collections(get_conn, weaviate_hook): + collection_names = ["class_a", "class_b"] + get_conn.return_value.schema.delete_collection.side_effect = [ weaviate.UnexpectedStatusCodeException("something failed", requests.Response()), None, ] - error_list = weaviate_hook.delete_classes(class_names, if_error="continue") + error_list = weaviate_hook.delete_collections(collection_names, if_error="continue") assert error_list == ["class_a"] - get_client.return_value.schema.delete_class.side_effect = weaviate.UnexpectedStatusCodeException( + get_conn.return_value.schema.delete_collection.side_effect = weaviate.UnexpectedStatusCodeException( "something failed", requests.Response() ) with pytest.raises(weaviate.UnexpectedStatusCodeException): - weaviate_hook.delete_classes("class_a", if_error="stop") + weaviate_hook.delete_collections("class_a", if_error="stop") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client") -def test_http_errors_of_delete_classes(get_client, weaviate_hook): - class_names = ["class_a", "class_b"] +@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") +def test_http_errors_of_delete_collections(get_conn, weaviate_hook): + collection_names = ["class_a", "class_b"] resp = requests.Response() resp.status_code = 429 - get_client.return_value.schema.delete_class.side_effect = [ + get_conn.return_value.schema.delete_collection.side_effect = [ requests.exceptions.HTTPError(response=resp), None, requests.exceptions.ConnectionError, None, ] - error_list = weaviate_hook.delete_classes(class_names, if_error="continue") + error_list = weaviate_hook.delete_collections(collection_names, if_error="continue") assert error_list == [] - assert get_client.return_value.schema.delete_class.call_count == 4 - - -@pytest.mark.parametrize( - argnames=["classes_to_test", "expected_result"], - argvalues=[ - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - ], - }, - ], - True, - ), - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - ], - }, - ], - True, - ), - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "invalid_property", - "description": "Last name of the author", - "dataType": ["text"], - } - ], - }, - ], - False, - ), - ( - [ - { - "class": "invalid_class", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - ], - }, - ], - False, - ), - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - { - "name": "name", - "description": "Name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - ], - }, - ], - True, - ), - ], - ids=( - "property_level_check", - "class_level_check", - "invalid_property", - "invalid_class", - "swapped_properties", - ), -) -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema") -def test_contains_schema(get_schema, classes_to_test, expected_result, weaviate_hook): - get_schema.return_value = { - "classes": [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "name", - "description": "Name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - ], - }, - { - "class": "Article", - "description": "An article written by an Author", - "properties": [ - { - "name": "name", - "description": "Name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - ], - }, - ] - } - assert weaviate_hook.check_subset_of_schema(classes_to_test) == expected_result + assert get_conn.return_value.schema.delete_collection.call_count == 4 @mock.patch("weaviate.util.generate_uuid5") @@ -678,7 +587,7 @@ def test___generate_uuids(generate_uuid5, weaviate_hook): df = pd.DataFrame.from_dict({"name": ["ross", "bob"], "age": ["12", "22"], "gender": ["m", "m"]}) with pytest.raises(ValueError, match=r"Columns last_name don't exist in dataframe"): weaviate_hook._generate_uuids( - df=df, class_name="test", unique_columns=["name", "age", "gender", "last_name"] + df=df, collection_name="test", unique_columns=["name", "age", "gender", "last_name"] ) df = pd.DataFrame.from_dict( @@ -687,14 +596,14 @@ def test___generate_uuids(generate_uuid5, weaviate_hook): with pytest.raises( ValueError, match=r"Property 'id' already in dataset. Consider renaming or specify 'uuid_column'" ): - weaviate_hook._generate_uuids(df=df, class_name="test", unique_columns=["name", "age", "gender"]) + weaviate_hook._generate_uuids(df=df, collection_name="test", unique_columns=["name", "age", "gender"]) with pytest.raises( ValueError, match=r"Property age already in dataset. Consider renaming or specify a different 'uuid_column'.", ): weaviate_hook._generate_uuids( - df=df, uuid_column="age", class_name="test", unique_columns=["name", "age", "gender"] + df=df, uuid_column="age", collection_name="test", unique_columns=["name", "age", "gender"] ) @@ -712,7 +621,7 @@ def test__delete_objects(delete_object, weaviate_hook): ) delete_object.side_effect = [not_found_exception, None, http_429_exception, http_429_exception, None] - weaviate_hook._delete_objects(uuids=["1", "2", "3"], class_name="test") + weaviate_hook._delete_objects(uuids=["1", "2", "3"], collection_name="test") assert delete_object.call_count == 5 @@ -750,7 +659,7 @@ def test___get_segregated_documents(_get_documents_to_uuid_map, _prepare_documen data=pd.DataFrame(), document_column="doc_key", uuid_column="id", - class_name="doc", + collection_name="doc", ) assert changed_documents == {"abc.doc"} assert unchanged_docs == {"xyz.doc"} @@ -776,7 +685,7 @@ def test_error_option_of_create_or_replace_document_objects( _generate_uuids.return_value = (df, "id") with pytest.raises(ValueError, match="Documents abc.xml already exists. You can either skip or replace"): weaviate_hook.create_or_replace_document_objects( - data=df, document_column="doc", class_name="test", existing="error" + data=df, document_column="doc", collection_name="test", existing="error" ) @@ -797,7 +706,7 @@ def test_skip_option_of_create_or_replace_document_objects( } ) - class_name = "test" + collection_name = "test" documents_to_uuid_map, changed_documents, unchanged_documents, new_documents = ( {}, {"abc.xml"}, @@ -813,7 +722,7 @@ def test_skip_option_of_create_or_replace_document_objects( _generate_uuids.return_value = (df, "id") weaviate_hook.create_or_replace_document_objects( - data=df, class_name=class_name, existing="skip", document_column="doc" + data=df, collection_name=collection_name, existing="skip", document_column="doc" ) pd.testing.assert_frame_equal( @@ -838,7 +747,7 @@ def test_replace_option_of_create_or_replace_document_objects( } ) - class_name = "test" + collection_name = "test" documents_to_uuid_map, changed_documents, unchanged_documents, new_documents = ( {"abc.xml": {"uuid"}}, {"abc.xml"}, @@ -854,16 +763,14 @@ def test_replace_option_of_create_or_replace_document_objects( ) _generate_uuids.return_value = (df, "id") weaviate_hook.create_or_replace_document_objects( - data=df, class_name=class_name, existing="replace", document_column="doc" + data=df, collection_name=collection_name, existing="replace", document_column="doc" ) _delete_all_documents_objects.assert_called_with( document_keys=list(changed_documents), total_objects_count=1, document_column="doc", - class_name="test", + collection_name="test", batch_delete_error=[], - tenant=None, - batch_config_params=None, verbose=False, ) pd.testing.assert_frame_equal( From 6593f171aade9e5c04a29c386fca033ad6c1f081 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 20 Jun 2024 15:56:38 +0800 Subject: [PATCH 45/52] test(providers/weaviate): migrate system tests to v4 API --- .../weaviate/example_weaviate_cohere.py | 27 ++-- .../example_weaviate_dynamic_mapping_dag.py | 29 ++-- .../weaviate/example_weaviate_openai.py | 31 ++-- .../weaviate/example_weaviate_operator.py | 142 +++++++----------- .../weaviate/example_weaviate_using_hook.py | 76 ++++------ .../example_weaviate_vectorizer_dag.py | 37 +++-- ...example_weaviate_without_vectorizer_dag.py | 32 ++-- 7 files changed, 159 insertions(+), 215 deletions(-) diff --git a/tests/system/providers/weaviate/example_weaviate_cohere.py b/tests/system/providers/weaviate/example_weaviate_cohere.py index f2a190bfcf2ef..8415bcecc7ac8 100644 --- a/tests/system/providers/weaviate/example_weaviate_cohere.py +++ b/tests/system/providers/weaviate/example_weaviate_cohere.py @@ -36,19 +36,15 @@ def example_weaviate_cohere(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "Weaviate_example_class", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # Collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection(name="Weaviate_example_collection", vectorizer_config=None) @setup @task @@ -60,6 +56,7 @@ def get_data_to_embed(): return [[item["Question"]] for item in data] data_to_embed = get_data_to_embed() + embed_data = CohereEmbeddingOperator.partial( task_id="embedding_using_xcom_data", ).expand(input_text=data_to_embed["return_value"]) @@ -81,7 +78,7 @@ def update_vector_data_in_json(**kwargs): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name="Weaviate_example_class", + collection_name="Weaviate_example_collection", input_data=update_vector_data_in_json["return_value"], ) @@ -92,24 +89,24 @@ def update_vector_data_in_json(**kwargs): @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collections(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["Weaviate_example_class"]) + weaviate_hook.delete_collections(["Weaviate_example_collections"]) ( - create_weaviate_class() + create_weaviate_collection() >> embed_data >> update_vector_data_in_json >> perform_ingestion >> embed_query - >> delete_weaviate_class() + >> delete_weaviate_collections() ) diff --git a/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py b/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py index 4b29eb1998b80..a52c0e52625ce 100644 --- a/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py +++ b/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py @@ -17,6 +17,7 @@ from __future__ import annotations import pendulum +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, setup, task, teardown from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator @@ -34,19 +35,15 @@ def example_weaviate_dynamic_mapping_dag(): @setup @task - def create_weaviate_class(data): + def create_weaviate_collection(data): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": data[0], - "vectorizer": data[1], - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection(data[0], vectorizer_config=data[1]) @setup @task @@ -64,28 +61,30 @@ def get_data_to_ingest(): task_id="perform_ingestion", conn_id="weaviate_default", ).expand( - class_name=["example1", "example2"], + collection_name=["example1", "example2"], input_data=get_data_to_ingest["return_value"], ) @teardown @task - def delete_weaviate_class(class_name): + def delete_weaviate_collection(collection_name): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes([class_name]) + weaviate_hook.delete_collections([collection_name]) ( - create_weaviate_class.expand(data=[["example1", "none"], ["example2", "text2vec-openai"]]) + create_weaviate_collection.expand( + data=[["example1", "none"], ["example2", Configure.Vectorizer.text2vec_openai()]] + ) >> perform_ingestion - >> delete_weaviate_class.expand(class_name=["example1", "example2"]) + >> delete_weaviate_collection.expand(collection_name=["example1", "example2"]) ) diff --git a/tests/system/providers/weaviate/example_weaviate_openai.py b/tests/system/providers/weaviate/example_weaviate_openai.py index 40304236818f0..97d002285d8cc 100644 --- a/tests/system/providers/weaviate/example_weaviate_openai.py +++ b/tests/system/providers/weaviate/example_weaviate_openai.py @@ -40,17 +40,13 @@ def example_weaviate_openai(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "Weaviate_example_class", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection("Weaviate_example_collection", "None") @setup @task @@ -59,6 +55,7 @@ def get_data_to_embed(): return [item["Question"] for item in data] data_to_embed = get_data_to_embed() + embed_data = OpenAIEmbeddingOperator.partial( task_id="embedding_using_xcom_data", conn_id="openai_default", @@ -79,7 +76,7 @@ def update_vector_data_in_json(**kwargs): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name="Weaviate_example_class", + collection_name="Weaviate_example_collection", input_data=update_vector_data_in_json["return_value"], ) @@ -96,31 +93,31 @@ def query_weaviate(**kwargs): query_vector = ti.xcom_pull(task_ids="embed_query", key="return_value") weaviate_hook = WeaviateHook() properties = ["question", "answer", "category"] - response = weaviate_hook.query_with_vector(query_vector, "Weaviate_example_class", *properties) + response = weaviate_hook.query_with_vector(query_vector, "Weaviate_example_collection", *properties) assert ( "In 1953 Watson & Crick built a model" - in response["data"]["Get"]["Weaviate_example_class"][0]["question"] + in response["data"]["Get"]["Weaviate_example_collection"][0]["question"] ) @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collection(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["Weaviate_example_class"]) + weaviate_hook.delete_collections(["Weaviate_example_collection"]) ( - create_weaviate_class() + create_weaviate_collection() >> embed_data >> update_vector_data_in_json >> perform_ingestion >> embed_query >> query_weaviate() - >> delete_weaviate_class() + >> delete_weaviate_collection() ) diff --git a/tests/system/providers/weaviate/example_weaviate_operator.py b/tests/system/providers/weaviate/example_weaviate_operator.py index ae4097006e0c1..b92538059e556 100644 --- a/tests/system/providers/weaviate/example_weaviate_operator.py +++ b/tests/system/providers/weaviate/example_weaviate_operator.py @@ -17,6 +17,8 @@ from __future__ import annotations import pendulum +from weaviate.classes.config import DataType, Property +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, task, teardown from airflow.providers.weaviate.operators.weaviate import ( @@ -95,22 +97,18 @@ def example_weaviate_using_operator(): Example Weaviate DAG demonstrating usage of the operator. """ - # Example tasks to create a Weaviate class without vectorizers, store data with custom vectors in XCOM, and call + # Example tasks to create a Weaviate collection without vectorizers, store data with custom vectors in XCOM, and call # WeaviateIngestOperator to ingest data with those custom vectors. @task() - def create_class_without_vectorizer(): + def create_collection_without_vectorizer(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "QuestionWithoutVectorizerUsingOperator", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection("QuestionWithoutVectorizerUsingOperator") @task(trigger_rule="all_done") def store_data_with_vectors_in_xcom(): @@ -120,7 +118,7 @@ def store_data_with_vectors_in_xcom(): batch_data_with_vectors_xcom_data = WeaviateIngestOperator( task_id="batch_data_with_vectors_xcom_data", conn_id="weaviate_default", - class_name="QuestionWithoutVectorizerUsingOperator", + collection_name="QuestionWithoutVectorizerUsingOperator", input_data=store_data_with_vectors_in_xcom(), trigger_rule="all_done", ) @@ -130,82 +128,52 @@ def store_data_with_vectors_in_xcom(): batch_data_with_vectors_callable_data = WeaviateIngestOperator( task_id="batch_data_with_vectors_callable_data", conn_id="weaviate_default", - class_name="QuestionWithoutVectorizerUsingOperator", + collection_name="QuestionWithoutVectorizerUsingOperator", input_data=get_data_with_vectors(), trigger_rule="all_done", ) # [END howto_operator_weaviate_embedding_and_ingest_callable_data_with_vectors] - # Example tasks to create class with OpenAI vectorizer, store data without vectors in XCOM, and call + # Example tasks to create collection with OpenAI vectorizer, store data without vectors in XCOM, and call # WeaviateIngestOperator to ingest data by internally generating OpenAI vectors while ingesting. @task() - def create_class_with_vectorizer(): + def create_collection_with_vectorizer(): """ - Example task to create class with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. + Example task to create collection with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - class_obj = { - "class": "QuestionWithOpenAIVectorizerUsingOperator", - "description": "Information from a Jeopardy! question", # description of the class - "properties": [ - { - "dataType": ["text"], - "description": "The question", - "name": "question", - }, - { - "dataType": ["text"], - "description": "The answer", - "name": "answer", - }, - { - "dataType": ["text"], - "description": "The category", - "name": "category", - }, + weaviate_hook.create_collection( + "QuestionWithOpenAIVectorizerUsingOperator", + description="Information from a Jeopardy! question", + properties=[ + Property(name="question", description="The question", data_type=DataType.TEXT), + Property(name="answer", description="The answer", data_type=DataType.TEXT), + Property(name="category", description="The category", data_type=DataType.TEXT), ], - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @task() - def create_class_for_doc_data_with_vectorizer(): + def create_collection_for_doc_data_with_vectorizer(): """ - Example task to create class with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. + Example task to create collection with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - class_obj = { - "class": "QuestionWithOpenAIVectorizerUsingOperatorDocs", - "description": "Information from a Jeopardy! question", # description of the class - "properties": [ - { - "dataType": ["text"], - "description": "The question", - "name": "question", - }, - { - "dataType": ["text"], - "description": "The answer", - "name": "answer", - }, - { - "dataType": ["text"], - "description": "The category", - "name": "category", - }, - { - "dataType": ["text"], - "description": "URL for source document", - "name": "docLink", - }, + weaviate_hook.create_collection( + "QuestionWithOpenAIVectorizerUsingOperatorDocs", + description="Information from a Jeopardy! question", + properties=[ + Property(name="question", description="The question", data_type=DataType.TEXT), + Property(name="answer", description="The answer", data_type=DataType.TEXT), + Property(name="category", description="The category", data_type=DataType.TEXT), + Property(name="docLink", description="URL for source document", data_type=DataType.TEXT), ], - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @task(trigger_rule="all_done") def store_data_without_vectors_in_xcom(): @@ -224,13 +192,14 @@ def store_doc_data_without_vectors_in_xcom(): return data xcom_data_without_vectors = store_data_without_vectors_in_xcom() + xcom_doc_data_without_vectors = store_doc_data_without_vectors_in_xcom() # [START howto_operator_weaviate_ingest_xcom_data_without_vectors] batch_data_without_vectors_xcom_data = WeaviateIngestOperator( task_id="batch_data_without_vectors_xcom_data", conn_id="weaviate_default", - class_name="QuestionWithOpenAIVectorizerUsingOperator", + collection_name="QuestionWithOpenAIVectorizerUsingOperator", input_data=xcom_data_without_vectors["return_value"], trigger_rule="all_done", ) @@ -240,7 +209,7 @@ def store_doc_data_without_vectors_in_xcom(): batch_data_without_vectors_callable_data = WeaviateIngestOperator( task_id="batch_data_without_vectors_callable_data", conn_id="weaviate_default", - class_name="QuestionWithOpenAIVectorizerUsingOperator", + collection_name="QuestionWithOpenAIVectorizerUsingOperator", input_data=get_data_without_vectors(), trigger_rule="all_done", ) @@ -251,24 +220,23 @@ def store_doc_data_without_vectors_in_xcom(): existing="replace", document_column="docLink", conn_id="weaviate_default", - class_name="QuestionWithOpenAIVectorizerUsingOperatorDocs", - batch_config_params={"batch_size": 1000}, + collection_name="QuestionWithOpenAIVectorizerUsingOperatorDocs", input_data=xcom_doc_data_without_vectors["return_value"], trigger_rule="all_done", ) @teardown @task - def delete_weaviate_class_Vector(): + def delete_weaviate_collection_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes( + weaviate_hook.delete_collections( [ "QuestionWithOpenAIVectorizerUsingOperator", ] @@ -276,16 +244,16 @@ def delete_weaviate_class_Vector(): @teardown @task - def delete_weaviate_class_without_Vector(): + def delete_weaviate_collection_without_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes( + weaviate_hook.delete_collections( [ "QuestionWithoutVectorizerUsingOperator", ] @@ -293,34 +261,34 @@ def delete_weaviate_class_without_Vector(): @teardown @task - def delete_weaviate_docs_class_without_Vector(): + def delete_weaviate_docs_collection_without_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["QuestionWithOpenAIVectorizerUsingOperatorDocs"]) + weaviate_hook.delete_collections(["QuestionWithOpenAIVectorizerUsingOperatorDocs"]) ( - create_class_without_vectorizer() + create_collection_without_vectorizer() >> [batch_data_with_vectors_xcom_data, batch_data_with_vectors_callable_data] - >> delete_weaviate_class_without_Vector() + >> delete_weaviate_collection_without_vector() ) ( - create_class_for_doc_data_with_vectorizer() + create_collection_for_doc_data_with_vectorizer() >> [create_or_replace_document_objects_without_vectors] - >> delete_weaviate_docs_class_without_Vector() + >> delete_weaviate_docs_collection_without_vector() ) ( - create_class_with_vectorizer() + create_collection_with_vectorizer() >> [ batch_data_without_vectors_xcom_data, batch_data_without_vectors_callable_data, ] - >> delete_weaviate_class_Vector() + >> delete_weaviate_collection_vector() ) diff --git a/tests/system/providers/weaviate/example_weaviate_using_hook.py b/tests/system/providers/weaviate/example_weaviate_using_hook.py index 3cb185007f2c6..11d57bfbef2a9 100644 --- a/tests/system/providers/weaviate/example_weaviate_using_hook.py +++ b/tests/system/providers/weaviate/example_weaviate_using_hook.py @@ -17,6 +17,8 @@ from __future__ import annotations import pendulum +from weaviate.classes.config import DataType, Property +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, task, teardown @@ -31,51 +33,37 @@ def example_weaviate_dag_using_hook(): """Example Weaviate DAG demonstrating usage of the hook.""" @task() - def create_class_with_vectorizer(): + def create_collection_with_vectorizer(): """ - Example task to create class with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. + Example task to create collection with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - class_obj = { - "class": "QuestionWithOpenAIVectorizerUsingHook", - "description": "Information from a Jeopardy! question", # description of the class - "properties": [ - { - "dataType": ["text"], - "description": "The question", - "name": "question", - }, - { - "dataType": ["text"], - "description": "The answer", - "name": "answer", - }, - { - "dataType": ["text"], - "description": "The category", - "name": "category", - }, + weaviate_hook.create_collection( + "QuestionWithOpenAIVectorizerUsingHook", + description="Information from a Jeopardy! question", + properties=[ + Property(name="question", description="The question", data_type=DataType.TEXT), + Property(name="answer", description="The answer", data_type=DataType.TEXT), + Property(name="category", description="The category", data_type=DataType.TEXT), ], - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @task() - def create_class_without_vectorizer(): + def create_collection_without_vectorizer(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "QuestionWithoutVectorizerUsingHook", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection( + "QuestionWithoutVectorizerUsingHook", + vectorizer_config=None, + ) @task(trigger_rule="all_done") def store_data_without_vectors_in_xcom(): @@ -109,42 +97,42 @@ def batch_data_with_vectors(data: list): @teardown @task - def delete_weaviate_class_Vector(): + def delete_weaviate_collection_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["QuestionWithOpenAIVectorizerUsingHook"]) + weaviate_hook.delete_collections(["QuestionWithOpenAIVectorizerUsingHook"]) @teardown @task - def delete_weaviate_class_without_Vector(): + def delete_weaviate_collection_without_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["QuestionWithoutVectorizerUsingHook"]) + weaviate_hook.delete_collections(["QuestionWithoutVectorizerUsingHook"]) data_with_vectors = store_data_with_vectors_in_xcom() ( - create_class_without_vectorizer() + create_collection_without_vectorizer() >> batch_data_with_vectors(data_with_vectors["return_value"]) - >> delete_weaviate_class_Vector() + >> delete_weaviate_collection_vector() ) data_without_vectors = store_data_without_vectors_in_xcom() ( - create_class_with_vectorizer() + create_collection_with_vectorizer() >> batch_data_without_vectors(data_without_vectors["return_value"]) - >> delete_weaviate_class_without_Vector() + >> delete_weaviate_collection_without_vector() ) diff --git a/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py b/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py index 77a6c6c1cd676..c78e431396d8c 100644 --- a/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py +++ b/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py @@ -17,11 +17,12 @@ from __future__ import annotations import pendulum +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, setup, task, teardown from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator -class_name = "Weaviate_with_vectorizer_example_class" +collection_name = "Weaviate_with_vectorizer_example_collection" @dag( @@ -37,19 +38,18 @@ def example_weaviate_vectorizer_dag(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": class_name, - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection( + collection_name, + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @setup @task @@ -65,7 +65,7 @@ def get_data_to_ingest(): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name=class_name, + collection_name=collection_name, input_data=data_to_ingest["return_value"], ) @@ -75,26 +75,25 @@ def query_weaviate(): weaviate_hook = WeaviateHook() properties = ["question", "answer", "category"] - response = weaviate_hook.query_without_vector( - "biology", "Weaviate_with_vectorizer_example_class", *properties + response = weaviate_hook.query_with_text( + "biology", "Weaviate_with_vectorizer_example_collection", *properties ) - assert "In 1953 Watson & Crick built a model" in response["data"]["Get"][class_name][0]["question"] + assert "In 1953 Watson & Crick built a model" in response.objects[0].properties["question"] @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collection(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes([class_name]) + weaviate_hook.delete_collections([collection_name]) - delete_weaviate_class = delete_weaviate_class() - create_weaviate_class() >> perform_ingestion >> query_weaviate() >> delete_weaviate_class + create_weaviate_collection() >> perform_ingestion >> query_weaviate() >> delete_weaviate_collection() example_weaviate_vectorizer_dag() diff --git a/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py b/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py index c73aa7f2d067b..a8430661afa7e 100644 --- a/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py +++ b/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py @@ -22,7 +22,7 @@ from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator -class_name = "Weaviate_example_without_vectorizer_class" +collection_name = "Weaviate_example_without_vectorizer_collection" @dag( @@ -38,19 +38,15 @@ def example_weaviate_without_vectorizer_dag(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": class_name, - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection(collection_name, vectorizer_config=None) @setup @task @@ -66,7 +62,7 @@ def get_data_without_vectors(): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name=class_name, + collection_name=collection_name, input_data=data_to_ingest["return_value"], ) @@ -86,29 +82,29 @@ def query_weaviate(**kwargs): weaviate_hook = WeaviateHook() properties = ["question", "answer", "category"] response = weaviate_hook.query_with_vector( - query_vector, "Weaviate_example_without_vectorizer_class", *properties + query_vector, "Weaviate_example_without_vectorizer_collection", *properties ) - assert "In 1953 Watson & Crick built a model" in response["data"]["Get"][class_name][0]["question"] + assert "In 1953 Watson & Crick built a model" in response.objects[0].properties["question"] @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collection(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes([class_name]) + weaviate_hook.delete_collections([collection_name]) ( - create_weaviate_class() + create_weaviate_collection() >> perform_ingestion >> embedd_query >> query_weaviate() - >> delete_weaviate_class() + >> delete_weaviate_collection() ) From 95d159c54acf95841e77564c56700c9c3ada720b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 20 Jun 2024 16:16:14 +0800 Subject: [PATCH 46/52] docs(providers/weaviate): update doc for v4 API migration --- .../connections.rst | 20 ++++- .../index.rst | 2 +- .../providers/weaviate/hooks/test_weaviate.py | 85 ++++--------------- 3 files changed, 36 insertions(+), 71 deletions(-) diff --git a/docs/apache-airflow-providers-weaviate/connections.rst b/docs/apache-airflow-providers-weaviate/connections.rst index 081fe14d92acb..4721d5e747eb6 100644 --- a/docs/apache-airflow-providers-weaviate/connections.rst +++ b/docs/apache-airflow-providers-weaviate/connections.rst @@ -31,7 +31,7 @@ Configuring the Connection -------------------------- Host (required) - Host URL to connect to the Weaviate cluster. + The host to use for the Weaviate cluster REST and GraphQL API calls. OIDC Username (optional) Username for the OIDC user when OIDC option is to be used for authentication. @@ -39,6 +39,9 @@ OIDC Username (optional) OIDC Password (optional) Password for the OIDC user when OIDC option is to be used for authentication. +Port (option) + The port to use for the Weaviate cluster REST and GraphQL API calls. + Extra (optional) Specify the extra parameters (as json dictionary) that can be used in the connection. All parameters are optional. @@ -48,11 +51,24 @@ Extra (optional) * If you'd like to use Vectorizers for your class, configure the API keys to use the corresponding embedding API. The extras accepts a key ``additional_headers`` containing the dictionary of API keys for the embedding API authentication. They are mentioned in a section here: - `addtional_headers `__ + `Third party API keys `__ Weaviate API Token (optional) Specify your Weaviate API Key to connect when API Key option is to be used for authentication. +Use https (optional) + Whether to use https for the Weaviate cluster REST and GraphQL API calls. + +gRPT host (optional) + The host to use for the Weaviate cluster gRPC API. + +gRPC port (optional) + The port to use for the Weaviate cluster gRPC API. + +Use a secure channel for the underlying gRPC API (optional) + Whether to use a secure channel for the the Weaviate cluster gRPC API. + + Supported Authentication Methods -------------------------------- * API Key Authentication: This method uses the Weaviate API Key to authenticate the connection. You can either have the diff --git a/docs/apache-airflow-providers-weaviate/index.rst b/docs/apache-airflow-providers-weaviate/index.rst index 52b662b262fea..36d0a5ed4e523 100644 --- a/docs/apache-airflow-providers-weaviate/index.rst +++ b/docs/apache-airflow-providers-weaviate/index.rst @@ -104,7 +104,7 @@ PIP package Version required =================== ========================================= ``apache-airflow`` ``>=2.7.0`` ``httpx`` ``>=0.25.0`` -``weaviate-client`` ``>=3.24.2`` +``weaviate-client`` ``>=4.4.0`` ``pandas`` ``>=2.1.2,<2.2; python_version >= "3.9"`` ``pandas`` ``>=1.5.3,<2.2; python_version < "3.9"`` =================== ========================================= diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 6c1b769b6914d..1b733400f660b 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -from contextlib import ExitStack from unittest import mock from unittest.mock import MagicMock, Mock @@ -448,17 +447,13 @@ def test_create_collection(weaviate_hook): mock_client = MagicMock() weaviate_hook.get_conn = MagicMock(return_value=mock_client) - # Define test class JSON - test_class_json = { - "class": "TestCollection", - "description": "Test class for unit testing", - } - # Test the create_collection method - weaviate_hook.create_collection(test_class_json) + weaviate_hook.create_collection("TestCollection", description="Test class for unit testing") # Assert that the create_collection method was called with the correct arguments - mock_client.schema.create_collection.assert_called_once_with(test_class_json) + mock_client.collections.create.assert_called_once_with( + name="TestCollection", description="Test class for unit testing" + ) @pytest.mark.parametrize( @@ -473,9 +468,9 @@ def test_batch_data(data, expected_length, weaviate_hook): """ Test the batch_data method of WeaviateHook. """ - # Mock the Weaviate Client - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + # Mock the Weaviate Collection + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) # Define test data test_collection_name = "TestCollection" @@ -484,9 +479,9 @@ def test_batch_data(data, expected_length, weaviate_hook): weaviate_hook.batch_data(test_collection_name, data) # Assert that the batch_data method was called with the correct arguments - mock_client.batch.configure.assert_called_once() - mock_batch_context = mock_client.batch.__enter__.return_value - assert mock_batch_context.add_data_object.call_count == expected_length + mock_collection.batch.dynamic.assert_called_once() + mock_batch_context = mock_collection.batch.dynamic.__enter__.return_value + assert mock_batch_context.add_object.call_count == expected_length @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") @@ -503,63 +498,17 @@ def test_batch_data_retry(get_conn, weaviate_hook): assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect) -@pytest.mark.parametrize( - argnames=["get_schema_value", "existing", "expected_value"], - argvalues=[ - ({"classes": [{"class": "A"}, {"class": "B"}]}, "ignore", [{"class": "C"}]), - ({"classes": [{"class": "A"}, {"class": "B"}]}, "replace", [{"class": "B"}, {"class": "C"}]), - ({"classes": [{"class": "A"}, {"class": "B"}]}, "fail", {}), - ({"classes": [{"class": "A"}, {"class": "B"}]}, "invalid_option", {}), - ], - ids=["ignore", "replace", "fail", "invalid_option"], -) -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.delete_collections") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema") -def test_upsert_schema_scenarios( - get_schema, create_schema, delete_collections, get_schema_value, existing, expected_value, weaviate_hook -): - schema_json = { - "B": {"class": "B"}, - "C": {"class": "C"}, - } - with ExitStack() as stack: - delete_collections.return_value = None - if existing in ["fail", "invalid_option"]: - stack.enter_context(pytest.raises(ValueError)) - get_schema.return_value = get_schema_value - weaviate_hook.create_or_replace_classes(schema_json=schema_json, existing=existing) - create_schema.assert_called_once_with({"classes": expected_value}) - if existing == "replace": - delete_collections.assert_called_once_with(collection_names=["B"]) - - -@mock.patch("builtins.open") -@mock.patch("json.load") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema") -def test_upsert_schema_json_file_param(get_schema, create_schema, load, open, weaviate_hook): - """Test if schema_json is path to a json file""" - get_schema.return_value = {"classes": [{"class": "A"}, {"class": "B"}]} - load.return_value = { - "B": {"class": "B"}, - "C": {"class": "C"}, - } - weaviate_hook.create_or_replace_classes(schema_json="/tmp/some_temp_file.json", existing="ignore") - create_schema.assert_called_once_with({"classes": [{"class": "C"}]}) - - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") def test_delete_collections(get_conn, weaviate_hook): - collection_names = ["class_a", "class_b"] - get_conn.return_value.schema.delete_collection.side_effect = [ + collection_names = ["collection_a", "collection_b"] + get_conn.return_value.collections.delete.side_effect = [ weaviate.UnexpectedStatusCodeException("something failed", requests.Response()), None, ] error_list = weaviate_hook.delete_collections(collection_names, if_error="continue") - assert error_list == ["class_a"] + assert error_list == ["collection_a"] - get_conn.return_value.schema.delete_collection.side_effect = weaviate.UnexpectedStatusCodeException( + get_conn.return_value.collections.delete.side_effect = weaviate.UnexpectedStatusCodeException( "something failed", requests.Response() ) with pytest.raises(weaviate.UnexpectedStatusCodeException): @@ -568,10 +517,10 @@ def test_delete_collections(get_conn, weaviate_hook): @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") def test_http_errors_of_delete_collections(get_conn, weaviate_hook): - collection_names = ["class_a", "class_b"] + collection_names = ["collection_a", "collection_b"] resp = requests.Response() resp.status_code = 429 - get_conn.return_value.schema.delete_collection.side_effect = [ + get_conn.return_value.collections.delete.side_effect = [ requests.exceptions.HTTPError(response=resp), None, requests.exceptions.ConnectionError, @@ -579,7 +528,7 @@ def test_http_errors_of_delete_collections(get_conn, weaviate_hook): ] error_list = weaviate_hook.delete_collections(collection_names, if_error="continue") assert error_list == [] - assert get_conn.return_value.schema.delete_collection.call_count == 4 + assert get_conn.return_value.collections.delete.call_count == 4 @mock.patch("weaviate.util.generate_uuid5") From f8fa89b52c16860ab7ac94ed1682c078edabbd28 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 20 Jun 2024 17:48:50 +0800 Subject: [PATCH 47/52] test(providers/weaviate): fix hooks tests due to API migration --- .../providers/weaviate/hooks/test_weaviate.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 1b733400f660b..d6821ee886d2e 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -478,24 +478,30 @@ def test_batch_data(data, expected_length, weaviate_hook): # Test the batch_data method weaviate_hook.batch_data(test_collection_name, data) - # Assert that the batch_data method was called with the correct arguments - mock_collection.batch.dynamic.assert_called_once() - mock_batch_context = mock_collection.batch.dynamic.__enter__.return_value + mock_batch_context = mock_collection.batch.dynamic.return_value.__enter__.return_value assert mock_batch_context.add_object.call_count == expected_length -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") -def test_batch_data_retry(get_conn, weaviate_hook): +def test_batch_data_retry(weaviate_hook): """Test to ensure retrying working as expected""" + # Mock the Weaviate Collection + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + data = [{"name": "chandler"}, {"name": "joey"}, {"name": "ross"}] response = requests.Response() response.status_code = 429 error = requests.exceptions.HTTPError() error.response = response side_effect = [None, error, None, error, None] - get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect + + mock_collection.batch.dynamic.return_value.__enter__.return_value.add_object.side_effect = side_effect + weaviate_hook.batch_data("TestCollection", data) - assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect) + + assert mock_collection.batch.dynamic.return_value.__enter__.return_value.add_object.call_count == len( + side_effect + ) @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") From 8a2af9c6447871bef240eb32e3fd1e92764a703d Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 20 Jun 2024 17:04:29 +0800 Subject: [PATCH 48/52] style(providers/weaviate): fix mypy warnings --- airflow/providers/weaviate/hooks/weaviate.py | 95 +++++++++++-------- .../providers/weaviate/operators/weaviate.py | 3 +- .../connections.rst | 2 +- .../providers/weaviate/hooks/test_weaviate.py | 10 +- 4 files changed, 68 insertions(+), 42 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 2d2134ed764a2..997b0a0221122 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -20,7 +20,7 @@ import contextlib import json from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, cast import requests import weaviate @@ -35,13 +35,19 @@ from airflow.hooks.base import BaseHook if TYPE_CHECKING: - from typing import Callable, Collection, Literal + from typing import Callable, Literal import pandas as pd from weaviate.auth import AuthCredentials + from weaviate.collections import Collection from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple - from weaviate.collections.classes.internal import QueryReturnType, QuerySearchReturnType, ReferenceInputs - from weaviate.collections.classes.types import Properties, References, TProperties, TReferences + from weaviate.collections.classes.internal import ( + Object, + QueryReturnType, + QuerySearchReturnType, + ReferenceInputs, + ) + from weaviate.collections.classes.types import Properties from weaviate.types import UUID from airflow.models.connection import Connection @@ -169,12 +175,12 @@ def test_connection(self) -> tuple[bool, str]: self.log.error("Error testing Weaviate connection: %s", e) return False, str(e) - def create_collection(self, name: str, **kwargs) -> Collection[Properties, References]: + def create_collection(self, name: str, **kwargs) -> Collection: """Create a new collection.""" client = self.conn return client.collections.create(name=name, **kwargs) - def get_collection(self, name: str) -> Collection[Properties, References]: + def get_collection(self, name: str) -> Collection: """Get a collection by name. :param name: The name of the collection to get. @@ -309,10 +315,11 @@ def query_with_vector( self, embeddings: list[float], collection_name: str, - *properties: list[str], + properties: list[str], certainty: float = 0.7, limit: int = 1, - ) -> QuerySearchReturnType[Properties, References, TProperties, TReferences]: + **kwargs, + ) -> QuerySearchReturnType: """ Query weaviate database with near vectors. @@ -324,11 +331,13 @@ def query_with_vector( client = self.conn collection = client.collections.get(collection_name) response = collection.query.near_vector( - near_vector=embeddings, certainty=certainty, limit=limit, return_properties=properties + near_vector=embeddings, certainty=certainty, limit=limit, return_properties=properties, **kwargs ) return response - def query_with_text(self, search_text: str, collection_name: str, *properties: list[str], limit: int = 1): + def query_with_text( + self, search_text: str, collection_name: str, properties: list[str], limit: int = 1, **kwargs + ) -> QuerySearchReturnType: """ Query using near text. @@ -338,10 +347,12 @@ def query_with_text(self, search_text: str, collection_name: str, *properties: l """ client = self.conn collection = client.collections.get(collection_name) - response = collection.query.near_text(query=search_text, limit=limit, return_properties=properties) + response = collection.query.near_text( + query=search_text, limit=limit, return_properties=properties, **kwargs + ) return response - def create_object(self, data_object: dict | str, collection_name: str, **kwargs) -> UUID | None: + def create_object(self, data_object: dict, collection_name: str, **kwargs) -> UUID | None: """Create a new object. :param data_object: Object to be added. If type is str it should be either a URL or a file. @@ -360,17 +371,16 @@ def create_object(self, data_object: dict | str, collection_name: str, **kwargs) def get_or_create_object( self, collection_name, - data_object: dict | str | None = None, + data_object: dict, vector: Sequence | None = None, **kwargs, - ) -> QueryReturnType[Properties, References, TProperties, TReferences] | None | UUID: + ) -> QueryReturnType | UUID | None: """Get or Create a new object. Returns the object if already exists, return UUID if not :param collection_name: Collection name associated with the object given.. - :param data_object: Object to be added. If type is str it should be either a URL or a file. This is required - to create a new object. + :param data_object: Object to be added. :param vector: Vector associated with the object given. This argument is only used when creating object. :param kwargs: parameters to be passed to collection.data.fetch_object_by_id() or collection.data.fetch_objects() @@ -385,20 +395,17 @@ def get_or_create_object( ) return obj - def get_object( - self, collection_name: str, **kwargs - ) -> QueryReturnType[Properties, References, TProperties, TReferences] | None: + def get_object(self, collection_name: str, **kwargs) -> QueryReturnType: """Get objects or an object from weaviate. - :param kwargs: parameters to be passed to collection.data.fetch_object_by_id() or - collection.data.fetch_objects() + :param kwargs: parameters to be passed to collection.query.fetch_objects() """ collection = self.get_collection(collection_name) return collection.query.fetch_objects(**kwargs) def get_all_objects( self, collection_name: str, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs - ) -> list[dict[str, Any]] | pd.DataFrame: + ) -> list[Object] | pd.DataFrame: """Get all objects from weaviate. if after is provided, it will be used as the starting point for the listing. @@ -407,7 +414,7 @@ def get_all_objects( :param as_dataframe: if True, returns a pandas dataframe :param kwargs: parameters to be passed to weaviate_client.data_object.get() """ - all_objects = [] + all_objects: list[Object] = [] after = kwargs.pop("after", after) while True: results = self.get_object(collection_name=collection_name, after=after, **kwargs) @@ -419,9 +426,19 @@ def get_all_objects( import pandas # '_WeaviateUUIDInt' object has no attribute 'is_safe' which causes error - for obj in all_objects: - obj.uuid = str(obj.uuid) - return pandas.DataFrame(all_objects) + return pandas.DataFrame( + [ + { + "collection": obj.collection, + "metadata": obj.metadata, + "properties": obj.properties, + "references": obj.references, + "uuid": str(obj.uuid), + "vector": obj.vector, + } + for obj in all_objects + ] + ) return all_objects def delete_object(self, collection_name: str, uuid: UUID | str) -> bool: @@ -474,7 +491,9 @@ def object_exists(self, collection_name: str, uuid: str | UUID) -> bool: collection = self.get_collection(collection_name) return collection.data.exists(uuid=uuid) - def _delete_objects(self, uuids: Collection, collection_name: str, retry_attempts_per_object: int = 5): + def _delete_objects( + self, uuids: list[UUID], collection_name: str, retry_attempts_per_object: int = 5 + ) -> None: """Delete multiple objects. Helper function for `create_or_replace_objects()` to delete multiple objects. @@ -590,17 +609,17 @@ def _get_documents_to_uuid_map( offset = offset + limit if uuid_column in data_objects.objects[0].properties: - data = [obj.properties for obj in data_objects.objects] + data_object_properties = [obj.properties for obj in data_objects.objects] else: - data = [] + data_object_properties = [] for obj in data_objects.objects: - row = obj.properties + row = dict(obj.properties) row[uuid_column] = str(obj.uuid) - data.append(row) + data_object_properties.append(row) documents_to_uuid.update( self._prepare_document_to_uuid_map( - data=data, + data=data_object_properties, group_key=document_column, get_value=lambda x: x[uuid_column], ) @@ -609,7 +628,7 @@ def _get_documents_to_uuid_map( @staticmethod def _prepare_document_to_uuid_map( - data: list[dict], group_key: str, get_value: Callable[[dict], str] + data: Sequence[Mapping], group_key: str, get_value: Callable[[Mapping], str] ) -> dict[str, set]: """Prepare the map of grouped_key to set.""" grouped_key_to_set: dict = {} @@ -666,9 +685,9 @@ def _delete_all_documents_objects( document_column: str, collection_name: str, total_objects_count: int = 1, - batch_delete_error: list | None = None, + batch_delete_error: Sequence | None = None, verbose: bool = False, - ) -> list[dict[str, UUID | str]]: + ) -> Sequence[dict[str, UUID | str]]: """Delete all object that belong to list of documents. :param document_keys: list of unique documents identifiers. @@ -692,7 +711,7 @@ def _delete_all_documents_objects( ) total_objects_count = total_objects_count - MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS matched_objects = delete_many_return.matches - if delete_many_return.failed > 0: + if delete_many_return.failed > 0 and delete_many_return.objects: batch_delete_error = [ {"uuid": obj.uuid, "error": obj.error} for obj in delete_many_return.objects @@ -712,7 +731,7 @@ def create_or_replace_document_objects( uuid_column: str | None = None, vector_column: str = "Vector", verbose: bool = False, - ) -> list[dict[str, UUID | str] | None]: + ) -> Sequence[dict[str, UUID | str] | None]: """ create or replace objects belonging to documents. @@ -782,7 +801,7 @@ def create_or_replace_document_objects( if verbose: self.log.info("%s objects remain after deduplication.", data.shape[0]) - batch_delete_error: list = [] + batch_delete_error: Sequence[dict[str, UUID | str]] = [] ( documents_to_uuid_map, changed_documents, diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py index fad06e94e621d..e2f9ac0d701f1 100644 --- a/airflow/providers/weaviate/operators/weaviate.py +++ b/airflow/providers/weaviate/operators/weaviate.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: import pandas as pd + from weaviate.types import UUID from airflow.utils.context import Context @@ -169,7 +170,7 @@ def hook(self) -> WeaviateHook: """Return an instance of the WeaviateHook.""" return WeaviateHook(conn_id=self.conn_id, **self.hook_params) - def execute(self, context: Context) -> list: + def execute(self, context: Context) -> Sequence[dict[str, UUID | str] | None]: """ Create or replace objects belonging to documents. diff --git a/docs/apache-airflow-providers-weaviate/connections.rst b/docs/apache-airflow-providers-weaviate/connections.rst index 4721d5e747eb6..3e6868007ffa7 100644 --- a/docs/apache-airflow-providers-weaviate/connections.rst +++ b/docs/apache-airflow-providers-weaviate/connections.rst @@ -59,7 +59,7 @@ Weaviate API Token (optional) Use https (optional) Whether to use https for the Weaviate cluster REST and GraphQL API calls. -gRPT host (optional) +gRPC host (optional) The host to use for the Weaviate cluster gRPC API. gRPC port (optional) diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index d6821ee886d2e..48abfb3ffee8a 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -78,8 +78,14 @@ class MockObject: def __init__(self, *, properties: dict, uuid: str) -> None: self.properties = properties self.uuid = uuid - - def __eq__(self, other: MockObject) -> bool: + self.collection = "collection" + self.metadata = "metadata" + self.references = "references" + self.vector = "vector" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MockObject): + return False return self.properties == other.properties and self.uuid == other.uuid From 7e3a8ae6f9caa3c77474da962188b33e2871f0fb Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 25 Jun 2024 16:06:03 +0800 Subject: [PATCH 49/52] docs(providers/weaviate): update changelog --- airflow/providers/weaviate/CHANGELOG.rst | 45 ++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/airflow/providers/weaviate/CHANGELOG.rst b/airflow/providers/weaviate/CHANGELOG.rst index 6634ff7400746..deaa58975648c 100644 --- a/airflow/providers/weaviate/CHANGELOG.rst +++ b/airflow/providers/weaviate/CHANGELOG.rst @@ -20,6 +20,51 @@ Changelog --------- +2.0.0 +...... + + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + * We bumped the minimum version of weaviate-client to 4.4.0. Many of the concepts and methods have been changed. We suggest you read `Migrate from v3 to v4 `_ before you upgrade to this version + +* Add columns ``Port``, ``gRPC host``, ``gRPC port`` and ``Use https``, ``Use a secure channel for the underlying gRPC API`` options to the Weaviate connection. The default values from Airflow providers may not be suitable for using Weaviate correctly, so we recommend explicitly specifying these values. +* Update ``WeaviateIngestOperator`` and ``WeaviateDocumentIngestOperator`` to use ``WeaviateHook`` with ``weaviate-client`` v4 API. + * Argument ``class_name`` is named as ``collection_name`` + * Argument ``batch_params`` is removed. +* Update ``WeaviateHook`` to utilize ``weaviate-client`` v4 API. The implemetation has been extensively changed. We recommend reading `Migrate from v3 to v4 `_ to understand the changes on the Weaviate side before using the updated ``WeaviateHook``. + * Migrate the following public methods + * ``test_connections`` + * ``query_with_vector`` + * ``create_object`` + * ``get_object`` + * ``delete_object`` + * ``update_object`` + * ``replace_object`` + * ``object_exists`` + * ``batch_data`` + * ``get_or_create_object`` + * ``create_or_replace_document_objects`` + + * Rename the following public methods + * ``update_schema`` to ``update_collection_configuration`` + * ``create_class`` to ``create_collection`` + * ``get_schema`` to ``get_collection_configuraiton`` + * ``delete_classes`` to ``delete_collections`` + * ``query_without_vector`` to ``query_with_text`` + + * Remove the following public methods + * ``validate_object`` + * ``update_schema`` + * ``create_schema`` + * ``delete_all_schema`` + * ``check_subset_of_schema`` + + * Remove deprecated method ``WeaviateHook.get_client`` + * Remove unused argument ``retry_status_codes`` in ``WeaviateHook.__init__`` + 1.4.2 ..... From e881d701a14083f769aa8b922392c115204aeb05 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 25 Jun 2024 16:06:53 +0800 Subject: [PATCH 50/52] docs(providers/weaviate): add more detail description to host --- airflow/providers/weaviate/CHANGELOG.rst | 2 +- docs/apache-airflow-providers-weaviate/connections.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/weaviate/CHANGELOG.rst b/airflow/providers/weaviate/CHANGELOG.rst index deaa58975648c..e13737b53657f 100644 --- a/airflow/providers/weaviate/CHANGELOG.rst +++ b/airflow/providers/weaviate/CHANGELOG.rst @@ -34,7 +34,7 @@ Breaking changes * Update ``WeaviateIngestOperator`` and ``WeaviateDocumentIngestOperator`` to use ``WeaviateHook`` with ``weaviate-client`` v4 API. * Argument ``class_name`` is named as ``collection_name`` * Argument ``batch_params`` is removed. -* Update ``WeaviateHook`` to utilize ``weaviate-client`` v4 API. The implemetation has been extensively changed. We recommend reading `Migrate from v3 to v4 `_ to understand the changes on the Weaviate side before using the updated ``WeaviateHook``. +* Update ``WeaviateHook`` to utilize ``weaviate-client`` v4 API. The implementation has been extensively changed. We recommend reading `Migrate from v3 to v4 `_ to understand the changes on the Weaviate side before using the updated ``WeaviateHook``. * Migrate the following public methods * ``test_connections`` * ``query_with_vector`` diff --git a/docs/apache-airflow-providers-weaviate/connections.rst b/docs/apache-airflow-providers-weaviate/connections.rst index 3e6868007ffa7..428d058d3e96a 100644 --- a/docs/apache-airflow-providers-weaviate/connections.rst +++ b/docs/apache-airflow-providers-weaviate/connections.rst @@ -31,7 +31,7 @@ Configuring the Connection -------------------------- Host (required) - The host to use for the Weaviate cluster REST and GraphQL API calls. + The host to use for the Weaviate cluster REST and GraphQL API calls. DO NOT include the schema (i.e., http or https). OIDC Username (optional) Username for the OIDC user when OIDC option is to be used for authentication. From a2224e7f3160e8606127e9ffb6f6798a92cff7c3 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 25 Jun 2024 16:36:25 +0800 Subject: [PATCH 51/52] docs(providers/weaviate): fix changelog rst format --- airflow/providers/weaviate/CHANGELOG.rst | 38 ++++-------------------- 1 file changed, 6 insertions(+), 32 deletions(-) diff --git a/airflow/providers/weaviate/CHANGELOG.rst b/airflow/providers/weaviate/CHANGELOG.rst index e13737b53657f..42796ffdd125c 100644 --- a/airflow/providers/weaviate/CHANGELOG.rst +++ b/airflow/providers/weaviate/CHANGELOG.rst @@ -31,39 +31,13 @@ Breaking changes * We bumped the minimum version of weaviate-client to 4.4.0. Many of the concepts and methods have been changed. We suggest you read `Migrate from v3 to v4 `_ before you upgrade to this version * Add columns ``Port``, ``gRPC host``, ``gRPC port`` and ``Use https``, ``Use a secure channel for the underlying gRPC API`` options to the Weaviate connection. The default values from Airflow providers may not be suitable for using Weaviate correctly, so we recommend explicitly specifying these values. -* Update ``WeaviateIngestOperator`` and ``WeaviateDocumentIngestOperator`` to use ``WeaviateHook`` with ``weaviate-client`` v4 API. - * Argument ``class_name`` is named as ``collection_name`` - * Argument ``batch_params`` is removed. +* Update ``WeaviateIngestOperator`` and ``WeaviateDocumentIngestOperator`` to use ``WeaviateHook`` with ``weaviate-client`` v4 API. The major changes are changing argument ``class_name`` to ``collection_name`` and removing ``batch_params``. * Update ``WeaviateHook`` to utilize ``weaviate-client`` v4 API. The implementation has been extensively changed. We recommend reading `Migrate from v3 to v4 `_ to understand the changes on the Weaviate side before using the updated ``WeaviateHook``. - * Migrate the following public methods - * ``test_connections`` - * ``query_with_vector`` - * ``create_object`` - * ``get_object`` - * ``delete_object`` - * ``update_object`` - * ``replace_object`` - * ``object_exists`` - * ``batch_data`` - * ``get_or_create_object`` - * ``create_or_replace_document_objects`` - - * Rename the following public methods - * ``update_schema`` to ``update_collection_configuration`` - * ``create_class`` to ``create_collection`` - * ``get_schema`` to ``get_collection_configuraiton`` - * ``delete_classes`` to ``delete_collections`` - * ``query_without_vector`` to ``query_with_text`` - - * Remove the following public methods - * ``validate_object`` - * ``update_schema`` - * ``create_schema`` - * ``delete_all_schema`` - * ``check_subset_of_schema`` - - * Remove deprecated method ``WeaviateHook.get_client`` - * Remove unused argument ``retry_status_codes`` in ``WeaviateHook.__init__`` +* Migrate the following ``WeaviateHook`` public methods to v4 API: ``test_connections``, ``query_with_vector``, ``create_object``, ``get_object``, ``delete_object``, ``update_object``, ``replace_object``, ``object_exists``, ``batch_data``, ``get_or_create_object``, ``create_or_replace_document_objects`` +* Rename ``WeaviateHook`` public methods ``update_schema`` as ``update_collection_configuration``, ``create_class`` as ``create_collection``, ``get_schema`` as ``get_collection_configuraiton``, ``delete_classes`` as ``delete_collections`` and ``query_without_vector`` as ``query_with_text``. +* Remove the following ``WeaviateHook`` public methods: ``validate_object``, ``update_schema``, ``create_schema``, ``delete_all_schema``, ``check_subset_of_schema`` +* Remove deprecated method ``WeaviateHook.get_client`` +* Remove unused argument ``retry_status_codes`` in ``WeaviateHook.__init__`` 1.4.2 ..... From 8f6f4aed5a74afa6a83fe1cbdf12f120d240cd06 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 25 Jun 2024 16:37:07 +0800 Subject: [PATCH 52/52] build(providers/weaviate): add 2.0.0 to provider metadata --- airflow/providers/weaviate/CHANGELOG.rst | 3 --- airflow/providers/weaviate/provider.yaml | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/airflow/providers/weaviate/CHANGELOG.rst b/airflow/providers/weaviate/CHANGELOG.rst index 42796ffdd125c..01301a7883a77 100644 --- a/airflow/providers/weaviate/CHANGELOG.rst +++ b/airflow/providers/weaviate/CHANGELOG.rst @@ -39,9 +39,6 @@ Breaking changes * Remove deprecated method ``WeaviateHook.get_client`` * Remove unused argument ``retry_status_codes`` in ``WeaviateHook.__init__`` -1.4.2 -..... - Misc ~~~~ diff --git a/airflow/providers/weaviate/provider.yaml b/airflow/providers/weaviate/provider.yaml index 577fce2688a70..47e2fbd386e3d 100644 --- a/airflow/providers/weaviate/provider.yaml +++ b/airflow/providers/weaviate/provider.yaml @@ -28,7 +28,7 @@ source-date-epoch: 1718605569 # note that those versions are maintained by release manager - do not update them manually versions: - - 1.4.2 + - 2.0.0 - 1.4.1 - 1.4.0 - 1.3.4