From 3d59b646318ca97150e5a9349f5e05a28ae86662 Mon Sep 17 00:00:00 2001 From: e-halan Date: Fri, 12 Apr 2024 15:30:10 +0000 Subject: [PATCH] Fix deferrable mode for DataflowTemplatedJobStartOperator and DataflowStartFlexTemplateOperator --- .../providers/google/cloud/hooks/dataflow.py | 177 ++++++++++++++---- .../google/cloud/operators/dataflow.py | 86 +++++---- .../google/cloud/triggers/dataflow.py | 2 +- .../operators/cloud/dataflow.rst | 20 +- .../google/cloud/hooks/test_dataflow.py | 52 +++++ .../google/cloud/operators/test_dataflow.py | 50 +++-- .../dataflow/example_dataflow_template.py | 42 ++++- 7 files changed, 336 insertions(+), 93 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index a9bf802b14beb..59eee63501864 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -41,9 +41,13 @@ MessagesV1Beta3AsyncClient, MetricsV1Beta3AsyncClient, ) -from google.cloud.dataflow_v1beta3.types import GetJobMetricsRequest, JobMessageImportance, JobMetrics +from google.cloud.dataflow_v1beta3.types import ( + GetJobMetricsRequest, + JobMessageImportance, + JobMetrics, +) from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest -from googleapiclient.discovery import build +from googleapiclient.discovery import Resource, build from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args @@ -573,7 +577,7 @@ def __init__( impersonation_chain=impersonation_chain, ) - def get_conn(self) -> build: + def get_conn(self) -> Resource: """Return a Google Cloud Dataflow service object.""" http_authorized = self._authorize() return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False) @@ -653,9 +657,9 @@ def start_template_dataflow( on_new_job_callback: Callable[[dict], None] | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, environment: dict | None = None, - ) -> dict: + ) -> dict[str, str]: """ - Start Dataflow template job. + Launch a Dataflow job with a Classic Template and wait for its completion. :param job_name: The name of the job. :param variables: Map of job runtime environment options. @@ -688,26 +692,14 @@ def start_template_dataflow( environment=environment, ) - service = self.get_conn() - - request = ( - service.projects() - .locations() - .templates() - .launch( - projectId=project_id, - location=location, - gcsPath=dataflow_template, - body={ - "jobName": name, - "parameters": parameters, - "environment": environment, - }, - ) + job: dict[str, str] = self.send_launch_template_request( + project_id=project_id, + location=location, + gcs_path=dataflow_template, + job_name=name, + parameters=parameters, + environment=environment, ) - response = request.execute(num_retries=self.num_retries) - - job = response["job"] if on_new_job_id_callback: warnings.warn( @@ -715,7 +707,7 @@ def start_template_dataflow( AirflowProviderDeprecationWarning, stacklevel=3, ) - on_new_job_id_callback(job.get("id")) + on_new_job_id_callback(job["id"]) if on_new_job_callback: on_new_job_callback(job) @@ -734,7 +726,62 @@ def start_template_dataflow( expected_terminal_state=self.expected_terminal_state, ) jobs_controller.wait_for_done() - return response["job"] + return job + + @_fallback_to_location_from_variables + @_fallback_to_project_id_from_variables + @GoogleBaseHook.fallback_to_default_project_id + def launch_job_with_template( + self, + *, + job_name: str, + variables: dict, + parameters: dict, + dataflow_template: str, + project_id: str, + append_job_name: bool = True, + location: str = DEFAULT_DATAFLOW_LOCATION, + environment: dict | None = None, + ) -> dict[str, str]: + """ + Launch a Dataflow job with a Classic Template and exit without waiting for its completion. + + :param job_name: The name of the job. + :param variables: Map of job runtime environment options. + It will update environment argument if passed. + + .. seealso:: + For more information on possible configurations, look at the API documentation + `https://cloud.google.com/dataflow/pipelines/specifying-exec-params + `__ + + :param parameters: Parameters for the template + :param dataflow_template: GCS path to the template. + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param append_job_name: True if unique suffix has to be appended to job name. + :param location: Job location. + + .. seealso:: + For more information on possible configurations, look at the API documentation + `https://cloud.google.com/dataflow/pipelines/specifying-exec-params + `__ + :return: the Dataflow job response + """ + name = self.build_dataflow_job_name(job_name, append_job_name) + environment = self._update_environment( + variables=variables, + environment=environment, + ) + job: dict[str, str] = self.send_launch_template_request( + project_id=project_id, + location=location, + gcs_path=dataflow_template, + job_name=name, + parameters=parameters, + environment=environment, + ) + return job def _update_environment(self, variables: dict, environment: dict | None = None) -> dict: environment = environment or {} @@ -770,6 +817,35 @@ def _check_one(key, val): return environment + def send_launch_template_request( + self, + *, + project_id: str, + location: str, + gcs_path: str, + job_name: str, + parameters: dict, + environment: dict, + ) -> dict[str, str]: + service: Resource = self.get_conn() + request = ( + service.projects() + .locations() + .templates() + .launch( + projectId=project_id, + location=location, + gcsPath=gcs_path, + body={ + "jobName": job_name, + "parameters": parameters, + "environment": environment, + }, + ) + ) + response: dict = request.execute(num_retries=self.num_retries) + return response["job"] + @GoogleBaseHook.fallback_to_default_project_id def start_flex_template( self, @@ -778,9 +854,9 @@ def start_flex_template( project_id: str, on_new_job_id_callback: Callable[[str], None] | None = None, on_new_job_callback: Callable[[dict], None] | None = None, - ) -> dict: + ) -> dict[str, str]: """ - Start flex templates with the Dataflow pipeline. + Launch a Dataflow job with a Flex Template and wait for its completion. :param body: The request body. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body @@ -791,15 +867,16 @@ def start_flex_template( :param on_new_job_callback: A callback that is called when a Job is detected. :return: the Job """ - service = self.get_conn() + service: Resource = self.get_conn() request = ( service.projects() .locations() .flexTemplates() .launch(projectId=project_id, body=body, location=location) ) - response = request.execute(num_retries=self.num_retries) + response: dict = request.execute(num_retries=self.num_retries) job = response["job"] + job_id: str = job["id"] if on_new_job_id_callback: warnings.warn( @@ -807,7 +884,7 @@ def start_flex_template( AirflowProviderDeprecationWarning, stacklevel=3, ) - on_new_job_id_callback(job.get("id")) + on_new_job_id_callback(job_id) if on_new_job_callback: on_new_job_callback(job) @@ -815,7 +892,7 @@ def start_flex_template( jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), project_number=project_id, - job_id=job.get("id"), + job_id=job_id, location=location, poll_sleep=self.poll_sleep, num_retries=self.num_retries, @@ -826,6 +903,42 @@ def start_flex_template( return jobs_controller.get_jobs(refresh=True)[0] + @GoogleBaseHook.fallback_to_default_project_id + def launch_job_with_flex_template( + self, + body: dict, + location: str, + project_id: str, + ) -> dict[str, str]: + """ + Launch a Dataflow Job with a Flex Template and exit without waiting for the job completion. + + :param body: The request body. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body + :param location: The location of the Dataflow job (for example europe-west1) + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :return: a Dataflow job response + """ + service: Resource = self.get_conn() + request = ( + service.projects() + .locations() + .flexTemplates() + .launch(projectId=project_id, body=body, location=location) + ) + response: dict = request.execute(num_retries=self.num_retries) + return response["job"] + + @staticmethod + def extract_job_id(job: dict) -> str: + try: + return job["id"] + except KeyError: + raise AirflowException( + "While reading job object after template execution error occurred. Job object has no id." + ) + @_fallback_to_location_from_variables @_fallback_to_project_id_from_variables @GoogleBaseHook.fallback_to_default_project_id diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 4a6f197e14c8c..424cb8d805c61 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -41,6 +41,7 @@ from airflow.providers.google.cloud.links.dataflow import DataflowJobLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.cloud.triggers.dataflow import TemplateJobStartTrigger +from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.version import version @@ -460,7 +461,7 @@ def on_kill(self) -> None: class DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator): """ - Start a Templated Cloud Dataflow job; the parameters of the operation will be passed to the job. + Start a Dataflow job with a classic template; the parameters of the operation will be passed to the job. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -643,7 +644,7 @@ def __init__( self.deferrable = deferrable self.expected_terminal_state = expected_terminal_state - self.job: dict | None = None + self.job: dict[str, str] | None = None self._validate_deferrable_params() @@ -681,29 +682,34 @@ def set_current_job(current_job): if not self.location: self.location = DEFAULT_DATAFLOW_LOCATION - self.job = self.hook.start_template_dataflow( + if not self.deferrable: + self.job = self.hook.start_template_dataflow( + job_name=self.job_name, + variables=options, + parameters=self.parameters, + dataflow_template=self.template, + on_new_job_callback=set_current_job, + project_id=self.project_id, + location=self.location, + environment=self.environment, + append_job_name=self.append_job_name, + ) + job_id = self.hook.extract_job_id(self.job) + self.xcom_push(context, key="job_id", value=job_id) + return job_id + + self.job = self.hook.launch_job_with_template( job_name=self.job_name, variables=options, parameters=self.parameters, dataflow_template=self.template, - on_new_job_callback=set_current_job, project_id=self.project_id, + append_job_name=self.append_job_name, location=self.location, environment=self.environment, - append_job_name=self.append_job_name, ) - job_id = self.job.get("id") - - if job_id is None: - raise AirflowException( - "While reading job object after template execution error occurred. Job object has no id." - ) - - if not self.deferrable: - return job_id - - context["ti"].xcom_push(key="job_id", value=job_id) - + job_id = self.hook.extract_job_id(self.job) + DataflowJobLink.persist(self, context, self.project_id, self.location, job_id) self.defer( trigger=TemplateJobStartTrigger( project_id=self.project_id, @@ -714,16 +720,17 @@ def set_current_job(current_job): impersonation_chain=self.impersonation_chain, cancel_timeout=self.cancel_timeout, ), - method_name="execute_complete", + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, ) - def execute_complete(self, context: Context, event: dict[str, Any]): + def execute_complete(self, context: Context, event: dict[str, Any]) -> str: """Execute after trigger finishes its work.""" if event["status"] in ("error", "stopped"): self.log.info("status: %s, msg: %s", event["status"], event["message"]) raise AirflowException(event["message"]) job_id = event["job_id"] + self.xcom_push(context, key="job_id", value=job_id) self.log.info("Task %s completed with response %s", self.task_id, event["message"]) return job_id @@ -741,7 +748,7 @@ def on_kill(self) -> None: class DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator): """ - Starts flex templates with the Dataflow pipeline. + Starts a Dataflow Job with a Flex Template. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -803,6 +810,9 @@ class DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator): :param expected_terminal_state: The expected final status of the operator on which the corresponding Airflow task succeeds. When not specified, it will be determined by the hook. :param append_job_name: True if unique suffix has to be appended to job name. + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status while the job is in the + JOB_STATE_RUNNING state. """ template_fields: Sequence[str] = ("body", "location", "project_id", "gcp_conn_id") @@ -821,6 +831,7 @@ def __init__( deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), append_job_name: bool = True, expected_terminal_state: str | None = None, + poll_sleep: int = 10, *args, **kwargs, ) -> None: @@ -832,11 +843,12 @@ def __init__( self.drain_pipeline = drain_pipeline self.cancel_timeout = cancel_timeout self.wait_until_finished = wait_until_finished - self.job: dict | None = None + self.job: dict[str, str] | None = None self.impersonation_chain = impersonation_chain self.deferrable = deferrable self.expected_terminal_state = expected_terminal_state self.append_job_name = append_job_name + self.poll_sleep = poll_sleep self._validate_deferrable_params() @@ -871,32 +883,35 @@ def set_current_job(current_job): self.job = current_job DataflowJobLink.persist(self, context, self.project_id, self.location, self.job.get("id")) - self.job = self.hook.start_flex_template( + if not self.deferrable: + self.job = self.hook.start_flex_template( + body=self.body, + location=self.location, + project_id=self.project_id, + on_new_job_callback=set_current_job, + ) + job_id = self.hook.extract_job_id(self.job) + self.xcom_push(context, key="job_id", value=job_id) + return self.job + + self.job = self.hook.launch_job_with_flex_template( body=self.body, location=self.location, project_id=self.project_id, - on_new_job_callback=set_current_job, ) - - job_id = self.job.get("id") - if job_id is None: - raise AirflowException( - "While reading job object after template execution error occurred. Job object has no id." - ) - - if not self.deferrable: - return self.job - + job_id = self.hook.extract_job_id(self.job) + DataflowJobLink.persist(self, context, self.project_id, self.location, job_id) self.defer( trigger=TemplateJobStartTrigger( project_id=self.project_id, job_id=job_id, location=self.location, gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, impersonation_chain=self.impersonation_chain, cancel_timeout=self.cancel_timeout, ), - method_name="execute_complete", + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, ) def _append_uuid_to_job_name(self): @@ -907,7 +922,7 @@ def _append_uuid_to_job_name(self): job_body["jobName"] = job_name self.log.info("Job name was changed to %s", job_name) - def execute_complete(self, context: Context, event: dict): + def execute_complete(self, context: Context, event: dict) -> dict[str, str]: """Execute after trigger finishes its work.""" if event["status"] in ("error", "stopped"): self.log.info("status: %s, msg: %s", event["status"], event["message"]) @@ -915,6 +930,7 @@ def execute_complete(self, context: Context, event: dict): job_id = event["job_id"] self.log.info("Task %s completed with response %s", job_id, event["message"]) + self.xcom_push(context, key="job_id", value=job_id) job = self.hook.get_job(job_id=job_id, project_id=self.project_id, location=self.location) return job diff --git a/airflow/providers/google/cloud/triggers/dataflow.py b/airflow/providers/google/cloud/triggers/dataflow.py index 32f68a9fd7034..577c0bbf60059 100644 --- a/airflow/providers/google/cloud/triggers/dataflow.py +++ b/airflow/providers/google/cloud/triggers/dataflow.py @@ -138,7 +138,7 @@ async def run(self): return else: self.log.info("Job is still running...") - self.log.info("Current job status is: %s", status) + self.log.info("Current job status is: %s", status.name) self.log.info("Sleeping for %s seconds.", self.poll_sleep) await asyncio.sleep(self.poll_sleep) except Exception as e: diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst index d3f1bd6df4b6b..f9302af8c3cde 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst @@ -208,7 +208,7 @@ from the staging and execution steps. There are two types of templates for Dataf See the `official documentation for Dataflow templates `_ for more information. -Here is an example of running Classic template with +Here is an example of running a Dataflow job using a Classic Template with :class:`~airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`: .. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py @@ -217,10 +217,18 @@ Here is an example of running Classic template with :start-after: [START howto_operator_start_template_job] :end-before: [END howto_operator_start_template_job] +Also for this action you can use the operator in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_start_template_job_deferrable] + :end-before: [END howto_operator_start_template_job_deferrable] + See the `list of Google-provided templates that can be used with this operator `_. -Here is an example of running Flex template with +Here is an example of running a Dataflow job using a Flex Template with :class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator`: .. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py @@ -229,6 +237,14 @@ Here is an example of running Flex template with :start-after: [START howto_operator_start_flex_template_job] :end-before: [END howto_operator_start_flex_template_job] +Also for this action you can use the operator in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_start_flex_template_job_deferrable] + :end-before: [END howto_operator_start_flex_template_job_deferrable] + .. _howto/operator:DataflowStartSqlJobOperator: Dataflow SQL diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 2458b48e8143b..1c8f768ea3aa2 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -1052,6 +1052,34 @@ def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_datafl ) mock_uuid.assert_called_once_with() + @mock.patch(DATAFLOW_STRING.format("uuid.uuid4"), return_value=MOCK_UUID) + @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn")) + def test_launch_job_with_template(self, mock_conn, mock_uuid): + launch_method = ( + mock_conn.return_value.projects.return_value.locations.return_value.templates.return_value.launch + ) + launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}} + variables = {"zone": "us-central1-f", "tempLocation": "gs://test/temp"} + result = self.dataflow_hook.launch_job_with_template( + job_name=JOB_NAME, + variables=copy.deepcopy(variables), + parameters=PARAMETERS, + dataflow_template=TEST_TEMPLATE, + project_id=TEST_PROJECT, + ) + + launch_method.assert_called_once_with( + body={ + "jobName": f"test-dataflow-pipeline-{MOCK_UUID_PREFIX}", + "parameters": PARAMETERS, + "environment": variables, + }, + gcsPath="gs://dataflow-templates/wordcount/template_file", + projectId=TEST_PROJECT, + location=DEFAULT_DATAFLOW_LOCATION, + ) + assert result == {"id": TEST_JOB_ID} + @mock.patch(DATAFLOW_STRING.format("_DataflowJobsController")) @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn")) def test_start_flex_template(self, mock_conn, mock_controller): @@ -1088,6 +1116,26 @@ def test_start_flex_template(self, mock_conn, mock_controller): mock_controller.return_value.get_jobs.assert_called_once_with(refresh=True) assert result == {"id": TEST_JOB_ID} + @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn")) + def test_launch_job_with_flex_template(self, mock_conn): + expected_job = {"id": TEST_JOB_ID} + + mock_locations = mock_conn.return_value.projects.return_value.locations + launch_method = mock_locations.return_value.flexTemplates.return_value.launch + launch_method.return_value.execute.return_value = {"job": expected_job} + + result = self.dataflow_hook.launch_job_with_flex_template( + body={"launchParameter": TEST_FLEX_PARAMETERS}, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + ) + launch_method.assert_called_once_with( + projectId="test-project-id", + body={"launchParameter": TEST_FLEX_PARAMETERS}, + location=TEST_LOCATION, + ) + assert result == {"id": TEST_JOB_ID} + @mock.patch(DATAFLOW_STRING.format("_DataflowJobsController")) @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn")) def test_cancel_job(self, mock_get_conn, jobs_controller): @@ -1177,6 +1225,10 @@ def test_start_sql_job(self, mock_run, mock_provide_authorized_gcloud, mock_get_ on_new_job_callback=mock.MagicMock(), ) + def test_extract_job_id_raises_exception(self): + with pytest.raises(AirflowException): + self.dataflow_hook.extract_job_id({"not_id": True}) + class TestDataflowJob: def setup_method(self): diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 495287b9af731..ebbf471383760 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -102,6 +102,7 @@ GCP_CONN_ID = "test_gcp_conn_id" IMPERSONATION_CHAIN = ["impersonate", "this"] CANCEL_TIMEOUT = 10 * 420 +DATAFLOW_PATH = "airflow.providers.google.cloud.operators.dataflow" class TestDataflowCreatePythonJobOperator: @@ -488,11 +489,12 @@ def deferrable_operator(self): cancel_timeout=CANCEL_TIMEOUT, ) - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") - def test_exec(self, dataflow_mock, sync_operator): - start_template_hook = dataflow_mock.return_value.start_template_dataflow + @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") + def test_execute(self, hook_mock, mock_xcom_push, sync_operator): + start_template_hook = hook_mock.return_value.start_template_dataflow sync_operator.execute(None) - assert dataflow_mock.called + assert hook_mock.called expected_options = { "project": "test", "stagingLocation": "gs://test/staging", @@ -512,10 +514,27 @@ def test_exec(self, dataflow_mock, sync_operator): append_job_name=True, ) - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.defer") - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.hook") + @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.defer") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferrable_operator): deferrable_operator.execute(mock.MagicMock()) + expected_variables = { + "project": "test", + "stagingLocation": "gs://test/staging", + "tempLocation": "gs://test/temp", + "zone": "us-central1-f", + "EXTRA_OPTION": "TEST_A", + } + mock_hook.return_value.launch_job_with_template.assert_called_once_with( + job_name=JOB_NAME, + variables=expected_variables, + parameters=PARAMETERS, + dataflow_template=TEMPLATE, + project_id=TEST_PROJECT, + append_job_name=True, + location=TEST_LOCATION, + environment={"maxWorkers": 2}, + ) mock_defer_method.assert_called_once() def test_validation_deferrable_params_raises_error(self): @@ -540,8 +559,9 @@ def test_validation_deferrable_params_raises_error(self): DataflowTemplatedJobStartOperator(**init_kwargs) @pytest.mark.db_test - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow") - def test_start_with_custom_region(self, dataflow_mock): + @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow") + def test_start_with_custom_region(self, dataflow_mock, mock_xcom_push): init_kwargs = { "task_id": TASK_ID, "template": TEMPLATE, @@ -560,8 +580,9 @@ def test_start_with_custom_region(self, dataflow_mock): assert kwargs["location"] == DEFAULT_DATAFLOW_LOCATION @pytest.mark.db_test - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow") - def test_start_with_location(self, dataflow_mock): + @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow") + def test_start_with_location(self, dataflow_mock, mock_xcom_push): init_kwargs = { "task_id": TASK_ID, "template": TEMPLATE, @@ -601,7 +622,7 @@ def deferrable_operator(self): deferrable=True, ) - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") def test_execute(self, mock_dataflow, sync_operator): sync_operator.execute(mock.MagicMock()) mock_dataflow.assert_called_once_with( @@ -640,16 +661,15 @@ def test_validation_deferrable_params_raises_error(self): with pytest.raises(ValueError): DataflowStartFlexTemplateOperator(**init_kwargs) - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator.defer") - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") + @mock.patch(f"{DATAFLOW_PATH}.DataflowStartFlexTemplateOperator.defer") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferrable_operator): deferrable_operator.execute(mock.MagicMock()) - mock_hook.return_value.start_flex_template.assert_called_once_with( + mock_hook.return_value.launch_job_with_flex_template.assert_called_once_with( body={"launchParameter": TEST_FLEX_PARAMETERS}, location=TEST_LOCATION, project_id=TEST_PROJECT, - on_new_job_callback=mock.ANY, ) mock_defer_method.assert_called_once() diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py index b6eec97a16e55..2a3e747eb7b5c 100644 --- a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py @@ -17,7 +17,8 @@ # under the License. """ -Example Airflow DAG for testing Google Dataflow +Example Airflow DAG for testing Google Dataflow. + :class:`~airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator` operator. """ @@ -27,6 +28,7 @@ from datetime import datetime from pathlib import Path +from airflow.models.baseoperator import chain from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.dataflow import ( DataflowStartFlexTemplateOperator, @@ -104,6 +106,7 @@ template="gs://dataflow-templates/latest/Word_Count", parameters={"inputFile": f"gs://{BUCKET_NAME}/{CSV_FILE_NAME}", "output": GCS_OUTPUT}, location=LOCATION, + wait_until_finished=True, ) # [END howto_operator_start_template_job] @@ -114,20 +117,43 @@ body=BODY, location=LOCATION, append_job_name=False, + wait_until_finished=True, ) # [END howto_operator_start_flex_template_job] + # [START howto_operator_start_template_job_deferrable] + start_template_job_deferrable = DataflowTemplatedJobStartOperator( + task_id="start_template_job_deferrable", + project_id=PROJECT_ID, + template="gs://dataflow-templates/latest/Word_Count", + parameters={"inputFile": f"gs://{BUCKET_NAME}/{CSV_FILE_NAME}", "output": GCS_OUTPUT}, + location=LOCATION, + deferrable=True, + ) + # [END howto_operator_start_template_job_deferrable] + + # [START howto_operator_start_flex_template_job_deferrable] + start_flex_template_job_deferrable = DataflowStartFlexTemplateOperator( + task_id="start_flex_template_job_deferrable", + project_id=PROJECT_ID, + body=BODY, + location=LOCATION, + append_job_name=False, + deferrable=True, + ) + # [END howto_operator_start_flex_template_job_deferrable] + delete_bucket = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE ) - ( - create_bucket - >> upload_file - >> upload_schema - >> start_template_job - >> start_flex_template_job - >> delete_bucket + chain( + create_bucket, + upload_file, + upload_schema, + [start_template_job, start_flex_template_job], + [start_template_job_deferrable, start_flex_template_job_deferrable], + delete_bucket, ) from tests.system.utils.watcher import watcher