diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py
index 1dd7cb03bafbf..6dae1e36d6fab 100644
--- a/airflow/providers/google/cloud/hooks/automl.py
+++ b/airflow/providers/google/cloud/hooks/automl.py
@@ -15,166 +15,628 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
+"""
+This module contains a Google AutoML hook.
-import warnings
+.. spelling:word-list::
-from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+ PredictResponse
+"""
+from __future__ import annotations
-class CloudAutoMLHook:
- """
- Former Google Cloud AutoML hook.
+from functools import cached_property
+from typing import TYPE_CHECKING, Sequence
+
+from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
+from google.cloud.automl_v1beta1 import (
+ AutoMlClient,
+ BatchPredictInputConfig,
+ BatchPredictOutputConfig,
+ Dataset,
+ ExamplePayload,
+ ImageObjectDetectionModelDeploymentMetadata,
+ InputConfig,
+ Model,
+ PredictionServiceClient,
+ PredictResponse,
+)
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.common.consts import CLIENT_INFO
+from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
+
+if TYPE_CHECKING:
+ from google.api_core.operation import Operation
+ from google.api_core.retry import Retry
+ from google.cloud.automl_v1beta1.services.auto_ml.pagers import (
+ ListColumnSpecsPager,
+ ListDatasetsPager,
+ ListTableSpecsPager,
+ )
+ from google.protobuf.field_mask_pb2 import FieldMask
- Deprecated as AutoML API becomes unusable starting March 31, 2024:
- https://cloud.google.com/automl/docs
- """
- deprecation_warning = (
- "CloudAutoMLHook has been deprecated, as AutoML API becomes unusable starting "
- "March 31, 2024, and will be removed in future release. Please use an equivalent "
- " Vertex AI hook available in"
- "airflow.providers.google.cloud.hooks.vertex_ai instead."
- )
+class CloudAutoMLHook(GoogleBaseHook):
+ """
+ Google Cloud AutoML hook.
- method_exception = "This method cannot be used as AutoML API becomes unusable."
+ All the methods in the hook where project_id is used must be called with
+ keyword arguments rather than positional.
+ """
- def __init__(self, **_) -> None:
- warnings.warn(self.deprecation_warning, AirflowProviderDeprecationWarning)
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ if kwargs.get("delegate_to") is not None:
+ raise RuntimeError(
+ "The `delegate_to` parameter has been deprecated before and finally removed in this version"
+ " of Google Provider. You MUST convert it to `impersonate_chain`"
+ )
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ )
+ self._client: AutoMlClient | None = None
@staticmethod
def extract_object_id(obj: dict) -> str:
"""Return unique id of the object."""
- warnings.warn(
- "'extract_object_id' method is deprecated and will be removed in future release.",
- AirflowProviderDeprecationWarning,
- )
return obj["name"].rpartition("/")[-1]
- def get_conn(self):
- """
- Retrieve connection to AutoML (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def wait_for_operation(self, **_):
- """
- Wait for long-lasting operation to complete (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def prediction_client(self, **_):
- """
- Create a PredictionServiceClient (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def create_model(self, **_):
- """
- Create a model_id and returns a Model in the `response` field when it completes (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def batch_predict(self, **_):
- """
- Perform a batch prediction (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def predict(self, **_):
- """
- Perform an online prediction (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def create_dataset(self, **_):
- """
- Create a dataset (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def import_data(self, **_):
- """
- Import data (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def list_column_specs(self, **_):
- """
- List column specs (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def get_model(self, **_):
- """
- Get a model (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def delete_model(self, **_):
- """
- Delete a model (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def update_dataset(self, **_):
- """
- Update a model (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def deploy_model(self, **_):
- """
- Deploy a model (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def list_table_specs(self, **_):
- """
- List table specs (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def list_datasets(self, **_):
- """
- List datasets (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def delete_dataset(self, **_):
- """
- Delete a dataset (deprecated).
+ def get_conn(self) -> AutoMlClient:
+ """
+ Retrieve connection to AutoML.
+
+ :return: Google Cloud AutoML client object.
+ """
+ if self._client is None:
+ self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+ return self._client
+
+ def wait_for_operation(self, operation: Operation, timeout: float | None = None):
+ """Wait for long-lasting operation to complete."""
+ try:
+ return operation.result(timeout=timeout)
+ except Exception:
+ error = operation.exception(timeout=timeout)
+ raise AirflowException(error)
+
+ @cached_property
+ def prediction_client(self) -> PredictionServiceClient:
+ """
+ Creates PredictionServiceClient.
+
+ :return: Google Cloud AutoML PredictionServiceClient client object.
+ """
+ return PredictionServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_model(
+ self,
+ model: dict | Model,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ retry: Retry | _MethodDefault = DEFAULT,
+ ) -> Operation:
+ """
+ Create a model_id and returns a Model in the `response` field when it completes.
+
+ When you create a model, several model evaluations are created for it:
+ a global evaluation, and one evaluation for each annotation spec.
+
+ :param model: The model_id to create. If a dict is provided, it must be of the same form
+ as the protobuf message `google.cloud.automl_v1beta1.types.Model`
+ :param project_id: ID of the Google Cloud project where model will be created if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests
+ will not be retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete.
+ Note that if `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.get_conn()
+ parent = f"projects/{project_id}/locations/{location}"
+ return client.create_model(
+ request={"parent": parent, "model": model},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
+ @GoogleBaseHook.fallback_to_default_project_id
+ def batch_predict(
+ self,
+ model_id: str,
+ input_config: dict | BatchPredictInputConfig,
+ output_config: dict | BatchPredictOutputConfig,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ params: dict[str, str] | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Perform a batch prediction and returns a long-running operation object.
+
+ Unlike the online `Predict`, batch prediction result won't be immediately
+ available in the response. Instead, a long-running operation object is returned.
+
+ :param model_id: Name of the model_id requested to serve the batch prediction.
+ :param input_config: Required. The input configuration for batch prediction.
+ If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictInputConfig`
+ :param output_config: Required. The Configuration specifying where output predictions should be
+ written. If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig`
+ :param params: Additional domain-specific parameters for the predictions, any string must be up to
+ 25000 characters long.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.prediction_client
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.batch_predict(
+ request={
+ "name": name,
+ "input_config": input_config,
+ "output_config": output_config,
+ "params": params,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def predict(
+ self,
+ model_id: str,
+ payload: dict | ExamplePayload,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ params: dict[str, str] | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> PredictResponse:
+ """
+ Perform an online prediction and returns the prediction result in the response.
+
+ :param model_id: Name of the model_id requested to serve the prediction.
+ :param payload: Required. Payload to perform a prediction on. The payload must match the problem type
+ that the model_id was trained to solve. If a dict is provided, it must be of
+ the same form as the protobuf message `google.cloud.automl_v1beta1.types.ExamplePayload`
+ :param params: Additional domain-specific parameters, any string must be up to 25000 characters long.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance
+ """
+ client = self.prediction_client
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.predict(
+ request={"name": name, "payload": payload, "params": params},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_dataset(
+ self,
+ dataset: dict | Dataset,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Dataset:
+ """
+ Create a dataset.
+
+ :param dataset: The dataset to create. If a dict is provided, it must be of the
+ same form as the protobuf message Dataset.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.Dataset` instance.
+ """
+ client = self.get_conn()
+ parent = f"projects/{project_id}/locations/{location}"
+ result = client.create_dataset(
+ request={"parent": parent, "dataset": dataset},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def import_data(
+ self,
+ dataset_id: str,
+ location: str,
+ input_config: dict | InputConfig,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Import data into a dataset. For Tables this method can only be called on an empty Dataset.
+
+ :param dataset_id: Name of the AutoML dataset.
+ :param input_config: The desired input location and its domain specific semantics, if any.
+ If a dict is provided, it must be of the same form as the protobuf message InputConfig.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+ result = client.import_data(
+ request={"name": name, "input_config": input_config},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_column_specs(
+ self,
+ dataset_id: str,
+ table_spec_id: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ field_mask: dict | FieldMask | None = None,
+ filter_: str | None = None,
+ page_size: int | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> ListColumnSpecsPager:
+ """
+ List column specs in a table spec.
+
+ :param dataset_id: Name of the AutoML dataset.
+ :param table_spec_id: table_spec_id for path builder.
+ :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same
+ form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask`
+ :param filter_: Filter expression, see go/filtering.
+ :param page_size: The maximum number of resources contained in the
+ underlying API response. If page streaming is performed per
+ resource, this parameter does not affect the return value. If page
+ streaming is performed per-page, this determines the maximum number
+ of resources in a page.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.ColumnSpec` instance.
+ """
+ client = self.get_conn()
+ parent = client.table_spec_path(
+ project=project_id,
+ location=location,
+ dataset=dataset_id,
+ table_spec=table_spec_id,
+ )
+ result = client.list_column_specs(
+ request={"parent": parent, "field_mask": field_mask, "filter": filter_, "page_size": page_size},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def get_model(
+ self,
+ model_id: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Model:
+ """
+ Get a AutoML model.
+
+ :param model_id: Name of the model.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.Model` instance.
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.get_model(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_model(
+ self,
+ model_id: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Delete a AutoML model.
+
+ :param model_id: Name of the model.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.delete_model(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ def update_dataset(
+ self,
+ dataset: dict | Dataset,
+ update_mask: dict | FieldMask | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Dataset:
+ """
+ Update a dataset.
+
+ :param dataset: The dataset which replaces the resource on the server.
+ If a dict is provided, it must be of the same form as the protobuf message Dataset.
+ :param update_mask: The update mask applies to the resource. If a dict is provided, it must
+ be of the same form as the protobuf message FieldMask.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.Dataset` instance..
+ """
+ client = self.get_conn()
+ result = client.update_dataset(
+ request={"dataset": dataset, "update_mask": update_mask},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def deploy_model(
+ self,
+ model_id: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ image_detection_metadata: ImageObjectDetectionModelDeploymentMetadata | dict | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Deploys a model.
+
+ If a model is already deployed, deploying it with the same parameters
+ has no effect. Deploying with different parameters (as e.g. changing node_number) will
+ reset the deployment state without pausing the model_id's availability.
+
+ Only applicable for Text Classification, Image Object Detection and Tables; all other
+ domains manage deployment automatically.
+
+ :param model_id: Name of the model requested to serve the prediction.
+ :param image_detection_metadata: Model deployment metadata specific to Image Object Detection.
+ If a dict is provided, it must be of the same form as the protobuf message
+ ImageObjectDetectionModelDeploymentMetadata
+ :param project_id: ID of the Google Cloud project where model will be created if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.deploy_model(
+ request={
+ "name": name,
+ "image_object_detection_model_deployment_metadata": image_detection_metadata,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ def list_table_specs(
+ self,
+ dataset_id: str,
+ location: str,
+ project_id: str | None = None,
+ filter_: str | None = None,
+ page_size: int | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> ListTableSpecsPager:
+ """
+ List table specs in a dataset_id.
+
+ :param dataset_id: Name of the dataset.
+ :param filter_: Filter expression, see go/filtering.
+ :param page_size: The maximum number of resources contained in the
+ underlying API response. If page streaming is performed per
+ resource, this parameter does not affect the return value. If page
+ streaming is performed per-page, this determines the maximum number
+ of resources in a page.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: A `google.gax.PageIterator` instance. By default, this
+ is an iterable of `google.cloud.automl_v1beta1.types.TableSpec` instances.
+ This object can also be configured to iterate over the pages
+ of the response through the `options` parameter.
+ """
+ client = self.get_conn()
+ parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+ result = client.list_table_specs(
+ request={"parent": parent, "filter": filter_, "page_size": page_size},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_datasets(
+ self,
+ location: str,
+ project_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> ListDatasetsPager:
+ """
+ List datasets in a project.
+
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: A `google.gax.PageIterator` instance. By default, this
+ is an iterable of `google.cloud.automl_v1beta1.types.Dataset` instances.
+ This object can also be configured to iterate over the pages
+ of the response through the `options` parameter.
+ """
+ client = self.get_conn()
+ parent = f"projects/{project_id}/locations/{location}"
+ result = client.list_datasets(
+ request={"parent": parent},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_dataset(
+ self,
+ dataset_id: str,
+ location: str,
+ project_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Delete a dataset and all of its contents.
+
+ :param dataset_id: ID of dataset to be deleted.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+ result = client.delete_dataset(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
diff --git a/airflow/providers/google/cloud/links/automl.py b/airflow/providers/google/cloud/links/automl.py
index b57601c64906c..79561d5b48132 100644
--- a/airflow/providers/google/cloud/links/automl.py
+++ b/airflow/providers/google/cloud/links/automl.py
@@ -19,28 +19,13 @@
from __future__ import annotations
-import warnings
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
-from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.links.base import BaseGoogleLink
if TYPE_CHECKING:
from airflow.utils.context import Context
-
-def __getattr__(name: str) -> Any:
- warnings.warn(
- (
- "AutoML links module have been deprecated and will be removed in the next MAJOR release."
- " Please use equivalent Vertex AI links instead"
- ),
- AirflowProviderDeprecationWarning,
- stacklevel=2,
- )
- return getattr(__name__, name)
-
-
AUTOML_BASE_LINK = "https://console.cloud.google.com/automl-tables"
AUTOML_DATASET_LINK = (
AUTOML_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/schemav2?project={project_id}"
diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py
index ca32994193fbb..54d31025c0ca2 100644
--- a/airflow/providers/google/cloud/operators/automl.py
+++ b/airflow/providers/google/cloud/operators/automl.py
@@ -15,16 +15,1259 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""This module is deprecated. Please use `airflow.providers.google.cloud.vertex_ai.auto_ml` instead."""
+"""This module contains Google AutoML operators."""
from __future__ import annotations
+import ast
import warnings
+from typing import TYPE_CHECKING, Sequence, Tuple
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
+from google.cloud.automl_v1beta1 import (
+ BatchPredictResult,
+ ColumnSpec,
+ Dataset,
+ Model,
+ PredictResponse,
+ TableSpec,
+)
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.vertex_ai.auto_ml` instead.",
- AirflowProviderDeprecationWarning,
- stacklevel=2,
+from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.links.automl import (
+ AutoMLDatasetLink,
+ AutoMLDatasetListLink,
+ AutoMLModelLink,
+ AutoMLModelPredictLink,
+ AutoMLModelTrainLink,
)
+from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
+
+if TYPE_CHECKING:
+ from google.api_core.retry import Retry
+
+ from airflow.utils.context import Context
+
+MetaData = Sequence[Tuple[str, str]]
+
+
+class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
+ """
+ Creates Google Cloud AutoML model.
+
+ AutoMLTrainModelOperator for text prediction is deprecated. Please use
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
+ instead.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLTrainModelOperator`
+
+ :param model: Model definition.
+ :param project_id: ID of the Google Cloud project where model will be created if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (
+ AutoMLModelTrainLink(),
+ AutoMLModelLink(),
+ )
+
+ def __init__(
+ self,
+ *,
+ model: dict,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.model = model
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ # Output warning if running not AutoML Translation prediction job
+ if "translation_model_metadata" not in self.model:
+ warnings.warn(
+ "AutoMLTrainModelOperator for text, image and video prediction is deprecated. "
+ "All the functionality of legacy "
+ "AutoML Natural Language, Vision and Video Intelligence and new features are available "
+ "on the Vertex AI platform. "
+ "Please use `CreateAutoMLTextTrainingJobOperator`, `CreateAutoMLImageTrainingJobOperator` or"
+ " `CreateAutoMLVideoTrainingJobOperator` from VertexAI.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=3,
+ )
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Creating model %s...", self.model["display_name"])
+ operation = hook.create_model(
+ model=self.model,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelTrainLink.persist(context=context, task_instance=self, project_id=project_id)
+ operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ result = Model.to_dict(operation_result)
+ model_id = hook.extract_object_id(result)
+ self.log.info("Model is created, model_id: %s", model_id)
+
+ self.xcom_push(context, key="model_id", value=model_id)
+ if project_id:
+ AutoMLModelLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.model["dataset_id"] or "-",
+ model_id=model_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLPredictOperator(GoogleCloudBaseOperator):
+ """
+ Runs prediction operation on Google Cloud AutoML.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLPredictOperator`
+
+ :param model_id: Name of the model requested to serve the batch prediction.
+ :param payload: Name od the model used for the prediction.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param operation_params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLModelPredictLink(),)
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ payload: dict,
+ operation_params: dict[str, str] | None = None,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.operation_params = operation_params # type: ignore
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.payload = payload
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ result = hook.predict(
+ model_id=self.model_id,
+ payload=self.payload,
+ location=self.location,
+ project_id=self.project_id,
+ params=self.operation_params,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelPredictLink.persist(
+ context=context,
+ task_instance=self,
+ model_id=self.model_id,
+ project_id=project_id,
+ )
+ return PredictResponse.to_dict(result)
+
+
+class AutoMLBatchPredictOperator(GoogleCloudBaseOperator):
+ """
+ Perform a batch prediction on Google Cloud AutoML.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLBatchPredictOperator`
+
+ :param project_id: ID of the Google Cloud project where model will be created if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param model_id: Name of the model_id requested to serve the batch prediction.
+ :param input_config: Required. The input configuration for batch prediction.
+ If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictInputConfig`
+ :param output_config: Required. The Configuration specifying where output predictions should be
+ written. If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig`
+ :param prediction_params: Additional domain-specific parameters for the predictions,
+ any string must be up to 25000 characters long.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "input_config",
+ "output_config",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLModelPredictLink(),)
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ input_config: dict,
+ output_config: dict,
+ location: str,
+ project_id: str | None = None,
+ prediction_params: dict[str, str] | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.location = location
+ self.project_id = project_id
+ self.prediction_params = prediction_params
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.input_config = input_config
+ self.output_config = output_config
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Fetch batch prediction.")
+ operation = hook.batch_predict(
+ model_id=self.model_id,
+ input_config=self.input_config,
+ output_config=self.output_config,
+ project_id=self.project_id,
+ location=self.location,
+ params=self.prediction_params,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ result = BatchPredictResult.to_dict(operation_result)
+ self.log.info("Batch prediction is ready.")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelPredictLink.persist(
+ context=context,
+ task_instance=self,
+ model_id=self.model_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Creates a Google Cloud AutoML dataset.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLCreateDatasetOperator`
+
+ :param dataset: The dataset to create. If a dict is provided, it must be of the
+ same form as the protobuf message Dataset.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset: dict,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.dataset = dataset
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Creating dataset %s...", self.dataset)
+ result = hook.create_dataset(
+ dataset=self.dataset,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ result = Dataset.to_dict(result)
+ dataset_id = hook.extract_object_id(result)
+ self.log.info("Creating completed. Dataset id: %s", dataset_id)
+
+ self.xcom_push(context, key="dataset_id", value=dataset_id)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=dataset_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLImportDataOperator(GoogleCloudBaseOperator):
+ """
+ Imports data to a Google Cloud AutoML dataset.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLImportDataOperator`
+
+ :param dataset_id: ID of dataset to be updated.
+ :param input_config: The desired input location and its domain specific semantics, if any.
+ If a dict is provided, it must be of the same form as the protobuf message InputConfig.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset_id",
+ "input_config",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ location: str,
+ input_config: dict,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.dataset_id = dataset_id
+ self.input_config = input_config
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Importing data to dataset...")
+ operation = hook.import_data(
+ dataset_id=self.dataset_id,
+ input_config=self.input_config,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ self.log.info("Import is completed")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ project_id=project_id,
+ )
+
+
+class AutoMLTablesListColumnSpecsOperator(GoogleCloudBaseOperator):
+ """
+ Lists column specs in a table.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLTablesListColumnSpecsOperator`
+
+ :param dataset_id: Name of the dataset.
+ :param table_spec_id: table_spec_id for path builder.
+ :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same
+ form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask`
+ :param filter_: Filter expression, see go/filtering.
+ :param page_size: The maximum number of resources contained in the
+ underlying API response. If page streaming is performed per
+ resource, this parameter does not affect the return value. If page
+ streaming is performed per page, this determines the maximum number
+ of resources in a page.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset_id",
+ "table_spec_id",
+ "field_mask",
+ "filter_",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ table_spec_id: str,
+ location: str,
+ field_mask: dict | None = None,
+ filter_: str | None = None,
+ page_size: int | None = None,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset_id = dataset_id
+ self.table_spec_id = table_spec_id
+ self.field_mask = field_mask
+ self.filter_ = filter_
+ self.page_size = page_size
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Requesting column specs.")
+ page_iterator = hook.list_column_specs(
+ dataset_id=self.dataset_id,
+ table_spec_id=self.table_spec_id,
+ field_mask=self.field_mask,
+ filter_=self.filter_,
+ page_size=self.page_size,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ result = [ColumnSpec.to_dict(spec) for spec in page_iterator]
+ self.log.info("Columns specs obtained.")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLTablesUpdateDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Updates a dataset.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLTablesUpdateDatasetOperator`
+
+ :param dataset: The dataset which replaces the resource on the server.
+ If a dict is provided, it must be of the same form as the protobuf message Dataset.
+ :param update_mask: The update mask applies to the resource. If a dict is provided, it must
+ be of the same form as the protobuf message FieldMask.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset",
+ "update_mask",
+ "location",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset: dict,
+ location: str,
+ update_mask: dict | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.dataset = dataset
+ self.update_mask = update_mask
+ self.location = location
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Updating AutoML dataset %s.", self.dataset["name"])
+ result = hook.update_dataset(
+ dataset=self.dataset,
+ update_mask=self.update_mask,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Dataset updated.")
+ project_id = hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=hook.extract_object_id(self.dataset),
+ project_id=project_id,
+ )
+ return Dataset.to_dict(result)
+
+
+class AutoMLGetModelOperator(GoogleCloudBaseOperator):
+ """
+ Get Google Cloud AutoML model.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLGetModelOperator`
+
+ :param model_id: Name of the model requested to serve the prediction.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ result = hook.get_model(
+ model_id=self.model_id,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ model = Model.to_dict(result)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=model["dataset_id"],
+ model_id=self.model_id,
+ project_id=project_id,
+ )
+ return model
+
+
+class AutoMLDeleteModelOperator(GoogleCloudBaseOperator):
+ """
+ Delete Google Cloud AutoML model.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLDeleteModelOperator`
+
+ :param model_id: Name of the model requested to serve the prediction.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ operation = hook.delete_model(
+ model_id=self.model_id,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ self.log.info("Deletion is completed")
+
+
+class AutoMLDeployModelOperator(GoogleCloudBaseOperator):
+ """
+ Deploys a model; if a model is already deployed, deploying it with the same parameters has no effect.
+
+ Deploying with different parameters (as e.g. changing node_number) will
+ reset the deployment state without pausing the model_id's availability.
+
+ Only applicable for Text Classification, Image Object Detection and Tables; all other
+ domains manage deployment automatically.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLDeployModelOperator`
+
+ :param model_id: Name of the model to be deployed.
+ :param image_detection_metadata: Model deployment metadata specific to Image Object Detection.
+ If a dict is provided, it must be of the same form as the protobuf message
+ ImageObjectDetectionModelDeploymentMetadata
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ project_id: str | None = None,
+ image_detection_metadata: dict | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.image_detection_metadata = image_detection_metadata
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Deploying model_id %s", self.model_id)
+
+ operation = hook.deploy_model(
+ model_id=self.model_id,
+ location=self.location,
+ project_id=self.project_id,
+ image_detection_metadata=self.image_detection_metadata,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ self.log.info("Model was deployed successfully.")
+
+
+class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator):
+ """
+ Lists table specs in a dataset.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLTablesListTableSpecsOperator`
+
+ :param dataset_id: Name of the dataset.
+ :param filter_: Filter expression, see go/filtering.
+ :param page_size: The maximum number of resources contained in the
+ underlying API response. If page streaming is performed per
+ resource, this parameter does not affect the return value. If page
+ streaming is performed per-page, this determines the maximum number
+ of resources in a page.
+ :param project_id: ID of the Google Cloud project if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset_id",
+ "filter_",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ location: str,
+ page_size: int | None = None,
+ filter_: str | None = None,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset_id = dataset_id
+ self.filter_ = filter_
+ self.page_size = page_size
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Requesting table specs for %s.", self.dataset_id)
+ page_iterator = hook.list_table_specs(
+ dataset_id=self.dataset_id,
+ filter_=self.filter_,
+ page_size=self.page_size,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ result = [TableSpec.to_dict(spec) for spec in page_iterator]
+ self.log.info(result)
+ self.log.info("Table specs obtained.")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLListDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Lists AutoML Datasets in project.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLListDatasetOperator`
+
+ :param project_id: ID of the Google Cloud project where datasets are located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetListLink(),)
+
+ def __init__(
+ self,
+ *,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Requesting datasets")
+ page_iterator = hook.list_datasets(
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ result = [Dataset.to_dict(dataset) for dataset in page_iterator]
+ self.log.info("Datasets obtained.")
+
+ self.xcom_push(
+ context,
+ key="dataset_id_list",
+ value=[hook.extract_object_id(d) for d in result],
+ )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetListLink.persist(context=context, task_instance=self, project_id=project_id)
+ return result
+
+
+class AutoMLDeleteDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Deletes a dataset and all of its contents.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLDeleteDatasetOperator`
+
+ :param dataset_id: Name of the dataset_id, list of dataset_id or string of dataset_id
+ coma separated to be deleted.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str | list[str],
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.dataset_id = dataset_id
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ @staticmethod
+ def _parse_dataset_id(dataset_id: str | list[str]) -> list[str]:
+ if not isinstance(dataset_id, str):
+ return dataset_id
+ try:
+ return ast.literal_eval(dataset_id)
+ except (SyntaxError, ValueError):
+ return dataset_id.split(",")
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ dataset_id_list = self._parse_dataset_id(self.dataset_id)
+ for dataset_id in dataset_id_list:
+ self.log.info("Deleting dataset %s", dataset_id)
+ hook.delete_dataset(
+ dataset_id=dataset_id,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Dataset deleted.")
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index b70ff2705a0d7..748cfbb39cc28 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -108,6 +108,7 @@ dependencies:
- google-auth>=1.0.0
- google-auth-httplib2>=0.0.1
- google-cloud-aiplatform>=1.42.1
+ - google-cloud-automl>=2.12.0
- google-cloud-bigquery-datatransfer>=3.13.0
- google-cloud-bigtable>=2.17.0
- google-cloud-build>=3.22.0
@@ -201,6 +202,8 @@ integrations:
tags: [gmp]
- integration-name: Google AutoML
external-doc-url: https://cloud.google.com/automl/
+ how-to-guide:
+ - /docs/apache-airflow-providers-google/operators/cloud/automl.rst
logo: /integration-logos/gcp/Cloud-AutoML.png
tags: [gcp]
- integration-name: Google BigQuery Data Transfer Service
@@ -529,6 +532,9 @@ operators:
- integration-name: Google Cloud Common
python-modules:
- airflow.providers.google.cloud.operators.cloud_base
+ - integration-name: Google AutoML
+ python-modules:
+ - airflow.providers.google.cloud.operators.automl
- integration-name: Google BigQuery
python-modules:
- airflow.providers.google.cloud.operators.bigquery
@@ -1229,6 +1235,11 @@ extra-links:
- airflow.providers.google.cloud.links.cloud_build.CloudBuildListLink
- airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggersListLink
- airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggerDetailsLink
+ - airflow.providers.google.cloud.links.automl.AutoMLDatasetLink
+ - airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelTrainLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink
- airflow.providers.google.cloud.links.life_sciences.LifeSciencesLink
- airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsDetailsLink
- airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsListLink
diff --git a/docker_tests/test_prod_image.py b/docker_tests/test_prod_image.py
index b4e59052ee309..ab35c63bffa53 100644
--- a/docker_tests/test_prod_image.py
+++ b/docker_tests/test_prod_image.py
@@ -131,6 +131,7 @@ def test_pip_dependencies_conflict(self, default_docker_image):
"googleapiclient",
"google.auth",
"google_auth_httplib2",
+ "google.cloud.automl",
"google.cloud.bigquery_datatransfer",
"google.cloud.bigtable",
"google.cloud.container",
diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
new file mode 100644
index 0000000000000..34324ea60d914
--- /dev/null
+++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
@@ -0,0 +1,229 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Google Cloud AutoML Operators
+=======================================
+
+The `Google Cloud AutoML `__
+makes the power of machine learning available to you even if you have limited knowledge
+of machine learning. You can use AutoML to build on Google's machine learning capabilities
+to create your own custom machine learning models that are tailored to your business needs,
+and then integrate those models into your applications and web sites.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+.. include:: /operators/_partials/prerequisite_tasks.rst
+
+.. _howto/operator:CloudAutoMLDocuments:
+.. _howto/operator:AutoMLCreateDatasetOperator:
+.. _howto/operator:AutoMLImportDataOperator:
+.. _howto/operator:AutoMLTablesUpdateDatasetOperator:
+
+Creating Datasets
+^^^^^^^^^^^^^^^^^
+
+To create a Google AutoML dataset you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator`.
+The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_automl_create_dataset]
+ :end-before: [END howto_operator_automl_create_dataset]
+
+After creating a dataset you can use it to import some data using
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLImportDataOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_automl_import_data]
+ :end-before: [END howto_operator_automl_import_data]
+
+To update dataset you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_automl_update_dataset]
+ :end-before: [END howto_operator_automl_update_dataset]
+
+.. _howto/operator:AutoMLTablesListTableSpecsOperator:
+.. _howto/operator:AutoMLTablesListColumnSpecsOperator:
+
+Listing Table And Columns Specs
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To list table specs you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_automl_specs]
+ :end-before: [END howto_operator_automl_specs]
+
+To list column specs you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListColumnSpecsOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_automl_column_specs]
+ :end-before: [END howto_operator_automl_column_specs]
+
+.. _howto/operator:AutoMLTrainModelOperator:
+.. _howto/operator:AutoMLGetModelOperator:
+.. _howto/operator:AutoMLDeployModelOperator:
+.. _howto/operator:AutoMLDeleteModelOperator:
+
+Operations On Models
+^^^^^^^^^^^^^^^^^^^^
+
+To create a Google AutoML model you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator`.
+The operator will wait for the operation to complete. Additionally the operator
+returns the id of model in :ref:`XCom ` under ``model_id`` key.
+
+This Operator is deprecated when running for text, video and vision prediction and will be removed soon.
+All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are
+available on the Vertex AI platform. Please use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`,
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator` or
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`.
+
+You can find example on how to use VertexAI operators for AutoML Natural Language classification here:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_cloud_create_text_classification_training_job_operator]
+ :end-before: [END howto_cloud_create_text_classification_training_job_operator]
+
+Additionally, you can find example on how to use VertexAI operators for AutoML Vision classification here:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_vision_classification.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_cloud_create_image_classification_training_job_operator]
+ :end-before: [END howto_cloud_create_image_classification_training_job_operator]
+
+Example on how to use VertexAI operators for AutoML Video Intelligence classification you can find here:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_video_classification.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_cloud_create_video_classification_training_job_operator]
+ :end-before: [END howto_cloud_create_video_classification_training_job_operator]
+
+When running Vertex AI Operator for training data, please ensure that your data is correctly stored in Vertex AI
+datasets. To create and import data to the dataset please use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
+and
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_automl_create_model]
+ :end-before: [END howto_operator_automl_create_model]
+
+To get existing model one can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_get_model]
+ :end-before: [END howto_operator_get_model]
+
+Once a model is created it could be deployed using
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_deploy_model]
+ :end-before: [END howto_operator_deploy_model]
+
+If you wish to delete a model you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_automl_delete_model]
+ :end-before: [END howto_operator_automl_delete_model]
+
+.. _howto/operator:AutoMLPredictOperator:
+.. _howto/operator:AutoMLBatchPredictOperator:
+
+Making Predictions
+^^^^^^^^^^^^^^^^^^
+
+To obtain predictions from Google Cloud AutoML model you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator` or
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator`. In the first case
+the model must be deployed.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_prediction]
+ :end-before: [END howto_operator_prediction]
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_batch_prediction]
+ :end-before: [END howto_operator_batch_prediction]
+
+.. _howto/operator:AutoMLListDatasetOperator:
+.. _howto/operator:AutoMLDeleteDatasetOperator:
+
+Listing And Deleting Datasets
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+You can get a list of AutoML datasets using
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list
+of datasets ids in :ref:`XCom ` under ``dataset_id_list`` key.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_list_dataset]
+ :end-before: [END howto_operator_list_dataset]
+
+To delete a dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
+The delete operator allows also to pass list or coma separated string of datasets ids to be deleted.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_delete_dataset]
+ :end-before: [END howto_operator_delete_dataset]
+
+Reference
+^^^^^^^^^
+
+For further information, look at:
+
+* `Client Library Documentation `__
+* `Product Documentation `__
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index d88b975a87ebc..54a02cc5d598b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -526,6 +526,7 @@
"google-auth-httplib2>=0.0.1",
"google-auth>=1.0.0",
"google-cloud-aiplatform>=1.42.1",
+ "google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
"google-cloud-bigtable>=2.17.0",
diff --git a/scripts/in_container/run_provider_yaml_files_check.py b/scripts/in_container/run_provider_yaml_files_check.py
index c343bb2397726..b1608c25bcedd 100755
--- a/scripts/in_container/run_provider_yaml_files_check.py
+++ b/scripts/in_container/run_provider_yaml_files_check.py
@@ -50,17 +50,10 @@
"airflow.providers.apache.hdfs.hooks.hdfs",
"airflow.providers.cncf.kubernetes.triggers.kubernetes_pod",
"airflow.providers.cncf.kubernetes.operators.kubernetes_pod",
- "airflow.providers.google.cloud.operators.automl",
]
KNOWN_DEPRECATED_CLASSES = [
- "airflow.providers.google.cloud.links.automl.AutoMLDatasetLink",
- "airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink",
- "airflow.providers.google.cloud.links.automl.AutoMLModelLink",
- "airflow.providers.google.cloud.links.automl.AutoMLModelListLink",
- "airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink",
"airflow.providers.google.cloud.links.dataproc.DataprocLink",
- "airflow.providers.google.cloud.hooks.automl.CloudAutoMLHook",
]
try:
diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py
index 773a2fc1c4206..5113ce15f4052 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -429,6 +429,8 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
}
ASSETS_NOT_REQUIRED = {
+ "airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator",
"airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator",
"airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteDatasetOperator",
"airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator",
diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml
index d271e6077a25b..6d27f4f5388d6 100644
--- a/tests/deprecations_ignore.yml
+++ b/tests/deprecations_ignore.yml
@@ -779,6 +779,7 @@
- tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py::TestCustomJobWithoutDefaultProjectIdHook::test_get_pipeline_job
- tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py::TestCustomJobWithoutDefaultProjectIdHook::test_list_pipeline_jobs
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryCreateExternalTableOperator::test_execute_with_csv_format
+- tests/providers/google/cloud/operators/test_automl.py::TestAutoMLTrainModelOperator::test_execute
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryCreateExternalTableOperator::test_execute_with_parquet_format
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryOperator::test_bigquery_operator_defaults
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryOperator::test_bigquery_operator_extra_link_when_missing_job_id
diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py
index 0f97b91e5a78d..f79dd8b51b73d 100644
--- a/tests/providers/google/cloud/hooks/test_automl.py
+++ b/tests/providers/google/cloud/hooks/test_automl.py
@@ -17,85 +17,237 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
import pytest
+from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud.automl_v1beta1 import AutoMlClient
-from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.common.consts import CLIENT_INFO
+from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id
+
+CREDENTIALS = "test-creds"
+TASK_ID = "test-automl-hook"
+GCP_PROJECT_ID = "test-project"
+GCP_LOCATION = "test-location"
+MODEL_NAME = "test_model"
+MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
+DATASET_ID = "TBL123456789"
+MODEL = {
+ "display_name": MODEL_NAME,
+ "dataset_id": DATASET_ID,
+ "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
+}
+
+LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
+MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
+DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
+
+INPUT_CONFIG = {"input": "value"}
+OUTPUT_CONFIG = {"output": "value"}
+PAYLOAD = {"test": "payload"}
+DATASET = {"dataset_id": "data"}
+MASK = {"field": "mask"}
class TestAutoMLHook:
- def setup_method(self):
- self.hook = CloudAutoMLHook()
-
- def test_init(self):
- with pytest.warns(AirflowProviderDeprecationWarning):
- CloudAutoMLHook()
-
- def test_extract_object_id(self):
- with pytest.warns(AirflowProviderDeprecationWarning, match="'extract_object_id'"):
- object_id = CloudAutoMLHook.extract_object_id(obj={"name": "x/y"})
- assert object_id == "y"
-
- def test_get_conn(self):
- with pytest.raises(AirflowException):
- self.hook.get_conn()
-
- def test_wait_for_operation(self):
- with pytest.raises(AirflowException):
- self.hook.wait_for_operation()
-
- def test_prediction_client(self):
- with pytest.raises(AirflowException):
- self.hook.prediction_client()
-
- def test_create_model(self):
- with pytest.raises(AirflowException):
- self.hook.create_model()
-
- def test_batch_predict(self):
- with pytest.raises(AirflowException):
- self.hook.batch_predict()
-
- def test_predict(self):
- with pytest.raises(AirflowException):
- self.hook.predict()
+ def test_delegate_to_runtime_error(self):
+ with pytest.raises(RuntimeError):
+ CloudAutoMLHook(gcp_conn_id="GCP_CONN_ID", delegate_to="delegate_to")
- def test_create_dataset(self):
- with pytest.raises(AirflowException):
- self.hook.create_dataset()
-
- def test_import_data(self):
- with pytest.raises(AirflowException):
- self.hook.import_data()
-
- def test_list_column_specs(self):
- with pytest.raises(AirflowException):
- self.hook.list_column_specs()
-
- def test_get_model(self):
- with pytest.raises(AirflowException):
- self.hook.get_model()
-
- def test_delete_model(self):
- with pytest.raises(AirflowException):
- self.hook.delete_model()
-
- def test_update_dataset(self):
- with pytest.raises(AirflowException):
- self.hook.update_dataset()
-
- def test_deploy_model(self):
- with pytest.raises(AirflowException):
- self.hook.deploy_model()
-
- def test_list_table_specs(self):
- with pytest.raises(AirflowException):
- self.hook.list_table_specs()
-
- def test_list_datasets(self):
- with pytest.raises(AirflowException):
- self.hook.list_datasets()
-
- def test_delete_dataset(self):
- with pytest.raises(AirflowException):
- self.hook.delete_dataset()
+ def setup_method(self):
+ with mock.patch(
+ "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_no_default_project_id,
+ ):
+ self.hook = CloudAutoMLHook()
+ self.hook.get_credentials = mock.MagicMock(return_value=CREDENTIALS) # type: ignore
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient")
+ def test_get_conn(self, mock_automl_client):
+ self.hook.get_conn()
+ mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO)
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient")
+ def test_prediction_client(self, mock_prediction_client):
+ client = self.hook.prediction_client # noqa: F841
+ mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO)
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_model")
+ def test_create_model(self, mock_create_model):
+ self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_create_model.assert_called_once_with(
+ request=dict(parent=LOCATION_PATH, model=MODEL), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict")
+ def test_batch_predict(self, mock_batch_predict):
+ self.hook.batch_predict(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ output_config=OUTPUT_CONFIG,
+ )
+
+ mock_batch_predict.assert_called_once_with(
+ request=dict(
+ name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None
+ ),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict")
+ def test_predict(self, mock_predict):
+ self.hook.predict(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ payload=PAYLOAD,
+ )
+
+ mock_predict.assert_called_once_with(
+ request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset")
+ def test_create_dataset(self, mock_create_dataset):
+ self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_create_dataset.assert_called_once_with(
+ request=dict(parent=LOCATION_PATH, dataset=DATASET),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data")
+ def test_import_dataset(self, mock_import_data):
+ self.hook.import_data(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ )
+
+ mock_import_data.assert_called_once_with(
+ request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs")
+ def test_list_column_specs(self, mock_list_column_specs):
+ table_spec = "table_spec_id"
+ filter_ = "filter"
+ page_size = 42
+
+ self.hook.list_column_specs(
+ dataset_id=DATASET_ID,
+ table_spec_id=table_spec,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ field_mask=MASK,
+ filter_=filter_,
+ page_size=page_size,
+ )
+
+ parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec)
+ mock_list_column_specs.assert_called_once_with(
+ request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model")
+ def test_get_model(self, mock_get_model):
+ self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_get_model.assert_called_once_with(
+ request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model")
+ def test_delete_model(self, mock_delete_model):
+ self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_delete_model.assert_called_once_with(
+ request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset")
+ def test_update_dataset(self, mock_update_dataset):
+ self.hook.update_dataset(
+ dataset=DATASET,
+ update_mask=MASK,
+ )
+
+ mock_update_dataset.assert_called_once_with(
+ request=dict(dataset=DATASET, update_mask=MASK), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model")
+ def test_deploy_model(self, mock_deploy_model):
+ image_detection_metadata = {}
+
+ self.hook.deploy_model(
+ model_id=MODEL_ID,
+ image_detection_metadata=image_detection_metadata,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ mock_deploy_model.assert_called_once_with(
+ request=dict(
+ name=MODEL_PATH,
+ image_object_detection_model_deployment_metadata=image_detection_metadata,
+ ),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs")
+ def test_list_table_specs(self, mock_list_table_specs):
+ filter_ = "filter"
+ page_size = 42
+
+ self.hook.list_table_specs(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ filter_=filter_,
+ page_size=page_size,
+ )
+
+ mock_list_table_specs.assert_called_once_with(
+ request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets")
+ def test_list_datasets(self, mock_list_datasets):
+ self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_list_datasets.assert_called_once_with(
+ request=dict(parent=LOCATION_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset")
+ def test_delete_dataset(self, mock_delete_dataset):
+ self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_delete_dataset.assert_called_once_with(
+ request=dict(name=DATASET_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py
index 1caf680cd512c..4f00f76a2dbef 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -17,13 +17,643 @@
# under the License.
from __future__ import annotations
-from importlib import import_module
+import copy
+from unittest import mock
import pytest
+from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLBatchPredictOperator,
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLDeleteModelOperator,
+ AutoMLDeployModelOperator,
+ AutoMLGetModelOperator,
+ AutoMLImportDataOperator,
+ AutoMLListDatasetOperator,
+ AutoMLPredictOperator,
+ AutoMLTablesListColumnSpecsOperator,
+ AutoMLTablesListTableSpecsOperator,
+ AutoMLTablesUpdateDatasetOperator,
+ AutoMLTrainModelOperator,
+)
+from airflow.utils import timezone
+CREDENTIALS = "test-creds"
+TASK_ID = "test-automl-hook"
+GCP_PROJECT_ID = "test-project"
+GCP_LOCATION = "test-location"
+MODEL_NAME = "test_model"
+MODEL_ID = "TBL9195602771183665152"
+DATASET_ID = "TBL123456789"
+MODEL = {
+ "display_name": MODEL_NAME,
+ "dataset_id": DATASET_ID,
+ "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
+}
-def test_deprecated_module():
- with pytest.warns(AirflowProviderDeprecationWarning):
- import_module("airflow.providers.google.cloud.operators.automl")
+LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
+MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
+DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
+
+INPUT_CONFIG = {"input": "value"}
+OUTPUT_CONFIG = {"output": "value"}
+PAYLOAD = {"test": "payload"}
+DATASET = {"dataset_id": "data"}
+MASK = {"field": "mask"}
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+class TestAutoMLTrainModelOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH)
+ mock_hook.return_value.extract_object_id = extract_object_id
+ mock_hook.return_value.wait_for_operation.return_value = Model()
+ op = AutoMLTrainModelOperator(
+ model=MODEL,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.create_model.assert_called_once_with(
+ model=MODEL,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTrainModelOperator,
+ # Templated fields
+ model="{{ 'model' }}",
+ location="{{ 'location' }}",
+ impersonation_chain="{{ 'impersonation_chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTrainModelOperator = ti.task
+ assert task.model == "model"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation_chain"
+
+
+class TestAutoMLBatchPredictOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult()
+ mock_hook.return_value.extract_object_id = extract_object_id
+ mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult()
+
+ op = AutoMLBatchPredictOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ output_config=OUTPUT_CONFIG,
+ task_id=TASK_ID,
+ prediction_params={},
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.batch_predict.assert_called_once_with(
+ input_config=INPUT_CONFIG,
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ output_config=OUTPUT_CONFIG,
+ params={},
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLBatchPredictOperator,
+ # Templated fields
+ model_id="{{ 'model' }}",
+ input_config="{{ 'input-config' }}",
+ output_config="{{ 'output-config' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLBatchPredictOperator = ti.task
+ assert task.model_id == "model"
+ assert task.input_config == "input-config"
+ assert task.output_config == "output-config"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLPredictOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.predict.return_value = PredictResponse()
+
+ op = AutoMLPredictOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ payload=PAYLOAD,
+ task_id=TASK_ID,
+ operation_params={"TEST_KEY": "TEST_VALUE"},
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.predict.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ params={"TEST_KEY": "TEST_VALUE"},
+ payload=PAYLOAD,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLPredictOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ payload={},
+ )
+ ti.render_templates()
+ task: AutoMLPredictOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.project_id == "project-id"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLCreateImportOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH)
+ mock_hook.return_value.extract_object_id = extract_object_id
+
+ op = AutoMLCreateDatasetOperator(
+ dataset=DATASET,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.create_dataset.assert_called_once_with(
+ dataset=DATASET,
+ location=GCP_LOCATION,
+ metadata=(),
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLCreateDatasetOperator,
+ # Templated fields
+ dataset="{{ 'dataset' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLCreateDatasetOperator = ti.task
+ assert task.dataset == "dataset"
+ assert task.project_id == "project-id"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLListColumnsSpecsOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ table_spec = "table_spec_id"
+ filter_ = "filter"
+ page_size = 42
+
+ op = AutoMLTablesListColumnSpecsOperator(
+ dataset_id=DATASET_ID,
+ table_spec_id=table_spec,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ field_mask=MASK,
+ filter_=filter_,
+ page_size=page_size,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.list_column_specs.assert_called_once_with(
+ dataset_id=DATASET_ID,
+ field_mask=MASK,
+ filter_=filter_,
+ location=GCP_LOCATION,
+ metadata=(),
+ page_size=page_size,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ table_spec_id=table_spec,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTablesListColumnSpecsOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ table_spec_id="{{ 'table-spec-id' }}",
+ field_mask="{{ 'field-mask' }}",
+ filter_="{{ 'filter-' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTablesListColumnSpecsOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.table_spec_id == "table-spec-id"
+ assert task.field_mask == "field-mask"
+ assert task.filter_ == "filter-"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLUpdateDatasetOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH)
+
+ dataset = copy.deepcopy(DATASET)
+ dataset["name"] = DATASET_ID
+
+ op = AutoMLTablesUpdateDatasetOperator(
+ dataset=dataset,
+ update_mask=MASK,
+ location=GCP_LOCATION,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.update_dataset.assert_called_once_with(
+ dataset=dataset,
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ update_mask=MASK,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTablesUpdateDatasetOperator,
+ # Templated fields
+ dataset="{{ 'dataset' }}",
+ update_mask="{{ 'update-mask' }}",
+ location="{{ 'location' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTablesUpdateDatasetOperator = ti.task
+ assert task.dataset == "dataset"
+ assert task.update_mask == "update-mask"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLGetModelOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH)
+ mock_hook.return_value.extract_object_id = extract_object_id
+
+ op = AutoMLGetModelOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.get_model.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLGetModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLGetModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDeleteModelOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLDeleteModelOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=None)
+ mock_hook.return_value.delete_model.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeleteModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeleteModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDeployModelOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ image_detection_metadata = {}
+ op = AutoMLDeployModelOperator(
+ model_id=MODEL_ID,
+ image_detection_metadata=image_detection_metadata,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=None)
+ mock_hook.return_value.deploy_model.assert_called_once_with(
+ image_detection_metadata={},
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeployModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeployModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDatasetImportOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLImportDataOperator(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.import_data.assert_called_once_with(
+ input_config=INPUT_CONFIG,
+ location=GCP_LOCATION,
+ metadata=(),
+ dataset_id=DATASET_ID,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLImportDataOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ input_config="{{ 'input-config' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLImportDataOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.input_config == "input-config"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLTablesListTableSpecsOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ filter_ = "filter"
+ page_size = 42
+
+ op = AutoMLTablesListTableSpecsOperator(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ filter_=filter_,
+ page_size=page_size,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.list_table_specs.assert_called_once_with(
+ dataset_id=DATASET_ID,
+ filter_=filter_,
+ location=GCP_LOCATION,
+ metadata=(),
+ page_size=page_size,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTablesListTableSpecsOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ filter_="{{ 'filter-' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTablesListTableSpecsOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.filter_ == "filter-"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDatasetListOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID)
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.list_datasets.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLListDatasetOperator,
+ # Templated fields
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLListDatasetOperator = ti.task
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDatasetDeleteOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLDeleteDatasetOperator(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=None)
+ mock_hook.return_value.delete_dataset.assert_called_once_with(
+ location=GCP_LOCATION,
+ dataset_id=DATASET_ID,
+ metadata=(),
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeleteDatasetOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeleteDatasetOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
diff --git a/tests/system/providers/google/cloud/automl/example_automl_dataset.py b/tests/system/providers/google/cloud/automl/example_automl_dataset.py
new file mode 100644
index 0000000000000..1c2691657e19b
--- /dev/null
+++ b/tests/system/providers/google/cloud/automl/example_automl_dataset.py
@@ -0,0 +1,201 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google AutoML service testing dataset operations.
+"""
+
+from __future__ import annotations
+
+import os
+from copy import deepcopy
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLImportDataOperator,
+ AutoMLListDatasetOperator,
+ AutoMLTablesListColumnSpecsOperator,
+ AutoMLTablesListTableSpecsOperator,
+ AutoMLTablesUpdateDatasetOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import (
+ GCSCreateBucketOperator,
+ GCSDeleteBucketOperator,
+ GCSSynchronizeBucketsOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+DAG_ID = "example_automl_dataset"
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+
+GCP_AUTOML_LOCATION = "us-central1"
+RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
+
+DATASET_NAME = f"ds_tabular_{ENV_ID}".replace("-", "_")
+DATASET = {
+ "display_name": DATASET_NAME,
+ "tables_dataset_metadata": {"target_column_spec_id": ""},
+}
+AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/tabular-classification.csv"
+IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str:
+ """
+ Using column name returns spec of the column.
+ """
+ for column in columns_specs:
+ if column["display_name"] == column_name:
+ return extract_object_id(column)
+ raise Exception(f"Unknown target column: {column_name}")
+
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "automl", "dataset"],
+ user_defined_macros={
+ "get_target_column_spec": get_target_column_spec,
+ "target": "Class",
+ "extract_object_id": extract_object_id,
+ },
+) as dag:
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ storage_class="REGIONAL",
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ move_dataset_file = GCSSynchronizeBucketsOperator(
+ task_id="move_dataset_to_bucket",
+ source_bucket=RESOURCE_DATA_BUCKET,
+ source_object="automl/datasets/tabular",
+ destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+ destination_object="automl",
+ recursive=True,
+ )
+
+ # [START howto_operator_automl_create_dataset]
+ create_dataset = AutoMLCreateDatasetOperator(
+ task_id="create_dataset",
+ dataset=DATASET,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ dataset_id = create_dataset.output["dataset_id"]
+ # [END howto_operator_automl_create_dataset]
+
+ # [START howto_operator_automl_import_data]
+ import_dataset = AutoMLImportDataOperator(
+ task_id="import_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ input_config=IMPORT_INPUT_CONFIG,
+ )
+ # [END howto_operator_automl_import_data]
+
+ # [START howto_operator_automl_specs]
+ list_tables_spec = AutoMLTablesListTableSpecsOperator(
+ task_id="list_tables_spec",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_automl_specs]
+
+ # [START howto_operator_automl_column_specs]
+ list_columns_spec = AutoMLTablesListColumnSpecsOperator(
+ task_id="list_columns_spec",
+ dataset_id=dataset_id,
+ table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}",
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_automl_column_specs]
+
+ # [START howto_operator_automl_update_dataset]
+ update = deepcopy(DATASET)
+ update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}'
+ update["tables_dataset_metadata"][ # type: ignore
+ "target_column_spec_id"
+ ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'), target) }}"
+
+ update_dataset = AutoMLTablesUpdateDatasetOperator(
+ task_id="update_dataset",
+ dataset=update,
+ location=GCP_AUTOML_LOCATION,
+ )
+ # [END howto_operator_automl_update_dataset]
+
+ # [START howto_operator_list_dataset]
+ list_datasets = AutoMLListDatasetOperator(
+ task_id="list_datasets",
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_list_dataset]
+
+ # [START howto_operator_delete_dataset]
+ delete_dataset = AutoMLDeleteDatasetOperator(
+ task_id="delete_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_delete_dataset]
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
+ )
+
+ (
+ # TEST SETUP
+ [create_bucket >> move_dataset_file, create_dataset]
+ # TEST BODY
+ >> import_dataset
+ >> list_tables_spec
+ >> list_columns_spec
+ >> update_dataset
+ >> list_datasets
+ # TEST TEARDOWN
+ >> delete_dataset
+ >> delete_bucket
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/google/cloud/automl/example_automl_model.py b/tests/system/providers/google/cloud/automl/example_automl_model.py
new file mode 100644
index 0000000000000..59ec91c8790d5
--- /dev/null
+++ b/tests/system/providers/google/cloud/automl/example_automl_model.py
@@ -0,0 +1,288 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google AutoML service testing model operations.
+"""
+
+from __future__ import annotations
+
+import os
+from copy import deepcopy
+from datetime import datetime
+
+from google.protobuf.struct_pb2 import Value
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLBatchPredictOperator,
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLDeleteModelOperator,
+ AutoMLDeployModelOperator,
+ AutoMLGetModelOperator,
+ AutoMLImportDataOperator,
+ AutoMLPredictOperator,
+ AutoMLTablesListColumnSpecsOperator,
+ AutoMLTablesListTableSpecsOperator,
+ AutoMLTablesUpdateDatasetOperator,
+ AutoMLTrainModelOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import (
+ GCSCreateBucketOperator,
+ GCSDeleteBucketOperator,
+ GCSSynchronizeBucketsOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+DAG_ID = "example_automl_model"
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+
+GCP_AUTOML_LOCATION = "us-central1"
+
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
+RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
+
+DATASET_NAME = f"md_tabular_{ENV_ID}".replace("-", "_")
+DATASET = {
+ "display_name": DATASET_NAME,
+ "tables_dataset_metadata": {"target_column_spec_id": ""},
+}
+AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/bank-marketing-split.csv"
+IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+IMPORT_OUTPUT_CONFIG = {
+ "gcs_destination": {"output_uri_prefix": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl"}
+}
+
+# change the name here
+MODEL_NAME = f"md_tabular_{ENV_ID}".replace("-", "_")
+MODEL = {
+ "display_name": MODEL_NAME,
+ "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
+}
+
+PREDICT_VALUES = [
+ Value(string_value="TRAINING"),
+ Value(string_value="51"),
+ Value(string_value="blue-collar"),
+ Value(string_value="married"),
+ Value(string_value="primary"),
+ Value(string_value="no"),
+ Value(string_value="620"),
+ Value(string_value="yes"),
+ Value(string_value="yes"),
+ Value(string_value="cellular"),
+ Value(string_value="29"),
+ Value(string_value="jul"),
+ Value(string_value="88"),
+ Value(string_value="10"),
+ Value(string_value="-1"),
+ Value(string_value="0"),
+ Value(string_value="unknown"),
+]
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str:
+ """
+ Using column name returns spec of the column.
+ """
+ for column in columns_specs:
+ if column["display_name"] == column_name:
+ return extract_object_id(column)
+ raise Exception(f"Unknown target column: {column_name}")
+
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ user_defined_macros={
+ "get_target_column_spec": get_target_column_spec,
+ "target": "Deposit",
+ "extract_object_id": extract_object_id,
+ },
+ tags=["example", "automl", "model"],
+) as dag:
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ storage_class="REGIONAL",
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ move_dataset_file = GCSSynchronizeBucketsOperator(
+ task_id="move_data_to_bucket",
+ source_bucket=RESOURCE_DATA_BUCKET,
+ source_object="automl/datasets/model",
+ destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+ destination_object="automl",
+ recursive=True,
+ )
+
+ create_dataset = AutoMLCreateDatasetOperator(
+ task_id="create_dataset",
+ dataset=DATASET,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ dataset_id = create_dataset.output["dataset_id"]
+ MODEL["dataset_id"] = dataset_id
+ import_dataset = AutoMLImportDataOperator(
+ task_id="import_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ input_config=IMPORT_INPUT_CONFIG,
+ )
+
+ list_tables_spec = AutoMLTablesListTableSpecsOperator(
+ task_id="list_tables_spec",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ list_columns_spec = AutoMLTablesListColumnSpecsOperator(
+ task_id="list_columns_spec",
+ dataset_id=dataset_id,
+ table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec')[0]) }}",
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ update = deepcopy(DATASET)
+ update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}'
+ update["tables_dataset_metadata"][ # type: ignore
+ "target_column_spec_id"
+ ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec'), target) }}"
+
+ update_dataset = AutoMLTablesUpdateDatasetOperator(
+ task_id="update_dataset",
+ dataset=update,
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ # [START howto_operator_automl_create_model]
+ create_model = AutoMLTrainModelOperator(
+ task_id="create_model",
+ model=MODEL,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ model_id = create_model.output["model_id"]
+ # [END howto_operator_automl_create_model]
+
+ # [START howto_operator_get_model]
+ get_model = AutoMLGetModelOperator(
+ task_id="get_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_get_model]
+
+ # [START howto_operator_deploy_model]
+ deploy_model = AutoMLDeployModelOperator(
+ task_id="deploy_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_deploy_model]
+
+ # [START howto_operator_prediction]
+ predict_task = AutoMLPredictOperator(
+ task_id="predict_task",
+ model_id=model_id,
+ payload={
+ "row": {
+ "values": PREDICT_VALUES,
+ }
+ },
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_prediction]
+
+ # [START howto_operator_batch_prediction]
+ batch_predict_task = AutoMLBatchPredictOperator(
+ task_id="batch_predict_task",
+ model_id=model_id,
+ input_config=IMPORT_INPUT_CONFIG,
+ output_config=IMPORT_OUTPUT_CONFIG,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_batch_prediction]
+
+ # [START howto_operator_automl_delete_model]
+ delete_model = AutoMLDeleteModelOperator(
+ task_id="delete_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_automl_delete_model]
+
+ delete_dataset = AutoMLDeleteDatasetOperator(
+ task_id="delete_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
+ )
+
+ (
+ # TEST SETUP
+ [create_bucket >> move_dataset_file, create_dataset]
+ >> import_dataset
+ >> list_tables_spec
+ >> list_columns_spec
+ >> update_dataset
+ # TEST BODY
+ >> create_model
+ >> get_model
+ >> deploy_model
+ >> predict_task
+ >> batch_predict_task
+ # TEST TEARDOWN
+ >> delete_model
+ >> delete_dataset
+ >> delete_bucket
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
index 7305123cb0164..9ef04db81882b 100644
--- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
+++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
@@ -28,6 +28,7 @@
from google.protobuf.struct_pb2 import Value
from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
@@ -70,6 +71,7 @@
"gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
},
]
+extract_object_id = CloudAutoMLHook.extract_object_id
# Example DAG for AutoML Natural Language Text Classification
with DAG(
diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
index 916fc25877c9c..8f8564f62c209 100644
--- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
+++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
@@ -28,6 +28,7 @@
from google.protobuf.struct_pb2 import Value
from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
@@ -69,11 +70,7 @@
},
]
-
-def extract_object_id(obj: dict) -> str:
- """Returns unique id of the object."""
- return obj["name"].rpartition("/")[-1]
-
+extract_object_id = CloudAutoMLHook.extract_object_id
# Example DAG for AutoML Natural Language Entities Extraction
with DAG(
diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
index 0e641e1b05feb..94f349c6c3702 100644
--- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
+++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
@@ -28,6 +28,7 @@
from google.protobuf.struct_pb2 import Value
from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
@@ -70,11 +71,7 @@
},
]
-
-def extract_object_id(obj: dict) -> str:
- """Returns unique id of the object."""
- return obj["name"].rpartition("/")[-1]
-
+extract_object_id = CloudAutoMLHook.extract_object_id
# Example DAG for AutoML Natural Language Text Sentiment
with DAG(
diff --git a/tests/system/providers/google/cloud/automl/example_automl_translation.py b/tests/system/providers/google/cloud/automl/example_automl_translation.py
new file mode 100644
index 0000000000000..ba36f556c427d
--- /dev/null
+++ b/tests/system/providers/google/cloud/automl/example_automl_translation.py
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that uses Google AutoML services.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+from typing import cast
+
+# The storage module cannot be imported yet https://github.com/googleapis/python-storage/issues/393
+from google.cloud import storage # type: ignore[attr-defined]
+
+from airflow.decorators import task
+from airflow.models.dag import DAG
+from airflow.models.xcom_arg import XComArg
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLDeleteModelOperator,
+ AutoMLImportDataOperator,
+ AutoMLTrainModelOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
+from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "example_automl_translate"
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+GCP_AUTOML_LOCATION = "us-central1"
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
+RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
+
+
+MODEL_NAME = "translate_test_model"
+MODEL = {
+ "display_name": MODEL_NAME,
+ "translation_model_metadata": {},
+}
+
+DATASET_NAME = f"ds_translate_{ENV_ID}".replace("-", "_")
+DATASET = {
+ "display_name": DATASET_NAME,
+ "translation_dataset_metadata": {
+ "source_language_code": "en",
+ "target_language_code": "es",
+ },
+}
+
+CSV_FILE_NAME = "en-es.csv"
+TSV_FILE_NAME = "en-es.tsv"
+GCS_FILE_PATH = f"automl/datasets/translate/{CSV_FILE_NAME}"
+AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/{CSV_FILE_NAME}"
+IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+# Example DAG for AutoML Translation
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ user_defined_macros={"extract_object_id": extract_object_id},
+ tags=["example", "automl", "translate"],
+) as dag:
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ storage_class="REGIONAL",
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ @task
+ def upload_csv_file_to_gcs():
+ # download file into memory
+ storage_client = storage.Client()
+ bucket = storage_client.bucket(RESOURCE_DATA_BUCKET)
+ blob = bucket.blob(GCS_FILE_PATH)
+ contents = blob.download_as_string().decode()
+
+ # update memory content
+ updated_contents = contents.replace("template-bucket", DATA_SAMPLE_GCS_BUCKET_NAME)
+
+ # upload updated content to bucket
+ destination_bucket = storage_client.bucket(DATA_SAMPLE_GCS_BUCKET_NAME)
+ destination_blob = destination_bucket.blob(f"automl/{CSV_FILE_NAME}")
+ destination_blob.upload_from_string(updated_contents)
+
+ upload_csv_file_to_gcs_task = upload_csv_file_to_gcs()
+
+ copy_dataset_file = GCSToGCSOperator(
+ task_id="copy_dataset_file",
+ source_bucket=RESOURCE_DATA_BUCKET,
+ source_object=f"automl/datasets/translate/{TSV_FILE_NAME}",
+ destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+ destination_object=f"automl/{TSV_FILE_NAME}",
+ )
+
+ create_dataset = AutoMLCreateDatasetOperator(
+ task_id="create_dataset", dataset=DATASET, location=GCP_AUTOML_LOCATION
+ )
+
+ dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
+
+ import_dataset = AutoMLImportDataOperator(
+ task_id="import_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ input_config=IMPORT_INPUT_CONFIG,
+ )
+
+ MODEL["dataset_id"] = dataset_id
+
+ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
+ model_id = cast(str, XComArg(create_model, key="model_id"))
+
+ delete_model = AutoMLDeleteModelOperator(
+ task_id="delete_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ delete_dataset = AutoMLDeleteDatasetOperator(
+ task_id="delete_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ [create_bucket >> upload_csv_file_to_gcs_task >> copy_dataset_file]
+ # TEST BODY
+ >> create_dataset
+ >> import_dataset
+ >> create_model
+ # TEST TEARDOWN
+ >> delete_dataset
+ >> delete_model
+ >> delete_bucket
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)