From 06d0d6ab52866c9c944d6c424330867625cfe833 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sun, 18 Jun 2023 12:07:52 +0200 Subject: [PATCH] Provide missing project id and creds for TabularDataset Signed-off-by: Hussein Awala --- .../google/cloud/operators/vertex_ai/auto_ml.py | 7 ++++++- .../google/cloud/operators/test_vertex_ai.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index 80e573a1d663c..a3bcb2158d340 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -352,11 +352,16 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + credentials, _ = self.hook.get_credentials_and_project_id() model, training_id = self.hook.create_auto_ml_tabular_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, - dataset=datasets.TabularDataset(dataset_name=self.dataset_id), + dataset=datasets.TabularDataset( + dataset_name=self.dataset_id, + project=self.project_id, + credentials=credentials, + ), target_column=self.target_column, optimization_prediction_type=self.optimization_prediction_type, optimization_objective=self.optimization_objective, diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index f4b3ad154dda1..f2921942ab41f 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -17,6 +17,7 @@ from __future__ import annotations from unittest import mock +from unittest.mock import MagicMock from google.api_core.gapic_v1.method import DEFAULT from google.api_core.retry import Retry @@ -783,7 +784,12 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator: @mock.patch("google.cloud.aiplatform.datasets.TabularDataset") @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook")) def test_execute(self, mock_hook, mock_dataset): - mock_hook.return_value.create_auto_ml_tabular_training_job.return_value = (None, "training_id") + mock_hook.return_value = MagicMock( + **{ + "create_auto_ml_tabular_training_job.return_value": (None, "training_id"), + "get_credentials_and_project_id.return_value": ("creds", "project_id"), + } + ) op = CreateAutoMLTabularTrainingJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -798,7 +804,9 @@ def test_execute(self, mock_hook, mock_dataset): ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) - mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID) + mock_dataset.assert_called_once_with( + dataset_name=TEST_DATASET_ID, project=GCP_PROJECT, credentials="creds" + ) mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION,