diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 470d9bf54f6b1..ccca02d357620 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -183,7 +183,7 @@ def create_schema(self, schema_json: dict[str, Any] | str) -> None: client.schema.create(schema_json) @staticmethod - def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame) -> list[dict[str, Any]]: + def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) -> list[dict[str, Any]]: """Helper function to convert dataframe to list of dicts. In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing @@ -382,7 +382,7 @@ def check_subset_of_schema(self, classes_objects: list) -> bool: def batch_data( self, class_name: str, - data: list[dict[str, Any]] | pd.DataFrame, + data: list[dict[str, Any]] | pd.DataFrame | None, batch_config_params: dict[str, Any] | None = None, vector_col: str = "Vector", uuid_col: str = "id", @@ -401,7 +401,7 @@ def batch_data( :param retry_attempts_per_object: number of time to try in case of failure before giving up. :param tenant: The tenant to which the object will be added. """ - data = self._convert_dataframe_to_list(data) + converted_data = self._convert_dataframe_to_list(data) total_results = 0 error_results = 0 insertion_errors: list = [] @@ -437,7 +437,7 @@ def _process_batch_errors( self.log.info( "Total Objects %s / Objects %s successfully inserted and Objects %s had errors.", - len(data), + len(converted_data), total_results, error_results, ) @@ -460,7 +460,7 @@ def _process_batch_errors( client.batch.configure(**batch_config_params) with client.batch as batch: # Batch import all data - for index, data_obj in enumerate(data): + for index, data_obj in enumerate(converted_data): for attempt in Retrying( stop=stop_after_attempt(retry_attempts_per_object), retry=( diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py index d4dadf261cc3a..586caaa993b06 100644 --- a/airflow/providers/weaviate/operators/weaviate.py +++ b/airflow/providers/weaviate/operators/weaviate.py @@ -74,9 +74,8 @@ def __init__( self.input_json = input_json self.uuid_column = uuid_column self.tenant = tenant - if input_data is not None: - self.input_data = input_data - elif input_json is not None: + self.input_data = input_data + if (self.input_data is None) and (input_json is not None): warnings.warn( "Passing 'input_json' to WeaviateIngestOperator is deprecated and" " you should use 'input_data' instead", @@ -84,7 +83,7 @@ def __init__( stacklevel=2, ) self.input_data = input_json - else: + elif self.input_data is None and input_json is None: raise TypeError("Either input_json or input_data is required") @cached_property