diff --git a/providers/google/src/airflow/providers/google/ads/operators/ads.py b/providers/google/src/airflow/providers/google/ads/operators/ads.py index b8ef110c19d17..f6aaf1df3fda2 100644 --- a/providers/google/src/airflow/providers/google/ads/operators/ads.py +++ b/providers/google/src/airflow/providers/google/ads/operators/ads.py @@ -24,9 +24,9 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.ads.hooks.ads import GoogleAdsHook from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py b/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py index 591eaf3b4f0da..baa99c5b4aa5d 100644 --- a/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py +++ b/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py @@ -22,9 +22,9 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.ads.hooks.ads import GoogleAdsHook from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/links/base.py b/providers/google/src/airflow/providers/google/cloud/links/base.py index b24b780c61318..d543b495ab10c 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/base.py +++ b/providers/google/src/airflow/providers/google/cloud/links/base.py @@ -19,22 +19,23 @@ from typing import TYPE_CHECKING, ClassVar -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS - -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator - from airflow.sdk import BaseSensorOperator - from airflow.utils.context import Context +from airflow.providers.google.version_compat import ( + AIRFLOW_V_3_0_PLUS, + BaseOperator, + BaseOperatorLink, + BaseSensorOperator, +) if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperatorLink from airflow.sdk.execution_time.xcom import XCom else: - from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] from airflow.models.xcom import XCom # type: ignore[no-redef] +if TYPE_CHECKING: + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + from airflow.utils.context import Context + BASE_LINK = "https://console.cloud.google.com" diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py index 832c66bf132e0..67be746ff4d4a 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py @@ -26,19 +26,20 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.google.version_compat import ( + AIRFLOW_V_3_0_PLUS, + BaseOperator, + BaseOperatorLink, +) if TYPE_CHECKING: - from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperatorLink from airflow.sdk.execution_time.xcom import XCom else: - from airflow.models import XCom # type: ignore[no-redef] - from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] + from airflow.models.xcom import XCom # type: ignore[no-redef] def __getattr__(name: str) -> Any: @@ -94,16 +95,16 @@ class DataprocLink(BaseOperatorLink): @staticmethod def persist( context: Context, - task_instance, url: str, resource: str, + region: str, + project_id: str, ): - task_instance.xcom_push( - context=context, + context["task_instance"].xcom_push( key=DataprocLink.key, value={ - "region": task_instance.region, - "project_id": task_instance.project_id, + "region": region, + "project_id": project_id, "url": url, "resource": resource, }, @@ -147,14 +148,13 @@ class DataprocListLink(BaseOperatorLink): @staticmethod def persist( context: Context, - task_instance, url: str, + project_id: str, ): - task_instance.xcom_push( - context=context, + context["task_instance"].xcom_push( key=DataprocListLink.key, value={ - "project_id": task_instance.project_id, + "project_id": project_id, "url": url, }, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/automl.py b/providers/google/src/airflow/providers/google/cloud/operators/automl.py index 648cee403be6a..7d2cd4ee49fa1 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/automl.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/automl.py @@ -163,7 +163,7 @@ def execute(self, context: Context): model_id = hook.extract_object_id(result) self.log.info("Model is created, model_id: %s", model_id) - self.xcom_push(context, key="model_id", value=model_id) + context["task_instance"].xcom_push(key="model_id", value=model_id) if project_id: TranslationLegacyModelLink.persist( context=context, @@ -415,7 +415,7 @@ def execute(self, context: Context): dataset_id = hook.extract_object_id(result) self.log.info("Creating completed. Dataset id: %s", dataset_id) - self.xcom_push(context, key="dataset_id", value=dataset_id) + context["task_instance"].xcom_push(key="dataset_id", value=dataset_id) project_id = self.project_id or hook.project_id if project_id: TranslationLegacyDatasetLink.persist( @@ -1248,8 +1248,7 @@ def execute(self, context: Context): result.append(Dataset.to_dict(dataset)) self.log.info("Datasets obtained.") - self.xcom_push( - context, + context["task_instance"].xcom_push( key="dataset_id_list", value=[hook.extract_object_id(d) for d in result], ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py b/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py index 387a2ab55f32c..b06dea7bc0283 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py @@ -141,7 +141,7 @@ def execute(self, context: Context): result = TransferConfig.to_dict(response) self.log.info("Created DTS transfer config %s", get_object_id(result)) - self.xcom_push(context, key="transfer_config_id", value=get_object_id(result)) + context["ti"].xcom_push(key="transfer_config_id", value=get_object_id(result)) # don't push AWS secret in XCOM result.get("params", {}).pop("secret_access_key", None) result.get("params", {}).pop("access_key_id", None) @@ -335,7 +335,7 @@ def execute(self, context: Context): result = StartManualTransferRunsResponse.to_dict(response) run_id = get_object_id(result["runs"][0]) - self.xcom_push(context, key="run_id", value=run_id) + context["ti"].xcom_push(key="run_id", value=run_id) if not self.deferrable: # Save as attribute for further use by OpenLineage diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py index fb11a7276fb01..4e7e2cf6ea9d4 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py @@ -23,7 +23,7 @@ from google.api_core.gapic_v1.method import DEFAULT -from airflow.models import BaseOperator +from airflow.providers.google.version_compat import BaseOperator class GoogleCloudBaseOperator(BaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py index 78d4bc2d6065c..26f1444df36c2 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py @@ -125,7 +125,7 @@ def execute(self, context: Context): location=self.location, ) - self.xcom_push(context, key="id", value=result.id) + context["task_instance"].xcom_push(key="id", value=result.id) project_id = self.project_id or hook.project_id if project_id: CloudBuildLink.persist( @@ -235,7 +235,7 @@ def execute(self, context: Context): metadata=self.metadata, location=self.location, ) - self.xcom_push(context, key="id", value=self.id_) + context["task_instance"].xcom_push(key="id", value=self.id_) if not self.wait: return Build.to_dict( hook.get_build(id_=self.id_, project_id=self.project_id, location=self.location) @@ -358,7 +358,7 @@ def execute(self, context: Context): metadata=self.metadata, location=self.location, ) - self.xcom_push(context, key="id", value=result.id) + context["task_instance"].xcom_push(key="id", value=result.id) project_id = self.project_id or hook.project_id if project_id: CloudBuildTriggerDetailsLink.persist( @@ -854,7 +854,7 @@ def execute(self, context: Context): location=self.location, ) - self.xcom_push(context, key="id", value=result.id) + context["task_instance"].xcom_push(key="id", value=result.id) project_id = self.project_id or hook.project_id if project_id: CloudBuildLink.persist( @@ -944,7 +944,7 @@ def execute(self, context: Context): metadata=self.metadata, location=self.location, ) - self.xcom_push(context, key="id", value=result.id) + context["task_instance"].xcom_push(key="id", value=result.id) project_id = self.project_id or hook.project_id if project_id: CloudBuildLink.persist( @@ -1030,7 +1030,7 @@ def execute(self, context: Context): metadata=self.metadata, location=self.location, ) - self.xcom_push(context, key="id", value=result.id) + context["task_instance"].xcom_push(key="id", value=result.id) project_id = self.project_id or hook.project_id if project_id: CloudBuildTriggerDetailsLink.persist( diff --git a/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py b/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py index 48b444f8b13b7..ba6692cb1e802 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py @@ -163,7 +163,7 @@ def execute(self, context: Context): ) _, _, entry_id = result.name.rpartition("/") self.log.info("Current entry_id ID: %s", entry_id) - self.xcom_push(context, key="entry_id", value=entry_id) + context["ti"].xcom_push(key="entry_id", value=entry_id) DataCatalogEntryLink.persist( context=context, entry_id=self.entry_id, @@ -283,7 +283,7 @@ def execute(self, context: Context): _, _, entry_group_id = result.name.rpartition("/") self.log.info("Current entry group ID: %s", entry_group_id) - self.xcom_push(context, key="entry_group_id", value=entry_group_id) + context["ti"].xcom_push(key="entry_group_id", value=entry_group_id) DataCatalogEntryGroupLink.persist( context=context, entry_group_id=self.entry_group_id, @@ -425,7 +425,7 @@ def execute(self, context: Context): _, _, tag_id = tag.name.rpartition("/") self.log.info("Current Tag ID: %s", tag_id) - self.xcom_push(context, key="tag_id", value=tag_id) + context["ti"].xcom_push(key="tag_id", value=tag_id) DataCatalogEntryLink.persist( context=context, entry_id=self.entry, @@ -542,7 +542,7 @@ def execute(self, context: Context): ) _, _, tag_template = result.name.rpartition("/") self.log.info("Current Tag ID: %s", tag_template) - self.xcom_push(context, key="tag_template_id", value=tag_template) + context["ti"].xcom_push(key="tag_template_id", value=tag_template) DataCatalogTagTemplateLink.persist( context=context, tag_template_id=self.tag_template_id, @@ -668,7 +668,7 @@ def execute(self, context: Context): result = tag_template.fields[self.tag_template_field_id] self.log.info("Current Tag ID: %s", self.tag_template_field_id) - self.xcom_push(context, key="tag_template_field_id", value=self.tag_template_field_id) + context["ti"].xcom_push(key="tag_template_field_id", value=self.tag_template_field_id) DataCatalogTagTemplateLink.persist( context=context, tag_template_id=self.tag_template, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py b/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py index c61a19b6a3495..88542a8f1e401 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py @@ -409,7 +409,7 @@ def set_current_job(current_job): append_job_name=self.append_job_name, ) job_id = self.hook.extract_job_id(self.job) - self.xcom_push(context, key="job_id", value=job_id) + context["task_instance"].xcom_push(key="job_id", value=job_id) return job_id self.job = self.hook.launch_job_with_template( @@ -446,7 +446,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> str: raise AirflowException(event["message"]) job_id = event["job_id"] - self.xcom_push(context, key="job_id", value=job_id) + context["task_instance"].xcom_push(key="job_id", value=job_id) self.log.info("Task %s completed with response %s", self.task_id, event["message"]) return job_id @@ -609,7 +609,7 @@ def set_current_job(current_job): on_new_job_callback=set_current_job, ) job_id = self.hook.extract_job_id(self.job) - self.xcom_push(context, key="job_id", value=job_id) + context["task_instance"].xcom_push(key="job_id", value=job_id) return self.job self.job = self.hook.launch_job_with_flex_template( @@ -650,7 +650,7 @@ def execute_complete(self, context: Context, event: dict) -> dict[str, str]: job_id = event["job_id"] self.log.info("Task %s completed with response %s", job_id, event["message"]) - self.xcom_push(context, key="job_id", value=job_id) + context["task_instance"].xcom_push(key="job_id", value=job_id) job = self.hook.get_job(job_id=job_id, project_id=self.project_id, location=self.location) return job @@ -807,7 +807,7 @@ def execute_complete(self, context: Context, event: dict) -> dict[str, Any]: raise AirflowException(event["message"]) job = event["job"] self.log.info("Job %s completed with response %s", job["id"], event["message"]) - self.xcom_push(context, key="job_id", value=job["id"]) + context["task_instance"].xcom_push(key="job_id", value=job["id"]) return job @@ -1025,7 +1025,7 @@ def execute(self, context: Context): location=self.location, ) DataflowPipelineLink.persist(context=context) - self.xcom_push(context, key="pipeline_name", value=self.pipeline_name) + context["task_instance"].xcom_push(key="pipeline_name", value=self.pipeline_name) if self.pipeline: if "error" in self.pipeline: raise AirflowException(self.pipeline.get("error").get("message")) @@ -1096,7 +1096,7 @@ def execute(self, context: Context): location=self.location, )["job"] job_id = self.dataflow_hook.extract_job_id(self.job) - self.xcom_push(context, key="job_id", value=job_id) + context["task_instance"].xcom_push(key="job_id", value=job_id) DataflowJobLink.persist( context=context, project_id=self.project_id, region=self.location, job_id=job_id ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py b/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py index ad6a7217e2cc9..45361b364ab75 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py @@ -2533,8 +2533,7 @@ def execute(self, context: Context): metadata=self.metadata, ) self.log.info("EntryGroup on page: %s", entry_group_on_page) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="entry_group_page", value=ListEntryGroupsResponse.to_dict(entry_group_on_page._response), ) @@ -2954,8 +2953,7 @@ def execute(self, context: Context): metadata=self.metadata, ) self.log.info("EntryType on page: %s", entry_type_on_page) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="entry_type_page", value=ListEntryTypesResponse.to_dict(entry_type_on_page._response), ) @@ -3308,8 +3306,7 @@ def execute(self, context: Context): metadata=self.metadata, ) self.log.info("AspectType on page: %s", aspect_type_on_page) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="aspect_type_page", value=ListAspectTypesResponse.to_dict(aspect_type_on_page._response), ) @@ -3803,8 +3800,7 @@ def execute(self, context: Context): metadata=self.metadata, ) self.log.info("Entries on page: %s", entries_on_page) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="entry_page", value=ListEntriesResponse.to_dict(entries_on_page._response), ) @@ -3901,8 +3897,7 @@ def execute(self, context: Context): metadata=self.metadata, ) self.log.info("Entries on page: %s", entries_on_page) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="entry_page", value=SearchEntriesResponse.to_dict(entries_on_page._response), ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py index 3246c5bb6c356..5ed6923821ac8 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py @@ -1353,7 +1353,11 @@ def execute(self, context: Context): self.log.info("Job %s submitted successfully.", job_id) # Save data required for extra links no matter what the job status will be DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_JOB_LINK_DEPRECATED, resource=job_id + context=context, + url=DATAPROC_JOB_LINK_DEPRECATED, + resource=job_id, + region=self.region, + project_id=self.project_id, ) if self.deferrable: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py index 01340af99c612..25743a13a701f 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -39,8 +39,8 @@ if TYPE_CHECKING: from google.protobuf.field_mask_pb2 import FieldMask - from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.google.version_compat import BaseOperator from airflow.utils.context import Context BASE_LINK = "https://console.cloud.google.com" diff --git a/providers/google/src/airflow/providers/google/cloud/operators/functions.py b/providers/google/src/airflow/providers/google/cloud/operators/functions.py index 77dee5089f2c9..f31f93c27dbc4 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/functions.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/functions.py @@ -488,7 +488,7 @@ def execute(self, context: Context): project_id=self.project_id, ) self.log.info("Function called successfully. Execution id %s", result.get("executionId")) - self.xcom_push(context=context, key="execution_id", value=result.get("executionId")) + context["ti"].xcom_push(key="execution_id", value=result.get("executionId")) project_id = self.project_id or hook.project_id if project_id: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py index 6c42eb8be1d01..c4f7d09e5789c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py @@ -256,8 +256,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="cluster_page", value=types.ListClustersResponse.to_dict(cluster_list_pager._response), ) @@ -622,8 +621,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="topic_page", value=types.ListTopicsResponse.to_dict(topic_list_pager._response), ) @@ -897,8 +895,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="consumer_group_page", value=types.ListConsumerGroupsResponse.to_dict(consumer_group_list_pager._response), ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/translate.py b/providers/google/src/airflow/providers/google/cloud/operators/translate.py index dd30b4536ae5e..ba6b0c338abde 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/translate.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/translate.py @@ -479,7 +479,7 @@ def execute(self, context: Context) -> str: result = hook.wait_for_operation_result(result_operation) result = type(result).to_dict(result) dataset_id = hook.extract_object_id(result) - self.xcom_push(context, key="dataset_id", value=dataset_id) + context["ti"].xcom_push(key="dataset_id", value=dataset_id) self.log.info("Dataset creation complete. The dataset_id: %s.", dataset_id) project_id = self.project_id or hook.project_id @@ -819,7 +819,7 @@ def execute(self, context: Context) -> str: result = hook.wait_for_operation_result(operation=result_operation) result = type(result).to_dict(result) model_id = hook.extract_object_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) self.log.info("Model creation complete. The model_id: %s.", model_id) project_id = self.project_id or hook.project_id @@ -1406,7 +1406,7 @@ def execute(self, context: Context) -> str: result = type(result).to_dict(result) glossary_id = hook.extract_object_id(result) - self.xcom_push(context, key="glossary_id", value=glossary_id) + context["ti"].xcom_push(key="glossary_id", value=glossary_id) self.log.info("Glossary creation complete. The glossary_id: %s.", glossary_id) return result diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index e1a1f50575ae6..4251b4becaa5a 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -249,11 +249,11 @@ def execute(self, context: Context): if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore - self.xcom_push(context, key="training_id", value=training_id) + context["ti"].xcom_push(key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -341,11 +341,11 @@ def execute(self, context: Context): if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore - self.xcom_push(context, key="training_id", value=training_id) + context["ti"].xcom_push(key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -464,11 +464,11 @@ def execute(self, context: Context): if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore - self.xcom_push(context, key="training_id", value=training_id) + context["ti"].xcom_push(key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -538,11 +538,11 @@ def execute(self, context: Context): if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore - self.xcom_push(context, key="training_id", value=training_id) + context["ti"].xcom_push(key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, training_id=training_id) return result diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py index 156aefdc6fa6b..4aad729a9dff4 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py @@ -269,7 +269,7 @@ def execute(self, context: Context): batch_prediction_job_id = batch_prediction_job.name self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id) - self.xcom_push(context, key="batch_prediction_job_id", value=batch_prediction_job_id) + context["ti"].xcom_push(key="batch_prediction_job_id", value=batch_prediction_job_id) VertexAIBatchPredictionJobLink.persist( context=context, batch_prediction_job_id=batch_prediction_job_id, @@ -303,13 +303,11 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, job: dict[str, Any] = event["job"] self.log.info("Batch prediction job %s created and completed successfully.", job["name"]) job_id = self.hook.extract_batch_prediction_job_id(job) - self.xcom_push( - context, + context["ti"].xcom_push( key="batch_prediction_job_id", value=job_id, ) - self.xcom_push( - context, + context["ti"].xcom_push( key="training_conf", value={ "training_conf_id": job_id, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 9543f27c71924..24a7efaeadc03 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -182,11 +182,11 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, raise AirflowException(event["message"]) training_pipeline = event["job"] custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline) - self.xcom_push(context, key="custom_job_id", value=custom_job_id) + context["ti"].xcom_push(key="custom_job_id", value=custom_job_id) try: model = training_pipeline["model_to_upload"] model_id = self.hook.extract_model_id(model) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) return model except KeyError: @@ -591,12 +591,12 @@ def execute(self, context: Context): if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore - self.xcom_push(context, key="training_id", value=training_id) - self.xcom_push(context, key="custom_job_id", value=custom_job_id) + context["ti"].xcom_push(key="training_id", value=training_id) + context["ti"].xcom_push(key="custom_job_id", value=custom_job_id) VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -655,7 +655,7 @@ def invoke_defer(self, context: Context) -> None: ) custom_container_training_job_obj.wait_for_resource_creation() training_pipeline_id: str = custom_container_training_job_obj.name - self.xcom_push(context, key="training_id", value=training_pipeline_id) + context["ti"].xcom_push(key="training_id", value=training_pipeline_id) VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id) self.defer( trigger=CustomContainerTrainingJobTrigger( @@ -1048,12 +1048,12 @@ def execute(self, context: Context): if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore - self.xcom_push(context, key="training_id", value=training_id) - self.xcom_push(context, key="custom_job_id", value=custom_job_id) + context["ti"].xcom_push(key="training_id", value=training_id) + context["ti"].xcom_push(key="custom_job_id", value=custom_job_id) VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -1113,7 +1113,7 @@ def invoke_defer(self, context: Context) -> None: ) custom_python_training_job_obj.wait_for_resource_creation() training_pipeline_id: str = custom_python_training_job_obj.name - self.xcom_push(context, key="training_id", value=training_pipeline_id) + context["ti"].xcom_push(key="training_id", value=training_pipeline_id) VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id) self.defer( trigger=CustomPythonPackageTrainingJobTrigger( @@ -1511,12 +1511,12 @@ def execute(self, context: Context): if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore - self.xcom_push(context, key="training_id", value=training_id) - self.xcom_push(context, key="custom_job_id", value=custom_job_id) + context["ti"].xcom_push(key="training_id", value=training_id) + context["ti"].xcom_push(key="custom_job_id", value=custom_job_id) VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -1576,7 +1576,7 @@ def invoke_defer(self, context: Context) -> None: ) custom_training_job_obj.wait_for_resource_creation() training_pipeline_id: str = custom_training_job_obj.name - self.xcom_push(context, key="training_id", value=training_pipeline_id) + context["ti"].xcom_push(key="training_id", value=training_pipeline_id) VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id) self.defer( trigger=CustomTrainingJobTrigger( diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index 594b8e4fcdab8..a2c0c79eb0fd3 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -113,7 +113,7 @@ def execute(self, context: Context): dataset_id = hook.extract_dataset_id(dataset) self.log.info("Dataset was created. Dataset id: %s", dataset_id) - self.xcom_push(context, key="dataset_id", value=dataset_id) + context["ti"].xcom_push(key="dataset_id", value=dataset_id) VertexAIDatasetLink.persist(context=context, dataset_id=dataset_id) return dataset diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py index 32a64e17bcabf..9871cdccdc49a 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py @@ -122,7 +122,7 @@ def execute(self, context: Context): endpoint_id = hook.extract_endpoint_id(endpoint) self.log.info("Endpoint was created. Endpoint ID: %s", endpoint_id) - self.xcom_push(context, key="endpoint_id", value=endpoint_id) + context["ti"].xcom_push(key="endpoint_id", value=endpoint_id) VertexAIEndpointLink.persist(context=context, endpoint_id=endpoint_id) return endpoint @@ -292,7 +292,7 @@ def execute(self, context: Context): deployed_model_id = hook.extract_deployed_model_id(deploy_model) self.log.info("Model was deployed. Deployed Model ID: %s", deployed_model_id) - self.xcom_push(context, key="deployed_model_id", value=deployed_model_id) + context["ti"].xcom_push(key="deployed_model_id", value=deployed_model_id) VertexAIModelLink.persist(context=context, model_id=deployed_model_id) return deploy_model diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index 20257dad196c8..2c47b2ee5aff8 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -93,7 +93,7 @@ def execute(self, context: Context): ) self.log.info("Model response: %s", response) - self.xcom_push(context, key="model_response", value=response) + context["ti"].xcom_push(key="model_response", value=response) return response @@ -172,7 +172,7 @@ def execute(self, context: Context): ) self.log.info("Model response: %s", response) - self.xcom_push(context, key="model_response", value=response) + context["ti"].xcom_push(key="model_response", value=response) return response @@ -261,8 +261,8 @@ def execute(self, context: Context): self.log.info("Tuned Model Name: %s", response.tuned_model_name) self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name) - self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name) - self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name) + context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name) + context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name) result = { "tuned_model_name": response.tuned_model_name, @@ -332,8 +332,8 @@ def execute(self, context: Context): self.log.info("Total tokens: %s", response.total_tokens) self.log.info("Total billable characters: %s", response.total_billable_characters) - self.xcom_push(context, key="total_tokens", value=response.total_tokens) - self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters) + context["ti"].xcom_push(key="total_tokens", value=response.total_tokens) + context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters) class RunEvaluationOperator(GoogleCloudBaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py index 86d278c9ab1ba..a667778965bc7 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py @@ -257,7 +257,7 @@ def execute(self, context: Context): hyperparameter_tuning_job_id = hyperparameter_tuning_job.name self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id) - self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id) + context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id) VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id) if self.deferrable: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py index d5f9c26e5a078..2b6459d434986 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py @@ -186,7 +186,7 @@ def execute(self, context: Context): ) self.log.info("Model found. Model ID: %s", self.model_id) - self.xcom_push(context, key="model_id", value=self.model_id) + context["ti"].xcom_push(key="model_id", value=self.model_id) VertexAIModelLink.persist(context=context, model_id=self.model_id) return Model.to_dict(model) except NotFound: @@ -453,7 +453,7 @@ def execute(self, context: Context): model_id = hook.extract_model_id(model_resp) self.log.info("Model was uploaded. Model ID: %s", model_id) - self.xcom_push(context, key="model_id", value=model_id) + context["ti"].xcom_push(key="model_id", value=model_id) VertexAIModelLink.persist(context=context, model_id=model_id) return model_resp diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py index d12adaaf27222..875d2eff00e50 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py @@ -195,7 +195,7 @@ def execute(self, context: Context): ) pipeline_job_id = pipeline_job_obj.job_id self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id) - self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id) + context["ti"].xcom_push(key="pipeline_job_id", value=pipeline_job_id) VertexAIPipelineJobLink.persist(context=context, pipeline_id=pipeline_job_id) if self.deferrable: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py index 6368e81eb0170..0cd222ec6d2fa 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py @@ -188,8 +188,7 @@ def execute(self, context: Context): labels=self.labels, ) cluster_id = self.hook.extract_cluster_id(cluster_path) - self.xcom_push( - context=context, + context["ti"].xcom_push( key="cluster_id", value=cluster_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/workflows.py b/providers/google/src/airflow/providers/google/cloud/operators/workflows.py index 81f97d18a6d97..51da294c664ac 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/workflows.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/workflows.py @@ -501,7 +501,7 @@ def execute(self, context: Context): metadata=self.metadata, ) execution_id = execution.name.split("/")[-1] - self.xcom_push(context, key="execution_id", value=execution_id) + context["task_instance"].xcom_push(key="execution_id", value=execution_id) WorkflowsExecutionLink.persist( context=context, diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py index 2b82d5cfff122..f08757f0edd0a 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py @@ -31,12 +31,7 @@ BigQueryTableExistenceTrigger, BigQueryTablePartitionExistenceTrigger, ) -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseSensorOperator -else: - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] +from airflow.providers.google.version_compat import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py index 38d131851e2c9..bc71934e9dc87 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py @@ -37,12 +37,7 @@ DataflowJobStatusTrigger, ) from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseSensorOperator -else: - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] +from airflow.providers.google.version_compat import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py index 303f0c00761a1..edbfd61b1a2fd 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py @@ -21,8 +21,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator try: from airflow.providers.microsoft.azure.hooks.wasb import WasbHook diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py index 9dd3ca6015f47..4b890b6c59d59 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py @@ -23,8 +23,8 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory +from airflow.providers.google.version_compat import BaseOperator try: from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index 603fe3929dae3..d34c71a7dc0fb 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -22,10 +22,10 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 1531807884673..719d565a9fc32 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -27,11 +27,11 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID +from airflow.providers.google.version_compat import BaseOperator from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py index 081bb8df9d63e..dc3ad68fb81ce 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py @@ -23,9 +23,9 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.utils.bigquery_get_data import bigquery_get_data +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.providers.common.sql.hooks.sql import DbApiHook diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py index f9d7b977db0c1..bc6e9628b5462 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py @@ -21,9 +21,9 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.suite.hooks.calendar import GoogleCalendarHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from datetime import datetime diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index f283649faac0d..7c40c8fae5d47 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -31,9 +31,9 @@ from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py index ea43667076651..8814df2c02da3 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py @@ -26,9 +26,9 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.facebook.ads.hooks.ads import FacebookAdsReportingHook from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from facebook_business.adobjects.adsinsights import AdsInsights diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 08a119a4d3a53..f4c47dec9420c 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -38,12 +38,12 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID +from airflow.providers.google.version_compat import BaseOperator from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 54a8269709a9b..296d216b88135 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -24,8 +24,8 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator WILDCARD = "*" diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py index 70cdf0cdb9baf..b407829cc52a5 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py @@ -20,9 +20,9 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.models.xcom import MAX_XCOM_SIZE from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py index f529cef36133e..7aebbe1b6850c 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py @@ -26,8 +26,8 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator from airflow.providers.sftp.hooks.sftp import SFTPHook WILDCARD = "*" diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py index dc57fd712c568..827f7455910fe 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py @@ -19,9 +19,9 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context @@ -99,3 +99,7 @@ def execute(self, context: Context): bucket_name=self.bucket_name, object_name=self.object_name ) as file: gdrive_hook.download_file(file_id=file_metadata["id"], file_handle=file) + + def dry_run(self): + """Perform a dry run of the operator.""" + return None diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py index 22dcc9f67ccb8..12d903ea52f08 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py @@ -19,8 +19,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py index 71b2b7020af4c..b06ac8d22e561 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py @@ -22,8 +22,8 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator from airflow.providers.http.hooks.http import HttpHook if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py index b1a143b242f30..b6c183e600597 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py @@ -24,8 +24,8 @@ from glob import glob from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py index b83f09be70b24..5ffe7f59f32cc 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py @@ -21,8 +21,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator from airflow.providers.salesforce.hooks.salesforce import SalesforceHook if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py index 1653b84ec2b82..9e53d16f94395 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py @@ -26,8 +26,8 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator from airflow.providers.sftp.hooks.sftp import SFTPHook if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py index 8c681bc5c9db6..6f11b38531f18 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py @@ -21,9 +21,9 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.suite.hooks.sheets import GSheetsHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context @@ -130,5 +130,5 @@ def execute(self, context: Context): gcs_path_to_file = self._upload_data(gcs_hook, sheet_hook, sheet_range, data) destination_array.append(gcs_path_to_file) - self.xcom_push(context, "destination_objects", destination_array) + context["ti"].xcom_push(key="destination_objects", value=destination_array) return destination_array diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py index 6953c70876733..e4ec2d730ef81 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py @@ -30,8 +30,8 @@ import pyarrow as pa import pyarrow.parquet as pq -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.providers.common.compat.openlineage.facet import OutputDataset diff --git a/providers/google/src/airflow/providers/google/firebase/operators/firestore.py b/providers/google/src/airflow/providers/google/firebase/operators/firestore.py index 055bcadb92a41..08ad61c7616eb 100644 --- a/providers/google/src/airflow/providers/google/firebase/operators/firestore.py +++ b/providers/google/src/airflow/providers/google/firebase/operators/firestore.py @@ -20,9 +20,9 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.providers.google.firebase.hooks.firestore import CloudFirestoreHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py b/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py index 2d544e89b45bc..77320d5e6bd7a 100644 --- a/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py +++ b/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py @@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.google.leveldb.hooks.leveldb import LevelDBHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py index 9c055889bd751..a783ddf008181 100644 --- a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py +++ b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py @@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, ClassVar +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator + if TYPE_CHECKING: - from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS - if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink from airflow.sdk.execution_time.xcom import XCom @@ -64,11 +63,9 @@ class GoogleAnalyticsPropertyLink(GoogleAnalyticsBaseLink): @staticmethod def persist( context: Context, - task_instance: BaseOperator, property_id: str, ): - task_instance.xcom_push( - context, + context["task_instance"].xcom_push( key=GoogleAnalyticsPropertyLink.key, value={"property_id": property_id}, ) diff --git a/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py b/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py index 8f662f9c92d50..4465a6084aa4d 100644 --- a/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py +++ b/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py @@ -194,7 +194,6 @@ def execute( self.log.info("The Google Analytics property %s was created successfully.", prop.name) GoogleAnalyticsPropertyLink.persist( context=context, - task_instance=self, property_id=prop.name.lstrip("properties/"), ) diff --git a/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py b/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py index 4673f3c53a2ae..e68a286284872 100644 --- a/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py +++ b/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py @@ -28,9 +28,9 @@ from googleapiclient import http from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.marketing_platform.hooks.campaign_manager import GoogleCampaignManagerHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context @@ -237,7 +237,7 @@ def execute(self, context: Context) -> None: mime_type="text/csv", ) - self.xcom_push(context, key="report_name", value=report_name) + context["task_instance"].xcom_push(key="report_name", value=report_name) class GoogleCampaignManagerInsertReportOperator(BaseOperator): @@ -308,7 +308,7 @@ def execute(self, context: Context): self.log.info("Inserting Campaign Manager report.") response = hook.insert_report(profile_id=self.profile_id, report=self.report) report_id = response.get("id") - self.xcom_push(context, key="report_id", value=report_id) + context["task_instance"].xcom_push(key="report_id", value=report_id) self.log.info("Report successfully inserted. Report id: %s", report_id) return response @@ -381,7 +381,7 @@ def execute(self, context: Context): synchronous=self.synchronous, ) file_id = response.get("id") - self.xcom_push(context, key="file_id", value=file_id) + context["task_instance"].xcom_push(key="file_id", value=file_id) self.log.info("Report file id: %s", file_id) return response diff --git a/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py b/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py index a45ca8ae2d22d..0c88db9f94e4a 100644 --- a/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py +++ b/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py @@ -29,9 +29,9 @@ from urllib.parse import urlsplit from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.marketing_platform.hooks.display_video import GoogleDisplayVideo360Hook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context @@ -99,7 +99,7 @@ def execute(self, context: Context) -> dict: self.log.info("Creating Display & Video 360 query.") response = hook.create_query(query=self.body) query_id = response["queryId"] - self.xcom_push(context, key="query_id", value=query_id) + context["task_instance"].xcom_push(key="query_id", value=query_id) self.log.info("Created query with ID: %s", query_id) return response @@ -295,7 +295,7 @@ def execute(self, context: Context): self.bucket_name, report_name, ) - self.xcom_push(context, key="report_name", value=report_name) + context["task_instance"].xcom_push(key="report_name", value=report_name) class GoogleDisplayVideo360RunQueryOperator(BaseOperator): @@ -360,8 +360,8 @@ def execute(self, context: Context) -> dict: self.parameters, ) response = hook.run_query(query_id=self.query_id, params=self.parameters) - self.xcom_push(context, key="query_id", value=response["key"]["queryId"]) - self.xcom_push(context, key="report_id", value=response["key"]["reportId"]) + context["task_instance"].xcom_push(key="query_id", value=response["key"]["queryId"]) + context["task_instance"].xcom_push(key="report_id", value=response["key"]["reportId"]) return response @@ -564,7 +564,7 @@ def execute(self, context: Context) -> dict[str, Any]: operation = hook.create_sdf_download_operation(body_request=self.body_request) name = operation["name"] - self.xcom_push(context, key="name", value=name) + context["task_instance"].xcom_push(key="name", value=name) self.log.info("Created SDF operation with name: %s", name) return operation diff --git a/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py b/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py index 1c7e120aa37f0..7b22197596c71 100644 --- a/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py +++ b/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py @@ -23,8 +23,8 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.google.marketing_platform.hooks.search_ads import GoogleSearchAdsReportingHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/suite/operators/sheets.py b/providers/google/src/airflow/providers/google/suite/operators/sheets.py index f6eaaf9d2ba07..2cf942218c675 100644 --- a/providers/google/src/airflow/providers/google/suite/operators/sheets.py +++ b/providers/google/src/airflow/providers/google/suite/operators/sheets.py @@ -19,8 +19,8 @@ from collections.abc import Sequence from typing import Any -from airflow.models import BaseOperator from airflow.providers.google.suite.hooks.sheets import GSheetsHook +from airflow.providers.google.version_compat import BaseOperator class GoogleSheetsCreateSpreadsheetOperator(BaseOperator): @@ -68,6 +68,6 @@ def execute(self, context: Any) -> dict[str, Any]: impersonation_chain=self.impersonation_chain, ) spreadsheet = hook.create_spreadsheet(spreadsheet=self.spreadsheet) - self.xcom_push(context, "spreadsheet_id", spreadsheet["spreadsheetId"]) - self.xcom_push(context, "spreadsheet_url", spreadsheet["spreadsheetUrl"]) + context["task_instance"].xcom_push(key="spreadsheet_id", value=spreadsheet["spreadsheetId"]) + context["task_instance"].xcom_push(key="spreadsheet_url", value=spreadsheet["spreadsheetUrl"]) return spreadsheet diff --git a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py index 2a8728e30ea13..06854bca89d7b 100644 --- a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py +++ b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py @@ -24,9 +24,9 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py index 95b9120df91d3..77f3fc03e4271 100644 --- a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py +++ b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py @@ -21,9 +21,9 @@ from tempfile import NamedTemporaryFile from typing import Any -from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.suite.hooks.sheets import GSheetsHook +from airflow.providers.google.version_compat import BaseOperator class GCSToGoogleSheetsOperator(BaseOperator): diff --git a/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py b/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py index 9a9e831762850..aa712d773da27 100644 --- a/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py +++ b/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py @@ -24,8 +24,8 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowFailException -from airflow.models import BaseOperator from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.providers.google.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/version_compat.py b/providers/google/src/airflow/providers/google/version_compat.py index 48d122b669696..9b604e6b95008 100644 --- a/providers/google/src/airflow/providers/google/version_compat.py +++ b/providers/google/src/airflow/providers/google/version_compat.py @@ -33,3 +33,31 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) + +# Version-compatible imports +# BaseOperator: Use 3.1+ due to xcom_push method missing in SDK BaseOperator 3.0.x +# This is needed for DecoratedOperator compatibility +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models import BaseOperator + +# Other SDK components: Available since 3.0+ +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import ( + BaseOperatorLink, + BaseSensorOperator, + ) +else: + from airflow.models import BaseOperatorLink # type: ignore[no-redef] + from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] + +# Explicitly export these imports to protect them from being removed by linters +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "AIRFLOW_V_3_1_PLUS", + "BaseOperator", + "BaseSensorOperator", + "BaseOperatorLink", +] diff --git a/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py b/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py index b7ca33cae727b..66602204431c0 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py +++ b/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py @@ -426,7 +426,7 @@ def test_async_create_build_fires_correct_trigger_should_execute_successfully( ) with pytest.raises(TaskDeferred) as exc: - ti.task.execute({"ti": ti}) + ti.task.execute({"ti": ti, "task_instance": ti}) assert isinstance(exc.value.trigger, CloudBuildCreateBuildTrigger), ( "Trigger is not a CloudBuildCreateBuildTrigger" diff --git a/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py b/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py index 0026dd1f50111..da5b420edb374 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py +++ b/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py @@ -121,6 +121,7 @@ } TEST_TAG_TEMPLATE: TagTemplate = TagTemplate(name=TEST_TAG_TEMPLATE_PATH) TEST_TAG_TEMPLATE_DICT: dict = { + "dataplex_transfer_status": 0, "display_name": "", "fields": {}, "is_publicly_readable": False, @@ -145,8 +146,7 @@ class TestCloudDataCatalogCreateEntryOperator: "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", **{"return_value.create_entry.return_value": TEST_ENTRY}, ) - @mock.patch(BASE_PATH.format("CloudDataCatalogCreateEntryOperator.xcom_push")) - def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: + def test_assert_valid_hook_call(self, mock_hook) -> None: with pytest.warns(AirflowProviderDeprecationWarning): task = CloudDataCatalogCreateEntryOperator( task_id="task_id", @@ -161,8 +161,9 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - context = mock.MagicMock() - result = task.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = task.execute(context=mock_context) # type: ignore[arg-type] mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, @@ -177,8 +178,7 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) - mock_xcom.assert_called_with( - context, + mock_ti.xcom_push.assert_any_call( key="entry_id", value=TEST_ENTRY_ID, ) @@ -186,8 +186,7 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: assert result == TEST_ENTRY_DICT @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook") - @mock.patch(BASE_PATH.format("CloudDataCatalogCreateEntryOperator.xcom_push")) - def test_assert_valid_hook_call_when_exists(self, mock_xcom, mock_hook) -> None: + def test_assert_valid_hook_call_when_exists(self, mock_hook) -> None: mock_hook.return_value.create_entry.side_effect = AlreadyExists(message="message") mock_hook.return_value.get_entry.return_value = TEST_ENTRY with pytest.warns(AirflowProviderDeprecationWarning): @@ -204,8 +203,9 @@ def test_assert_valid_hook_call_when_exists(self, mock_xcom, mock_hook) -> None: gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - context = mock.MagicMock() - result = task.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = task.execute(context=mock_context) # type: ignore[arg-type] mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, @@ -229,8 +229,7 @@ def test_assert_valid_hook_call_when_exists(self, mock_xcom, mock_hook) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) - mock_xcom.assert_called_with( - context, + mock_ti.xcom_push.assert_any_call( key="entry_id", value=TEST_ENTRY_ID, ) @@ -242,8 +241,7 @@ class TestCloudDataCatalogCreateEntryGroupOperator: "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", **{"return_value.create_entry_group.return_value": TEST_ENTRY_GROUP}, ) - @mock.patch(BASE_PATH.format("CloudDataCatalogCreateEntryGroupOperator.xcom_push")) - def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: + def test_assert_valid_hook_call(self, mock_hook) -> None: with pytest.warns(AirflowProviderDeprecationWarning): task = CloudDataCatalogCreateEntryGroupOperator( task_id="task_id", @@ -257,8 +255,9 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - context = mock.MagicMock() - result = task.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = task.execute(context=mock_context) # type: ignore[arg-type] mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, @@ -272,8 +271,7 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) - mock_xcom.assert_called_with( - context, + mock_ti.xcom_push.assert_any_call( key="entry_group_id", value=TEST_ENTRY_GROUP_ID, ) @@ -285,8 +283,7 @@ class TestCloudDataCatalogCreateTagOperator: "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", **{"return_value.create_tag.return_value": TEST_TAG}, ) - @mock.patch(BASE_PATH.format("CloudDataCatalogCreateTagOperator.xcom_push")) - def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: + def test_assert_valid_hook_call(self, mock_hook) -> None: with pytest.warns(AirflowProviderDeprecationWarning): task = CloudDataCatalogCreateTagOperator( task_id="task_id", @@ -302,8 +299,9 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - context = mock.MagicMock() - result = task.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = task.execute(context=mock_context) # type: ignore[arg-type] mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, @@ -319,8 +317,7 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) - mock_xcom.assert_called_with( - context, + mock_ti.xcom_push.assert_any_call( key="tag_id", value=TEST_TAG_ID, ) @@ -332,8 +329,7 @@ class TestCloudDataCatalogCreateTagTemplateOperator: "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", **{"return_value.create_tag_template.return_value": TEST_TAG_TEMPLATE}, ) - @mock.patch(BASE_PATH.format("CloudDataCatalogCreateTagTemplateOperator.xcom_push")) - def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: + def test_assert_valid_hook_call(self, mock_hook) -> None: with pytest.warns(AirflowProviderDeprecationWarning): task = CloudDataCatalogCreateTagTemplateOperator( task_id="task_id", @@ -347,8 +343,9 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - context = mock.MagicMock() - result = task.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = task.execute(context=mock_context) # type: ignore[arg-type] mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, @@ -362,12 +359,11 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) - mock_xcom.assert_called_with( - context, + mock_ti.xcom_push.assert_any_call( key="tag_template_id", value=TEST_TAG_TEMPLATE_ID, ) - assert result == {**result, **TEST_TAG_TEMPLATE_DICT} + assert result == TEST_TAG_TEMPLATE_DICT class TestCloudDataCatalogCreateTagTemplateFieldOperator: @@ -375,8 +371,7 @@ class TestCloudDataCatalogCreateTagTemplateFieldOperator: "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", **{"return_value.create_tag_template_field.return_value": TEST_TAG_TEMPLATE_FIELD}, # type: ignore ) - @mock.patch(BASE_PATH.format("CloudDataCatalogCreateTagTemplateFieldOperator.xcom_push")) - def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: + def test_assert_valid_hook_call(self, mock_hook) -> None: with pytest.warns(AirflowProviderDeprecationWarning): task = CloudDataCatalogCreateTagTemplateFieldOperator( task_id="task_id", @@ -391,8 +386,9 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - context = mock.MagicMock() - result = task.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = task.execute(context=mock_context) # type: ignore[arg-type] mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, @@ -407,12 +403,11 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) - mock_xcom.assert_called_with( - context, + mock_ti.xcom_push.assert_any_call( key="tag_template_field_id", value=TEST_TAG_TEMPLATE_FIELD_ID, ) - assert result == {**result, **TEST_TAG_TEMPLATE_FIELD_DICT} + assert result == TEST_TAG_TEMPLATE_FIELD_DICT class TestCloudDataCatalogDeleteEntryOperator: diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataflow.py b/providers/google/tests/unit/google/cloud/operators/test_dataflow.py index a9bc77b02b6d3..c7e7622b65fbd 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataflow.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataflow.py @@ -161,11 +161,11 @@ def deferrable_operator(self): cancel_timeout=CANCEL_TIMEOUT, ) - @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push") @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") - def test_execute(self, hook_mock, mock_xcom_push, sync_operator): + def test_execute(self, hook_mock, sync_operator): start_template_hook = hook_mock.return_value.start_template_dataflow - sync_operator.execute(None) + mock_context = {"task_instance": mock.MagicMock()} + sync_operator.execute(mock_context) assert hook_mock.called expected_options = { "project": "test", @@ -231,9 +231,8 @@ def test_validation_deferrable_params_raises_error(self): DataflowTemplatedJobStartOperator(**init_kwargs) @pytest.mark.db_test - @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push") @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow") - def test_start_with_custom_region(self, dataflow_mock, mock_xcom_push): + def test_start_with_custom_region(self, dataflow_mock): init_kwargs = { "task_id": TASK_ID, "template": TEMPLATE, @@ -245,16 +244,16 @@ def test_start_with_custom_region(self, dataflow_mock, mock_xcom_push): "cancel_timeout": CANCEL_TIMEOUT, } operator = DataflowTemplatedJobStartOperator(**init_kwargs) - operator.execute(None) + mock_context = {"task_instance": mock.MagicMock()} + operator.execute(mock_context) assert dataflow_mock.called _, kwargs = dataflow_mock.call_args_list[0] assert kwargs["variables"]["region"] == TEST_REGION assert kwargs["location"] == DEFAULT_DATAFLOW_LOCATION @pytest.mark.db_test - @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push") @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow") - def test_start_with_location(self, dataflow_mock, mock_xcom_push): + def test_start_with_location(self, dataflow_mock): init_kwargs = { "task_id": TASK_ID, "template": TEMPLATE, @@ -264,7 +263,8 @@ def test_start_with_location(self, dataflow_mock, mock_xcom_push): "cancel_timeout": CANCEL_TIMEOUT, } operator = DataflowTemplatedJobStartOperator(**init_kwargs) - operator.execute(None) + mock_context = {"task_instance": mock.MagicMock()} + operator.execute(mock_context) assert dataflow_mock.called _, kwargs = dataflow_mock.call_args_list[0] assert not kwargs["variables"] @@ -409,19 +409,18 @@ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferr ) mock_defer_method.assert_called_once() - @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.xcom_push") @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") - def test_execute_complete_success(self, mock_hook, mock_xcom_push, deferrable_operator): + def test_execute_complete_success(self, mock_hook, deferrable_operator): expected_result = {"id": JOB_ID} + mock_context = {"task_instance": mock.MagicMock()} actual_result = deferrable_operator.execute_complete( - context=None, + context=mock_context, event={ "status": "success", "message": "Batch job completed.", "job": expected_result, }, ) - mock_xcom_push.assert_called_with(None, key="job_id", value=JOB_ID) assert actual_result == expected_result def test_execute_complete_error_status_raises_exception(self, deferrable_operator): @@ -449,7 +448,8 @@ def test_exec_job_id(self, dataflow_mock): Test DataflowHook is created and the right args are passed to cancel_job. """ cancel_job_hook = dataflow_mock.return_value.cancel_job - self.dataflow.execute(None) + mock_context = {"task_instance": mock.MagicMock()} + self.dataflow.execute(mock_context) assert dataflow_mock.called cancel_job_hook.assert_called_once_with( job_name=None, @@ -473,7 +473,8 @@ def test_exec_job_name_prefix(self, dataflow_mock): """ is_job_running_hook = dataflow_mock.return_value.is_job_dataflow_running cancel_job_hook = dataflow_mock.return_value.cancel_job - self.dataflow.execute(None) + mock_context = {"task_instance": mock.MagicMock()} + self.dataflow.execute(mock_context) assert dataflow_mock.called is_job_running_hook.assert_called_once_with( name=JOB_NAME, diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 00923cb590a1f..91c2c0a351262 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -69,11 +69,11 @@ DataprocSubmitTrigger, ) from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.timezone import datetime from tests_common.test_utils.db import clear_db_runs, clear_db_xcom -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.sdk.execution_time.comms import XComResult diff --git a/providers/google/tests/unit/google/cloud/operators/test_functions.py b/providers/google/tests/unit/google/cloud/operators/test_functions.py index b075236589364..47b3e4ebde680 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_functions.py +++ b/providers/google/tests/unit/google/cloud/operators/test_functions.py @@ -693,9 +693,8 @@ def test_non_404_gcf_error_bubbled_up(self, mock_hook): class TestGcfFunctionInvokeOperator: - @mock.patch("airflow.providers.google.cloud.operators.functions.GoogleCloudBaseOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.functions.CloudFunctionsHook") - def test_execute(self, mock_gcf_hook, mock_xcom): + def test_execute(self, mock_gcf_hook): exec_id = "exec_id" mock_gcf_hook.return_value.call_function.return_value = {"executionId": exec_id} @@ -715,8 +714,9 @@ def test_execute(self, mock_gcf_hook, mock_xcom): gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, ) - context = mock.MagicMock() - op.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + op.execute(mock_context) mock_gcf_hook.assert_called_once_with( api_version=api_version, @@ -728,8 +728,7 @@ def test_execute(self, mock_gcf_hook, mock_xcom): function_id=function_id, input_data=payload, location=GCP_LOCATION, project_id=GCP_PROJECT_ID ) - mock_xcom.assert_called_with( - context=context, + mock_ti.xcom_push.assert_any_call( key="execution_id", value=exec_id, ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_translate.py b/providers/google/tests/unit/google/cloud/operators/test_translate.py index 957bd3ddd1b5b..4e20d8ee2ace3 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_translate.py +++ b/providers/google/tests/unit/google/cloud/operators/test_translate.py @@ -214,9 +214,8 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): class TestTranslateDatasetCreate: @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist") - @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") - def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): + def test_minimal_green_path(self, mock_hook, mock_link_persist): DS_CREATION_RESULT_SAMPLE = { "display_name": "", "example_count": 0, @@ -249,8 +248,9 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): timeout=TIMEOUT_VALUE, retry=None, ) - context = mock.MagicMock() - result = op.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = op.execute(context=mock_context) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -263,9 +263,9 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): retry=None, metadata=(), ) - mock_xcom_push.assert_called_once_with(context, key="dataset_id", value=DATASET_ID) + mock_ti.xcom_push.assert_any_call(key="dataset_id", value=DATASET_ID) mock_link_persist.assert_called_once_with( - context=context, + context=mock_context, dataset_id=DATASET_ID, location=LOCATION, project_id=PROJECT_ID, @@ -402,9 +402,8 @@ def test_minimal_green_path(self, mock_hook): class TestTranslateModelCreate: @mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelLink.persist") - @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") - def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): + def test_minimal_green_path(self, mock_hook, mock_link_persist): MODEL_DISPLAY_NAME = "model_display_name_01" MODEL_CREATION_RESULT_SAMPLE = { "display_name": MODEL_DISPLAY_NAME, @@ -435,8 +434,9 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): timeout=TIMEOUT_VALUE, retry=None, ) - context = mock.MagicMock() - result = op.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = op.execute(context=mock_context) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -450,9 +450,9 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): retry=None, metadata=(), ) - mock_xcom_push.assert_called_once_with(context, key="model_id", value=MODEL_ID) + mock_ti.xcom_push.assert_any_call(key="model_id", value=MODEL_ID) mock_link_persist.assert_called_once_with( - context=context, + context=mock_context, model_id=MODEL_ID, project_id=PROJECT_ID, dataset_id=DATASET_ID, @@ -711,11 +711,8 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): class TestTranslateGlossaryCreate: - @mock.patch( - "airflow.providers.google.cloud.operators.translate.TranslateCreateGlossaryOperator.xcom_push" - ) @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") - def test_minimal_green_path(self, mock_hook, mock_xcom_push): + def test_minimal_green_path(self, mock_hook): GLOSSARY_CREATION_RESULT = { "name": f"projects/{PROJECT_ID}/locations/{LOCATION}/glossaries/{GLOSSARY_ID}", "display_name": f"{GLOSSARY_ID}", @@ -746,8 +743,9 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push): timeout=TIMEOUT_VALUE, retry=None, ) - context = mock.MagicMock() - result = op.execute(context=context) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + result = op.execute(context=mock_context) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -764,7 +762,7 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push): retry=None, metadata=(), ) - mock_xcom_push.assert_called_once_with(context, key="glossary_id", value=GLOSSARY_ID) + mock_ti.xcom_push.assert_any_call(key="glossary_id", value=GLOSSARY_ID) assert result == GLOSSARY_CREATION_RESULT diff --git a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py index 67d11428cb76b..460e8b24847c3 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py +++ b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py @@ -439,14 +439,11 @@ def test_execute_enters_deferred_state(self, mock_hook): ) @mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist")) - @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.xcom_push")) @mock.patch( VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook.extract_model_id") ) @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook")) - def test_execute_complete_success( - self, mock_hook, mock_hook_extract_model_id, mock_xcom_push, mock_link_persist - ): + def test_execute_complete_success(self, mock_hook, mock_hook_extract_model_id, mock_link_persist): task = CreateCustomContainerTrainingJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -471,16 +468,20 @@ def test_execute_complete_success( ) expected_result = TEST_TRAINING_PIPELINE_DATA["model_to_upload"] mock_hook_extract_model_id.return_value = "test-model" + + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + actual_result = task.execute_complete( - context=None, + context=mock_context, event={ "status": "success", "message": "", "job": TEST_TRAINING_PIPELINE_DATA, }, ) - mock_xcom_push.assert_called_with(None, key="model_id", value="test-model") - mock_link_persist.assert_called_once_with(context=None, model_id="test-model") + mock_ti.xcom_push.assert_any_call(key="model_id", value="test-model") + mock_link_persist.assert_called_once_with(context=mock_context, model_id="test-model") assert actual_result == expected_result def test_execute_complete_error_status_raises_exception(self): @@ -510,7 +511,6 @@ def test_execute_complete_error_status_raises_exception(self): task.execute_complete(context=None, event={"status": "error", "message": "test message"}) @mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist")) - @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.xcom_push")) @mock.patch( VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook.extract_model_id") ) @@ -519,7 +519,6 @@ def test_execute_complete_no_model_produced( self, mock_hook, hook_extract_model_id, - mock_xcom_push, mock_link_persist, ): task = CreateCustomContainerTrainingJobOperator( @@ -544,11 +543,16 @@ def test_execute_complete_no_model_produced( ) expected_result = None hook_extract_model_id.return_value = None + + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + actual_result = task.execute_complete( - context=None, + context=mock_context, event={"status": "success", "message": "", "job": TEST_TRAINING_PIPELINE_DATA_NO_MODEL}, ) - mock_xcom_push.assert_called_once() + # When no model is produced, xcom_push should still be called but with None value + mock_ti.xcom_push.assert_called_once() mock_link_persist.assert_not_called() assert actual_result == expected_result @@ -765,7 +769,6 @@ def test_execute_enters_deferred_state(self, mock_hook): ) @mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist")) - @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.xcom_push")) @mock.patch( VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook.extract_model_id") ) @@ -774,7 +777,6 @@ def test_execute_complete_success( self, mock_hook, hook_extract_model_id, - mock_xcom_push, mock_link_persist, ): task = CreateCustomPythonPackageTrainingJobOperator( @@ -802,16 +804,20 @@ def test_execute_complete_success( ) expected_result = TEST_TRAINING_PIPELINE_DATA["model_to_upload"] hook_extract_model_id.return_value = "test-model" + + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + actual_result = task.execute_complete( - context=None, + context=mock_context, event={ "status": "success", "message": "", "job": TEST_TRAINING_PIPELINE_DATA, }, ) - mock_xcom_push.assert_called_with(None, key="model_id", value="test-model") - mock_link_persist.assert_called_once_with(context=None, model_id="test-model") + mock_ti.xcom_push.assert_any_call(key="model_id", value="test-model") + mock_link_persist.assert_called_once_with(context=mock_context, model_id="test-model") assert actual_result == expected_result def test_execute_complete_error_status_raises_exception(self): @@ -842,7 +848,6 @@ def test_execute_complete_error_status_raises_exception(self): task.execute_complete(context=None, event={"status": "error", "message": "test message"}) @mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist")) - @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.xcom_push")) @mock.patch( VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook.extract_model_id") ) @@ -851,7 +856,6 @@ def test_execute_complete_no_model_produced( self, mock_hook, hook_extract_model_id, - mock_xcom_push, mock_link_persist, ): task = CreateCustomPythonPackageTrainingJobOperator( @@ -875,12 +879,15 @@ def test_execute_complete_no_model_produced( project_id=GCP_PROJECT, deferrable=True, ) + + mock_ti = mock.MagicMock() + expected_result = None actual_result = task.execute_complete( - context=None, + context={"ti": mock_ti}, event={"status": "success", "message": "", "job": TEST_TRAINING_PIPELINE_DATA_NO_MODEL}, ) - mock_xcom_push.assert_called_once() + mock_ti.xcom_push.assert_called_once() mock_link_persist.assert_not_called() assert actual_result == expected_result @@ -1076,14 +1083,12 @@ def test_execute_enters_deferred_state(self, mock_hook): ) @mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist")) - @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.xcom_push")) @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook.extract_model_id")) @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook")) def test_execute_complete_success( self, mock_hook, hook_extract_model_id, - mock_xcom_push, mock_link_persist, ): task = CreateCustomTrainingJobOperator( @@ -1104,16 +1109,18 @@ def test_execute_complete_success( ) expected_result = TEST_TRAINING_PIPELINE_DATA["model_to_upload"] hook_extract_model_id.return_value = "test-model" + + mock_ti = mock.MagicMock() actual_result = task.execute_complete( - context=None, + context={"ti": mock_ti}, event={ "status": "success", "message": "", "job": TEST_TRAINING_PIPELINE_DATA, }, ) - mock_xcom_push.assert_called_with(None, key="model_id", value="test-model") - mock_link_persist.assert_called_once_with(context=None, model_id="test-model") + mock_ti.xcom_push.assert_called_with(key="model_id", value="test-model") + mock_link_persist.assert_called_once_with(context={"ti": mock_ti}, model_id="test-model") assert actual_result == expected_result def test_execute_complete_error_status_raises_exception(self): @@ -1127,7 +1134,6 @@ def test_execute_complete_error_status_raises_exception(self): args=PYTHON_PACKAGE_CMDARGS, container_uri=CONTAINER_URI, model_serving_container_image_uri=CONTAINER_URI, - requirements=[], replica_count=1, region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -1137,14 +1143,12 @@ def test_execute_complete_error_status_raises_exception(self): task.execute_complete(context=None, event={"status": "error", "message": "test message"}) @mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist")) - @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.xcom_push")) @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook.extract_model_id")) @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook")) def test_execute_complete_no_model_produced( self, mock_hook, hook_extract_model_id, - mock_xcom_push, mock_link_persist, ): task = CreateCustomTrainingJobOperator( @@ -1164,10 +1168,14 @@ def test_execute_complete_no_model_produced( ) expected_result = None hook_extract_model_id.return_value = None + + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + actual_result = task.execute_complete( - context=None, event={"status": "success", "message": "", "job": {}} + context=mock_context, event={"status": "success", "message": "", "job": {}} ) - mock_xcom_push.assert_called_once() + mock_ti.xcom_push.assert_called_once() mock_link_persist.assert_not_called() assert actual_result == expected_result @@ -2115,10 +2123,10 @@ def test_execute_deferrable(self, mock_hook, mock_link_persist): assert exception_info.value.trigger.poll_interval == 10 assert exception_info.value.trigger.impersonation_chain == IMPERSONATION_CHAIN - @mock.patch(VERTEX_AI_PATH.format("batch_prediction_job.CreateBatchPredictionJobOperator.xcom_push")) @mock.patch(VERTEX_AI_PATH.format("batch_prediction_job.BatchPredictionJobHook")) - def test_execute_complete(self, mock_hook, mock_xcom_push): - context = mock.MagicMock() + def test_execute_complete(self, mock_hook): + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} mock_job = {"name": TEST_JOB_DISPLAY_NAME} event = { "status": "success", @@ -2139,14 +2147,13 @@ def test_execute_complete(self, mock_hook, mock_xcom_push): create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT, batch_size=TEST_BATCH_SIZE, ) - execute_complete_result = op.execute_complete(context=context, event=event) + execute_complete_result = op.execute_complete(context=mock_context, event=event) mock_hook.return_value.extract_batch_prediction_job_id.assert_called_once_with(mock_job) - mock_xcom_push.assert_has_calls( + mock_ti.xcom_push.assert_has_calls( [ - call(context, key="batch_prediction_job_id", value=TEST_BATCH_PREDICTION_JOB_ID), + call(key="batch_prediction_job_id", value=TEST_BATCH_PREDICTION_JOB_ID), call( - context, key="training_conf", value={ "training_conf_id": TEST_BATCH_PREDICTION_JOB_ID, @@ -2969,9 +2976,8 @@ def test_execute_enters_deferred_state(self, mock_hook): task.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) assert isinstance(exc.value.trigger, RunPipelineJobTrigger), "Trigger is not a RunPipelineJobTrigger" - @mock.patch(VERTEX_AI_PATH.format("pipeline_job.RunPipelineJobOperator.xcom_push")) @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) - def test_execute_complete_success(self, mock_hook, mock_xcom_push): + def test_execute_complete_success(self, mock_hook): task = RunPipelineJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -2987,9 +2993,12 @@ def test_execute_complete_success(self, mock_hook, mock_xcom_push): "name": f"projects/{GCP_PROJECT}/locations/{GCP_LOCATION}/pipelineJobs/{TEST_PIPELINE_JOB_ID}", } mock_hook.return_value.exists.return_value = False - mock_xcom_push.return_value = None + + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + actual_result = task.execute_complete( - context=None, event={"status": "success", "message": "", "job": expected_pipeline_job} + context=mock_context, event={"status": "success", "message": "", "job": expected_pipeline_job} ) assert actual_result == expected_result diff --git a/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py b/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py index 797203ae818a6..5e4446a6c2034 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py @@ -78,14 +78,15 @@ def test_upload_data(self, mock_tempfile, mock_writer): @mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GCSHook") @mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GSheetsHook") - @mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator.xcom_push") @mock.patch( "airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator._upload_data" ) - def test_execute(self, mock_upload_data, mock_xcom, mock_sheet_hook, mock_gcs_hook): - context = {} + def test_execute(self, mock_upload_data, mock_sheet_hook, mock_gcs_hook): + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} data = ["data1", "data2"] mock_sheet_hook.return_value.get_sheet_titles.return_value = RANGES + mock_sheet_hook.return_value.get_values.side_effect = data mock_upload_data.side_effect = [PATH, PATH] op = GoogleSheetsToGCSOperator( @@ -97,7 +98,7 @@ def test_execute(self, mock_upload_data, mock_xcom, mock_sheet_hook, mock_gcs_ho gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context) + op.execute(mock_context) mock_sheet_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -115,9 +116,12 @@ def test_execute(self, mock_upload_data, mock_xcom, mock_sheet_hook, mock_gcs_ho calls = [mock.call(spreadsheet_id=SPREADSHEET_ID, range_=r) for r in RANGES] mock_sheet_hook.return_value.get_values.assert_has_calls(calls) - calls = [mock.call(mock_gcs_hook, mock_sheet_hook, r, v) for r, v in zip(RANGES, data)] - mock_upload_data.assert_called() + calls = [ + mock.call(mock_gcs_hook.return_value, mock_sheet_hook.return_value, r, v) + for r, v in zip(RANGES, data) + ] + mock_upload_data.assert_has_calls(calls) actual_call_count = mock_upload_data.call_count assert len(RANGES) == actual_call_count - mock_xcom.assert_called_once_with(context, "destination_objects", [PATH, PATH]) + mock_ti.xcom_push.assert_called_once_with(key="destination_objects", value=[PATH, PATH]) diff --git a/providers/google/tests/unit/google/cloud/utils/airflow_util.py b/providers/google/tests/unit/google/cloud/utils/airflow_util.py index 891a00780b594..3e0b14cb0a567 100644 --- a/providers/google/tests/unit/google/cloud/utils/airflow_util.py +++ b/providers/google/tests/unit/google/cloud/utils/airflow_util.py @@ -28,7 +28,7 @@ from airflow.utils.types import DagRunType if TYPE_CHECKING: - from airflow.models.baseoperator import BaseOperator + from airflow.providers.google.version_compat import BaseOperator def get_dag_run(dag_id: str = "test_dag_id", run_id: str = "test_dag_id") -> DagRun: diff --git a/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py b/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py index bb015c9be2485..9a1db76e79b7e 100644 --- a/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py +++ b/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py @@ -55,17 +55,15 @@ def test_get_link_not_found(self, mock_xcom): assert url == "" def test_persist(self): - mock_context = mock.MagicMock() mock_task_instance = mock.MagicMock() + mock_context = {"task_instance": mock_task_instance} GoogleAnalyticsPropertyLink.persist( context=mock_context, - task_instance=mock_task_instance, property_id=TEST_PROPERTY_ID, ) mock_task_instance.xcom_push.assert_called_once_with( - mock_context, key=GoogleAnalyticsPropertyLink.key, value={"property_id": TEST_PROPERTY_ID}, ) diff --git a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py index cb2af61d9d297..f2240b79aa586 100644 --- a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py +++ b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py @@ -102,13 +102,8 @@ def teardown_method(self): ) @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.GCSHook") @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.BaseOperator") - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerDownloadReportOperator.xcom_push" - ) def test_execute( self, - xcom_mock, mock_base_op, gcs_hook_mock, hook_mock, @@ -120,6 +115,9 @@ def test_execute( True, ) tempfile_mock.NamedTemporaryFile.return_value.__enter__.return_value.name = TEMP_FILE_NAME + + mock_context = {"task_instance": mock.Mock()} + op = GoogleCampaignManagerDownloadReportOperator( profile_id=PROFILE_ID, report_id=REPORT_ID, @@ -129,7 +127,7 @@ def test_execute( api_version=API_VERSION, task_id="test_task", ) - op.execute(context=None) + op.execute(context=mock_context) hook_mock.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION, @@ -149,7 +147,9 @@ def test_execute( filename=TEMP_FILE_NAME, mime_type="text/csv", ) - xcom_mock.assert_called_once_with(None, key="report_name", value=REPORT_NAME + ".gz") + mock_context["task_instance"].xcom_push.assert_called_once_with( + key="report_name", value=REPORT_NAME + ".gz" + ) @pytest.mark.parametrize( "test_bucket_name", @@ -214,13 +214,11 @@ class TestGoogleCampaignManagerInsertReportOperator: "airflow.providers.google.marketing_platform.operators.campaign_manager.GoogleCampaignManagerHook" ) @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.BaseOperator") - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerInsertReportOperator.xcom_push" - ) - def test_execute(self, xcom_mock, mock_base_op, hook_mock): + def test_execute(self, mock_base_op, hook_mock): report = {"report": "test"} + mock_context = {"task_instance": mock.Mock()} + hook_mock.return_value.insert_report.return_value = {"id": REPORT_ID} op = GoogleCampaignManagerInsertReportOperator( @@ -229,14 +227,14 @@ def test_execute(self, xcom_mock, mock_base_op, hook_mock): api_version=API_VERSION, task_id="test_task", ) - op.execute(context=None) + op.execute(context=mock_context) hook_mock.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.insert_report.assert_called_once_with(profile_id=PROFILE_ID, report=report) - xcom_mock.assert_called_once_with(None, key="report_id", value=REPORT_ID) + mock_context["task_instance"].xcom_push.assert_called_once_with(key="report_id", value=REPORT_ID) def test_prepare_template(self): report = {"key": "value"} @@ -260,13 +258,11 @@ class TestGoogleCampaignManagerRunReportOperator: "airflow.providers.google.marketing_platform.operators.campaign_manager.GoogleCampaignManagerHook" ) @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.BaseOperator") - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerRunReportOperator.xcom_push" - ) - def test_execute(self, xcom_mock, mock_base_op, hook_mock): + def test_execute(self, mock_base_op, hook_mock): synchronous = True + mock_context = {"task_instance": mock.Mock()} + hook_mock.return_value.run_report.return_value = {"id": FILE_ID} op = GoogleCampaignManagerRunReportOperator( @@ -276,7 +272,7 @@ def test_execute(self, xcom_mock, mock_base_op, hook_mock): api_version=API_VERSION, task_id="test_task", ) - op.execute(context=None) + op.execute(context=mock_context) hook_mock.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION, @@ -285,7 +281,7 @@ def test_execute(self, xcom_mock, mock_base_op, hook_mock): hook_mock.return_value.run_report.assert_called_once_with( profile_id=PROFILE_ID, report_id=REPORT_ID, synchronous=synchronous ) - xcom_mock.assert_called_once_with(None, key="file_id", value=FILE_ID) + mock_context["task_instance"].xcom_push.assert_called_once_with(key="file_id", value=FILE_ID) class TestGoogleCampaignManagerBatchInsertConversionsOperator: diff --git a/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py b/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py index 8596743dd7d6b..db361c80f4a64 100644 --- a/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py +++ b/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py @@ -83,10 +83,6 @@ def teardown_method(self): @mock.patch("airflow.providers.google.marketing_platform.operators.display_video.shutil") @mock.patch("airflow.providers.google.marketing_platform.operators.display_video.urllib.request") @mock.patch("airflow.providers.google.marketing_platform.operators.display_video.tempfile") - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360DownloadReportV2Operator.xcom_push" - ) @mock.patch("airflow.providers.google.marketing_platform.operators.display_video.GCSHook") @mock.patch( "airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook" @@ -95,7 +91,6 @@ def test_execute( self, mock_hook, mock_gcs_hook, - mock_xcom, mock_temp, mock_request, mock_shutil, @@ -109,6 +104,9 @@ def test_execute( "googleCloudStoragePath": file_path, } } + # Create mock context with task_instance + mock_context = {"task_instance": mock.Mock()} + op = GoogleDisplayVideo360DownloadReportV2Operator( query_id=QUERY_ID, report_id=REPORT_ID, @@ -118,9 +116,9 @@ def test_execute( ) if should_except: with pytest.raises(AirflowException): - op.execute(context=None) + op.execute(context=mock_context) return - op.execute(context=None) + op.execute(context=mock_context) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, api_version="v2", @@ -139,7 +137,9 @@ def test_execute( mime_type="text/csv", object_name=REPORT_NAME + ".gz", ) - mock_xcom.assert_called_once_with(None, key="report_name", value=REPORT_NAME + ".gz") + mock_context["task_instance"].xcom_push.assert_called_once_with( + key="report_name", value=REPORT_NAME + ".gz" + ) @pytest.mark.parametrize( "test_bucket_name", @@ -199,15 +199,15 @@ def f(): class TestGoogleDisplayVideo360RunQueryOperator: - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360RunQueryOperator.xcom_push" - ) @mock.patch( "airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook" ) - def test_execute(self, hook_mock, mock_xcom): + def test_execute(self, hook_mock): parameters = {"param": "test"} + + # Create mock context with task_instance + mock_context = {"task_instance": mock.Mock()} + hook_mock.return_value.run_query.return_value = { "key": { "queryId": QUERY_ID, @@ -220,15 +220,15 @@ def test_execute(self, hook_mock, mock_xcom): api_version=API_VERSION, task_id="test_task", ) - op.execute(context=None) + op.execute(context=mock_context) hook_mock.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION, impersonation_chain=None, ) - mock_xcom.assert_any_call(None, key="query_id", value=QUERY_ID) - mock_xcom.assert_any_call(None, key="report_id", value=REPORT_ID) + mock_context["task_instance"].xcom_push.assert_any_call(key="query_id", value=QUERY_ID) + mock_context["task_instance"].xcom_push.assert_any_call(key="report_id", value=REPORT_ID) hook_mock.return_value.run_query.assert_called_once_with(query_id=QUERY_ID, params=parameters) @@ -388,20 +388,20 @@ def test_execute(self, mock_temp, gcs_mock_hook, mock_hook): class TestGoogleDisplayVideo360CreateSDFDownloadTaskOperator: - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360CreateSDFDownloadTaskOperator.xcom_push" - ) @mock.patch( "airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook" ) - def test_execute(self, mock_hook, xcom_mock): + def test_execute(self, mock_hook): body_request = { "version": "1", "id": "id", "filter": {"id": []}, } test_name = "test_task" + + # Create mock context with task_instance + mock_context = {"task_instance": mock.Mock()} + mock_hook.return_value.create_sdf_download_operation.return_value = {"name": test_name} op = GoogleDisplayVideo360CreateSDFDownloadTaskOperator( @@ -411,7 +411,7 @@ def test_execute(self, mock_hook, xcom_mock): task_id="test_task", ) - op.execute(context=None) + op.execute(context=mock_context) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION, @@ -422,29 +422,29 @@ def test_execute(self, mock_hook, xcom_mock): mock_hook.return_value.create_sdf_download_operation.assert_called_once_with( body_request=body_request ) - xcom_mock.assert_called_once_with(None, key="name", value=test_name) + mock_context["task_instance"].xcom_push.assert_called_once_with(key="name", value=test_name) class TestGoogleDisplayVideo360CreateQueryOperator: - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360CreateQueryOperator.xcom_push" - ) @mock.patch( "airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook" ) - def test_execute(self, hook_mock, xcom_mock): + def test_execute(self, hook_mock): body = {"body": "test"} + + # Create mock context with task_instance + mock_context = {"task_instance": mock.Mock()} + hook_mock.return_value.create_query.return_value = {"queryId": QUERY_ID} op = GoogleDisplayVideo360CreateQueryOperator(body=body, task_id="test_task") - op.execute(context=None) + op.execute(context=mock_context) hook_mock.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, api_version="v2", impersonation_chain=None, ) hook_mock.return_value.create_query.assert_called_once_with(query=body) - xcom_mock.assert_called_once_with(None, key="query_id", value=QUERY_ID) + mock_context["task_instance"].xcom_push.assert_called_once_with(key="query_id", value=QUERY_ID) def test_prepare_template(self): body = {"key": "value"} diff --git a/providers/google/tests/unit/google/suite/operators/test_sheets.py b/providers/google/tests/unit/google/suite/operators/test_sheets.py index 6e04a2b9fe56d..1d0fc216c0308 100644 --- a/providers/google/tests/unit/google/suite/operators/test_sheets.py +++ b/providers/google/tests/unit/google/suite/operators/test_sheets.py @@ -27,11 +27,9 @@ class TestGoogleSheetsCreateSpreadsheet: @mock.patch("airflow.providers.google.suite.operators.sheets.GSheetsHook") - @mock.patch( - "airflow.providers.google.suite.operators.sheets.GoogleSheetsCreateSpreadsheetOperator.xcom_push" - ) - def test_execute(self, mock_xcom, mock_hook): - context = {} + def test_execute(self, mock_hook): + mock_task_instance = mock.MagicMock() + context = {"task_instance": mock_task_instance} spreadsheet = mock.MagicMock() mock_hook.return_value.create_spreadsheet.return_value = { "spreadsheetId": SPREADSHEET_ID, @@ -44,5 +42,10 @@ def test_execute(self, mock_xcom, mock_hook): mock_hook.return_value.create_spreadsheet.assert_called_once_with(spreadsheet=spreadsheet) + # Verify xcom_push was called with correct arguments + assert mock_task_instance.xcom_push.call_count == 2 + mock_task_instance.xcom_push.assert_any_call(key="spreadsheet_id", value=SPREADSHEET_ID) + mock_task_instance.xcom_push.assert_any_call(key="spreadsheet_url", value=SPREADSHEET_URL) + assert op_execute_result["spreadsheetId"] == "1234567890" assert op_execute_result["spreadsheetUrl"] == "https://example/sheets"