From 12824a9c765fe9b3fb8609630b3d3a0c65ead32d Mon Sep 17 00:00:00 2001 From: romsharon98 Date: Sun, 21 Jan 2024 18:51:10 +0200 Subject: [PATCH] fix templating field to super constructor --- .pre-commit-config.yaml | 1 - .../cloud/operators/vertex_ai/auto_ml.py | 28 ++++++++++++++++--- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf1f8218cfabd..bdff078f65823 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -310,7 +310,6 @@ repos: ^airflow\/providers\/databricks\/operators\/databricks\.py$| ^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service\.py$| ^airflow\/providers\/google\/cloud\/transfers\/bigquery_to_mysql\.py$| - ^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/auto_ml\.py$| ^airflow\/providers\/amazon\/aws\/transfers\/redshift_to_s3\.py$| ^airflow\/providers\/google\/cloud\/operators\/compute\.py$| ^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/custom_job\.py$| diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index 465677a81b3b1..c815fa49bbdb4 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -133,9 +133,14 @@ def __init__( quantiles: list[float] | None = None, validation_options: str | None = None, budget_milli_node_hours: int = 1000, + region: str, + impersonation_chain: str | Sequence[str] | None = None, + parent_model: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__( + region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs + ) self.dataset_id = dataset_id self.target_column = target_column self.time_column = time_column @@ -252,9 +257,14 @@ def __init__( test_filter_split: str | None = None, budget_milli_node_hours: int | None = None, disable_early_stopping: bool = False, + region: str, + impersonation_chain: str | Sequence[str] | None = None, + parent_model: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__( + region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs + ) self.dataset_id = dataset_id self.prediction_type = prediction_type self.multi_label = multi_label @@ -345,9 +355,14 @@ def __init__( export_evaluated_data_items: bool = False, export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, + region: str, + impersonation_chain: str | Sequence[str] | None = None, + parent_model: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__( + region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs + ) self.dataset_id = dataset_id self.target_column = target_column self.optimization_prediction_type = optimization_prediction_type @@ -529,9 +544,14 @@ def __init__( model_type: str = "CLOUD", training_filter_split: str | None = None, test_filter_split: str | None = None, + region: str, + impersonation_chain: str | Sequence[str] | None = None, + parent_model: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__( + region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs + ) self.dataset_id = dataset_id self.prediction_type = prediction_type self.model_type = model_type