diff --git a/airflow/providers/weaviate/CHANGELOG.rst b/airflow/providers/weaviate/CHANGELOG.rst index 6634ff7400746..01301a7883a77 100644 --- a/airflow/providers/weaviate/CHANGELOG.rst +++ b/airflow/providers/weaviate/CHANGELOG.rst @@ -20,8 +20,24 @@ Changelog --------- -1.4.2 -..... +2.0.0 +...... + + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + * We bumped the minimum version of weaviate-client to 4.4.0. Many of the concepts and methods have been changed. We suggest you read `Migrate from v3 to v4 `_ before you upgrade to this version + +* Add columns ``Port``, ``gRPC host``, ``gRPC port`` and ``Use https``, ``Use a secure channel for the underlying gRPC API`` options to the Weaviate connection. The default values from Airflow providers may not be suitable for using Weaviate correctly, so we recommend explicitly specifying these values. +* Update ``WeaviateIngestOperator`` and ``WeaviateDocumentIngestOperator`` to use ``WeaviateHook`` with ``weaviate-client`` v4 API. The major changes are changing argument ``class_name`` to ``collection_name`` and removing ``batch_params``. +* Update ``WeaviateHook`` to utilize ``weaviate-client`` v4 API. The implementation has been extensively changed. We recommend reading `Migrate from v3 to v4 `_ to understand the changes on the Weaviate side before using the updated ``WeaviateHook``. +* Migrate the following ``WeaviateHook`` public methods to v4 API: ``test_connections``, ``query_with_vector``, ``create_object``, ``get_object``, ``delete_object``, ``update_object``, ``replace_object``, ``object_exists``, ``batch_data``, ``get_or_create_object``, ``create_or_replace_document_objects`` +* Rename ``WeaviateHook`` public methods ``update_schema`` as ``update_collection_configuration``, ``create_class`` as ``create_collection``, ``get_schema`` as ``get_collection_configuraiton``, ``delete_classes`` as ``delete_collections`` and ``query_without_vector`` as ``query_with_text``. +* Remove the following ``WeaviateHook`` public methods: ``validate_object``, ``update_schema``, ``create_schema``, ``delete_all_schema``, ``check_subset_of_schema`` +* Remove deprecated method ``WeaviateHook.get_client`` +* Remove unused argument ``retry_status_codes`` in ``WeaviateHook.__init__`` Misc ~~~~ diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 56c7f666330f9..997b0a0221122 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -20,27 +20,38 @@ import contextlib import json from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, cast import requests +import weaviate import weaviate.exceptions -from deprecated import deprecated from tenacity import Retrying, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt -from weaviate import Client as WeaviateClient -from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword -from weaviate.data.replication import ConsistencyLevel +from weaviate import WeaviateClient +from weaviate.auth import Auth +from weaviate.classes.query import Filter from weaviate.exceptions import ObjectAlreadyExistsException from weaviate.util import generate_uuid5 -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook if TYPE_CHECKING: - from typing import Callable, Collection, Literal + from typing import Callable, Literal import pandas as pd + from weaviate.auth import AuthCredentials + from weaviate.collections import Collection + from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple + from weaviate.collections.classes.internal import ( + Object, + QueryReturnType, + QuerySearchReturnType, + ReferenceInputs, + ) + from weaviate.collections.classes.types import Properties from weaviate.types import UUID + from airflow.models.connection import Connection + ExitingSchemaOptions = Literal["replace", "fail", "ignore"] HTTP_RETRY_STATUS_CODE = [429, 500, 503, 504] @@ -76,7 +87,6 @@ class WeaviateHook(BaseHook): def __init__( self, conn_id: str = default_conn_name, - retry_status_codes: list[int] | None = None, *args: Any, **kwargs: Any, ) -> None: @@ -86,19 +96,25 @@ def __init__( @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to connection form.""" - from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext - from wtforms import PasswordField + from wtforms import BooleanField, PasswordField, StringField return { + "http_secure": BooleanField(lazy_gettext("Use https"), default=False), "token": PasswordField(lazy_gettext("Weaviate API Key"), widget=BS3PasswordFieldWidget()), + "grpc_host": StringField(lazy_gettext("gRPC host"), widget=BS3TextFieldWidget()), + "grpc_port": StringField(lazy_gettext("gRPC port"), widget=BS3TextFieldWidget()), + "grcp_secure": BooleanField( + lazy_gettext("Use a secure channel for the underlying gRPC API"), default=False + ), } @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { - "hidden_fields": ["port", "schema"], + "hidden_fields": ["schema"], "relabeling": { "login": "OIDC Username", "password": "OIDC Password", @@ -107,127 +123,89 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: def get_conn(self) -> WeaviateClient: conn = self.get_connection(self.conn_id) - url = conn.host - username = conn.login or "" - password = conn.password or "" extras = conn.extra_dejson - access_token = extras.get("access_token", None) - refresh_token = extras.get("refresh_token", None) - expires_in = extras.get("expires_in", 60) + http_secure = extras.pop("http_secure", False) + grpc_secure = extras.pop("grcp_secure", False) + return weaviate.connect_to_custom( + http_host=conn.host, + http_port=conn.port or 443 if http_secure else 80, + http_secure=http_secure, + grpc_host=extras.pop("grpc_host", conn.host), + grpc_port=extras.pop("grpc_port", 443 if grpc_secure else 80), + grpc_secure=grpc_secure, + headers=extras.pop("additional_headers", {}), + auth_credentials=self._extract_auth_credentials(conn), + ) + + def _extract_auth_credentials(self, conn: Connection) -> AuthCredentials: + extras = conn.extra_dejson # previously token was used as api_key(backwards compatibility) api_key = extras.get("api_key", None) or extras.get("token", None) - client_secret = extras.get("client_secret", None) - additional_headers = extras.pop("additional_headers", {}) - scope = extras.get("scope", None) or extras.get("oidc_scope", None) if api_key: - auth_client_secret: AuthApiKey | AuthBearerToken | AuthClientCredentials | AuthClientPassword = ( - AuthApiKey(api_key) - ) - elif access_token: - auth_client_secret = AuthBearerToken( - access_token, expires_in=expires_in, refresh_token=refresh_token + return Auth.api_key(api_key=api_key) + + access_token = extras.get("access_token", None) + if access_token: + refresh_token = extras.get("refresh_token", None) + expires_in = extras.get("expires_in", 60) + return Auth.bearer_token( + access_token=access_token, expires_in=expires_in, refresh_token=refresh_token ) - elif client_secret: - auth_client_secret = AuthClientCredentials(client_secret=client_secret, scope=scope) - else: - auth_client_secret = AuthClientPassword(username=username, password=password, scope=scope) - return WeaviateClient( - url=url, auth_client_secret=auth_client_secret, additional_headers=additional_headers - ) + scope = extras.get("scope", None) or extras.get("oidc_scope", None) + client_secret = extras.get("client_secret", None) + if client_secret: + return Auth.client_credentials(client_secret=client_secret, scope=scope) + + username = conn.login or "" + password = conn.password or "" + return Auth.client_password(username=username, password=password, scope=scope) @cached_property def conn(self) -> WeaviateClient: """Returns a Weaviate client.""" return self.get_conn() - @deprecated( - reason="The `get_client` method has been renamed to `get_conn`", - category=AirflowProviderDeprecationWarning, - ) - def get_client(self) -> WeaviateClient: - """Return a Weaviate client.""" - # Keeping this for backwards compatibility - return self.conn - def test_connection(self) -> tuple[bool, str]: try: client = self.conn - client.schema.get() + client.collections.list_all() return True, "Connection established!" except Exception as e: self.log.error("Error testing Weaviate connection: %s", e) return False, str(e) - def create_class(self, class_json: dict[str, Any]) -> None: - """Create a new class.""" + def create_collection(self, name: str, **kwargs) -> Collection: + """Create a new collection.""" client = self.conn - client.schema.create_class(class_json) - - @retry( - reraise=True, - stop=stop_after_attempt(3), - retry=( - retry_if_exception(lambda exc: check_http_error_is_retryable(exc)) - | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) - ), - ) - def create_schema(self, schema_json: dict[str, Any] | str) -> None: - """ - Create a new Schema. + return client.collections.create(name=name, **kwargs) - Instead of adding classes one by one , you can upload a full schema in JSON format at once. + def get_collection(self, name: str) -> Collection: + """Get a collection by name. - :param schema_json: Schema as a Python dict or the path to a JSON file, or the URL of a JSON file. + :param name: The name of the collection to get. """ client = self.conn - client.schema.create(schema_json) - - @staticmethod - def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) -> list[dict[str, Any]]: - """Convert dataframe to list of dicts. - - In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing - Pandas will fail, which is invalid. This function handles this scenario. - """ - with contextlib.suppress(ImportError): - import pandas + return client.collections.get(name) - if isinstance(data, pandas.DataFrame): - data = json.loads(data.to_json(orient="records")) - return cast(List[Dict[str, Any]], data) - - @retry( - reraise=True, - stop=stop_after_attempt(3), - retry=( - retry_if_exception(lambda exc: check_http_error_is_retryable(exc)) - | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) - ), - ) - def get_schema(self, class_name: str | None = None): - """Get the schema from Weaviate. - - :param class_name: The class for which to return the schema. If NOT provided the whole schema is - returned, otherwise only the schema of this class is returned. By default None. - """ - client = self.get_client() - return client.schema.get(class_name) + def delete_collections( + self, collection_names: list[str] | str, if_error: str = "stop" + ) -> list[str] | None: + """Delete all or specific collections if collection_names are provided. - def delete_classes(self, class_names: list[str] | str, if_error: str = "stop") -> list[str] | None: - """Delete all or specific classes if class_names are provided. - - :param class_names: list of class names to be deleted. - :param if_error: define the actions to be taken if there is an error while deleting a class, possible + :param collection_names: list of collection names to be deleted. + :param if_error: define the actions to be taken if there is an error while deleting a collection, possible options are `stop` and `continue` - :return: if `if_error=continue` return list of classes which we failed to delete. + :return: if `if_error=continue` return list of collections which we failed to delete. if `if_error=stop` returns None. """ - client = self.get_client() - class_names = [class_names] if class_names and isinstance(class_names, str) else class_names + client = self.get_conn() + collection_names = ( + [collection_names] if collection_names and isinstance(collection_names, str) else collection_names + ) - failed_class_list = [] - for class_name in class_names: + failed_collection_list = [] + for collection_name in collection_names: try: for attempt in Retrying( stop=stop_after_attempt(3), @@ -237,229 +215,77 @@ def delete_classes(self, class_names: list[str] | str, if_error: str = "stop") - ), ): with attempt: - print(attempt) - client.schema.delete_class(class_name) + self.log.info(attempt) + client.collections.delete(collection_name) except Exception as e: if if_error == "continue": self.log.error(e) - failed_class_list.append(class_name) + failed_collection_list.append(collection_name) elif if_error == "stop": raise e if if_error == "continue": - return failed_class_list + return failed_collection_list return None - def delete_all_schema(self): - """Remove the entire schema from the Weaviate instance and all data associated with it.""" - client = self.get_client() - return client.schema.delete_all() - - def update_config(self, class_name: str, config: dict): - """Update a schema configuration for a specific class.""" - client = self.get_client() - client.schema.update_config(class_name=class_name, config=config) - - def create_or_replace_classes( - self, schema_json: dict[str, Any] | str, existing: ExitingSchemaOptions = "ignore" - ): - """ - Create or replace the classes in schema of Weaviate database. + @retry( + reraise=True, + stop=stop_after_attempt(3), + retry=( + retry_if_exception(lambda exc: check_http_error_is_retryable(exc)) + | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) + ), + ) + def get_collection_configuraiton(self, collection_name: str) -> CollectionConfig | CollectionConfigSimple: + """Get the collection configuration from Weaviate. - :param schema_json: Json containing the schema. Format {"class_name": "class_dict"} - .. seealso:: `example of class_dict `_. - :param existing: Options to handle the case when the classes exist, possible options - 'replace', 'fail', 'ignore'. + :param collection_name: The collection for which to return the collection configuration. """ - existing_schema_options = ["replace", "fail", "ignore"] - if existing not in existing_schema_options: - raise ValueError(f"Param 'existing' should be one of the {existing_schema_options} values.") - if isinstance(schema_json, str): - schema_json = cast(dict, json.load(open(schema_json))) - set__exiting_classes = {class_object["class"] for class_object in self.get_schema()["classes"]} - set__to_be_added_classes = {key for key, _ in schema_json.items()} - intersection_classes = set__exiting_classes.intersection(set__to_be_added_classes) - classes_to_create = set() - if existing == "fail" and intersection_classes: - raise ValueError(f"Trying to create class {intersection_classes} but this class already exists.") - elif existing == "ignore": - classes_to_create = set__to_be_added_classes - set__exiting_classes - elif existing == "replace": - error_list = self.delete_classes(class_names=list(intersection_classes)) - if error_list: - raise ValueError(error_list) - classes_to_create = intersection_classes.union(set__to_be_added_classes) - classes_to_create_list = [schema_json[item] for item in sorted(list(classes_to_create))] - self.create_schema({"classes": classes_to_create_list}) - - def _compare_schema_subset(self, subset_object: Any, superset_object: Any) -> bool: - """ - Recursively check if requested subset_object is a subset of the superset_object. - - Example 1: - superset_object = {"a": {"b": [1, 2, 3], "c": "d"}} - subset_object = {"a": {"c": "d"}} - _compare_schema_subset(subset_object, superset_object) # will result in True + client = self.get_conn() + return client.collections.get(collection_name).config.get() - superset_object = {"a": {"b": [1, 2, 3], "c": "d"}} - subset_object = {"a": {"d": "e"}} - _compare_schema_subset(subset_object, superset_object) # will result in False - - :param subset_object: The object to be checked - :param superset_object: The object to check against - """ - # Direct equality check - if subset_object == superset_object: - return True - - # Type mismatch early return - if type(subset_object) != type(superset_object): - return False - - # Dictionary comparison - if isinstance(subset_object, dict): - for k, v in subset_object.items(): - if (k not in superset_object) or (not self._compare_schema_subset(v, superset_object[k])): - return False - return True - - # List or Tuple comparison - if isinstance(subset_object, (list, tuple)): - for sub, sup in zip(subset_object, superset_object): - if len(subset_object) > len(superset_object) or not self._compare_schema_subset(sub, sup): - return False - return True - - # Default case for non-matching types or unsupported types - return False + def update_collection_configuration(self, collection_name: str, **kwargs) -> None: + """Update the collection configuration.""" + collection = self.get_collection(collection_name) + collection.config.update(**kwargs) @staticmethod - def _convert_properties_to_dict(classes_objects, key_property: str = "name"): - """ - Convert list of class properties into dict by using a `key_property` as key. - - This is done to avoid class properties comparison as list of properties. - - Case 1: - A = [1, 2, 3] - B = [1, 2] - When comparing list we check for the length, but it's not suitable for subset check. + def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) -> list[dict[str, Any]]: + """Convert dataframe to list of dicts. - Case 2: - A = [1, 2, 3] - B = [1, 3, 2] - When we compare two lists, we compare item 1 of list A with item 1 of list B and - pass if the two are same, but there can be scenarios when the properties are not in same order. + In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing + Pandas will fail, which is invalid. This function handles this scenario. """ - for cls in classes_objects: - cls["properties"] = {p[key_property]: p for p in cls["properties"]} - return classes_objects - - def check_subset_of_schema(self, classes_objects: list) -> bool: - """Check if the class_objects is a subset of existing schema. - - Note - weaviate client's `contains()` don't handle the class properties mismatch, if you want to - compare `Class A` with `Class B` they must have exactly same properties. If `Class A` has fewer - numbers of properties than Class B, `contains()` will result in False. + with contextlib.suppress(ImportError): + import pandas - .. seealso:: `contains `_. - """ - # When the class properties are not in same order or not the same length. We convert them to dicts - # with property `name` as the key. This way we ensure, the subset is checked. - - classes_objects = self._convert_properties_to_dict(classes_objects) - exiting_classes_list = self._convert_properties_to_dict(self.get_schema()["classes"]) - - exiting_classes = {cls["class"]: cls for cls in exiting_classes_list} - exiting_classes_set = set(exiting_classes.keys()) - input_classes_set = {cls["class"] for cls in classes_objects} - if not input_classes_set.issubset(exiting_classes_set): - return False - for cls in classes_objects: - if not self._compare_schema_subset(cls, exiting_classes[cls["class"]]): - return False - return True + if isinstance(data, pandas.DataFrame): + data = json.loads(data.to_json(orient="records")) + return cast(List[Dict[str, Any]], data) def batch_data( self, - class_name: str, + collection_name: str, data: list[dict[str, Any]] | pd.DataFrame | None, - batch_config_params: dict[str, Any] | None = None, vector_col: str = "Vector", uuid_col: str = "id", retry_attempts_per_object: int = 5, - tenant: str | None = None, - ) -> list: + references: ReferenceInputs | None = None, + ) -> None: """ Add multiple objects or object references at once into weaviate. - :param class_name: The name of the class that objects belongs to. + :param collection_name: The name of the collection that objects belongs to. :param data: list or dataframe of objects we want to add. - :param batch_config_params: dict of batch configuration option. - .. seealso:: `batch_config_params options `__ :param vector_col: name of the column containing the vector. :param uuid_col: Name of the column containing the UUID. :param retry_attempts_per_object: number of time to try in case of failure before giving up. - :param tenant: The tenant to which the object will be added. + :param references: The references of the object to be added as a dictionary. Use `wvc.Reference.to` to create the correct values in the dict. """ converted_data = self._convert_dataframe_to_list(data) - total_results = 0 - error_results = 0 - insertion_errors: list = [] - - def _process_batch_errors( - results: list, - verbose: bool = True, - ) -> None: - """ - Process the results from insert or delete batch operation and collects any errors. - - :param results: Results from the batch operation. - :param verbose: Flag to enable verbose logging. - """ - nonlocal total_results - nonlocal error_results - total_batch_results = len(results) - error_batch_results = 0 - for item in results: - if "errors" in item["result"]: - error_batch_results = error_batch_results + 1 - item_error = {"uuid": item["id"], "errors": item["result"]["errors"]} - if verbose: - self.log.info( - "Error occurred in batch process for %s with error %s", - item["id"], - item["result"]["errors"], - ) - insertion_errors.append(item_error) - if verbose: - total_results = total_results + (total_batch_results - error_batch_results) - error_results = error_results + error_batch_results - self.log.info( - "Total Objects %s / Objects %s successfully inserted and Objects %s had errors.", - len(converted_data), - total_results, - error_results, - ) - - client = self.conn - if not batch_config_params: - batch_config_params = {} - - # configuration for context manager for __exit__ method to callback on errors for weaviate - # batch ingestion. - if not batch_config_params.get("callback"): - batch_config_params["callback"] = _process_batch_errors - - if not batch_config_params.get("timeout_retries"): - batch_config_params["timeout_retries"] = 5 - - if not batch_config_params.get("connection_error_retries"): - batch_config_params["connection_error_retries"] = 5 - - client.batch.configure(**batch_config_params) - with client.batch as batch: + collection = self.get_collection(collection_name) + with collection.batch.dynamic() as batch: # Batch import all data for data_obj in converted_data: for attempt in Retrying( @@ -477,44 +303,41 @@ def _process_batch_errors( attempt.retry_state.attempt_number, uuid, ) - batch.add_data_object( - data_object=data_obj, - class_name=class_name, - vector=vector, + batch.add_object( + properties=data_obj, + references=references, uuid=uuid, - tenant=tenant, + vector=vector, ) self.log.debug("Inserted object with uuid: %s into batch", uuid) - return insertion_errors def query_with_vector( self, embeddings: list[float], - class_name: str, - *properties: list[str], + collection_name: str, + properties: list[str], certainty: float = 0.7, limit: int = 1, - ) -> dict[str, dict[Any, Any]]: + **kwargs, + ) -> QuerySearchReturnType: """ Query weaviate database with near vectors. This method uses a vector search using a Get query. we are using a with_near_vector to provide - weaviate with a query with vector itself. This is needed for query a Weaviate class with a custom, + weaviate with a query with vector itself. This is needed for query a Weaviate class with a custom, external vectorizer. Weaviate then converts this into a vector through the inference API (OpenAI in this particular example) and uses that vector as the basis for a vector search. """ client = self.conn - results: dict[str, dict[Any, Any]] = ( - client.query.get(class_name, properties[0]) - .with_near_vector({"vector": embeddings, "certainty": certainty}) - .with_limit(limit) - .do() + collection = client.collections.get(collection_name) + response = collection.query.near_vector( + near_vector=embeddings, certainty=certainty, limit=limit, return_properties=properties, **kwargs ) - return results + return response - def query_without_vector( - self, search_text: str, class_name: str, *properties: list[str], limit: int = 1 - ) -> dict[str, dict[Any, Any]]: + def query_with_text( + self, search_text: str, collection_name: str, properties: list[str], limit: int = 1, **kwargs + ) -> QuerySearchReturnType: """ Query using near text. @@ -523,83 +346,66 @@ def query_without_vector( API (OpenAI in this particular example) and uses that vector as the basis for a vector search. """ client = self.conn - results: dict[str, dict[Any, Any]] = ( - client.query.get(class_name, properties[0]) - .with_near_text({"concepts": [search_text]}) - .with_limit(limit) - .do() + collection = client.collections.get(collection_name) + response = collection.query.near_text( + query=search_text, limit=limit, return_properties=properties, **kwargs ) - return results + return response - def create_object( - self, data_object: dict | str, class_name: str, **kwargs - ) -> str | dict[str, Any] | None: + def create_object(self, data_object: dict, collection_name: str, **kwargs) -> UUID | None: """Create a new object. :param data_object: Object to be added. If type is str it should be either a URL or a file. - :param class_name: Class name associated with the object given. + :param collection_name: Collection name associated with the object given. :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() """ - client = self.conn + collection = self.get_collection(collection_name) # generate deterministic uuid if not provided uuid = kwargs.pop("uuid", generate_uuid5(data_object)) try: - return client.data_object.create(data_object, class_name, uuid=uuid, **kwargs) + return collection.data.insert(properties=data_object, uuid=uuid, **kwargs) except ObjectAlreadyExistsException: self.log.warning("Object with the UUID %s already exists", uuid) return None def get_or_create_object( self, - data_object: dict | str | None = None, - class_name: str | None = None, + collection_name, + data_object: dict, vector: Sequence | None = None, - consistency_level: ConsistencyLevel | None = None, - tenant: str | None = None, **kwargs, - ) -> str | dict[str, Any] | None: + ) -> QueryReturnType | UUID | None: """Get or Create a new object. - Returns the object if already exists + Returns the object if already exists, return UUID if not - :param data_object: Object to be added. If type is str it should be either a URL or a file. This is required - to create a new object. - :param class_name: Class name associated with the object given. This is required to create a new object. + :param collection_name: Collection name associated with the object given.. + :param data_object: Object to be added. :param vector: Vector associated with the object given. This argument is only used when creating object. - :param consistency_level: Consistency level to be used. Applies to both create and get operations. - :param tenant: Tenant to be used. Applies to both create and get operations. - :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() and - weaviate_client.data_object.get() + :param kwargs: parameters to be passed to collection.data.fetch_object_by_id() or + collection.data.fetch_objects() """ - obj = self.get_object( - class_name=class_name, consistency_level=consistency_level, tenant=tenant, **kwargs - ) + obj = self.get_object(collection_name=collection_name, **kwargs) if not obj: - if not (data_object and class_name): - raise ValueError("data_object and class_name are required to create a new object") + if not (data_object and collection_name): + raise ValueError("data_object and collection are required to create a new object") uuid = kwargs.pop("uuid", generate_uuid5(data_object)) return self.create_object( - data_object, - class_name, - vector=vector, - uuid=uuid, - consistency_level=consistency_level, - tenant=tenant, + data_object=data_object, collection_name=collection_name, uuid=uuid, vector=vector, **kwargs ) return obj - def get_object(self, **kwargs) -> dict[str, Any] | None: + def get_object(self, collection_name: str, **kwargs) -> QueryReturnType: """Get objects or an object from weaviate. - :param kwargs: parameters to be passed to weaviate_client.data_object.get() or - weaviate_client.data_object.get_by_id() + :param kwargs: parameters to be passed to collection.query.fetch_objects() """ - client = self.conn - return client.data_object.get(**kwargs) + collection = self.get_collection(collection_name) + return collection.query.fetch_objects(**kwargs) def get_all_objects( - self, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs - ) -> list[dict[str, Any]] | pd.DataFrame: + self, collection_name: str, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs + ) -> list[Object] | pd.DataFrame: """Get all objects from weaviate. if after is provided, it will be used as the starting point for the listing. @@ -608,80 +414,92 @@ def get_all_objects( :param as_dataframe: if True, returns a pandas dataframe :param kwargs: parameters to be passed to weaviate_client.data_object.get() """ - all_objects = [] + all_objects: list[Object] = [] after = kwargs.pop("after", after) while True: - results = self.get_object(after=after, **kwargs) or {} - if not results.get("objects"): + results = self.get_object(collection_name=collection_name, after=after, **kwargs) + if not results or not results.objects: break - all_objects.extend(results["objects"]) - after = results["objects"][-1]["id"] + all_objects.extend(results.objects) + after = results.objects[-1].uuid if as_dataframe: import pandas - return pandas.DataFrame(all_objects) + # '_WeaviateUUIDInt' object has no attribute 'is_safe' which causes error + return pandas.DataFrame( + [ + { + "collection": obj.collection, + "metadata": obj.metadata, + "properties": obj.properties, + "references": obj.references, + "uuid": str(obj.uuid), + "vector": obj.vector, + } + for obj in all_objects + ] + ) return all_objects - def delete_object(self, uuid: UUID | str, **kwargs) -> None: + def delete_object(self, collection_name: str, uuid: UUID | str) -> bool: """Delete an object from weaviate. + :param collection_name: Collection name associated with the object given. :param uuid: uuid of the object to be deleted - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.delete() """ - client = self.conn - client.data_object.delete(uuid, **kwargs) + collection = self.get_collection(collection_name) + return collection.data.delete_by_id(uuid=uuid) - def update_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None: + def update_object( + self, collection_name: str, uuid: UUID | str, properties: Properties | None = None, **kwargs + ) -> None: """Update an object in weaviate. - :param data_object: The object states the fields that should be updated. Fields not specified in the - 'data_object' remain unchanged. Fields that are None will not be changed. - If type is str it should be either an URL or a file. - :param class_name: Class name associated with the object given. + :param collection_name: Collection name associated with the object given. :param uuid: uuid of the object to be updated - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.update() + :param properties: The properties of the object. + :param kwargs: Optional parameters to be passed to collection.data.update() """ - client = self.conn - client.data_object.update(data_object, class_name, uuid, **kwargs) + collection = self.get_collection(collection_name) + collection.data.update(uuid=uuid, properties=properties, **kwargs) - def replace_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None: + def replace_object( + self, + collection_name: str, + uuid: UUID | str, + properties: Properties, + references: ReferenceInputs | None = None, + **kwargs, + ) -> None: """Replace an object in weaviate. - :param data_object: The object states the fields that should be updated. Fields not specified in the - 'data_object' will be set to None. If type is str it should be either an URL or a file. - :param class_name: Class name associated with the object given. - :param uuid: uuid of the object to be replaced - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.replace() + :param collection_name: Collection name associated with the object given. + :param uuid: uuid of the object to be updated + :param properties: The properties of the object. + :param references: Any references to other objects in Weaviate. + :param kwargs: Optional parameters to be passed to collection.data.replace() """ - client = self.conn - client.data_object.replace(data_object, class_name, uuid, **kwargs) - - def validate_object(self, data_object: dict | str, class_name: str, **kwargs): - """Validate an object in weaviate. + collection = self.get_collection(collection_name) + collection.data.replace(uuid=uuid, properties=properties, references=references, **kwargs) - :param data_object: The object to be validated. If type is str it should be either an URL or a file. - :param class_name: Class name associated with the object given. - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.validate() - """ - client = self.conn - client.data_object.validate(data_object, class_name, **kwargs) - - def object_exists(self, uuid: str | UUID, **kwargs) -> bool: + def object_exists(self, collection_name: str, uuid: str | UUID) -> bool: """Check if an object exists in weaviate. + :param collection_name: Collection name associated with the object given. :param uuid: The UUID of the object that may or may not exist within Weaviate. - :param kwargs: Optional parameters to be passed to weaviate_client.data_object.exists() """ - client = self.conn - return client.data_object.exists(uuid, **kwargs) + collection = self.get_collection(collection_name) + return collection.data.exists(uuid=uuid) - def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per_object: int = 5): + def _delete_objects( + self, uuids: list[UUID], collection_name: str, retry_attempts_per_object: int = 5 + ) -> None: """Delete multiple objects. Helper function for `create_or_replace_objects()` to delete multiple objects. :param uuids: Collection of uuids. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param collection_name: Name of the collection in Weaviate schema where data is to be ingested. :param retry_attempts_per_object: number of times to try in case of failure before giving up. """ for uuid in uuids: @@ -694,7 +512,7 @@ def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per ): with attempt: try: - self.delete_object(uuid=uuid, class_name=class_name) + self.delete_object(uuid=uuid, collection_name=collection_name) self.log.debug("Deleted object with uuid %s", uuid) except weaviate.exceptions.UnexpectedStatusCodeException as e: if e.status_code == 404: @@ -708,7 +526,7 @@ def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per def _generate_uuids( self, df: pd.DataFrame, - class_name: str, + collection_name: str, unique_columns: list[str], vector_column: str | None = None, uuid_column: str | None = None, @@ -720,7 +538,7 @@ def _generate_uuids( The function can potentially ingest the same data multiple times with different UUIDs. :param df: A dataframe with data to generate a UUID from. - :param class_name: The name of the class use as part of the uuid namespace. + :param collection_name: The name of the collection use as part of the uuid namespace. :param uuid_column: Name of the column to create. Default is 'id'. :param unique_columns: A list of columns to use for UUID generation. By default, all columns except vector_column will be used. @@ -751,7 +569,7 @@ def _generate_uuids( df[uuid_column] = ( df[unique_columns] .drop(columns=[vector_column], inplace=False, errors="ignore") - .apply(lambda row: generate_uuid5(identifier=row.to_dict(), namespace=class_name), axis=1) + .apply(lambda row: generate_uuid5(identifier=row.to_dict(), namespace=collection_name), axis=1) ) return df, uuid_column @@ -761,7 +579,7 @@ def _get_documents_to_uuid_map( data: pd.DataFrame, document_column: str, uuid_column: str, - class_name: str, + collection_name: str, offset: int = 0, limit: int = 2000, ) -> dict[str, set]: @@ -769,7 +587,7 @@ def _get_documents_to_uuid_map( :param data: A single pandas DataFrame. :param document_column: The name of the property to query. - :param class_name: The name of the class to query. + :param collection_name: The name of the collection to query. :param uuid_column: The name of the column containing the UUID. :param offset: pagination parameter to indicate the which object to start fetching data. :param limit: pagination param to indicate the number of records to fetch from start object. @@ -777,37 +595,40 @@ def _get_documents_to_uuid_map( documents_to_uuid: dict = {} document_keys = set(data[document_column]) while True: - data_objects = ( - self.conn.query.get(properties=[document_column], class_name=class_name) - .with_additional([uuid_column]) - .with_where( - { - "operator": "Or", - "operands": [ - {"valueText": key, "path": document_column, "operator": "Equal"} - for key in document_keys - ], - } - ) - .with_offset(offset) - .with_limit(limit) - .do()["data"]["Get"][class_name] + collection = self.get_collection(collection_name) + data_objects = collection.query.fetch_objects( + filters=Filter.any_of( + [Filter.by_property(document_column).equal(key) for key in document_keys] + ), + return_properties=[document_column], + limit=limit, + offset=offset, ) - if len(data_objects) == 0: + if len(data_objects.objects) == 0: break offset = offset + limit + + if uuid_column in data_objects.objects[0].properties: + data_object_properties = [obj.properties for obj in data_objects.objects] + else: + data_object_properties = [] + for obj in data_objects.objects: + row = dict(obj.properties) + row[uuid_column] = str(obj.uuid) + data_object_properties.append(row) + documents_to_uuid.update( self._prepare_document_to_uuid_map( - data=data_objects, + data=data_object_properties, group_key=document_column, - get_value=lambda x: x["_additional"][uuid_column], + get_value=lambda x: x[uuid_column], ) ) return documents_to_uuid @staticmethod def _prepare_document_to_uuid_map( - data: list[dict], group_key: str, get_value: Callable[[dict], str] + data: Sequence[Mapping], group_key: str, get_value: Callable[[Mapping], str] ) -> dict[str, set]: """Prepare the map of grouped_key to set.""" grouped_key_to_set: dict = {} @@ -821,21 +642,24 @@ def _prepare_document_to_uuid_map( return grouped_key_to_set def _get_segregated_documents( - self, data: pd.DataFrame, document_column: str, class_name: str, uuid_column: str + self, data: pd.DataFrame, document_column: str, collection_name: str, uuid_column: str ) -> tuple[dict[str, set], set, set, set]: """ Segregate documents into changed, unchanged and new document, when compared to Weaviate db. :param data: A single pandas DataFrame. :param document_column: The name of the property to query. - :param class_name: The name of the class to query. + :param collection_name: The name of the collection to query. :param uuid_column: The name of the column containing the UUID. """ changed_documents = set() unchanged_docs = set() new_documents = set() existing_documents_to_uuid = self._get_documents_to_uuid_map( - data=data, uuid_column=uuid_column, document_column=document_column, class_name=class_name + data=data, + uuid_column=uuid_column, + document_column=document_column, + collection_name=collection_name, ) input_documents_to_uuid = self._prepare_document_to_uuid_map( @@ -843,16 +667,15 @@ def _get_segregated_documents( group_key=document_column, get_value=lambda x: x[uuid_column], ) - # segregate documents into changed, unchanged and non-existing documents. for doc_url, doc_set in input_documents_to_uuid.items(): if doc_url in existing_documents_to_uuid: if existing_documents_to_uuid[doc_url] != doc_set: - changed_documents.add(doc_url) + changed_documents.add(str(doc_url)) else: - unchanged_docs.add(doc_url) + unchanged_docs.add(str(doc_url)) else: - new_documents.add(doc_url) + new_documents.add(str(doc_url)) return existing_documents_to_uuid, changed_documents, unchanged_docs, new_documents @@ -860,81 +683,55 @@ def _delete_all_documents_objects( self, document_keys: list[str], document_column: str, - class_name: str, + collection_name: str, total_objects_count: int = 1, - batch_delete_error: list | None = None, - tenant: str | None = None, - batch_config_params: dict[str, Any] | None = None, + batch_delete_error: Sequence | None = None, verbose: bool = False, - ): + ) -> Sequence[dict[str, UUID | str]]: """Delete all object that belong to list of documents. :param document_keys: list of unique documents identifiers. :param document_column: Column in DataFrame that identifying source document. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param collection_name: Name of the collection in Weaviate schema where data is to be ingested. :param total_objects_count: total number of objects to delete, needed as max limit on one delete query is 10,000, if we have more objects to delete we need to run query multiple times. :param batch_delete_error: list to hold errors while inserting. - :param tenant: The tenant to which the object will be added. - :param batch_config_params: Additional parameters for Weaviate batch configuration. :param verbose: Flag to enable verbose output during the ingestion process. """ batch_delete_error = batch_delete_error or [] - if not batch_config_params: - batch_config_params = {} - # This limit is imposed by Weavaite database MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS = 10000 - self.conn.batch.configure(**batch_config_params) - with self.conn.batch as batch: - # ConsistencyLevel.ALL is essential here to guarantee complete deletion of objects - # across all nodes. Maintaining this level ensures data integrity, preventing - # irrelevant objects from providing misleading context for LLM models. - batch.consistency_level = ConsistencyLevel.ALL - while total_objects_count > 0: - document_objects = batch.delete_objects( - class_name=class_name, - where={ - "operator": "Or", - "operands": [ - { - "path": [document_column], - "operator": "Equal", - "valueText": key, - } - for key in document_keys - ], - }, - output="verbose", - dry_run=False, - tenant=tenant, - ) - total_objects_count = total_objects_count - MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS - matched_objects = document_objects["results"]["matches"] - batch_delete_error = [ - {"uuid": obj["id"]} - for obj in document_objects["results"]["objects"] - if "error" in obj["status"] - ] - if verbose: - self.log.info("Deleted %s Objects", matched_objects) + collection = self.get_collection(collection_name) + delete_many_return = collection.data.delete_many( + where=Filter.any_of([Filter.by_property(document_column).equal(key) for key in document_keys]), + verbose=verbose, + dry_run=False, + ) + total_objects_count = total_objects_count - MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS + matched_objects = delete_many_return.matches + if delete_many_return.failed > 0 and delete_many_return.objects: + batch_delete_error = [ + {"uuid": obj.uuid, "error": obj.error} + for obj in delete_many_return.objects + if obj.error is not None + ] + if verbose: + self.log.info("Deleted %s Objects", matched_objects) return batch_delete_error def create_or_replace_document_objects( self, data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame], - class_name: str, + collection_name: str, document_column: str, existing: str = "skip", uuid_column: str | None = None, vector_column: str = "Vector", - batch_config_params: dict | None = None, - tenant: str | None = None, verbose: bool = False, - ): + ) -> Sequence[dict[str, UUID | str] | None]: """ create or replace objects belonging to documents. @@ -956,21 +753,19 @@ def create_or_replace_document_objects( error: raise an error if an object belonging to a existing document is tried to be created. :param data: A single pandas DataFrame or a list of dicts to be ingested. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param colleciton_name: Name of the collection in Weaviate schema where data is to be ingested. :param existing: Strategy for handling existing data: 'skip', or 'replace'. Default is 'skip'. :param document_column: Column in DataFrame that identifying source document. :param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated. :param vector_column: Column with embedding vectors for pre-embedded data. - :param batch_config_params: Additional parameters for Weaviate batch configuration. - :param tenant: The tenant to which the object will be added. :param verbose: Flag to enable verbose output during the ingestion process. :return: list of UUID which failed to create """ - import pandas as pd - if existing not in ["skip", "replace", "error"]: raise ValueError("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'error'.") + import pandas as pd + if len(data) == 0: return [] @@ -994,7 +789,7 @@ def create_or_replace_document_objects( uuid_column, ) = self._generate_uuids( df=data, - class_name=class_name, + collection_name=collection_name, unique_columns=unique_columns, vector_column=vector_column, uuid_column=uuid_column, @@ -1006,7 +801,7 @@ def create_or_replace_document_objects( if verbose: self.log.info("%s objects remain after deduplication.", data.shape[0]) - batch_delete_error: list = [] + batch_delete_error: Sequence[dict[str, UUID | str]] = [] ( documents_to_uuid_map, changed_documents, @@ -1016,7 +811,7 @@ def create_or_replace_document_objects( data=data, document_column=document_column, uuid_column=uuid_column, - class_name=class_name, + collection_name=collection_name, ) if verbose: self.log.info( @@ -1029,7 +824,6 @@ def create_or_replace_document_objects( self.log.info( "Changed document: %s has %s objects.", document, len(documents_to_uuid_map[document]) ) - self.log.info("Non-existing document: %s", ", ".join(new_documents)) if existing == "error" and len(changed_documents): @@ -1051,43 +845,39 @@ def create_or_replace_document_objects( total_objects_count, changed_documents, ) - batch_delete_error = self._delete_all_documents_objects( - document_keys=list(changed_documents), - total_objects_count=total_objects_count, - document_column=document_column, - class_name=class_name, - batch_delete_error=batch_delete_error, - tenant=tenant, - batch_config_params=batch_config_params, - verbose=verbose, - ) + if list(changed_documents): + batch_delete_error = self._delete_all_documents_objects( + document_keys=list(changed_documents), + document_column=document_column, + collection_name=collection_name, + total_objects_count=total_objects_count, + batch_delete_error=batch_delete_error, + verbose=verbose, + ) data = data[data[document_column].isin(new_documents.union(changed_documents))] self.log.info("Batch inserting %s objects for non-existing and changed documents.", data.shape[0]) - insertion_errors: list = [] if data.shape[0]: - insertion_errors = self.batch_data( - class_name=class_name, + self.batch_data( + collection_name=collection_name, data=data, - batch_config_params=batch_config_params, vector_col=vector_column, uuid_col=uuid_column, - tenant=tenant, ) - if insertion_errors or batch_delete_error: - if insertion_errors: - self.log.info("Failed to insert %s objects.", len(insertion_errors)) + if batch_delete_error: if batch_delete_error: - self.log.info("Failed to delete %s objects.", len(insertion_errors)) + self.log.info("Failed to delete %s objects.", len(batch_delete_error)) # Rollback object that were not created properly self._delete_objects( - [item["uuid"] for item in insertion_errors + batch_delete_error], class_name=class_name + [item["uuid"] for item in batch_delete_error], + collection_name=collection_name, ) if verbose: + collection = self.get_collection(collection_name) self.log.info( - "Total objects in class %s : %s ", - class_name, - self.conn.query.aggregate(class_name).with_meta_count().do(), + "Total objects in collection %s : %s ", + collection_name, + collection.aggregate.over_all(total_count=True), ) - return insertion_errors, batch_delete_error + return batch_delete_error diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py index 8a26ee5bfbed8..e2f9ac0d701f1 100644 --- a/airflow/providers/weaviate/operators/weaviate.py +++ b/airflow/providers/weaviate/operators/weaviate.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: import pandas as pd + from weaviate.types import UUID from airflow.utils.context import Context @@ -43,11 +44,10 @@ class WeaviateIngestOperator(BaseOperator): custom vectors and store them in the Weaviate class. :param conn_id: The Weaviate connection. - :param class_name: The Weaviate class to be used for storing the data objects into. + :param collection: The Weaviate collection to be used for storing the data objects into. :param input_data: The list of dicts or pandas dataframe representing Weaviate data objects to generate embeddings on (or provides custom vectors) and store them in the Weaviate class. :param vector_col: key/column name in which the vectors are stored. - :param batch_params: Additional parameters for Weaviate batch configuration. :param hook_params: Optional config params to be passed to the underlying hook. Should match the desired hook constructor params. :param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on @@ -59,25 +59,23 @@ class WeaviateIngestOperator(BaseOperator): def __init__( self, conn_id: str, - class_name: str, + collection_name: str, input_data: list[dict[str, Any]] | pd.DataFrame | None = None, vector_col: str = "Vector", uuid_column: str = "id", tenant: str | None = None, - batch_params: dict | None = None, hook_params: dict | None = None, input_json: list[dict[str, Any]] | pd.DataFrame | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.class_name = class_name + self.collection_name = collection_name self.conn_id = conn_id self.vector_col = vector_col self.input_json = input_json self.uuid_column = uuid_column self.tenant = tenant self.input_data = input_data - self.batch_params = batch_params or {} self.hook_params = hook_params or {} if (self.input_data is None) and (input_json is not None): @@ -96,18 +94,14 @@ def hook(self) -> WeaviateHook: """Return an instance of the WeaviateHook.""" return WeaviateHook(conn_id=self.conn_id, **self.hook_params) - def execute(self, context: Context) -> list: + def execute(self, context: Context) -> None: self.log.debug("Input data: %s", self.input_data) - insertion_errors: list = [] self.hook.batch_data( - class_name=self.class_name, + collection_name=self.collection_name, data=self.input_data, - batch_config_params=self.batch_params, vector_col=self.vector_col, uuid_col=self.uuid_column, - tenant=self.tenant, ) - return insertion_errors class WeaviateDocumentIngestOperator(BaseOperator): @@ -132,12 +126,11 @@ class WeaviateDocumentIngestOperator(BaseOperator): error: raise an error if an object belonging to a existing document is tried to be created. :param data: A single pandas DataFrame or a list of dicts to be ingested. - :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param collection_name: Name of the collection in Weaviate schema where data is to be ingested. :param existing: Strategy for handling existing data: 'skip', or 'replace'. Default is 'skip'. :param document_column: Column in DataFrame that identifying source document. :param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated. :param vector_column: Column with embedding vectors for pre-embedded data. - :param batch_config_params: Additional parameters for Weaviate batch configuration. :param tenant: The tenant to which the object will be added. :param verbose: Flag to enable verbose output during the ingestion process. :param hook_params: Optional config params to be passed to the underlying hook. @@ -150,12 +143,11 @@ def __init__( self, conn_id: str, input_data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame], - class_name: str, + collection_name: str, document_column: str, existing: str = "skip", uuid_column: str = "id", vector_col: str = "Vector", - batch_config_params: dict | None = None, tenant: str | None = None, verbose: bool = False, hook_params: dict | None = None, @@ -164,12 +156,11 @@ def __init__( super().__init__(**kwargs) self.conn_id = conn_id self.input_data = input_data - self.class_name = class_name + self.collection_name = collection_name self.document_column = document_column self.existing = existing self.uuid_column = uuid_column self.vector_col = vector_col - self.batch_config_params = batch_config_params self.tenant = tenant self.verbose = verbose self.hook_params = hook_params or {} @@ -179,22 +170,20 @@ def hook(self) -> WeaviateHook: """Return an instance of the WeaviateHook.""" return WeaviateHook(conn_id=self.conn_id, **self.hook_params) - def execute(self, context: Context) -> list: + def execute(self, context: Context) -> Sequence[dict[str, UUID | str] | None]: """ Create or replace objects belonging to documents. :return: List of UUID which failed to create """ self.log.debug("Total input objects : %s", len(self.input_data)) - insertion_errors = self.hook.create_or_replace_document_objects( + batch_delete_error = self.hook.create_or_replace_document_objects( data=self.input_data, - class_name=self.class_name, + collection_name=self.collection_name, document_column=self.document_column, existing=self.existing, uuid_column=self.uuid_column, vector_column=self.vector_col, - batch_config_params=self.batch_config_params, - tenant=self.tenant, verbose=self.verbose, ) - return insertion_errors + return batch_delete_error diff --git a/airflow/providers/weaviate/provider.yaml b/airflow/providers/weaviate/provider.yaml index b060c7407b704..47e2fbd386e3d 100644 --- a/airflow/providers/weaviate/provider.yaml +++ b/airflow/providers/weaviate/provider.yaml @@ -28,7 +28,7 @@ source-date-epoch: 1718605569 # note that those versions are maintained by release manager - do not update them manually versions: - - 1.4.2 + - 2.0.0 - 1.4.1 - 1.4.0 - 1.3.4 @@ -50,7 +50,7 @@ integrations: dependencies: - apache-airflow>=2.7.0 - httpx>=0.25.0 - - weaviate-client>=3.24.2 + - weaviate-client>=4.4.0 # In pandas 2.2 minimal version of the sqlalchemy is 2.0 # https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies # However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723 diff --git a/docs/apache-airflow-providers-weaviate/connections.rst b/docs/apache-airflow-providers-weaviate/connections.rst index 081fe14d92acb..428d058d3e96a 100644 --- a/docs/apache-airflow-providers-weaviate/connections.rst +++ b/docs/apache-airflow-providers-weaviate/connections.rst @@ -31,7 +31,7 @@ Configuring the Connection -------------------------- Host (required) - Host URL to connect to the Weaviate cluster. + The host to use for the Weaviate cluster REST and GraphQL API calls. DO NOT include the schema (i.e., http or https). OIDC Username (optional) Username for the OIDC user when OIDC option is to be used for authentication. @@ -39,6 +39,9 @@ OIDC Username (optional) OIDC Password (optional) Password for the OIDC user when OIDC option is to be used for authentication. +Port (option) + The port to use for the Weaviate cluster REST and GraphQL API calls. + Extra (optional) Specify the extra parameters (as json dictionary) that can be used in the connection. All parameters are optional. @@ -48,11 +51,24 @@ Extra (optional) * If you'd like to use Vectorizers for your class, configure the API keys to use the corresponding embedding API. The extras accepts a key ``additional_headers`` containing the dictionary of API keys for the embedding API authentication. They are mentioned in a section here: - `addtional_headers `__ + `Third party API keys `__ Weaviate API Token (optional) Specify your Weaviate API Key to connect when API Key option is to be used for authentication. +Use https (optional) + Whether to use https for the Weaviate cluster REST and GraphQL API calls. + +gRPC host (optional) + The host to use for the Weaviate cluster gRPC API. + +gRPC port (optional) + The port to use for the Weaviate cluster gRPC API. + +Use a secure channel for the underlying gRPC API (optional) + Whether to use a secure channel for the the Weaviate cluster gRPC API. + + Supported Authentication Methods -------------------------------- * API Key Authentication: This method uses the Weaviate API Key to authenticate the connection. You can either have the diff --git a/docs/apache-airflow-providers-weaviate/index.rst b/docs/apache-airflow-providers-weaviate/index.rst index 52b662b262fea..36d0a5ed4e523 100644 --- a/docs/apache-airflow-providers-weaviate/index.rst +++ b/docs/apache-airflow-providers-weaviate/index.rst @@ -104,7 +104,7 @@ PIP package Version required =================== ========================================= ``apache-airflow`` ``>=2.7.0`` ``httpx`` ``>=0.25.0`` -``weaviate-client`` ``>=3.24.2`` +``weaviate-client`` ``>=4.4.0`` ``pandas`` ``>=2.1.2,<2.2; python_version >= "3.9"`` ``pandas`` ``>=1.5.3,<2.2; python_version < "3.9"`` =================== ========================================= diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index a794cfc25ef2f..f4760e5a0aeaa 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1319,7 +1319,7 @@ "httpx>=0.25.0", "pandas>=1.5.3,<2.2;python_version<\"3.9\"", "pandas>=2.1.2,<2.2;python_version>=\"3.9\"", - "weaviate-client>=3.24.2" + "weaviate-client>=4.4.0" ], "devel-deps": [], "plugins": [], diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 650f938dba6b6..48abfb3ffee8a 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -from contextlib import ExitStack from unittest import mock from unittest.mock import MagicMock, Mock @@ -48,28 +47,48 @@ def weaviate_hook(): @pytest.fixture def mock_auth_api_key(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthApiKey") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.api_key") as m: yield m @pytest.fixture def mock_auth_bearer_token(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthBearerToken") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.bearer_token") as m: yield m @pytest.fixture def mock_auth_client_credentials(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientCredentials") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.client_credentials") as m: yield m @pytest.fixture def mock_auth_client_password(): - with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientPassword") as m: + with mock.patch("airflow.providers.weaviate.hooks.weaviate.Auth.client_password") as m: yield m +class MockFetchObjectReturn: + def __init__(self, *, objects): + self.objects = objects + + +class MockObject: + def __init__(self, *, properties: dict, uuid: str) -> None: + self.properties = properties + self.uuid = uuid + self.collection = "collection" + self.metadata = "metadata" + self.references = "references" + self.vector = "vector" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MockObject): + return False + return self.properties == other.properties and self.uuid == other.uuid + + class TestWeaviateHook: """ Test the WeaviateHook Hook. @@ -86,103 +105,148 @@ def setup_method(self, monkeypatch): self.scope = "scope1 scope2" self.client_password = "client_password" self.client_bearer_token = "client_bearer_token" - self.host = "http://localhost:8080" + self.host = "localhost" + self.port = 8000 + self.grpc_host = "localhost" + self.grpc_port = 50051 conns = ( Connection( conn_id=self.weaviate_api_key1, host=self.host, + port=self.port, conn_type="weaviate", - extra={"api_key": self.api_key}, + extra={"api_key": self.api_key, "grpc_host": self.grpc_host, "grpc_port": self.grpc_port}, ), Connection( conn_id=self.weaviate_api_key2, host=self.host, + port=self.port, conn_type="weaviate", - extra={"token": self.api_key}, + extra={"token": self.api_key, "grpc_host": self.grpc_host, "grpc_port": self.grpc_port}, ), Connection( conn_id=self.weaviate_client_credentials, host=self.host, + port=self.port, conn_type="weaviate", - extra={"client_secret": self.client_secret, "scope": self.scope}, + extra={ + "client_secret": self.client_secret, + "scope": self.scope, + "grpc_host": self.grpc_host, + "grpc_port": self.grpc_port, + }, ), Connection( conn_id=self.client_password, host=self.host, + port=self.port, conn_type="weaviate", login="login", password="password", + extra={"grpc_host": self.grpc_host, "grpc_port": self.grpc_port}, ), Connection( conn_id=self.client_bearer_token, host=self.host, + port=self.port, conn_type="weaviate", extra={ "access_token": self.client_bearer_token, "expires_in": 30, "refresh_token": "refresh_token", + "grpc_host": self.grpc_host, + "grpc_port": self.grpc_port, }, ), ) for conn in conns: monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.get_uri()) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_api_key_in_extra(self, mock_client, mock_auth_api_key): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_api_key_in_extra(self, mock_connect_to_custom, mock_auth_api_key): hook = WeaviateHook(conn_id=self.weaviate_api_key1) hook.get_conn() - mock_auth_api_key.assert_called_once_with(self.api_key) - mock_client.assert_called_once_with( - url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + mock_auth_api_key.assert_called_once_with(api_key=self.api_key) + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_api_key(api_key=self.api_key), + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_token_in_extra(self, mock_client, mock_auth_api_key): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_token_in_extra(self, mock_connect_to_custom, mock_auth_api_key): # when token is passed in extra hook = WeaviateHook(conn_id=self.weaviate_api_key2) hook.get_conn() - mock_auth_api_key.assert_called_once_with(self.api_key) - mock_client.assert_called_once_with( - url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + mock_auth_api_key.assert_called_once_with(api_key=self.api_key) + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_api_key(api_key=self.api_key), + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_access_token_in_extra(self, mock_client, mock_auth_bearer_token): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_access_token_in_extra(self, mock_connect_to_custom, mock_auth_bearer_token): hook = WeaviateHook(conn_id=self.client_bearer_token) hook.get_conn() mock_auth_bearer_token.assert_called_once_with( - self.client_bearer_token, expires_in=30, refresh_token="refresh_token" + access_token=self.client_bearer_token, expires_in=30, refresh_token="refresh_token" ) - mock_client.assert_called_once_with( - url=self.host, - auth_client_secret=mock_auth_bearer_token( + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_bearer_token( access_token=self.client_bearer_token, expires_in=30, refresh_token="refresh_token" ), - additional_headers={}, + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_client_secret_in_extra(self, mock_client, mock_auth_client_credentials): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_client_secret_in_extra(self, mock_connect_to_custom, mock_auth_client_credentials): hook = WeaviateHook(conn_id=self.weaviate_client_credentials) hook.get_conn() mock_auth_client_credentials.assert_called_once_with( client_secret=self.client_secret, scope=self.scope ) - mock_client.assert_called_once_with( - url=self.host, - auth_client_secret=mock_auth_client_credentials(api_key=self.client_secret, scope=self.scope), - additional_headers={}, + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_client_credentials(api_key=self.client_secret, scope=self.scope), + headers={}, ) - @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_conn_with_client_password_in_extra(self, mock_client, mock_auth_client_password): + @mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom") + def test_get_conn_with_client_password_in_extra(self, mock_connect_to_custom, mock_auth_client_password): hook = WeaviateHook(conn_id=self.client_password) hook.get_conn() mock_auth_client_password.assert_called_once_with(username="login", password="password", scope=None) - mock_client.assert_called_once_with( - url=self.host, - auth_client_secret=mock_auth_client_password(username="login", password="password", scope=None), - additional_headers={}, + mock_connect_to_custom.assert_called_once_with( + http_host=self.host, + http_port=80, + http_secure=False, + grpc_host="localhost", + grpc_port=50051, + grpc_secure=False, + auth_credentials=mock_auth_client_password(username="login", password="password", scope=None), + headers={}, ) @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5") @@ -190,12 +254,14 @@ def test_create_object(self, mock_gen_uuid, weaviate_hook): """ Test the create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + return_value = weaviate_hook.create_object({"name": "Test"}, "TestCollection") + mock_gen_uuid.assert_called_once() - mock_client.data_object.create.assert_called_once_with( - {"name": "Test"}, "TestClass", uuid=mock_gen_uuid.return_value + mock_collection.data.insert.assert_called_once_with( + properties={"name": "Test"}, uuid=mock_gen_uuid.return_value ) assert return_value @@ -203,107 +269,126 @@ def test_create_object_already_exists_return_none(self, weaviate_hook): """ Test the create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - mock_client.data_object.create.side_effect = ObjectAlreadyExistsException - return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + mock_collection.data.insert.side_effect = ObjectAlreadyExistsException + + return_value = weaviate_hook.create_object({"name": "Test"}, "TestCollection") + assert return_value is None def test_get_object(self, weaviate_hook): """ Test the get_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.get_object(class_name="TestClass", uuid="uuid") - mock_client.data_object.get.assert_called_once_with(class_name="TestClass", uuid="uuid") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + weaviate_hook.get_object(collection_name="TestCollection", uuid="uuid") + + mock_collection.query.fetch_objects.assert_called_once_with(uuid="uuid") def test_get_of_get_or_create_object(self, weaviate_hook): """ Test the get part of get_or_create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass") - mock_client.data_object.get.assert_called_once_with( - class_name="TestClass", - consistency_level=None, - tenant=None, - ) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + weaviate_hook.get_or_create_object(data_object={"name": "Test"}, collection_name="TestCollection") + + mock_collection.query.fetch_objects.assert_called_once_with() @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5") def test_create_of_get_or_create_object(self, mock_gen_uuid, weaviate_hook): """ Test the create part of get_or_create_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) weaviate_hook.get_object = MagicMock(return_value=None) mock_create_object = MagicMock() weaviate_hook.create_object = mock_create_object - weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass") + + weaviate_hook.get_or_create_object(data_object={"name": "Test"}, collection_name="TestCollection") + mock_create_object.assert_called_once_with( - {"name": "Test"}, - "TestClass", + data_object={"name": "Test"}, + collection_name="TestCollection", uuid=mock_gen_uuid.return_value, - consistency_level=None, - tenant=None, vector=None, ) def test_create_of_get_or_create_object_raises_valueerror(self, weaviate_hook): """ - Test that if data_object is None or class_name is None, ValueError is raised. + Test that if data_object is None or collection_name is None, ValueError is raised. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) weaviate_hook.get_object = MagicMock(return_value=None) mock_create_object = MagicMock() + weaviate_hook.create_object = mock_create_object + with pytest.raises(ValueError): - weaviate_hook.get_or_create_object(data_object=None, class_name="TestClass") + weaviate_hook.get_or_create_object(data_object=None, collection_name="TestCollection") with pytest.raises(ValueError): - weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name=None) + weaviate_hook.get_or_create_object(data_object={"name": "Test"}, collection_name=None) def test_get_all_objects(self, weaviate_hook): """ Test the get_all_objects method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) objects = [ - {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]}, - {"deprecations": None, "objects": []}, + MockFetchObjectReturn( + objects=[ + MockObject(properties={"name": "Test1", "id": 2}, uuid="u1"), + MockObject(properties={"name": "Test2", "id": 3}, uuid="u2"), + ] + ), + MockFetchObjectReturn(objects=[]), ] mock_get_object = MagicMock() weaviate_hook.get_object = mock_get_object mock_get_object.side_effect = objects - return_value = weaviate_hook.get_all_objects(class_name="TestClass") + return_value = weaviate_hook.get_all_objects(collection_name="TestCollection") + assert weaviate_hook.get_object.call_args_list == [ - mock.call(after=None, class_name="TestClass"), - mock.call(after=3, class_name="TestClass"), + mock.call(after=None, collection_name="TestCollection"), + mock.call(after="u2", collection_name="TestCollection"), + ] + assert return_value == [ + MockObject(properties={"name": "Test1", "id": 2}, uuid="u1"), + MockObject(properties={"name": "Test2", "id": 3}, uuid="u2"), ] - assert return_value == [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}] def test_get_all_objects_returns_dataframe(self, weaviate_hook): """ Test the get_all_objects method of WeaviateHook can return a dataframe. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) objects = [ - {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]}, - {"deprecations": None, "objects": []}, + MockFetchObjectReturn( + objects=[ + MockObject(properties={"name": "Test1", "id": 2}, uuid="u1"), + MockObject(properties={"name": "Test2", "id": 3}, uuid="u2"), + ] + ), + MockFetchObjectReturn(objects=[]), ] mock_get_object = MagicMock() weaviate_hook.get_object = mock_get_object mock_get_object.side_effect = objects - return_value = weaviate_hook.get_all_objects(class_name="TestClass", as_dataframe=True) + return_value = weaviate_hook.get_all_objects(collection_name="TestCollection", as_dataframe=True) + assert weaviate_hook.get_object.call_args_list == [ - mock.call(after=None, class_name="TestClass"), - mock.call(after=3, class_name="TestClass"), + mock.call(after=None, collection_name="TestCollection"), + mock.call(after="u2", collection_name="TestCollection"), ] import pandas @@ -313,100 +398,68 @@ def test_delete_object(self, weaviate_hook): """ Test the delete_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.delete_object(uuid="uuid", class_name="TestClass") - mock_client.data_object.delete.assert_called_once_with("uuid", class_name="TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + weaviate_hook.delete_object(collection_name="TestCollection", uuid="uuid") + + mock_collection.data.delete_by_id.assert_called_once_with(uuid="uuid") def test_update_object(self, weaviate_hook): """ Test the update_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + weaviate_hook.update_object( - uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d" - ) - mock_client.data_object.update.assert_called_once_with( - {"name": "Test"}, "TestClass", "uuid", tenant="2d" + uuid="uuid", collection_name="TestCollection", properties={"name": "Test"} ) - def test_validate_object(self, weaviate_hook): - """ - Test the validate_object method of WeaviateHook. - """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.validate_object(class_name="TestClass", data_object={"name": "Test"}, uuid="2d") - mock_client.data_object.validate.assert_called_once_with({"name": "Test"}, "TestClass", uuid="2d") + mock_collection.data.update.assert_called_once_with(properties={"name": "Test"}, uuid="uuid") def test_replace_object(self, weaviate_hook): """ Test the replace_object method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + weaviate_hook.replace_object( - uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d" + uuid="uuid", collection_name="TestCollection", properties={"name": "Test"} ) - mock_client.data_object.replace.assert_called_once_with( - {"name": "Test"}, "TestClass", "uuid", tenant="2d" + + mock_collection.data.replace.assert_called_once_with( + properties={"name": "Test"}, uuid="uuid", references=None ) def test_object_exists(self, weaviate_hook): """ Test the object_exists method of WeaviateHook. """ - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) - weaviate_hook.object_exists(class_name="TestClass", uuid="2d") - mock_client.data_object.exists.assert_called_once_with("2d", class_name="TestClass") + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + weaviate_hook.object_exists(collection_name="TestCollection", uuid="2d") -def test_create_class(weaviate_hook): - """ - Test the create_class method of WeaviateHook. - """ - # Mock the Weaviate Client - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + mock_collection.data.exists.assert_called_once_with(uuid="2d") - # Define test class JSON - test_class_json = { - "class": "TestClass", - "description": "Test class for unit testing", - } - - # Test the create_class method - weaviate_hook.create_class(test_class_json) - - # Assert that the create_class method was called with the correct arguments - mock_client.schema.create_class.assert_called_once_with(test_class_json) - -def test_create_schema(weaviate_hook): +def test_create_collection(weaviate_hook): """ - Test the create_schema method of WeaviateHook. + Test the create_collection method of WeaviateHook. """ # Mock the Weaviate Client mock_client = MagicMock() weaviate_hook.get_conn = MagicMock(return_value=mock_client) - # Define test schema JSON - test_schema_json = { - "classes": [ - { - "class": "TestClass", - "description": "Test class for unit testing", - } - ] - } - - # Test the create_schema method - weaviate_hook.create_schema(test_schema_json) + # Test the create_collection method + weaviate_hook.create_collection("TestCollection", description="Test class for unit testing") - # Assert that the create_schema method was called with the correct arguments - mock_client.schema.create.assert_called_once_with(test_schema_json) + # Assert that the create_collection method was called with the correct arguments + mock_client.collections.create.assert_called_once_with( + name="TestCollection", description="Test class for unit testing" + ) @pytest.mark.parametrize( @@ -421,256 +474,73 @@ def test_batch_data(data, expected_length, weaviate_hook): """ Test the batch_data method of WeaviateHook. """ - # Mock the Weaviate Client - mock_client = MagicMock() - weaviate_hook.get_conn = MagicMock(return_value=mock_client) + # Mock the Weaviate Collection + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) # Define test data - test_class_name = "TestClass" + test_collection_name = "TestCollection" # Test the batch_data method - weaviate_hook.batch_data(test_class_name, data) + weaviate_hook.batch_data(test_collection_name, data) - # Assert that the batch_data method was called with the correct arguments - mock_client.batch.configure.assert_called_once() - mock_batch_context = mock_client.batch.__enter__.return_value - assert mock_batch_context.add_data_object.call_count == expected_length + mock_batch_context = mock_collection.batch.dynamic.return_value.__enter__.return_value + assert mock_batch_context.add_object.call_count == expected_length -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") -def test_batch_data_retry(get_conn, weaviate_hook): +def test_batch_data_retry(weaviate_hook): """Test to ensure retrying working as expected""" + # Mock the Weaviate Collection + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + data = [{"name": "chandler"}, {"name": "joey"}, {"name": "ross"}] response = requests.Response() response.status_code = 429 error = requests.exceptions.HTTPError() error.response = response side_effect = [None, error, None, error, None] - get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect - weaviate_hook.batch_data("TestClass", data) - assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect) + mock_collection.batch.dynamic.return_value.__enter__.return_value.add_object.side_effect = side_effect -@pytest.mark.parametrize( - argnames=["get_schema_value", "existing", "expected_value"], - argvalues=[ - ({"classes": [{"class": "A"}, {"class": "B"}]}, "ignore", [{"class": "C"}]), - ({"classes": [{"class": "A"}, {"class": "B"}]}, "replace", [{"class": "B"}, {"class": "C"}]), - ({"classes": [{"class": "A"}, {"class": "B"}]}, "fail", {}), - ({"classes": [{"class": "A"}, {"class": "B"}]}, "invalid_option", {}), - ], - ids=["ignore", "replace", "fail", "invalid_option"], -) -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.delete_classes") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema") -def test_upsert_schema_scenarios( - get_schema, create_schema, delete_classes, get_schema_value, existing, expected_value, weaviate_hook -): - schema_json = { - "B": {"class": "B"}, - "C": {"class": "C"}, - } - with ExitStack() as stack: - delete_classes.return_value = None - if existing in ["fail", "invalid_option"]: - stack.enter_context(pytest.raises(ValueError)) - get_schema.return_value = get_schema_value - weaviate_hook.create_or_replace_classes(schema_json=schema_json, existing=existing) - create_schema.assert_called_once_with({"classes": expected_value}) - if existing == "replace": - delete_classes.assert_called_once_with(class_names=["B"]) - - -@mock.patch("builtins.open") -@mock.patch("json.load") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema") -def test_upsert_schema_json_file_param(get_schema, create_schema, load, open, weaviate_hook): - """Test if schema_json is path to a json file""" - get_schema.return_value = {"classes": [{"class": "A"}, {"class": "B"}]} - load.return_value = { - "B": {"class": "B"}, - "C": {"class": "C"}, - } - weaviate_hook.create_or_replace_classes(schema_json="/tmp/some_temp_file.json", existing="ignore") - create_schema.assert_called_once_with({"classes": [{"class": "C"}]}) + weaviate_hook.batch_data("TestCollection", data) + + assert mock_collection.batch.dynamic.return_value.__enter__.return_value.add_object.call_count == len( + side_effect + ) -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client") -def test_delete_classes(get_client, weaviate_hook): - class_names = ["class_a", "class_b"] - get_client.return_value.schema.delete_class.side_effect = [ +@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") +def test_delete_collections(get_conn, weaviate_hook): + collection_names = ["collection_a", "collection_b"] + get_conn.return_value.collections.delete.side_effect = [ weaviate.UnexpectedStatusCodeException("something failed", requests.Response()), None, ] - error_list = weaviate_hook.delete_classes(class_names, if_error="continue") - assert error_list == ["class_a"] + error_list = weaviate_hook.delete_collections(collection_names, if_error="continue") + assert error_list == ["collection_a"] - get_client.return_value.schema.delete_class.side_effect = weaviate.UnexpectedStatusCodeException( + get_conn.return_value.collections.delete.side_effect = weaviate.UnexpectedStatusCodeException( "something failed", requests.Response() ) with pytest.raises(weaviate.UnexpectedStatusCodeException): - weaviate_hook.delete_classes("class_a", if_error="stop") + weaviate_hook.delete_collections("class_a", if_error="stop") -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client") -def test_http_errors_of_delete_classes(get_client, weaviate_hook): - class_names = ["class_a", "class_b"] +@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn") +def test_http_errors_of_delete_collections(get_conn, weaviate_hook): + collection_names = ["collection_a", "collection_b"] resp = requests.Response() resp.status_code = 429 - get_client.return_value.schema.delete_class.side_effect = [ + get_conn.return_value.collections.delete.side_effect = [ requests.exceptions.HTTPError(response=resp), None, requests.exceptions.ConnectionError, None, ] - error_list = weaviate_hook.delete_classes(class_names, if_error="continue") + error_list = weaviate_hook.delete_collections(collection_names, if_error="continue") assert error_list == [] - assert get_client.return_value.schema.delete_class.call_count == 4 - - -@pytest.mark.parametrize( - argnames=["classes_to_test", "expected_result"], - argvalues=[ - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - ], - }, - ], - True, - ), - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - ], - }, - ], - True, - ), - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "invalid_property", - "description": "Last name of the author", - "dataType": ["text"], - } - ], - }, - ], - False, - ), - ( - [ - { - "class": "invalid_class", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - ], - }, - ], - False, - ), - ( - [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - }, - { - "name": "name", - "description": "Name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - ], - }, - ], - True, - ), - ], - ids=( - "property_level_check", - "class_level_check", - "invalid_property", - "invalid_class", - "swapped_properties", - ), -) -@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema") -def test_contains_schema(get_schema, classes_to_test, expected_result, weaviate_hook): - get_schema.return_value = { - "classes": [ - { - "class": "Author", - "description": "Authors info", - "properties": [ - { - "name": "name", - "description": "Name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - ], - }, - { - "class": "Article", - "description": "An article written by an Author", - "properties": [ - { - "name": "name", - "description": "Name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - { - "name": "last_name", - "description": "Last name of the author", - "dataType": ["text"], - "extra_key": "some_value", - }, - ], - }, - ] - } - assert weaviate_hook.check_subset_of_schema(classes_to_test) == expected_result + assert get_conn.return_value.collections.delete.call_count == 4 @mock.patch("weaviate.util.generate_uuid5") @@ -678,7 +548,7 @@ def test___generate_uuids(generate_uuid5, weaviate_hook): df = pd.DataFrame.from_dict({"name": ["ross", "bob"], "age": ["12", "22"], "gender": ["m", "m"]}) with pytest.raises(ValueError, match=r"Columns last_name don't exist in dataframe"): weaviate_hook._generate_uuids( - df=df, class_name="test", unique_columns=["name", "age", "gender", "last_name"] + df=df, collection_name="test", unique_columns=["name", "age", "gender", "last_name"] ) df = pd.DataFrame.from_dict( @@ -687,14 +557,14 @@ def test___generate_uuids(generate_uuid5, weaviate_hook): with pytest.raises( ValueError, match=r"Property 'id' already in dataset. Consider renaming or specify 'uuid_column'" ): - weaviate_hook._generate_uuids(df=df, class_name="test", unique_columns=["name", "age", "gender"]) + weaviate_hook._generate_uuids(df=df, collection_name="test", unique_columns=["name", "age", "gender"]) with pytest.raises( ValueError, match=r"Property age already in dataset. Consider renaming or specify a different 'uuid_column'.", ): weaviate_hook._generate_uuids( - df=df, uuid_column="age", class_name="test", unique_columns=["name", "age", "gender"] + df=df, uuid_column="age", collection_name="test", unique_columns=["name", "age", "gender"] ) @@ -712,7 +582,7 @@ def test__delete_objects(delete_object, weaviate_hook): ) delete_object.side_effect = [not_found_exception, None, http_429_exception, http_429_exception, None] - weaviate_hook._delete_objects(uuids=["1", "2", "3"], class_name="test") + weaviate_hook._delete_objects(uuids=["1", "2", "3"], collection_name="test") assert delete_object.call_count == 5 @@ -750,7 +620,7 @@ def test___get_segregated_documents(_get_documents_to_uuid_map, _prepare_documen data=pd.DataFrame(), document_column="doc_key", uuid_column="id", - class_name="doc", + collection_name="doc", ) assert changed_documents == {"abc.doc"} assert unchanged_docs == {"xyz.doc"} @@ -776,7 +646,7 @@ def test_error_option_of_create_or_replace_document_objects( _generate_uuids.return_value = (df, "id") with pytest.raises(ValueError, match="Documents abc.xml already exists. You can either skip or replace"): weaviate_hook.create_or_replace_document_objects( - data=df, document_column="doc", class_name="test", existing="error" + data=df, document_column="doc", collection_name="test", existing="error" ) @@ -797,7 +667,7 @@ def test_skip_option_of_create_or_replace_document_objects( } ) - class_name = "test" + collection_name = "test" documents_to_uuid_map, changed_documents, unchanged_documents, new_documents = ( {}, {"abc.xml"}, @@ -813,7 +683,7 @@ def test_skip_option_of_create_or_replace_document_objects( _generate_uuids.return_value = (df, "id") weaviate_hook.create_or_replace_document_objects( - data=df, class_name=class_name, existing="skip", document_column="doc" + data=df, collection_name=collection_name, existing="skip", document_column="doc" ) pd.testing.assert_frame_equal( @@ -838,7 +708,7 @@ def test_replace_option_of_create_or_replace_document_objects( } ) - class_name = "test" + collection_name = "test" documents_to_uuid_map, changed_documents, unchanged_documents, new_documents = ( {"abc.xml": {"uuid"}}, {"abc.xml"}, @@ -854,16 +724,14 @@ def test_replace_option_of_create_or_replace_document_objects( ) _generate_uuids.return_value = (df, "id") weaviate_hook.create_or_replace_document_objects( - data=df, class_name=class_name, existing="replace", document_column="doc" + data=df, collection_name=collection_name, existing="replace", document_column="doc" ) _delete_all_documents_objects.assert_called_with( document_keys=list(changed_documents), total_objects_count=1, document_column="doc", - class_name="test", + collection_name="test", batch_delete_error=[], - tenant=None, - batch_config_params=None, verbose=False, ) pd.testing.assert_frame_equal( diff --git a/tests/providers/weaviate/operators/test_weaviate.py b/tests/providers/weaviate/operators/test_weaviate.py index e147743d74cfe..8060fdf023116 100644 --- a/tests/providers/weaviate/operators/test_weaviate.py +++ b/tests/providers/weaviate/operators/test_weaviate.py @@ -37,15 +37,14 @@ def operator(self): return WeaviateIngestOperator( task_id="weaviate_task", conn_id="weaviate_conn", - class_name="my_class", + collection_name="my_collection", input_data=[{"data": "sample_data"}], ) def test_constructor(self, operator): assert operator.conn_id == "weaviate_conn" - assert operator.class_name == "my_class" + assert operator.collection_name == "my_collection" assert operator.input_data == [{"data": "sample_data"}] - assert operator.batch_params == {} assert operator.hook_params == {} @patch("airflow.providers.weaviate.operators.weaviate.WeaviateIngestOperator.log") @@ -57,21 +56,18 @@ def test_execute_with_input_json(self, mock_log, operator): operator = WeaviateIngestOperator( task_id="weaviate_task", conn_id="weaviate_conn", - class_name="my_class", + collection_name="my_collection", input_json=[{"data": "sample_data"}], ) - operator.hook.batch_data = MagicMock() operator.execute(context=None) operator.hook.batch_data.assert_called_once_with( - class_name="my_class", + collection_name="my_collection", data=[{"data": "sample_data"}], - batch_config_params={}, vector_col="Vector", uuid_col="id", - tenant=None, ) mock_log.debug.assert_called_once_with("Input data: %s", [{"data": "sample_data"}]) @@ -82,12 +78,10 @@ def test_execute_with_input_data(self, mock_log, operator): operator.execute(context=None) operator.hook.batch_data.assert_called_once_with( - class_name="my_class", + collection_name="my_collection", data=[{"data": "sample_data"}], - batch_config_params={}, vector_col="Vector", uuid_col="id", - tenant=None, ) mock_log.debug.assert_called_once_with("Input data: %s", [{"data": "sample_data"}]) @@ -99,7 +93,7 @@ def test_templates(self, create_task_instance_of_operator): dag_id=dag_id, task_id="task-id", conn_id="weaviate_conn", - class_name="my_class", + collection_name="my_collection", input_json="{{ dag.dag_id }}", input_data="{{ dag.dag_id }}", ) @@ -114,8 +108,7 @@ def test_partial_batch_hook_params(self, dag_maker, session): WeaviateIngestOperator.partial( task_id="fake-task-id", conn_id="weaviate_conn", - class_name="FooBar", - batch_params={"spam": "egg"}, + collection_name="FooBar", hook_params={"baz": "biz"}, ).expand(input_data=[{}, {}]) @@ -124,7 +117,6 @@ def test_partial_batch_hook_params(self, dag_maker, session): with set_current_task_instance_session(session=session): for ti in tis: ti.render_templates() - assert ti.task.batch_params == {"spam": "egg"} assert ti.task.hook_params == {"baz": "biz"} @@ -135,23 +127,21 @@ def operator(self): task_id="weaviate_task", conn_id="weaviate_conn", input_data=[{"data": "sample_data"}], - class_name="my_class", + collection_name="my_collection", document_column="docLink", existing="skip", uuid_column="id", vector_col="vector", - batch_config_params={"size": 1000}, ) def test_constructor(self, operator): assert operator.conn_id == "weaviate_conn" assert operator.input_data == [{"data": "sample_data"}] - assert operator.class_name == "my_class" + assert operator.collection_name == "my_collection" assert operator.document_column == "docLink" assert operator.existing == "skip" assert operator.uuid_column == "id" assert operator.vector_col == "vector" - assert operator.batch_config_params == {"size": 1000} assert operator.hook_params == {} @patch("airflow.providers.weaviate.operators.weaviate.WeaviateDocumentIngestOperator.log") @@ -162,13 +152,11 @@ def test_execute_with_input_json(self, mock_log, operator): operator.hook.create_or_replace_document_objects.assert_called_once_with( data=[{"data": "sample_data"}], - class_name="my_class", + collection_name="my_collection", document_column="docLink", existing="skip", uuid_column="id", vector_column="vector", - batch_config_params={"size": 1000}, - tenant=None, verbose=False, ) mock_log.debug.assert_called_once_with("Total input objects : %s", len([{"data": "sample_data"}])) @@ -179,7 +167,7 @@ def test_partial_hook_params(self, dag_maker, session): WeaviateDocumentIngestOperator.partial( task_id="fake-task-id", conn_id="weaviate_conn", - class_name="FooBar", + collection_name="FooBar", document_column="spam-egg", hook_params={"baz": "biz"}, ).expand(input_data=[{}, {}]) diff --git a/tests/system/providers/weaviate/example_weaviate_cohere.py b/tests/system/providers/weaviate/example_weaviate_cohere.py index f2a190bfcf2ef..8415bcecc7ac8 100644 --- a/tests/system/providers/weaviate/example_weaviate_cohere.py +++ b/tests/system/providers/weaviate/example_weaviate_cohere.py @@ -36,19 +36,15 @@ def example_weaviate_cohere(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "Weaviate_example_class", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # Collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection(name="Weaviate_example_collection", vectorizer_config=None) @setup @task @@ -60,6 +56,7 @@ def get_data_to_embed(): return [[item["Question"]] for item in data] data_to_embed = get_data_to_embed() + embed_data = CohereEmbeddingOperator.partial( task_id="embedding_using_xcom_data", ).expand(input_text=data_to_embed["return_value"]) @@ -81,7 +78,7 @@ def update_vector_data_in_json(**kwargs): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name="Weaviate_example_class", + collection_name="Weaviate_example_collection", input_data=update_vector_data_in_json["return_value"], ) @@ -92,24 +89,24 @@ def update_vector_data_in_json(**kwargs): @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collections(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["Weaviate_example_class"]) + weaviate_hook.delete_collections(["Weaviate_example_collections"]) ( - create_weaviate_class() + create_weaviate_collection() >> embed_data >> update_vector_data_in_json >> perform_ingestion >> embed_query - >> delete_weaviate_class() + >> delete_weaviate_collections() ) diff --git a/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py b/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py index 4b29eb1998b80..a52c0e52625ce 100644 --- a/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py +++ b/tests/system/providers/weaviate/example_weaviate_dynamic_mapping_dag.py @@ -17,6 +17,7 @@ from __future__ import annotations import pendulum +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, setup, task, teardown from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator @@ -34,19 +35,15 @@ def example_weaviate_dynamic_mapping_dag(): @setup @task - def create_weaviate_class(data): + def create_weaviate_collection(data): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": data[0], - "vectorizer": data[1], - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection(data[0], vectorizer_config=data[1]) @setup @task @@ -64,28 +61,30 @@ def get_data_to_ingest(): task_id="perform_ingestion", conn_id="weaviate_default", ).expand( - class_name=["example1", "example2"], + collection_name=["example1", "example2"], input_data=get_data_to_ingest["return_value"], ) @teardown @task - def delete_weaviate_class(class_name): + def delete_weaviate_collection(collection_name): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes([class_name]) + weaviate_hook.delete_collections([collection_name]) ( - create_weaviate_class.expand(data=[["example1", "none"], ["example2", "text2vec-openai"]]) + create_weaviate_collection.expand( + data=[["example1", "none"], ["example2", Configure.Vectorizer.text2vec_openai()]] + ) >> perform_ingestion - >> delete_weaviate_class.expand(class_name=["example1", "example2"]) + >> delete_weaviate_collection.expand(collection_name=["example1", "example2"]) ) diff --git a/tests/system/providers/weaviate/example_weaviate_openai.py b/tests/system/providers/weaviate/example_weaviate_openai.py index 40304236818f0..97d002285d8cc 100644 --- a/tests/system/providers/weaviate/example_weaviate_openai.py +++ b/tests/system/providers/weaviate/example_weaviate_openai.py @@ -40,17 +40,13 @@ def example_weaviate_openai(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "Weaviate_example_class", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection("Weaviate_example_collection", "None") @setup @task @@ -59,6 +55,7 @@ def get_data_to_embed(): return [item["Question"] for item in data] data_to_embed = get_data_to_embed() + embed_data = OpenAIEmbeddingOperator.partial( task_id="embedding_using_xcom_data", conn_id="openai_default", @@ -79,7 +76,7 @@ def update_vector_data_in_json(**kwargs): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name="Weaviate_example_class", + collection_name="Weaviate_example_collection", input_data=update_vector_data_in_json["return_value"], ) @@ -96,31 +93,31 @@ def query_weaviate(**kwargs): query_vector = ti.xcom_pull(task_ids="embed_query", key="return_value") weaviate_hook = WeaviateHook() properties = ["question", "answer", "category"] - response = weaviate_hook.query_with_vector(query_vector, "Weaviate_example_class", *properties) + response = weaviate_hook.query_with_vector(query_vector, "Weaviate_example_collection", *properties) assert ( "In 1953 Watson & Crick built a model" - in response["data"]["Get"]["Weaviate_example_class"][0]["question"] + in response["data"]["Get"]["Weaviate_example_collection"][0]["question"] ) @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collection(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["Weaviate_example_class"]) + weaviate_hook.delete_collections(["Weaviate_example_collection"]) ( - create_weaviate_class() + create_weaviate_collection() >> embed_data >> update_vector_data_in_json >> perform_ingestion >> embed_query >> query_weaviate() - >> delete_weaviate_class() + >> delete_weaviate_collection() ) diff --git a/tests/system/providers/weaviate/example_weaviate_operator.py b/tests/system/providers/weaviate/example_weaviate_operator.py index ae4097006e0c1..b92538059e556 100644 --- a/tests/system/providers/weaviate/example_weaviate_operator.py +++ b/tests/system/providers/weaviate/example_weaviate_operator.py @@ -17,6 +17,8 @@ from __future__ import annotations import pendulum +from weaviate.classes.config import DataType, Property +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, task, teardown from airflow.providers.weaviate.operators.weaviate import ( @@ -95,22 +97,18 @@ def example_weaviate_using_operator(): Example Weaviate DAG demonstrating usage of the operator. """ - # Example tasks to create a Weaviate class without vectorizers, store data with custom vectors in XCOM, and call + # Example tasks to create a Weaviate collection without vectorizers, store data with custom vectors in XCOM, and call # WeaviateIngestOperator to ingest data with those custom vectors. @task() - def create_class_without_vectorizer(): + def create_collection_without_vectorizer(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "QuestionWithoutVectorizerUsingOperator", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection("QuestionWithoutVectorizerUsingOperator") @task(trigger_rule="all_done") def store_data_with_vectors_in_xcom(): @@ -120,7 +118,7 @@ def store_data_with_vectors_in_xcom(): batch_data_with_vectors_xcom_data = WeaviateIngestOperator( task_id="batch_data_with_vectors_xcom_data", conn_id="weaviate_default", - class_name="QuestionWithoutVectorizerUsingOperator", + collection_name="QuestionWithoutVectorizerUsingOperator", input_data=store_data_with_vectors_in_xcom(), trigger_rule="all_done", ) @@ -130,82 +128,52 @@ def store_data_with_vectors_in_xcom(): batch_data_with_vectors_callable_data = WeaviateIngestOperator( task_id="batch_data_with_vectors_callable_data", conn_id="weaviate_default", - class_name="QuestionWithoutVectorizerUsingOperator", + collection_name="QuestionWithoutVectorizerUsingOperator", input_data=get_data_with_vectors(), trigger_rule="all_done", ) # [END howto_operator_weaviate_embedding_and_ingest_callable_data_with_vectors] - # Example tasks to create class with OpenAI vectorizer, store data without vectors in XCOM, and call + # Example tasks to create collection with OpenAI vectorizer, store data without vectors in XCOM, and call # WeaviateIngestOperator to ingest data by internally generating OpenAI vectors while ingesting. @task() - def create_class_with_vectorizer(): + def create_collection_with_vectorizer(): """ - Example task to create class with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. + Example task to create collection with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - class_obj = { - "class": "QuestionWithOpenAIVectorizerUsingOperator", - "description": "Information from a Jeopardy! question", # description of the class - "properties": [ - { - "dataType": ["text"], - "description": "The question", - "name": "question", - }, - { - "dataType": ["text"], - "description": "The answer", - "name": "answer", - }, - { - "dataType": ["text"], - "description": "The category", - "name": "category", - }, + weaviate_hook.create_collection( + "QuestionWithOpenAIVectorizerUsingOperator", + description="Information from a Jeopardy! question", + properties=[ + Property(name="question", description="The question", data_type=DataType.TEXT), + Property(name="answer", description="The answer", data_type=DataType.TEXT), + Property(name="category", description="The category", data_type=DataType.TEXT), ], - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @task() - def create_class_for_doc_data_with_vectorizer(): + def create_collection_for_doc_data_with_vectorizer(): """ - Example task to create class with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. + Example task to create collection with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - class_obj = { - "class": "QuestionWithOpenAIVectorizerUsingOperatorDocs", - "description": "Information from a Jeopardy! question", # description of the class - "properties": [ - { - "dataType": ["text"], - "description": "The question", - "name": "question", - }, - { - "dataType": ["text"], - "description": "The answer", - "name": "answer", - }, - { - "dataType": ["text"], - "description": "The category", - "name": "category", - }, - { - "dataType": ["text"], - "description": "URL for source document", - "name": "docLink", - }, + weaviate_hook.create_collection( + "QuestionWithOpenAIVectorizerUsingOperatorDocs", + description="Information from a Jeopardy! question", + properties=[ + Property(name="question", description="The question", data_type=DataType.TEXT), + Property(name="answer", description="The answer", data_type=DataType.TEXT), + Property(name="category", description="The category", data_type=DataType.TEXT), + Property(name="docLink", description="URL for source document", data_type=DataType.TEXT), ], - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @task(trigger_rule="all_done") def store_data_without_vectors_in_xcom(): @@ -224,13 +192,14 @@ def store_doc_data_without_vectors_in_xcom(): return data xcom_data_without_vectors = store_data_without_vectors_in_xcom() + xcom_doc_data_without_vectors = store_doc_data_without_vectors_in_xcom() # [START howto_operator_weaviate_ingest_xcom_data_without_vectors] batch_data_without_vectors_xcom_data = WeaviateIngestOperator( task_id="batch_data_without_vectors_xcom_data", conn_id="weaviate_default", - class_name="QuestionWithOpenAIVectorizerUsingOperator", + collection_name="QuestionWithOpenAIVectorizerUsingOperator", input_data=xcom_data_without_vectors["return_value"], trigger_rule="all_done", ) @@ -240,7 +209,7 @@ def store_doc_data_without_vectors_in_xcom(): batch_data_without_vectors_callable_data = WeaviateIngestOperator( task_id="batch_data_without_vectors_callable_data", conn_id="weaviate_default", - class_name="QuestionWithOpenAIVectorizerUsingOperator", + collection_name="QuestionWithOpenAIVectorizerUsingOperator", input_data=get_data_without_vectors(), trigger_rule="all_done", ) @@ -251,24 +220,23 @@ def store_doc_data_without_vectors_in_xcom(): existing="replace", document_column="docLink", conn_id="weaviate_default", - class_name="QuestionWithOpenAIVectorizerUsingOperatorDocs", - batch_config_params={"batch_size": 1000}, + collection_name="QuestionWithOpenAIVectorizerUsingOperatorDocs", input_data=xcom_doc_data_without_vectors["return_value"], trigger_rule="all_done", ) @teardown @task - def delete_weaviate_class_Vector(): + def delete_weaviate_collection_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes( + weaviate_hook.delete_collections( [ "QuestionWithOpenAIVectorizerUsingOperator", ] @@ -276,16 +244,16 @@ def delete_weaviate_class_Vector(): @teardown @task - def delete_weaviate_class_without_Vector(): + def delete_weaviate_collection_without_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes( + weaviate_hook.delete_collections( [ "QuestionWithoutVectorizerUsingOperator", ] @@ -293,34 +261,34 @@ def delete_weaviate_class_without_Vector(): @teardown @task - def delete_weaviate_docs_class_without_Vector(): + def delete_weaviate_docs_collection_without_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["QuestionWithOpenAIVectorizerUsingOperatorDocs"]) + weaviate_hook.delete_collections(["QuestionWithOpenAIVectorizerUsingOperatorDocs"]) ( - create_class_without_vectorizer() + create_collection_without_vectorizer() >> [batch_data_with_vectors_xcom_data, batch_data_with_vectors_callable_data] - >> delete_weaviate_class_without_Vector() + >> delete_weaviate_collection_without_vector() ) ( - create_class_for_doc_data_with_vectorizer() + create_collection_for_doc_data_with_vectorizer() >> [create_or_replace_document_objects_without_vectors] - >> delete_weaviate_docs_class_without_Vector() + >> delete_weaviate_docs_collection_without_vector() ) ( - create_class_with_vectorizer() + create_collection_with_vectorizer() >> [ batch_data_without_vectors_xcom_data, batch_data_without_vectors_callable_data, ] - >> delete_weaviate_class_Vector() + >> delete_weaviate_collection_vector() ) diff --git a/tests/system/providers/weaviate/example_weaviate_using_hook.py b/tests/system/providers/weaviate/example_weaviate_using_hook.py index 3cb185007f2c6..11d57bfbef2a9 100644 --- a/tests/system/providers/weaviate/example_weaviate_using_hook.py +++ b/tests/system/providers/weaviate/example_weaviate_using_hook.py @@ -17,6 +17,8 @@ from __future__ import annotations import pendulum +from weaviate.classes.config import DataType, Property +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, task, teardown @@ -31,51 +33,37 @@ def example_weaviate_dag_using_hook(): """Example Weaviate DAG demonstrating usage of the hook.""" @task() - def create_class_with_vectorizer(): + def create_collection_with_vectorizer(): """ - Example task to create class with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. + Example task to create collection with OpenAI Vectorizer responsible for vectorining data using Weaviate cluster. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - class_obj = { - "class": "QuestionWithOpenAIVectorizerUsingHook", - "description": "Information from a Jeopardy! question", # description of the class - "properties": [ - { - "dataType": ["text"], - "description": "The question", - "name": "question", - }, - { - "dataType": ["text"], - "description": "The answer", - "name": "answer", - }, - { - "dataType": ["text"], - "description": "The category", - "name": "category", - }, + weaviate_hook.create_collection( + "QuestionWithOpenAIVectorizerUsingHook", + description="Information from a Jeopardy! question", + properties=[ + Property(name="question", description="The question", data_type=DataType.TEXT), + Property(name="answer", description="The answer", data_type=DataType.TEXT), + Property(name="category", description="The category", data_type=DataType.TEXT), ], - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @task() - def create_class_without_vectorizer(): + def create_collection_without_vectorizer(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": "QuestionWithoutVectorizerUsingHook", - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection( + "QuestionWithoutVectorizerUsingHook", + vectorizer_config=None, + ) @task(trigger_rule="all_done") def store_data_without_vectors_in_xcom(): @@ -109,42 +97,42 @@ def batch_data_with_vectors(data: list): @teardown @task - def delete_weaviate_class_Vector(): + def delete_weaviate_collection_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["QuestionWithOpenAIVectorizerUsingHook"]) + weaviate_hook.delete_collections(["QuestionWithOpenAIVectorizerUsingHook"]) @teardown @task - def delete_weaviate_class_without_Vector(): + def delete_weaviate_collection_without_vector(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes(["QuestionWithoutVectorizerUsingHook"]) + weaviate_hook.delete_collections(["QuestionWithoutVectorizerUsingHook"]) data_with_vectors = store_data_with_vectors_in_xcom() ( - create_class_without_vectorizer() + create_collection_without_vectorizer() >> batch_data_with_vectors(data_with_vectors["return_value"]) - >> delete_weaviate_class_Vector() + >> delete_weaviate_collection_vector() ) data_without_vectors = store_data_without_vectors_in_xcom() ( - create_class_with_vectorizer() + create_collection_with_vectorizer() >> batch_data_without_vectors(data_without_vectors["return_value"]) - >> delete_weaviate_class_without_Vector() + >> delete_weaviate_collection_without_vector() ) diff --git a/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py b/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py index 77a6c6c1cd676..c78e431396d8c 100644 --- a/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py +++ b/tests/system/providers/weaviate/example_weaviate_vectorizer_dag.py @@ -17,11 +17,12 @@ from __future__ import annotations import pendulum +from weaviate.collections.classes.config import Configure from airflow.decorators import dag, setup, task, teardown from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator -class_name = "Weaviate_with_vectorizer_example_class" +collection_name = "Weaviate_with_vectorizer_example_collection" @dag( @@ -37,19 +38,18 @@ def example_weaviate_vectorizer_dag(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": class_name, - "vectorizer": "text2vec-openai", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection( + collection_name, + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + ) @setup @task @@ -65,7 +65,7 @@ def get_data_to_ingest(): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name=class_name, + collection_name=collection_name, input_data=data_to_ingest["return_value"], ) @@ -75,26 +75,25 @@ def query_weaviate(): weaviate_hook = WeaviateHook() properties = ["question", "answer", "category"] - response = weaviate_hook.query_without_vector( - "biology", "Weaviate_with_vectorizer_example_class", *properties + response = weaviate_hook.query_with_text( + "biology", "Weaviate_with_vectorizer_example_collection", *properties ) - assert "In 1953 Watson & Crick built a model" in response["data"]["Get"][class_name][0]["question"] + assert "In 1953 Watson & Crick built a model" in response.objects[0].properties["question"] @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collection(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes([class_name]) + weaviate_hook.delete_collections([collection_name]) - delete_weaviate_class = delete_weaviate_class() - create_weaviate_class() >> perform_ingestion >> query_weaviate() >> delete_weaviate_class + create_weaviate_collection() >> perform_ingestion >> query_weaviate() >> delete_weaviate_collection() example_weaviate_vectorizer_dag() diff --git a/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py b/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py index c73aa7f2d067b..a8430661afa7e 100644 --- a/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py +++ b/tests/system/providers/weaviate/example_weaviate_without_vectorizer_dag.py @@ -22,7 +22,7 @@ from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator -class_name = "Weaviate_example_without_vectorizer_class" +collection_name = "Weaviate_example_without_vectorizer_collection" @dag( @@ -38,19 +38,15 @@ def example_weaviate_without_vectorizer_dag(): @setup @task - def create_weaviate_class(): + def create_weaviate_collection(): """ - Example task to create class without any Vectorizer. You're expected to provide custom vectors for your data. + Example task to create collection without any Vectorizer. You're expected to provide custom vectors for your data. """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. - class_obj = { - "class": class_name, - "vectorizer": "none", - } - weaviate_hook.create_class(class_obj) + # collection definition object. Weaviate's autoschema feature will infer properties when importing. + weaviate_hook.create_collection(collection_name, vectorizer_config=None) @setup @task @@ -66,7 +62,7 @@ def get_data_without_vectors(): perform_ingestion = WeaviateIngestOperator( task_id="perform_ingestion", conn_id="weaviate_default", - class_name=class_name, + collection_name=collection_name, input_data=data_to_ingest["return_value"], ) @@ -86,29 +82,29 @@ def query_weaviate(**kwargs): weaviate_hook = WeaviateHook() properties = ["question", "answer", "category"] response = weaviate_hook.query_with_vector( - query_vector, "Weaviate_example_without_vectorizer_class", *properties + query_vector, "Weaviate_example_without_vectorizer_collection", *properties ) - assert "In 1953 Watson & Crick built a model" in response["data"]["Get"][class_name][0]["question"] + assert "In 1953 Watson & Crick built a model" in response.objects[0].properties["question"] @teardown @task - def delete_weaviate_class(): + def delete_weaviate_collection(): """ - Example task to delete a weaviate class + Example task to delete a weaviate collection """ from airflow.providers.weaviate.hooks.weaviate import WeaviateHook weaviate_hook = WeaviateHook() - # Class definition object. Weaviate's autoschema feature will infer properties when importing. + # collection definition object. Weaviate's autoschema feature will infer properties when importing. - weaviate_hook.delete_classes([class_name]) + weaviate_hook.delete_collections([collection_name]) ( - create_weaviate_class() + create_weaviate_collection() >> perform_ingestion >> embedd_query >> query_weaviate() - >> delete_weaviate_class() + >> delete_weaviate_collection() )