From f1fe9567f68371d8167d5dc6de3d5b681f7e8db4 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Mon, 25 Oct 2021 13:44:38 +0000 Subject: [PATCH 01/20] Create PipelineService client hooks for Vertex AI --- .../providers/google/cloud/hooks/vertex_ai.py | 565 ++++++++++++++++++ .../google/cloud/hooks/test_vertex_ai.py | 437 ++++++++++++++ 2 files changed, 1002 insertions(+) create mode 100644 airflow/providers/google/cloud/hooks/vertex_ai.py create mode 100644 tests/providers/google/cloud/hooks/test_vertex_ai.py diff --git a/airflow/providers/google/cloud/hooks/vertex_ai.py b/airflow/providers/google/cloud/hooks/vertex_ai.py new file mode 100644 index 0000000000000..c405549e4f58a --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai.py @@ -0,0 +1,565 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud Vertex AI hook.""" + +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + +from google.api_core.operation import Operation +from google.api_core.retry import Retry +from google.cloud.aiplatform_v1 import PipelineServiceClient +from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ( + ListPipelineJobsPager, ListTrainingPipelinesPager) +from google.cloud.aiplatform_v1.types import ( + PipelineJob, TrainingPipeline) + + +class VertexAIHook(GoogleBaseHook): + """Hook for Google Cloud Vertex AI APIs.""" + + def get_pipeline_service_client( + self, + region: Optional[str] = None, + ) -> PipelineServiceClient: + """Returns PipelineServiceClient.""" + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} + + return PipelineServiceClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def wait_for_operation(self, timeout: float, operation: Operation): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best + effort to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other + methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a + pipeline with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a + [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: The name of the PipelineJob to cancel. + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + client.cancel_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes + a best effort to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline] + or other methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the TrainingPipeline is not deleted; instead it becomes a + pipeline with a [TrainingPipeline.error][google.cloud.aiplatform.v1.TrainingPipeline.error] value with + a [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The name of the TrainingPipeline to cancel. + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + client.cancel_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: PipelineJob, + pipeline_job_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> PipelineJob: + """ + Creates a PipelineJob. A PipelineJob will run immediately when created. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: Required. The PipelineJob to create. + :type pipeline_job: google.cloud.aiplatform_v1.types.PipelineJob + :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of + the PipelineJob name. If not provided, an ID will be automatically generated. + + This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. + :type pipeline_job_id: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_pipeline_job( + request={ + 'parent': parent, + 'pipeline_job': pipeline_job, + 'pipeline_job_id': pipeline_job_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: TrainingPipeline, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TrainingPipeline: + """ + Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The TrainingPipeline to create. + :type training_pipeline: google.cloud.aiplatform_v1.types.TrainingPipeline + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_training_pipeline( + request={ + 'parent': parent, + 'training_pipeline': training_pipeline, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deletes a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + result = client.delete_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deletes a TrainingPipeline. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted. + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + result = client.delete_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> PipelineJob: + """ + Gets a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: Required. The name of the PipelineJob resource. + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + result = client.get_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TrainingPipeline: + """ + Gets a TrainingPipeline. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The name of the TrainingPipeline resource. + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + result = client.get_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_pipeline_jobs( + self, + project_id: str, + region: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListPipelineJobsPager: + """ + Lists PipelineJobs in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: Optional. Lists the PipelineJobs that match the filter expression. The + following fields are supported: + + - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons. + - ``display_name``: Supports ``=``, ``!=`` comparisons, and + ``:`` wildcard. + - ``pipeline_job_user_id``: Supports ``=``, ``!=`` + comparisons, and ``:`` wildcard. for example, can check + if pipeline's display_name contains *step* by doing + display_name:"*step*" + - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``labels``: Supports key-value equality and key presence. + + Filter expressions can be combined together using logical + operators (``AND`` & ``OR``). For example: + ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``. + + The syntax to define filter expression is based on + https://google.aip.dev/160. + + Examples: + + - ``create_time>"2021-05-18T00:00:00Z" OR update_time>"2020-05-18T00:00:00Z"`` + PipelineJobs created or updated after 2020-05-18 00:00:00 + UTC. + - ``labels.env = "prod"`` PipelineJobs with label "env" set + to "prod". + :type filter: str + :param page_size: Optional. The standard list page size. + :type page_size: int + :param page_token: Optional. The standard list page token. Typically obtained via + [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token] + of the previous + [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs] + call. + :type page_token: str + :param order_by: Optional. A comma-separated list of fields to order by. The default + sort order is in ascending order. Use "desc" after a field + name for descending. You can have multiple order_by fields + provided e.g. "create_time desc, end_time", "end_time, + start_time, update_time" For example, using "create_time + desc, end_time" will order results by create time in + descending order, and if there are multiple jobs having the + same create time, order them by the end time in ascending + order. if order_by is not specified, it will order by + default order is create time in descending order. Supported + fields: + + - ``create_time`` + - ``update_time`` + - ``end_time`` + - ``start_time`` + :type order_by: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_pipeline_jobs( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_training_pipelines( + self, + project_id: str, + region: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListTrainingPipelinesPager: + """ + Lists TrainingPipelines in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: Optional. The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + :type filter: str + :param page_size: Optional. The standard list page size. + :type page_size: int + :param page_token: Optional. The standard list page token. Typically obtained via + [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] + of the previous + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] + call. + :type page_token: str + :param read_mask: Optional. Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_training_pipelines( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/tests/providers/google/cloud/hooks/test_vertex_ai.py b/tests/providers/google/cloud/hooks/test_vertex_ai.py new file mode 100644 index 0000000000000..329228723cbc6 --- /dev/null +++ b/tests/providers/google/cloud/hooks/test_vertex_ai.py @@ -0,0 +1,437 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Dict, Optional, Sequence, Tuple, Union +from unittest import TestCase, mock + +from airflow import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai import VertexAIHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry + +from tests.providers.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id) + +TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_REGION: str = "test-region" +TEST_PROJECT_ID: str = "test-project-id" +TEST_PIPELINE_JOB: dict = {} +TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id" +TEST_TRAINING_PIPELINE: dict = {} +TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline" + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +VERTEX_AI_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.{}" + +class TestVertexAIWithDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), + new=mock_base_gcp_hook_default_project_id + ): + self.hook = VertexAIHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_cancel_pipeline_job(self, mock_client) -> None: + self.hook.cancel_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_cancel_training_pipeline(self, mock_client) -> None: + self.hook.cancel_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_create_pipeline_job(self, mock_client) -> None: + self.hook.create_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_pipeline_job.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_create_training_pipeline(self, mock_client) -> None: + self.hook.create_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_training_pipeline.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + training_pipeline=TEST_TRAINING_PIPELINE, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_delete_pipeline_job(self, mock_client) -> None: + self.hook.delete_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_delete_training_pipeline(self, mock_client) -> None: + self.hook.delete_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_get_pipeline_job(self, mock_client) -> None: + self.hook.get_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_get_training_pipeline(self, mock_client) -> None: + self.hook.get_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_list_pipeline_jobs(self, mock_client) -> None: + self.hook.list_pipeline_jobs( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_pipeline_jobs.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_list_training_pipelines(self, mock_client) -> None: + self.hook.list_training_pipelines( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_training_pipelines.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + read_mask=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + +class TestVertexAIWithoutDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), + new=mock_base_gcp_hook_no_default_project_id + ): + self.hook = VertexAIHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_cancel_pipeline_job(self, mock_client) -> None: + self.hook.cancel_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_cancel_training_pipeline(self, mock_client) -> None: + self.hook.cancel_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.cancel_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_create_pipeline_job(self, mock_client) -> None: + self.hook.create_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_pipeline_job.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_create_training_pipeline(self, mock_client) -> None: + self.hook.create_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_training_pipeline.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + training_pipeline=TEST_TRAINING_PIPELINE, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_delete_pipeline_job(self, mock_client) -> None: + self.hook.delete_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_delete_training_pipeline(self, mock_client) -> None: + self.hook.delete_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_get_pipeline_job(self, mock_client) -> None: + self.hook.get_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_get_training_pipeline(self, mock_client) -> None: + self.hook.get_training_pipeline( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + training_pipeline=TEST_TRAINING_PIPELINE_NAME, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_training_pipeline.assert_called_once_with( + request=dict( + name=mock_client.return_value.training_pipeline_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_list_pipeline_jobs(self, mock_client) -> None: + self.hook.list_pipeline_jobs( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_pipeline_jobs.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + def test_list_training_pipelines(self, mock_client) -> None: + self.hook.list_training_pipelines( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_training_pipelines.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + read_mask=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) From 6ae69c3912f6e40e22ca5d05f84308c2506af179 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Thu, 4 Nov 2021 14:24:54 +0000 Subject: [PATCH 02/20] Implement CreateCustomContainerTrainingJob operator --- .../cloud/example_dags/example_vertex_ai.py | 128 ++ .../google/cloud/operators/vertex_ai.py | 1415 +++++++++++++++++ airflow/providers/google/provider.yaml | 11 + .../operators/cloud/vertex_ai.rst | 16 + .../cloud/operators/test_vertex_ai_system.py | 40 + .../google/cloud/utils/gcp_authenticator.py | 1 + 6 files changed, 1611 insertions(+) create mode 100644 airflow/providers/google/cloud/example_dags/example_vertex_ai.py create mode 100644 airflow/providers/google/cloud/operators/vertex_ai.py create mode 100644 docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst create mode 100644 tests/providers/google/cloud/operators/test_vertex_ai_system.py diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py new file mode 100644 index 0000000000000..6792171edbd95 --- /dev/null +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that demonstrates operators for the Google Vertex AI service in the Google +Cloud Platform. + +This DAG relies on the following OS environment variables: + +* GCP_BUCKET_NAME - Google Cloud Storage bucket where the file exists. +""" +import os +from random import randint +from uuid import uuid4 + +from airflow import models +from airflow.providers.google.cloud.operators.vertex_ai import ( + VertexAICreateCustomContainerTrainingJobOperator, + VertexAICreateCustomPythonPackageTrainingJobOperator, + VertexAICreateCustomTrainingJobOperator, +) +from airflow.utils.dates import days_ago + +# from google.cloud import aiplatform + + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") +REGION = os.environ.get("GCP_LOCATION", "europe-west1") +BUCKET = os.environ.get("GCP_VERTEX_AI_BUCKET", "vertex-ai-system-tests") + +STAGING_BUCKET = f"gs://{BUCKET}" +DISPLAY_NAME = str(uuid4()) # Create random display name +DISPLAY_NAME_2 = str(uuid4()) +DISPLAY_NAME_3 = str(uuid4()) +DISPLAY_NAME_4 = str(uuid4()) +ARGS = ["--tfds", "tf_flowers:3.*.*"] +CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" +RESOURCE_ID = str(randint(10000000, 99999999)) # Create random resource ID +REPLICA_COUNT = 1 +MACHINE_TYPE = "n1-standard-4" +ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" +ACCELERATOR_COUNT = 0 +TRAINING_FRACTION_SPLIT = 0.7 +TEST_FRACTION_SPLIT = 0.15 +VALIDATION_FRACTION_SPLIT = 0.15 +# This example uses an ImageDataset, but you can use another type +# DATASET = aiplatform.ImageDataset(RESOURCE_ID) if RESOURCE_ID else None +COMMAND = ['python3', 'run_script.py'] +COMMAND_2 = ['echo', 'Hello World'] +PYTHON_PACKAGE_GCS_URI = "gs://bucket3/custom-training-python-package/my_app/trainer-0.1.tar.gz" +PYTHON_MODULE_NAME = "trainer.task" + + +with models.DAG( + "example_gcp_vertex_ai", + start_date=days_ago(1), + schedule_interval="@once", +) as dag: + # [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + create_custom_container_training_job = VertexAICreateCustomContainerTrainingJobOperator( + task_id="custom_container_task", + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + # dataset=DATASET, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + + # [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + create_custom_python_package_training_job = VertexAICreateCustomPythonPackageTrainingJobOperator( + task_id="python_package_task", + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME_3, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + args=ARGS, + model_serving_container_image_uri=CONTAINER_URI, + # dataset=DATASET, + model_display_name=DISPLAY_NAME_4, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + + # [START how_to_cloud_vertex_ai_create_custom_training_job_operator] + create_custom_training_job = VertexAICreateCustomTrainingJobOperator( + task_id="custom_task", + # TODO: add parameters from example + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_custom_training_job_operator] diff --git a/airflow/providers/google/cloud/operators/vertex_ai.py b/airflow/providers/google/cloud/operators/vertex_ai.py new file mode 100644 index 0000000000000..23825ff5be84b --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai.py @@ -0,0 +1,1415 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Vertex AI operators.""" + +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core.retry import Retry +from google.cloud.aiplatform import datasets, initializer, schema, utils +from google.cloud.aiplatform.utils import _timestamped_gcs_dir, worker_spec_utils +from google.cloud.aiplatform_v1.types import ( + BigQueryDestination, + CancelPipelineJobRequest, + CancelTrainingPipelineRequest, + CreatePipelineJobRequest, + CreateTrainingPipelineRequest, + DeletePipelineJobRequest, + DeleteTrainingPipelineRequest, + EnvVar, + FilterSplit, + FractionSplit, + GcsDestination, + GetPipelineJobRequest, + GetTrainingPipelineRequest, + InputDataConfig, + ListPipelineJobsRequest, + ListTrainingPipelinesRequest, + Model, + ModelContainerSpec, + PipelineJob, + Port, + PredefinedSplit, + PredictSchemata, + TimestampSplit, + TrainingPipeline, +) + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.vertex_ai import VertexAIHook + + +class VertexAITrainingJobBaseOperator(BaseOperator): + """The base class for operators that launch job on VertexAI.""" + + def __init__( + self, + *, + region: str = None, + project_id: str, + display_name: str, + # START Run param + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + # END Run param + # START Custom + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # END Custom + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.display_name = display_name + # START Run param + self.dataset = dataset + self.annotation_schema_uri = annotation_schema_uri + self.model_display_name = model_display_name + self.model_labels = model_labels + self.base_output_dir = base_output_dir + self.service_account = service_account + self.network = network + self.bigquery_destination = bigquery_destination + self.args = args + self.environment_variables = environment_variables + self.replica_count = replica_count + self.machine_type = machine_type + self.accelerator_type = accelerator_type + self.accelerator_count = accelerator_count + self.boot_disk_type = boot_disk_type + self.boot_disk_size_gb = boot_disk_size_gb + self.training_fraction_split = training_fraction_split + self.validation_fraction_split = validation_fraction_split + self.test_fraction_split = test_fraction_split + self.training_filter_split = training_filter_split + self.validation_filter_split = validation_filter_split + self.test_filter_split = test_filter_split + self.predefined_split_column_name = predefined_split_column_name + self.timestamp_split_column_name = timestamp_split_column_name + self.tensorboard = tensorboard + self.sync = sync + # TODO: add optional and important parameters + # END Run param + # START Custom + self._container_uri = container_uri + + model_predict_schemata = None + if any( + [ + model_instance_schema_uri, + model_parameters_schema_uri, + model_prediction_schema_uri, + ] + ): + model_predict_schemata = PredictSchemata( + instance_schema_uri=model_instance_schema_uri, + parameters_schema_uri=model_parameters_schema_uri, + prediction_schema_uri=model_prediction_schema_uri, + ) + + # Create the container spec + env = None + ports = None + + if model_serving_container_environment_variables: + env = [ + EnvVar(name=str(key), value=str(value)) + for key, value in model_serving_container_environment_variables.items() + ] + + if model_serving_container_ports: + ports = [Port(container_port=port) for port in model_serving_container_ports] + + container_spec = ModelContainerSpec( + image_uri=model_serving_container_image_uri, + command=model_serving_container_command, + args=model_serving_container_args, + env=env, + ports=ports, + predict_route=model_serving_container_predict_route, + health_route=model_serving_container_health_route, + ) + + self._model_encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=model_encryption_spec_key_name + ) + + self._managed_model = Model( + description=model_description, + predict_schemata=model_predict_schemata, + container_spec=container_spec, + encryption_spec=self._model_encryption_spec, + ) + + self.labels = labels + self._training_encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=training_encryption_spec_key_name + ) + + self._staging_bucket = staging_bucket or initializer.global_config.staging_bucket + # END Custom + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook = VertexAIHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) + self.training_pipeline: Optional[TrainingPipeline] = None + self.worker_pool_specs: worker_spec_utils._DistributedTrainingSpec = None + self.managed_model: Optional[Model] = None + + def _prepare_and_validate_run( + self, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + ) -> Tuple[worker_spec_utils._DistributedTrainingSpec, Optional[Model]]: + """Create worker pool specs and managed model as well validating the + run. + + Args: + model_display_name (str): + If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + model_labels (Dict[str, str]): + Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + boot_disk_type (str): + Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + boot_disk_size_gb (int): + Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + Returns: + Worker pools specs and managed model for run. + + Raises: + RuntimeError: If Training job has already been run or model_display_name was + provided but required arguments were not provided in constructor. + """ + # TODO: Maybe not need + # if self._is_waiting_to_run(): + # raise RuntimeError("Custom Training is already scheduled to run.") + + # TODO: Maybe not need + # if self._has_run: + # raise RuntimeError("Custom Training has already run.") + + # if args needed for model is incomplete + if model_display_name and not self._managed_model.container_spec.image_uri: + raise RuntimeError( + """model_display_name was provided but + model_serving_container_image_uri was not provided when this + custom pipeline was constructed. + """ + ) + + if self._managed_model.container_spec.image_uri: + model_display_name = model_display_name or self._display_name + "-model" + + # validates args and will raise + worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + ).pool_specs + + managed_model = self._managed_model + if model_display_name: + utils.validate_display_name(model_display_name) + managed_model.display_name = model_display_name + if model_labels: + utils.validate_labels(model_labels) + managed_model.labels = model_labels + else: + managed_model.labels = self.labels + else: + managed_model = None + + return worker_pool_specs, managed_model + + def _prepare_training_task_inputs_and_output_dir( + self, + worker_pool_specs: worker_spec_utils._DistributedTrainingSpec, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + tensorboard: Optional[str] = None, + ) -> Tuple[Dict, str]: + """Prepares training task inputs and output directory for custom job. + + Args: + worker_pools_spec (worker_spec_utils._DistributedTrainingSpec): + Worker pools pecs required to run job. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + tensorboard (str): + Optional. The name of a Vertex AI + [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] + resource to which this CustomJob will upload Tensorboard + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + The training script should write Tensorboard to following Vertex AI environment + variable: + + AIP_TENSORBOARD_LOG_DIR + + `service_account` is required with provided `tensorboard`. + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + Returns: + Training task inputs and Output directory for custom job. + """ + # default directory if not given + base_output_dir = base_output_dir or _timestamped_gcs_dir( + self._staging_bucket, "aiplatform-custom-training" + ) + + self.log.info(f"Training Output directory:\n{base_output_dir} ") + + training_task_inputs = { + "worker_pool_specs": worker_pool_specs, + "base_output_directory": {"output_uri_prefix": base_output_dir}, + } + + if service_account: + training_task_inputs["service_account"] = service_account + if network: + training_task_inputs["network"] = network + if tensorboard: + training_task_inputs["tensorboard"] = tensorboard + + return training_task_inputs, base_output_dir + + def _create_input_data_config( + self, + dataset: Optional[datasets._Dataset] = None, + annotation_schema_uri: Optional[str] = None, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + gcs_destination_uri_prefix: Optional[str] = None, + bigquery_destination: Optional[str] = None, + ) -> Optional[InputDataConfig]: + """Constructs a input data config to pass to the training pipeline. + + Args: + dataset (datasets._Dataset): + The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + The schema files that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + training_fraction_split (float): + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + This parameter must be used with training_fraction_split, + validation_fraction_split and test_fraction_split. + gcs_destination_uri_prefix (str): + Optional. The Google Cloud Storage location. + + The Vertex AI environment variables representing Google + Cloud Storage data URIs will always be represented in the + Google Cloud Storage wildcard format to support sharded + data. + + - AIP_DATA_FORMAT = "jsonl". + - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" + - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" + - AIP_TEST_DATA_URI = "gcs_destination/test-*". + bigquery_destination (str): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + Raises: + ValueError: When more than 1 type of split configuration is passed or when + the split configuration passed is incompatible with the dataset schema. + """ + input_data_config = None + if dataset: + # Initialize all possible splits + filter_split = None + predefined_split = None + timestamp_split = None + fraction_split = None + + # Create filter split + if any( + [ + training_filter_split is not None, + validation_filter_split is not None, + test_filter_split is not None, + ] + ): + if all( + [ + training_filter_split is not None, + validation_filter_split is not None, + test_filter_split is not None, + ] + ): + filter_split = FilterSplit( + training_filter=training_filter_split, + validation_filter=validation_filter_split, + test_filter=test_filter_split, + ) + else: + raise ValueError("All filter splits must be passed together or not at all") + + # Create predefined split + if predefined_split_column_name: + predefined_split = PredefinedSplit(key=predefined_split_column_name) + + # Create timestamp split or fraction split + if timestamp_split_column_name: + timestamp_split = TimestampSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + key=timestamp_split_column_name, + ) + elif any( + [ + training_fraction_split is not None, + validation_fraction_split is not None, + test_fraction_split is not None, + ] + ): + fraction_split = FractionSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + ) + + splits = [ + split + for split in [ + filter_split, + predefined_split, + timestamp_split_column_name, + fraction_split, + ] + if split is not None + ] + + # Fallback to fraction split if nothing else is specified + if len(splits) == 0: + self.log.info("No dataset split provided. The service will use a default split.") + elif len(splits) > 1: + raise ValueError( + """Can only specify one of: + 1. training_filter_split, validation_filter_split, test_filter_split + 2. predefined_split_column_name + 3. timestamp_split_column_name, training_fraction_split, validation_fraction_split, + test_fraction_split + 4. training_fraction_split, validation_fraction_split, test_fraction_split""" + ) + + # create GCS destination + gcs_destination = None + if gcs_destination_uri_prefix: + gcs_destination = GcsDestination(output_uri_prefix=gcs_destination_uri_prefix) + + # TODO(b/177416223) validate managed BQ dataset is passed in + bigquery_destination_proto = None + if bigquery_destination: + bigquery_destination_proto = BigQueryDestination(output_uri=bigquery_destination) + + # create input data config + input_data_config = InputDataConfig( + fraction_split=fraction_split, + filter_split=filter_split, + predefined_split=predefined_split, + timestamp_split=timestamp_split, + dataset_id=dataset.name, + annotation_schema_uri=annotation_schema_uri, + gcs_destination=gcs_destination, + bigquery_destination=bigquery_destination_proto, + ) + + return input_data_config + + def _get_model(self, training_pipeline): + # TODO: implement logic for extract model from training_pipeline object + pass + + def execute(self, context): + (training_task_inputs, base_output_dir,) = self._prepare_training_task_inputs_and_output_dir( + worker_pool_specs=self.worker_pool_specs, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + tensorboard=self.tensorboard, + ) + + input_data_config = self._create_input_data_config( + dataset=self.dataset, + annotation_schema_uri=self.annotation_schema_uri, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + gcs_destination_uri_prefix=base_output_dir, + bigquery_destination=self.bigquery_destination, + ) + + # create training pipeline configuration object + training_pipeline = TrainingPipeline( + display_name=self.display_name, + training_task_definition=schema.training_job.definition.custom_task, # TODO: different for automl + training_task_inputs=training_task_inputs, # Required + model_to_upload=self.managed_model, # Optional + input_data_config=input_data_config, # Optional + labels=self.labels, # Optional + encryption_spec=self._training_encryption_spec, # Optional + ) + + self.training_pipeline = self.hook.create_training_pipeline( + project_id=self.project_id, + region=self.region, + training_pipeline=training_pipeline, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + model = self._get_model(self.training_pipeline) + + return model + + def on_kill(self) -> None: + """ + Callback called when the operator is killed. + Cancel any running job. + """ + if self.training_pipeline: + self.hook.cancel_training_pipeline( + project_id=self.project_id, + region=self.region, + training_pipeline=self.training_pipeline.name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class VertexAICreateCustomContainerTrainingJobOperator(VertexAITrainingJobBaseOperator): + """Create Custom Container Training job""" + + template_fields = [ + 'region', + 'impersonation_chain', + ] + + def __init__( + self, + *, + command: Sequence[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._command = command + + def execute(self, context): + self.worker_pool_specs, self.managed_model = self._prepare_and_validate_run( + model_display_name=self.model_display_name, + model_labels=self.model_labels, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_count=self.accelerator_count, + accelerator_type=self.accelerator_type, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + ) + + for spec in self.worker_pool_specs: + spec["containerSpec"] = {"imageUri": self._container_uri} + + if self._command: + spec["containerSpec"]["command"] = self._command + + if self.args: + spec["containerSpec"]["args"] = self.args + + if self.environment_variables: + spec["containerSpec"]["env"] = [ + {"name": key, "value": value} for key, value in self.environment_variables.items() + ] + + super().execute(context) + + +class VertexAICreateCustomPythonPackageTrainingJobOperator(VertexAITrainingJobBaseOperator): + """Create Custom Python Package Training job""" + + template_fields = [ + 'region', + 'impersonation_chain', + ] + + def __init__( + self, + python_package_gcs_uri: str, + python_module_name: str, + ) -> None: + self._package_gcs_uri = python_package_gcs_uri + self._python_module = python_module_name + + def execute(self, context): + self.worker_pool_specs, self.managed_model = self._prepare_and_validate_run( + model_display_name=self.model_display_name, + model_labels=self.model_labels, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_count=self.accelerator_count, + accelerator_type=self.accelerator_type, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + ) + + for spec in self.worker_pool_specs: + spec["python_package_spec"] = { + "executor_image_uri": self._container_uri, + "python_module": self._python_module, + "package_uris": [self._package_gcs_uri], + } + + if self.args: + spec["python_package_spec"]["args"] = self.args + + if self.environment_variables: + spec["python_package_spec"]["env"] = [ + {"name": key, "value": value} for key, value in self.environment_variables.items() + ] + + super().execute(context) + + +class VertexAICreateCustomTrainingJobOperator(VertexAITrainingJobBaseOperator): + """Create Custom Training job""" + + def __init__( + self, + display_name: str, + script_path: str, + ) -> None: + pass + + def execute(self, context): + self.worker_pool_specs, self.managed_model = self._prepare_and_validate_run( + model_display_name=self.model_display_name, + model_labels=self.model_labels, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_count=self.accelerator_count, + accelerator_type=self.accelerator_type, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + ) + super().execute(context) + + +class VertexAICancelPipelineJobOperator(BaseOperator): + """ + Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best effort + to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other + methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a pipeline + with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a + [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. + + :param request: The request object. Request message for + [PipelineService.CancelPipelineJob][google.cloud.aiplatform.v1.PipelineService.CancelPipelineJob]. + :type request: Union[google.cloud.aiplatform_v1.types.CancelPipelineJobRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param pipeline_job: TODO: Fill description + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[CancelPipelineJobRequest, Dict], + location: str, + pipeline_job: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.pipeline_job = pipeline_job + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.cancel_pipeline_job( + request=self.request, + location=self.location, + pipeline_job=self.pipeline_job, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAICancelTrainingPipelineOperator(BaseOperator): + """ + Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a + best effort to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline] or + other methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the TrainingPipeline is not deleted; instead it becomes a + pipeline with a [TrainingPipeline.error][google.cloud.aiplatform.v1.TrainingPipeline.error] value with a + [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``. + + :param request: The request object. Request message for [PipelineService.CancelTrainingPipeline][google.c + loud.aiplatform.v1.PipelineService.CancelTrainingPipeline]. + :type request: Union[google.cloud.aiplatform_v1.types.CancelTrainingPipelineRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param training_pipeline: TODO: Fill description + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[CancelTrainingPipelineRequest, Dict], + location: str, + training_pipeline: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.training_pipeline = training_pipeline + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.cancel_training_pipeline( + request=self.request, + location=self.location, + training_pipeline=self.training_pipeline, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAICreatePipelineJobOperator(BaseOperator): + """ + Creates a PipelineJob. A PipelineJob will run immediately when created. + + :param request: The request object. Request message for + [PipelineService.CreatePipelineJob][google.cloud.aiplatform.v1.PipelineService.CreatePipelineJob]. + :type request: Union[google.cloud.aiplatform_v1.types.CreatePipelineJobRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param pipeline_job: Required. The PipelineJob to create. This corresponds to the ``pipeline_job`` field + on the ``request`` instance; if ``request`` is provided, this should not be set. + :type pipeline_job: google.cloud.aiplatform_v1.types.PipelineJob + :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of the + PipelineJob name. If not provided, an ID will be automatically generated. + + This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. + + This corresponds to the ``pipeline_job_id`` field on the ``request`` instance; if ``request`` is + provided, this should not be set. + :type pipeline_job_id: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[CreatePipelineJobRequest, Dict], + location: str, + pipeline_job: PipelineJob, + pipeline_job_id: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.pipeline_job = pipeline_job + self.pipeline_job_id = pipeline_job_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.create_pipeline_job( + request=self.request, + location=self.location, + pipeline_job=self.pipeline_job, + pipeline_job_id=self.pipeline_job_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAICreateTrainingPipelineOperator(BaseOperator): + """ + Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. + + :param request: The request object. Request message for [PipelineService.CreateTrainingPipeline][google.c + loud.aiplatform.v1.PipelineService.CreateTrainingPipeline]. + :type request: Union[google.cloud.aiplatform_v1.types.CreateTrainingPipelineRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param training_pipeline: Required. The TrainingPipeline to create. + + This corresponds to the ``training_pipeline`` field on the ``request`` instance; if ``request`` is + provided, this should not be set. + :type training_pipeline: google.cloud.aiplatform_v1.types.TrainingPipeline + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[CreateTrainingPipelineRequest, Dict], + location: str, + training_pipeline: TrainingPipeline, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.training_pipeline = training_pipeline + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.create_training_pipeline( + request=self.request, + location=self.location, + training_pipeline=self.training_pipeline, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAIDeletePipelineJobOperator(BaseOperator): + """ + Deletes a PipelineJob. + + :param request: The request object. Request message for + [PipelineService.DeletePipelineJob][google.cloud.aiplatform.v1.PipelineService.DeletePipelineJob]. + :type request: Union[google.cloud.aiplatform_v1.types.DeletePipelineJobRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param pipeline_job: TODO: Fill description + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[DeletePipelineJobRequest, Dict], + location: str, + pipeline_job: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.pipeline_job = pipeline_job + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.delete_pipeline_job( + request=self.request, + location=self.location, + pipeline_job=self.pipeline_job, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAIDeleteTrainingPipelineOperator(BaseOperator): + """ + Deletes a TrainingPipeline. + + :param request: The request object. Request message for [PipelineService.DeleteTrainingPipeline][google.c + loud.aiplatform.v1.PipelineService.DeleteTrainingPipeline]. + :type request: Union[google.cloud.aiplatform_v1.types.DeleteTrainingPipelineRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param training_pipeline: TODO: Fill description + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[DeleteTrainingPipelineRequest, Dict], + location: str, + training_pipeline: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.training_pipeline = training_pipeline + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.delete_training_pipeline( + request=self.request, + location=self.location, + training_pipeline=self.training_pipeline, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAIGetPipelineJobOperator(BaseOperator): + """ + Gets a PipelineJob. + + :param request: The request object. Request message for + [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob]. + :type request: Union[google.cloud.aiplatform_v1.types.GetPipelineJobRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param pipeline_job: TODO: Fill description + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[GetPipelineJobRequest, Dict], + location: str, + pipeline_job: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.pipeline_job = pipeline_job + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.get_pipeline_job( + request=self.request, + location=self.location, + pipeline_job=self.pipeline_job, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAIGetTrainingPipelineOperator(BaseOperator): + """ + Gets a TrainingPipeline. + + :param request: The request object. Request message for + [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline]. + :type request: Union[google.cloud.aiplatform_v1.types.GetTrainingPipelineRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param training_pipeline: TODO: Fill description + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[GetTrainingPipelineRequest, Dict], + location: str, + training_pipeline: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.training_pipeline = training_pipeline + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.get_training_pipeline( + request=self.request, + location=self.location, + training_pipeline=self.training_pipeline, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAIListPipelineJobsOperator(BaseOperator): + """ + Lists PipelineJobs in a Location. + + :param request: The request object. Request message for + [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs]. + :type request: Union[google.cloud.aiplatform_v1.types.ListPipelineJobsRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[ListPipelineJobsRequest, Dict], + location: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.list_pipeline_jobs( + request=self.request, + location=self.location, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) + + +class VertexAIListTrainingPipelinesOperator(BaseOperator): + """ + Lists TrainingPipelines in a Location. + + :param request: The request object. Request message for + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelin + es]. + :type request: Union[google.cloud.aiplatform_v1.types.ListTrainingPipelinesRequest, Dict] + :param location: TODO: Fill description + :type location: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param project_id: TODO: Fill description + :type project_id: str + :param gcp_conn_id: + :type gcp_conn_id: str + """ + + def __init__( + self, + request: Union[ListTrainingPipelinesRequest, Dict], + location: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.request = request + self.location = location + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) + hook.list_training_pipelines( + request=self.request, + location=self.location, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + project_id=self.project_id, + ) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 5e9b00d267d57..089ac62b7664b 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -337,6 +337,11 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-google/operators/leveldb/leveldb.rst tags: [google] + - integration-name: Google Vertex AI + external-doc-url: https://cloud.google.com/vertex-ai + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst + tags: [gcp] operators: - integration-name: Google Ads @@ -464,6 +469,9 @@ operators: - integration-name: Google LevelDB python-modules: - airflow.providers.google.leveldb.operators.leveldb + - integration-name: Google Vertex AI + python-modules: + - airflow.providers.google.cloud.operators.vertex_ai sensors: - integration-name: Google BigQuery @@ -658,6 +666,9 @@ hooks: - integration-name: Google LevelDB python-modules: - airflow.providers.google.leveldb.hooks.leveldb + - integration-name: Google Vertex AI + python-modules: + - airflow.providers.google.cloud.hooks.vertex_ai transfers: - source-integration-name: Presto diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst new file mode 100644 index 0000000000000..106592bd11775 --- /dev/null +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -0,0 +1,16 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. diff --git a/tests/providers/google/cloud/operators/test_vertex_ai_system.py b/tests/providers/google/cloud/operators/test_vertex_ai_system.py new file mode 100644 index 0000000000000..69e8527a750b7 --- /dev/null +++ b/tests/providers/google/cloud/operators/test_vertex_ai_system.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from airflow.providers.google.cloud.example_dags.example_vertex_ai import BUCKET, REGION +from tests.providers.google.cloud.utils.gcp_authenticator import GCP_VERTEX_AI_KEY +from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context + + +@pytest.mark.backend("mysql", "postgres") +@pytest.mark.credential_file(GCP_VERTEX_AI_KEY) +class DataprocExampleDagsTest(GoogleSystemTest): + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def setUp(self): + super().setUp() + self.create_gcs_bucket(BUCKET, REGION) + + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def tearDown(self): + self.delete_gcs_bucket(BUCKET) + super().tearDown() + + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def test_run_example_dag(self): + self.run_dag(dag_id="example_gcp_vertex_ai", dag_folder=CLOUD_DAG_FOLDER) diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py b/tests/providers/google/cloud/utils/gcp_authenticator.py index 1a58b0d63928d..f6236b2813304 100644 --- a/tests/providers/google/cloud/utils/gcp_authenticator.py +++ b/tests/providers/google/cloud/utils/gcp_authenticator.py @@ -54,6 +54,7 @@ GCP_SPANNER_KEY = 'gcp_spanner.json' GCP_STACKDRIVER = 'gcp_stackdriver.json' GCP_TASKS_KEY = 'gcp_tasks.json' +GCP_VERTEX_AI_KEY = 'gcp_vertex_ai.json' GCP_WORKFLOWS_KEY = "gcp_workflows.json" GMP_KEY = 'gmp.json' G_FIREBASE_KEY = 'g_firebase.json' From 4371b14034103778afc18b0ac8cbb00d6f9e72d6 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Wed, 10 Nov 2021 16:19:42 +0000 Subject: [PATCH 03/20] Create CustomJob hooks for VertexAI --- .../providers/google/cloud/hooks/vertex_ai.py | 565 ---- .../google/cloud/hooks/vertex_ai/__init__.py | 16 + .../cloud/hooks/vertex_ai/custom_job.py | 2379 +++++++++++++++++ .../google/cloud/hooks/test_vertex_ai.py | 114 +- 4 files changed, 2462 insertions(+), 612 deletions(-) delete mode 100644 airflow/providers/google/cloud/hooks/vertex_ai.py create mode 100644 airflow/providers/google/cloud/hooks/vertex_ai/__init__.py create mode 100644 airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py diff --git a/airflow/providers/google/cloud/hooks/vertex_ai.py b/airflow/providers/google/cloud/hooks/vertex_ai.py deleted file mode 100644 index c405549e4f58a..0000000000000 --- a/airflow/providers/google/cloud/hooks/vertex_ai.py +++ /dev/null @@ -1,565 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -"""This module contains a Google Cloud Vertex AI hook.""" - -from typing import Dict, Optional, Sequence, Tuple, Union - -from airflow import AirflowException -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook - -from google.api_core.operation import Operation -from google.api_core.retry import Retry -from google.cloud.aiplatform_v1 import PipelineServiceClient -from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ( - ListPipelineJobsPager, ListTrainingPipelinesPager) -from google.cloud.aiplatform_v1.types import ( - PipelineJob, TrainingPipeline) - - -class VertexAIHook(GoogleBaseHook): - """Hook for Google Cloud Vertex AI APIs.""" - - def get_pipeline_service_client( - self, - region: Optional[str] = None, - ) -> PipelineServiceClient: - """Returns PipelineServiceClient.""" - client_options = None - if region and region != 'global': - client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} - - return PipelineServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options - ) - - def wait_for_operation(self, timeout: float, operation: Operation): - """Waits for long-lasting operation to complete.""" - try: - return operation.result(timeout=timeout) - except Exception: - error = operation.exception(timeout=timeout) - raise AirflowException(error) - - @GoogleBaseHook.fallback_to_default_project_id - def cancel_pipeline_job( - self, - project_id: str, - region: str, - pipeline_job: str, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> None: - """ - Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best - effort to cancel the pipeline, but success is not guaranteed. Clients can use - [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other - methods to check whether the cancellation succeeded or whether the pipeline completed despite - cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a - pipeline with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a - [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and - [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param pipeline_job: The name of the PipelineJob to cancel. - :type pipeline_job: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - name = client.pipeline_job_path(project_id, region, pipeline_job) - - client.cancel_pipeline_job( - request={ - 'name': name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - @GoogleBaseHook.fallback_to_default_project_id - def cancel_training_pipeline( - self, - project_id: str, - region: str, - training_pipeline: str, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> None: - """ - Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes - a best effort to cancel the pipeline, but success is not guaranteed. Clients can use - [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline] - or other methods to check whether the cancellation succeeded or whether the pipeline completed despite - cancellation. On successful cancellation, the TrainingPipeline is not deleted; instead it becomes a - pipeline with a [TrainingPipeline.error][google.cloud.aiplatform.v1.TrainingPipeline.error] value with - a [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and - [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param training_pipeline: Required. The name of the TrainingPipeline to cancel. - :type training_pipeline: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - name = client.training_pipeline_path(project_id, region, training_pipeline) - - client.cancel_training_pipeline( - request={ - 'name': name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - @GoogleBaseHook.fallback_to_default_project_id - def create_pipeline_job( - self, - project_id: str, - region: str, - pipeline_job: PipelineJob, - pipeline_job_id: str, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> PipelineJob: - """ - Creates a PipelineJob. A PipelineJob will run immediately when created. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param pipeline_job: Required. The PipelineJob to create. - :type pipeline_job: google.cloud.aiplatform_v1.types.PipelineJob - :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of - the PipelineJob name. If not provided, an ID will be automatically generated. - - This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. - :type pipeline_job_id: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - parent = client.common_location_path(project_id, region) - - result = client.create_pipeline_job( - request={ - 'parent': parent, - 'pipeline_job': pipeline_job, - 'pipeline_job_id': pipeline_job_id, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def create_training_pipeline( - self, - project_id: str, - region: str, - training_pipeline: TrainingPipeline, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> TrainingPipeline: - """ - Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param training_pipeline: Required. The TrainingPipeline to create. - :type training_pipeline: google.cloud.aiplatform_v1.types.TrainingPipeline - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - parent = client.common_location_path(project_id, region) - - result = client.create_training_pipeline( - request={ - 'parent': parent, - 'training_pipeline': training_pipeline, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def delete_pipeline_job( - self, - project_id: str, - region: str, - pipeline_job: str, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> Operation: - """ - Deletes a PipelineJob. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. - :type pipeline_job: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - name = client.pipeline_job_path(project_id, region, pipeline_job) - - result = client.delete_pipeline_job( - request={ - 'name': name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def delete_training_pipeline( - self, - project_id: str, - region: str, - training_pipeline: str, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> Operation: - """ - Deletes a TrainingPipeline. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted. - :type training_pipeline: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - name = client.training_pipeline_path(project_id, region, training_pipeline) - - result = client.delete_training_pipeline( - request={ - 'name': name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def get_pipeline_job( - self, - project_id: str, - region: str, - pipeline_job: str, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> PipelineJob: - """ - Gets a PipelineJob. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param pipeline_job: Required. The name of the PipelineJob resource. - :type pipeline_job: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - name = client.pipeline_job_path(project_id, region, pipeline_job) - - result = client.get_pipeline_job( - request={ - 'name': name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def get_training_pipeline( - self, - project_id: str, - region: str, - training_pipeline: str, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> TrainingPipeline: - """ - Gets a TrainingPipeline. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param training_pipeline: Required. The name of the TrainingPipeline resource. - :type training_pipeline: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - name = client.training_pipeline_path(project_id, region, training_pipeline) - - result = client.get_training_pipeline( - request={ - 'name': name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def list_pipeline_jobs( - self, - project_id: str, - region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - order_by: Optional[str] = None, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> ListPipelineJobsPager: - """ - Lists PipelineJobs in a Location. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param filter: Optional. Lists the PipelineJobs that match the filter expression. The - following fields are supported: - - - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons. - - ``display_name``: Supports ``=``, ``!=`` comparisons, and - ``:`` wildcard. - - ``pipeline_job_user_id``: Supports ``=``, ``!=`` - comparisons, and ``:`` wildcard. for example, can check - if pipeline's display_name contains *step* by doing - display_name:"*step*" - - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, - ``<=``, and ``>=`` comparisons. Values must be in RFC - 3339 format. - - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``, - ``<=``, and ``>=`` comparisons. Values must be in RFC - 3339 format. - - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, - ``<=``, and ``>=`` comparisons. Values must be in RFC - 3339 format. - - ``labels``: Supports key-value equality and key presence. - - Filter expressions can be combined together using logical - operators (``AND`` & ``OR``). For example: - ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``. - - The syntax to define filter expression is based on - https://google.aip.dev/160. - - Examples: - - - ``create_time>"2021-05-18T00:00:00Z" OR update_time>"2020-05-18T00:00:00Z"`` - PipelineJobs created or updated after 2020-05-18 00:00:00 - UTC. - - ``labels.env = "prod"`` PipelineJobs with label "env" set - to "prod". - :type filter: str - :param page_size: Optional. The standard list page size. - :type page_size: int - :param page_token: Optional. The standard list page token. Typically obtained via - [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token] - of the previous - [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs] - call. - :type page_token: str - :param order_by: Optional. A comma-separated list of fields to order by. The default - sort order is in ascending order. Use "desc" after a field - name for descending. You can have multiple order_by fields - provided e.g. "create_time desc, end_time", "end_time, - start_time, update_time" For example, using "create_time - desc, end_time" will order results by create time in - descending order, and if there are multiple jobs having the - same create time, order them by the end time in ascending - order. if order_by is not specified, it will order by - default order is create time in descending order. Supported - fields: - - - ``create_time`` - - ``update_time`` - - ``end_time`` - - ``start_time`` - :type order_by: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - parent = client.common_location_path(project_id, region) - - result = client.list_pipeline_jobs( - request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'order_by': order_by, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def list_training_pipelines( - self, - project_id: str, - region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> ListTrainingPipelinesPager: - """ - Lists TrainingPipelines in a Location. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str - :param filter: Optional. The standard list filter. Supported fields: - - - ``display_name`` supports = and !=. - - - ``state`` supports = and !=. - - Some examples of using the filter are: - - - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` - - - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` - - - ``NOT display_name="my_pipeline"`` - - - ``state="PIPELINE_STATE_FAILED"`` - :type filter: str - :param page_size: Optional. The standard list page size. - :type page_size: int - :param page_token: Optional. The standard list page token. Typically obtained via - [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] - of the previous - [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] - call. - :type page_token: str - :param read_mask: Optional. Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - """ - client = self.get_pipeline_service_client(region) - parent = client.common_location_path(project_id, region) - - result = client.list_training_pipelines( - request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'read_mask': read_mask, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py b/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py new file mode 100644 index 0000000000000..78f7682cff6d6 --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -0,0 +1,2379 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud Vertex AI hook.""" + +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core.operation import Operation +from google.api_core.retry import Retry +from google.cloud.aiplatform import ( + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + CustomTrainingJob, + datasets, +) +from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient +from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager +from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ( + ListPipelineJobsPager, + ListTrainingPipelinesPager, +) +from google.cloud.aiplatform_v1.types import CustomJob, Model, PipelineJob, TrainingPipeline + +from airflow import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + + +class CustomJobHook(GoogleBaseHook): + """Hook for Google Cloud Vertex AI Custom Job APIs.""" + + def get_pipeline_service_client( + self, + region: Optional[str] = None, + ) -> PipelineServiceClient: + """Returns PipelineServiceClient.""" + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} + + return PipelineServiceClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def get_job_service_client( + self, + region: Optional[str] = None, + ) -> JobServiceClient: + """Returns JobServiceClient""" + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} + + return JobServiceClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def get_custom_container_training_job( + self, + display_name: str, + container_uri: str, + command: Sequence[str] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ) -> CustomContainerTrainingJob: + """Returns CustomContainerTrainingJob object""" + return CustomContainerTrainingJob( + display_name=display_name, + container_uri=container_uri, + command=command, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + project=project, + location=location, + credentials=self._get_credentials(), + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + def get_custom_python_package_training_job( + self, + display_name: str, + python_package_gcs_uri: str, + python_module_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Returns CustomPythonPackageTrainingJob object""" + return CustomPythonPackageTrainingJob( + display_name=display_name, + container_uri=container_uri, + python_package_gcs_uri=python_package_gcs_uri, + python_module_name=python_module_name, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + project=project, + location=location, + credentials=self._get_credentials(), + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + def get_custom_training_job( + self, + display_name: str, + script_path: str, + container_uri: str, + requirements: Optional[Sequence[str]] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Returns CustomTrainingJob object""" + return CustomTrainingJob( + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + requirements=requirements, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + project=project, + location=location, + credentials=self._get_credentials(), + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + def wait_for_operation(self, timeout: float, operation: Operation): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + def _run_job( + self, + job: Union[ + CustomTrainingJob, + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + ], + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> Model: + """Run Job for training pipeline""" + self.log.info("START RUN JOB") + model = job.run( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + self.log.info(f"END RUN JOB. {model}") + model.wait() + self.log.info("STOP WAIT") + return model + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best + effort to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other + methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a + pipeline with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a + [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: The name of the PipelineJob to cancel. + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + client.cancel_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes + a best effort to cancel the pipeline, but success is not guaranteed. Clients can use + [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline] + or other methods to check whether the cancellation succeeded or whether the pipeline completed despite + cancellation. On successful cancellation, the TrainingPipeline is not deleted; instead it becomes a + pipeline with a [TrainingPipeline.error][google.cloud.aiplatform.v1.TrainingPipeline.error] value with + a [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The name of the TrainingPipeline to cancel. + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + client.cancel_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_custom_job( + self, + project_id: str, + region: str, + custom_job: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + ) -> None: + """ + Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort + to cancel the job, but success is not guaranteed. Clients can use + [JobService.GetCustomJob][google.cloud.aiplatform.v1.JobService.GetCustomJob] or other methods to + check whether the cancellation succeeded or whether the job completed despite cancellation. On + successful cancellation, the CustomJob is not deleted; instead it becomes a job with a + [CustomJob.error][google.cloud.aiplatform.v1.CustomJob.error] value with a + [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and + [CustomJob.state][google.cloud.aiplatform.v1.CustomJob.state] is set to ``CANCELLED``. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param custom_job: Required. The name of the CustomJob to cancel. + :type custom_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_job_service_client(region) + name = JobServiceClient.custom_job_path(project_id, region, custom_job) + + client.cancel_custom_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: PipelineJob, + pipeline_job_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> PipelineJob: + """ + Creates a PipelineJob. A PipelineJob will run immediately when created. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: Required. The PipelineJob to create. + :type pipeline_job: google.cloud.aiplatform_v1.types.PipelineJob + :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of + the PipelineJob name. If not provided, an ID will be automatically generated. + + This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. + :type pipeline_job_id: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_pipeline_job( + request={ + 'parent': parent, + 'pipeline_job': pipeline_job, + 'pipeline_job_id': pipeline_job_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: TrainingPipeline, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TrainingPipeline: + """ + Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The TrainingPipeline to create. + :type training_pipeline: google.cloud.aiplatform_v1.types.TrainingPipeline + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_training_pipeline( + request={ + 'parent': parent, + 'training_pipeline': training_pipeline, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_job( + self, + project_id: str, + region: str, + custom_job: CustomJob, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + ) -> CustomJob: + """ + Creates a CustomJob. A created CustomJob right away will be attempted to be run. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param custom_job: Required. The CustomJob to create. This corresponds to the ``custom_job`` field on + the ``request`` instance; if ``request`` is provided, this should not be set. + :type custom_job: google.cloud.aiplatform_v1.types.CustomJob + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_job_service_client(region) + parent = JobServiceClient.common_location_path(project_id, region) + + result = client.create_custom_job( + request={ + 'parent': parent, + 'custom_job': custom_job, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_container_training_job( + self, + project_id: str, + region: str, + display_name: str, + container_uri: str, + command: Sequence[str] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> Model: + """ + Create Custom Container Training Job + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :type display_name: str + :param command: The command to be invoked when the container is started. + It overrides the entrypoint instruction in Dockerfile when provided + :type command: Sequence[str] + :param container_uri: Required: Uri of the training container image in the GCR. + :type container_uri: str + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :type model_serving_container_image_uri: str + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :type model_serving_container_predict_route: str + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :type model_serving_container_health_route: str + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :type model_serving_container_command: Sequence[str] + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :type model_serving_container_args: Sequence[str] + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :type model_serving_container_environment_variables: Dict[str, str] + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :type model_serving_container_ports: Sequence[int] + :param model_description: The description of the Model. + :type model_description: str + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_instance_schema_uri: str + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :type model_parameters_schema_uri: str + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_prediction_schema_uri: str + :param project: Project to run training in. Overrides project set in aiplatform.init. + :type project: str + :param location: Location to run training in. Overrides location set in aiplatform.init. + :type location: str + :param credentials: Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + :type credentials: auth_credentials.Credentials + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type labels: Dict[str, str] + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + :type training_encryption_spec_key_name: Optional[str] + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + :type model_encryption_spec_key_name: Optional[str] + :param staging_bucket: Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + :type staging_bucket: str + + :param dataset: Vertex AI to fit this training against. Custom training script should + retrieve datasets through passed in environment variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + :type dataset: Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + The schema files that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :type annotation_schema_uri: str + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :type model_display_name: str + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type model_labels: Dict[str, str] + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :type base_output_dir: str + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :type service_account: str + :param network: The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :type network: str + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :type bigquery_destination: str + :param args: Command line arguments to be passed to the Python script. + :type args: List[Unions[str, int, float]] + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + + environment_variables = { + 'MY_KEY': 'MY_VALUE' + } + :type environment_variables: Dict[str, str] + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :type replica_count: int + :param machine_type: The type of machine to use for training. + :type machine_type: str + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :type accelerator_type: str + :param accelerator_count: The number of accelerators to attach to a worker replica. + :type accelerator_count: int + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :type boot_disk_type: str + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :type boot_disk_size_gb: int + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :type training_fraction_split: float + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :type validation_fraction_split: float + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :type test_fraction_split: float + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type training_filter_split: str + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type validation_filter_split: str + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type test_filter_split: str + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type predefined_split_column_name: str + :param timestamp_split_column_name : Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type timestamp_split_column_name: str + :param tensorboard: Optional. The name of a Vertex AI + [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] + resource to which this CustomJob will upload Tensorboard + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + The training script should write Tensorboard to following Vertex AI environment + variable: + + AIP_TENSORBOARD_LOG_DIR + + `service_account` is required with provided `tensorboard`. + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :type tensorboard: str + :param sync: Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :type sync: bool + """ + job = self.get_custom_container_training_job( + project=project_id, + location=region, + display_name=display_name, + container_uri=container_uri, + command=command, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + model = self._run_job( + job=job, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + + return model + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_python_package_training_job( + self, + project_id: str, + region: str, + display_name: str, + python_package_gcs_uri: str, + python_module_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> Model: + """ + Create Custom Python Package Training Job + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :type display_name: str + :param python_package_gcs_uri: Required: GCS location of the training python package. + :type python_package_gcs_uri: str + :param python_module_name: Required: The module name of the training python package. + :type python_module_name: str + :param container_uri: Required: Uri of the training container image in the GCR. + :type container_uri: str + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :type model_serving_container_image_uri: str + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :type model_serving_container_predict_route: str + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :type model_serving_container_health_route: str + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :type model_serving_container_command: Sequence[str] + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :type model_serving_container_args: Sequence[str] + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :type model_serving_container_environment_variables: Dict[str, str] + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :type model_serving_container_ports: Sequence[int] + :param model_description: The description of the Model. + :type model_description: str + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_instance_schema_uri: str + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :type model_parameters_schema_uri: str + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_prediction_schema_uri: str + :param project: Project to run training in. Overrides project set in aiplatform.init. + :type project: str + :param location: Location to run training in. Overrides location set in aiplatform.init. + :type location: str + :param credentials: Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + :type credentials: auth_credentials.Credentials + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type labels: Dict[str, str] + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + :type training_encryption_spec_key_name: Optional[str] + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + :type model_encryption_spec_key_name: Optional[str] + :param staging_bucket: Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + :type staging_bucket: str + + :param dataset: Vertex AI to fit this training against. Custom training script should + retrieve datasets through passed in environment variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + :type dataset: Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + The schema files that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :type annotation_schema_uri: str + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :type model_display_name: str + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type model_labels: Dict[str, str] + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :type base_output_dir: str + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :type service_account: str + :param network: The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :type network: str + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :type bigquery_destination: str + :param args: Command line arguments to be passed to the Python script. + :type args: List[Unions[str, int, float]] + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + + environment_variables = { + 'MY_KEY': 'MY_VALUE' + } + :type environment_variables: Dict[str, str] + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :type replica_count: int + :param machine_type: The type of machine to use for training. + :type machine_type: str + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :type accelerator_type: str + :param accelerator_count: The number of accelerators to attach to a worker replica. + :type accelerator_count: int + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :type boot_disk_type: str + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :type boot_disk_size_gb: int + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :type training_fraction_split: float + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :type validation_fraction_split: float + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :type test_fraction_split: float + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type training_filter_split: str + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type validation_filter_split: str + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type test_filter_split: str + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type predefined_split_column_name: str + :param timestamp_split_column_name : Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type timestamp_split_column_name: str + :param tensorboard: Optional. The name of a Vertex AI + [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] + resource to which this CustomJob will upload Tensorboard + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + The training script should write Tensorboard to following Vertex AI environment + variable: + + AIP_TENSORBOARD_LOG_DIR + + `service_account` is required with provided `tensorboard`. + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :type tensorboard: str + :param sync: Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :type sync: bool + """ + job = self.get_custom_python_package_training_job( + project=project_id, + location=region, + display_name=display_name, + python_package_gcs_uri=python_package_gcs_uri, + python_module_name=python_module_name, + container_uri=container_uri, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + model = self._run_job( + job=job, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + + return model + + @GoogleBaseHook.fallback_to_default_project_id + def create_custom_training_job( + self, + project_id: str, + region: str, + display_name: str, + script_path: str, + container_uri: str, + requirements: Optional[Sequence[str]] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + ) -> Model: + """ + Create Custom Training Job + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :type display_name: str + :param script_path: Required. Local path to training script. + :type script_path: str + :param container_uri: Required: Uri of the training container image in the GCR. + :type container_uri: str + :param requirements: List of python packages dependencies of script. + :type requirements: Sequence[str] + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :type model_serving_container_image_uri: str + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :type model_serving_container_predict_route: str + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :type model_serving_container_health_route: str + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :type model_serving_container_command: Sequence[str] + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :type model_serving_container_args: Sequence[str] + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :type model_serving_container_environment_variables: Dict[str, str] + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :type model_serving_container_ports: Sequence[int] + :param model_description: The description of the Model. + :type model_description: str + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_instance_schema_uri: str + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :type model_parameters_schema_uri: str + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_prediction_schema_uri: str + :param project: Project to run training in. Overrides project set in aiplatform.init. + :type project: str + :param location: Location to run training in. Overrides location set in aiplatform.init. + :type location: str + :param credentials: Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + :type credentials: auth_credentials.Credentials + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type labels: Dict[str, str] + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + :type training_encryption_spec_key_name: Optional[str] + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + :type model_encryption_spec_key_name: Optional[str] + :param staging_bucket: Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + :type staging_bucket: str + :param dataset: Vertex AI to fit this training against. Custom training script should + retrieve datasets through passed in environment variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + :type dataset: Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + The schema files that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :type annotation_schema_uri: str + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :type model_display_name: str + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type model_labels: Dict[str, str] + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :type base_output_dir: str + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :type service_account: str + :param network: The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :type network: str + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :type bigquery_destination: str + :param args: Command line arguments to be passed to the Python script. + :type args: List[Unions[str, int, float]] + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + + environment_variables = { + 'MY_KEY': 'MY_VALUE' + } + :type environment_variables: Dict[str, str] + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :type replica_count: int + :param machine_type: The type of machine to use for training. + :type machine_type: str + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :type accelerator_type: str + :param accelerator_count: The number of accelerators to attach to a worker replica. + :type accelerator_count: int + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :type boot_disk_type: str + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :type boot_disk_size_gb: int + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :type training_fraction_split: float + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :type validation_fraction_split: float + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :type test_fraction_split: float + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type training_filter_split: str + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type validation_filter_split: str + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type test_filter_split: str + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type predefined_split_column_name: str + :param timestamp_split_column_name : Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type timestamp_split_column_name: str + :param tensorboard: Optional. The name of a Vertex AI + [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] + resource to which this CustomJob will upload Tensorboard + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + The training script should write Tensorboard to following Vertex AI environment + variable: + + AIP_TENSORBOARD_LOG_DIR + + `service_account` is required with provided `tensorboard`. + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :type tensorboard: str + :param sync: Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :type sync: bool + """ + job = self.get_custom_training_job( + project=project_id, + location=region, + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + requirements=requirements, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + model = self._run_job( + job=job, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + sync=sync, + ) + + return model + + @GoogleBaseHook.fallback_to_default_project_id + def delete_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deletes a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + result = client.delete_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deletes a TrainingPipeline. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted. + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + result = client.delete_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_custom_job( + self, + project_id: str, + region: str, + custom_job: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + ) -> Operation: + """ + Deletes a CustomJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param custom_job: Required. The name of the CustomJob to delete. + :type custom_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_job_service_client(region) + name = JobServiceClient.custom_job_path(project_id, region, custom_job) + + result = client.delete_custom_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> PipelineJob: + """ + Gets a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param pipeline_job: Required. The name of the PipelineJob resource. + :type pipeline_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + result = client.get_pipeline_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TrainingPipeline: + """ + Gets a TrainingPipeline. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param training_pipeline: Required. The name of the TrainingPipeline resource. + :type training_pipeline: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + result = client.get_training_pipeline( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_custom_job( + self, + project_id: str, + region: str, + custom_job: str, + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + ) -> CustomJob: + """ + Gets a CustomJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param custom_job: Required. The name of the CustomJob to get. + :type custom_job: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_job_service_client(region) + name = JobServiceClient.custom_job_path(project_id, region, custom_job) + + result = client.get_custom_job( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_pipeline_jobs( + self, + project_id: str, + region: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListPipelineJobsPager: + """ + Lists PipelineJobs in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: Optional. Lists the PipelineJobs that match the filter expression. The + following fields are supported: + + - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons. + - ``display_name``: Supports ``=``, ``!=`` comparisons, and + ``:`` wildcard. + - ``pipeline_job_user_id``: Supports ``=``, ``!=`` + comparisons, and ``:`` wildcard. for example, can check + if pipeline's display_name contains *step* by doing + display_name:"*step*" + - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``labels``: Supports key-value equality and key presence. + + Filter expressions can be combined together using logical + operators (``AND`` & ``OR``). For example: + ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``. + + The syntax to define filter expression is based on + https://google.aip.dev/160. + :type filter: str + :param page_size: Optional. The standard list page size. + :type page_size: int + :param page_token: Optional. The standard list page token. Typically obtained via + [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token] + of the previous + [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs] + call. + :type page_token: str + :param order_by: Optional. A comma-separated list of fields to order by. The default + sort order is in ascending order. Use "desc" after a field + name for descending. You can have multiple order_by fields + provided e.g. "create_time desc, end_time", "end_time, + start_time, update_time" For example, using "create_time + desc, end_time" will order results by create time in + descending order, and if there are multiple jobs having the + same create time, order them by the end time in ascending + order. if order_by is not specified, it will order by + default order is create time in descending order. Supported + fields: + + - ``create_time`` + - ``update_time`` + - ``end_time`` + - ``start_time`` + :type order_by: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_pipeline_jobs( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_training_pipelines( + self, + project_id: str, + region: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListTrainingPipelinesPager: + """ + Lists TrainingPipelines in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: Optional. The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + :type filter: str + :param page_size: Optional. The standard list page size. + :type page_size: int + :param page_token: Optional. The standard list page token. Typically obtained via + [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] + of the previous + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] + call. + :type page_token: str + :param read_mask: Optional. Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_training_pipelines( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_custom_jobs( + self, + project_id: str, + region: str, + page_size: Optional[int], + page_token: Optional[str], + filter: Optional[str], + read_mask: Optional[str], + retry: Retry, + timeout: float, + metadata: Sequence[Tuple[str, str]], + ) -> ListCustomJobsPager: + """ + Lists CustomJobs in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: Optional. The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + :type filter: str + :param page_size: Optional. The standard list page size. + :type page_size: int + :param page_token: Optional. The standard list page token. Typically obtained via + [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] + of the previous + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] + call. + :type page_token: str + :param read_mask: Optional. Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_job_service_client(region) + parent = JobServiceClient.common_location_path(project_id, region) + + result = client.list_custom_jobs( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + 'filter': filter, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/tests/providers/google/cloud/hooks/test_vertex_ai.py b/tests/providers/google/cloud/hooks/test_vertex_ai.py index 329228723cbc6..4fc21f93b401e 100644 --- a/tests/providers/google/cloud/hooks/test_vertex_ai.py +++ b/tests/providers/google/cloud/hooks/test_vertex_ai.py @@ -17,38 +17,34 @@ # under the License. # -from typing import Dict, Optional, Sequence, Tuple, Union from unittest import TestCase, mock -from airflow import AirflowException -from airflow.providers.google.cloud.hooks.vertex_ai import VertexAIHook -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook -from google.api_core.retry import Retry - +from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook from tests.providers.google.cloud.utils.base_gcp_mock import ( mock_base_gcp_hook_default_project_id, - mock_base_gcp_hook_no_default_project_id) + mock_base_gcp_hook_no_default_project_id, +) TEST_GCP_CONN_ID: str = "test-gcp-conn-id" TEST_REGION: str = "test-region" TEST_PROJECT_ID: str = "test-project-id" -TEST_PIPELINE_JOB: dict = {} +TEST_PIPELINE_JOB: dict = {} TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id" TEST_TRAINING_PIPELINE: dict = {} TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline" BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" -VERTEX_AI_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.{}" +VERTEX_AI_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.custom_job.{}" + class TestVertexAIWithDefaultProjectIdHook(TestCase): def setUp(self): with mock.patch( - BASE_STRING.format("GoogleBaseHook.__init__"), - new=mock_base_gcp_hook_default_project_id + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id ): - self.hook = VertexAIHook(gcp_conn_id=TEST_GCP_CONN_ID) + self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_pipeline_job(self, mock_client) -> None: self.hook.cancel_pipeline_job( project_id=TEST_PROJECT_ID, @@ -64,9 +60,11 @@ def test_cancel_pipeline_job(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_training_pipeline(self, mock_client) -> None: self.hook.cancel_training_pipeline( project_id=TEST_PROJECT_ID, @@ -82,9 +80,11 @@ def test_cancel_training_pipeline(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_pipeline_job(self, mock_client) -> None: self.hook.create_pipeline_job( project_id=TEST_PROJECT_ID, @@ -105,7 +105,7 @@ def test_create_pipeline_job(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_training_pipeline(self, mock_client) -> None: self.hook.create_training_pipeline( project_id=TEST_PROJECT_ID, @@ -124,7 +124,7 @@ def test_create_training_pipeline(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_pipeline_job(self, mock_client) -> None: self.hook.delete_pipeline_job( project_id=TEST_PROJECT_ID, @@ -140,9 +140,11 @@ def test_delete_pipeline_job(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_training_pipeline(self, mock_client) -> None: self.hook.delete_training_pipeline( project_id=TEST_PROJECT_ID, @@ -158,9 +160,11 @@ def test_delete_training_pipeline(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_pipeline_job(self, mock_client) -> None: self.hook.get_pipeline_job( project_id=TEST_PROJECT_ID, @@ -176,9 +180,11 @@ def test_get_pipeline_job(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_training_pipeline(self, mock_client) -> None: self.hook.get_training_pipeline( project_id=TEST_PROJECT_ID, @@ -194,9 +200,11 @@ def test_get_training_pipeline(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_pipeline_jobs(self, mock_client) -> None: self.hook.list_pipeline_jobs( project_id=TEST_PROJECT_ID, @@ -217,7 +225,7 @@ def test_list_pipeline_jobs(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_training_pipelines(self, mock_client) -> None: self.hook.list_training_pipelines( project_id=TEST_PROJECT_ID, @@ -238,15 +246,15 @@ def test_list_training_pipelines(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + class TestVertexAIWithoutDefaultProjectIdHook(TestCase): def setUp(self): with mock.patch( - BASE_STRING.format("GoogleBaseHook.__init__"), - new=mock_base_gcp_hook_no_default_project_id + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id ): - self.hook = VertexAIHook(gcp_conn_id=TEST_GCP_CONN_ID) + self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_pipeline_job(self, mock_client) -> None: self.hook.cancel_pipeline_job( project_id=TEST_PROJECT_ID, @@ -262,9 +270,11 @@ def test_cancel_pipeline_job(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_training_pipeline(self, mock_client) -> None: self.hook.cancel_training_pipeline( project_id=TEST_PROJECT_ID, @@ -280,9 +290,11 @@ def test_cancel_training_pipeline(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_pipeline_job(self, mock_client) -> None: self.hook.create_pipeline_job( project_id=TEST_PROJECT_ID, @@ -303,7 +315,7 @@ def test_create_pipeline_job(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_training_pipeline(self, mock_client) -> None: self.hook.create_training_pipeline( project_id=TEST_PROJECT_ID, @@ -322,7 +334,7 @@ def test_create_training_pipeline(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_pipeline_job(self, mock_client) -> None: self.hook.delete_pipeline_job( project_id=TEST_PROJECT_ID, @@ -338,9 +350,11 @@ def test_delete_pipeline_job(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_training_pipeline(self, mock_client) -> None: self.hook.delete_training_pipeline( project_id=TEST_PROJECT_ID, @@ -356,9 +370,11 @@ def test_delete_training_pipeline(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_pipeline_job(self, mock_client) -> None: self.hook.get_pipeline_job( project_id=TEST_PROJECT_ID, @@ -374,9 +390,11 @@ def test_get_pipeline_job(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.pipeline_job_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_training_pipeline(self, mock_client) -> None: self.hook.get_training_pipeline( project_id=TEST_PROJECT_ID, @@ -392,9 +410,11 @@ def test_get_training_pipeline(self, mock_client) -> None: retry=None, timeout=None, ) - mock_client.return_value.training_pipeline_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME) + mock_client.return_value.training_pipeline_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME + ) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_pipeline_jobs(self, mock_client) -> None: self.hook.list_pipeline_jobs( project_id=TEST_PROJECT_ID, @@ -415,7 +435,7 @@ def test_list_pipeline_jobs(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("VertexAIHook.get_pipeline_service_client")) + @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_training_pipelines(self, mock_client) -> None: self.hook.list_training_pipelines( project_id=TEST_PROJECT_ID, From c7310c8610d2600f59f1b5c51e71935cb788f6bf Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Wed, 10 Nov 2021 16:31:29 +0000 Subject: [PATCH 04/20] Create CustomJob operators for VertexAI --- .../google/cloud/operators/vertex_ai.py | 1415 ----------------- .../cloud/operators/vertex_ai/__init__.py | 16 + .../cloud/operators/vertex_ai/custom_job.py | 477 ++++++ .../google/cloud/operators/test_vertex_ai.py | 361 +++++ 4 files changed, 854 insertions(+), 1415 deletions(-) delete mode 100644 airflow/providers/google/cloud/operators/vertex_ai.py create mode 100644 airflow/providers/google/cloud/operators/vertex_ai/__init__.py create mode 100644 airflow/providers/google/cloud/operators/vertex_ai/custom_job.py create mode 100644 tests/providers/google/cloud/operators/test_vertex_ai.py diff --git a/airflow/providers/google/cloud/operators/vertex_ai.py b/airflow/providers/google/cloud/operators/vertex_ai.py deleted file mode 100644 index 23825ff5be84b..0000000000000 --- a/airflow/providers/google/cloud/operators/vertex_ai.py +++ /dev/null @@ -1,1415 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -"""This module contains Google Vertex AI operators.""" - -from typing import Dict, List, Optional, Sequence, Tuple, Union - -from google.api_core.retry import Retry -from google.cloud.aiplatform import datasets, initializer, schema, utils -from google.cloud.aiplatform.utils import _timestamped_gcs_dir, worker_spec_utils -from google.cloud.aiplatform_v1.types import ( - BigQueryDestination, - CancelPipelineJobRequest, - CancelTrainingPipelineRequest, - CreatePipelineJobRequest, - CreateTrainingPipelineRequest, - DeletePipelineJobRequest, - DeleteTrainingPipelineRequest, - EnvVar, - FilterSplit, - FractionSplit, - GcsDestination, - GetPipelineJobRequest, - GetTrainingPipelineRequest, - InputDataConfig, - ListPipelineJobsRequest, - ListTrainingPipelinesRequest, - Model, - ModelContainerSpec, - PipelineJob, - Port, - PredefinedSplit, - PredictSchemata, - TimestampSplit, - TrainingPipeline, -) - -from airflow.models import BaseOperator -from airflow.providers.google.cloud.hooks.vertex_ai import VertexAIHook - - -class VertexAITrainingJobBaseOperator(BaseOperator): - """The base class for operators that launch job on VertexAI.""" - - def __init__( - self, - *, - region: str = None, - project_id: str, - display_name: str, - # START Run param - dataset: Optional[ - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ] = None, - annotation_schema_uri: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, - base_output_dir: Optional[str] = None, - service_account: Optional[str] = None, - network: Optional[str] = None, - bigquery_destination: Optional[str] = None, - args: Optional[List[Union[str, float, int]]] = None, - environment_variables: Optional[Dict[str, str]] = None, - replica_count: int = 1, - machine_type: str = "n1-standard-4", - accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", - accelerator_count: int = 0, - boot_disk_type: str = "pd-ssd", - boot_disk_size_gb: int = 100, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - tensorboard: Optional[str] = None, - sync=True, - # END Run param - # START Custom - container_uri: str, - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, - # END Custom - retry: Optional[Retry] = None, - timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = "", - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.region = region - self.project_id = project_id - self.display_name = display_name - # START Run param - self.dataset = dataset - self.annotation_schema_uri = annotation_schema_uri - self.model_display_name = model_display_name - self.model_labels = model_labels - self.base_output_dir = base_output_dir - self.service_account = service_account - self.network = network - self.bigquery_destination = bigquery_destination - self.args = args - self.environment_variables = environment_variables - self.replica_count = replica_count - self.machine_type = machine_type - self.accelerator_type = accelerator_type - self.accelerator_count = accelerator_count - self.boot_disk_type = boot_disk_type - self.boot_disk_size_gb = boot_disk_size_gb - self.training_fraction_split = training_fraction_split - self.validation_fraction_split = validation_fraction_split - self.test_fraction_split = test_fraction_split - self.training_filter_split = training_filter_split - self.validation_filter_split = validation_filter_split - self.test_filter_split = test_filter_split - self.predefined_split_column_name = predefined_split_column_name - self.timestamp_split_column_name = timestamp_split_column_name - self.tensorboard = tensorboard - self.sync = sync - # TODO: add optional and important parameters - # END Run param - # START Custom - self._container_uri = container_uri - - model_predict_schemata = None - if any( - [ - model_instance_schema_uri, - model_parameters_schema_uri, - model_prediction_schema_uri, - ] - ): - model_predict_schemata = PredictSchemata( - instance_schema_uri=model_instance_schema_uri, - parameters_schema_uri=model_parameters_schema_uri, - prediction_schema_uri=model_prediction_schema_uri, - ) - - # Create the container spec - env = None - ports = None - - if model_serving_container_environment_variables: - env = [ - EnvVar(name=str(key), value=str(value)) - for key, value in model_serving_container_environment_variables.items() - ] - - if model_serving_container_ports: - ports = [Port(container_port=port) for port in model_serving_container_ports] - - container_spec = ModelContainerSpec( - image_uri=model_serving_container_image_uri, - command=model_serving_container_command, - args=model_serving_container_args, - env=env, - ports=ports, - predict_route=model_serving_container_predict_route, - health_route=model_serving_container_health_route, - ) - - self._model_encryption_spec = initializer.global_config.get_encryption_spec( - encryption_spec_key_name=model_encryption_spec_key_name - ) - - self._managed_model = Model( - description=model_description, - predict_schemata=model_predict_schemata, - container_spec=container_spec, - encryption_spec=self._model_encryption_spec, - ) - - self.labels = labels - self._training_encryption_spec = initializer.global_config.get_encryption_spec( - encryption_spec_key_name=training_encryption_spec_key_name - ) - - self._staging_bucket = staging_bucket or initializer.global_config.staging_bucket - # END Custom - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.hook = VertexAIHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) - self.training_pipeline: Optional[TrainingPipeline] = None - self.worker_pool_specs: worker_spec_utils._DistributedTrainingSpec = None - self.managed_model: Optional[Model] = None - - def _prepare_and_validate_run( - self, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, - replica_count: int = 1, - machine_type: str = "n1-standard-4", - accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", - accelerator_count: int = 0, - boot_disk_type: str = "pd-ssd", - boot_disk_size_gb: int = 100, - ) -> Tuple[worker_spec_utils._DistributedTrainingSpec, Optional[Model]]: - """Create worker pool specs and managed model as well validating the - run. - - Args: - model_display_name (str): - If the script produces a managed Vertex AI Model. The display name of - the Model. The name can be up to 128 characters long and can be consist - of any UTF-8 characters. - - If not provided upon creation, the job's display_name is used. - model_labels (Dict[str, str]): - Optional. The labels with user-defined metadata to - organize your Models. - Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only - contain lowercase letters, numeric characters, - underscores and dashes. International characters - are allowed. - See https://goo.gl/xmQnxf for more information - and examples of labels. - replica_count (int): - The number of worker replicas. If replica count = 1 then one chief - replica will be provisioned. If replica_count > 1 the remainder will be - provisioned as a worker replica pool. - machine_type (str): - The type of machine to use for training. - accelerator_type (str): - Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, - NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, - NVIDIA_TESLA_T4 - accelerator_count (int): - The number of accelerators to attach to a worker replica. - boot_disk_type (str): - Type of the boot disk, default is `pd-ssd`. - Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or - `pd-standard` (Persistent Disk Hard Disk Drive). - boot_disk_size_gb (int): - Size in GB of the boot disk, default is 100GB. - boot disk size must be within the range of [100, 64000]. - Returns: - Worker pools specs and managed model for run. - - Raises: - RuntimeError: If Training job has already been run or model_display_name was - provided but required arguments were not provided in constructor. - """ - # TODO: Maybe not need - # if self._is_waiting_to_run(): - # raise RuntimeError("Custom Training is already scheduled to run.") - - # TODO: Maybe not need - # if self._has_run: - # raise RuntimeError("Custom Training has already run.") - - # if args needed for model is incomplete - if model_display_name and not self._managed_model.container_spec.image_uri: - raise RuntimeError( - """model_display_name was provided but - model_serving_container_image_uri was not provided when this - custom pipeline was constructed. - """ - ) - - if self._managed_model.container_spec.image_uri: - model_display_name = model_display_name or self._display_name + "-model" - - # validates args and will raise - worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( - replica_count=replica_count, - machine_type=machine_type, - accelerator_count=accelerator_count, - accelerator_type=accelerator_type, - boot_disk_type=boot_disk_type, - boot_disk_size_gb=boot_disk_size_gb, - ).pool_specs - - managed_model = self._managed_model - if model_display_name: - utils.validate_display_name(model_display_name) - managed_model.display_name = model_display_name - if model_labels: - utils.validate_labels(model_labels) - managed_model.labels = model_labels - else: - managed_model.labels = self.labels - else: - managed_model = None - - return worker_pool_specs, managed_model - - def _prepare_training_task_inputs_and_output_dir( - self, - worker_pool_specs: worker_spec_utils._DistributedTrainingSpec, - base_output_dir: Optional[str] = None, - service_account: Optional[str] = None, - network: Optional[str] = None, - tensorboard: Optional[str] = None, - ) -> Tuple[Dict, str]: - """Prepares training task inputs and output directory for custom job. - - Args: - worker_pools_spec (worker_spec_utils._DistributedTrainingSpec): - Worker pools pecs required to run job. - base_output_dir (str): - GCS output directory of job. If not provided a - timestamped directory in the staging directory will be used. - service_account (str): - Specifies the service account for workload run-as account. - Users submitting jobs must have act-as permission on this run-as account. - network (str): - The full name of the Compute Engine network to which the job - should be peered. For example, projects/12345/global/networks/myVPC. - Private services access must already be configured for the network. - If left unspecified, the job is not peered with any network. - tensorboard (str): - Optional. The name of a Vertex AI - [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] - resource to which this CustomJob will upload Tensorboard - logs. Format: - ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` - - The training script should write Tensorboard to following Vertex AI environment - variable: - - AIP_TENSORBOARD_LOG_DIR - - `service_account` is required with provided `tensorboard`. - For more information on configuring your service account please visit: - https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - Returns: - Training task inputs and Output directory for custom job. - """ - # default directory if not given - base_output_dir = base_output_dir or _timestamped_gcs_dir( - self._staging_bucket, "aiplatform-custom-training" - ) - - self.log.info(f"Training Output directory:\n{base_output_dir} ") - - training_task_inputs = { - "worker_pool_specs": worker_pool_specs, - "base_output_directory": {"output_uri_prefix": base_output_dir}, - } - - if service_account: - training_task_inputs["service_account"] = service_account - if network: - training_task_inputs["network"] = network - if tensorboard: - training_task_inputs["tensorboard"] = tensorboard - - return training_task_inputs, base_output_dir - - def _create_input_data_config( - self, - dataset: Optional[datasets._Dataset] = None, - annotation_schema_uri: Optional[str] = None, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - gcs_destination_uri_prefix: Optional[str] = None, - bigquery_destination: Optional[str] = None, - ) -> Optional[InputDataConfig]: - """Constructs a input data config to pass to the training pipeline. - - Args: - dataset (datasets._Dataset): - The dataset within the same Project from which data will be used to train the Model. The - Dataset must use schema compatible with Model being trained, - and what is compatible should be described in the used - TrainingPipeline's [training_task_definition] - [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. - For tabular Datasets, all their data is exported to - training, to pick and choose from. - annotation_schema_uri (str): - Google Cloud Storage URI points to a YAML file describing - annotation schema. The schema is defined as an OpenAPI 3.0.2 - [Schema Object] - (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) - The schema files that can be used here are found in - gs://google-cloud-aiplatform/schema/dataset/annotation/, - note that the chosen schema must be consistent with - ``metadata`` - of the Dataset specified by - ``dataset_id``. - - Only Annotations that both match this schema and belong to - DataItems not ignored by the split method are used in - respectively training, validation or test role, depending on - the role of the DataItem they are on. - - When used in conjunction with - ``annotations_filter``, - the Annotations used for training are filtered by both - ``annotations_filter`` - and - ``annotation_schema_uri``. - training_fraction_split (float): - Optional. The fraction of the input data that is to be used to train - the Model. This is ignored if Dataset is not provided. - validation_fraction_split (float): - Optional. The fraction of the input data that is to be used to validate - the Model. This is ignored if Dataset is not provided. - test_fraction_split (float): - Optional. The fraction of the input data that is to be used to evaluate - the Model. This is ignored if Dataset is not provided. - training_filter_split (str): - Optional. A filter on DataItems of the Dataset. DataItems that match - this filter are used to train the Model. A filter with same syntax - as the one used in DatasetService.ListDataItems may be used. If a - single DataItem is matched by more than one of the FilterSplit filters, - then it is assigned to the first set that applies to it in the training, - validation, test order. This is ignored if Dataset is not provided. - validation_filter_split (str): - Optional. A filter on DataItems of the Dataset. DataItems that match - this filter are used to validate the Model. A filter with same syntax - as the one used in DatasetService.ListDataItems may be used. If a - single DataItem is matched by more than one of the FilterSplit filters, - then it is assigned to the first set that applies to it in the training, - validation, test order. This is ignored if Dataset is not provided. - test_filter_split (str): - Optional. A filter on DataItems of the Dataset. DataItems that match - this filter are used to test the Model. A filter with same syntax - as the one used in DatasetService.ListDataItems may be used. If a - single DataItem is matched by more than one of the FilterSplit filters, - then it is assigned to the first set that applies to it in the training, - validation, test order. This is ignored if Dataset is not provided. - predefined_split_column_name (str): - Optional. The key is a name of one of the Dataset's data - columns. The value of the key (either the label's value or - value in the column) must be one of {``training``, - ``validation``, ``test``}, and it defines to which set the - given piece of data is assigned. If for a piece of data the - key is not present or has an invalid value, that piece is - ignored by the pipeline. - - Supported only for tabular and time series Datasets. - timestamp_split_column_name (str): - Optional. The key is a name of one of the Dataset's data - columns. The value of the key values of the key (the values in - the column) must be in RFC 3339 `date-time` format, where - `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a - piece of data the key is not present or has an invalid value, - that piece is ignored by the pipeline. - - Supported only for tabular and time series Datasets. - This parameter must be used with training_fraction_split, - validation_fraction_split and test_fraction_split. - gcs_destination_uri_prefix (str): - Optional. The Google Cloud Storage location. - - The Vertex AI environment variables representing Google - Cloud Storage data URIs will always be represented in the - Google Cloud Storage wildcard format to support sharded - data. - - - AIP_DATA_FORMAT = "jsonl". - - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" - - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" - - AIP_TEST_DATA_URI = "gcs_destination/test-*". - bigquery_destination (str): - The BigQuery project location where the training data is to - be written to. In the given project a new dataset is created - with name - ``dataset___`` - where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All - training input data will be written into that dataset. In - the dataset three tables will be created, ``training``, - ``validation`` and ``test``. - - - AIP_DATA_FORMAT = "bigquery". - - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" - Raises: - ValueError: When more than 1 type of split configuration is passed or when - the split configuration passed is incompatible with the dataset schema. - """ - input_data_config = None - if dataset: - # Initialize all possible splits - filter_split = None - predefined_split = None - timestamp_split = None - fraction_split = None - - # Create filter split - if any( - [ - training_filter_split is not None, - validation_filter_split is not None, - test_filter_split is not None, - ] - ): - if all( - [ - training_filter_split is not None, - validation_filter_split is not None, - test_filter_split is not None, - ] - ): - filter_split = FilterSplit( - training_filter=training_filter_split, - validation_filter=validation_filter_split, - test_filter=test_filter_split, - ) - else: - raise ValueError("All filter splits must be passed together or not at all") - - # Create predefined split - if predefined_split_column_name: - predefined_split = PredefinedSplit(key=predefined_split_column_name) - - # Create timestamp split or fraction split - if timestamp_split_column_name: - timestamp_split = TimestampSplit( - training_fraction=training_fraction_split, - validation_fraction=validation_fraction_split, - test_fraction=test_fraction_split, - key=timestamp_split_column_name, - ) - elif any( - [ - training_fraction_split is not None, - validation_fraction_split is not None, - test_fraction_split is not None, - ] - ): - fraction_split = FractionSplit( - training_fraction=training_fraction_split, - validation_fraction=validation_fraction_split, - test_fraction=test_fraction_split, - ) - - splits = [ - split - for split in [ - filter_split, - predefined_split, - timestamp_split_column_name, - fraction_split, - ] - if split is not None - ] - - # Fallback to fraction split if nothing else is specified - if len(splits) == 0: - self.log.info("No dataset split provided. The service will use a default split.") - elif len(splits) > 1: - raise ValueError( - """Can only specify one of: - 1. training_filter_split, validation_filter_split, test_filter_split - 2. predefined_split_column_name - 3. timestamp_split_column_name, training_fraction_split, validation_fraction_split, - test_fraction_split - 4. training_fraction_split, validation_fraction_split, test_fraction_split""" - ) - - # create GCS destination - gcs_destination = None - if gcs_destination_uri_prefix: - gcs_destination = GcsDestination(output_uri_prefix=gcs_destination_uri_prefix) - - # TODO(b/177416223) validate managed BQ dataset is passed in - bigquery_destination_proto = None - if bigquery_destination: - bigquery_destination_proto = BigQueryDestination(output_uri=bigquery_destination) - - # create input data config - input_data_config = InputDataConfig( - fraction_split=fraction_split, - filter_split=filter_split, - predefined_split=predefined_split, - timestamp_split=timestamp_split, - dataset_id=dataset.name, - annotation_schema_uri=annotation_schema_uri, - gcs_destination=gcs_destination, - bigquery_destination=bigquery_destination_proto, - ) - - return input_data_config - - def _get_model(self, training_pipeline): - # TODO: implement logic for extract model from training_pipeline object - pass - - def execute(self, context): - (training_task_inputs, base_output_dir,) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs=self.worker_pool_specs, - base_output_dir=self.base_output_dir, - service_account=self.service_account, - network=self.network, - tensorboard=self.tensorboard, - ) - - input_data_config = self._create_input_data_config( - dataset=self.dataset, - annotation_schema_uri=self.annotation_schema_uri, - training_fraction_split=self.training_fraction_split, - validation_fraction_split=self.validation_fraction_split, - test_fraction_split=self.test_fraction_split, - training_filter_split=self.training_filter_split, - validation_filter_split=self.validation_filter_split, - test_filter_split=self.test_filter_split, - predefined_split_column_name=self.predefined_split_column_name, - timestamp_split_column_name=self.timestamp_split_column_name, - gcs_destination_uri_prefix=base_output_dir, - bigquery_destination=self.bigquery_destination, - ) - - # create training pipeline configuration object - training_pipeline = TrainingPipeline( - display_name=self.display_name, - training_task_definition=schema.training_job.definition.custom_task, # TODO: different for automl - training_task_inputs=training_task_inputs, # Required - model_to_upload=self.managed_model, # Optional - input_data_config=input_data_config, # Optional - labels=self.labels, # Optional - encryption_spec=self._training_encryption_spec, # Optional - ) - - self.training_pipeline = self.hook.create_training_pipeline( - project_id=self.project_id, - region=self.region, - training_pipeline=training_pipeline, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - model = self._get_model(self.training_pipeline) - - return model - - def on_kill(self) -> None: - """ - Callback called when the operator is killed. - Cancel any running job. - """ - if self.training_pipeline: - self.hook.cancel_training_pipeline( - project_id=self.project_id, - region=self.region, - training_pipeline=self.training_pipeline.name, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - - -class VertexAICreateCustomContainerTrainingJobOperator(VertexAITrainingJobBaseOperator): - """Create Custom Container Training job""" - - template_fields = [ - 'region', - 'impersonation_chain', - ] - - def __init__( - self, - *, - command: Sequence[str] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self._command = command - - def execute(self, context): - self.worker_pool_specs, self.managed_model = self._prepare_and_validate_run( - model_display_name=self.model_display_name, - model_labels=self.model_labels, - replica_count=self.replica_count, - machine_type=self.machine_type, - accelerator_count=self.accelerator_count, - accelerator_type=self.accelerator_type, - boot_disk_type=self.boot_disk_type, - boot_disk_size_gb=self.boot_disk_size_gb, - ) - - for spec in self.worker_pool_specs: - spec["containerSpec"] = {"imageUri": self._container_uri} - - if self._command: - spec["containerSpec"]["command"] = self._command - - if self.args: - spec["containerSpec"]["args"] = self.args - - if self.environment_variables: - spec["containerSpec"]["env"] = [ - {"name": key, "value": value} for key, value in self.environment_variables.items() - ] - - super().execute(context) - - -class VertexAICreateCustomPythonPackageTrainingJobOperator(VertexAITrainingJobBaseOperator): - """Create Custom Python Package Training job""" - - template_fields = [ - 'region', - 'impersonation_chain', - ] - - def __init__( - self, - python_package_gcs_uri: str, - python_module_name: str, - ) -> None: - self._package_gcs_uri = python_package_gcs_uri - self._python_module = python_module_name - - def execute(self, context): - self.worker_pool_specs, self.managed_model = self._prepare_and_validate_run( - model_display_name=self.model_display_name, - model_labels=self.model_labels, - replica_count=self.replica_count, - machine_type=self.machine_type, - accelerator_count=self.accelerator_count, - accelerator_type=self.accelerator_type, - boot_disk_type=self.boot_disk_type, - boot_disk_size_gb=self.boot_disk_size_gb, - ) - - for spec in self.worker_pool_specs: - spec["python_package_spec"] = { - "executor_image_uri": self._container_uri, - "python_module": self._python_module, - "package_uris": [self._package_gcs_uri], - } - - if self.args: - spec["python_package_spec"]["args"] = self.args - - if self.environment_variables: - spec["python_package_spec"]["env"] = [ - {"name": key, "value": value} for key, value in self.environment_variables.items() - ] - - super().execute(context) - - -class VertexAICreateCustomTrainingJobOperator(VertexAITrainingJobBaseOperator): - """Create Custom Training job""" - - def __init__( - self, - display_name: str, - script_path: str, - ) -> None: - pass - - def execute(self, context): - self.worker_pool_specs, self.managed_model = self._prepare_and_validate_run( - model_display_name=self.model_display_name, - model_labels=self.model_labels, - replica_count=self.replica_count, - machine_type=self.machine_type, - accelerator_count=self.accelerator_count, - accelerator_type=self.accelerator_type, - boot_disk_type=self.boot_disk_type, - boot_disk_size_gb=self.boot_disk_size_gb, - ) - super().execute(context) - - -class VertexAICancelPipelineJobOperator(BaseOperator): - """ - Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best effort - to cancel the pipeline, but success is not guaranteed. Clients can use - [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other - methods to check whether the cancellation succeeded or whether the pipeline completed despite - cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a pipeline - with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a - [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and - [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. - - :param request: The request object. Request message for - [PipelineService.CancelPipelineJob][google.cloud.aiplatform.v1.PipelineService.CancelPipelineJob]. - :type request: Union[google.cloud.aiplatform_v1.types.CancelPipelineJobRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param pipeline_job: TODO: Fill description - :type pipeline_job: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[CancelPipelineJobRequest, Dict], - location: str, - pipeline_job: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.pipeline_job = pipeline_job - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.cancel_pipeline_job( - request=self.request, - location=self.location, - pipeline_job=self.pipeline_job, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAICancelTrainingPipelineOperator(BaseOperator): - """ - Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a - best effort to cancel the pipeline, but success is not guaranteed. Clients can use - [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline] or - other methods to check whether the cancellation succeeded or whether the pipeline completed despite - cancellation. On successful cancellation, the TrainingPipeline is not deleted; instead it becomes a - pipeline with a [TrainingPipeline.error][google.cloud.aiplatform.v1.TrainingPipeline.error] value with a - [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and - [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``. - - :param request: The request object. Request message for [PipelineService.CancelTrainingPipeline][google.c - loud.aiplatform.v1.PipelineService.CancelTrainingPipeline]. - :type request: Union[google.cloud.aiplatform_v1.types.CancelTrainingPipelineRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param training_pipeline: TODO: Fill description - :type training_pipeline: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[CancelTrainingPipelineRequest, Dict], - location: str, - training_pipeline: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.training_pipeline = training_pipeline - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.cancel_training_pipeline( - request=self.request, - location=self.location, - training_pipeline=self.training_pipeline, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAICreatePipelineJobOperator(BaseOperator): - """ - Creates a PipelineJob. A PipelineJob will run immediately when created. - - :param request: The request object. Request message for - [PipelineService.CreatePipelineJob][google.cloud.aiplatform.v1.PipelineService.CreatePipelineJob]. - :type request: Union[google.cloud.aiplatform_v1.types.CreatePipelineJobRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param pipeline_job: Required. The PipelineJob to create. This corresponds to the ``pipeline_job`` field - on the ``request`` instance; if ``request`` is provided, this should not be set. - :type pipeline_job: google.cloud.aiplatform_v1.types.PipelineJob - :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of the - PipelineJob name. If not provided, an ID will be automatically generated. - - This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. - - This corresponds to the ``pipeline_job_id`` field on the ``request`` instance; if ``request`` is - provided, this should not be set. - :type pipeline_job_id: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[CreatePipelineJobRequest, Dict], - location: str, - pipeline_job: PipelineJob, - pipeline_job_id: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.pipeline_job = pipeline_job - self.pipeline_job_id = pipeline_job_id - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.create_pipeline_job( - request=self.request, - location=self.location, - pipeline_job=self.pipeline_job, - pipeline_job_id=self.pipeline_job_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAICreateTrainingPipelineOperator(BaseOperator): - """ - Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. - - :param request: The request object. Request message for [PipelineService.CreateTrainingPipeline][google.c - loud.aiplatform.v1.PipelineService.CreateTrainingPipeline]. - :type request: Union[google.cloud.aiplatform_v1.types.CreateTrainingPipelineRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param training_pipeline: Required. The TrainingPipeline to create. - - This corresponds to the ``training_pipeline`` field on the ``request`` instance; if ``request`` is - provided, this should not be set. - :type training_pipeline: google.cloud.aiplatform_v1.types.TrainingPipeline - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[CreateTrainingPipelineRequest, Dict], - location: str, - training_pipeline: TrainingPipeline, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.training_pipeline = training_pipeline - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.create_training_pipeline( - request=self.request, - location=self.location, - training_pipeline=self.training_pipeline, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAIDeletePipelineJobOperator(BaseOperator): - """ - Deletes a PipelineJob. - - :param request: The request object. Request message for - [PipelineService.DeletePipelineJob][google.cloud.aiplatform.v1.PipelineService.DeletePipelineJob]. - :type request: Union[google.cloud.aiplatform_v1.types.DeletePipelineJobRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param pipeline_job: TODO: Fill description - :type pipeline_job: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[DeletePipelineJobRequest, Dict], - location: str, - pipeline_job: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.pipeline_job = pipeline_job - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.delete_pipeline_job( - request=self.request, - location=self.location, - pipeline_job=self.pipeline_job, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAIDeleteTrainingPipelineOperator(BaseOperator): - """ - Deletes a TrainingPipeline. - - :param request: The request object. Request message for [PipelineService.DeleteTrainingPipeline][google.c - loud.aiplatform.v1.PipelineService.DeleteTrainingPipeline]. - :type request: Union[google.cloud.aiplatform_v1.types.DeleteTrainingPipelineRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param training_pipeline: TODO: Fill description - :type training_pipeline: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[DeleteTrainingPipelineRequest, Dict], - location: str, - training_pipeline: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.training_pipeline = training_pipeline - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.delete_training_pipeline( - request=self.request, - location=self.location, - training_pipeline=self.training_pipeline, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAIGetPipelineJobOperator(BaseOperator): - """ - Gets a PipelineJob. - - :param request: The request object. Request message for - [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob]. - :type request: Union[google.cloud.aiplatform_v1.types.GetPipelineJobRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param pipeline_job: TODO: Fill description - :type pipeline_job: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[GetPipelineJobRequest, Dict], - location: str, - pipeline_job: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.pipeline_job = pipeline_job - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.get_pipeline_job( - request=self.request, - location=self.location, - pipeline_job=self.pipeline_job, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAIGetTrainingPipelineOperator(BaseOperator): - """ - Gets a TrainingPipeline. - - :param request: The request object. Request message for - [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline]. - :type request: Union[google.cloud.aiplatform_v1.types.GetTrainingPipelineRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param training_pipeline: TODO: Fill description - :type training_pipeline: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[GetTrainingPipelineRequest, Dict], - location: str, - training_pipeline: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.training_pipeline = training_pipeline - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.get_training_pipeline( - request=self.request, - location=self.location, - training_pipeline=self.training_pipeline, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAIListPipelineJobsOperator(BaseOperator): - """ - Lists PipelineJobs in a Location. - - :param request: The request object. Request message for - [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs]. - :type request: Union[google.cloud.aiplatform_v1.types.ListPipelineJobsRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[ListPipelineJobsRequest, Dict], - location: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.list_pipeline_jobs( - request=self.request, - location=self.location, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) - - -class VertexAIListTrainingPipelinesOperator(BaseOperator): - """ - Lists TrainingPipelines in a Location. - - :param request: The request object. Request message for - [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelin - es]. - :type request: Union[google.cloud.aiplatform_v1.types.ListTrainingPipelinesRequest, Dict] - :param location: TODO: Fill description - :type location: str - :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry - :param timeout: The timeout for this request. - :type timeout: float - :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] - :param project_id: TODO: Fill description - :type project_id: str - :param gcp_conn_id: - :type gcp_conn_id: str - """ - - def __init__( - self, - request: Union[ListTrainingPipelinesRequest, Dict], - location: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], - project_id: str = None, - gcp_conn_id: str = "google_cloud_default", - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.request = request - self.location = location - self.retry = retry - self.timeout = timeout - self.metadata = metadata - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Dict): - hook = VertexAIHook(gcp_conn_id=self.gcp_conn_id) - hook.list_training_pipelines( - request=self.request, - location=self.location, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - project_id=self.project_id, - ) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/__init__.py b/airflow/providers/google/cloud/operators/vertex_ai/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py new file mode 100644 index 0000000000000..05a3d756d8a08 --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -0,0 +1,477 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Vertex AI operators.""" + +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core.retry import Retry +from google.cloud.aiplatform import datasets + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook + + +class _CustomTrainingJobBaseOperator(BaseOperator): + """The base class for operators that launch Custom jobs on VertexAI.""" + + def __init__( + self, + *, + project_id: str, + region: str, + display_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + # RUN + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + model_labels: Optional[Dict[str, str]] = None, + base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + tensorboard: Optional[str] = None, + sync=True, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.display_name = display_name + # START Custom + self.container_uri = container_uri + self.model_serving_container_image_uri = model_serving_container_image_uri + self.model_serving_container_predict_route = model_serving_container_predict_route + self.model_serving_container_health_route = model_serving_container_health_route + self.model_serving_container_command = model_serving_container_command + self.model_serving_container_args = model_serving_container_args + self.model_serving_container_environment_variables = model_serving_container_environment_variables + self.model_serving_container_ports = model_serving_container_ports + self.model_description = model_description + self.model_instance_schema_uri = model_instance_schema_uri + self.model_parameters_schema_uri = model_parameters_schema_uri + self.model_prediction_schema_uri = model_prediction_schema_uri + self.labels = labels + self.training_encryption_spec_key_name = training_encryption_spec_key_name + self.model_encryption_spec_key_name = model_encryption_spec_key_name + self.staging_bucket = staging_bucket + # END Custom + # START Run param + self.dataset = dataset + self.annotation_schema_uri = annotation_schema_uri + self.model_display_name = model_display_name + self.model_labels = model_labels + self.base_output_dir = base_output_dir + self.service_account = service_account + self.network = network + self.bigquery_destination = bigquery_destination + self.args = args + self.environment_variables = environment_variables + self.replica_count = replica_count + self.machine_type = machine_type + self.accelerator_type = accelerator_type + self.accelerator_count = accelerator_count + self.boot_disk_type = boot_disk_type + self.boot_disk_size_gb = boot_disk_size_gb + self.training_fraction_split = training_fraction_split + self.validation_fraction_split = validation_fraction_split + self.test_fraction_split = test_fraction_split + self.training_filter_split = training_filter_split + self.validation_filter_split = validation_filter_split + self.test_filter_split = test_filter_split + self.predefined_split_column_name = predefined_split_column_name + self.timestamp_split_column_name = timestamp_split_column_name + self.tensorboard = tensorboard + self.sync = sync + # END Run param + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook = CustomJobHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) + + def on_kill(self) -> None: + """ + Callback called when the operator is killed. + Cancel any running job. + """ + self.hook.cancel_training_pipeline( + project_id=self.project_id, + region=self.region, + training_pipeline=self.display_name, + ) + self.hook.cancel_custom_job( + project_id=self.project_id, region=self.region, custom_job=f"{self.display_name}-custom-job" + ) + + +class CreateCustomContainerTrainingJobOperator(_CustomTrainingJobBaseOperator): + """Create Custom Container Training job""" + + template_fields = [ + 'region', + 'command', + 'impersonation_chain', + ] + + def __init__( + self, + *, + command: Sequence[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.command = command + + def execute(self, context): + model = self.hook.create_custom_container_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + container_uri=self.container_uri, + command=self.command, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=self.dataset, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + sync=True, + ) + return model + + +class CreateCustomPythonPackageTrainingJobOperator(_CustomTrainingJobBaseOperator): + """Create Custom Python Package Training job""" + + template_fields = [ + 'region', + 'impersonation_chain', + ] + + def __init__( + self, + *, + python_package_gcs_uri: str, + python_module_name: str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.python_package_gcs_uri = python_package_gcs_uri + self.python_module_name = python_module_name + + def execute(self, context): + model = self.hook.create_custom_python_package_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + python_package_gcs_uri=self.python_package_gcs_uri, + python_module_name=self.python_module_name, + container_uri=self.container_uri, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=self.dataset, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + sync=True, + ) + return model + + +class CreateCustomTrainingJobOperator(_CustomTrainingJobBaseOperator): + """Create Custom Training job""" + + template_fields = [ + 'region', + 'script_path', + 'requirements', + 'impersonation_chain', + ] + + def __init__( + self, + *, + script_path: str, + requirements: Optional[Sequence[str]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.requirements = requirements + self.script_path = script_path + + def execute(self, context): + model = self.hook.create_custom_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + script_path=self.script_path, + container_uri=self.container_uri, + requirements=self.requirements, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=self.dataset, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + sync=True, + ) + return model + + +class DeleteCustomTrainingJobOperator(BaseOperator): + """Deletes a CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob.""" + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + training_pipeline: str, + region: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.training_pipeline = training_pipeline + self.region = region + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CustomJobHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Deleting custom training job: %s", self.training_pipeline) + hook.delete_training_pipeline( + training_pipeline=self.training_pipeline, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.delete_custom_job( + custom_job=f"{self.training_pipeline}-custom-job", + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Custom training job deleted.") + + +class ListCustomTrainingJobOperator(BaseOperator): + """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location.""" + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + filter: Optional[str] = None, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.page_size = page_size + self.page_token = page_token + self.filter = filter + self.read_mask = read_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CustomJobHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook.list_training_pipelines( + region=self.region, + project_id=self.project_id, + page_size=self.page_size, + page_token=self.page_token, + filter=self.filter, + read_mask=self.read_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py new file mode 100644 index 0000000000000..34a3afcf95a6e --- /dev/null +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -0,0 +1,361 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +from google.api_core.retry import Retry + +from airflow.providers.google.cloud.operators.vertex_ai.custom_job import ( + CreateCustomContainerTrainingJobOperator, + CreateCustomPythonPackageTrainingJobOperator, + CreateCustomTrainingJobOperator, + DeleteCustomTrainingJobOperator, + ListCustomTrainingJobOperator, +) + +VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" +TIMEOUT = 120 +RETRY = mock.MagicMock(Retry) +METADATA = [("key", "value")] + +TASK_ID = "test_task_id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "test-location" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +STAGING_BUCKET = "gs://test-vertex-ai-bucket" +DISPLAY_NAME = "display_name_1" # Create random display name +DISPLAY_NAME_2 = "display_nmae_2" +ARGS = ["--tfds", "tf_flowers:3.*.*"] +CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" +REPLICA_COUNT = 1 +MACHINE_TYPE = "n1-standard-4" +ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" +ACCELERATOR_COUNT = 0 +TRAINING_FRACTION_SPLIT = 0.7 +TEST_FRACTION_SPLIT = 0.15 +VALIDATION_FRACTION_SPLIT = 0.15 +COMMAND_2 = ['echo', 'Hello World'] + +TEST_API_ENDPOINT: str = "test-api-endpoint" +TEST_PIPELINE_JOB: str = "test-pipeline-job" +TEST_TRAINING_PIPELINE: str = "test-training-pipeline" +TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id" + +PYTHON_PACKAGE = "/files/trainer-0.1.tar.gz" +PYTHON_PACKAGE_CMDARGS = "test-python-cmd" +PYTHON_PACKAGE_GCS_URI = "gs://test-vertex-ai-bucket/trainer-0.1.tar.gz" +PYTHON_MODULE_NAME = "trainer.task" + + +class TestVertexAICreateCustomContainerTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = CreateCustomContainerTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_custom_container_training_job.assert_called_once_with( + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + dataset=None, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model_serving_container_predict_route=None, + model_serving_container_health_route=None, + model_serving_container_command=None, + model_serving_container_args=None, + model_serving_container_environment_variables=None, + model_serving_container_ports=None, + model_description=None, + model_instance_schema_uri=None, + model_parameters_schema_uri=None, + model_prediction_schema_uri=None, + labels=None, + training_encryption_spec_key_name=None, + model_encryption_spec_key_name=None, + # RUN + annotation_schema_uri=None, + model_labels=None, + base_output_dir=None, + service_account=None, + network=None, + bigquery_destination=None, + environment_variables=None, + boot_disk_type='pd-ssd', + boot_disk_size_gb=100, + training_filter_split=None, + validation_filter_split=None, + test_filter_split=None, + predefined_split_column_name=None, + timestamp_split_column_name=None, + tensorboard=None, + sync=True, + ) + + +class TestVertexAICreateCustomPythonPackageTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = CreateCustomPythonPackageTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + args=ARGS, + model_serving_container_image_uri=CONTAINER_URI, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with( + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + dataset=None, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model_serving_container_predict_route=None, + model_serving_container_health_route=None, + model_serving_container_command=None, + model_serving_container_args=None, + model_serving_container_environment_variables=None, + model_serving_container_ports=None, + model_description=None, + model_instance_schema_uri=None, + model_parameters_schema_uri=None, + model_prediction_schema_uri=None, + labels=None, + training_encryption_spec_key_name=None, + model_encryption_spec_key_name=None, + # RUN + annotation_schema_uri=None, + model_labels=None, + base_output_dir=None, + service_account=None, + network=None, + bigquery_destination=None, + environment_variables=None, + boot_disk_type='pd-ssd', + boot_disk_size_gb=100, + training_filter_split=None, + validation_filter_split=None, + test_filter_split=None, + predefined_split_column_name=None, + timestamp_split_column_name=None, + tensorboard=None, + sync=True, + ) + + +class TestVertexAICreateCustomTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = CreateCustomTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + script_path=PYTHON_PACKAGE, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + requirements=[], + replica_count=1, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_custom_training_job.assert_called_once_with( + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + script_path=PYTHON_PACKAGE, + requirements=[], + dataset=None, + model_display_name=None, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=None, + validation_fraction_split=None, + test_fraction_split=None, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model_serving_container_predict_route=None, + model_serving_container_health_route=None, + model_serving_container_command=None, + model_serving_container_args=None, + model_serving_container_environment_variables=None, + model_serving_container_ports=None, + model_description=None, + model_instance_schema_uri=None, + model_parameters_schema_uri=None, + model_prediction_schema_uri=None, + labels=None, + training_encryption_spec_key_name=None, + model_encryption_spec_key_name=None, + # RUN + annotation_schema_uri=None, + model_labels=None, + base_output_dir=None, + service_account=None, + network=None, + bigquery_destination=None, + environment_variables=None, + boot_disk_type='pd-ssd', + boot_disk_size_gb=100, + training_filter_split=None, + validation_filter_split=None, + test_filter_split=None, + predefined_split_column_name=None, + timestamp_split_column_name=None, + tensorboard=None, + sync=True, + ) + + +class TestVertexAIDeleteCustomTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + op = DeleteCustomTrainingJobOperator( + task_id=TASK_ID, + training_pipeline=DISPLAY_NAME, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.delete_training_pipeline.assert_called_once_with( + training_pipeline=DISPLAY_NAME, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.delete_custom_job.assert_called_once_with( + custom_job=f"{DISPLAY_NAME}-custom-job", + region=GCP_LOCATION, + project_id=GCP_PROJECT, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIListCustomTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) + def test_execute(self, mock_hook): + page_token = "page_token" + page_size = 42 + filter = "filter" + read_mask = "read_mask" + + op = ListCustomTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + read_mask=read_mask, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.list_training_pipelines.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + read_mask=read_mask, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) From 22787435e6d560ce5c37ca1573eba4375b81b4dd Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Tue, 16 Nov 2021 13:12:23 +0000 Subject: [PATCH 05/20] Create Dataset hooks for Vertex AI service --- .../cloud/hooks/vertex_ai/custom_job.py | 30 +- .../google/cloud/hooks/vertex_ai/dataset.py | 535 ++++++++++++++++++ .../google/cloud/hooks/vertex_ai/__init__.py | 16 + .../test_custom_job.py} | 46 +- .../cloud/hooks/vertex_ai/test_dataset.py | 494 ++++++++++++++++ 5 files changed, 1083 insertions(+), 38 deletions(-) create mode 100644 airflow/providers/google/cloud/hooks/vertex_ai/dataset.py create mode 100644 tests/providers/google/cloud/hooks/vertex_ai/__init__.py rename tests/providers/google/cloud/hooks/{test_vertex_ai.py => vertex_ai/test_custom_job.py} (89%) create mode 100644 tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index 78f7682cff6d6..b8f608d9d4926 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -397,9 +397,9 @@ def cancel_custom_job( project_id: str, region: str, custom_job: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> None: """ Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort @@ -529,9 +529,9 @@ def create_custom_job( project_id: str, region: str, custom_job: CustomJob, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> CustomJob: """ Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -1991,9 +1991,9 @@ def delete_custom_job( project_id: str, region: str, custom_job: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Operation: """ Deletes a CustomJob. @@ -2108,9 +2108,9 @@ def get_custom_job( project_id: str, region: str, custom_job: str, - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> CustomJob: """ Gets a CustomJob. @@ -2317,9 +2317,9 @@ def list_custom_jobs( page_token: Optional[str], filter: Optional[str], read_mask: Optional[str], - retry: Retry, - timeout: float, - metadata: Sequence[Tuple[str, str]], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> ListCustomJobsPager: """ Lists CustomJobs in a Location. diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py new file mode 100644 index 0000000000000..b8e96c771ff20 --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py @@ -0,0 +1,535 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud Vertex AI hook.""" + +from typing import Optional, Sequence, Tuple + +from google.api_core.operation import Operation +from google.api_core.retry import Retry +from google.cloud.aiplatform_v1 import DatasetServiceClient +from google.cloud.aiplatform_v1.services.dataset_service.pagers import ( + ListAnnotationsPager, + ListDataItemsPager, + ListDatasetsPager, +) +from google.cloud.aiplatform_v1.types import AnnotationSpec, Dataset, ExportDataConfig, ImportDataConfig +from google.protobuf.field_mask_pb2 import FieldMask + +from airflow import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + + +class DatasetHook(GoogleBaseHook): + """Hook for Google Cloud Vertex AI Dataset APIs.""" + + def get_dataset_service_client(self, region: Optional[str] = None) -> DatasetServiceClient: + """Returns DatasetServiceClient.""" + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'} + + return DatasetServiceClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def wait_for_operation(self, timeout: float, operation: Operation): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @GoogleBaseHook.fallback_to_default_project_id + def create_dataset( + self, + project_id: str, + region: str, + dataset: Dataset, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Creates a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The Dataset to create. + :type dataset: google.cloud.aiplatform_v1.types.Dataset + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_dataset( + request={ + 'parent': parent, + 'dataset': dataset, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_dataset( + self, + project_id: str, + region: str, + dataset: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deletes a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The ID of the Dataset to delete. + :type dataset: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.delete_dataset( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def export_data( + self, + project_id: str, + region: str, + dataset: str, + export_config: ExportDataConfig, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Exports data from a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The ID of the Dataset to export. + :type dataset: str + :param export_config: Required. The desired output location. + :type export_config: google.cloud.aiplatform_v1.types.ExportDataConfig + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.export_data( + request={ + 'name': name, + 'export_config': export_config, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_annotation_spec( + self, + project_id: str, + region: str, + dataset: str, + annotation_spec: str, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> AnnotationSpec: + """ + Gets an AnnotationSpec. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The ID of the Dataset. + :type dataset: str + :param annotation_spec: The ID of the AnnotationSpec resource. + :type annotation_spec: str + :param read_mask: Optional. Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + name = client.annotation_spec_path(project_id, region, dataset, annotation_spec) + + result = client.get_annotation_spec( + request={ + 'name': name, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_dataset( + self, + project_id: str, + region: str, + dataset: str, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Dataset: + """ + Gets a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The ID of the Dataset to export. + :type dataset: str + :param read_mask: Optional. Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.get_dataset( + request={ + 'name': name, + 'read_mask': read_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def import_data( + self, + project_id: str, + region: str, + dataset: str, + import_configs: Sequence[ImportDataConfig], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Imports data into a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The ID of the Dataset to import. + :type dataset: str + :param import_configs: Required. The desired input locations. The contents of all input locations + will be imported in one batch. + :type import_configs: Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig] + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + name = client.dataset_path(project_id, region, dataset) + + result = client.import_data( + request={ + 'name': name, + 'import_configs': import_configs, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_annotations( + self, + project_id: str, + region: str, + dataset: str, + data_item: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListAnnotationsPager: + """ + Lists Annotations belongs to a dataitem + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The ID of the Dataset. + :type dataset: str + :param data_item: Required. The ID of the DataItem to list Annotations from. + :type data_item: str + :param filter: The standard list filter. + :type filter: str + :param page_size: The standard list page size. + :type page_size: int + :param page_token: The standard list page token. + :type page_token: str + :param read_mask: Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :type order_by: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + parent = client.data_item_path(project_id, region, dataset, data_item) + + result = client.list_annotations( + request={ + 'parent': parent, + 'filter': filter, + 'page_size': page_size, + 'page_token': page_token, + 'read_mask': read_mask, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_data_items( + self, + project_id: str, + region: str, + dataset: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListDataItemsPager: + """ + Lists DataItems in a Dataset. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The ID of the Dataset. + :type dataset: str + :param filter: The standard list filter. + :type filter: str + :param page_size: The standard list page size. + :type page_size: int + :param page_token: The standard list page token. + :type page_token: str + :param read_mask: Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :type order_by: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + parent = client.dataset_path(project_id, region, dataset) + + result = client.list_data_items( + request={ + 'parent': parent, + 'filter': filter, + 'page_size': page_size, + 'page_token': page_token, + 'read_mask': read_mask, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_datasets( + self, + project_id: str, + region: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListDatasetsPager: + """ + Lists Datasets in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: The standard list filter. + :type filter: str + :param page_size: The standard list page size. + :type page_size: int + :param page_token: The standard list page token. + :type page_token: str + :param read_mask: Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :type order_by: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_datasets( + request={ + 'parent': parent, + 'filter': filter, + 'page_size': page_size, + 'page_token': page_token, + 'read_mask': read_mask, + 'order_by': order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def update_dataset( + self, + region: str, + dataset: Dataset, + update_mask: FieldMask, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Dataset: + """ + Updates a Dataset. + + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The Dataset which replaces the resource on the server. + :type dataset: google.cloud.aiplatform_v1.types.Dataset + :param update_mask: Required. The update mask applies to the resource. For the ``FieldMask`` + definition, see [google.protobuf.FieldMask][google.protobuf.FieldMask]. + Updatable fields: + - ``display_name`` + - ``description`` + - ``labels`` + :type update_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_dataset_service_client(region) + + result = client.update_dataset( + request={ + 'dataset': dataset, + 'update_mask': update_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/tests/providers/google/cloud/hooks/vertex_ai/__init__.py b/tests/providers/google/cloud/hooks/vertex_ai/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/google/cloud/hooks/vertex_ai/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/google/cloud/hooks/test_vertex_ai.py b/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py similarity index 89% rename from tests/providers/google/cloud/hooks/test_vertex_ai.py rename to tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py index 4fc21f93b401e..2b12fbf799e06 100644 --- a/tests/providers/google/cloud/hooks/test_vertex_ai.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py @@ -34,17 +34,17 @@ TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline" BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" -VERTEX_AI_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.custom_job.{}" +CUSTOM_JOB_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.custom_job.{}" -class TestVertexAIWithDefaultProjectIdHook(TestCase): +class TestCustomJobWithDefaultProjectIdHook(TestCase): def setUp(self): with mock.patch( BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id ): self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_pipeline_job(self, mock_client) -> None: self.hook.cancel_pipeline_job( project_id=TEST_PROJECT_ID, @@ -64,7 +64,7 @@ def test_cancel_pipeline_job(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_training_pipeline(self, mock_client) -> None: self.hook.cancel_training_pipeline( project_id=TEST_PROJECT_ID, @@ -84,7 +84,7 @@ def test_cancel_training_pipeline(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_pipeline_job(self, mock_client) -> None: self.hook.create_pipeline_job( project_id=TEST_PROJECT_ID, @@ -105,7 +105,7 @@ def test_create_pipeline_job(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_training_pipeline(self, mock_client) -> None: self.hook.create_training_pipeline( project_id=TEST_PROJECT_ID, @@ -124,7 +124,7 @@ def test_create_training_pipeline(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_pipeline_job(self, mock_client) -> None: self.hook.delete_pipeline_job( project_id=TEST_PROJECT_ID, @@ -144,7 +144,7 @@ def test_delete_pipeline_job(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_training_pipeline(self, mock_client) -> None: self.hook.delete_training_pipeline( project_id=TEST_PROJECT_ID, @@ -164,7 +164,7 @@ def test_delete_training_pipeline(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_pipeline_job(self, mock_client) -> None: self.hook.get_pipeline_job( project_id=TEST_PROJECT_ID, @@ -184,7 +184,7 @@ def test_get_pipeline_job(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_training_pipeline(self, mock_client) -> None: self.hook.get_training_pipeline( project_id=TEST_PROJECT_ID, @@ -204,7 +204,7 @@ def test_get_training_pipeline(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_pipeline_jobs(self, mock_client) -> None: self.hook.list_pipeline_jobs( project_id=TEST_PROJECT_ID, @@ -225,7 +225,7 @@ def test_list_pipeline_jobs(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_training_pipelines(self, mock_client) -> None: self.hook.list_training_pipelines( project_id=TEST_PROJECT_ID, @@ -247,14 +247,14 @@ def test_list_training_pipelines(self, mock_client) -> None: mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) -class TestVertexAIWithoutDefaultProjectIdHook(TestCase): +class TestCustomJobWithoutDefaultProjectIdHook(TestCase): def setUp(self): with mock.patch( BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id ): self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_pipeline_job(self, mock_client) -> None: self.hook.cancel_pipeline_job( project_id=TEST_PROJECT_ID, @@ -274,7 +274,7 @@ def test_cancel_pipeline_job(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_cancel_training_pipeline(self, mock_client) -> None: self.hook.cancel_training_pipeline( project_id=TEST_PROJECT_ID, @@ -294,7 +294,7 @@ def test_cancel_training_pipeline(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_pipeline_job(self, mock_client) -> None: self.hook.create_pipeline_job( project_id=TEST_PROJECT_ID, @@ -315,7 +315,7 @@ def test_create_pipeline_job(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_create_training_pipeline(self, mock_client) -> None: self.hook.create_training_pipeline( project_id=TEST_PROJECT_ID, @@ -334,7 +334,7 @@ def test_create_training_pipeline(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_pipeline_job(self, mock_client) -> None: self.hook.delete_pipeline_job( project_id=TEST_PROJECT_ID, @@ -354,7 +354,7 @@ def test_delete_pipeline_job(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_delete_training_pipeline(self, mock_client) -> None: self.hook.delete_training_pipeline( project_id=TEST_PROJECT_ID, @@ -374,7 +374,7 @@ def test_delete_training_pipeline(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_pipeline_job(self, mock_client) -> None: self.hook.get_pipeline_job( project_id=TEST_PROJECT_ID, @@ -394,7 +394,7 @@ def test_get_pipeline_job(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_get_training_pipeline(self, mock_client) -> None: self.hook.get_training_pipeline( project_id=TEST_PROJECT_ID, @@ -414,7 +414,7 @@ def test_get_training_pipeline(self, mock_client) -> None: TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME ) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_pipeline_jobs(self, mock_client) -> None: self.hook.list_pipeline_jobs( project_id=TEST_PROJECT_ID, @@ -435,7 +435,7 @@ def test_list_pipeline_jobs(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) - @mock.patch(VERTEX_AI_STRING.format("CustomJobHook.get_pipeline_service_client")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client")) def test_list_training_pipelines(self, mock_client) -> None: self.hook.list_training_pipelines( project_id=TEST_PROJECT_ID, diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py b/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py new file mode 100644 index 0000000000000..76efb94172f93 --- /dev/null +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py @@ -0,0 +1,494 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from unittest import TestCase, mock + +from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook +from tests.providers.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id, +) + +TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_REGION: str = "test-region" +TEST_PROJECT_ID: str = "test-project-id" +TEST_PIPELINE_JOB: dict = {} +TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id" +TEST_TRAINING_PIPELINE: dict = {} +TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline" +TEST_DATASET: dict = {} +TEST_DATASET_ID: str = "test-dataset-id" +TEST_EXPORT_CONFIG: dict = {} +TEST_ANNOTATION_SPEC: str = "test-annotation-spec" +TEST_IMPORT_CONFIGS: dict = {} +TEST_DATA_ITEM: str = "test-data-item" +TEST_UPDATE_MASK: dict = {} + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +DATASET_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.dataset.{}" + + +class TestVertexAIWithDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = DatasetHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_create_dataset(self, mock_client) -> None: + self.hook.create_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_dataset.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + dataset=TEST_DATASET, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_delete_dataset(self, mock_client) -> None: + self.hook.delete_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_export_data(self, mock_client) -> None: + self.hook.export_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.export_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + export_config=TEST_EXPORT_CONFIG, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_annotation_spec(self, mock_client) -> None: + self.hook.get_annotation_spec( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + annotation_spec=TEST_ANNOTATION_SPEC, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_annotation_spec.assert_called_once_with( + request=dict( + name=mock_client.return_value.annotation_spec_path.return_value, + read_mask=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.annotation_spec_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_ANNOTATION_SPEC + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_dataset(self, mock_client) -> None: + self.hook.get_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + read_mask=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_import_data(self, mock_client) -> None: + self.hook.import_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIGS, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.import_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + import_configs=TEST_IMPORT_CONFIGS, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_annotations(self, mock_client) -> None: + self.hook.list_annotations( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + data_item=TEST_DATA_ITEM, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_annotations.assert_called_once_with( + request=dict( + parent=mock_client.return_value.data_item_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.data_item_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_DATA_ITEM + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_data_items(self, mock_client) -> None: + self.hook.list_data_items( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_data_items.assert_called_once_with( + request=dict( + parent=mock_client.return_value.dataset_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_datasets(self, mock_client) -> None: + self.hook.list_datasets( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_datasets.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_update_dataset(self, mock_client) -> None: + self.hook.update_dataset( + region=TEST_REGION, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.update_dataset.assert_called_once_with( + request=dict( + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ), + metadata=None, + retry=None, + timeout=None, + ) + + +class TestVertexAIWithoutDefaultProjectIdHook(TestCase): + def setUp(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id + ): + self.hook = DatasetHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_create_dataset(self, mock_client) -> None: + self.hook.create_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_dataset.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + dataset=TEST_DATASET, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_delete_dataset(self, mock_client) -> None: + self.hook.delete_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_export_data(self, mock_client) -> None: + self.hook.export_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.export_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + export_config=TEST_EXPORT_CONFIG, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_annotation_spec(self, mock_client) -> None: + self.hook.get_annotation_spec( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + annotation_spec=TEST_ANNOTATION_SPEC, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_annotation_spec.assert_called_once_with( + request=dict( + name=mock_client.return_value.annotation_spec_path.return_value, + read_mask=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.annotation_spec_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_ANNOTATION_SPEC + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_get_dataset(self, mock_client) -> None: + self.hook.get_dataset( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_dataset.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + read_mask=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_import_data(self, mock_client) -> None: + self.hook.import_data( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIGS, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.import_data.assert_called_once_with( + request=dict( + name=mock_client.return_value.dataset_path.return_value, + import_configs=TEST_IMPORT_CONFIGS, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_annotations(self, mock_client) -> None: + self.hook.list_annotations( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + data_item=TEST_DATA_ITEM, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_annotations.assert_called_once_with( + request=dict( + parent=mock_client.return_value.data_item_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.data_item_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_DATA_ITEM + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_data_items(self, mock_client) -> None: + self.hook.list_data_items( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + dataset=TEST_DATASET_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_data_items.assert_called_once_with( + request=dict( + parent=mock_client.return_value.dataset_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_list_datasets(self, mock_client) -> None: + self.hook.list_datasets( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_datasets.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + filter=None, + page_size=None, + page_token=None, + read_mask=None, + order_by=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) + def test_update_dataset(self, mock_client) -> None: + self.hook.update_dataset( + region=TEST_REGION, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.update_dataset.assert_called_once_with( + request=dict( + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + ), + metadata=None, + retry=None, + timeout=None, + ) From 48dbcdcfdd20a6b3242167143ce6d44b2eb23b1c Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Wed, 24 Nov 2021 09:04:54 +0000 Subject: [PATCH 06/20] Create Datasets operators fot Vertex AI --- .../cloud/example_dags/example_vertex_ai.py | 171 +++++- .../cloud/operators/vertex_ai/dataset.py | 487 ++++++++++++++++++ .../google/cloud/operators/test_vertex_ai.py | 210 ++++++++ .../cloud/operators/test_vertex_ai_system.py | 10 +- 4 files changed, 861 insertions(+), 17 deletions(-) create mode 100644 airflow/providers/google/cloud/operators/vertex_ai/dataset.py diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py index 6792171edbd95..2388b2b204d0e 100644 --- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -29,10 +29,18 @@ from uuid import uuid4 from airflow import models -from airflow.providers.google.cloud.operators.vertex_ai import ( - VertexAICreateCustomContainerTrainingJobOperator, - VertexAICreateCustomPythonPackageTrainingJobOperator, - VertexAICreateCustomTrainingJobOperator, +from airflow.providers.google.cloud.operators.vertex_ai.custom_job import ( + CreateCustomContainerTrainingJobOperator, + CreateCustomPythonPackageTrainingJobOperator, + CreateCustomTrainingJobOperator, +) +from airflow.providers.google.cloud.operators.vertex_ai.dataset import ( + CreateDatasetOperator, + DeleteDatasetOperator, + ExportDataOperator, + ImportDataOperator, + ListDatasetsOperator, + UpdateDatasetOperator, ) from airflow.utils.dates import days_ago @@ -40,10 +48,11 @@ PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") -REGION = os.environ.get("GCP_LOCATION", "europe-west1") +REGION = os.environ.get("GCP_LOCATION", "us-central1") BUCKET = os.environ.get("GCP_VERTEX_AI_BUCKET", "vertex-ai-system-tests") STAGING_BUCKET = f"gs://{BUCKET}" +BASE_OUTPUT_DIR = f"{STAGING_BUCKET}/models" DISPLAY_NAME = str(uuid4()) # Create random display name DISPLAY_NAME_2 = str(uuid4()) DISPLAY_NAME_3 = str(uuid4()) @@ -62,17 +71,61 @@ # DATASET = aiplatform.ImageDataset(RESOURCE_ID) if RESOURCE_ID else None COMMAND = ['python3', 'run_script.py'] COMMAND_2 = ['echo', 'Hello World'] -PYTHON_PACKAGE_GCS_URI = "gs://bucket3/custom-training-python-package/my_app/trainer-0.1.tar.gz" +GCS_DESTINATION = f"gs://{BUCKET}/output-dir/" +PYTHON_PACKAGE = "/files/trainer-0.1.tar.gz" +PYTHON_PACKAGE_CMDARGS = f"--model-dir={GCS_DESTINATION}" +PYTHON_PACKAGE_GCS_URI = "gs://test-vertex-ai-bucket/trainer-0.1.tar.gz" PYTHON_MODULE_NAME = "trainer.task" +IMAGE_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml", + "metadata": "test-image-dataset", +} +TABULAR_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml", + "metadata": "test-tabular-dataset", +} +TEXT_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml", + "metadata": "test-text-dataset", +} +VIDEO_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml", + "metadata": "test-video-dataset", +} +TIME_SERIES_DATASET = { + "display_name": str(uuid4()), + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/time_series_1.0.0.yaml", + "metadata": "test-video-dataset", +} +DATASET_ID = "3255741890774958080" +TEST_EXPORT_CONFIG = {"gcs_destination": {"output_uri_prefix": "gs://test-vertex-ai-bucket/exports"}} +TEST_IMPORT_CONFIG = [ + { + "data_item_labels": { + "test-labels-name": "test-labels-value", + }, + "import_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml", + "gcs_source": { + "uris": ["gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl"] + }, + }, +] +TEST_UPDATE_MASK = { + "paths": "display_name", +} with models.DAG( - "example_gcp_vertex_ai", + "example_gcp_vertex_ai_custom_jobs", start_date=days_ago(1), schedule_interval="@once", -) as dag: +) as custom_jobs_dag: # [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator] - create_custom_container_training_job = VertexAICreateCustomContainerTrainingJobOperator( + create_custom_container_training_job = CreateCustomContainerTrainingJobOperator( task_id="custom_container_task", staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -80,7 +133,6 @@ container_uri=CONTAINER_URI, model_serving_container_image_uri=CONTAINER_URI, command=COMMAND_2, - # dataset=DATASET, model_display_name=DISPLAY_NAME_2, replica_count=REPLICA_COUNT, machine_type=MACHINE_TYPE, @@ -91,11 +143,12 @@ test_fraction_split=TEST_FRACTION_SPLIT, region=REGION, project_id=PROJECT_ID, + base_output_dir=BASE_OUTPUT_DIR, ) # [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] # [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] - create_custom_python_package_training_job = VertexAICreateCustomPythonPackageTrainingJobOperator( + create_custom_python_package_training_job = CreateCustomPythonPackageTrainingJobOperator( task_id="python_package_task", staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME_3, @@ -104,7 +157,6 @@ container_uri=CONTAINER_URI, args=ARGS, model_serving_container_image_uri=CONTAINER_URI, - # dataset=DATASET, model_display_name=DISPLAY_NAME_4, replica_count=REPLICA_COUNT, machine_type=MACHINE_TYPE, @@ -119,10 +171,101 @@ # [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] # [START how_to_cloud_vertex_ai_create_custom_training_job_operator] - create_custom_training_job = VertexAICreateCustomTrainingJobOperator( + create_custom_training_job = CreateCustomTrainingJobOperator( task_id="custom_task", - # TODO: add parameters from example + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + script_path=PYTHON_PACKAGE, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + requirements=[], + replica_count=1, region=REGION, project_id=PROJECT_ID, ) # [END how_to_cloud_vertex_ai_create_custom_training_job_operator] + +with models.DAG( + "example_gcp_vertex_ai_dataset", + start_date=days_ago(1), + schedule_interval="@once", +) as dataset_dag: + # [START how_to_cloud_vertex_ai_create_dataset_operator] + create_image_dataset_job = CreateDatasetOperator( + task_id="image_dataset", + dataset=IMAGE_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_tabular_dataset_job = CreateDatasetOperator( + task_id="tabular_dataset", + dataset=TABULAR_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_text_dataset_job = CreateDatasetOperator( + task_id="text_dataset", + dataset=TEXT_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_video_dataset_job = CreateDatasetOperator( + task_id="video_dataset", + dataset=VIDEO_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + create_time_series_dataset_job = CreateDatasetOperator( + task_id="time_series_dataset", + dataset=TIME_SERIES_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_create_dataset_operator] + + # # [START how_to_cloud_vertex_ai_delete_dataset_operator] + delete_dataset_job = DeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=DATASET_ID, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_delete_dataset_operator] + + # [START how_to_cloud_vertex_ai_export_data_operator] + export_data_job = ExportDataOperator( + task_id="export_data", + dataset_id="7732319920381231104", + region=REGION, + project_id=PROJECT_ID, + export_config=TEST_EXPORT_CONFIG, + ) + # [END how_to_cloud_vertex_ai_export_datas_operator] + + # [START how_to_cloud_vertex_ai_import_data_operator] + import_data_job = ImportDataOperator( + task_id="import_data", + dataset_id="7732319920381231104", + region=REGION, + project_id=PROJECT_ID, + import_configs=TEST_IMPORT_CONFIG, + ) + # [END how_to_cloud_vertex_ai_import_data_operator] + + # [START how_to_cloud_vertex_ai_list_dataset_operator] + list_dataset_job = ListDatasetsOperator( + task_id="list_dataset", + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_list_dataset_operator] + + # [START how_to_cloud_vertex_ai_update_dataset_operator] + update_dataset_job = UpdateDatasetOperator( + task_id="update_dataset", + region=REGION, + dataset=TEXT_DATASET, + update_mask=TEST_UPDATE_MASK, + ) + # [END how_to_cloud_vertex_ai_update_dataset_operator] diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py new file mode 100644 index 0000000000000..5355e49078744 --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -0,0 +1,487 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Vertex AI operators.""" + +from typing import Dict, Optional, Sequence, Tuple, Union + +from google.api_core.exceptions import NotFound +from google.api_core.retry import Retry +from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig +from google.protobuf.field_mask_pb2 import FieldMask + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook + + +class CreateDatasetOperator(BaseOperator): + """ + Creates a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param dataset: Required. The Dataset to create. This corresponds to the ``dataset`` field on the + ``request`` instance; if ``request`` is provided, this should not be set. + :type dataset: google.cloud.aiplatform_v1.types.Dataset + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + dataset: Dataset, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset = dataset + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + + self.log.info("Creating dataset") + operation = hook.create_dataset( + project_id=self.project_id, + region=self.region, + dataset=self.dataset, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = hook.wait_for_operation(self.timeout, operation) + self.log.info("Dataset was created.") + + return Dataset.to_dict(result) + + +class DeleteDatasetOperator(BaseOperator): + """ + Deletes a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param dataset_id: Required. The ID of the Dataset to delete. + :type dataset_id: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + + try: + self.log.info("Deleting dataset: %s", self.dataset_id) + operation = hook.delete_dataset( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Dataset was deleted.") + except NotFound: + self.log.info("The Dataset ID %s does not exist.", self.dataset_id) + + +class ExportDataOperator(BaseOperator): + """ + Exports data from a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param dataset_id: Required. The ID of the Dataset to delete. + :type dataset_id: str + :param export_config: Required. The desired output location. + :type export_config: google.cloud.aiplatform_v1.types.ExportDataConfig + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + export_config: ExportDataConfig, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.export_config = export_config + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + + self.log.info("Exporting data: %s", self.dataset_id) + operation = hook.export_data( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + export_config=self.export_config, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Export was done successfully") + + +class ImportDataOperator(BaseOperator): + """ + Imports data into a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param dataset_id: Required. The ID of the Dataset to delete. + :type dataset_id: str + :param import_configs: Required. The desired input locations. The contents of all input locations will be + imported in one batch. + :type import_configs: Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig] + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + import_configs: Sequence[ImportDataConfig], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.import_configs = import_configs + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + + self.log.info("Importing data") + operation = hook.import_data( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + import_configs=self.import_configs, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Import was done successfully") + + +class ListDatasetsOperator(BaseOperator): + """ + Lists Datasets in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: The standard list filter. + :type filter: str + :param page_size: The standard list page size. + :type page_size: int + :param page_token: The standard list page token. + :type page_token: str + :param read_mask: Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" + after a field name for descending. + :type order_by: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: str, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + read_mask: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.filter = filter + self.page_size = page_size + self.page_token = page_token + self.read_mask = read_mask + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + results = hook.list_datasets( + project_id=self.project_id, + region=self.region, + filter=self.filter, + page_size=self.page_size, + page_token=self.page_token, + read_mask=self.read_mask, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [Dataset.to_dict(result) for result in results] + + +class UpdateDatasetOperator(BaseOperator): + """ + Updates a Dataset. + + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param dataset: Required. The Dataset which replaces the resource on the server. + :type dataset: google.cloud.aiplatform_v1.types.Dataset + :param update_mask: Required. The update mask applies to the resource. For the ``FieldMask`` definition, + see [google.protobuf.FieldMask][google.protobuf.FieldMask]. Updatable fields: + - ``display_name`` + - ``description`` + - ``labels`` + :type update_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("region", "impersonation_chain") + + def __init__( + self, + *, + region: str, + dataset: Dataset, + update_mask: FieldMask, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.dataset = dataset + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Updating dataset") + result = hook.update_dataset( + region=self.region, + dataset=self.dataset, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Dataset was updated") + return Dataset.to_dict(result) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 34a3afcf95a6e..8fb73133df39a 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -26,6 +26,14 @@ DeleteCustomTrainingJobOperator, ListCustomTrainingJobOperator, ) +from airflow.providers.google.cloud.operators.vertex_ai.dataset import ( + CreateDatasetOperator, + DeleteDatasetOperator, + ExportDataOperator, + ImportDataOperator, + ListDatasetsOperator, + UpdateDatasetOperator, +) VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" TIMEOUT = 120 @@ -61,6 +69,28 @@ PYTHON_PACKAGE_GCS_URI = "gs://test-vertex-ai-bucket/trainer-0.1.tar.gz" PYTHON_MODULE_NAME = "trainer.task" +TEST_DATASET = { + "display_name": "test-dataset-name", + "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml", + "metadata": "test-image-dataset", +} +TEST_DATASET_ID = "test-dataset-id" +TEST_EXPORT_CONFIG = { + "annotationsFilter": "test-filter", + "gcs_destination": {"output_uri_prefix": "airflow-system-tests-data"}, +} +TEST_IMPORT_CONFIG = [ + { + "data_item_labels": { + "test-labels-name": "test-labels-value", + }, + "import_schema_uri": "test-shema-uri", + "gcs_source": {"uris": ['test-string']}, + }, + {}, +] +TEST_UPDATE_MASK = "test-update-mask" + class TestVertexAICreateCustomContainerTrainingJobOperator: @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) @@ -359,3 +389,183 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) + + +class TestVertexAICreateDatasetOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = CreateDatasetOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_dataset.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIDeleteDatasetOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = DeleteDatasetOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.delete_dataset.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIExportDataOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ExportDataOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.export_data.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET_ID, + export_config=TEST_EXPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIImportDataOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ImportDataOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.import_data.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET_ID, + import_configs=TEST_IMPORT_CONFIG, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIListDatasetsOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + page_token = "page_token" + page_size = 42 + filter = "filter" + read_mask = "read_mask" + order_by = "order_by" + + op = ListDatasetsOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + filter=filter, + page_size=page_size, + page_token=page_token, + read_mask=read_mask, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.list_datasets.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + filter=filter, + page_size=page_size, + page_token=page_token, + read_mask=read_mask, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestVertexAIUpdateDatasetOperator: + @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = UpdateDatasetOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_dataset.assert_called_once_with( + region=GCP_LOCATION, + dataset=TEST_DATASET, + update_mask=TEST_UPDATE_MASK, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai_system.py b/tests/providers/google/cloud/operators/test_vertex_ai_system.py index 69e8527a750b7..2c641fa1c16eb 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai_system.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai_system.py @@ -24,7 +24,7 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_VERTEX_AI_KEY) -class DataprocExampleDagsTest(GoogleSystemTest): +class VertexAIExampleDagsTest(GoogleSystemTest): @provide_gcp_context(GCP_VERTEX_AI_KEY) def setUp(self): super().setUp() @@ -36,5 +36,9 @@ def tearDown(self): super().tearDown() @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai", dag_folder=CLOUD_DAG_FOLDER) + def test_run_custom_jobs_example_dag(self): + self.run_dag(dag_id="example_gcp_vertex_ai_custom_jobs", dag_folder=CLOUD_DAG_FOLDER) + + @provide_gcp_context(GCP_VERTEX_AI_KEY) + def test_run_dataset_example_dag(self): + self.run_dag(dag_id="example_gcp_vertex_ai_dataset", dag_folder=CLOUD_DAG_FOLDER) From 5c21363517ad211b27adb64d1b8229b9de80848a Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Thu, 2 Dec 2021 15:06:13 +0000 Subject: [PATCH 07/20] Create system tests fot Vertex AI --- .../cloud/example_dags/example_vertex_ai.py | 102 +++++----- .../cloud/hooks/vertex_ai/custom_job.py | 8 +- .../google/cloud/hooks/vertex_ai/dataset.py | 14 +- .../cloud/operators/vertex_ai/custom_job.py | 110 +++++++++-- .../cloud/operators/vertex_ai/dataset.py | 177 +++++++++++++++++- airflow/providers/google/provider.yaml | 2 + .../cloud/hooks/vertex_ai/test_dataset.py | 10 + .../google/cloud/operators/test_vertex_ai.py | 16 +- .../cloud/operators/test_vertex_ai_system.py | 3 - 9 files changed, 360 insertions(+), 82 deletions(-) diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py index 2388b2b204d0e..5499c04678fa6 100644 --- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -25,7 +25,6 @@ * GCP_BUCKET_NAME - Google Cloud Storage bucket where the file exists. """ import os -from random import randint from uuid import uuid4 from airflow import models @@ -38,28 +37,22 @@ CreateDatasetOperator, DeleteDatasetOperator, ExportDataOperator, + GetDatasetOperator, ImportDataOperator, ListDatasetsOperator, UpdateDatasetOperator, ) from airflow.utils.dates import days_ago -# from google.cloud import aiplatform - - PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") REGION = os.environ.get("GCP_LOCATION", "us-central1") BUCKET = os.environ.get("GCP_VERTEX_AI_BUCKET", "vertex-ai-system-tests") STAGING_BUCKET = f"gs://{BUCKET}" -BASE_OUTPUT_DIR = f"{STAGING_BUCKET}/models" DISPLAY_NAME = str(uuid4()) # Create random display name -DISPLAY_NAME_2 = str(uuid4()) -DISPLAY_NAME_3 = str(uuid4()) -DISPLAY_NAME_4 = str(uuid4()) -ARGS = ["--tfds", "tf_flowers:3.*.*"] CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" -RESOURCE_ID = str(randint(10000000, 99999999)) # Create random resource ID +CUSTOM_CONTAINER_URI = os.environ.get("CUSTOM_CONTAINER_URI", "path_to_container_with_model") +MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest" REPLICA_COUNT = 1 MACHINE_TYPE = "n1-standard-4" ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" @@ -67,15 +60,11 @@ TRAINING_FRACTION_SPLIT = 0.7 TEST_FRACTION_SPLIT = 0.15 VALIDATION_FRACTION_SPLIT = 0.15 -# This example uses an ImageDataset, but you can use another type -# DATASET = aiplatform.ImageDataset(RESOURCE_ID) if RESOURCE_ID else None -COMMAND = ['python3', 'run_script.py'] -COMMAND_2 = ['echo', 'Hello World'] -GCS_DESTINATION = f"gs://{BUCKET}/output-dir/" -PYTHON_PACKAGE = "/files/trainer-0.1.tar.gz" -PYTHON_PACKAGE_CMDARGS = f"--model-dir={GCS_DESTINATION}" -PYTHON_PACKAGE_GCS_URI = "gs://test-vertex-ai-bucket/trainer-0.1.tar.gz" -PYTHON_MODULE_NAME = "trainer.task" + +PYTHON_PACKAGE_GCS_URI = os.environ.get("PYTHON_PACKAGE_GSC_URI", "path_to_test_model_in_arch") +PYTHON_MODULE_NAME = "aiplatform_custom_trainer_script.task" + +LOCAL_TRAINING_SCRIPT_PATH = os.environ.get("LOCAL_TRAINING_SCRIPT_PATH", "path_to_training_script") IMAGE_DATASET = { "display_name": str(uuid4()), @@ -102,22 +91,23 @@ "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/time_series_1.0.0.yaml", "metadata": "test-video-dataset", } -DATASET_ID = "3255741890774958080" +DATASET_ID = os.environ.get("DATASET_ID", "test-dataset-id") TEST_EXPORT_CONFIG = {"gcs_destination": {"output_uri_prefix": "gs://test-vertex-ai-bucket/exports"}} TEST_IMPORT_CONFIG = [ { "data_item_labels": { "test-labels-name": "test-labels-value", }, - "import_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml", + "import_schema_uri": ( + "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" + ), "gcs_source": { "uris": ["gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl"] }, }, ] -TEST_UPDATE_MASK = { - "paths": "display_name", -} +DATASET_TO_UPDATE = {"display_name": "test-name"} +TEST_UPDATE_MASK = {"paths": ["displayName"]} with models.DAG( "example_gcp_vertex_ai_custom_jobs", @@ -128,12 +118,13 @@ create_custom_container_training_job = CreateCustomContainerTrainingJobOperator( task_id="custom_container_task", staging_bucket=STAGING_BUCKET, - display_name=DISPLAY_NAME, - args=ARGS, - container_uri=CONTAINER_URI, - model_serving_container_image_uri=CONTAINER_URI, - command=COMMAND_2, - model_display_name=DISPLAY_NAME_2, + display_name=f"train-housing-container-{DISPLAY_NAME}", + container_uri=CUSTOM_CONTAINER_URI, + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=DATASET_ID, + command=["python3", "task.py"], + model_display_name=f"container-housing-model-{DISPLAY_NAME}", replica_count=REPLICA_COUNT, machine_type=MACHINE_TYPE, accelerator_type=ACCELERATOR_TYPE, @@ -143,7 +134,6 @@ test_fraction_split=TEST_FRACTION_SPLIT, region=REGION, project_id=PROJECT_ID, - base_output_dir=BASE_OUTPUT_DIR, ) # [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] @@ -151,13 +141,14 @@ create_custom_python_package_training_job = CreateCustomPythonPackageTrainingJobOperator( task_id="python_package_task", staging_bucket=STAGING_BUCKET, - display_name=DISPLAY_NAME_3, + display_name=f"train-housing-py-package-{DISPLAY_NAME}", python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, python_module_name=PYTHON_MODULE_NAME, container_uri=CONTAINER_URI, - args=ARGS, - model_serving_container_image_uri=CONTAINER_URI, - model_display_name=DISPLAY_NAME_4, + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=DATASET_ID, + model_display_name=f"py-package-housing-model-{DISPLAY_NAME}", replica_count=REPLICA_COUNT, machine_type=MACHINE_TYPE, accelerator_type=ACCELERATOR_TYPE, @@ -174,13 +165,16 @@ create_custom_training_job = CreateCustomTrainingJobOperator( task_id="custom_task", staging_bucket=STAGING_BUCKET, - display_name=DISPLAY_NAME, - script_path=PYTHON_PACKAGE, - args=PYTHON_PACKAGE_CMDARGS, + display_name=f"train-housing-custom-{DISPLAY_NAME}", + script_path=LOCAL_TRAINING_SCRIPT_PATH, container_uri=CONTAINER_URI, - model_serving_container_image_uri=CONTAINER_URI, - requirements=[], + requirements=["gcsfs==0.7.1"], + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=DATASET_ID, replica_count=1, + model_display_name=f"custom-housing-model-{DISPLAY_NAME}", + sync=False, region=REGION, project_id=PROJECT_ID, ) @@ -224,19 +218,28 @@ ) # [END how_to_cloud_vertex_ai_create_dataset_operator] - # # [START how_to_cloud_vertex_ai_delete_dataset_operator] + # [START how_to_cloud_vertex_ai_delete_dataset_operator] delete_dataset_job = DeleteDatasetOperator( task_id="delete_dataset", - dataset_id=DATASET_ID, + dataset_id=create_text_dataset_job.output['dataset_id'], region=REGION, project_id=PROJECT_ID, ) # [END how_to_cloud_vertex_ai_delete_dataset_operator] + # [START how_to_cloud_vertex_ai_get_dataset_operator] + get_dataset = GetDatasetOperator( + task_id="get_dataset", + project_id=PROJECT_ID, + region=REGION, + dataset_id=create_tabular_dataset_job.output['dataset_id'], + ) + # [END how_to_cloud_vertex_ai_get_dataset_operator] + # [START how_to_cloud_vertex_ai_export_data_operator] export_data_job = ExportDataOperator( task_id="export_data", - dataset_id="7732319920381231104", + dataset_id=create_image_dataset_job.output['dataset_id'], region=REGION, project_id=PROJECT_ID, export_config=TEST_EXPORT_CONFIG, @@ -246,7 +249,7 @@ # [START how_to_cloud_vertex_ai_import_data_operator] import_data_job = ImportDataOperator( task_id="import_data", - dataset_id="7732319920381231104", + dataset_id=create_image_dataset_job.output['dataset_id'], region=REGION, project_id=PROJECT_ID, import_configs=TEST_IMPORT_CONFIG, @@ -264,8 +267,17 @@ # [START how_to_cloud_vertex_ai_update_dataset_operator] update_dataset_job = UpdateDatasetOperator( task_id="update_dataset", + project_id=PROJECT_ID, region=REGION, - dataset=TEXT_DATASET, + dataset_id=create_video_dataset_job.output['dataset_id'], + dataset=DATASET_TO_UPDATE, update_mask=TEST_UPDATE_MASK, ) # [END how_to_cloud_vertex_ai_update_dataset_operator] + + create_time_series_dataset_job + create_text_dataset_job >> delete_dataset_job + create_tabular_dataset_job >> get_dataset + create_image_dataset_job >> import_data_job >> export_data_job + create_video_dataset_job >> update_dataset_job + list_dataset_job diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index b8f608d9d4926..9477ac44401e2 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -217,6 +217,11 @@ def get_custom_training_job( staging_bucket=staging_bucket, ) + @staticmethod + def extract_model_id(obj: Dict) -> str: + """Returns unique id of the Model.""" + return obj["name"].rpartition("/")[-1] + def wait_for_operation(self, timeout: float, operation: Operation): """Waits for long-lasting operation to complete.""" try: @@ -267,7 +272,6 @@ def _run_job( sync=True, ) -> Model: """Run Job for training pipeline""" - self.log.info("START RUN JOB") model = job.run( dataset=dataset, annotation_schema_uri=annotation_schema_uri, @@ -296,9 +300,7 @@ def _run_job( tensorboard=tensorboard, sync=sync, ) - self.log.info(f"END RUN JOB. {model}") model.wait() - self.log.info("STOP WAIT") return model @GoogleBaseHook.fallback_to_default_project_id diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py index b8e96c771ff20..f624a06bdda17 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py @@ -18,7 +18,7 @@ # """This module contains a Google Cloud Vertex AI hook.""" -from typing import Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple from google.api_core.operation import Operation from google.api_core.retry import Retry @@ -56,6 +56,11 @@ def wait_for_operation(self, timeout: float, operation: Operation): error = operation.exception(timeout=timeout) raise AirflowException(error) + @staticmethod + def extract_dataset_id(obj: Dict) -> str: + """Returns unique id of the dataset.""" + return obj["name"].rpartition("/")[-1] + @GoogleBaseHook.fallback_to_default_project_id def create_dataset( self, @@ -493,7 +498,9 @@ def list_datasets( def update_dataset( self, + project_id: str, region: str, + dataset_id: str, dataset: Dataset, update_mask: FieldMask, retry: Optional[Retry] = None, @@ -503,8 +510,12 @@ def update_dataset( """ Updates a Dataset. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. :type region: str + :param dataset_id: Required. The ID of the Dataset. + :type dataset_id: str :param dataset: Required. The Dataset which replaces the resource on the server. :type dataset: google.cloud.aiplatform_v1.types.Dataset :param update_mask: Required. The update mask applies to the resource. For the ``FieldMask`` @@ -522,6 +533,7 @@ def update_dataset( :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) + dataset["name"] = client.dataset_path(project_id, region, dataset_id) result = client.update_dataset( request={ diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 05a3d756d8a08..e411d426cfe72 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -21,11 +21,55 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union from google.api_core.retry import Retry -from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.models import Model +from google.cloud.aiplatform_v1.types.dataset import Dataset -from airflow.models import BaseOperator +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook +VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" +VERTEX_AI_MODEL_LINK = ( + VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}" +) +VERTEX_AI_TRAINING_PIPELINES_LINK = VERTEX_AI_BASE_LINK + "/training/training-pipelines?project={project_id}" + + +class VertexAIModelLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Model link""" + + name = "Vertex AI Model" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + model_conf = ti.xcom_pull(task_ids=operator.task_id, key="model_conf") + return ( + VERTEX_AI_MODEL_LINK.format( + region=model_conf["region"], + model_id=model_conf["model_id"], + project_id=model_conf["project_id"], + ) + if model_conf + else "" + ) + + +class VertexAITrainingPipelinesLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Training Pipelines link""" + + name = "Vertex AI Training Pipelines" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + project_id = ti.xcom_pull(task_ids=operator.task_id, key="project_id") + return ( + VERTEX_AI_TRAINING_PIPELINES_LINK.format( + project_id=project_id, + ) + if project_id + else "" + ) + class _CustomTrainingJobBaseOperator(BaseOperator): """The base class for operators that launch Custom jobs on VertexAI.""" @@ -53,14 +97,7 @@ def __init__( model_encryption_spec_key_name: Optional[str] = None, staging_bucket: Optional[str] = None, # RUN - dataset: Optional[ - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ] = None, + dataset_id: Optional[str] = None, annotation_schema_uri: Optional[str] = None, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, @@ -113,7 +150,7 @@ def __init__( self.staging_bucket = staging_bucket # END Custom # START Run param - self.dataset = dataset + self.dataset = Dataset(name=dataset_id) if dataset_id else None self.annotation_schema_uri = annotation_schema_uri self.model_display_name = model_display_name self.model_labels = model_labels @@ -167,6 +204,7 @@ class CreateCustomContainerTrainingJobOperator(_CustomTrainingJobBaseOperator): 'command', 'impersonation_chain', ] + operator_extra_links = (VertexAIModelLink(),) def __init__( self, @@ -227,7 +265,17 @@ def execute(self, context): tensorboard=self.tensorboard, sync=True, ) - return model + model_id = self.hook.extract_model_id(model) + self.xcom_push( + context, + key="model_conf", + value={ + "model_id": model_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return Model.to_dict(model) class CreateCustomPythonPackageTrainingJobOperator(_CustomTrainingJobBaseOperator): @@ -237,6 +285,7 @@ class CreateCustomPythonPackageTrainingJobOperator(_CustomTrainingJobBaseOperato 'region', 'impersonation_chain', ] + operator_extra_links = (VertexAIModelLink(),) def __init__( self, @@ -300,7 +349,18 @@ def execute(self, context): tensorboard=self.tensorboard, sync=True, ) - return model + + model_id = self.hook.extract_model_id(model) + self.xcom_push( + context, + key="model_conf", + value={ + "model_id": model_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return Model.to_dict(model) class CreateCustomTrainingJobOperator(_CustomTrainingJobBaseOperator): @@ -312,6 +372,7 @@ class CreateCustomTrainingJobOperator(_CustomTrainingJobBaseOperator): 'requirements', 'impersonation_chain', ] + operator_extra_links = (VertexAIModelLink(),) def __init__( self, @@ -375,7 +436,18 @@ def execute(self, context): tensorboard=self.tensorboard, sync=True, ) - return model + + model_id = self.hook.extract_model_id(model) + self.xcom_push( + context, + key="model_conf", + value={ + "model_id": model_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return Model.to_dict(model) class DeleteCustomTrainingJobOperator(BaseOperator): @@ -431,7 +503,14 @@ def execute(self, context: Dict): class ListCustomTrainingJobOperator(BaseOperator): """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location.""" - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = [ + "region", + "project_id", + "impersonation_chain", + ] + operator_extra_links = [ + VertexAITrainingPipelinesLink(), + ] def __init__( self, @@ -475,3 +554,4 @@ def execute(self, context: Dict): timeout=self.timeout, metadata=self.metadata, ) + self.xcom_push(context, key="project_id", value=self.project_id) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index 5355e49078744..7ec2fdeea2ca9 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -25,9 +25,52 @@ from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig from google.protobuf.field_mask_pb2 import FieldMask -from airflow.models import BaseOperator +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook +VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" +VERTEX_AI_DATASET_LINK = ( + VERTEX_AI_BASE_LINK + "/locations/{region}/datasets/{dataset_id}/analyze?project={project_id}" +) +VERTEX_AI_DATASET_LIST_LINK = VERTEX_AI_BASE_LINK + "/datasets?project={project_id}" + + +class VertexAIDatasetLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Dataset link""" + + name = "Dataset" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + dataset_conf = ti.xcom_pull(task_ids=operator.task_id, key="dataset_conf") + return ( + VERTEX_AI_DATASET_LINK.format( + region=dataset_conf["region"], + model_id=dataset_conf["model_id"], + project_id=dataset_conf["project_id"], + ) + if dataset_conf + else "" + ) + + +class VertexAIDatasetListLink(BaseOperatorLink): + """Helper class for constructing Vertex AI Datasets Link""" + + name = "Dataset List" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + project_id = ti.xcom_pull(task_ids=operator.task_id, key="project_id") + return ( + VERTEX_AI_DATASET_LIST_LINK.format( + project_id=project_id, + ) + if project_id + else "" + ) + class CreateDatasetOperator(BaseOperator): """ @@ -60,6 +103,7 @@ class CreateDatasetOperator(BaseOperator): """ template_fields = ("region", "project_id", "impersonation_chain") + operator_extra_links = (VertexAIDatasetLink(),) def __init__( self, @@ -97,9 +141,108 @@ def execute(self, context: Dict): metadata=self.metadata, ) result = hook.wait_for_operation(self.timeout, operation) - self.log.info("Dataset was created.") - return Dataset.to_dict(result) + dataset = Dataset.to_dict(result) + dataset_id = hook.extract_dataset_id(dataset) + self.log.info("Dataset was created. Dataset id: %s", dataset_id) + + self.xcom_push(context, key="dataset_id", value=dataset_id) + self.xcom_push( + context, + key="dataset_conf", + value={ + "dataset_id": dataset_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return dataset + + +class GetDatasetOperator(BaseOperator): + """ + Get a Dataset. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param dataset_id: Required. The ID of the Dataset to get. + :type dataset_id: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") + operator_extra_links = (VertexAIDatasetLink(),) + + def __init__( + self, + *, + region: str, + project_id: str, + dataset_id: str, + read_mask: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> Dataset: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.dataset_id = dataset_id + self.read_mask = read_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + + try: + self.log.info("Get dataset: %s", self.dataset_id) + dataset_obj = hook.get_dataset( + project_id=self.project_id, + region=self.region, + dataset=self.dataset_id, + read_mask=self.read_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.xcom_push( + context, + key="dataset_conf", + value={ + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "region": self.region, + }, + ) + self.log.info("Dataset was gotten.") + return Dataset.to_dict(dataset_obj) + except NotFound: + self.log.info("The Dataset ID %s does not exist.", self.dataset_id) class DeleteDatasetOperator(BaseOperator): @@ -131,7 +274,7 @@ class DeleteDatasetOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") def __init__( self, @@ -206,7 +349,7 @@ class ExportDataOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") def __init__( self, @@ -282,7 +425,7 @@ class ImportDataOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") def __init__( self, @@ -312,7 +455,7 @@ def __init__( def execute(self, context: Dict): hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - self.log.info("Importing data") + self.log.info("Importing data: %s", self.dataset_id) operation = hook.import_data( project_id=self.project_id, region=self.region, @@ -365,6 +508,7 @@ class ListDatasetsOperator(BaseOperator): """ template_fields = ("region", "project_id", "impersonation_chain") + operator_extra_links = (VertexAIDatasetListLink(),) def __init__( self, @@ -411,6 +555,11 @@ def execute(self, context: Dict): timeout=self.timeout, metadata=self.metadata, ) + self.xcom_push( + context, + key="project_id", + value=self.project_id, + ) return [Dataset.to_dict(result) for result in results] @@ -418,8 +567,12 @@ class UpdateDatasetOperator(BaseOperator): """ Updates a Dataset. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. :type region: str + :param dataset_id: Required. The ID of the Dataset to update. + :type dataset_id: str :param dataset: Required. The Dataset which replaces the resource on the server. :type dataset: google.cloud.aiplatform_v1.types.Dataset :param update_mask: Required. The update mask applies to the resource. For the ``FieldMask`` definition, @@ -447,12 +600,14 @@ class UpdateDatasetOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("region", "impersonation_chain") + template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") def __init__( self, *, + project_id: str, region: str, + dataset_id: str, dataset: Dataset, update_mask: FieldMask, retry: Optional[Retry] = None, @@ -463,7 +618,9 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + self.project_id = project_id self.region = region + self.dataset_id = dataset_id self.dataset = dataset self.update_mask = update_mask self.retry = retry @@ -474,9 +631,11 @@ def __init__( def execute(self, context): hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - self.log.info("Updating dataset") + self.log.info("Updating dataset: %s", self.dataset_id) result = hook.update_dataset( + project_id=self.project_id, region=self.region, + dataset_id=self.dataset_id, dataset=self.dataset, update_mask=self.update_mask, retry=self.retry, diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 089ac62b7664b..8fae678cecef5 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -816,6 +816,8 @@ extra-links: - airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink - airflow.providers.google.cloud.operators.dataproc.DataprocJobLink - airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink + - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAIModelLink + - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAITrainingPipelinesLink additional-extras: apache.beam: apache-beam[gcp] diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py b/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py index 76efb94172f93..5a2908768016f 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py @@ -253,7 +253,9 @@ def test_list_datasets(self, mock_client) -> None: @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) def test_update_dataset(self, mock_client) -> None: self.hook.update_dataset( + project_id=TEST_PROJECT_ID, region=TEST_REGION, + dataset_id=TEST_DATASET_ID, dataset=TEST_DATASET, update_mask=TEST_UPDATE_MASK, ) @@ -267,6 +269,9 @@ def test_update_dataset(self, mock_client) -> None: retry=None, timeout=None, ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) class TestVertexAIWithoutDefaultProjectIdHook(TestCase): @@ -478,7 +483,9 @@ def test_list_datasets(self, mock_client) -> None: @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client")) def test_update_dataset(self, mock_client) -> None: self.hook.update_dataset( + project_id=TEST_PROJECT_ID, region=TEST_REGION, + dataset_id=TEST_DATASET_ID, dataset=TEST_DATASET, update_mask=TEST_UPDATE_MASK, ) @@ -492,3 +499,6 @@ def test_update_dataset(self, mock_client) -> None: retry=None, timeout=None, ) + mock_client.return_value.dataset_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 8fb73133df39a..a40091a1f3920 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -116,7 +116,7 @@ def test_execute(self, mock_hook): region=GCP_LOCATION, project_id=GCP_PROJECT, ) - op.execute(context={}) + op.execute(context={'ti': mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_custom_container_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, @@ -194,7 +194,7 @@ def test_execute(self, mock_hook): region=GCP_LOCATION, project_id=GCP_PROJECT, ) - op.execute(context={}) + op.execute(context={'ti': mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, @@ -266,7 +266,7 @@ def test_execute(self, mock_hook): region=GCP_LOCATION, project_id=GCP_PROJECT, ) - op.execute(context={}) + op.execute(context={'ti': mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_custom_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, @@ -376,7 +376,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={}) + op.execute(context={'ti': mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_training_pipelines.assert_called_once_with( region=GCP_LOCATION, @@ -406,7 +406,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={}) + op.execute(context={'ti': mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_dataset.assert_called_once_with( region=GCP_LOCATION, @@ -528,7 +528,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={}) + op.execute(context={'ti': mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_datasets.assert_called_once_with( region=GCP_LOCATION, @@ -552,7 +552,9 @@ def test_execute(self, mock_hook, to_dict_mock): task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + project_id=GCP_PROJECT, region=GCP_LOCATION, + dataset_id=TEST_DATASET_ID, dataset=TEST_DATASET, update_mask=TEST_UPDATE_MASK, retry=RETRY, @@ -562,7 +564,9 @@ def test_execute(self, mock_hook, to_dict_mock): op.execute(context={}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.update_dataset.assert_called_once_with( + project_id=GCP_PROJECT, region=GCP_LOCATION, + dataset_id=TEST_DATASET_ID, dataset=TEST_DATASET, update_mask=TEST_UPDATE_MASK, retry=RETRY, diff --git a/tests/providers/google/cloud/operators/test_vertex_ai_system.py b/tests/providers/google/cloud/operators/test_vertex_ai_system.py index 2c641fa1c16eb..84b84c33200c1 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai_system.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai_system.py @@ -17,7 +17,6 @@ # under the License. import pytest -from airflow.providers.google.cloud.example_dags.example_vertex_ai import BUCKET, REGION from tests.providers.google.cloud.utils.gcp_authenticator import GCP_VERTEX_AI_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context @@ -28,11 +27,9 @@ class VertexAIExampleDagsTest(GoogleSystemTest): @provide_gcp_context(GCP_VERTEX_AI_KEY) def setUp(self): super().setUp() - self.create_gcs_bucket(BUCKET, REGION) @provide_gcp_context(GCP_VERTEX_AI_KEY) def tearDown(self): - self.delete_gcs_bucket(BUCKET) super().tearDown() @provide_gcp_context(GCP_VERTEX_AI_KEY) From 420955a9f8e497f5c71a558bd4eae61a96239fb9 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Fri, 3 Dec 2021 09:12:02 +0000 Subject: [PATCH 08/20] Add links for Vertex AI operators --- .../cloud/operators/vertex_ai/custom_job.py | 16 ++++++++++------ .../google/cloud/operators/vertex_ai/dataset.py | 2 +- airflow/providers/google/provider.yaml | 2 ++ setup.py | 1 + 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index e411d426cfe72..2a8414b0f6ddd 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -265,7 +265,9 @@ def execute(self, context): tensorboard=self.tensorboard, sync=True, ) - model_id = self.hook.extract_model_id(model) + + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) self.xcom_push( context, key="model_conf", @@ -275,7 +277,7 @@ def execute(self, context): "project_id": self.project_id, }, ) - return Model.to_dict(model) + return result class CreateCustomPythonPackageTrainingJobOperator(_CustomTrainingJobBaseOperator): @@ -350,7 +352,8 @@ def execute(self, context): sync=True, ) - model_id = self.hook.extract_model_id(model) + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) self.xcom_push( context, key="model_conf", @@ -360,7 +363,7 @@ def execute(self, context): "project_id": self.project_id, }, ) - return Model.to_dict(model) + return result class CreateCustomTrainingJobOperator(_CustomTrainingJobBaseOperator): @@ -437,7 +440,8 @@ def execute(self, context): sync=True, ) - model_id = self.hook.extract_model_id(model) + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) self.xcom_push( context, key="model_conf", @@ -447,7 +451,7 @@ def execute(self, context): "project_id": self.project_id, }, ) - return Model.to_dict(model) + return result class DeleteCustomTrainingJobOperator(BaseOperator): diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index 7ec2fdeea2ca9..b2b04cd7efe73 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -47,7 +47,7 @@ def get_link(self, operator, dttm): return ( VERTEX_AI_DATASET_LINK.format( region=dataset_conf["region"], - model_id=dataset_conf["model_id"], + dataset_id=dataset_conf["dataset_id"], project_id=dataset_conf["project_id"], ) if dataset_conf diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 8fae678cecef5..2f392aa1d8d41 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -818,6 +818,8 @@ extra-links: - airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAIModelLink - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAITrainingPipelinesLink + - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetLink + - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetListLink additional-extras: apache.beam: apache-beam[gcp] diff --git a/setup.py b/setup.py index c9d4d0ba87d8d..45d402db59c2a 100644 --- a/setup.py +++ b/setup.py @@ -298,6 +298,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version # https://github.com/googleapis/google-cloud-python/issues/10566 'google-auth>=1.0.0,<3.0.0', 'google-auth-httplib2>=0.0.1', + 'google-cloud-aiplatform>=1.7.1,<2.0.0', 'google-cloud-automl>=2.1.0,<3.0.0', 'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0', 'google-cloud-bigtable>=1.0.0,<2.0.0', From 0e0ab9d48c0d64c1b4a67c1389a4f6226895725e Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Mon, 6 Dec 2021 16:20:13 +0000 Subject: [PATCH 09/20] Add how-to documentation for Vertex AI operators --- .../cloud/example_dags/example_vertex_ai.py | 23 +++ .../cloud/hooks/vertex_ai/custom_job.py | 31 +++- .../cloud/operators/vertex_ai/custom_job.py | 66 ++++---- .../apache-airflow-providers-google/index.rst | 1 + .../operators/cloud/vertex_ai.rst | 157 ++++++++++++++++++ .../google/cloud/operators/test_vertex_ai.py | 10 +- 6 files changed, 249 insertions(+), 39 deletions(-) diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py index 5499c04678fa6..a552f0fd6a61a 100644 --- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -32,6 +32,8 @@ CreateCustomContainerTrainingJobOperator, CreateCustomPythonPackageTrainingJobOperator, CreateCustomTrainingJobOperator, + DeleteCustomTrainingJobOperator, + ListCustomTrainingJobOperator, ) from airflow.providers.google.cloud.operators.vertex_ai.dataset import ( CreateDatasetOperator, @@ -66,6 +68,9 @@ LOCAL_TRAINING_SCRIPT_PATH = os.environ.get("LOCAL_TRAINING_SCRIPT_PATH", "path_to_training_script") +TRAINING_PIPELINE_ID = "test-training-pipeline-id" +CUSTOM_JOB_ID = "test-custom-job-id" + IMAGE_DATASET = { "display_name": str(uuid4()), "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml", @@ -180,6 +185,24 @@ ) # [END how_to_cloud_vertex_ai_create_custom_training_job_operator] + # [START how_to_cloud_vertex_ai_delete_custom_training_job_operator] + delete_custom_training_job = DeleteCustomTrainingJobOperator( + task_id="delete_custom_training_job", + training_pipeline_id=TRAINING_PIPELINE_ID, + custom_job_id=CUSTOM_JOB_ID, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_delete_custom_training_job_operator] + + # [START how_to_cloud_vertex_ai_list_custom_training_job_operator] + list_custom_training_job = ListCustomTrainingJobOperator( + task_id="list_custom_training_job", + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_list_custom_training_job_operator] + with models.DAG( "example_gcp_vertex_ai_dataset", start_date=days_ago(1), diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index 9477ac44401e2..25088ac57d0b9 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -43,6 +43,19 @@ class CustomJobHook(GoogleBaseHook): """Hook for Google Cloud Vertex AI Custom Job APIs.""" + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._job = None + def get_pipeline_service_client( self, region: Optional[str] = None, @@ -230,6 +243,10 @@ def wait_for_operation(self, timeout: float, operation: Operation): error = operation.exception(timeout=timeout) raise AirflowException(error) + def cancel_job(self) -> None: + """Cancel Job for training pipeline""" + self._job.cancel() + def _run_job( self, job: Union[ @@ -957,7 +974,7 @@ def create_custom_container_training_job( be immediately returned and synced when the Future has completed. :type sync: bool """ - job = self.get_custom_container_training_job( + self._job = self.get_custom_container_training_job( project=project_id, location=region, display_name=display_name, @@ -981,7 +998,7 @@ def create_custom_container_training_job( ) model = self._run_job( - job=job, + job=self._job, dataset=dataset, annotation_schema_uri=annotation_schema_uri, model_display_name=model_display_name, @@ -1405,7 +1422,7 @@ def create_custom_python_package_training_job( be immediately returned and synced when the Future has completed. :type sync: bool """ - job = self.get_custom_python_package_training_job( + self._job = self.get_custom_python_package_training_job( project=project_id, location=region, display_name=display_name, @@ -1430,7 +1447,7 @@ def create_custom_python_package_training_job( ) model = self._run_job( - job=job, + job=self._job, dataset=dataset, annotation_schema_uri=annotation_schema_uri, model_display_name=model_display_name, @@ -1853,7 +1870,7 @@ def create_custom_training_job( be immediately returned and synced when the Future has completed. :type sync: bool """ - job = self.get_custom_training_job( + self._job = self.get_custom_training_job( project=project_id, location=region, display_name=display_name, @@ -1878,7 +1895,7 @@ def create_custom_training_job( ) model = self._run_job( - job=job, + job=self._job, dataset=dataset, annotation_schema_uri=annotation_schema_uri, model_display_name=model_display_name, @@ -2014,7 +2031,7 @@ def delete_custom_job( :type metadata: Sequence[Tuple[str, str]] """ client = self.get_job_service_client(region) - name = JobServiceClient.custom_job_path(project_id, region, custom_job) + name = client.custom_job_path(project_id, region, custom_job) result = client.delete_custom_job( request={ diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 2a8414b0f6ddd..bc35d5f01fa2a 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -20,9 +20,11 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union +from google.api_core.exceptions import NotFound from google.api_core.retry import Retry from google.cloud.aiplatform.models import Model from google.cloud.aiplatform_v1.types.dataset import Dataset +from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance @@ -186,14 +188,7 @@ def on_kill(self) -> None: Callback called when the operator is killed. Cancel any running job. """ - self.hook.cancel_training_pipeline( - project_id=self.project_id, - region=self.region, - training_pipeline=self.display_name, - ) - self.hook.cancel_custom_job( - project_id=self.project_id, region=self.region, custom_job=f"{self.display_name}-custom-job" - ) + self.hook.cancel_job() class CreateCustomContainerTrainingJobOperator(_CustomTrainingJobBaseOperator): @@ -462,7 +457,8 @@ class DeleteCustomTrainingJobOperator(BaseOperator): def __init__( self, *, - training_pipeline: str, + training_pipeline_id: str, + custom_job_id: str, region: str, project_id: str, retry: Optional[Retry] = None, @@ -473,7 +469,8 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self.training_pipeline = training_pipeline + self.training_pipeline = training_pipeline_id + self.custom_job = custom_job_id self.region = region self.project_id = project_id self.retry = retry @@ -484,24 +481,34 @@ def __init__( def execute(self, context: Dict): hook = CustomJobHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - self.log.info("Deleting custom training job: %s", self.training_pipeline) - hook.delete_training_pipeline( - training_pipeline=self.training_pipeline, - region=self.region, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.delete_custom_job( - custom_job=f"{self.training_pipeline}-custom-job", - region=self.region, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - self.log.info("Custom training job deleted.") + try: + self.log.info("Deleting custom training pipeline: %s", self.training_pipeline) + training_pipeline_operation = hook.delete_training_pipeline( + training_pipeline=self.training_pipeline, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation) + self.log.info("Training pipeline was deleted.") + except NotFound: + self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline) + try: + self.log.info("Deleting custom job: %s", self.custom_job) + custom_job_operation = hook.delete_custom_job( + custom_job=self.custom_job, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=custom_job_operation) + self.log.info("Custom job was deleted.") + except NotFound: + self.log.info("The Custom Job ID %s does not exist.", self.custom_job) class ListCustomTrainingJobOperator(BaseOperator): @@ -547,7 +554,7 @@ def __init__( def execute(self, context: Dict): hook = CustomJobHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - hook.list_training_pipelines( + results = hook.list_training_pipelines( region=self.region, project_id=self.project_id, page_size=self.page_size, @@ -559,3 +566,4 @@ def execute(self, context: Dict): metadata=self.metadata, ) self.xcom_push(context, key="project_id", value=self.project_id) + return [TrainingPipeline.to_dict(result) for result in results] diff --git a/docs/apache-airflow-providers-google/index.rst b/docs/apache-airflow-providers-google/index.rst index d5a2f61cff597..4c6dc4eef6f43 100644 --- a/docs/apache-airflow-providers-google/index.rst +++ b/docs/apache-airflow-providers-google/index.rst @@ -96,6 +96,7 @@ PIP package Version required ``google-api-python-client`` ``>=1.6.0,<2.0.0`` ``google-auth-httplib2`` ``>=0.0.1`` ``google-auth`` ``>=1.0.0,<3.0.0`` +``google-cloud-aiplatform`` ``>=1.7.1,<2.0.0`` ``google-cloud-automl`` ``>=2.1.0,<3.0.0`` ``google-cloud-bigquery-datatransfer`` ``>=3.0.0,<4.0.0`` ``google-cloud-bigtable`` ``>=1.0.0,<2.0.0`` diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst index 106592bd11775..92c22af0d4aef 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -14,3 +14,160 @@ KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +Google Cloud VertexAI Operators +======================================= + +The `Google Cloud VertexAI `__ +brings AutoML and AI Platform together into a unified API, client library, and user +interface. AutoML lets you train models on image, tabular, text, and video datasets +without writing code, while training in AI Platform lets you run custom training code. +With Vertex AI, both AutoML training and custom training are available options. +Whichever option you choose for training, you can save models, deploy models, and +request predictions with Vertex AI. + +Creating Datasets +^^^^^^^^^^^^^^^^^ + +To create a Google VertexAI dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`. +The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_create_dataset_operator] + +After creating a dataset you can use it to import some data using +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_import_data_operator] + :end-before: [END how_to_cloud_vertex_ai_import_data_operator] + +To export dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ExportDataOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_export_data_operator] + :end-before: [END how_to_cloud_vertex_ai_export_data_operator] + +To delete dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.DeleteDatasetOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_dataset_operator] + +To get dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.GetDatasetOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_get_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_get_dataset_operator] + +To get a dataset list you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_list_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_list_dataset_operator] + +To update dataset you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_update_dataset_operator] + :end-before: [END how_to_cloud_vertex_ai_update_dataset_operator] + +Creating a Training Jobs +^^^^^^^^^^^^^^^^^^^^^^^^ + +To create a Google Vertex AI training jobs you have three operators +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator`, +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`, +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`. +Each of them will wait for the operation to complete. The results of each operator will be a model +which was trained by user using these operators. + +Preparation step + +For each operator you must prepare and create dataset. Then put dataset id to ``dataset_id`` parameter in operator. + +How to run Container Training Job +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator` + +Before start running this Job you should create a docker image with training script inside. Documentation how to +create image you can find by this link: https://cloud.google.com/vertex-ai/docs/training/create-custom-container +After that you should put link to the image in ``container_uri`` parameter. Also you can type executing command +for container which will be created from this image in ``command`` parameter. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + +How to run Python Package Training Job +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator` + +Before start running this Job you should create a python package with training script inside. Documentation how to +create you can find by this link: https://cloud.google.com/vertex-ai/docs/training/create-python-pre-built-container +Next you should put link to the package in ``python_package_gcs_uri`` parameter, also ``python_module_name`` +parameter should has the name of script which will run your training task. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + +How to run Training Job +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`. + +For this Job you should put path to your local training script inside ``script_path`` parameter. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_create_custom_training_job_operator] + +You can get a list of Training Jobs using +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.ListCustomTrainingJobOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_list_custom_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_list_custom_training_job_operator] + +If you wish to delete a Custom Training Job you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.DeleteCustomTrainingJobOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_custom_training_job_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_custom_training_job_operator] + +Reference +^^^^^^^^^ + +For further information, look at: + +* `Client Library Documentation `__ +* `Product Documentation `__ diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index a40091a1f3920..3d15b7564e542 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -69,6 +69,9 @@ PYTHON_PACKAGE_GCS_URI = "gs://test-vertex-ai-bucket/trainer-0.1.tar.gz" PYTHON_MODULE_NAME = "trainer.task" +TRAINING_PIPELINE_ID = "test-training-pipeline-id" +CUSTOM_JOB_ID = "test-custom-job-id" + TEST_DATASET = { "display_name": "test-dataset-name", "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml", @@ -325,7 +328,8 @@ class TestVertexAIDeleteCustomTrainingJobOperator: def test_execute(self, mock_hook): op = DeleteCustomTrainingJobOperator( task_id=TASK_ID, - training_pipeline=DISPLAY_NAME, + training_pipeline_id=TRAINING_PIPELINE_ID, + custom_job_id=CUSTOM_JOB_ID, region=GCP_LOCATION, project_id=GCP_PROJECT, retry=RETRY, @@ -337,7 +341,7 @@ def test_execute(self, mock_hook): op.execute(context={}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.delete_training_pipeline.assert_called_once_with( - training_pipeline=DISPLAY_NAME, + training_pipeline=TRAINING_PIPELINE_ID, region=GCP_LOCATION, project_id=GCP_PROJECT, retry=RETRY, @@ -345,7 +349,7 @@ def test_execute(self, mock_hook): metadata=METADATA, ) mock_hook.return_value.delete_custom_job.assert_called_once_with( - custom_job=f"{DISPLAY_NAME}-custom-job", + custom_job=CUSTOM_JOB_ID, region=GCP_LOCATION, project_id=GCP_PROJECT, retry=RETRY, From 790babf170a3ab16f4b0283b65420a50d67aa929 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Tue, 7 Dec 2021 16:03:27 +0000 Subject: [PATCH 10/20] Change example_dags --- .../cloud/example_dags/example_vertex_ai.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py index a552f0fd6a61a..04e03c22cc738 100644 --- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -22,9 +22,15 @@ This DAG relies on the following OS environment variables: -* GCP_BUCKET_NAME - Google Cloud Storage bucket where the file exists. +* GCP_VERTEX_AI_BUCKET - Google Cloud Storage bucket where the model will be saved +after training process was finished. +* CUSTOM_CONTAINER_URI - path to container with model. +* PYTHON_PACKAGE_GSC_URI - path to test model in archive. +* LOCAL_TRAINING_SCRIPT_PATH - path to local training script. +* DATASET_ID - ID of dataset which will be used in training process. """ import os +from datetime import datetime from uuid import uuid4 from airflow import models @@ -44,7 +50,6 @@ ListDatasetsOperator, UpdateDatasetOperator, ) -from airflow.utils.dates import days_ago PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") REGION = os.environ.get("GCP_LOCATION", "us-central1") @@ -116,8 +121,9 @@ with models.DAG( "example_gcp_vertex_ai_custom_jobs", - start_date=days_ago(1), schedule_interval="@once", + start_date=datetime(2021, 1, 1), + catchup=False, ) as custom_jobs_dag: # [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator] create_custom_container_training_job = CreateCustomContainerTrainingJobOperator( @@ -205,8 +211,9 @@ with models.DAG( "example_gcp_vertex_ai_dataset", - start_date=days_ago(1), schedule_interval="@once", + start_date=datetime(2021, 1, 1), + catchup=False, ) as dataset_dag: # [START how_to_cloud_vertex_ai_create_dataset_operator] create_image_dataset_job = CreateDatasetOperator( From e592fcbbb4ae40d1fdac5b42e206c03f11cf7380 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Thu, 9 Dec 2021 10:42:18 +0000 Subject: [PATCH 11/20] Change pre-commit static check script --- airflow/providers/google/provider.yaml | 6 ++++-- .../ci/pre_commit/pre_commit_check_provider_yaml_files.py | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 2f392aa1d8d41..8dd83a2a0403b 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -471,7 +471,8 @@ operators: - airflow.providers.google.leveldb.operators.leveldb - integration-name: Google Vertex AI python-modules: - - airflow.providers.google.cloud.operators.vertex_ai + - airflow.providers.google.cloud.operators.vertex_ai.dataset + - airflow.providers.google.cloud.operators.vertex_ai.custom_job sensors: - integration-name: Google BigQuery @@ -668,7 +669,8 @@ hooks: - airflow.providers.google.leveldb.hooks.leveldb - integration-name: Google Vertex AI python-modules: - - airflow.providers.google.cloud.hooks.vertex_ai + - airflow.providers.google.cloud.hooks.vertex_ai.dataset + - airflow.providers.google.cloud.hooks.vertex_ai.custom_job transfers: - source-integration-name: Presto diff --git a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py index ec614a5e466b3..11d9eafb22d7c 100755 --- a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py +++ b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py @@ -151,7 +151,10 @@ def parse_module_data(provider_data, resource_type, yaml_file_path): package_dir = ROOT_DIR + "/" + os.path.dirname(yaml_file_path) provider_package = os.path.dirname(yaml_file_path).replace(os.sep, ".") py_files = chain( - glob(f"{package_dir}/**/{resource_type}/*.py"), glob(f"{package_dir}/{resource_type}/*.py") + glob(f"{package_dir}/**/{resource_type}/*.py"), + glob(f"{package_dir}/{resource_type}/*.py"), + glob(f"{package_dir}/**/{resource_type}/**/*.py"), + glob(f"{package_dir}/{resource_type}/**/*.py"), ) expected_modules = {_filepath_to_module(f) for f in py_files if not f.endswith("/__init__.py")} resource_data = provider_data.get(resource_type, []) From 3ec11b2e0899326af8ea9a43a92f29b293a50d89 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Tue, 14 Dec 2021 15:29:59 +0000 Subject: [PATCH 12/20] Fix documentation build --- .../cloud/example_dags/example_vertex_ai.py | 2 +- .../cloud/hooks/vertex_ai/custom_job.py | 200 ++++-------------- .../google/cloud/hooks/vertex_ai/dataset.py | 9 +- .../cloud/operators/vertex_ai/dataset.py | 6 +- 4 files changed, 43 insertions(+), 174 deletions(-) diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py index 04e03c22cc738..5a459e77d7be3 100644 --- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -274,7 +274,7 @@ project_id=PROJECT_ID, export_config=TEST_EXPORT_CONFIG, ) - # [END how_to_cloud_vertex_ai_export_datas_operator] + # [END how_to_cloud_vertex_ai_export_data_operator] # [START how_to_cloud_vertex_ai_import_data_operator] import_data_job = ImportDataOperator( diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index 25088ac57d0b9..1c755a206e2b9 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -737,17 +737,14 @@ def create_custom_container_training_job( one given on input. The output URI will point to a location where the user only has a read access. :type model_prediction_schema_uri: str - :param project: Project to run training in. Overrides project set in aiplatform.init. - :type project: str - :param location: Location to run training in. Overrides location set in aiplatform.init. - :type location: str - :param credentials: Custom credentials to use to run call training service. Overrides - credentials set in aiplatform.init. - :type credentials: auth_credentials.Credentials + :param project_id: Project to run training in. + :type project_id: str + :param region: Location to run training in. + :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only + characters, can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. @@ -765,8 +762,6 @@ def create_custom_container_training_job( Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - - Overrides encryption_spec_key_name set in aiplatform.init. :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the @@ -776,39 +771,16 @@ def create_custom_container_training_job( resource is created. If set, the trained Model will be secured by this key. - - Overrides encryption_spec_key_name set in aiplatform.init. :type model_encryption_spec_key_name: Optional[str] - :param staging_bucket: Bucket used to stage source and training artifacts. Overrides - staging_bucket set in aiplatform.init. + :param staging_bucket: Bucket used to stage source and training artifacts. :type staging_bucket: str - - :param dataset: Vertex AI to fit this training against. Custom training script should - retrieve datasets through passed in environment variables uris: - - os.environ["AIP_TRAINING_DATA_URI"] - os.environ["AIP_VALIDATION_DATA_URI"] - os.environ["AIP_TEST_DATA_URI"] - - Additionally the dataset format is passed in as: - - os.environ["AIP_DATA_FORMAT"] - :type dataset: Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] + :param dataset: Vertex AI to fit this training against. + :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, + datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) - The schema files that can be used here are found in - gs://google-cloud-aiplatform/schema/dataset/annotation/, - note that the chosen schema must be consistent with - ``metadata`` - of the Dataset specified by - ``dataset_id``. Only Annotations that both match this schema and belong to DataItems not ignored by the split method are used in @@ -831,7 +803,7 @@ def create_custom_container_training_job( :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only + characters, can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. @@ -855,7 +827,7 @@ def create_custom_container_training_job( Users submitting jobs must have act-as permission on this run-as account. :type service_account: str :param network: The full name of the Compute Engine network to which the job - should be peered. For example, projects/12345/global/networks/myVPC. + should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. :type network: str @@ -881,10 +853,6 @@ def create_custom_container_training_job( and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - - environment_variables = { - 'MY_KEY': 'MY_VALUE' - } :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be @@ -945,7 +913,7 @@ def create_custom_container_training_job( Supported only for tabular and time series Datasets. :type predefined_split_column_name: str - :param timestamp_split_column_name : Optional. The key is a name of one of the Dataset's data + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a @@ -954,18 +922,9 @@ def create_custom_container_training_job( Supported only for tabular and time series Datasets. :type timestamp_split_column_name: str - :param tensorboard: Optional. The name of a Vertex AI - [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] - resource to which this CustomJob will upload Tensorboard + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` - - The training script should write Tensorboard to following Vertex AI environment - variable: - - AIP_TENSORBOARD_LOG_DIR - - `service_account` is required with provided `tensorboard`. For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str @@ -1185,17 +1144,14 @@ def create_custom_python_package_training_job( one given on input. The output URI will point to a location where the user only has a read access. :type model_prediction_schema_uri: str - :param project: Project to run training in. Overrides project set in aiplatform.init. - :type project: str - :param location: Location to run training in. Overrides location set in aiplatform.init. - :type location: str - :param credentials: Custom credentials to use to run call training service. Overrides - credentials set in aiplatform.init. - :type credentials: auth_credentials.Credentials + :param project_id: Project to run training in. + :type project_id: str + :param region: Location to run training in. + :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only + characters, can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. @@ -1213,8 +1169,6 @@ def create_custom_python_package_training_job( Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - - Overrides encryption_spec_key_name set in aiplatform.init. :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the @@ -1224,39 +1178,16 @@ def create_custom_python_package_training_job( resource is created. If set, the trained Model will be secured by this key. - - Overrides encryption_spec_key_name set in aiplatform.init. :type model_encryption_spec_key_name: Optional[str] - :param staging_bucket: Bucket used to stage source and training artifacts. Overrides - staging_bucket set in aiplatform.init. + :param staging_bucket: Bucket used to stage source and training artifacts. :type staging_bucket: str - - :param dataset: Vertex AI to fit this training against. Custom training script should - retrieve datasets through passed in environment variables uris: - - os.environ["AIP_TRAINING_DATA_URI"] - os.environ["AIP_VALIDATION_DATA_URI"] - os.environ["AIP_TEST_DATA_URI"] - - Additionally the dataset format is passed in as: - - os.environ["AIP_DATA_FORMAT"] - :type dataset: Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] + :param dataset: Vertex AI to fit this training against. + :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, + datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) - The schema files that can be used here are found in - gs://google-cloud-aiplatform/schema/dataset/annotation/, - note that the chosen schema must be consistent with - ``metadata`` - of the Dataset specified by - ``dataset_id``. Only Annotations that both match this schema and belong to DataItems not ignored by the split method are used in @@ -1279,7 +1210,7 @@ def create_custom_python_package_training_job( :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only + characters, can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. @@ -1303,7 +1234,7 @@ def create_custom_python_package_training_job( Users submitting jobs must have act-as permission on this run-as account. :type service_account: str :param network: The full name of the Compute Engine network to which the job - should be peered. For example, projects/12345/global/networks/myVPC. + should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. :type network: str @@ -1329,10 +1260,6 @@ def create_custom_python_package_training_job( and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - - environment_variables = { - 'MY_KEY': 'MY_VALUE' - } :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be @@ -1393,7 +1320,7 @@ def create_custom_python_package_training_job( Supported only for tabular and time series Datasets. :type predefined_split_column_name: str - :param timestamp_split_column_name : Optional. The key is a name of one of the Dataset's data + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a @@ -1402,18 +1329,9 @@ def create_custom_python_package_training_job( Supported only for tabular and time series Datasets. :type timestamp_split_column_name: str - :param tensorboard: Optional. The name of a Vertex AI - [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] - resource to which this CustomJob will upload Tensorboard + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` - - The training script should write Tensorboard to following Vertex AI environment - variable: - - AIP_TENSORBOARD_LOG_DIR - - `service_account` is required with provided `tensorboard`. For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str @@ -1634,17 +1552,14 @@ def create_custom_training_job( one given on input. The output URI will point to a location where the user only has a read access. :type model_prediction_schema_uri: str - :param project: Project to run training in. Overrides project set in aiplatform.init. - :type project: str - :param location: Location to run training in. Overrides location set in aiplatform.init. - :type location: str - :param credentials: Custom credentials to use to run call training service. Overrides - credentials set in aiplatform.init. - :type credentials: auth_credentials.Credentials + :param project_id: Project to run training in. + :type project_id: str + :param region: Location to run training in. + :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only + characters, can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. @@ -1662,8 +1577,6 @@ def create_custom_training_job( Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - - Overrides encryption_spec_key_name set in aiplatform.init. :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the @@ -1673,38 +1586,16 @@ def create_custom_training_job( resource is created. If set, the trained Model will be secured by this key. - - Overrides encryption_spec_key_name set in aiplatform.init. :type model_encryption_spec_key_name: Optional[str] - :param staging_bucket: Bucket used to stage source and training artifacts. Overrides - staging_bucket set in aiplatform.init. + :param staging_bucket: Bucket used to stage source and training artifacts. :type staging_bucket: str - :param dataset: Vertex AI to fit this training against. Custom training script should - retrieve datasets through passed in environment variables uris: - - os.environ["AIP_TRAINING_DATA_URI"] - os.environ["AIP_VALIDATION_DATA_URI"] - os.environ["AIP_TEST_DATA_URI"] - - Additionally the dataset format is passed in as: - - os.environ["AIP_DATA_FORMAT"] - :type dataset: Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] + :param dataset: Vertex AI to fit this training against. + :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, + datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) - The schema files that can be used here are found in - gs://google-cloud-aiplatform/schema/dataset/annotation/, - note that the chosen schema must be consistent with - ``metadata`` - of the Dataset specified by - ``dataset_id``. Only Annotations that both match this schema and belong to DataItems not ignored by the split method are used in @@ -1727,7 +1618,7 @@ def create_custom_training_job( :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only + characters, can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. @@ -1751,7 +1642,7 @@ def create_custom_training_job( Users submitting jobs must have act-as permission on this run-as account. :type service_account: str :param network: The full name of the Compute Engine network to which the job - should be peered. For example, projects/12345/global/networks/myVPC. + should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. :type network: str @@ -1777,10 +1668,6 @@ def create_custom_training_job( and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - - environment_variables = { - 'MY_KEY': 'MY_VALUE' - } :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be @@ -1841,7 +1728,7 @@ def create_custom_training_job( Supported only for tabular and time series Datasets. :type predefined_split_column_name: str - :param timestamp_split_column_name : Optional. The key is a name of one of the Dataset's data + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a @@ -1850,18 +1737,9 @@ def create_custom_training_job( Supported only for tabular and time series Datasets. :type timestamp_split_column_name: str - :param tensorboard: Optional. The name of a Vertex AI - [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] - resource to which this CustomJob will upload Tensorboard + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` - - The training script should write Tensorboard to following Vertex AI environment - variable: - - AIP_TENSORBOARD_LOG_DIR - - `service_account` is required with provided `tensorboard`. For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py index f624a06bdda17..ce613c3de0dcf 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py @@ -333,7 +333,7 @@ def list_annotations( metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> ListAnnotationsPager: """ - Lists Annotations belongs to a dataitem + Lists Annotations belongs to a data item :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :type project_id: str @@ -518,12 +518,7 @@ def update_dataset( :type dataset_id: str :param dataset: Required. The Dataset which replaces the resource on the server. :type dataset: google.cloud.aiplatform_v1.types.Dataset - :param update_mask: Required. The update mask applies to the resource. For the ``FieldMask`` - definition, see [google.protobuf.FieldMask][google.protobuf.FieldMask]. - Updatable fields: - - ``display_name`` - - ``description`` - - ``labels`` + :param update_mask: Required. The update mask applies to the resource. :type update_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. :type retry: google.api_core.retry.Retry diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index b2b04cd7efe73..fc43b05a9d766 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -575,11 +575,7 @@ class UpdateDatasetOperator(BaseOperator): :type dataset_id: str :param dataset: Required. The Dataset which replaces the resource on the server. :type dataset: google.cloud.aiplatform_v1.types.Dataset - :param update_mask: Required. The update mask applies to the resource. For the ``FieldMask`` definition, - see [google.protobuf.FieldMask][google.protobuf.FieldMask]. Updatable fields: - - ``display_name`` - - ``description`` - - ``labels`` + :param update_mask: Required. The update mask applies to the resource. :type update_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. :type retry: google.api_core.retry.Retry From a718d1f31f774bd325e5c471c6001d00cac8e0a3 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Tue, 14 Dec 2021 15:56:58 +0000 Subject: [PATCH 13/20] Add delegate_to parameter --- .../cloud/operators/vertex_ai/custom_job.py | 22 ++++- .../cloud/operators/vertex_ai/dataset.py | 84 +++++++++++++++++-- 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index bc35d5f01fa2a..bf5226b2a426c 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -126,6 +126,7 @@ def __init__( tensorboard: Optional[str] = None, sync=True, gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -180,8 +181,11 @@ def __init__( self.sync = sync # END Run param self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook = CustomJobHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) + self.hook = CustomJobHook( + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain + ) def on_kill(self) -> None: """ @@ -465,6 +469,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -477,10 +482,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = CustomJobHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) try: self.log.info("Deleting custom training pipeline: %s", self.training_pipeline) training_pipeline_operation = hook.delete_training_pipeline( @@ -536,6 +546,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -550,10 +561,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = CustomJobHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) results = hook.list_training_pipelines( region=self.region, project_id=self.project_id, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index fc43b05a9d766..3af83dc3ee29e 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -91,6 +91,10 @@ class CreateDatasetOperator(BaseOperator): :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -115,6 +119,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -126,10 +131,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) self.log.info("Creating dataset") operation = hook.create_dataset( @@ -177,6 +187,10 @@ class GetDatasetOperator(BaseOperator): :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -202,6 +216,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> Dataset: @@ -214,10 +229,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) try: self.log.info("Get dataset: %s", self.dataset_id) @@ -263,6 +283,10 @@ class DeleteDatasetOperator(BaseOperator): :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -286,6 +310,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -297,10 +322,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) try: self.log.info("Deleting dataset: %s", self.dataset_id) @@ -338,6 +368,10 @@ class ExportDataOperator(BaseOperator): :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -362,6 +396,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -374,10 +409,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) self.log.info("Exporting data: %s", self.dataset_id) operation = hook.export_data( @@ -414,6 +454,10 @@ class ImportDataOperator(BaseOperator): :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -438,6 +482,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -450,10 +495,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) self.log.info("Importing data: %s", self.dataset_id) operation = hook.import_data( @@ -496,6 +546,10 @@ class ListDatasetsOperator(BaseOperator): :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -524,6 +578,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -539,10 +594,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) results = hook.list_datasets( project_id=self.project_id, region=self.region, @@ -585,6 +645,10 @@ class UpdateDatasetOperator(BaseOperator): :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -610,6 +674,7 @@ def __init__( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = "", gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -623,10 +688,15 @@ def __init__( self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatasetHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DatasetHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) self.log.info("Updating dataset: %s", self.dataset_id) result = hook.update_dataset( project_id=self.project_id, From fc2c912ed8deb7c14df6286c4bf8a4c6dd7ac913 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Wed, 15 Dec 2021 13:06:55 +0000 Subject: [PATCH 14/20] Change __init__ method for CustomJobBase class --- .../google/cloud/operators/vertex_ai/custom_job.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index bf5226b2a426c..9c490d0e6d861 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -183,8 +183,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain + self.hook: Optional[CustomJobHook] = None + + def execute(self, context): self.hook = CustomJobHook( - gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, ) def on_kill(self) -> None: @@ -192,7 +197,8 @@ def on_kill(self) -> None: Callback called when the operator is killed. Cancel any running job. """ - self.hook.cancel_job() + if self.hook: + self.hook.cancel_job() class CreateCustomContainerTrainingJobOperator(_CustomTrainingJobBaseOperator): @@ -215,6 +221,7 @@ def __init__( self.command = command def execute(self, context): + super().execute(context) model = self.hook.create_custom_container_training_job( project_id=self.project_id, region=self.region, @@ -300,6 +307,7 @@ def __init__( self.python_module_name = python_module_name def execute(self, context): + super().execute(context) model = self.hook.create_custom_python_package_training_job( project_id=self.project_id, region=self.region, @@ -388,6 +396,7 @@ def __init__( self.script_path = script_path def execute(self, context): + super().execute(context) model = self.hook.create_custom_training_job( project_id=self.project_id, region=self.region, From 6c28ec5b66e3c946df9e45bd44695e3df8418299 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Wed, 15 Dec 2021 13:49:50 +0000 Subject: [PATCH 15/20] Change _CustomTrainingJobBaseOperator to CustomTrainingJobBaseOperator --- .../google/cloud/operators/vertex_ai/custom_job.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 9c490d0e6d861..6a1b20b6a4cec 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -73,7 +73,7 @@ def get_link(self, operator, dttm): ) -class _CustomTrainingJobBaseOperator(BaseOperator): +class CustomTrainingJobBaseOperator(BaseOperator): """The base class for operators that launch Custom jobs on VertexAI.""" def __init__( @@ -201,7 +201,7 @@ def on_kill(self) -> None: self.hook.cancel_job() -class CreateCustomContainerTrainingJobOperator(_CustomTrainingJobBaseOperator): +class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): """Create Custom Container Training job""" template_fields = [ @@ -286,7 +286,7 @@ def execute(self, context): return result -class CreateCustomPythonPackageTrainingJobOperator(_CustomTrainingJobBaseOperator): +class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator): """Create Custom Python Package Training job""" template_fields = [ @@ -373,7 +373,7 @@ def execute(self, context): return result -class CreateCustomTrainingJobOperator(_CustomTrainingJobBaseOperator): +class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): """Create Custom Training job""" template_fields = [ From 8a39d14f63ad2b598bf4138a6e304a78a97e8e22 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Wed, 15 Dec 2021 15:33:15 +0000 Subject: [PATCH 16/20] Fix unit tests --- .../google/cloud/operators/test_vertex_ai.py | 56 +++++++++++++++---- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 3d15b7564e542..ec5a63d47890d 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -44,6 +44,7 @@ GCP_PROJECT = "test-project" GCP_LOCATION = "test-location" GCP_CONN_ID = "test-conn" +DELEGATE_TO = "test-delegate-to" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] STAGING_BUCKET = "gs://test-vertex-ai-bucket" DISPLAY_NAME = "display_name_1" # Create random display name @@ -101,6 +102,7 @@ def test_execute(self, mock_hook): op = CreateCustomContainerTrainingJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -120,7 +122,9 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT, ) op.execute(context={'ti': mock.MagicMock()}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.create_custom_container_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -178,6 +182,7 @@ def test_execute(self, mock_hook): op = CreateCustomPythonPackageTrainingJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -198,7 +203,9 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT, ) op.execute(context={'ti': mock.MagicMock()}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -257,6 +264,7 @@ def test_execute(self, mock_hook): op = CreateCustomTrainingJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -270,7 +278,9 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT, ) op.execute(context={'ti': mock.MagicMock()}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.create_custom_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -336,10 +346,13 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.delete_training_pipeline.assert_called_once_with( training_pipeline=TRAINING_PIPELINE_ID, region=GCP_LOCATION, @@ -369,6 +382,7 @@ def test_execute(self, mock_hook): op = ListCustomTrainingJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -381,7 +395,9 @@ def test_execute(self, mock_hook): metadata=METADATA, ) op.execute(context={'ti': mock.MagicMock()}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.list_training_pipelines.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -402,6 +418,7 @@ def test_execute(self, mock_hook, to_dict_mock): op = CreateDatasetOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -411,7 +428,9 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, ) op.execute(context={'ti': mock.MagicMock()}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.create_dataset.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -429,6 +448,7 @@ def test_execute(self, mock_hook, to_dict_mock): op = DeleteDatasetOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -438,7 +458,9 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, ) op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.delete_dataset.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -456,6 +478,7 @@ def test_execute(self, mock_hook, to_dict_mock): op = ExportDataOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -466,7 +489,9 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, ) op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.export_data.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -485,6 +510,7 @@ def test_execute(self, mock_hook, to_dict_mock): op = ImportDataOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -495,7 +521,9 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, ) op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.import_data.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -520,6 +548,7 @@ def test_execute(self, mock_hook, to_dict_mock): op = ListDatasetsOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -533,7 +562,9 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, ) op.execute(context={'ti': mock.MagicMock()}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.list_datasets.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -555,6 +586,7 @@ def test_execute(self, mock_hook, to_dict_mock): op = UpdateDatasetOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, project_id=GCP_PROJECT, region=GCP_LOCATION, @@ -566,7 +598,9 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, ) op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN + ) mock_hook.return_value.update_dataset.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, From 87e3e602c994c9382842bdbdff53a20d263518c3 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Mon, 17 Jan 2022 14:26:48 +0000 Subject: [PATCH 17/20] Update CustomJobs docstring --- .../cloud/operators/vertex_ai/custom_job.py | 1036 ++++++++++++++++- .../cloud/operators/vertex_ai/dataset.py | 17 +- 2 files changed, 1034 insertions(+), 19 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 6a1b20b6a4cec..3cfa6c80fd882 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -29,6 +29,7 @@ from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook +from airflow.utils.context import Context VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" VERTEX_AI_MODEL_LINK = ( @@ -185,7 +186,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.hook: Optional[CustomJobHook] = None - def execute(self, context): + def execute(self, context: Context): self.hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -202,7 +203,316 @@ def on_kill(self) -> None: class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): - """Create Custom Container Training job""" + """Create Custom Container Training job + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param display_name: Required. The user-defined name of this TrainingPipeline. + :type display_name: str + :param command: The command to be invoked when the container is started. + It overrides the entrypoint instruction in Dockerfile when provided + :type command: Sequence[str] + :param container_uri: Required: Uri of the training container image in the GCR. + :type container_uri: str + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :type model_serving_container_image_uri: str + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :type model_serving_container_predict_route: str + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :type model_serving_container_health_route: str + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :type model_serving_container_command: Sequence[str] + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :type model_serving_container_args: Sequence[str] + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :type model_serving_container_environment_variables: Dict[str, str] + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :type model_serving_container_ports: Sequence[int] + :param model_description: The description of the Model. + :type model_description: str + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_instance_schema_uri: str + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :type model_parameters_schema_uri: str + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_prediction_schema_uri: str + :param project_id: Project to run training in. + :type project_id: str + :param region: Location to run training in. + :type region: str + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type labels: Dict[str, str] + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :type training_encryption_spec_key_name: Optional[str] + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :type model_encryption_spec_key_name: Optional[str] + :param staging_bucket: Bucket used to stage source and training artifacts. + :type staging_bucket: str + :param dataset: Vertex AI to fit this training against. + :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, + datasets.VideoDataset,] + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :type annotation_schema_uri: str + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :type model_display_name: str + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type model_labels: Dict[str, str] + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :type base_output_dir: str + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :type service_account: str + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :type network: str + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :type bigquery_destination: str + :param args: Command line arguments to be passed to the Python script. + :type args: List[Unions[str, int, float]] + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :type environment_variables: Dict[str, str] + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :type replica_count: int + :param machine_type: The type of machine to use for training. + :type machine_type: str + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :type accelerator_type: str + :param accelerator_count: The number of accelerators to attach to a worker replica. + :type accelerator_count: int + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :type boot_disk_type: str + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :type boot_disk_size_gb: int + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :type training_fraction_split: float + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :type validation_fraction_split: float + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :type test_fraction_split: float + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type training_filter_split: str + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type validation_filter_split: str + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type test_filter_split: str + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type predefined_split_column_name: str + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type timestamp_split_column_name: str + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :type tensorboard: str + :param sync: Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :type sync: bool + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ template_fields = [ 'region', @@ -220,7 +530,7 @@ def __init__( super().__init__(**kwargs) self.command = command - def execute(self, context): + def execute(self, context: Context): super().execute(context) model = self.hook.create_custom_container_training_job( project_id=self.project_id, @@ -287,7 +597,317 @@ def execute(self, context): class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator): - """Create Custom Python Package Training job""" + """Create Custom Python Package Training job + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param display_name: Required. The user-defined name of this TrainingPipeline. + :type display_name: str + :param python_package_gcs_uri: Required: GCS location of the training python package. + :type python_package_gcs_uri: str + :param python_module_name: Required: The module name of the training python package. + :type python_module_name: str + :param container_uri: Required: Uri of the training container image in the GCR. + :type container_uri: str + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :type model_serving_container_image_uri: str + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :type model_serving_container_predict_route: str + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :type model_serving_container_health_route: str + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :type model_serving_container_command: Sequence[str] + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :type model_serving_container_args: Sequence[str] + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :type model_serving_container_environment_variables: Dict[str, str] + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :type model_serving_container_ports: Sequence[int] + :param model_description: The description of the Model. + :type model_description: str + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_instance_schema_uri: str + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :type model_parameters_schema_uri: str + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_prediction_schema_uri: str + :param project_id: Project to run training in. + :type project_id: str + :param region: Location to run training in. + :type region: str + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type labels: Dict[str, str] + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :type training_encryption_spec_key_name: Optional[str] + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :type model_encryption_spec_key_name: Optional[str] + :param staging_bucket: Bucket used to stage source and training artifacts. + :type staging_bucket: str + :param dataset: Vertex AI to fit this training against. + :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, + datasets.VideoDataset,] + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :type annotation_schema_uri: str + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :type model_display_name: str + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type model_labels: Dict[str, str] + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :type base_output_dir: str + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :type service_account: str + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :type network: str + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :type bigquery_destination: str + :param args: Command line arguments to be passed to the Python script. + :type args: List[Unions[str, int, float]] + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :type environment_variables: Dict[str, str] + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :type replica_count: int + :param machine_type: The type of machine to use for training. + :type machine_type: str + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :type accelerator_type: str + :param accelerator_count: The number of accelerators to attach to a worker replica. + :type accelerator_count: int + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :type boot_disk_type: str + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :type boot_disk_size_gb: int + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :type training_fraction_split: float + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :type validation_fraction_split: float + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :type test_fraction_split: float + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type training_filter_split: str + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type validation_filter_split: str + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type test_filter_split: str + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type predefined_split_column_name: str + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type timestamp_split_column_name: str + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :type tensorboard: str + :param sync: Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :type sync: bool + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ template_fields = [ 'region', @@ -306,7 +926,7 @@ def __init__( self.python_package_gcs_uri = python_package_gcs_uri self.python_module_name = python_module_name - def execute(self, context): + def execute(self, context: Context): super().execute(context) model = self.hook.create_custom_python_package_training_job( project_id=self.project_id, @@ -374,7 +994,317 @@ def execute(self, context): class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): - """Create Custom Training job""" + """Create Custom Training job + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param display_name: Required. The user-defined name of this TrainingPipeline. + :type display_name: str + :param script_path: Required. Local path to training script. + :type script_path: str + :param container_uri: Required: Uri of the training container image in the GCR. + :type container_uri: str + :param requirements: List of python packages dependencies of script. + :type requirements: Sequence[str] + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :type model_serving_container_image_uri: str + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :type model_serving_container_predict_route: str + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :type model_serving_container_health_route: str + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :type model_serving_container_command: Sequence[str] + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :type model_serving_container_args: Sequence[str] + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :type model_serving_container_environment_variables: Dict[str, str] + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :type model_serving_container_ports: Sequence[int] + :param model_description: The description of the Model. + :type model_description: str + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_instance_schema_uri: str + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :type model_parameters_schema_uri: str + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :type model_prediction_schema_uri: str + :param project_id: Project to run training in. + :type project_id: str + :param region: Location to run training in. + :type region: str + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type labels: Dict[str, str] + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :type training_encryption_spec_key_name: Optional[str] + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :type model_encryption_spec_key_name: Optional[str] + :param staging_bucket: Bucket used to stage source and training artifacts. + :type staging_bucket: str + :param dataset: Vertex AI to fit this training against. + :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, + datasets.VideoDataset,] + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :type annotation_schema_uri: str + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :type model_display_name: str + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :type model_labels: Dict[str, str] + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :type base_output_dir: str + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :type service_account: str + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :type network: str + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :type bigquery_destination: str + :param args: Command line arguments to be passed to the Python script. + :type args: List[Unions[str, int, float]] + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :type environment_variables: Dict[str, str] + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :type replica_count: int + :param machine_type: The type of machine to use for training. + :type machine_type: str + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :type accelerator_type: str + :param accelerator_count: The number of accelerators to attach to a worker replica. + :type accelerator_count: int + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :type boot_disk_type: str + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :type boot_disk_size_gb: int + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :type training_fraction_split: float + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :type validation_fraction_split: float + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :type test_fraction_split: float + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type training_filter_split: str + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type validation_filter_split: str + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :type test_filter_split: str + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type predefined_split_column_name: str + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :type timestamp_split_column_name: str + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + :type tensorboard: str + :param sync: Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + :type sync: bool + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ template_fields = [ 'region', @@ -395,7 +1325,7 @@ def __init__( self.requirements = requirements self.script_path = script_path - def execute(self, context): + def execute(self, context: Context): super().execute(context) model = self.hook.create_custom_training_job( project_id=self.project_id, @@ -463,7 +1393,38 @@ def execute(self, context): class DeleteCustomTrainingJobOperator(BaseOperator): - """Deletes a CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob.""" + """Deletes a CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob. + + :param training_pipeline_id: Required. The name of the TrainingPipeline resource to be deleted. + :type training_pipeline_id: str + :param custom_job_id: Required. The name of the CustomJob to delete. + :type custom_job_id: str + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ template_fields = ("region", "project_id", "impersonation_chain") @@ -494,7 +1455,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Dict): + def execute(self, context: Context): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -531,7 +1492,60 @@ def execute(self, context: Dict): class ListCustomTrainingJobOperator(BaseOperator): - """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location.""" + """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :type project_id: str + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :type region: str + :param filter: Optional. The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + :type filter: str + :param page_size: Optional. The standard list page size. + :type page_size: int + :param page_token: Optional. The standard list page token. Typically obtained via + [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] + of the previous + [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] + call. + :type page_token: str + :param read_mask: Optional. Mask specifying which fields to read. + :type read_mask: google.protobuf.field_mask_pb2.FieldMask + :param retry: Designation of what errors, if any, should be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The timeout for this request. + :type timeout: float + :param metadata: Strings which should be sent along with the request as metadata. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ template_fields = [ "region", @@ -573,7 +1587,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Dict): + def execute(self, context: Context): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index 3af83dc3ee29e..18e1b47c41bba 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -18,7 +18,7 @@ # """This module contains Google Vertex AI operators.""" -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union from google.api_core.exceptions import NotFound from google.api_core.retry import Retry @@ -28,6 +28,7 @@ from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook +from airflow.utils.context import Context VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" VERTEX_AI_DATASET_LINK = ( @@ -134,7 +135,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Dict): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -232,7 +233,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -325,7 +326,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Dict): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -412,7 +413,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Dict): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -498,7 +499,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Dict): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -597,7 +598,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Dict): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -691,7 +692,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, From 4f325dcb9971d37aa9c2caa61b3d63fa75c9830f Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Tue, 18 Jan 2022 09:40:04 +0000 Subject: [PATCH 18/20] Add TYPE_CHECKING for Context --- .../cloud/operators/vertex_ai/custom_job.py | 18 +++++++++-------- .../cloud/operators/vertex_ai/dataset.py | 20 ++++++++++--------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 3cfa6c80fd882..3d167f1957f4d 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -18,7 +18,7 @@ # """This module contains Google Vertex AI operators.""" -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from google.api_core.exceptions import NotFound from google.api_core.retry import Retry @@ -29,7 +29,9 @@ from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook -from airflow.utils.context import Context + +if TYPE_CHECKING: + from airflow.utils.context import Context VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" VERTEX_AI_MODEL_LINK = ( @@ -186,7 +188,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.hook: Optional[CustomJobHook] = None - def execute(self, context: Context): + def execute(self, context: 'Context'): self.hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -530,7 +532,7 @@ def __init__( super().__init__(**kwargs) self.command = command - def execute(self, context: Context): + def execute(self, context: 'Context'): super().execute(context) model = self.hook.create_custom_container_training_job( project_id=self.project_id, @@ -926,7 +928,7 @@ def __init__( self.python_package_gcs_uri = python_package_gcs_uri self.python_module_name = python_module_name - def execute(self, context: Context): + def execute(self, context: 'Context'): super().execute(context) model = self.hook.create_custom_python_package_training_job( project_id=self.project_id, @@ -1325,7 +1327,7 @@ def __init__( self.requirements = requirements self.script_path = script_path - def execute(self, context: Context): + def execute(self, context: 'Context'): super().execute(context) model = self.hook.create_custom_training_job( project_id=self.project_id, @@ -1455,7 +1457,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1587,7 +1589,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index 18e1b47c41bba..f9c917f72402c 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -18,7 +18,7 @@ # """This module contains Google Vertex AI operators.""" -from typing import Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union from google.api_core.exceptions import NotFound from google.api_core.retry import Retry @@ -28,7 +28,9 @@ from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook -from airflow.utils.context import Context + +if TYPE_CHECKING: + from airflow.utils.context import Context VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai" VERTEX_AI_DATASET_LINK = ( @@ -135,7 +137,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -233,7 +235,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -326,7 +328,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -413,7 +415,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -499,7 +501,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -598,7 +600,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -692,7 +694,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Context): + def execute(self, context: 'Context'): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, From c7e681228fb671771870a0a8afbf8a236ded4b3c Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Thu, 20 Jan 2022 14:18:24 +0000 Subject: [PATCH 19/20] Change docstring for sync parameter --- .../google/cloud/hooks/vertex_ai/custom_job.py | 15 ++++++++------- .../cloud/operators/vertex_ai/custom_job.py | 8 ++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index 1c755a206e2b9..5f919a8a871f7 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -86,7 +86,7 @@ def get_custom_container_training_job( self, display_name: str, container_uri: str, - command: Sequence[str] = None, + command: Sequence[str] = [], model_serving_container_image_uri: Optional[str] = None, model_serving_container_predict_route: Optional[str] = None, model_serving_container_health_route: Optional[str] = None, @@ -235,7 +235,7 @@ def extract_model_id(obj: Dict) -> str: """Returns unique id of the Model.""" return obj["name"].rpartition("/")[-1] - def wait_for_operation(self, timeout: float, operation: Operation): + def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -245,7 +245,8 @@ def wait_for_operation(self, timeout: float, operation: Operation): def cancel_job(self) -> None: """Cancel Job for training pipeline""" - self._job.cancel() + if self._job: + self._job.cancel() def _run_job( self, @@ -590,7 +591,7 @@ def create_custom_container_training_job( region: str, display_name: str, container_uri: str, - command: Sequence[str] = None, + command: Sequence[str] = [], model_serving_container_image_uri: Optional[str] = None, model_serving_container_predict_route: Optional[str] = None, model_serving_container_health_route: Optional[str] = None, @@ -928,7 +929,7 @@ def create_custom_container_training_job( For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str - :param sync: Whether to execute this method synchronously. If False, this method + :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. :type sync: bool @@ -1335,7 +1336,7 @@ def create_custom_python_package_training_job( For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str - :param sync: Whether to execute this method synchronously. If False, this method + :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. :type sync: bool @@ -1743,7 +1744,7 @@ def create_custom_training_job( For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str - :param sync: Whether to execute this method synchronously. If False, this method + :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. :type sync: bool diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 3d167f1957f4d..5fece7608b1d7 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -495,7 +495,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str - :param sync: Whether to execute this method synchronously. If False, this method + :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. :type sync: bool @@ -526,7 +526,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): def __init__( self, *, - command: Sequence[str] = None, + command: Sequence[str] = [], **kwargs, ) -> None: super().__init__(**kwargs) @@ -890,7 +890,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str - :param sync: Whether to execute this method synchronously. If False, this method + :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. :type sync: bool @@ -1287,7 +1287,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :type tensorboard: str - :param sync: Whether to execute this method synchronously. If False, this method + :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. :type sync: bool From 50a18bca43665efa55f0232f7362a858d1661b76 Mon Sep 17 00:00:00 2001 From: MaksYermak Date: Thu, 20 Jan 2022 16:24:52 +0000 Subject: [PATCH 20/20] Delete :type from docstring --- .../cloud/hooks/vertex_ai/custom_job.py | 245 ------------------ .../google/cloud/hooks/vertex_ai/dataset.py | 82 ------ .../cloud/operators/vertex_ai/custom_job.py | 183 ------------- .../cloud/operators/vertex_ai/dataset.py | 71 ----- 4 files changed, 581 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index 5f919a8a871f7..fc59753ae074d 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -342,17 +342,11 @@ def cancel_pipeline_job( [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param pipeline_job: The name of the PipelineJob to cancel. - :type pipeline_job: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) name = client.pipeline_job_path(project_id, region, pipeline_job) @@ -387,17 +381,11 @@ def cancel_training_pipeline( [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param training_pipeline: Required. The name of the TrainingPipeline to cancel. - :type training_pipeline: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) name = client.training_pipeline_path(project_id, region, training_pipeline) @@ -432,17 +420,11 @@ def cancel_custom_job( [CustomJob.state][google.cloud.aiplatform.v1.CustomJob.state] is set to ``CANCELLED``. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param custom_job: Required. The name of the CustomJob to cancel. - :type custom_job: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_job_service_client(region) name = JobServiceClient.custom_job_path(project_id, region, custom_job) @@ -471,22 +453,15 @@ def create_pipeline_job( Creates a PipelineJob. A PipelineJob will run immediately when created. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param pipeline_job: Required. The PipelineJob to create. - :type pipeline_job: google.cloud.aiplatform_v1.types.PipelineJob :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of the PipelineJob name. If not provided, an ID will be automatically generated. This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. - :type pipeline_job_id: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) parent = client.common_location_path(project_id, region) @@ -517,17 +492,11 @@ def create_training_pipeline( Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param training_pipeline: Required. The TrainingPipeline to create. - :type training_pipeline: google.cloud.aiplatform_v1.types.TrainingPipeline :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) parent = client.common_location_path(project_id, region) @@ -557,18 +526,12 @@ def create_custom_job( Creates a CustomJob. A created CustomJob right away will be attempted to be run. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param custom_job: Required. The CustomJob to create. This corresponds to the ``custom_job`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - :type custom_job: google.cloud.aiplatform_v1.types.CustomJob :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_job_service_client(region) parent = JobServiceClient.common_location_path(project_id, region) @@ -646,24 +609,18 @@ def create_custom_container_training_job( Create Custom Container Training Job :param display_name: Required. The user-defined name of this TrainingPipeline. - :type display_name: str :param command: The command to be invoked when the container is started. It overrides the entrypoint instruction in Dockerfile when provided - :type command: Sequence[str] :param container_uri: Required: Uri of the training container image in the GCR. - :type container_uri: str :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI of the Model serving container suitable for serving the model produced by the training script. - :type model_serving_container_image_uri: str :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An HTTP path to send prediction requests to the container, and which must be supported by it. If not specified a default HTTP path will be used by Vertex AI. - :type model_serving_container_predict_route: str :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an HTTP path to send health check requests to the container, and which must be supported by it. If not specified a standard HTTP path will be used by AI Platform. - :type model_serving_container_health_route: str :param model_serving_container_command: The command with which the container is run. Not executed within a shell. The Docker image's ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's @@ -671,27 +628,22 @@ def create_custom_container_training_job( input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_command: Sequence[str] :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_args: Sequence[str] :param model_serving_container_environment_variables: The environment variables that are to be present in the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. - :type model_serving_container_environment_variables: Dict[str, str] :param model_serving_container_ports: Declaration of ports that are exposed by the container. This field is primarily informational, it gives Vertex AI information about the network connections the container uses. Listing or not a port here has no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. - :type model_serving_container_ports: Sequence[int] :param model_description: The description of the Model. - :type model_description: str :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single instance, which are used in @@ -706,7 +658,6 @@ def create_custom_container_training_job( and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_instance_schema_uri: str :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the parameters of prediction and explanation via @@ -722,7 +673,6 @@ def create_custom_container_training_job( immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_parameters_schema_uri: str :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single prediction produced by this Model, which are returned via @@ -737,11 +687,8 @@ def create_custom_container_training_job( and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_prediction_schema_uri: str :param project_id: Project to run training in. - :type project_id: str :param region: Location to run training in. - :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 @@ -751,7 +698,6 @@ def create_custom_container_training_job( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type labels: Dict[str, str] :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the training pipeline. Has the form: @@ -763,7 +709,6 @@ def create_custom_container_training_job( Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the form: @@ -772,12 +717,8 @@ def create_custom_container_training_job( resource is created. If set, the trained Model will be secured by this key. - :type model_encryption_spec_key_name: Optional[str] :param staging_bucket: Bucket used to stage source and training artifacts. - :type staging_bucket: str :param dataset: Vertex AI to fit this training against. - :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, - datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] @@ -794,13 +735,11 @@ def create_custom_container_training_job( ``annotations_filter`` and ``annotation_schema_uri``. - :type annotation_schema_uri: str :param model_display_name: If the script produces a managed Vertex AI Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon creation, the job's display_name is used. - :type model_display_name: str :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 @@ -810,7 +749,6 @@ def create_custom_container_training_job( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type model_labels: Dict[str, str] :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -823,15 +761,12 @@ def create_custom_container_training_job( - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/ - :type base_output_dir: str :param service_account: Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. - :type service_account: str :param network: The full name of the Compute Engine network to which the job should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - :type network: str :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -846,64 +781,49 @@ def create_custom_container_training_job( - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" - :type bigquery_destination: str :param args: Command line arguments to be passed to the Python script. - :type args: List[Unions[str, int, float]] :param environment_variables: Environment variables to be passed to the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be provisioned as a worker replica pool. - :type replica_count: int :param machine_type: The type of machine to use for training. - :type machine_type: str :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 - :type accelerator_type: str :param accelerator_count: The number of accelerators to attach to a worker replica. - :type accelerator_count: int :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or `pd-standard` (Persistent Disk Hard Disk Drive). - :type boot_disk_type: str :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. - :type boot_disk_size_gb: int :param training_fraction_split: Optional. The fraction of the input data that is to be used to train the Model. This is ignored if Dataset is not provided. - :type training_fraction_split: float :param validation_fraction_split: Optional. The fraction of the input data that is to be used to validate the Model. This is ignored if Dataset is not provided. - :type validation_fraction_split: float :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate the Model. This is ignored if Dataset is not provided. - :type test_fraction_split: float :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to train the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type training_filter_split: str :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to validate the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type validation_filter_split: str :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to test the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type test_filter_split: str :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or value in the column) must be one of {``training``, @@ -913,7 +833,6 @@ def create_custom_container_training_job( ignored by the pipeline. Supported only for tabular and time series Datasets. - :type predefined_split_column_name: str :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where @@ -922,17 +841,14 @@ def create_custom_container_training_job( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. - :type timestamp_split_column_name: str :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :type tensorboard: str :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. - :type sync: bool """ self._job = self.get_custom_container_training_job( project=project_id, @@ -1052,25 +968,18 @@ def create_custom_python_package_training_job( Create Custom Python Package Training Job :param display_name: Required. The user-defined name of this TrainingPipeline. - :type display_name: str :param python_package_gcs_uri: Required: GCS location of the training python package. - :type python_package_gcs_uri: str :param python_module_name: Required: The module name of the training python package. - :type python_module_name: str :param container_uri: Required: Uri of the training container image in the GCR. - :type container_uri: str :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI of the Model serving container suitable for serving the model produced by the training script. - :type model_serving_container_image_uri: str :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An HTTP path to send prediction requests to the container, and which must be supported by it. If not specified a default HTTP path will be used by Vertex AI. - :type model_serving_container_predict_route: str :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an HTTP path to send health check requests to the container, and which must be supported by it. If not specified a standard HTTP path will be used by AI Platform. - :type model_serving_container_health_route: str :param model_serving_container_command: The command with which the container is run. Not executed within a shell. The Docker image's ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's @@ -1078,27 +987,22 @@ def create_custom_python_package_training_job( input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_command: Sequence[str] :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_args: Sequence[str] :param model_serving_container_environment_variables: The environment variables that are to be present in the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. - :type model_serving_container_environment_variables: Dict[str, str] :param model_serving_container_ports: Declaration of ports that are exposed by the container. This field is primarily informational, it gives Vertex AI information about the network connections the container uses. Listing or not a port here has no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. - :type model_serving_container_ports: Sequence[int] :param model_description: The description of the Model. - :type model_description: str :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single instance, which are used in @@ -1113,7 +1017,6 @@ def create_custom_python_package_training_job( and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_instance_schema_uri: str :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the parameters of prediction and explanation via @@ -1129,7 +1032,6 @@ def create_custom_python_package_training_job( immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_parameters_schema_uri: str :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single prediction produced by this Model, which are returned via @@ -1144,11 +1046,8 @@ def create_custom_python_package_training_job( and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_prediction_schema_uri: str :param project_id: Project to run training in. - :type project_id: str :param region: Location to run training in. - :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 @@ -1158,7 +1057,6 @@ def create_custom_python_package_training_job( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type labels: Dict[str, str] :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the training pipeline. Has the form: @@ -1170,7 +1068,6 @@ def create_custom_python_package_training_job( Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the form: @@ -1179,12 +1076,8 @@ def create_custom_python_package_training_job( resource is created. If set, the trained Model will be secured by this key. - :type model_encryption_spec_key_name: Optional[str] :param staging_bucket: Bucket used to stage source and training artifacts. - :type staging_bucket: str :param dataset: Vertex AI to fit this training against. - :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, - datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] @@ -1201,13 +1094,11 @@ def create_custom_python_package_training_job( ``annotations_filter`` and ``annotation_schema_uri``. - :type annotation_schema_uri: str :param model_display_name: If the script produces a managed Vertex AI Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon creation, the job's display_name is used. - :type model_display_name: str :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 @@ -1217,7 +1108,6 @@ def create_custom_python_package_training_job( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type model_labels: Dict[str, str] :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -1229,16 +1119,12 @@ def create_custom_python_package_training_job( i.e. /checkpoints/ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/ - - :type base_output_dir: str :param service_account: Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. - :type service_account: str :param network: The full name of the Compute Engine network to which the job should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - :type network: str :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -1253,64 +1139,49 @@ def create_custom_python_package_training_job( - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" - :type bigquery_destination: str :param args: Command line arguments to be passed to the Python script. - :type args: List[Unions[str, int, float]] :param environment_variables: Environment variables to be passed to the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be provisioned as a worker replica pool. - :type replica_count: int :param machine_type: The type of machine to use for training. - :type machine_type: str :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 - :type accelerator_type: str :param accelerator_count: The number of accelerators to attach to a worker replica. - :type accelerator_count: int :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or `pd-standard` (Persistent Disk Hard Disk Drive). - :type boot_disk_type: str :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. - :type boot_disk_size_gb: int :param training_fraction_split: Optional. The fraction of the input data that is to be used to train the Model. This is ignored if Dataset is not provided. - :type training_fraction_split: float :param validation_fraction_split: Optional. The fraction of the input data that is to be used to validate the Model. This is ignored if Dataset is not provided. - :type validation_fraction_split: float :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate the Model. This is ignored if Dataset is not provided. - :type test_fraction_split: float :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to train the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type training_filter_split: str :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to validate the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type validation_filter_split: str :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to test the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type test_filter_split: str :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or value in the column) must be one of {``training``, @@ -1320,7 +1191,6 @@ def create_custom_python_package_training_job( ignored by the pipeline. Supported only for tabular and time series Datasets. - :type predefined_split_column_name: str :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where @@ -1329,17 +1199,14 @@ def create_custom_python_package_training_job( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. - :type timestamp_split_column_name: str :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :type tensorboard: str :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. - :type sync: bool """ self._job = self.get_custom_python_package_training_job( project=project_id, @@ -1460,25 +1327,18 @@ def create_custom_training_job( Create Custom Training Job :param display_name: Required. The user-defined name of this TrainingPipeline. - :type display_name: str :param script_path: Required. Local path to training script. - :type script_path: str :param container_uri: Required: Uri of the training container image in the GCR. - :type container_uri: str :param requirements: List of python packages dependencies of script. - :type requirements: Sequence[str] :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI of the Model serving container suitable for serving the model produced by the training script. - :type model_serving_container_image_uri: str :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An HTTP path to send prediction requests to the container, and which must be supported by it. If not specified a default HTTP path will be used by Vertex AI. - :type model_serving_container_predict_route: str :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an HTTP path to send health check requests to the container, and which must be supported by it. If not specified a standard HTTP path will be used by AI Platform. - :type model_serving_container_health_route: str :param model_serving_container_command: The command with which the container is run. Not executed within a shell. The Docker image's ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's @@ -1486,27 +1346,22 @@ def create_custom_training_job( input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_command: Sequence[str] :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_args: Sequence[str] :param model_serving_container_environment_variables: The environment variables that are to be present in the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. - :type model_serving_container_environment_variables: Dict[str, str] :param model_serving_container_ports: Declaration of ports that are exposed by the container. This field is primarily informational, it gives Vertex AI information about the network connections the container uses. Listing or not a port here has no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. - :type model_serving_container_ports: Sequence[int] :param model_description: The description of the Model. - :type model_description: str :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single instance, which are used in @@ -1521,7 +1376,6 @@ def create_custom_training_job( and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_instance_schema_uri: str :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the parameters of prediction and explanation via @@ -1537,7 +1391,6 @@ def create_custom_training_job( immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_parameters_schema_uri: str :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single prediction produced by this Model, which are returned via @@ -1552,11 +1405,8 @@ def create_custom_training_job( and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_prediction_schema_uri: str :param project_id: Project to run training in. - :type project_id: str :param region: Location to run training in. - :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 @@ -1566,7 +1416,6 @@ def create_custom_training_job( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type labels: Dict[str, str] :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the training pipeline. Has the form: @@ -1578,7 +1427,6 @@ def create_custom_training_job( Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the form: @@ -1587,12 +1435,8 @@ def create_custom_training_job( resource is created. If set, the trained Model will be secured by this key. - :type model_encryption_spec_key_name: Optional[str] :param staging_bucket: Bucket used to stage source and training artifacts. - :type staging_bucket: str :param dataset: Vertex AI to fit this training against. - :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, - datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] @@ -1609,13 +1453,11 @@ def create_custom_training_job( ``annotations_filter`` and ``annotation_schema_uri``. - :type annotation_schema_uri: str :param model_display_name: If the script produces a managed Vertex AI Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon creation, the job's display_name is used. - :type model_display_name: str :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 @@ -1625,7 +1467,6 @@ def create_custom_training_job( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type model_labels: Dict[str, str] :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -1637,16 +1478,12 @@ def create_custom_training_job( i.e. /checkpoints/ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/ - - :type base_output_dir: str :param service_account: Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. - :type service_account: str :param network: The full name of the Compute Engine network to which the job should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - :type network: str :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -1661,64 +1498,49 @@ def create_custom_training_job( - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" - :type bigquery_destination: str :param args: Command line arguments to be passed to the Python script. - :type args: List[Unions[str, int, float]] :param environment_variables: Environment variables to be passed to the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be provisioned as a worker replica pool. - :type replica_count: int :param machine_type: The type of machine to use for training. - :type machine_type: str :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 - :type accelerator_type: str :param accelerator_count: The number of accelerators to attach to a worker replica. - :type accelerator_count: int :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or `pd-standard` (Persistent Disk Hard Disk Drive). - :type boot_disk_type: str :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. - :type boot_disk_size_gb: int :param training_fraction_split: Optional. The fraction of the input data that is to be used to train the Model. This is ignored if Dataset is not provided. - :type training_fraction_split: float :param validation_fraction_split: Optional. The fraction of the input data that is to be used to validate the Model. This is ignored if Dataset is not provided. - :type validation_fraction_split: float :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate the Model. This is ignored if Dataset is not provided. - :type test_fraction_split: float :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to train the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type training_filter_split: str :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to validate the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type validation_filter_split: str :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to test the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type test_filter_split: str :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or value in the column) must be one of {``training``, @@ -1728,7 +1550,6 @@ def create_custom_training_job( ignored by the pipeline. Supported only for tabular and time series Datasets. - :type predefined_split_column_name: str :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where @@ -1737,17 +1558,14 @@ def create_custom_training_job( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. - :type timestamp_split_column_name: str :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :type tensorboard: str :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. - :type sync: bool """ self._job = self.get_custom_training_job( project=project_id, @@ -1819,17 +1637,11 @@ def delete_pipeline_job( Deletes a PipelineJob. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. - :type pipeline_job: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) name = client.pipeline_job_path(project_id, region, pipeline_job) @@ -1858,17 +1670,11 @@ def delete_training_pipeline( Deletes a TrainingPipeline. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted. - :type training_pipeline: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) name = client.training_pipeline_path(project_id, region, training_pipeline) @@ -1897,17 +1703,11 @@ def delete_custom_job( Deletes a CustomJob. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param custom_job: Required. The name of the CustomJob to delete. - :type custom_job: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_job_service_client(region) name = client.custom_job_path(project_id, region, custom_job) @@ -1936,17 +1736,11 @@ def get_pipeline_job( Gets a PipelineJob. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param pipeline_job: Required. The name of the PipelineJob resource. - :type pipeline_job: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) name = client.pipeline_job_path(project_id, region, pipeline_job) @@ -1975,17 +1769,11 @@ def get_training_pipeline( Gets a TrainingPipeline. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param training_pipeline: Required. The name of the TrainingPipeline resource. - :type training_pipeline: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) name = client.training_pipeline_path(project_id, region, training_pipeline) @@ -2014,17 +1802,11 @@ def get_custom_job( Gets a CustomJob. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param custom_job: Required. The name of the CustomJob to get. - :type custom_job: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_job_service_client(region) name = JobServiceClient.custom_job_path(project_id, region, custom_job) @@ -2056,9 +1838,7 @@ def list_pipeline_jobs( Lists PipelineJobs in a Location. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param filter: Optional. Lists the PipelineJobs that match the filter expression. The following fields are supported: @@ -2086,15 +1866,12 @@ def list_pipeline_jobs( The syntax to define filter expression is based on https://google.aip.dev/160. - :type filter: str :param page_size: Optional. The standard list page size. - :type page_size: int :param page_token: Optional. The standard list page token. Typically obtained via [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token] of the previous [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs] call. - :type page_token: str :param order_by: Optional. A comma-separated list of fields to order by. The default sort order is in ascending order. Use "desc" after a field name for descending. You can have multiple order_by fields @@ -2111,13 +1888,9 @@ def list_pipeline_jobs( - ``update_time`` - ``end_time`` - ``start_time`` - :type order_by: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) parent = client.common_location_path(project_id, region) @@ -2153,9 +1926,7 @@ def list_training_pipelines( Lists TrainingPipelines in a Location. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param filter: Optional. The standard list filter. Supported fields: - ``display_name`` supports = and !=. @@ -2171,23 +1942,16 @@ def list_training_pipelines( - ``NOT display_name="my_pipeline"`` - ``state="PIPELINE_STATE_FAILED"`` - :type filter: str :param page_size: Optional. The standard list page size. - :type page_size: int :param page_token: Optional. The standard list page token. Typically obtained via [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] of the previous [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] call. - :type page_token: str :param read_mask: Optional. Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_pipeline_service_client(region) parent = client.common_location_path(project_id, region) @@ -2223,9 +1987,7 @@ def list_custom_jobs( Lists CustomJobs in a Location. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param filter: Optional. The standard list filter. Supported fields: - ``display_name`` supports = and !=. @@ -2241,23 +2003,16 @@ def list_custom_jobs( - ``NOT display_name="my_pipeline"`` - ``state="PIPELINE_STATE_FAILED"`` - :type filter: str :param page_size: Optional. The standard list page size. - :type page_size: int :param page_token: Optional. The standard list page token. Typically obtained via [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] of the previous [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] call. - :type page_token: str :param read_mask: Optional. Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_job_service_client(region) parent = JobServiceClient.common_location_path(project_id, region) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py index ce613c3de0dcf..4a68c343f7349 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py @@ -75,17 +75,11 @@ def create_dataset( Creates a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The Dataset to create. - :type dataset: google.cloud.aiplatform_v1.types.Dataset :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) parent = client.common_location_path(project_id, region) @@ -115,17 +109,11 @@ def delete_dataset( Deletes a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The ID of the Dataset to delete. - :type dataset: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) name = client.dataset_path(project_id, region, dataset) @@ -155,19 +143,12 @@ def export_data( Exports data from a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The ID of the Dataset to export. - :type dataset: str :param export_config: Required. The desired output location. - :type export_config: google.cloud.aiplatform_v1.types.ExportDataConfig :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) name = client.dataset_path(project_id, region, dataset) @@ -199,21 +180,13 @@ def get_annotation_spec( Gets an AnnotationSpec. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The ID of the Dataset. - :type dataset: str :param annotation_spec: The ID of the AnnotationSpec resource. - :type annotation_spec: str :param read_mask: Optional. Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) name = client.annotation_spec_path(project_id, region, dataset, annotation_spec) @@ -244,19 +217,12 @@ def get_dataset( Gets a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The ID of the Dataset to export. - :type dataset: str :param read_mask: Optional. Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) name = client.dataset_path(project_id, region, dataset) @@ -287,20 +253,13 @@ def import_data( Imports data into a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The ID of the Dataset to import. - :type dataset: str :param import_configs: Required. The desired input locations. The contents of all input locations will be imported in one batch. - :type import_configs: Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig] :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) name = client.dataset_path(project_id, region, dataset) @@ -336,30 +295,18 @@ def list_annotations( Lists Annotations belongs to a data item :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The ID of the Dataset. - :type dataset: str :param data_item: Required. The ID of the DataItem to list Annotations from. - :type data_item: str :param filter: The standard list filter. - :type filter: str :param page_size: The standard list page size. - :type page_size: int :param page_token: The standard list page token. - :type page_token: str :param read_mask: Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. - :type order_by: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) parent = client.data_item_path(project_id, region, dataset, data_item) @@ -398,28 +345,17 @@ def list_data_items( Lists DataItems in a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset: Required. The ID of the Dataset. - :type dataset: str :param filter: The standard list filter. - :type filter: str :param page_size: The standard list page size. - :type page_size: int :param page_token: The standard list page token. - :type page_token: str :param read_mask: Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. - :type order_by: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) parent = client.dataset_path(project_id, region, dataset) @@ -457,26 +393,16 @@ def list_datasets( Lists Datasets in a Location. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param filter: The standard list filter. - :type filter: str :param page_size: The standard list page size. - :type page_size: int :param page_token: The standard list page token. - :type page_token: str :param read_mask: Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. - :type order_by: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) parent = client.common_location_path(project_id, region) @@ -511,21 +437,13 @@ def update_dataset( Updates a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset_id: Required. The ID of the Dataset. - :type dataset_id: str :param dataset: Required. The Dataset which replaces the resource on the server. - :type dataset: google.cloud.aiplatform_v1.types.Dataset :param update_mask: Required. The update mask applies to the resource. - :type update_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] """ client = self.get_dataset_service_client(region) dataset["name"] = client.dataset_path(project_id, region, dataset_id) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 5fece7608b1d7..875186b750ba2 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -208,28 +208,20 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): """Create Custom Container Training job :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param display_name: Required. The user-defined name of this TrainingPipeline. - :type display_name: str :param command: The command to be invoked when the container is started. It overrides the entrypoint instruction in Dockerfile when provided - :type command: Sequence[str] :param container_uri: Required: Uri of the training container image in the GCR. - :type container_uri: str :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI of the Model serving container suitable for serving the model produced by the training script. - :type model_serving_container_image_uri: str :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An HTTP path to send prediction requests to the container, and which must be supported by it. If not specified a default HTTP path will be used by Vertex AI. - :type model_serving_container_predict_route: str :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an HTTP path to send health check requests to the container, and which must be supported by it. If not specified a standard HTTP path will be used by AI Platform. - :type model_serving_container_health_route: str :param model_serving_container_command: The command with which the container is run. Not executed within a shell. The Docker image's ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's @@ -237,27 +229,22 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_command: Sequence[str] :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_args: Sequence[str] :param model_serving_container_environment_variables: The environment variables that are to be present in the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. - :type model_serving_container_environment_variables: Dict[str, str] :param model_serving_container_ports: Declaration of ports that are exposed by the container. This field is primarily informational, it gives Vertex AI information about the network connections the container uses. Listing or not a port here has no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. - :type model_serving_container_ports: Sequence[int] :param model_description: The description of the Model. - :type model_description: str :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single instance, which are used in @@ -272,7 +259,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_instance_schema_uri: str :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the parameters of prediction and explanation via @@ -288,7 +274,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_parameters_schema_uri: str :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single prediction produced by this Model, which are returned via @@ -303,11 +288,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_prediction_schema_uri: str :param project_id: Project to run training in. - :type project_id: str :param region: Location to run training in. - :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 @@ -317,7 +299,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type labels: Dict[str, str] :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the training pipeline. Has the form: @@ -329,7 +310,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the form: @@ -338,12 +318,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): resource is created. If set, the trained Model will be secured by this key. - :type model_encryption_spec_key_name: Optional[str] :param staging_bucket: Bucket used to stage source and training artifacts. - :type staging_bucket: str :param dataset: Vertex AI to fit this training against. - :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, - datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] @@ -360,13 +336,11 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): ``annotations_filter`` and ``annotation_schema_uri``. - :type annotation_schema_uri: str :param model_display_name: If the script produces a managed Vertex AI Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon creation, the job's display_name is used. - :type model_display_name: str :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 @@ -376,7 +350,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type model_labels: Dict[str, str] :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -388,16 +361,12 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): i.e. /checkpoints/ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/ - - :type base_output_dir: str :param service_account: Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. - :type service_account: str :param network: The full name of the Compute Engine network to which the job should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - :type network: str :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -412,64 +381,49 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" - :type bigquery_destination: str :param args: Command line arguments to be passed to the Python script. - :type args: List[Unions[str, int, float]] :param environment_variables: Environment variables to be passed to the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be provisioned as a worker replica pool. - :type replica_count: int :param machine_type: The type of machine to use for training. - :type machine_type: str :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 - :type accelerator_type: str :param accelerator_count: The number of accelerators to attach to a worker replica. - :type accelerator_count: int :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or `pd-standard` (Persistent Disk Hard Disk Drive). - :type boot_disk_type: str :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. - :type boot_disk_size_gb: int :param training_fraction_split: Optional. The fraction of the input data that is to be used to train the Model. This is ignored if Dataset is not provided. - :type training_fraction_split: float :param validation_fraction_split: Optional. The fraction of the input data that is to be used to validate the Model. This is ignored if Dataset is not provided. - :type validation_fraction_split: float :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate the Model. This is ignored if Dataset is not provided. - :type test_fraction_split: float :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to train the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type training_filter_split: str :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to validate the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type validation_filter_split: str :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to test the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type test_filter_split: str :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or value in the column) must be one of {``training``, @@ -479,7 +433,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): ignored by the pipeline. Supported only for tabular and time series Datasets. - :type predefined_split_column_name: str :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where @@ -488,23 +441,18 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. - :type timestamp_split_column_name: str :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :type tensorboard: str :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. - :type sync: bool :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -513,7 +461,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = [ @@ -602,29 +549,20 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator """Create Custom Python Package Training job :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param display_name: Required. The user-defined name of this TrainingPipeline. - :type display_name: str :param python_package_gcs_uri: Required: GCS location of the training python package. - :type python_package_gcs_uri: str :param python_module_name: Required: The module name of the training python package. - :type python_module_name: str :param container_uri: Required: Uri of the training container image in the GCR. - :type container_uri: str :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI of the Model serving container suitable for serving the model produced by the training script. - :type model_serving_container_image_uri: str :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An HTTP path to send prediction requests to the container, and which must be supported by it. If not specified a default HTTP path will be used by Vertex AI. - :type model_serving_container_predict_route: str :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an HTTP path to send health check requests to the container, and which must be supported by it. If not specified a standard HTTP path will be used by AI Platform. - :type model_serving_container_health_route: str :param model_serving_container_command: The command with which the container is run. Not executed within a shell. The Docker image's ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's @@ -632,27 +570,22 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_command: Sequence[str] :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_args: Sequence[str] :param model_serving_container_environment_variables: The environment variables that are to be present in the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. - :type model_serving_container_environment_variables: Dict[str, str] :param model_serving_container_ports: Declaration of ports that are exposed by the container. This field is primarily informational, it gives Vertex AI information about the network connections the container uses. Listing or not a port here has no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. - :type model_serving_container_ports: Sequence[int] :param model_description: The description of the Model. - :type model_description: str :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single instance, which are used in @@ -667,7 +600,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_instance_schema_uri: str :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the parameters of prediction and explanation via @@ -683,7 +615,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_parameters_schema_uri: str :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single prediction produced by this Model, which are returned via @@ -698,11 +629,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_prediction_schema_uri: str :param project_id: Project to run training in. - :type project_id: str :param region: Location to run training in. - :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 @@ -712,7 +640,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type labels: Dict[str, str] :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the training pipeline. Has the form: @@ -724,7 +651,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the form: @@ -733,12 +659,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator resource is created. If set, the trained Model will be secured by this key. - :type model_encryption_spec_key_name: Optional[str] :param staging_bucket: Bucket used to stage source and training artifacts. - :type staging_bucket: str :param dataset: Vertex AI to fit this training against. - :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, - datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] @@ -755,13 +677,11 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator ``annotations_filter`` and ``annotation_schema_uri``. - :type annotation_schema_uri: str :param model_display_name: If the script produces a managed Vertex AI Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon creation, the job's display_name is used. - :type model_display_name: str :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 @@ -771,7 +691,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type model_labels: Dict[str, str] :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -783,16 +702,12 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator i.e. /checkpoints/ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/ - - :type base_output_dir: str :param service_account: Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. - :type service_account: str :param network: The full name of the Compute Engine network to which the job should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - :type network: str :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -807,64 +722,49 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" - :type bigquery_destination: str :param args: Command line arguments to be passed to the Python script. - :type args: List[Unions[str, int, float]] :param environment_variables: Environment variables to be passed to the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be provisioned as a worker replica pool. - :type replica_count: int :param machine_type: The type of machine to use for training. - :type machine_type: str :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 - :type accelerator_type: str :param accelerator_count: The number of accelerators to attach to a worker replica. - :type accelerator_count: int :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or `pd-standard` (Persistent Disk Hard Disk Drive). - :type boot_disk_type: str :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. - :type boot_disk_size_gb: int :param training_fraction_split: Optional. The fraction of the input data that is to be used to train the Model. This is ignored if Dataset is not provided. - :type training_fraction_split: float :param validation_fraction_split: Optional. The fraction of the input data that is to be used to validate the Model. This is ignored if Dataset is not provided. - :type validation_fraction_split: float :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate the Model. This is ignored if Dataset is not provided. - :type test_fraction_split: float :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to train the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type training_filter_split: str :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to validate the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type validation_filter_split: str :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to test the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type test_filter_split: str :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or value in the column) must be one of {``training``, @@ -874,7 +774,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator ignored by the pipeline. Supported only for tabular and time series Datasets. - :type predefined_split_column_name: str :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where @@ -883,23 +782,18 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. - :type timestamp_split_column_name: str :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :type tensorboard: str :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. - :type sync: bool :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -908,7 +802,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = [ @@ -999,29 +892,20 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): """Create Custom Training job :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param display_name: Required. The user-defined name of this TrainingPipeline. - :type display_name: str :param script_path: Required. Local path to training script. - :type script_path: str :param container_uri: Required: Uri of the training container image in the GCR. - :type container_uri: str :param requirements: List of python packages dependencies of script. - :type requirements: Sequence[str] :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI of the Model serving container suitable for serving the model produced by the training script. - :type model_serving_container_image_uri: str :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An HTTP path to send prediction requests to the container, and which must be supported by it. If not specified a default HTTP path will be used by Vertex AI. - :type model_serving_container_predict_route: str :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an HTTP path to send health check requests to the container, and which must be supported by it. If not specified a standard HTTP path will be used by AI Platform. - :type model_serving_container_health_route: str :param model_serving_container_command: The command with which the container is run. Not executed within a shell. The Docker image's ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's @@ -1029,27 +913,22 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_command: Sequence[str] :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container's environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. - :type model_serving_container_args: Sequence[str] :param model_serving_container_environment_variables: The environment variables that are to be present in the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. - :type model_serving_container_environment_variables: Dict[str, str] :param model_serving_container_ports: Declaration of ports that are exposed by the container. This field is primarily informational, it gives Vertex AI information about the network connections the container uses. Listing or not a port here has no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. - :type model_serving_container_ports: Sequence[int] :param model_description: The description of the Model. - :type model_description: str :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single instance, which are used in @@ -1064,7 +943,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_instance_schema_uri: str :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the parameters of prediction and explanation via @@ -1080,7 +958,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_parameters_schema_uri: str :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud Storage describing the format of a single prediction produced by this Model, which are returned via @@ -1095,11 +972,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - :type model_prediction_schema_uri: str :param project_id: Project to run training in. - :type project_id: str :param region: Location to run training in. - :type region: str :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 @@ -1109,7 +983,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type labels: Dict[str, str] :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the training pipeline. Has the form: @@ -1121,7 +994,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload`` is not set separately. - :type training_encryption_spec_key_name: Optional[str] :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the model. Has the form: @@ -1130,12 +1002,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): resource is created. If set, the trained Model will be secured by this key. - :type model_encryption_spec_key_name: Optional[str] :param staging_bucket: Bucket used to stage source and training artifacts. - :type staging_bucket: str :param dataset: Vertex AI to fit this training against. - :type dataset: Union[datasets.ImageDataset, datasets.TabularDataset, datasets.TextDataset, - datasets.VideoDataset,] :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 [Schema Object] @@ -1152,13 +1020,11 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): ``annotations_filter`` and ``annotation_schema_uri``. - :type annotation_schema_uri: str :param model_display_name: If the script produces a managed Vertex AI Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon creation, the job's display_name is used. - :type model_display_name: str :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 @@ -1168,7 +1034,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - :type model_labels: Dict[str, str] :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -1180,16 +1045,12 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): i.e. /checkpoints/ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/ - - :type base_output_dir: str :param service_account: Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. - :type service_account: str :param network: The full name of the Compute Engine network to which the job should be peered. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - :type network: str :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -1204,64 +1065,49 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" - :type bigquery_destination: str :param args: Command line arguments to be passed to the Python script. - :type args: List[Unions[str, int, float]] :param environment_variables: Environment variables to be passed to the container. Should be a dictionary where keys are environment variable names and values are environment variable values for those names. At most 10 environment variables can be specified. The Name of the environment variable must be unique. - :type environment_variables: Dict[str, str] :param replica_count: The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be provisioned as a worker replica pool. - :type replica_count: int :param machine_type: The type of machine to use for training. - :type machine_type: str :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 - :type accelerator_type: str :param accelerator_count: The number of accelerators to attach to a worker replica. - :type accelerator_count: int :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or `pd-standard` (Persistent Disk Hard Disk Drive). - :type boot_disk_type: str :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. - :type boot_disk_size_gb: int :param training_fraction_split: Optional. The fraction of the input data that is to be used to train the Model. This is ignored if Dataset is not provided. - :type training_fraction_split: float :param validation_fraction_split: Optional. The fraction of the input data that is to be used to validate the Model. This is ignored if Dataset is not provided. - :type validation_fraction_split: float :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate the Model. This is ignored if Dataset is not provided. - :type test_fraction_split: float :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to train the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type training_filter_split: str :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to validate the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type validation_filter_split: str :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this filter are used to test the Model. A filter with same syntax as the one used in DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the FilterSplit filters, then it is assigned to the first set that applies to it in the training, validation, test order. This is ignored if Dataset is not provided. - :type test_filter_split: str :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or value in the column) must be one of {``training``, @@ -1271,7 +1117,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): ignored by the pipeline. Supported only for tabular and time series Datasets. - :type predefined_split_column_name: str :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns. The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time` format, where @@ -1280,23 +1125,18 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. - :type timestamp_split_column_name: str :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :type tensorboard: str :param sync: Whether to execute the AI Platform job synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. - :type sync: bool :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -1305,7 +1145,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = [ @@ -1398,25 +1237,16 @@ class DeleteCustomTrainingJobOperator(BaseOperator): """Deletes a CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob. :param training_pipeline_id: Required. The name of the TrainingPipeline resource to be deleted. - :type training_pipeline_id: str :param custom_job_id: Required. The name of the CustomJob to delete. - :type custom_job_id: str :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -1425,7 +1255,6 @@ class DeleteCustomTrainingJobOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "project_id", "impersonation_chain") @@ -1497,9 +1326,7 @@ class ListCustomTrainingJobOperator(BaseOperator): """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param filter: Optional. The standard list filter. Supported fields: - ``display_name`` supports = and !=. @@ -1515,29 +1342,20 @@ class ListCustomTrainingJobOperator(BaseOperator): - ``NOT display_name="my_pipeline"`` - ``state="PIPELINE_STATE_FAILED"`` - :type filter: str :param page_size: Optional. The standard list page size. - :type page_size: int :param page_token: Optional. The standard list page token. Typically obtained via [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token] of the previous [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines] call. - :type page_token: str :param read_mask: Optional. Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -1546,7 +1364,6 @@ class ListCustomTrainingJobOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = [ diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index f9c917f72402c..1def925018aa8 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -80,24 +80,16 @@ class CreateDatasetOperator(BaseOperator): Creates a Dataset. :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. - :type project_id: str :param region: Required. The Cloud Dataproc region in which to handle the request. - :type region: str :param dataset: Required. The Dataset to create. This corresponds to the ``dataset`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - :type dataset: google.cloud.aiplatform_v1.types.Dataset :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -106,7 +98,6 @@ class CreateDatasetOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "project_id", "impersonation_chain") @@ -177,23 +168,15 @@ class GetDatasetOperator(BaseOperator): Get a Dataset. :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. - :type project_id: str :param region: Required. The Cloud Dataproc region in which to handle the request. - :type region: str :param dataset_id: Required. The ID of the Dataset to get. - :type dataset_id: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -202,7 +185,6 @@ class GetDatasetOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") @@ -273,23 +255,15 @@ class DeleteDatasetOperator(BaseOperator): Deletes a Dataset. :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. - :type project_id: str :param region: Required. The Cloud Dataproc region in which to handle the request. - :type region: str :param dataset_id: Required. The ID of the Dataset to delete. - :type dataset_id: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -298,7 +272,6 @@ class DeleteDatasetOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") @@ -356,25 +329,16 @@ class ExportDataOperator(BaseOperator): Exports data from a Dataset. :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. - :type project_id: str :param region: Required. The Cloud Dataproc region in which to handle the request. - :type region: str :param dataset_id: Required. The ID of the Dataset to delete. - :type dataset_id: str :param export_config: Required. The desired output location. - :type export_config: google.cloud.aiplatform_v1.types.ExportDataConfig :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -383,7 +347,6 @@ class ExportDataOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") @@ -441,26 +404,17 @@ class ImportDataOperator(BaseOperator): Imports data into a Dataset. :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. - :type project_id: str :param region: Required. The Cloud Dataproc region in which to handle the request. - :type region: str :param dataset_id: Required. The ID of the Dataset to delete. - :type dataset_id: str :param import_configs: Required. The desired input locations. The contents of all input locations will be imported in one batch. - :type import_configs: Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig] :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -469,7 +423,6 @@ class ImportDataOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "dataset_id", "project_id", "impersonation_chain") @@ -527,32 +480,20 @@ class ListDatasetsOperator(BaseOperator): Lists Datasets in a Location. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param filter: The standard list filter. - :type filter: str :param page_size: The standard list page size. - :type page_size: int :param page_token: The standard list page token. - :type page_token: str :param read_mask: Mask specifying which fields to read. - :type read_mask: google.protobuf.field_mask_pb2.FieldMask :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. - :type order_by: str :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -561,7 +502,6 @@ class ListDatasetsOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "project_id", "impersonation_chain") @@ -631,27 +571,17 @@ class UpdateDatasetOperator(BaseOperator): Updates a Dataset. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :type project_id: str :param region: Required. The ID of the Google Cloud region that the service belongs to. - :type region: str :param dataset_id: Required. The ID of the Dataset to update. - :type dataset_id: str :param dataset: Required. The Dataset which replaces the resource on the server. - :type dataset: google.cloud.aiplatform_v1.types.Dataset :param update_mask: Required. The update mask applies to the resource. - :type update_mask: google.protobuf.field_mask_pb2.FieldMask :param retry: Designation of what errors, if any, should be retried. - :type retry: google.api_core.retry.Retry :param timeout: The timeout for this request. - :type timeout: float :param metadata: Strings which should be sent along with the request as metadata. - :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. - :type delegate_to: str :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -660,7 +590,6 @@ class UpdateDatasetOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :type impersonation_chain: Union[str, Sequence[str]] """ template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")