diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py index 1dd7cb03bafbf..6dae1e36d6fab 100644 --- a/airflow/providers/google/cloud/hooks/automl.py +++ b/airflow/providers/google/cloud/hooks/automl.py @@ -15,166 +15,628 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations +""" +This module contains a Google AutoML hook. -import warnings +.. spelling:word-list:: -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning + PredictResponse +""" +from __future__ import annotations -class CloudAutoMLHook: - """ - Former Google Cloud AutoML hook. +from functools import cached_property +from typing import TYPE_CHECKING, Sequence + +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud.automl_v1beta1 import ( + AutoMlClient, + BatchPredictInputConfig, + BatchPredictOutputConfig, + Dataset, + ExamplePayload, + ImageObjectDetectionModelDeploymentMetadata, + InputConfig, + Model, + PredictionServiceClient, + PredictResponse, +) + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO +from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook + +if TYPE_CHECKING: + from google.api_core.operation import Operation + from google.api_core.retry import Retry + from google.cloud.automl_v1beta1.services.auto_ml.pagers import ( + ListColumnSpecsPager, + ListDatasetsPager, + ListTableSpecsPager, + ) + from google.protobuf.field_mask_pb2 import FieldMask - Deprecated as AutoML API becomes unusable starting March 31, 2024: - https://cloud.google.com/automl/docs - """ - deprecation_warning = ( - "CloudAutoMLHook has been deprecated, as AutoML API becomes unusable starting " - "March 31, 2024, and will be removed in future release. Please use an equivalent " - " Vertex AI hook available in" - "airflow.providers.google.cloud.hooks.vertex_ai instead." - ) +class CloudAutoMLHook(GoogleBaseHook): + """ + Google Cloud AutoML hook. - method_exception = "This method cannot be used as AutoML API becomes unusable." + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ - def __init__(self, **_) -> None: - warnings.warn(self.deprecation_warning, AirflowProviderDeprecationWarning) + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + if kwargs.get("delegate_to") is not None: + raise RuntimeError( + "The `delegate_to` parameter has been deprecated before and finally removed in this version" + " of Google Provider. You MUST convert it to `impersonate_chain`" + ) + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + ) + self._client: AutoMlClient | None = None @staticmethod def extract_object_id(obj: dict) -> str: """Return unique id of the object.""" - warnings.warn( - "'extract_object_id' method is deprecated and will be removed in future release.", - AirflowProviderDeprecationWarning, - ) return obj["name"].rpartition("/")[-1] - def get_conn(self): - """ - Retrieve connection to AutoML (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def wait_for_operation(self, **_): - """ - Wait for long-lasting operation to complete (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def prediction_client(self, **_): - """ - Create a PredictionServiceClient (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def create_model(self, **_): - """ - Create a model_id and returns a Model in the `response` field when it completes (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def batch_predict(self, **_): - """ - Perform a batch prediction (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def predict(self, **_): - """ - Perform an online prediction (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def create_dataset(self, **_): - """ - Create a dataset (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def import_data(self, **_): - """ - Import data (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def list_column_specs(self, **_): - """ - List column specs (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def get_model(self, **_): - """ - Get a model (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def delete_model(self, **_): - """ - Delete a model (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def update_dataset(self, **_): - """ - Update a model (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def deploy_model(self, **_): - """ - Deploy a model (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def list_table_specs(self, **_): - """ - List table specs (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def list_datasets(self, **_): - """ - List datasets (deprecated). - - :raises: AirflowException - """ - raise AirflowException(self.method_exception) - - def delete_dataset(self, **_): - """ - Delete a dataset (deprecated). + def get_conn(self) -> AutoMlClient: + """ + Retrieve connection to AutoML. + + :return: Google Cloud AutoML client object. + """ + if self._client is None: + self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) + return self._client + + def wait_for_operation(self, operation: Operation, timeout: float | None = None): + """Wait for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @cached_property + def prediction_client(self) -> PredictionServiceClient: + """ + Creates PredictionServiceClient. + + :return: Google Cloud AutoML PredictionServiceClient client object. + """ + return PredictionServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) + + @GoogleBaseHook.fallback_to_default_project_id + def create_model( + self, + model: dict | Model, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + ) -> Operation: + """ + Create a model_id and returns a Model in the `response` field when it completes. + + When you create a model, several model evaluations are created for it: + a global evaluation, and one evaluation for each annotation spec. + + :param model: The model_id to create. If a dict is provided, it must be of the same form + as the protobuf message `google.cloud.automl_v1beta1.types.Model` + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + return client.create_model( + request={"parent": parent, "model": model}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) - :raises: AirflowException - """ - raise AirflowException(self.method_exception) + @GoogleBaseHook.fallback_to_default_project_id + def batch_predict( + self, + model_id: str, + input_config: dict | BatchPredictInputConfig, + output_config: dict | BatchPredictOutputConfig, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + params: dict[str, str] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Perform a batch prediction and returns a long-running operation object. + + Unlike the online `Predict`, batch prediction result won't be immediately + available in the response. Instead, a long-running operation object is returned. + + :param model_id: Name of the model_id requested to serve the batch prediction. + :param input_config: Required. The input configuration for batch prediction. + If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictInputConfig` + :param output_config: Required. The Configuration specifying where output predictions should be + written. If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig` + :param params: Additional domain-specific parameters for the predictions, any string must be up to + 25000 characters long. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.prediction_client + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.batch_predict( + request={ + "name": name, + "input_config": input_config, + "output_config": output_config, + "params": params, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def predict( + self, + model_id: str, + payload: dict | ExamplePayload, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + params: dict[str, str] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> PredictResponse: + """ + Perform an online prediction and returns the prediction result in the response. + + :param model_id: Name of the model_id requested to serve the prediction. + :param payload: Required. Payload to perform a prediction on. The payload must match the problem type + that the model_id was trained to solve. If a dict is provided, it must be of + the same form as the protobuf message `google.cloud.automl_v1beta1.types.ExamplePayload` + :param params: Additional domain-specific parameters, any string must be up to 25000 characters long. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance + """ + client = self.prediction_client + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.predict( + request={"name": name, "payload": payload, "params": params}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_dataset( + self, + dataset: dict | Dataset, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Dataset: + """ + Create a dataset. + + :param dataset: The dataset to create. If a dict is provided, it must be of the + same form as the protobuf message Dataset. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types.Dataset` instance. + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + result = client.create_dataset( + request={"parent": parent, "dataset": dataset}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def import_data( + self, + dataset_id: str, + location: str, + input_config: dict | InputConfig, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Import data into a dataset. For Tables this method can only be called on an empty Dataset. + + :param dataset_id: Name of the AutoML dataset. + :param input_config: The desired input location and its domain specific semantics, if any. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.import_data( + request={"name": name, "input_config": input_config}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_column_specs( + self, + dataset_id: str, + table_spec_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + field_mask: dict | FieldMask | None = None, + filter_: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListColumnSpecsPager: + """ + List column specs in a table spec. + + :param dataset_id: Name of the AutoML dataset. + :param table_spec_id: table_spec_id for path builder. + :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same + form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask` + :param filter_: Filter expression, see go/filtering. + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types.ColumnSpec` instance. + """ + client = self.get_conn() + parent = client.table_spec_path( + project=project_id, + location=location, + dataset=dataset_id, + table_spec=table_spec_id, + ) + result = client.list_column_specs( + request={"parent": parent, "field_mask": field_mask, "filter": filter_, "page_size": page_size}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_model( + self, + model_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Model: + """ + Get a AutoML model. + + :param model_id: Name of the model. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types.Model` instance. + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.get_model( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_model( + self, + model_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete a AutoML model. + + :param model_id: Name of the model. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.delete_model( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def update_dataset( + self, + dataset: dict | Dataset, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Dataset: + """ + Update a dataset. + + :param dataset: The dataset which replaces the resource on the server. + If a dict is provided, it must be of the same form as the protobuf message Dataset. + :param update_mask: The update mask applies to the resource. If a dict is provided, it must + be of the same form as the protobuf message FieldMask. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types.Dataset` instance.. + """ + client = self.get_conn() + result = client.update_dataset( + request={"dataset": dataset, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def deploy_model( + self, + model_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + image_detection_metadata: ImageObjectDetectionModelDeploymentMetadata | dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Deploys a model. + + If a model is already deployed, deploying it with the same parameters + has no effect. Deploying with different parameters (as e.g. changing node_number) will + reset the deployment state without pausing the model_id's availability. + + Only applicable for Text Classification, Image Object Detection and Tables; all other + domains manage deployment automatically. + + :param model_id: Name of the model requested to serve the prediction. + :param image_detection_metadata: Model deployment metadata specific to Image Object Detection. + If a dict is provided, it must be of the same form as the protobuf message + ImageObjectDetectionModelDeploymentMetadata + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.deploy_model( + request={ + "name": name, + "image_object_detection_model_deployment_metadata": image_detection_metadata, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def list_table_specs( + self, + dataset_id: str, + location: str, + project_id: str | None = None, + filter_: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListTableSpecsPager: + """ + List table specs in a dataset_id. + + :param dataset_id: Name of the dataset. + :param filter_: Filter expression, see go/filtering. + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: A `google.gax.PageIterator` instance. By default, this + is an iterable of `google.cloud.automl_v1beta1.types.TableSpec` instances. + This object can also be configured to iterate over the pages + of the response through the `options` parameter. + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.list_table_specs( + request={"parent": parent, "filter": filter_, "page_size": page_size}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_datasets( + self, + location: str, + project_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListDatasetsPager: + """ + List datasets in a project. + + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: A `google.gax.PageIterator` instance. By default, this + is an iterable of `google.cloud.automl_v1beta1.types.Dataset` instances. + This object can also be configured to iterate over the pages + of the response through the `options` parameter. + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + result = client.list_datasets( + request={"parent": parent}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_dataset( + self, + dataset_id: str, + location: str, + project_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete a dataset and all of its contents. + + :param dataset_id: ID of dataset to be deleted. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.delete_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/airflow/providers/google/cloud/links/automl.py b/airflow/providers/google/cloud/links/automl.py index b57601c64906c..79561d5b48132 100644 --- a/airflow/providers/google/cloud/links/automl.py +++ b/airflow/providers/google/cloud/links/automl.py @@ -19,28 +19,13 @@ from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.links.base import BaseGoogleLink if TYPE_CHECKING: from airflow.utils.context import Context - -def __getattr__(name: str) -> Any: - warnings.warn( - ( - "AutoML links module have been deprecated and will be removed in the next MAJOR release." - " Please use equivalent Vertex AI links instead" - ), - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - return getattr(__name__, name) - - AUTOML_BASE_LINK = "https://console.cloud.google.com/automl-tables" AUTOML_DATASET_LINK = ( AUTOML_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/schemav2?project={project_id}" diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index ca32994193fbb..54d31025c0ca2 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -15,16 +15,1259 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use `airflow.providers.google.cloud.vertex_ai.auto_ml` instead.""" +"""This module contains Google AutoML operators.""" from __future__ import annotations +import ast import warnings +from typing import TYPE_CHECKING, Sequence, Tuple -from airflow.exceptions import AirflowProviderDeprecationWarning +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud.automl_v1beta1 import ( + BatchPredictResult, + ColumnSpec, + Dataset, + Model, + PredictResponse, + TableSpec, +) -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.vertex_ai.auto_ml` instead.", - AirflowProviderDeprecationWarning, - stacklevel=2, +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.links.automl import ( + AutoMLDatasetLink, + AutoMLDatasetListLink, + AutoMLModelLink, + AutoMLModelPredictLink, + AutoMLModelTrainLink, ) +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from google.api_core.retry import Retry + + from airflow.utils.context import Context + +MetaData = Sequence[Tuple[str, str]] + + +class AutoMLTrainModelOperator(GoogleCloudBaseOperator): + """ + Creates Google Cloud AutoML model. + + AutoMLTrainModelOperator for text prediction is deprecated. Please use + :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator` + instead. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTrainModelOperator` + + :param model: Model definition. + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "model", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = ( + AutoMLModelTrainLink(), + AutoMLModelLink(), + ) + + def __init__( + self, + *, + model: dict, + location: str, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.model = model + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + # Output warning if running not AutoML Translation prediction job + if "translation_model_metadata" not in self.model: + warnings.warn( + "AutoMLTrainModelOperator for text, image and video prediction is deprecated. " + "All the functionality of legacy " + "AutoML Natural Language, Vision and Video Intelligence and new features are available " + "on the Vertex AI platform. " + "Please use `CreateAutoMLTextTrainingJobOperator`, `CreateAutoMLImageTrainingJobOperator` or" + " `CreateAutoMLVideoTrainingJobOperator` from VertexAI.", + AirflowProviderDeprecationWarning, + stacklevel=3, + ) + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating model %s...", self.model["display_name"]) + operation = hook.create_model( + model=self.model, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelTrainLink.persist(context=context, task_instance=self, project_id=project_id) + operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) + result = Model.to_dict(operation_result) + model_id = hook.extract_object_id(result) + self.log.info("Model is created, model_id: %s", model_id) + + self.xcom_push(context, key="model_id", value=model_id) + if project_id: + AutoMLModelLink.persist( + context=context, + task_instance=self, + dataset_id=self.model["dataset_id"] or "-", + model_id=model_id, + project_id=project_id, + ) + return result + + +class AutoMLPredictOperator(GoogleCloudBaseOperator): + """ + Runs prediction operation on Google Cloud AutoML. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLPredictOperator` + + :param model_id: Name of the model requested to serve the batch prediction. + :param payload: Name od the model used for the prediction. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param operation_params: Additional domain-specific parameters for the predictions. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLModelPredictLink(),) + + def __init__( + self, + *, + model_id: str, + location: str, + payload: dict, + operation_params: dict[str, str] | None = None, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.operation_params = operation_params # type: ignore + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.payload = payload + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.predict( + model_id=self.model_id, + payload=self.payload, + location=self.location, + project_id=self.project_id, + params=self.operation_params, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelPredictLink.persist( + context=context, + task_instance=self, + model_id=self.model_id, + project_id=project_id, + ) + return PredictResponse.to_dict(result) + + +class AutoMLBatchPredictOperator(GoogleCloudBaseOperator): + """ + Perform a batch prediction on Google Cloud AutoML. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLBatchPredictOperator` + + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :param location: The location of the project. + :param model_id: Name of the model_id requested to serve the batch prediction. + :param input_config: Required. The input configuration for batch prediction. + If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictInputConfig` + :param output_config: Required. The Configuration specifying where output predictions should be + written. If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig` + :param prediction_params: Additional domain-specific parameters for the predictions, + any string must be up to 25000 characters long. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "model_id", + "input_config", + "output_config", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLModelPredictLink(),) + + def __init__( + self, + *, + model_id: str, + input_config: dict, + output_config: dict, + location: str, + project_id: str | None = None, + prediction_params: dict[str, str] | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.location = location + self.project_id = project_id + self.prediction_params = prediction_params + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.input_config = input_config + self.output_config = output_config + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Fetch batch prediction.") + operation = hook.batch_predict( + model_id=self.model_id, + input_config=self.input_config, + output_config=self.output_config, + project_id=self.project_id, + location=self.location, + params=self.prediction_params, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) + result = BatchPredictResult.to_dict(operation_result) + self.log.info("Batch prediction is ready.") + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelPredictLink.persist( + context=context, + task_instance=self, + model_id=self.model_id, + project_id=project_id, + ) + return result + + +class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator): + """ + Creates a Google Cloud AutoML dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLCreateDatasetOperator` + + :param dataset: The dataset to create. If a dict is provided, it must be of the + same form as the protobuf message Dataset. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param params: Additional domain-specific parameters for the predictions. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLDatasetLink(),) + + def __init__( + self, + *, + dataset: dict, + location: str, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset = dataset + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating dataset %s...", self.dataset) + result = hook.create_dataset( + dataset=self.dataset, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = Dataset.to_dict(result) + dataset_id = hook.extract_object_id(result) + self.log.info("Creating completed. Dataset id: %s", dataset_id) + + self.xcom_push(context, key="dataset_id", value=dataset_id) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset_id, + project_id=project_id, + ) + return result + + +class AutoMLImportDataOperator(GoogleCloudBaseOperator): + """ + Imports data to a Google Cloud AutoML dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLImportDataOperator` + + :param dataset_id: ID of dataset to be updated. + :param input_config: The desired input location and its domain specific semantics, if any. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param params: Additional domain-specific parameters for the predictions. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "input_config", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLDatasetLink(),) + + def __init__( + self, + *, + dataset_id: str, + location: str, + input_config: dict, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset_id = dataset_id + self.input_config = input_config + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Importing data to dataset...") + operation = hook.import_data( + dataset_id=self.dataset_id, + input_config=self.input_config, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Import is completed") + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) + + +class AutoMLTablesListColumnSpecsOperator(GoogleCloudBaseOperator): + """ + Lists column specs in a table. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTablesListColumnSpecsOperator` + + :param dataset_id: Name of the dataset. + :param table_spec_id: table_spec_id for path builder. + :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same + form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask` + :param filter_: Filter expression, see go/filtering. + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per page, this determines the maximum number + of resources in a page. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "table_spec_id", + "field_mask", + "filter_", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLDatasetLink(),) + + def __init__( + self, + *, + dataset_id: str, + table_spec_id: str, + location: str, + field_mask: dict | None = None, + filter_: str | None = None, + page_size: int | None = None, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.table_spec_id = table_spec_id + self.field_mask = field_mask + self.filter_ = filter_ + self.page_size = page_size + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Requesting column specs.") + page_iterator = hook.list_column_specs( + dataset_id=self.dataset_id, + table_spec_id=self.table_spec_id, + field_mask=self.field_mask, + filter_=self.filter_, + page_size=self.page_size, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = [ColumnSpec.to_dict(spec) for spec in page_iterator] + self.log.info("Columns specs obtained.") + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) + return result + + +class AutoMLTablesUpdateDatasetOperator(GoogleCloudBaseOperator): + """ + Updates a dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTablesUpdateDatasetOperator` + + :param dataset: The dataset which replaces the resource on the server. + If a dict is provided, it must be of the same form as the protobuf message Dataset. + :param update_mask: The update mask applies to the resource. If a dict is provided, it must + be of the same form as the protobuf message FieldMask. + :param location: The location of the project. + :param params: Additional domain-specific parameters for the predictions. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset", + "update_mask", + "location", + "impersonation_chain", + ) + operator_extra_links = (AutoMLDatasetLink(),) + + def __init__( + self, + *, + dataset: dict, + location: str, + update_mask: dict | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset = dataset + self.update_mask = update_mask + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Updating AutoML dataset %s.", self.dataset["name"]) + result = hook.update_dataset( + dataset=self.dataset, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Dataset updated.") + project_id = hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=hook.extract_object_id(self.dataset), + project_id=project_id, + ) + return Dataset.to_dict(result) + + +class AutoMLGetModelOperator(GoogleCloudBaseOperator): + """ + Get Google Cloud AutoML model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLGetModelOperator` + + :param model_id: Name of the model requested to serve the prediction. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param params: Additional domain-specific parameters for the predictions. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLModelLink(),) + + def __init__( + self, + *, + model_id: str, + location: str, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.get_model( + model_id=self.model_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + model = Model.to_dict(result) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelLink.persist( + context=context, + task_instance=self, + dataset_id=model["dataset_id"], + model_id=self.model_id, + project_id=project_id, + ) + return model + + +class AutoMLDeleteModelOperator(GoogleCloudBaseOperator): + """ + Delete Google Cloud AutoML model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLDeleteModelOperator` + + :param model_id: Name of the model requested to serve the prediction. + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param params: Additional domain-specific parameters for the predictions. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + model_id: str, + location: str, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + operation = hook.delete_model( + model_id=self.model_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Deletion is completed") + + +class AutoMLDeployModelOperator(GoogleCloudBaseOperator): + """ + Deploys a model; if a model is already deployed, deploying it with the same parameters has no effect. + + Deploying with different parameters (as e.g. changing node_number) will + reset the deployment state without pausing the model_id's availability. + + Only applicable for Text Classification, Image Object Detection and Tables; all other + domains manage deployment automatically. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLDeployModelOperator` + + :param model_id: Name of the model to be deployed. + :param image_detection_metadata: Model deployment metadata specific to Image Object Detection. + If a dict is provided, it must be of the same form as the protobuf message + ImageObjectDetectionModelDeploymentMetadata + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :param location: The location of the project. + :param params: Additional domain-specific parameters for the predictions. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + model_id: str, + location: str, + project_id: str | None = None, + image_detection_metadata: dict | None = None, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.image_detection_metadata = image_detection_metadata + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Deploying model_id %s", self.model_id) + + operation = hook.deploy_model( + model_id=self.model_id, + location=self.location, + project_id=self.project_id, + image_detection_metadata=self.image_detection_metadata, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Model was deployed successfully.") + + +class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator): + """ + Lists table specs in a dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTablesListTableSpecsOperator` + + :param dataset_id: Name of the dataset. + :param filter_: Filter expression, see go/filtering. + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :param project_id: ID of the Google Cloud project if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "filter_", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLDatasetLink(),) + + def __init__( + self, + *, + dataset_id: str, + location: str, + page_size: int | None = None, + filter_: str | None = None, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.filter_ = filter_ + self.page_size = page_size + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Requesting table specs for %s.", self.dataset_id) + page_iterator = hook.list_table_specs( + dataset_id=self.dataset_id, + filter_=self.filter_, + page_size=self.page_size, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = [TableSpec.to_dict(spec) for spec in page_iterator] + self.log.info(result) + self.log.info("Table specs obtained.") + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) + return result + + +class AutoMLListDatasetOperator(GoogleCloudBaseOperator): + """ + Lists AutoML Datasets in project. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLListDatasetOperator` + + :param project_id: ID of the Google Cloud project where datasets are located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (AutoMLDatasetListLink(),) + + def __init__( + self, + *, + location: str, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Requesting datasets") + page_iterator = hook.list_datasets( + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = [Dataset.to_dict(dataset) for dataset in page_iterator] + self.log.info("Datasets obtained.") + + self.xcom_push( + context, + key="dataset_id_list", + value=[hook.extract_object_id(d) for d in result], + ) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetListLink.persist(context=context, task_instance=self, project_id=project_id) + return result + + +class AutoMLDeleteDatasetOperator(GoogleCloudBaseOperator): + """ + Deletes a dataset and all of its contents. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLDeleteDatasetOperator` + + :param dataset_id: Name of the dataset_id, list of dataset_id or string of dataset_id + coma separated to be deleted. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "location", + "project_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + dataset_id: str | list[str], + location: str, + project_id: str | None = None, + metadata: MetaData = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset_id = dataset_id + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @staticmethod + def _parse_dataset_id(dataset_id: str | list[str]) -> list[str]: + if not isinstance(dataset_id, str): + return dataset_id + try: + return ast.literal_eval(dataset_id) + except (SyntaxError, ValueError): + return dataset_id.split(",") + + def execute(self, context: Context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + dataset_id_list = self._parse_dataset_id(self.dataset_id) + for dataset_id in dataset_id_list: + self.log.info("Deleting dataset %s", dataset_id) + hook.delete_dataset( + dataset_id=dataset_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Dataset deleted.") diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index b70ff2705a0d7..748cfbb39cc28 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -108,6 +108,7 @@ dependencies: - google-auth>=1.0.0 - google-auth-httplib2>=0.0.1 - google-cloud-aiplatform>=1.42.1 + - google-cloud-automl>=2.12.0 - google-cloud-bigquery-datatransfer>=3.13.0 - google-cloud-bigtable>=2.17.0 - google-cloud-build>=3.22.0 @@ -201,6 +202,8 @@ integrations: tags: [gmp] - integration-name: Google AutoML external-doc-url: https://cloud.google.com/automl/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/automl.rst logo: /integration-logos/gcp/Cloud-AutoML.png tags: [gcp] - integration-name: Google BigQuery Data Transfer Service @@ -529,6 +532,9 @@ operators: - integration-name: Google Cloud Common python-modules: - airflow.providers.google.cloud.operators.cloud_base + - integration-name: Google AutoML + python-modules: + - airflow.providers.google.cloud.operators.automl - integration-name: Google BigQuery python-modules: - airflow.providers.google.cloud.operators.bigquery @@ -1229,6 +1235,11 @@ extra-links: - airflow.providers.google.cloud.links.cloud_build.CloudBuildListLink - airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggersListLink - airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggerDetailsLink + - airflow.providers.google.cloud.links.automl.AutoMLDatasetLink + - airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink + - airflow.providers.google.cloud.links.automl.AutoMLModelLink + - airflow.providers.google.cloud.links.automl.AutoMLModelTrainLink + - airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink - airflow.providers.google.cloud.links.life_sciences.LifeSciencesLink - airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsDetailsLink - airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsListLink diff --git a/docker_tests/test_prod_image.py b/docker_tests/test_prod_image.py index b4e59052ee309..ab35c63bffa53 100644 --- a/docker_tests/test_prod_image.py +++ b/docker_tests/test_prod_image.py @@ -131,6 +131,7 @@ def test_pip_dependencies_conflict(self, default_docker_image): "googleapiclient", "google.auth", "google_auth_httplib2", + "google.cloud.automl", "google.cloud.bigquery_datatransfer", "google.cloud.bigtable", "google.cloud.container", diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst b/docs/apache-airflow-providers-google/operators/cloud/automl.rst new file mode 100644 index 0000000000000..34324ea60d914 --- /dev/null +++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst @@ -0,0 +1,229 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Google Cloud AutoML Operators +======================================= + +The `Google Cloud AutoML `__ +makes the power of machine learning available to you even if you have limited knowledge +of machine learning. You can use AutoML to build on Google's machine learning capabilities +to create your own custom machine learning models that are tailored to your business needs, +and then integrate those models into your applications and web sites. + +Prerequisite Tasks +^^^^^^^^^^^^^^^^^^ + +.. include:: /operators/_partials/prerequisite_tasks.rst + +.. _howto/operator:CloudAutoMLDocuments: +.. _howto/operator:AutoMLCreateDatasetOperator: +.. _howto/operator:AutoMLImportDataOperator: +.. _howto/operator:AutoMLTablesUpdateDatasetOperator: + +Creating Datasets +^^^^^^^^^^^^^^^^^ + +To create a Google AutoML dataset you can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator`. +The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_automl_create_dataset] + :end-before: [END howto_operator_automl_create_dataset] + +After creating a dataset you can use it to import some data using +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLImportDataOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_automl_import_data] + :end-before: [END howto_operator_automl_import_data] + +To update dataset you can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_automl_update_dataset] + :end-before: [END howto_operator_automl_update_dataset] + +.. _howto/operator:AutoMLTablesListTableSpecsOperator: +.. _howto/operator:AutoMLTablesListColumnSpecsOperator: + +Listing Table And Columns Specs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To list table specs you can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_automl_specs] + :end-before: [END howto_operator_automl_specs] + +To list column specs you can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListColumnSpecsOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_automl_column_specs] + :end-before: [END howto_operator_automl_column_specs] + +.. _howto/operator:AutoMLTrainModelOperator: +.. _howto/operator:AutoMLGetModelOperator: +.. _howto/operator:AutoMLDeployModelOperator: +.. _howto/operator:AutoMLDeleteModelOperator: + +Operations On Models +^^^^^^^^^^^^^^^^^^^^ + +To create a Google AutoML model you can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator`. +The operator will wait for the operation to complete. Additionally the operator +returns the id of model in :ref:`XCom ` under ``model_id`` key. + +This Operator is deprecated when running for text, video and vision prediction and will be removed soon. +All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are +available on the Vertex AI platform. Please use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`, +:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator` or +:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`. + +You can find example on how to use VertexAI operators for AutoML Natural Language classification here: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py + :language: python + :dedent: 4 + :start-after: [START howto_cloud_create_text_classification_training_job_operator] + :end-before: [END howto_cloud_create_text_classification_training_job_operator] + +Additionally, you can find example on how to use VertexAI operators for AutoML Vision classification here: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_vision_classification.py + :language: python + :dedent: 4 + :start-after: [START howto_cloud_create_image_classification_training_job_operator] + :end-before: [END howto_cloud_create_image_classification_training_job_operator] + +Example on how to use VertexAI operators for AutoML Video Intelligence classification you can find here: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_video_classification.py + :language: python + :dedent: 4 + :start-after: [START howto_cloud_create_video_classification_training_job_operator] + :end-before: [END howto_cloud_create_video_classification_training_job_operator] + +When running Vertex AI Operator for training data, please ensure that your data is correctly stored in Vertex AI +datasets. To create and import data to the dataset please use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator` +and +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator` + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_automl_create_model] + :end-before: [END howto_operator_automl_create_model] + +To get existing model one can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_get_model] + :end-before: [END howto_operator_get_model] + +Once a model is created it could be deployed using +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_deploy_model] + :end-before: [END howto_operator_deploy_model] + +If you wish to delete a model you can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_automl_delete_model] + :end-before: [END howto_operator_automl_delete_model] + +.. _howto/operator:AutoMLPredictOperator: +.. _howto/operator:AutoMLBatchPredictOperator: + +Making Predictions +^^^^^^^^^^^^^^^^^^ + +To obtain predictions from Google Cloud AutoML model you can use +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator` or +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator`. In the first case +the model must be deployed. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_prediction] + :end-before: [END howto_operator_prediction] + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_batch_prediction] + :end-before: [END howto_operator_batch_prediction] + +.. _howto/operator:AutoMLListDatasetOperator: +.. _howto/operator:AutoMLDeleteDatasetOperator: + +Listing And Deleting Datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can get a list of AutoML datasets using +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list +of datasets ids in :ref:`XCom ` under ``dataset_id_list`` key. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_list_dataset] + :end-before: [END howto_operator_list_dataset] + +To delete a dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`. +The delete operator allows also to pass list or coma separated string of datasets ids to be deleted. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_delete_dataset] + :end-before: [END howto_operator_delete_dataset] + +Reference +^^^^^^^^^ + +For further information, look at: + +* `Client Library Documentation `__ +* `Product Documentation `__ diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index d88b975a87ebc..54a02cc5d598b 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -526,6 +526,7 @@ "google-auth-httplib2>=0.0.1", "google-auth>=1.0.0", "google-cloud-aiplatform>=1.42.1", + "google-cloud-automl>=2.12.0", "google-cloud-batch>=0.13.0", "google-cloud-bigquery-datatransfer>=3.13.0", "google-cloud-bigtable>=2.17.0", diff --git a/scripts/in_container/run_provider_yaml_files_check.py b/scripts/in_container/run_provider_yaml_files_check.py index c343bb2397726..b1608c25bcedd 100755 --- a/scripts/in_container/run_provider_yaml_files_check.py +++ b/scripts/in_container/run_provider_yaml_files_check.py @@ -50,17 +50,10 @@ "airflow.providers.apache.hdfs.hooks.hdfs", "airflow.providers.cncf.kubernetes.triggers.kubernetes_pod", "airflow.providers.cncf.kubernetes.operators.kubernetes_pod", - "airflow.providers.google.cloud.operators.automl", ] KNOWN_DEPRECATED_CLASSES = [ - "airflow.providers.google.cloud.links.automl.AutoMLDatasetLink", - "airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink", - "airflow.providers.google.cloud.links.automl.AutoMLModelLink", - "airflow.providers.google.cloud.links.automl.AutoMLModelListLink", - "airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink", "airflow.providers.google.cloud.links.dataproc.DataprocLink", - "airflow.providers.google.cloud.hooks.automl.CloudAutoMLHook", ] try: diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 773a2fc1c4206..5113ce15f4052 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -429,6 +429,8 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest } ASSETS_NOT_REQUIRED = { + "airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator", + "airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator", "airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator", "airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteDatasetOperator", "airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator", diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml index d271e6077a25b..6d27f4f5388d6 100644 --- a/tests/deprecations_ignore.yml +++ b/tests/deprecations_ignore.yml @@ -779,6 +779,7 @@ - tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py::TestCustomJobWithoutDefaultProjectIdHook::test_get_pipeline_job - tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py::TestCustomJobWithoutDefaultProjectIdHook::test_list_pipeline_jobs - tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryCreateExternalTableOperator::test_execute_with_csv_format +- tests/providers/google/cloud/operators/test_automl.py::TestAutoMLTrainModelOperator::test_execute - tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryCreateExternalTableOperator::test_execute_with_parquet_format - tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryOperator::test_bigquery_operator_defaults - tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryOperator::test_bigquery_operator_extra_link_when_missing_job_id diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py index 0f97b91e5a78d..f79dd8b51b73d 100644 --- a/tests/providers/google/cloud/hooks/test_automl.py +++ b/tests/providers/google/cloud/hooks/test_automl.py @@ -17,85 +17,237 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.automl_v1beta1 import AutoMlClient -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.common.consts import CLIENT_INFO +from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id + +CREDENTIALS = "test-creds" +TASK_ID = "test-automl-hook" +GCP_PROJECT_ID = "test-project" +GCP_LOCATION = "test-location" +MODEL_NAME = "test_model" +MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152" +DATASET_ID = "TBL123456789" +MODEL = { + "display_name": MODEL_NAME, + "dataset_id": DATASET_ID, + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, +} + +LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}" +MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}" +DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}" + +INPUT_CONFIG = {"input": "value"} +OUTPUT_CONFIG = {"output": "value"} +PAYLOAD = {"test": "payload"} +DATASET = {"dataset_id": "data"} +MASK = {"field": "mask"} class TestAutoMLHook: - def setup_method(self): - self.hook = CloudAutoMLHook() - - def test_init(self): - with pytest.warns(AirflowProviderDeprecationWarning): - CloudAutoMLHook() - - def test_extract_object_id(self): - with pytest.warns(AirflowProviderDeprecationWarning, match="'extract_object_id'"): - object_id = CloudAutoMLHook.extract_object_id(obj={"name": "x/y"}) - assert object_id == "y" - - def test_get_conn(self): - with pytest.raises(AirflowException): - self.hook.get_conn() - - def test_wait_for_operation(self): - with pytest.raises(AirflowException): - self.hook.wait_for_operation() - - def test_prediction_client(self): - with pytest.raises(AirflowException): - self.hook.prediction_client() - - def test_create_model(self): - with pytest.raises(AirflowException): - self.hook.create_model() - - def test_batch_predict(self): - with pytest.raises(AirflowException): - self.hook.batch_predict() - - def test_predict(self): - with pytest.raises(AirflowException): - self.hook.predict() + def test_delegate_to_runtime_error(self): + with pytest.raises(RuntimeError): + CloudAutoMLHook(gcp_conn_id="GCP_CONN_ID", delegate_to="delegate_to") - def test_create_dataset(self): - with pytest.raises(AirflowException): - self.hook.create_dataset() - - def test_import_data(self): - with pytest.raises(AirflowException): - self.hook.import_data() - - def test_list_column_specs(self): - with pytest.raises(AirflowException): - self.hook.list_column_specs() - - def test_get_model(self): - with pytest.raises(AirflowException): - self.hook.get_model() - - def test_delete_model(self): - with pytest.raises(AirflowException): - self.hook.delete_model() - - def test_update_dataset(self): - with pytest.raises(AirflowException): - self.hook.update_dataset() - - def test_deploy_model(self): - with pytest.raises(AirflowException): - self.hook.deploy_model() - - def test_list_table_specs(self): - with pytest.raises(AirflowException): - self.hook.list_table_specs() - - def test_list_datasets(self): - with pytest.raises(AirflowException): - self.hook.list_datasets() - - def test_delete_dataset(self): - with pytest.raises(AirflowException): - self.hook.delete_dataset() + def setup_method(self): + with mock.patch( + "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.__init__", + new=mock_base_gcp_hook_no_default_project_id, + ): + self.hook = CloudAutoMLHook() + self.hook.get_credentials = mock.MagicMock(return_value=CREDENTIALS) # type: ignore + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient") + def test_get_conn(self, mock_automl_client): + self.hook.get_conn() + mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient") + def test_prediction_client(self, mock_prediction_client): + client = self.hook.prediction_client # noqa: F841 + mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_model") + def test_create_model(self, mock_create_model): + self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) + + mock_create_model.assert_called_once_with( + request=dict(parent=LOCATION_PATH, model=MODEL), retry=DEFAULT, timeout=None, metadata=() + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict") + def test_batch_predict(self, mock_batch_predict): + self.hook.batch_predict( + model_id=MODEL_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + input_config=INPUT_CONFIG, + output_config=OUTPUT_CONFIG, + ) + + mock_batch_predict.assert_called_once_with( + request=dict( + name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict") + def test_predict(self, mock_predict): + self.hook.predict( + model_id=MODEL_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + payload=PAYLOAD, + ) + + mock_predict.assert_called_once_with( + request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset") + def test_create_dataset(self, mock_create_dataset): + self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) + + mock_create_dataset.assert_called_once_with( + request=dict(parent=LOCATION_PATH, dataset=DATASET), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data") + def test_import_dataset(self, mock_import_data): + self.hook.import_data( + dataset_id=DATASET_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + input_config=INPUT_CONFIG, + ) + + mock_import_data.assert_called_once_with( + request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs") + def test_list_column_specs(self, mock_list_column_specs): + table_spec = "table_spec_id" + filter_ = "filter" + page_size = 42 + + self.hook.list_column_specs( + dataset_id=DATASET_ID, + table_spec_id=table_spec, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + field_mask=MASK, + filter_=filter_, + page_size=page_size, + ) + + parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec) + mock_list_column_specs.assert_called_once_with( + request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model") + def test_get_model(self, mock_get_model): + self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) + + mock_get_model.assert_called_once_with( + request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=() + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model") + def test_delete_model(self, mock_delete_model): + self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) + + mock_delete_model.assert_called_once_with( + request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=() + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset") + def test_update_dataset(self, mock_update_dataset): + self.hook.update_dataset( + dataset=DATASET, + update_mask=MASK, + ) + + mock_update_dataset.assert_called_once_with( + request=dict(dataset=DATASET, update_mask=MASK), retry=DEFAULT, timeout=None, metadata=() + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model") + def test_deploy_model(self, mock_deploy_model): + image_detection_metadata = {} + + self.hook.deploy_model( + model_id=MODEL_ID, + image_detection_metadata=image_detection_metadata, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + mock_deploy_model.assert_called_once_with( + request=dict( + name=MODEL_PATH, + image_object_detection_model_deployment_metadata=image_detection_metadata, + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs") + def test_list_table_specs(self, mock_list_table_specs): + filter_ = "filter" + page_size = 42 + + self.hook.list_table_specs( + dataset_id=DATASET_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + filter_=filter_, + page_size=page_size, + ) + + mock_list_table_specs.assert_called_once_with( + request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets") + def test_list_datasets(self, mock_list_datasets): + self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID) + + mock_list_datasets.assert_called_once_with( + request=dict(parent=LOCATION_PATH), retry=DEFAULT, timeout=None, metadata=() + ) + + @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset") + def test_delete_dataset(self, mock_delete_dataset): + self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) + + mock_delete_dataset.assert_called_once_with( + request=dict(name=DATASET_PATH), retry=DEFAULT, timeout=None, metadata=() + ) diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index 1caf680cd512c..4f00f76a2dbef 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -17,13 +17,643 @@ # under the License. from __future__ import annotations -from importlib import import_module +import copy +from unittest import mock import pytest +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLBatchPredictOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLDeployModelOperator, + AutoMLGetModelOperator, + AutoMLImportDataOperator, + AutoMLListDatasetOperator, + AutoMLPredictOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, + AutoMLTrainModelOperator, +) +from airflow.utils import timezone +CREDENTIALS = "test-creds" +TASK_ID = "test-automl-hook" +GCP_PROJECT_ID = "test-project" +GCP_LOCATION = "test-location" +MODEL_NAME = "test_model" +MODEL_ID = "TBL9195602771183665152" +DATASET_ID = "TBL123456789" +MODEL = { + "display_name": MODEL_NAME, + "dataset_id": DATASET_ID, + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, +} -def test_deprecated_module(): - with pytest.warns(AirflowProviderDeprecationWarning): - import_module("airflow.providers.google.cloud.operators.automl") +LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}" +MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}" +DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}" + +INPUT_CONFIG = {"input": "value"} +OUTPUT_CONFIG = {"output": "value"} +PAYLOAD = {"test": "payload"} +DATASET = {"dataset_id": "data"} +MASK = {"field": "mask"} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +class TestAutoMLTrainModelOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH) + mock_hook.return_value.extract_object_id = extract_object_id + mock_hook.return_value.wait_for_operation.return_value = Model() + op = AutoMLTrainModelOperator( + model=MODEL, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + task_id=TASK_ID, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.create_model.assert_called_once_with( + model=MODEL, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTrainModelOperator, + # Templated fields + model="{{ 'model' }}", + location="{{ 'location' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTrainModelOperator = ti.task + assert task.model == "model" + assert task.location == "location" + assert task.impersonation_chain == "impersonation_chain" + + +class TestAutoMLBatchPredictOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult() + mock_hook.return_value.extract_object_id = extract_object_id + mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult() + + op = AutoMLBatchPredictOperator( + model_id=MODEL_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + input_config=INPUT_CONFIG, + output_config=OUTPUT_CONFIG, + task_id=TASK_ID, + prediction_params={}, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.batch_predict.assert_called_once_with( + input_config=INPUT_CONFIG, + location=GCP_LOCATION, + metadata=(), + model_id=MODEL_ID, + output_config=OUTPUT_CONFIG, + params={}, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLBatchPredictOperator, + # Templated fields + model_id="{{ 'model' }}", + input_config="{{ 'input-config' }}", + output_config="{{ 'output-config' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLBatchPredictOperator = ti.task + assert task.model_id == "model" + assert task.input_config == "input-config" + assert task.output_config == "output-config" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLPredictOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + mock_hook.return_value.predict.return_value = PredictResponse() + + op = AutoMLPredictOperator( + model_id=MODEL_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + payload=PAYLOAD, + task_id=TASK_ID, + operation_params={"TEST_KEY": "TEST_VALUE"}, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.predict.assert_called_once_with( + location=GCP_LOCATION, + metadata=(), + model_id=MODEL_ID, + params={"TEST_KEY": "TEST_VALUE"}, + payload=PAYLOAD, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLPredictOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + payload={}, + ) + ti.render_templates() + task: AutoMLPredictOperator = ti.task + assert task.model_id == "model-id" + assert task.project_id == "project-id" + assert task.location == "location" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLCreateImportOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH) + mock_hook.return_value.extract_object_id = extract_object_id + + op = AutoMLCreateDatasetOperator( + dataset=DATASET, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + task_id=TASK_ID, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.create_dataset.assert_called_once_with( + dataset=DATASET, + location=GCP_LOCATION, + metadata=(), + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLCreateDatasetOperator, + # Templated fields + dataset="{{ 'dataset' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLCreateDatasetOperator = ti.task + assert task.dataset == "dataset" + assert task.project_id == "project-id" + assert task.location == "location" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLListColumnsSpecsOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + table_spec = "table_spec_id" + filter_ = "filter" + page_size = 42 + + op = AutoMLTablesListColumnSpecsOperator( + dataset_id=DATASET_ID, + table_spec_id=table_spec, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + field_mask=MASK, + filter_=filter_, + page_size=page_size, + task_id=TASK_ID, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.list_column_specs.assert_called_once_with( + dataset_id=DATASET_ID, + field_mask=MASK, + filter_=filter_, + location=GCP_LOCATION, + metadata=(), + page_size=page_size, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + table_spec_id=table_spec, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTablesListColumnSpecsOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + table_spec_id="{{ 'table-spec-id' }}", + field_mask="{{ 'field-mask' }}", + filter_="{{ 'filter-' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTablesListColumnSpecsOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.table_spec_id == "table-spec-id" + assert task.field_mask == "field-mask" + assert task.filter_ == "filter-" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLUpdateDatasetOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH) + + dataset = copy.deepcopy(DATASET) + dataset["name"] = DATASET_ID + + op = AutoMLTablesUpdateDatasetOperator( + dataset=dataset, + update_mask=MASK, + location=GCP_LOCATION, + task_id=TASK_ID, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.update_dataset.assert_called_once_with( + dataset=dataset, + metadata=(), + retry=DEFAULT, + timeout=None, + update_mask=MASK, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTablesUpdateDatasetOperator, + # Templated fields + dataset="{{ 'dataset' }}", + update_mask="{{ 'update-mask' }}", + location="{{ 'location' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTablesUpdateDatasetOperator = ti.task + assert task.dataset == "dataset" + assert task.update_mask == "update-mask" + assert task.location == "location" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLGetModelOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH) + mock_hook.return_value.extract_object_id = extract_object_id + + op = AutoMLGetModelOperator( + model_id=MODEL_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + task_id=TASK_ID, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.get_model.assert_called_once_with( + location=GCP_LOCATION, + metadata=(), + model_id=MODEL_ID, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLGetModelOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLGetModelOperator = ti.task + assert task.model_id == "model-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLDeleteModelOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + op = AutoMLDeleteModelOperator( + model_id=MODEL_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + task_id=TASK_ID, + ) + op.execute(context=None) + mock_hook.return_value.delete_model.assert_called_once_with( + location=GCP_LOCATION, + metadata=(), + model_id=MODEL_ID, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLDeleteModelOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLDeleteModelOperator = ti.task + assert task.model_id == "model-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLDeployModelOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + image_detection_metadata = {} + op = AutoMLDeployModelOperator( + model_id=MODEL_ID, + image_detection_metadata=image_detection_metadata, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + task_id=TASK_ID, + ) + op.execute(context=None) + mock_hook.return_value.deploy_model.assert_called_once_with( + image_detection_metadata={}, + location=GCP_LOCATION, + metadata=(), + model_id=MODEL_ID, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLDeployModelOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLDeployModelOperator = ti.task + assert task.model_id == "model-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLDatasetImportOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + op = AutoMLImportDataOperator( + dataset_id=DATASET_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + input_config=INPUT_CONFIG, + task_id=TASK_ID, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.import_data.assert_called_once_with( + input_config=INPUT_CONFIG, + location=GCP_LOCATION, + metadata=(), + dataset_id=DATASET_ID, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLImportDataOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + input_config="{{ 'input-config' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLImportDataOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.input_config == "input-config" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLTablesListTableSpecsOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + filter_ = "filter" + page_size = 42 + + op = AutoMLTablesListTableSpecsOperator( + dataset_id=DATASET_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + filter_=filter_, + page_size=page_size, + task_id=TASK_ID, + ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.list_table_specs.assert_called_once_with( + dataset_id=DATASET_ID, + filter_=filter_, + location=GCP_LOCATION, + metadata=(), + page_size=page_size, + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTablesListTableSpecsOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + filter_="{{ 'filter-' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTablesListTableSpecsOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.filter_ == "filter-" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLDatasetListOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.list_datasets.assert_called_once_with( + location=GCP_LOCATION, + metadata=(), + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLListDatasetOperator, + # Templated fields + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLListDatasetOperator = ti.task + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + + +class TestAutoMLDatasetDeleteOperator: + @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") + def test_execute(self, mock_hook): + op = AutoMLDeleteDatasetOperator( + dataset_id=DATASET_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + task_id=TASK_ID, + ) + op.execute(context=None) + mock_hook.return_value.delete_dataset.assert_called_once_with( + location=GCP_LOCATION, + dataset_id=DATASET_ID, + metadata=(), + project_id=GCP_PROJECT_ID, + retry=DEFAULT, + timeout=None, + ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLDeleteDatasetOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLDeleteDatasetOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" diff --git a/tests/system/providers/google/cloud/automl/example_automl_dataset.py b/tests/system/providers/google/cloud/automl/example_automl_dataset.py new file mode 100644 index 0000000000000..1c2691657e19b --- /dev/null +++ b/tests/system/providers/google/cloud/automl/example_automl_dataset.py @@ -0,0 +1,201 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google AutoML service testing dataset operations. +""" + +from __future__ import annotations + +import os +from copy import deepcopy +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLImportDataOperator, + AutoMLListDatasetOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSSynchronizeBucketsOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "example_automl_dataset" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") + +GCP_AUTOML_LOCATION = "us-central1" +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") + +DATASET_NAME = f"ds_tabular_{ENV_ID}".replace("-", "_") +DATASET = { + "display_name": DATASET_NAME, + "tables_dataset_metadata": {"target_column_spec_id": ""}, +} +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/tabular-classification.csv" +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: + """ + Using column name returns spec of the column. + """ + for column in columns_specs: + if column["display_name"] == column_name: + return extract_object_id(column) + raise Exception(f"Unknown target column: {column_name}") + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "automl", "dataset"], + user_defined_macros={ + "get_target_column_spec": get_target_column_spec, + "target": "Class", + "extract_object_id": extract_object_id, + }, +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=GCP_AUTOML_LOCATION, + ) + + move_dataset_file = GCSSynchronizeBucketsOperator( + task_id="move_dataset_to_bucket", + source_bucket=RESOURCE_DATA_BUCKET, + source_object="automl/datasets/tabular", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object="automl", + recursive=True, + ) + + # [START howto_operator_automl_create_dataset] + create_dataset = AutoMLCreateDatasetOperator( + task_id="create_dataset", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + dataset_id = create_dataset.output["dataset_id"] + # [END howto_operator_automl_create_dataset] + + # [START howto_operator_automl_import_data] + import_dataset = AutoMLImportDataOperator( + task_id="import_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + # [END howto_operator_automl_import_data] + + # [START howto_operator_automl_specs] + list_tables_spec = AutoMLTablesListTableSpecsOperator( + task_id="list_tables_spec", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_specs] + + # [START howto_operator_automl_column_specs] + list_columns_spec = AutoMLTablesListColumnSpecsOperator( + task_id="list_columns_spec", + dataset_id=dataset_id, + table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_column_specs] + + # [START howto_operator_automl_update_dataset] + update = deepcopy(DATASET) + update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}' + update["tables_dataset_metadata"][ # type: ignore + "target_column_spec_id" + ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'), target) }}" + + update_dataset = AutoMLTablesUpdateDatasetOperator( + task_id="update_dataset", + dataset=update, + location=GCP_AUTOML_LOCATION, + ) + # [END howto_operator_automl_update_dataset] + + # [START howto_operator_list_dataset] + list_datasets = AutoMLListDatasetOperator( + task_id="list_datasets", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_list_dataset] + + # [START howto_operator_delete_dataset] + delete_dataset = AutoMLDeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_delete_dataset] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + [create_bucket >> move_dataset_file, create_dataset] + # TEST BODY + >> import_dataset + >> list_tables_spec + >> list_columns_spec + >> update_dataset + >> list_datasets + # TEST TEARDOWN + >> delete_dataset + >> delete_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/automl/example_automl_model.py b/tests/system/providers/google/cloud/automl/example_automl_model.py new file mode 100644 index 0000000000000..59ec91c8790d5 --- /dev/null +++ b/tests/system/providers/google/cloud/automl/example_automl_model.py @@ -0,0 +1,288 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google AutoML service testing model operations. +""" + +from __future__ import annotations + +import os +from copy import deepcopy +from datetime import datetime + +from google.protobuf.struct_pb2 import Value + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLBatchPredictOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLDeployModelOperator, + AutoMLGetModelOperator, + AutoMLImportDataOperator, + AutoMLPredictOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, + AutoMLTrainModelOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSSynchronizeBucketsOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "example_automl_model" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") + +GCP_AUTOML_LOCATION = "us-central1" + +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" + +DATASET_NAME = f"md_tabular_{ENV_ID}".replace("-", "_") +DATASET = { + "display_name": DATASET_NAME, + "tables_dataset_metadata": {"target_column_spec_id": ""}, +} +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/bank-marketing-split.csv" +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} +IMPORT_OUTPUT_CONFIG = { + "gcs_destination": {"output_uri_prefix": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl"} +} + +# change the name here +MODEL_NAME = f"md_tabular_{ENV_ID}".replace("-", "_") +MODEL = { + "display_name": MODEL_NAME, + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, +} + +PREDICT_VALUES = [ + Value(string_value="TRAINING"), + Value(string_value="51"), + Value(string_value="blue-collar"), + Value(string_value="married"), + Value(string_value="primary"), + Value(string_value="no"), + Value(string_value="620"), + Value(string_value="yes"), + Value(string_value="yes"), + Value(string_value="cellular"), + Value(string_value="29"), + Value(string_value="jul"), + Value(string_value="88"), + Value(string_value="10"), + Value(string_value="-1"), + Value(string_value="0"), + Value(string_value="unknown"), +] + +extract_object_id = CloudAutoMLHook.extract_object_id + + +def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: + """ + Using column name returns spec of the column. + """ + for column in columns_specs: + if column["display_name"] == column_name: + return extract_object_id(column) + raise Exception(f"Unknown target column: {column_name}") + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + user_defined_macros={ + "get_target_column_spec": get_target_column_spec, + "target": "Deposit", + "extract_object_id": extract_object_id, + }, + tags=["example", "automl", "model"], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=GCP_AUTOML_LOCATION, + ) + + move_dataset_file = GCSSynchronizeBucketsOperator( + task_id="move_data_to_bucket", + source_bucket=RESOURCE_DATA_BUCKET, + source_object="automl/datasets/model", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object="automl", + recursive=True, + ) + + create_dataset = AutoMLCreateDatasetOperator( + task_id="create_dataset", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + dataset_id = create_dataset.output["dataset_id"] + MODEL["dataset_id"] = dataset_id + import_dataset = AutoMLImportDataOperator( + task_id="import_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + list_tables_spec = AutoMLTablesListTableSpecsOperator( + task_id="list_tables_spec", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + list_columns_spec = AutoMLTablesListColumnSpecsOperator( + task_id="list_columns_spec", + dataset_id=dataset_id, + table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec')[0]) }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + update = deepcopy(DATASET) + update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}' + update["tables_dataset_metadata"][ # type: ignore + "target_column_spec_id" + ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec'), target) }}" + + update_dataset = AutoMLTablesUpdateDatasetOperator( + task_id="update_dataset", + dataset=update, + location=GCP_AUTOML_LOCATION, + ) + + # [START howto_operator_automl_create_model] + create_model = AutoMLTrainModelOperator( + task_id="create_model", + model=MODEL, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + model_id = create_model.output["model_id"] + # [END howto_operator_automl_create_model] + + # [START howto_operator_get_model] + get_model = AutoMLGetModelOperator( + task_id="get_model", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_get_model] + + # [START howto_operator_deploy_model] + deploy_model = AutoMLDeployModelOperator( + task_id="deploy_model", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_deploy_model] + + # [START howto_operator_prediction] + predict_task = AutoMLPredictOperator( + task_id="predict_task", + model_id=model_id, + payload={ + "row": { + "values": PREDICT_VALUES, + } + }, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_prediction] + + # [START howto_operator_batch_prediction] + batch_predict_task = AutoMLBatchPredictOperator( + task_id="batch_predict_task", + model_id=model_id, + input_config=IMPORT_INPUT_CONFIG, + output_config=IMPORT_OUTPUT_CONFIG, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_batch_prediction] + + # [START howto_operator_automl_delete_model] + delete_model = AutoMLDeleteModelOperator( + task_id="delete_model", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_delete_model] + + delete_dataset = AutoMLDeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + [create_bucket >> move_dataset_file, create_dataset] + >> import_dataset + >> list_tables_spec + >> list_columns_spec + >> update_dataset + # TEST BODY + >> create_model + >> get_model + >> deploy_model + >> predict_task + >> batch_predict_task + # TEST TEARDOWN + >> delete_model + >> delete_dataset + >> delete_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py index 7305123cb0164..9ef04db81882b 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py +++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py @@ -28,6 +28,7 @@ from google.protobuf.struct_pb2 import Value from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.gcs import ( GCSCreateBucketOperator, GCSDeleteBucketOperator, @@ -70,6 +71,7 @@ "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]}, }, ] +extract_object_id = CloudAutoMLHook.extract_object_id # Example DAG for AutoML Natural Language Text Classification with DAG( diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py index 916fc25877c9c..8f8564f62c209 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py +++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py @@ -28,6 +28,7 @@ from google.protobuf.struct_pb2 import Value from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.gcs import ( GCSCreateBucketOperator, GCSDeleteBucketOperator, @@ -69,11 +70,7 @@ }, ] - -def extract_object_id(obj: dict) -> str: - """Returns unique id of the object.""" - return obj["name"].rpartition("/")[-1] - +extract_object_id = CloudAutoMLHook.extract_object_id # Example DAG for AutoML Natural Language Entities Extraction with DAG( diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py index 0e641e1b05feb..94f349c6c3702 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py +++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py @@ -28,6 +28,7 @@ from google.protobuf.struct_pb2 import Value from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.gcs import ( GCSCreateBucketOperator, GCSDeleteBucketOperator, @@ -70,11 +71,7 @@ }, ] - -def extract_object_id(obj: dict) -> str: - """Returns unique id of the object.""" - return obj["name"].rpartition("/")[-1] - +extract_object_id = CloudAutoMLHook.extract_object_id # Example DAG for AutoML Natural Language Text Sentiment with DAG( diff --git a/tests/system/providers/google/cloud/automl/example_automl_translation.py b/tests/system/providers/google/cloud/automl/example_automl_translation.py new file mode 100644 index 0000000000000..ba36f556c427d --- /dev/null +++ b/tests/system/providers/google/cloud/automl/example_automl_translation.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that uses Google AutoML services. +""" + +from __future__ import annotations + +import os +from datetime import datetime +from typing import cast + +# The storage module cannot be imported yet https://github.com/googleapis/python-storage/issues/393 +from google.cloud import storage # type: ignore[attr-defined] + +from airflow.decorators import task +from airflow.models.dag import DAG +from airflow.models.xcom_arg import XComArg +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "example_automl_translate" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +GCP_AUTOML_LOCATION = "us-central1" +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" + + +MODEL_NAME = "translate_test_model" +MODEL = { + "display_name": MODEL_NAME, + "translation_model_metadata": {}, +} + +DATASET_NAME = f"ds_translate_{ENV_ID}".replace("-", "_") +DATASET = { + "display_name": DATASET_NAME, + "translation_dataset_metadata": { + "source_language_code": "en", + "target_language_code": "es", + }, +} + +CSV_FILE_NAME = "en-es.csv" +TSV_FILE_NAME = "en-es.tsv" +GCS_FILE_PATH = f"automl/datasets/translate/{CSV_FILE_NAME}" +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/{CSV_FILE_NAME}" +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +# Example DAG for AutoML Translation +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example", "automl", "translate"], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=GCP_AUTOML_LOCATION, + ) + + @task + def upload_csv_file_to_gcs(): + # download file into memory + storage_client = storage.Client() + bucket = storage_client.bucket(RESOURCE_DATA_BUCKET) + blob = bucket.blob(GCS_FILE_PATH) + contents = blob.download_as_string().decode() + + # update memory content + updated_contents = contents.replace("template-bucket", DATA_SAMPLE_GCS_BUCKET_NAME) + + # upload updated content to bucket + destination_bucket = storage_client.bucket(DATA_SAMPLE_GCS_BUCKET_NAME) + destination_blob = destination_bucket.blob(f"automl/{CSV_FILE_NAME}") + destination_blob.upload_from_string(updated_contents) + + upload_csv_file_to_gcs_task = upload_csv_file_to_gcs() + + copy_dataset_file = GCSToGCSOperator( + task_id="copy_dataset_file", + source_bucket=RESOURCE_DATA_BUCKET, + source_object=f"automl/datasets/translate/{TSV_FILE_NAME}", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object=f"automl/{TSV_FILE_NAME}", + ) + + create_dataset = AutoMLCreateDatasetOperator( + task_id="create_dataset", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = cast(str, XComArg(create_dataset, key="dataset_id")) + + import_dataset = AutoMLImportDataOperator( + task_id="import_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) + model_id = cast(str, XComArg(create_model, key="model_id")) + + delete_model = AutoMLDeleteModelOperator( + task_id="delete_model", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_dataset = AutoMLDeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + [create_bucket >> upload_csv_file_to_gcs_task >> copy_dataset_file] + # TEST BODY + >> create_dataset + >> import_dataset + >> create_model + # TEST TEARDOWN + >> delete_dataset + >> delete_model + >> delete_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)