diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py new file mode 100644 index 0000000000000..8c1f0d7437f6a --- /dev/null +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -0,0 +1,315 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that demonstrates operators for the Google Vertex AI service in the Google +Cloud Platform. + +This DAG relies on the following OS environment variables: + +* GCP_VERTEX_AI_BUCKET - Google Cloud Storage bucket where the model will be saved +after training process was finished. +* CUSTOM_CONTAINER_URI - path to container with model. +* PYTHON_PACKAGE_GSC_URI - path to test model in archive. +* LOCAL_TRAINING_SCRIPT_PATH - path to local training script. +* DATASET_ID - ID of dataset which will be used in training process. +""" +import os +from datetime import datetime +from uuid import uuid4 + +from google.protobuf.struct_pb2 import Value + +from airflow import models +from airflow.providers.google.cloud.operators.vertex_ai.custom_job import ( + CreateCustomContainerTrainingJobOperator, + CreateCustomPythonPackageTrainingJobOperator, + CreateCustomTrainingJobOperator, + DeleteCustomTrainingJobOperator, + ListCustomTrainingJobOperator, +) +from airflow.providers.google.cloud.operators.vertex_ai.dataset import ( + CreateDatasetOperator, + DeleteDatasetOperator, + ExportDataOperator, + GetDatasetOperator, + ImportDataOperator, + ListDatasetsOperator, + UpdateDatasetOperator, +) + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") +REGION = os.environ.get("GCP_LOCATION", "us-central1") +BUCKET = os.environ.get("GCP_VERTEX_AI_BUCKET", "vertex-ai-system-tests") + +STAGING_BUCKET = f"gs://{BUCKET}" +DISPLAY_NAME = str(uuid4()) # Create random display name +CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" +CUSTOM_CONTAINER_URI = os.environ.get("CUSTOM_CONTAINER_URI", "path_to_container_with_model") +MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest" +REPLICA_COUNT = 1 +MACHINE_TYPE = "n1-standard-4" +ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" +ACCELERATOR_COUNT = 0 +TRAINING_FRACTION_SPLIT = 0.7 +TEST_FRACTION_SPLIT = 0.15 +VALIDATION_FRACTION_SPLIT = 0.15 + +PYTHON_PACKAGE_GCS_URI = os.environ.get("PYTHON_PACKAGE_GSC_URI", "path_to_test_model_in_arch") +PYTHON_MODULE_NAME = "aiplatform_custom_trainer_script.task" + +LOCAL_TRAINING_SCRIPT_PATH = os.environ.get("LOCAL_TRAINING_SCRIPT_PATH", "path_to_training_script") + +TRAINING_PIPELINE_ID = "test-training-pipeline-id" +CUSTOM_JOB_ID = "test-custom-job-id" + +IMAGE_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml", + "metadata": Value(string_value="test-image-dataset"), +} +TABULAR_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml", + "metadata": Value(string_value="test-tabular-dataset"), +} +TEXT_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml", + "metadata": Value(string_value="test-text-dataset"), +} +VIDEO_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml", + "metadata": Value(string_value="test-video-dataset"), +} +TIME_SERIES_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/time_series_1.0.0.yaml", + "metadata": Value(string_value="test-video-dataset"), +} +DATASET_ID = os.environ.get("DATASET_ID", "test-dataset-id") +TEST_EXPORT_CONFIG = {"gcs_destination": {"output_uri_prefix": "gs://test-vertex-ai-bucket/exports"}} +TEST_IMPORT_CONFIG = [ + { + "data_item_labels": { + "test-labels-name": "test-labels-value", + }, + "import_schema_uri": ( + "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" + ), + "gcs_source": { + "uris": ["gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl"] + }, + }, +] +DATASET_TO_UPDATE = {"display_name": "test-name"} +TEST_UPDATE_MASK = {"paths": ["displayName"]} + +with models.DAG( + "example_gcp_vertex_ai_custom_jobs", + schedule_interval="@once", + start_date=datetime(2021, 1, 1), + catchup=False, +) as custom_jobs_dag: + # [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + create_custom_container_training_job = CreateCustomContainerTrainingJobOperator( + task_id="custom_container_task", + staging_bucket=STAGING_BUCKET, + display_name=f"train-housing-container-{DISPLAY_NAME}", + container_uri=CUSTOM_CONTAINER_URI, + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=DATASET_ID, + command=["python3", "task.py"], + model_display_name=f"container-housing-model-{DISPLAY_NAME}", + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + + # [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + create_custom_python_package_training_job = CreateCustomPythonPackageTrainingJobOperator( + task_id="python_package_task", + staging_bucket=STAGING_BUCKET, + display_name=f"train-housing-py-package-{DISPLAY_NAME}", + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=DATASET_ID, + model_display_name=f"py-package-housing-model-{DISPLAY_NAME}", + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + + # [START how_to_cloud_vertex_ai_create_custom_training_job_operator] + create_custom_training_job = CreateCustomTrainingJobOperator( + task_id="custom_task", + staging_bucket=STAGING_BUCKET, + display_name=f"train-housing-custom-{DISPLAY_NAME}", + script_path=LOCAL_TRAINING_SCRIPT_PATH, + container_uri=CONTAINER_URI, + requirements=["gcsfs==0.7.1"], + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=DATASET_ID, + replica_count=1, + model_display_name=f"custom-housing-model-{DISPLAY_NAME}", + sync=False, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_custom_training_job_operator] + + # [START how_to_cloud_vertex_ai_delete_custom_training_job_operator] + delete_custom_training_job = DeleteCustomTrainingJobOperator( + task_id="delete_custom_training_job", + training_pipeline_id=TRAINING_PIPELINE_ID, + custom_job_id=CUSTOM_JOB_ID, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_delete_custom_training_job_operator] + + # [START how_to_cloud_vertex_ai_list_custom_training_job_operator] + list_custom_training_job = ListCustomTrainingJobOperator( + task_id="list_custom_training_job", + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_list_custom_training_job_operator] + +with models.DAG( + "example_gcp_vertex_ai_dataset", + schedule_interval="@once", + start_date=datetime(2021, 1, 1), + catchup=False, +) as dataset_dag: + # [START how_to_cloud_vertex_ai_create_dataset_operator] + create_image_dataset_job = CreateDatasetOperator( + task_id="image_dataset", + dataset=IMAGE_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_tabular_dataset_job = CreateDatasetOperator( + task_id="tabular_dataset", + dataset=TABULAR_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_text_dataset_job = CreateDatasetOperator( + task_id="text_dataset", + dataset=TEXT_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_video_dataset_job = CreateDatasetOperator( + task_id="video_dataset", + dataset=VIDEO_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_time_series_dataset_job = CreateDatasetOperator( + task_id="time_series_dataset", + dataset=TIME_SERIES_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_dataset_operator] + + # [START how_to_cloud_vertex_ai_delete_dataset_operator] + delete_dataset_job = DeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=create_text_dataset_job.output['dataset_id'], + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_delete_dataset_operator] + + # [START how_to_cloud_vertex_ai_get_dataset_operator] + get_dataset = GetDatasetOperator( + task_id="get_dataset", + project_id=PROJECT_ID, + region=REGION, + dataset_id=create_tabular_dataset_job.output['dataset_id'], + ) + # [END how_to_cloud_vertex_ai_get_dataset_operator] + + # [START how_to_cloud_vertex_ai_export_data_operator] + export_data_job = ExportDataOperator( + task_id="export_data", + dataset_id=create_image_dataset_job.output['dataset_id'], + region=REGION, + project_id=PROJECT_ID, + export_config=TEST_EXPORT_CONFIG, + ) + # [END how_to_cloud_vertex_ai_export_data_operator] + + # [START how_to_cloud_vertex_ai_import_data_operator] + import_data_job = ImportDataOperator( + task_id="import_data", + dataset_id=create_image_dataset_job.output['dataset_id'], + region=REGION, + project_id=PROJECT_ID, + import_configs=TEST_IMPORT_CONFIG, + ) + # [END how_to_cloud_vertex_ai_import_data_operator] + + # [START how_to_cloud_vertex_ai_list_dataset_operator] + list_dataset_job = ListDatasetsOperator( + task_id="list_dataset", + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_list_dataset_operator] + + # [START how_to_cloud_vertex_ai_update_dataset_operator] + update_dataset_job = UpdateDatasetOperator( + task_id="update_dataset", + project_id=PROJECT_ID, + region=REGION, + dataset_id=create_video_dataset_job.output['dataset_id'], + dataset=DATASET_TO_UPDATE, + update_mask=TEST_UPDATE_MASK, + ) + # [END how_to_cloud_vertex_ai_update_dataset_operator] + + create_time_series_dataset_job + create_text_dataset_job >> delete_dataset_job + create_tabular_dataset_job >> get_dataset + create_image_dataset_job >> import_data_job >> export_data_job + create_video_dataset_job >> update_dataset_job + list_dataset_job diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py b/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py new file mode 100644 index 0000000000000..a0ae9eca016e3 --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -0,0 +1,2051 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud Vertex AI hook.""" + +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core.operation import Operation +from google.api_core.retry import Retry +from google.cloud.aiplatform import ( + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + CustomTrainingJob, + datasets, + models, +) +from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient +from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager +from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ( + ListPipelineJobsPager, + ListTrainingPipelinesPager, +) +from google.cloud.aiplatform_v1.types import CustomJob, PipelineJob, TrainingPipeline + +from airflow import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + + +class CustomJobHook(GoogleBaseHook): + """Hook for Google Cloud Vertex AI Custom Job APIs.""" + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._job: Optional[ + Union[ + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + CustomTrainingJob, + ] + ] = None + + def get_pipeline_service_client( + self, + region: Optional[str] = None, + ) -> PipelineServiceClient: + """Returns PipelineServiceClient.""" + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} + + return PipelineServiceClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def get_job_service_client( + self, + region: Optional[str] = None, + ) -> JobServiceClient: + """Returns JobServiceClient""" + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} + + return JobServiceClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def get_custom_container_training_job( + self, + display_name: str, + container_uri: str, + command: Sequence[str] = [], + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ) -> CustomContainerTrainingJob: + """Returns CustomContainerTrainingJob object""" + return CustomContainerTrainingJob( + display_name=display_name, + container_uri=container_uri, + command=command, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + project=project, + location=location, + credentials=self._get_credentials(), + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + def get_custom_python_package_training_job( + self, + display_name: str, + python_package_gcs_uri: str, + python_module_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Returns CustomPythonPackageTrainingJob object""" + return CustomPythonPackageTrainingJob( + display_name=display_name, + container_uri=container_uri, + python_package_gcs_uri=python_package_gcs_uri, + python_module_name=python_module_name, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + project=project, + location=location, + credentials=self._get_credentials(), + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + def get_custom_training_job( + self, + display_name: str, + script_path: str, + container_uri: str, + requirements: Optional[Sequence[str]] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Returns CustomTrainingJob object""" + return CustomTrainingJob( + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + requirements=requirements, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + project=project, + location=location, + credentials=self._get_credentials(), + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + @staticmethod + def extract_model_id(obj: Dict) -> str: + """Returns unique id of the Model.""" + return obj["name"].rpartition("/")[-1] + + def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + def cancel_job(self) -> None: + """Cancel Job for training pipeline""" + if self._job: + self._job.cancel() + + def _run_job( + self, + job: Union[ + CustomTrainingJob, + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + ], + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> models.Model: + """Run Job for training pipeline""" + model = job.run( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + if model: + model.wait() + return model + else: + raise AirflowException("Training did not produce a Managed Model returning None.") + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + """ + Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best + effort to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other + methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a + pipeline with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a + [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job: The name of the PipelineJob to cancel. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + client.cancel_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + """ + Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes + a best effort to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline] + or other methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the TrainingPipeline is not deleted; instead it becomes a + pipeline with a [TrainingPipeline.error][google.cloud.aiplatform.v1.TrainingPipeline.error] value with + a [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param training_pipeline: Required. The name of the TrainingPipeline to cancel. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + client.cancel_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_custom_job( + self, + project_id: str, + region: str, + custom_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + """ + Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort + to cancel the job, but success is not guaranteed. Clients can use + [JobService.GetCustomJob][google.cloud.aiplatform.v1.JobService.GetCustomJob] or other methods to + check whether the cancellation succeeded or whether the job completed despite cancellation. On + successful cancellation, the CustomJob is not deleted; instead it becomes a job with a + [CustomJob.error][google.cloud.aiplatform.v1.CustomJob.error] value with a + [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [CustomJob.state][google.cloud.aiplatform.v1.CustomJob.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param custom_job: Required. The name of the CustomJob to cancel. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_job_service_client(region) + name = JobServiceClient.custom_job_path(project_id, region, custom_job) + + client.cancel_custom_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: PipelineJob, + pipeline_job_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> PipelineJob: + """ + Creates a PipelineJob. A PipelineJob will run immediately when created. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job: Required. The PipelineJob to create. + :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of + the PipelineJob name. If not provided, an ID will be automatically generated. + + This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_pipeline_job( + request={ + 'parent': parent, + 'pipeline_job': pipeline_job, + 'pipeline_job_id': pipeline_job_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: TrainingPipeline, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> TrainingPipeline: + """ + Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param training_pipeline: Required. The TrainingPipeline to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_training_pipeline( + request={ + 'parent': parent, + 'training_pipeline': training_pipeline, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_job( + self, + project_id: str, + region: str, + custom_job: CustomJob, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> CustomJob: + """ + Creates a CustomJob. A created CustomJob right away will be attempted to be run. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param custom_job: Required. The CustomJob to create. This corresponds to the ``custom_job`` field on + the ``request`` instance; if ``request`` is provided, this should not be set. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_job_service_client(region) + parent = JobServiceClient.common_location_path(project_id, region) + + result = client.create_custom_job( + request={ + 'parent': parent, + 'custom_job': custom_job, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_container_training_job( + self, + project_id: str, + region: str, + display_name: str, + container_uri: str, + command: Sequence[str] = [], + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> models.Model: + """ + Create Custom Container Training Job + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param command: The command to be invoked when the container is started. + It overrides the entrypoint instruction in Dockerfile when provided + :param container_uri: Required: Uri of the training container image in the GCR. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :param sync: Whether to execute the AI Platform job synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + self._job = self.get_custom_container_training_job( + project=project_id, + location=region, + display_name=display_name, + container_uri=container_uri, + command=command, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + if not self._job: + raise AirflowException("CustomJob was not created") + + model = self._run_job( + job=self._job, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + + return model + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_python_package_training_job( + self, + project_id: str, + region: str, + display_name: str, + python_package_gcs_uri: str, + python_module_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> models.Model: + """ + Create Custom Python Package Training Job + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param python_package_gcs_uri: Required: GCS location of the training python package. + :param python_module_name: Required: The module name of the training python package. + :param container_uri: Required: Uri of the training container image in the GCR. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :param sync: Whether to execute the AI Platform job synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + self._job = self.get_custom_python_package_training_job( + project=project_id, + location=region, + display_name=display_name, + python_package_gcs_uri=python_package_gcs_uri, + python_module_name=python_module_name, + container_uri=container_uri, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + if not self._job: + raise AirflowException("CustomJob was not created") + + model = self._run_job( + job=self._job, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + + return model + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_training_job( + self, + project_id: str, + region: str, + display_name: str, + script_path: str, + container_uri: str, + requirements: Optional[Sequence[str]] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> models.Model: + """ + Create Custom Training Job + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param script_path: Required. Local path to training script. + :param container_uri: Required: Uri of the training container image in the GCR. + :param requirements: List of python packages dependencies of script. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :param sync: Whether to execute the AI Platform job synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + self._job = self.get_custom_training_job( + project=project_id, + location=region, + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + requirements=requirements, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + if not self._job: + raise AirflowException("CustomJob was not created") + + model = self._run_job( + job=self._job, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + + return model + + @GoogleBaseHook.fallback_to_default_project_id + def delete_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Operation: + """ + Deletes a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + result = client.delete_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Operation: + """ + Deletes a TrainingPipeline. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + result = client.delete_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_custom_job( + self, + project_id: str, + region: str, + custom_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Operation: + """ + Deletes a CustomJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param custom_job: Required. The name of the CustomJob to delete. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_job_service_client(region) + name = client.custom_job_path(project_id, region, custom_job) + + result = client.delete_custom_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> PipelineJob: + """ + Gets a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job: Required. The name of the PipelineJob resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + result = client.get_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> TrainingPipeline: + """ + Gets a TrainingPipeline. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param training_pipeline: Required. The name of the TrainingPipeline resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + result = client.get_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_custom_job( + self, + project_id: str, + region: str, + custom_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> CustomJob: + """ + Gets a CustomJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param custom_job: Required. The name of the CustomJob to get. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_job_service_client(region) + name = JobServiceClient.custom_job_path(project_id, region, custom_job) + + result = client.get_custom_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_pipeline_jobs( + self, + project_id: str, + region: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ListPipelineJobsPager: + """ + Lists PipelineJobs in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: Optional. Lists the PipelineJobs that match the filter expression. The + following fields are supported: + + - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons. + - ``display_name``: Supports ``=``, ``!=`` comparisons, and + ``:`` wildcard. + - ``pipeline_job_user_id``: Supports ``=``, ``!=`` + comparisons, and ``:`` wildcard. for example, can check + if pipeline's display_name contains *step* by doing + display_name:"*step*" + - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``labels``: Supports key-value equality and key presence. + + Filter expressions can be combined together using logical + operators (``AND`` & ``OR``). For example: + ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``. + + The syntax to define filter expression is based on + https://google.aip.dev/160. + :param page_size: Optional. The standard list page size. + :param page_token: Optional. The standard list page token. Typically obtained via + [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token] + of the previous + [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs] + call. + :param order_by: Optional. A comma-separated list of fields to order by. The default + sort order is in ascending order. Use "desc" after a field + name for descending. You can have multiple order_by fields + provided e.g. "create_time desc, end_time", "end_time, + start_time, update_time" For example, using "create_time + desc, end_time" will order results by create time in + descending order, and if there are multiple jobs having the + same create time, order them by the end time in ascending + order. if order_by is not specified, it will order by + default order is create time in descending order. Supported + fields: + + - ``create_time`` + - ``update_time`` + - ``end_time`` + - ``start_time`` + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_pipeline_jobs( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_training_pipelines( + self, + project_id: str, + region: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ListTrainingPipelinesPager: + """ + Lists TrainingPipelines in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: Optional. The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + :param page_size: Optional. The standard list page size. + :param page_token: Optional. The standard list page token. Typically obtained via + [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] + of the previous + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] + call. + :param read_mask: Optional. Mask specifying which fields to read. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_training_pipelines( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_custom_jobs( + self, + project_id: str, + region: str, + page_size: Optional[int], + page_token: Optional[str], + filter: Optional[str], + read_mask: Optional[str], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ListCustomJobsPager: + """ + Lists CustomJobs in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: Optional. The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + :param page_size: Optional. The standard list page size. + :param page_token: Optional. The standard list page token. Typically obtained via + [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] + of the previous + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] + call. + :param read_mask: Optional. Mask specifying which fields to read. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_job_service_client(region) + parent = JobServiceClient.common_location_path(project_id, region) + + result = client.list_custom_jobs( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py new file mode 100644 index 0000000000000..351419e77a83e --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py @@ -0,0 +1,460 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud Vertex AI hook.""" + +from typing import Dict, Optional, Sequence, Tuple, Union + +from google.api_core.operation import Operation +from google.api_core.retry import Retry +from google.cloud.aiplatform_v1 import DatasetServiceClient +from google.cloud.aiplatform_v1.services.dataset_service.pagers import ( + ListAnnotationsPager, + ListDataItemsPager, + ListDatasetsPager, +) +from google.cloud.aiplatform_v1.types import AnnotationSpec, Dataset, ExportDataConfig, ImportDataConfig +from google.protobuf.field_mask_pb2 import FieldMask + +from airflow import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + + +class DatasetHook(GoogleBaseHook): + """Hook for Google Cloud Vertex AI Dataset APIs.""" + + def get_dataset_service_client(self, region: Optional[str] = None) -> DatasetServiceClient: + """Returns DatasetServiceClient.""" + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} + + return DatasetServiceClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @staticmethod + def extract_dataset_id(obj: Dict) -> str: + """Returns unique id of the dataset.""" + return obj["name"].rpartition("/")[-1] + + @GoogleBaseHook.fallback_to_default_project_id + def create_dataset( + self, + project_id: str, + region: str, + dataset: Union[Dataset, Dict], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Operation: + """ + Creates a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The Dataset to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_dataset( + request={ + 'parent': parent, + 'dataset': dataset, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_dataset( + self, + project_id: str, + region: str, + dataset: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Operation: + """ + Deletes a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The ID of the Dataset to delete. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.delete_dataset( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def export_data( + self, + project_id: str, + region: str, + dataset: str, + export_config: Union[ExportDataConfig, Dict], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Operation: + """ + Exports data from a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The ID of the Dataset to export. + :param export_config: Required. The desired output location. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.export_data( + request={ + 'name': name, + 'export_config': export_config, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_annotation_spec( + self, + project_id: str, + region: str, + dataset: str, + annotation_spec: str, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> AnnotationSpec: + """ + Gets an AnnotationSpec. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The ID of the Dataset. + :param annotation_spec: The ID of the AnnotationSpec resource. + :param read_mask: Optional. Mask specifying which fields to read. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + name = client.annotation_spec_path(project_id, region, dataset, annotation_spec) + + result = client.get_annotation_spec( + request={ + 'name': name, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_dataset( + self, + project_id: str, + region: str, + dataset: str, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Dataset: + """ + Gets a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The ID of the Dataset to export. + :param read_mask: Optional. Mask specifying which fields to read. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.get_dataset( + request={ + 'name': name, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def import_data( + self, + project_id: str, + region: str, + dataset: str, + import_configs: Sequence[ImportDataConfig], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Operation: + """ + Imports data into a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The ID of the Dataset to import. + :param import_configs: Required. The desired input locations. The contents of all input locations + will be imported in one batch. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.import_data( + request={ + 'name': name, + 'import_configs': import_configs, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_annotations( + self, + project_id: str, + region: str, + dataset: str, + data_item: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ListAnnotationsPager: + """ + Lists Annotations belongs to a data item + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The ID of the Dataset. + :param data_item: Required. The ID of the DataItem to list Annotations from. + :param filter: The standard list filter. + :param page_size: The standard list page size. + :param page_token: The standard list page token. + :param read_mask: Mask specifying which fields to read. + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + parent = client.data_item_path(project_id, region, dataset, data_item) + + result = client.list_annotations( + request={ + 'parent': parent, + 'filter': filter, + 'page_size': page_size, + 'page_token': page_token, + 'read_mask': read_mask, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_data_items( + self, + project_id: str, + region: str, + dataset: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ListDataItemsPager: + """ + Lists DataItems in a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset: Required. The ID of the Dataset. + :param filter: The standard list filter. + :param page_size: The standard list page size. + :param page_token: The standard list page token. + :param read_mask: Mask specifying which fields to read. + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + parent = client.dataset_path(project_id, region, dataset) + + result = client.list_data_items( + request={ + 'parent': parent, + 'filter': filter, + 'page_size': page_size, + 'page_token': page_token, + 'read_mask': read_mask, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_datasets( + self, + project_id: str, + region: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ListDatasetsPager: + """ + Lists Datasets in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: The standard list filter. + :param page_size: The standard list page size. + :param page_token: The standard list page token. + :param read_mask: Mask specifying which fields to read. + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_datasets( + request={ + 'parent': parent, + 'filter': filter, + 'page_size': page_size, + 'page_token': page_token, + 'read_mask': read_mask, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def update_dataset( + self, + project_id: str, + region: str, + dataset_id: str, + dataset: Union[Dataset, Dict], + update_mask: Union[FieldMask, Dict], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Dataset: + """ + Updates a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset_id: Required. The ID of the Dataset. + :param dataset: Required. The Dataset which replaces the resource on the server. + :param update_mask: Required. The update mask applies to the resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataset_service_client(region) + dataset["name"] = client.dataset_path(project_id, region, dataset_id) + + result = client.update_dataset( + request={ + 'dataset': dataset, + 'update_mask': update_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/airflow/providers/google/cloud/operators/vertex_ai/__init__.py b/airflow/providers/google/cloud/operators/vertex_ai/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py new file mode 100644 index 0000000000000..822cf12255aa5 --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -0,0 +1,1449 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Vertex AI operators.""" + +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core.exceptions import NotFound +from google.api_core.retry import Retry +from google.cloud.aiplatform.models import Model +from google.cloud.aiplatform_v1.types.dataset import Dataset +from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline + +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.xcom import XCom +from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + +VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" +VERTEX_AI_MODEL_LINK = ( + VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}" +) +VERTEX_AI_TRAINING_PIPELINES_LINK = VERTEX_AI_BASE_LINK + "/training/training-pipelines?project={project_id}" + + +class VertexAIModelLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Model link""" + + name = "Vertex AI Model" + + def get_link(self, operator, dttm): + model_conf = XCom.get_one( + key='model_conf', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) + return ( + VERTEX_AI_MODEL_LINK.format( + region=model_conf["region"], + model_id=model_conf["model_id"], + project_id=model_conf["project_id"], + ) + if model_conf + else "" + ) + + +class VertexAITrainingPipelinesLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Training Pipelines link""" + + name = "Vertex AI Training Pipelines" + + def get_link(self, operator, dttm): + project_id = XCom.get_one( + key='project_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) + return ( + VERTEX_AI_TRAINING_PIPELINES_LINK.format( + project_id=project_id, + ) + if project_id + else "" + ) + + +class CustomTrainingJobBaseOperator(BaseOperator): + """The base class for operators that launch Custom jobs on VertexAI.""" + + def __init__( + self, + *, + project_id: str, + region: str, + display_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset_id: Optional[str] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.display_name = display_name + # START Custom + self.container_uri = container_uri + self.model_serving_container_image_uri = model_serving_container_image_uri + self.model_serving_container_predict_route = model_serving_container_predict_route + self.model_serving_container_health_route = model_serving_container_health_route + self.model_serving_container_command = model_serving_container_command + self.model_serving_container_args = model_serving_container_args + self.model_serving_container_environment_variables = model_serving_container_environment_variables + self.model_serving_container_ports = model_serving_container_ports + self.model_description = model_description + self.model_instance_schema_uri = model_instance_schema_uri + self.model_parameters_schema_uri = model_parameters_schema_uri + self.model_prediction_schema_uri = model_prediction_schema_uri + self.labels = labels + self.training_encryption_spec_key_name = training_encryption_spec_key_name + self.model_encryption_spec_key_name = model_encryption_spec_key_name + self.staging_bucket = staging_bucket + # END Custom + # START Run param + self.dataset = Dataset(name=dataset_id) if dataset_id else None + self.annotation_schema_uri = annotation_schema_uri + self.model_display_name = model_display_name + self.model_labels = model_labels + self.base_output_dir = base_output_dir + self.service_account = service_account + self.network = network + self.bigquery_destination = bigquery_destination + self.args = args + self.environment_variables = environment_variables + self.replica_count = replica_count + self.machine_type = machine_type + self.accelerator_type = accelerator_type + self.accelerator_count = accelerator_count + self.boot_disk_type = boot_disk_type + self.boot_disk_size_gb = boot_disk_size_gb + self.training_fraction_split = training_fraction_split + self.validation_fraction_split = validation_fraction_split + self.test_fraction_split = test_fraction_split + self.training_filter_split = training_filter_split + self.validation_filter_split = validation_filter_split + self.test_filter_split = test_filter_split + self.predefined_split_column_name = predefined_split_column_name + self.timestamp_split_column_name = timestamp_split_column_name + self.tensorboard = tensorboard + self.sync = sync + # END Run param + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + +class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): + """Create Custom Container Training job + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param command: The command to be invoked when the container is started. + It overrides the entrypoint instruction in Dockerfile when provided + :param container_uri: Required: Uri of the training container image in the GCR. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :param sync: Whether to execute the AI Platform job synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = [ + 'region', + 'command', + 'impersonation_chain', + ] + operator_extra_links = (VertexAIModelLink(),) + + def __init__( + self, + *, + command: Sequence[str] = [], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.command = command + + def execute(self, context: 'Context'): + self.hook = CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + model = self.hook.create_custom_container_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + container_uri=self.container_uri, + command=self.command, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=self.dataset, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + sync=True, + ) + + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + self.xcom_push( + context, + key="model_conf", + value={ + "model_id": model_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return result + + def on_kill(self) -> None: + """ + Callback called when the operator is killed. + Cancel any running job. + """ + if self.hook: + self.hook.cancel_job() + + +class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator): + """Create Custom Python Package Training job + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param python_package_gcs_uri: Required: GCS location of the training python package. + :param python_module_name: Required: The module name of the training python package. + :param container_uri: Required: Uri of the training container image in the GCR. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :param sync: Whether to execute the AI Platform job synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = [ + 'region', + 'impersonation_chain', + ] + operator_extra_links = (VertexAIModelLink(),) + + def __init__( + self, + *, + python_package_gcs_uri: str, + python_module_name: str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.python_package_gcs_uri = python_package_gcs_uri + self.python_module_name = python_module_name + + def execute(self, context: 'Context'): + self.hook = CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + model = self.hook.create_custom_python_package_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + python_package_gcs_uri=self.python_package_gcs_uri, + python_module_name=self.python_module_name, + container_uri=self.container_uri, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=self.dataset, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + sync=True, + ) + + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + self.xcom_push( + context, + key="model_conf", + value={ + "model_id": model_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return result + + def on_kill(self) -> None: + """ + Callback called when the operator is killed. + Cancel any running job. + """ + if self.hook: + self.hook.cancel_job() + + +class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): + """Create Custom Training job + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param script_path: Required. Local path to training script. + :param container_uri: Required: Uri of the training container image in the GCR. + :param requirements: List of python packages dependencies of script. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :param sync: Whether to execute the AI Platform job synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = [ + 'region', + 'script_path', + 'requirements', + 'impersonation_chain', + ] + operator_extra_links = (VertexAIModelLink(),) + + def __init__( + self, + *, + script_path: str, + requirements: Optional[Sequence[str]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.requirements = requirements + self.script_path = script_path + + def execute(self, context: 'Context'): + self.hook = CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + model = self.hook.create_custom_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + script_path=self.script_path, + container_uri=self.container_uri, + requirements=self.requirements, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=self.dataset, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + sync=True, + ) + + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + self.xcom_push( + context, + key="model_conf", + value={ + "model_id": model_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return result + + def on_kill(self) -> None: + """ + Callback called when the operator is killed. + Cancel any running job. + """ + if self.hook: + self.hook.cancel_job() + + +class DeleteCustomTrainingJobOperator(BaseOperator): + """Deletes a CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob. + + :param training_pipeline_id: Required. The name of the TrainingPipeline resource to be deleted. + :param custom_job_id: Required. The name of the CustomJob to delete. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + training_pipeline_id: str, + custom_job_id: str, + region: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.training_pipeline = training_pipeline_id + self.custom_job = custom_job_id + self.region = region + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + try: + self.log.info("Deleting custom training pipeline: %s", self.training_pipeline) + training_pipeline_operation = hook.delete_training_pipeline( + training_pipeline=self.training_pipeline, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation) + self.log.info("Training pipeline was deleted.") + except NotFound: + self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline) + try: + self.log.info("Deleting custom job: %s", self.custom_job) + custom_job_operation = hook.delete_custom_job( + custom_job=self.custom_job, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=custom_job_operation) + self.log.info("Custom job was deleted.") + except NotFound: + self.log.info("The Custom Job ID %s does not exist.", self.custom_job) + + +class ListCustomTrainingJobOperator(BaseOperator): + """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: Optional. The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + :param page_size: Optional. The standard list page size. + :param page_token: Optional. The standard list page token. Typically obtained via + [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] + of the previous + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] + call. + :param read_mask: Optional. Mask specifying which fields to read. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = [ + "region", + "project_id", + "impersonation_chain", + ] + operator_extra_links = [ + VertexAITrainingPipelinesLink(), + ] + + def __init__( + self, + *, + region: str, + project_id: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.page_size = page_size + self.page_token = page_token + self.filter = filter + self.read_mask = read_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + results = hook.list_training_pipelines( + region=self.region, + project_id=self.project_id, + page_size=self.page_size, + page_token=self.page_token, + filter=self.filter, + read_mask=self.read_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.xcom_push(context, key="project_id", value=self.project_id) + return [TrainingPipeline.to_dict(result) for result in results] diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py new file mode 100644 index 0000000000000..34fc46b4eced8 --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -0,0 +1,646 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Vertex AI operators.""" + +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core.exceptions import NotFound +from google.api_core.retry import Retry +from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig +from google.protobuf.field_mask_pb2 import FieldMask + +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.xcom import XCom +from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + +VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" +VERTEX_AI_DATASET_LINK = ( + VERTEX_AI_BASE_LINK + "/locations/{region}/datasets/{dataset_id}/analyze?project={project_id}" +) +VERTEX_AI_DATASET_LIST_LINK = VERTEX_AI_BASE_LINK + "/datasets?project={project_id}" + + +class VertexAIDatasetLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Dataset link""" + + name = "Dataset" + + def get_link(self, operator, dttm): + dataset_conf = XCom.get_one( + key='dataset_conf', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) + return ( + VERTEX_AI_DATASET_LINK.format( + region=dataset_conf["region"], + dataset_id=dataset_conf["dataset_id"], + project_id=dataset_conf["project_id"], + ) + if dataset_conf + else "" + ) + + +class VertexAIDatasetListLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Datasets Link""" + + name = "Dataset List" + + def get_link(self, operator, dttm): + project_id = XCom.get_one( + key='project_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) + return ( + VERTEX_AI_DATASET_LIST_LINK.format( + project_id=project_id, + ) + if project_id + else "" + ) + + +class CreateDatasetOperator(BaseOperator): + """ + Creates a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param dataset: Required. The Dataset to create. This corresponds to the ``dataset`` field on the + ``request`` instance; if ``request`` is provided, this should not be set. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "project_id", "impersonation_chain") + operator_extra_links = (VertexAIDatasetLink(),) + + def __init__( + self, + *, + region: str, + project_id: str, + dataset: Union[Dataset, Dict], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset = dataset + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Creating dataset") + operation = hook.create_dataset( + project_id=self.project_id, + region=self.region, + dataset=self.dataset, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = hook.wait_for_operation(timeout=self.timeout, operation=operation) + + dataset = Dataset.to_dict(result) + dataset_id = hook.extract_dataset_id(dataset) + self.log.info("Dataset was created. Dataset id: %s", dataset_id) + + self.xcom_push(context, key="dataset_id", value=dataset_id) + self.xcom_push( + context, + key="dataset_conf", + value={ + "dataset_id": dataset_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return dataset + + +class GetDatasetOperator(BaseOperator): + """ + Get a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param dataset_id: Required. The ID of the Dataset to get. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") + operator_extra_links = (VertexAIDatasetLink(),) + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.read_mask = read_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + try: + self.log.info("Get dataset: %s", self.dataset_id) + dataset_obj = hook.get_dataset( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + read_mask=self.read_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.xcom_push( + context, + key="dataset_conf", + value={ + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "region": self.region, + }, + ) + self.log.info("Dataset was gotten.") + return Dataset.to_dict(dataset_obj) + except NotFound: + self.log.info("The Dataset ID %s does not exist.", self.dataset_id) + + +class DeleteDatasetOperator(BaseOperator): + """ + Deletes a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param dataset_id: Required. The ID of the Dataset to delete. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + try: + self.log.info("Deleting dataset: %s", self.dataset_id) + operation = hook.delete_dataset( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Dataset was deleted.") + except NotFound: + self.log.info("The Dataset ID %s does not exist.", self.dataset_id) + + +class ExportDataOperator(BaseOperator): + """ + Exports data from a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param dataset_id: Required. The ID of the Dataset to delete. + :param export_config: Required. The desired output location. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + export_config: Union[ExportDataConfig, Dict], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.export_config = export_config + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Exporting data: %s", self.dataset_id) + operation = hook.export_data( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + export_config=self.export_config, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Export was done successfully") + + +class ImportDataOperator(BaseOperator): + """ + Imports data into a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param dataset_id: Required. The ID of the Dataset to delete. + :param import_configs: Required. The desired input locations. The contents of all input locations will be + imported in one batch. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + import_configs: Union[Sequence[ImportDataConfig], List], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.import_configs = import_configs + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Importing data: %s", self.dataset_id) + operation = hook.import_data( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + import_configs=self.import_configs, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Import was done successfully") + + +class ListDatasetsOperator(BaseOperator): + """ + Lists Datasets in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: The standard list filter. + :param page_size: The standard list page size. + :param page_token: The standard list page token. + :param read_mask: Mask specifying which fields to read. + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "project_id", "impersonation_chain") + operator_extra_links = (VertexAIDatasetListLink(),) + + def __init__( + self, + *, + region: str, + project_id: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.filter = filter + self.page_size = page_size + self.page_token = page_token + self.read_mask = read_mask + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + results = hook.list_datasets( + project_id=self.project_id, + region=self.region, + filter=self.filter, + page_size=self.page_size, + page_token=self.page_token, + read_mask=self.read_mask, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.xcom_push( + context, + key="project_id", + value=self.project_id, + ) + return [Dataset.to_dict(result) for result in results] + + +class UpdateDatasetOperator(BaseOperator): + """ + Updates a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param dataset_id: Required. The ID of the Dataset to update. + :param dataset: Required. The Dataset which replaces the resource on the server. + :param update_mask: Required. The update mask applies to the resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") + + def __init__( + self, + *, + project_id: str, + region: str, + dataset_id: str, + dataset: Union[Dataset, Dict], + update_mask: Union[FieldMask, Dict], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.dataset_id = dataset_id + self.dataset = dataset + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: 'Context'): + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Updating dataset: %s", self.dataset_id) + result = hook.update_dataset( + project_id=self.project_id, + region=self.region, + dataset_id=self.dataset_id, + dataset=self.dataset, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Dataset was updated") + return Dataset.to_dict(result) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 52a3c16a125be..45f9715ca93e9 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -338,6 +338,11 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-google/operators/leveldb/leveldb.rst tags: [google] + - integration-name: Google Vertex AI + external-doc-url: https://cloud.google.com/vertex-ai + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst + tags: [gcp] operators: - integration-name: Google Ads @@ -465,6 +470,10 @@ operators: - integration-name: Google LevelDB python-modules: - airflow.providers.google.leveldb.operators.leveldb + - integration-name: Google Vertex AI + python-modules: + - airflow.providers.google.cloud.operators.vertex_ai.dataset + - airflow.providers.google.cloud.operators.vertex_ai.custom_job sensors: - integration-name: Google BigQuery @@ -659,6 +668,10 @@ hooks: - integration-name: Google LevelDB python-modules: - airflow.providers.google.leveldb.hooks.leveldb + - integration-name: Google Vertex AI + python-modules: + - airflow.providers.google.cloud.hooks.vertex_ai.dataset + - airflow.providers.google.cloud.hooks.vertex_ai.custom_job transfers: - source-integration-name: Presto @@ -806,6 +819,10 @@ extra-links: - airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink - airflow.providers.google.cloud.operators.dataproc.DataprocJobLink - airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink + - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAIModelLink + - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAITrainingPipelinesLink + - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetLink + - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetListLink additional-extras: apache.beam: apache-beam[gcp] diff --git a/docs/apache-airflow-providers-google/index.rst b/docs/apache-airflow-providers-google/index.rst index bcb2a9a7a4b61..061a2b27d40e7 100644 --- a/docs/apache-airflow-providers-google/index.rst +++ b/docs/apache-airflow-providers-google/index.rst @@ -96,6 +96,7 @@ PIP package Version required ``google-api-python-client`` ``>=1.6.0,<2.0.0`` ``google-auth-httplib2`` ``>=0.0.1`` ``google-auth`` ``>=1.0.0,<3.0.0`` +``google-cloud-aiplatform`` ``>=1.7.1,<2.0.0`` ``google-cloud-automl`` ``>=2.1.0,<3.0.0`` ``google-cloud-bigquery-datatransfer`` ``>=3.0.0,<4.0.0`` ``google-cloud-bigtable`` ``>=1.0.0,<2.0.0`` diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst new file mode 100644 index 0000000000000..92c22af0d4aef --- /dev/null +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -0,0 +1,173 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Google Cloud VertexAI Operators +======================================= + +The `Google Cloud VertexAI `__ +brings AutoML and AI Platform together into a unified API, client library, and user +interface. AutoML lets you train models on image, tabular, text, and video datasets +without writing code, while training in AI Platform lets you run custom training code. +With Vertex AI, both AutoML training and custom training are available options. +Whichever option you choose for training, you can save models, deploy models, and +request predictions with Vertex AI. + +Creating Datasets +^^^^^^^^^^^^^^^^^ + +To create a Google VertexAI dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`. +The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_create_dataset_operator] + +After creating a dataset you can use it to import some data using +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_import_data_operator] + :end-before: [END how_to_cloud_vertex_ai_import_data_operator] + +To export dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ExportDataOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_export_data_operator] + :end-before: [END how_to_cloud_vertex_ai_export_data_operator] + +To delete dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.DeleteDatasetOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_dataset_operator] + +To get dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.GetDatasetOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_get_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_get_dataset_operator] + +To get a dataset list you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_list_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_list_dataset_operator] + +To update dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_update_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_update_dataset_operator] + +Creating a Training Jobs +^^^^^^^^^^^^^^^^^^^^^^^^ + +To create a Google Vertex AI training jobs you have three operators +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator`, +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`, +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`. +Each of them will wait for the operation to complete. The results of each operator will be a model +which was trained by user using these operators. + +Preparation step + +For each operator you must prepare and create dataset. Then put dataset id to ``dataset_id`` parameter in operator. + +How to run Container Training Job +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator` + +Before start running this Job you should create a docker image with training script inside. Documentation how to +create image you can find by this link: https://cloud.google.com/vertex-ai/docs/training/create-custom-container +After that you should put link to the image in ``container_uri`` parameter. Also you can type executing command +for container which will be created from this image in ``command`` parameter. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + +How to run Python Package Training Job +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator` + +Before start running this Job you should create a python package with training script inside. Documentation how to +create you can find by this link: https://cloud.google.com/vertex-ai/docs/training/create-python-pre-built-container +Next you should put link to the package in ``python_package_gcs_uri`` parameter, also ``python_module_name`` +parameter should has the name of script which will run your training task. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + +How to run Training Job +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`. + +For this Job you should put path to your local training script inside ``script_path`` parameter. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_create_custom_training_job_operator] + +You can get a list of Training Jobs using +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.ListCustomTrainingJobOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_list_custom_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_list_custom_training_job_operator] + +If you wish to delete a Custom Training Job you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.DeleteCustomTrainingJobOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_custom_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_custom_training_job_operator] + +Reference +^^^^^^^^^ + +For further information, look at: + +* `Client Library Documentation `__ +* `Product Documentation `__ diff --git a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py index ec614a5e466b3..11d9eafb22d7c 100755 --- a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py +++ b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py @@ -151,7 +151,10 @@ def parse_module_data(provider_data, resource_type, yaml_file_path): package_dir = ROOT_DIR + "/" + os.path.dirname(yaml_file_path) provider_package = os.path.dirname(yaml_file_path).replace(os.sep, ".") py_files = chain( - glob(f"{package_dir}/**/{resource_type}/*.py"), glob(f"{package_dir}/{resource_type}/*.py") + glob(f"{package_dir}/**/{resource_type}/*.py"), + glob(f"{package_dir}/{resource_type}/*.py"), + glob(f"{package_dir}/**/{resource_type}/**/*.py"), + glob(f"{package_dir}/{resource_type}/**/*.py"), ) expected_modules = {_filepath_to_module(f) for f in py_files if not f.endswith("/__init__.py")} resource_data = provider_data.get(resource_type, []) diff --git a/setup.py b/setup.py index ad3f19e960735..b04e8b1822b78 100644 --- a/setup.py +++ b/setup.py @@ -307,6 +307,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version # https://github.com/googleapis/google-cloud-python/issues/10566 'google-auth>=1.0.0,<3.0.0', 'google-auth-httplib2>=0.0.1', + 'google-cloud-aiplatform>=1.7.1,<2.0.0', 'google-cloud-automl>=2.1.0,<3.0.0', 'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0', 'google-cloud-bigtable>=1.0.0,<2.0.0', diff --git a/tests/providers/google/cloud/hooks/vertex_ai/__init__.py b/tests/providers/google/cloud/hooks/vertex_ai/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/google/cloud/hooks/vertex_ai/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py b/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py new file mode 100644 index 0000000000000..a78c278ade26e --- /dev/null +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py @@ -0,0 +1,457 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from unittest import TestCase, mock + +from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook +from tests.providers.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id, +) + +TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_REGION: str = "test-region" +TEST_PROJECT_ID: str = "test-project-id" +TEST_PIPELINE_JOB: dict = {} +TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id" +TEST_TRAINING_PIPELINE: dict = {} +TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline" + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +CUSTOM_JOB_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.custom_job.{}" + + +class TestCustomJobWithDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_cancel_pipeline_job(self, mock_client) -> None: + self.hook.cancel_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_cancel_training_pipeline(self, mock_client) -> None: + self.hook.cancel_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_create_pipeline_job(self, mock_client) -> None: + self.hook.create_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_pipeline_job.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_create_training_pipeline(self, mock_client) -> None: + self.hook.create_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_training_pipeline.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + training_pipeline=TEST_TRAINING_PIPELINE, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_delete_pipeline_job(self, mock_client) -> None: + self.hook.delete_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_delete_training_pipeline(self, mock_client) -> None: + self.hook.delete_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_get_pipeline_job(self, mock_client) -> None: + self.hook.get_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_get_training_pipeline(self, mock_client) -> None: + self.hook.get_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_list_pipeline_jobs(self, mock_client) -> None: + self.hook.list_pipeline_jobs( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_pipeline_jobs.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_list_training_pipelines(self, mock_client) -> None: + self.hook.list_training_pipelines( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_training_pipelines.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + read_mask=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + +class TestCustomJobWithoutDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id + ): + self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_cancel_pipeline_job(self, mock_client) -> None: + self.hook.cancel_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_cancel_training_pipeline(self, mock_client) -> None: + self.hook.cancel_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_create_pipeline_job(self, mock_client) -> None: + self.hook.create_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_pipeline_job.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_create_training_pipeline(self, mock_client) -> None: + self.hook.create_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_training_pipeline.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + training_pipeline=TEST_TRAINING_PIPELINE, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_delete_pipeline_job(self, mock_client) -> None: + self.hook.delete_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_delete_training_pipeline(self, mock_client) -> None: + self.hook.delete_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_get_pipeline_job(self, mock_client) -> None: + self.hook.get_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_get_training_pipeline(self, mock_client) -> None: + self.hook.get_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_list_pipeline_jobs(self, mock_client) -> None: + self.hook.list_pipeline_jobs( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_pipeline_jobs.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) + def test_list_training_pipelines(self, mock_client) -> None: + self.hook.list_training_pipelines( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_training_pipelines.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + read_mask=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py b/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py new file mode 100644 index 0000000000000..19c3d40a2afda --- /dev/null +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py @@ -0,0 +1,504 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from unittest import TestCase, mock + +from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook +from tests.providers.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id, +) + +TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_REGION: str = "test-region" +TEST_PROJECT_ID: str = "test-project-id" +TEST_PIPELINE_JOB: dict = {} +TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id" +TEST_TRAINING_PIPELINE: dict = {} +TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline" +TEST_DATASET: dict = {} +TEST_DATASET_ID: str = "test-dataset-id" +TEST_EXPORT_CONFIG: dict = {} +TEST_ANNOTATION_SPEC: str = "test-annotation-spec" +TEST_IMPORT_CONFIGS: dict = {} +TEST_DATA_ITEM: str = "test-data-item" +TEST_UPDATE_MASK: dict = {} + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +DATASET_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.dataset.{}" + + +class TestVertexAIWithDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = DatasetHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_create_dataset(self, mock_client) -> None: + self.hook.create_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_dataset.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + dataset=TEST_DATASET, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_delete_dataset(self, mock_client) -> None: + self.hook.delete_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_export_data(self, mock_client) -> None: + self.hook.export_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.export_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + export_config=TEST_EXPORT_CONFIG, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_annotation_spec(self, mock_client) -> None: + self.hook.get_annotation_spec( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + annotation_spec=TEST_ANNOTATION_SPEC, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_annotation_spec.assert_called_once_with( + request=dict( + name=mock_client.return_value.annotation_spec_path.return_value, + read_mask=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.annotation_spec_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_ANNOTATION_SPEC + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_dataset(self, mock_client) -> None: + self.hook.get_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + read_mask=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_import_data(self, mock_client) -> None: + self.hook.import_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIGS, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.import_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + import_configs=TEST_IMPORT_CONFIGS, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_annotations(self, mock_client) -> None: + self.hook.list_annotations( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + data_item=TEST_DATA_ITEM, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_annotations.assert_called_once_with( + request=dict( + parent=mock_client.return_value.data_item_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.data_item_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_DATA_ITEM + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_data_items(self, mock_client) -> None: + self.hook.list_data_items( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_data_items.assert_called_once_with( + request=dict( + parent=mock_client.return_value.dataset_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_datasets(self, mock_client) -> None: + self.hook.list_datasets( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_datasets.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_update_dataset(self, mock_client) -> None: + self.hook.update_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset_id=TEST_DATASET_ID, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.update_dataset.assert_called_once_with( + request=dict( + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + +class TestVertexAIWithoutDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id + ): + self.hook = DatasetHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_create_dataset(self, mock_client) -> None: + self.hook.create_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_dataset.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + dataset=TEST_DATASET, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_delete_dataset(self, mock_client) -> None: + self.hook.delete_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_export_data(self, mock_client) -> None: + self.hook.export_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.export_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + export_config=TEST_EXPORT_CONFIG, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_annotation_spec(self, mock_client) -> None: + self.hook.get_annotation_spec( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + annotation_spec=TEST_ANNOTATION_SPEC, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_annotation_spec.assert_called_once_with( + request=dict( + name=mock_client.return_value.annotation_spec_path.return_value, + read_mask=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.annotation_spec_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_ANNOTATION_SPEC + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_dataset(self, mock_client) -> None: + self.hook.get_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + read_mask=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_import_data(self, mock_client) -> None: + self.hook.import_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIGS, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.import_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + import_configs=TEST_IMPORT_CONFIGS, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_annotations(self, mock_client) -> None: + self.hook.list_annotations( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + data_item=TEST_DATA_ITEM, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_annotations.assert_called_once_with( + request=dict( + parent=mock_client.return_value.data_item_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.data_item_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_DATA_ITEM + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_data_items(self, mock_client) -> None: + self.hook.list_data_items( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_data_items.assert_called_once_with( + request=dict( + parent=mock_client.return_value.dataset_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_datasets(self, mock_client) -> None: + self.hook.list_datasets( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_datasets.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_update_dataset(self, mock_client) -> None: + self.hook.update_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset_id=TEST_DATASET_ID, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.update_dataset.assert_called_once_with( + request=dict( + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ), + metadata=(), + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py new file mode 100644 index 0000000000000..ec5a63d47890d --- /dev/null +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -0,0 +1,613 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +from google.api_core.retry import Retry + +from airflow.providers.google.cloud.operators.vertex_ai.custom_job import ( + CreateCustomContainerTrainingJobOperator, + CreateCustomPythonPackageTrainingJobOperator, + CreateCustomTrainingJobOperator, + DeleteCustomTrainingJobOperator, + ListCustomTrainingJobOperator, +) +from airflow.providers.google.cloud.operators.vertex_ai.dataset import ( + CreateDatasetOperator, + DeleteDatasetOperator, + ExportDataOperator, + ImportDataOperator, + ListDatasetsOperator, + UpdateDatasetOperator, +) + +VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" +TIMEOUT = 120 +RETRY = mock.MagicMock(Retry) +METADATA = [("key", "value")] + +TASK_ID = "test_task_id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "test-location" +GCP_CONN_ID = "test-conn" +DELEGATE_TO = "test-delegate-to" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +STAGING_BUCKET = "gs://test-vertex-ai-bucket" +DISPLAY_NAME = "display_name_1" # Create random display name +DISPLAY_NAME_2 = "display_nmae_2" +ARGS = ["--tfds", "tf_flowers:3.*.*"] +CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" +REPLICA_COUNT = 1 +MACHINE_TYPE = "n1-standard-4" +ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" +ACCELERATOR_COUNT = 0 +TRAINING_FRACTION_SPLIT = 0.7 +TEST_FRACTION_SPLIT = 0.15 +VALIDATION_FRACTION_SPLIT = 0.15 +COMMAND_2 = ['echo', 'Hello World'] + +TEST_API_ENDPOINT: str = "test-api-endpoint" +TEST_PIPELINE_JOB: str = "test-pipeline-job" +TEST_TRAINING_PIPELINE: str = "test-training-pipeline" +TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id" + +PYTHON_PACKAGE = "/files/trainer-0.1.tar.gz" +PYTHON_PACKAGE_CMDARGS = "test-python-cmd" +PYTHON_PACKAGE_GCS_URI = "gs://test-vertex-ai-bucket/trainer-0.1.tar.gz" +PYTHON_MODULE_NAME = "trainer.task" + +TRAINING_PIPELINE_ID = "test-training-pipeline-id" +CUSTOM_JOB_ID = "test-custom-job-id" + +TEST_DATASET = { + "display_name": "test-dataset-name", + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml", + "metadata": "test-image-dataset", +} +TEST_DATASET_ID = "test-dataset-id" +TEST_EXPORT_CONFIG = { + "annotationsFilter": "test-filter", + "gcs_destination": {"output_uri_prefix": "airflow-system-tests-data"}, +} +TEST_IMPORT_CONFIG = [ + { + "data_item_labels": { + "test-labels-name": "test-labels-value", + }, + "import_schema_uri": "test-shema-uri", + "gcs_source": {"uris": ['test-string']}, + }, + {}, +] +TEST_UPDATE_MASK = "test-update-mask" + + +class TestVertexAICreateCustomContainerTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = CreateCustomContainerTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + op.execute(context={'ti': mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.create_custom_container_training_job.assert_called_once_with( + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + dataset=None, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model_serving_container_predict_route=None, + model_serving_container_health_route=None, + model_serving_container_command=None, + model_serving_container_args=None, + model_serving_container_environment_variables=None, + model_serving_container_ports=None, + model_description=None, + model_instance_schema_uri=None, + model_parameters_schema_uri=None, + model_prediction_schema_uri=None, + labels=None, + training_encryption_spec_key_name=None, + model_encryption_spec_key_name=None, + # RUN + annotation_schema_uri=None, + model_labels=None, + base_output_dir=None, + service_account=None, + network=None, + bigquery_destination=None, + environment_variables=None, + boot_disk_type='pd-ssd', + boot_disk_size_gb=100, + training_filter_split=None, + validation_filter_split=None, + test_filter_split=None, + predefined_split_column_name=None, + timestamp_split_column_name=None, + tensorboard=None, + sync=True, + ) + + +class TestVertexAICreateCustomPythonPackageTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = CreateCustomPythonPackageTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + args=ARGS, + model_serving_container_image_uri=CONTAINER_URI, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + op.execute(context={'ti': mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with( + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + dataset=None, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model_serving_container_predict_route=None, + model_serving_container_health_route=None, + model_serving_container_command=None, + model_serving_container_args=None, + model_serving_container_environment_variables=None, + model_serving_container_ports=None, + model_description=None, + model_instance_schema_uri=None, + model_parameters_schema_uri=None, + model_prediction_schema_uri=None, + labels=None, + training_encryption_spec_key_name=None, + model_encryption_spec_key_name=None, + # RUN + annotation_schema_uri=None, + model_labels=None, + base_output_dir=None, + service_account=None, + network=None, + bigquery_destination=None, + environment_variables=None, + boot_disk_type='pd-ssd', + boot_disk_size_gb=100, + training_filter_split=None, + validation_filter_split=None, + test_filter_split=None, + predefined_split_column_name=None, + timestamp_split_column_name=None, + tensorboard=None, + sync=True, + ) + + +class TestVertexAICreateCustomTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = CreateCustomTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + script_path=PYTHON_PACKAGE, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + requirements=[], + replica_count=1, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + op.execute(context={'ti': mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.create_custom_training_job.assert_called_once_with( + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + script_path=PYTHON_PACKAGE, + requirements=[], + dataset=None, + model_display_name=None, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=None, + validation_fraction_split=None, + test_fraction_split=None, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model_serving_container_predict_route=None, + model_serving_container_health_route=None, + model_serving_container_command=None, + model_serving_container_args=None, + model_serving_container_environment_variables=None, + model_serving_container_ports=None, + model_description=None, + model_instance_schema_uri=None, + model_parameters_schema_uri=None, + model_prediction_schema_uri=None, + labels=None, + training_encryption_spec_key_name=None, + model_encryption_spec_key_name=None, + # RUN + annotation_schema_uri=None, + model_labels=None, + base_output_dir=None, + service_account=None, + network=None, + bigquery_destination=None, + environment_variables=None, + boot_disk_type='pd-ssd', + boot_disk_size_gb=100, + training_filter_split=None, + validation_filter_split=None, + test_filter_split=None, + predefined_split_column_name=None, + timestamp_split_column_name=None, + tensorboard=None, + sync=True, + ) + + +class TestVertexAIDeleteCustomTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = DeleteCustomTrainingJobOperator( + task_id=TASK_ID, + training_pipeline_id=TRAINING_PIPELINE_ID, + custom_job_id=CUSTOM_JOB_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context={}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.delete_training_pipeline.assert_called_once_with( + training_pipeline=TRAINING_PIPELINE_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.delete_custom_job.assert_called_once_with( + custom_job=CUSTOM_JOB_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIListCustomTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + page_token = "page_token" + page_size = 42 + filter = "filter" + read_mask = "read_mask" + + op = ListCustomTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + read_mask=read_mask, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={'ti': mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.list_training_pipelines.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + read_mask=read_mask, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAICreateDatasetOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = CreateDatasetOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={'ti': mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.create_dataset.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIDeleteDatasetOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = DeleteDatasetOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.delete_dataset.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIExportDataOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ExportDataOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.export_data.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIImportDataOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ImportDataOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.import_data.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIListDatasetsOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + page_token = "page_token" + page_size = 42 + filter = "filter" + read_mask = "read_mask" + order_by = "order_by" + + op = ListDatasetsOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + filter=filter, + page_size=page_size, + page_token=page_token, + read_mask=read_mask, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={'ti': mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.list_datasets.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + filter=filter, + page_size=page_size, + page_token=page_token, + read_mask=read_mask, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIUpdateDatasetOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = UpdateDatasetOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + project_id=GCP_PROJECT, + region=GCP_LOCATION, + dataset_id=TEST_DATASET_ID, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) + mock_hook.return_value.update_dataset.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + dataset_id=TEST_DATASET_ID, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai_system.py b/tests/providers/google/cloud/operators/test_vertex_ai_system.py new file mode 100644 index 0000000000000..84b84c33200c1 --- /dev/null +++ b/tests/providers/google/cloud/operators/test_vertex_ai_system.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tests.providers.google.cloud.utils.gcp_authenticator import GCP_VERTEX_AI_KEY +from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context + + +@pytest.mark.backend("mysql", "postgres") +@pytest.mark.credential_file(GCP_VERTEX_AI_KEY) +class VertexAIExampleDagsTest(GoogleSystemTest): + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def setUp(self): + super().setUp() + + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def tearDown(self): + super().tearDown() + + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def test_run_custom_jobs_example_dag(self): + self.run_dag(dag_id="example_gcp_vertex_ai_custom_jobs", dag_folder=CLOUD_DAG_FOLDER) + + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def test_run_dataset_example_dag(self): + self.run_dag(dag_id="example_gcp_vertex_ai_dataset", dag_folder=CLOUD_DAG_FOLDER) diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py b/tests/providers/google/cloud/utils/gcp_authenticator.py index 1a58b0d63928d..f6236b2813304 100644 --- a/tests/providers/google/cloud/utils/gcp_authenticator.py +++ b/tests/providers/google/cloud/utils/gcp_authenticator.py @@ -54,6 +54,7 @@ GCP_SPANNER_KEY = 'gcp_spanner.json' GCP_STACKDRIVER = 'gcp_stackdriver.json' GCP_TASKS_KEY = 'gcp_tasks.json' +GCP_VERTEX_AI_KEY = 'gcp_vertex_ai.json' GCP_WORKFLOWS_KEY = "gcp_workflows.json" GMP_KEY = 'gmp.json' G_FIREBASE_KEY = 'g_firebase.json'