diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow_sql.py b/airflow/providers/google/cloud/example_dags/example_dataflow_sql.py new file mode 100644 index 0000000000000..b7d3d708e6141 --- /dev/null +++ b/airflow/providers/google/cloud/example_dags/example_dataflow_sql.py @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Dataflow service +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.dataflow import DataflowStartSqlJobOperator +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +BQ_SQL_DATASET = os.environ.get("DATAFLOW_BQ_SQL_DATASET", "airflow_dataflow_samples") +BQ_SQL_TABLE_INPUT = os.environ.get("BQ_SQL_TABLE_INPUT", "beam_input") +BQ_SQL_TABLE_OUTPUT = os.environ.get("BQ_SQL_TABLE_OUTPUT", "beam_output") +DATAFLOW_SQL_JOB_NAME = os.environ.get("DATAFLOW_SQL_JOB_NAME", "dataflow-sql") +DATAFLOW_SQL_LOCATION = os.environ.get("DATAFLOW_SQL_LOCATION", "us-west1") + +with models.DAG( + dag_id="example_gcp_dataflow_sql", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=['example'], +) as dag_sql: + start_sql = DataflowStartSqlJobOperator( + task_id="start_sql_query", + job_name=DATAFLOW_SQL_JOB_NAME, + query=f""" + SELECT + sales_region as sales_region, + count(state_id) as count_state + FROM + bigquery.table.`{GCP_PROJECT_ID}`.`{BQ_SQL_DATASET}`.`{BQ_SQL_TABLE_INPUT}` + WHERE state_id >= @state_id_min + GROUP BY sales_region; + """, + options={ + "bigquery-project": GCP_PROJECT_ID, + "bigquery-dataset": BQ_SQL_DATASET, + "bigquery-table": BQ_SQL_TABLE_OUTPUT, + "bigquery-write-disposition": "write-truncate", + "parameter": "state_id_min:INT64:2", + }, + location=DATAFLOW_SQL_LOCATION, + do_xcom_push=True, + ) diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index cf9f559598e89..e17001ff2532b 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -28,7 +28,7 @@ import warnings from copy import deepcopy from tempfile import TemporaryDirectory -from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, cast from googleapiclient.discovery import build @@ -39,11 +39,11 @@ # This is the default location # https://cloud.google.com/dataflow/pipelines/specifying-exec-params -DEFAULT_DATAFLOW_LOCATION = 'us-central1' +DEFAULT_DATAFLOW_LOCATION = "us-central1" JOB_ID_PATTERN = re.compile( - r'Submitted job: (?P.*)|Created job with id: \[(?P.*)\]' + r"Submitted job: (?P.*)|Created job with id: \[(?P.*)\]" ) T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name @@ -66,7 +66,7 @@ def inner_wrapper(self: "DataflowHook", *args, **kwargs): ) parameter_location = kwargs.get(parameter_name) - variables_location = kwargs.get('variables', {}).get(variable_key_name) + variables_location = kwargs.get("variables", {}).get(variable_key_name) if parameter_location and variables_location: raise AirflowException( @@ -76,9 +76,9 @@ def inner_wrapper(self: "DataflowHook", *args, **kwargs): if parameter_location or variables_location: kwargs[parameter_name] = parameter_location or variables_location if variables_location: - copy_variables = deepcopy(kwargs['variables']) + copy_variables = deepcopy(kwargs["variables"]) del copy_variables[variable_key_name] - kwargs['variables'] = copy_variables + kwargs["variables"] = copy_variables return func(self, *args, **kwargs) @@ -87,8 +87,8 @@ def inner_wrapper(self: "DataflowHook", *args, **kwargs): return _wrapper -_fallback_to_location_from_variables = _fallback_variable_parameter('location', 'region') -_fallback_to_project_id_from_variables = _fallback_variable_parameter('project_id', 'project') +_fallback_to_location_from_variables = _fallback_variable_parameter("location", "region") +_fallback_to_project_id_from_variables = _fallback_variable_parameter("project_id", "project") class DataflowJobStatus: @@ -186,7 +186,7 @@ def is_job_running(self) -> bool: return False for job in self._jobs: - if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES: + if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES: return True return False @@ -203,17 +203,21 @@ def _get_current_jobs(self) -> List[dict]: elif self._job_name: jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower()) if len(jobs) == 1: - self._job_id = jobs[0]['id'] + self._job_id = jobs[0]["id"] return jobs else: - raise Exception('Missing both dataflow job ID and name.') + raise Exception("Missing both dataflow job ID and name.") def _fetch_job_by_id(self, job_id: str) -> dict: return ( self._dataflow.projects() .locations() .jobs() - .get(projectId=self._project_number, location=self._job_location, jobId=job_id) + .get( + projectId=self._project_number, + location=self._job_location, + jobId=job_id, + ) .execute(num_retries=self._num_retries) ) @@ -239,7 +243,7 @@ def _fetch_all_jobs(self) -> List[dict]: def _fetch_jobs_by_prefix_name(self, prefix_name: str) -> List[dict]: jobs = self._fetch_all_jobs() - jobs = [job for job in jobs if job['name'].startswith(prefix_name)] + jobs = [job for job in jobs if job["name"].startswith(prefix_name)] return jobs def _refresh_jobs(self) -> None: @@ -253,9 +257,13 @@ def _refresh_jobs(self) -> None: if self._jobs: for job in self._jobs: - self.log.info('Google Cloud DataFlow job %s is state: %s', job['name'], job['currentState']) + self.log.info( + "Google Cloud DataFlow job %s is state: %s", + job["name"], + job["currentState"], + ) else: - self.log.info('Google Cloud DataFlow job not available yet..') + self.log.info("Google Cloud DataFlow job not available yet..") def _check_dataflow_job_state(self, job) -> bool: """ @@ -266,22 +274,22 @@ def _check_dataflow_job_state(self, job) -> bool: :rtype: bool :raise: Exception """ - if DataflowJobStatus.JOB_STATE_DONE == job['currentState']: + if DataflowJobStatus.JOB_STATE_DONE == job["currentState"]: return True - elif DataflowJobStatus.JOB_STATE_FAILED == job['currentState']: - raise Exception("Google Cloud Dataflow job {} has failed.".format(job['name'])) - elif DataflowJobStatus.JOB_STATE_CANCELLED == job['currentState']: - raise Exception("Google Cloud Dataflow job {} was cancelled.".format(job['name'])) + elif DataflowJobStatus.JOB_STATE_FAILED == job["currentState"]: + raise Exception("Google Cloud Dataflow job {} has failed.".format(job["name"])) + elif DataflowJobStatus.JOB_STATE_CANCELLED == job["currentState"]: + raise Exception("Google Cloud Dataflow job {} was cancelled.".format(job["name"])) elif ( - DataflowJobStatus.JOB_STATE_RUNNING == job['currentState'] - and DataflowJobType.JOB_TYPE_STREAMING == job['type'] + DataflowJobStatus.JOB_STATE_RUNNING == job["currentState"] + and DataflowJobType.JOB_TYPE_STREAMING == job["type"] ): return True - elif job['currentState'] in DataflowJobStatus.AWAITING_STATES: + elif job["currentState"] in DataflowJobStatus.AWAITING_STATES: return False self.log.debug("Current job: %s", str(job)) raise Exception( - "Google Cloud Dataflow job {} was unknown state: {}".format(job['name'], job['currentState']) + "Google Cloud Dataflow job {} was unknown state: {}".format(job["name"], job["currentState"]) ) def wait_for_done(self) -> None: @@ -293,10 +301,12 @@ def wait_for_done(self) -> None: time.sleep(self._poll_sleep) self._refresh_jobs() - def get_jobs(self, refresh=False) -> List[dict]: + def get_jobs(self, refresh: bool = False) -> List[dict]: """ Returns Dataflow jobs. + :param refresh: Forces the latest data to be fetched. + :type refresh: bool :return: list of jobs :rtype: list """ @@ -310,14 +320,14 @@ def get_jobs(self, refresh=False) -> List[dict]: def cancel(self) -> None: """Cancels or drains current job""" jobs = self.get_jobs() - job_ids = [job['id'] for job in jobs if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES] + job_ids = [job["id"] for job in jobs if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES] if job_ids: batch = self._dataflow.new_batch_http_request() self.log.info("Canceling jobs: %s", ", ".join(job_ids)) for job in jobs: requested_state = ( DataflowJobStatus.JOB_STATE_DRAINED - if self.drain_pipeline and job['type'] == DataflowJobType.JOB_TYPE_STREAMING + if self.drain_pipeline and job["type"] == DataflowJobType.JOB_TYPE_STREAMING else DataflowJobStatus.JOB_STATE_CANCELLED ) batch.add( @@ -327,7 +337,7 @@ def cancel(self) -> None: .update( projectId=self._project_number, location=self._job_location, - jobId=job['id'], + jobId=job["id"], body={"requestedState": requested_state}, ) ) @@ -338,14 +348,20 @@ def cancel(self) -> None: class _DataflowRunner(LoggingMixin): def __init__( - self, cmd: List[str], on_new_job_id_callback: Optional[Callable[[str], None]] = None + self, + cmd: List[str], + on_new_job_id_callback: Optional[Callable[[str], None]] = None, ) -> None: super().__init__() - self.log.info("Running command: %s", ' '.join(shlex.quote(c) for c in cmd)) + self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd)) self.on_new_job_id_callback = on_new_job_id_callback self.job_id: Optional[str] = None self._proc = subprocess.Popen( - cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True + cmd, + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, ) def _process_fd(self, fd): @@ -382,7 +398,7 @@ def _process_line_and_extract_job_id(self, line: str) -> None: # Job id info: https://goo.gl/SE29y9. matched_job = JOB_ID_PATTERN.search(line) if matched_job: - job_id = matched_job.group('job_id_java') or matched_job.group('job_id_python') + job_id = matched_job.group("job_id_java") or matched_job.group("job_id_python") self.log.info("Found Job ID: %s", job_id) self.job_id = job_id if self.on_new_job_id_callback: @@ -449,7 +465,7 @@ def __init__( def get_conn(self) -> build: """Returns a Google Cloud Dataflow service object.""" http_authorized = self._authorize() - return build('dataflow', 'v1b3', http=http_authorized, cache_discovery=False) + return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False) @GoogleBaseHook.provide_gcp_credential_file def _start_dataflow( @@ -457,13 +473,17 @@ def _start_dataflow( variables: dict, name: str, command_prefix: List[str], - label_formatter: Callable[[dict], List[str]], project_id: str, multiple_jobs: bool = False, on_new_job_id_callback: Optional[Callable[[str], None]] = None, location: str = DEFAULT_DATAFLOW_LOCATION, ) -> None: - cmd = command_prefix + self._build_cmd(variables, label_formatter, project_id) + cmd = command_prefix + [ + "--runner=DataflowRunner", + f"--project={project_id}", + ] + if variables: + cmd.extend(self._options_to_args(variables)) runner = _DataflowRunner(cmd=cmd, on_new_job_id_callback=on_new_job_id_callback) job_id = runner.wait_for_done() job_controller = _DataflowJobsController( @@ -517,18 +537,17 @@ def start_java_dataflow( :type location: str """ name = self._build_dataflow_job_name(job_name, append_job_name) - variables['jobName'] = name - variables['region'] = location + variables["jobName"] = name + variables["region"] = location - def label_formatter(labels_dict): - return ['--labels={}'.format(json.dumps(labels_dict).replace(' ', ''))] + if "labels" in variables: + variables["labels"] = json.dumps(variables["labels"], separators=(",", ":")) command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar] self._start_dataflow( variables=variables, name=name, command_prefix=command_prefix, - label_formatter=label_formatter, project_id=project_id, multiple_jobs=multiple_jobs, on_new_job_id_callback=on_new_job_id_callback, @@ -591,21 +610,21 @@ def start_template_dataflow( # available keys for runtime environment are listed here: # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment environment_keys = [ - 'numWorkers', - 'maxWorkers', - 'zone', - 'serviceAccountEmail', - 'tempLocation', - 'bypassTempDirValidation', - 'machineType', - 'additionalExperiments', - 'network', - 'subnetwork', - 'additionalUserLabels', - 'kmsKeyName', - 'ipConfiguration', - 'workerRegion', - 'workerZone', + "numWorkers", + "maxWorkers", + "zone", + "serviceAccountEmail", + "tempLocation", + "bypassTempDirValidation", + "machineType", + "additionalExperiments", + "network", + "subnetwork", + "additionalUserLabels", + "kmsKeyName", + "ipConfiguration", + "workerRegion", + "workerZone", ] for key in variables: @@ -628,12 +647,16 @@ def start_template_dataflow( projectId=project_id, location=location, gcsPath=dataflow_template, - body={"jobName": name, "parameters": parameters, "environment": environment}, + body={ + "jobName": name, + "parameters": parameters, + "environment": environment, + }, ) ) response = request.execute(num_retries=self.num_retries) - job_id = response['job']['id'] + job_id = response["job"]["id"] if on_new_job_id_callback: on_new_job_id_callback(job_id) @@ -679,7 +702,7 @@ def start_flex_template( .launch(projectId=project_id, body=body, location=location) ) response = request.execute(num_retries=self.num_retries) - job_id = response['job']['id'] + job_id = response["job"]["id"] if on_new_job_id_callback: on_new_job_id_callback(job_id) @@ -753,11 +776,11 @@ def start_python_dataflow( # pylint: disable=too-many-arguments :type location: str """ name = self._build_dataflow_job_name(job_name, append_job_name) - variables['job_name'] = name - variables['region'] = location + variables["job_name"] = name + variables["region"] = location - def label_formatter(labels_dict): - return [f'--labels={key}={value}' for key, value in labels_dict.items()] + if "labels" in variables: + variables["labels"] = [f"{key}={value}" for key, value in variables["labels"].items()] if py_requirements is not None: if not py_requirements and not py_system_site_packages: @@ -773,7 +796,7 @@ def label_formatter(labels_dict): ) raise AirflowException(warning_invalid_environment) - with TemporaryDirectory(prefix='dataflow-venv') as tmp_dir: + with TemporaryDirectory(prefix="dataflow-venv") as tmp_dir: py_interpreter = prepare_virtualenv( venv_directory=tmp_dir, python_bin=py_interpreter, @@ -786,7 +809,6 @@ def label_formatter(labels_dict): variables=variables, name=name, command_prefix=command_prefix, - label_formatter=label_formatter, project_id=project_id, on_new_job_id_callback=on_new_job_id_callback, location=location, @@ -798,7 +820,6 @@ def label_formatter(labels_dict): variables=variables, name=name, command_prefix=command_prefix, - label_formatter=label_formatter, project_id=project_id, on_new_job_id_callback=on_new_job_id_callback, location=location, @@ -806,13 +827,13 @@ def label_formatter(labels_dict): @staticmethod def _build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str: - base_job_name = str(job_name).replace('_', '-') + base_job_name = str(job_name).replace("_", "-") if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", base_job_name): raise ValueError( - 'Invalid job_name ({}); the name must consist of' - 'only the characters [-a-z0-9], starting with a ' - 'letter and ending with a letter or number '.format(base_job_name) + "Invalid job_name ({}); the name must consist of" + "only the characters [-a-z0-9], starting with a " + "letter and ending with a letter or number ".format(base_job_name) ) if append_job_name: @@ -823,29 +844,21 @@ def _build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str return safe_job_name @staticmethod - def _build_cmd(variables: dict, label_formatter: Callable, project_id: str) -> List[str]: - command = [ - "--runner=DataflowRunner", - f"--project={project_id}", - ] - if variables is None: - return command - + def _options_to_args(variables: dict) -> List[str]: + if not variables: + return [] # The logic of this method should be compatible with Apache Beam: # https://github.com/apache/beam/blob/b56740f0e8cd80c2873412847d0b336837429fb9/sdks/python/ # apache_beam/options/pipeline_options.py#L230-L251 + args: List[str] = [] for attr, value in variables.items(): - if attr == 'labels': - command += label_formatter(value) - elif value is None: - command.append(f"--{attr}") - elif isinstance(value, bool) and value: - command.append(f"--{attr}") + if value is None or (isinstance(value, bool) and value): + args.append(f"--{attr}") elif isinstance(value, list): - command.extend([f"--{attr}={v}" for v in value]) + args.extend([f"--{attr}={v}" for v in value]) else: - command.append(f"--{attr}={value}") - return command + args.append(f"--{attr}={value}") + return args @_fallback_to_location_from_variables @_fallback_to_project_id_from_variables @@ -884,6 +897,7 @@ def is_job_dataflow_running( location=location, poll_sleep=self.poll_sleep, drain_pipeline=self.drain_pipeline, + num_retries=self.num_retries, ) return jobs_controller.is_job_running() @@ -918,5 +932,78 @@ def cancel_job( location=location, poll_sleep=self.poll_sleep, drain_pipeline=self.drain_pipeline, + num_retries=self.num_retries, ) jobs_controller.cancel() + + @GoogleBaseHook.fallback_to_default_project_id + def start_sql_job( + self, + job_name: str, + query: str, + options: Dict[str, Any], + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + ): + """ + Starts Dataflow SQL query. + + :param job_name: The unique name to assign to the Cloud Dataflow job. + :type job_name: str + :param query: The SQL query to execute. + :type query: str + :param options: Job parameters to be executed. + For more information, look at: + `https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/sql/query + `__ + command reference + :param location: The location of the Dataflow job (for example europe-west1) + :type location: str + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :type project_id: Optional[str] + :param on_new_job_id_callback: Callback called when the job ID is known. + :type on_new_job_id_callback: callable + :return: the new job object + """ + cmd = [ + "gcloud", + "dataflow", + "sql", + "query", + query, + f"--project={project_id}", + "--format=value(job.id)", + f"--job-name={job_name}", + f"--region={location}", + *(self._options_to_args(options)), + ] + self.log.info("Executing command: %s", " ".join([shlex.quote(c) for c in cmd])) + with self.provide_authorized_gcloud(): + proc = subprocess.run( # pylint: disable=subprocess-run-check + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + self.log.info("Output: %s", proc.stdout.decode()) + self.log.warning("Stderr: %s", proc.stderr.decode()) + self.log.info("Exit code %d", proc.returncode) + if proc.returncode != 0: + raise AirflowException(f"Process exit with non-zero exit code. Exit code: {proc.returncode}") + job_id = proc.stdout.decode().strip() + + self.log.info("Created job ID: %s", job_id) + if on_new_job_id_callback: + on_new_job_id_callback(job_id) + + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + job_id=job_id, + location=location, + poll_sleep=self.poll_sleep, + num_retries=self.num_retries, + drain_pipeline=self.drain_pipeline, + ) + jobs_controller.wait_for_done() + + return jobs_controller.get_jobs(refresh=True)[0] diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 3edc5df88cf0e..d2844a360b192 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -173,8 +173,8 @@ class DataflowCreateJavaJobOperator(BaseOperator): """ - template_fields = ['options', 'jar', 'job_name'] - ui_color = '#0273d4' + template_fields = ["options", "jar", "job_name"] + ui_color = "#0273d4" # pylint: disable=too-many-arguments @apply_defaults @@ -182,12 +182,12 @@ def __init__( self, *, jar: str, - job_name: str = '{{task.task_id}}', + job_name: str = "{{task.task_id}}", dataflow_default_options: Optional[dict] = None, options: Optional[dict] = None, project_id: Optional[str] = None, location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, poll_sleep: int = 10, job_class: Optional[str] = None, @@ -199,8 +199,8 @@ def __init__( dataflow_default_options = dataflow_default_options or {} options = options or {} - options.setdefault('labels', {}).update( - {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} + options.setdefault("labels", {}).update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} ) self.project_id = project_id self.location = location @@ -243,7 +243,7 @@ def execute(self, context): if not is_running: with ExitStack() as exit_stack: - if self.jar.lower().startswith('gs://'): + if self.jar.lower().startswith("gs://"): gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member gcs_hook.provide_file(object_url=self.jar) @@ -373,30 +373,30 @@ class DataflowTemplatedJobStartOperator(BaseOperator): """ template_fields = [ - 'template', - 'job_name', - 'options', - 'parameters', - 'project_id', - 'location', - 'gcp_conn_id', - 'impersonation_chain', - 'environment', + "template", + "job_name", + "options", + "parameters", + "project_id", + "location", + "gcp_conn_id", + "impersonation_chain", + "environment", ] - ui_color = '#0273d4' + ui_color = "#0273d4" @apply_defaults def __init__( # pylint: disable=too-many-arguments self, *, template: str, - job_name: str = '{{task.task_id}}', + job_name: str = "{{task.task_id}}", options: Optional[Dict[str, Any]] = None, dataflow_default_options: Optional[Dict[str, Any]] = None, parameters: Optional[Dict[str, str]] = None, project_id: Optional[str] = None, location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, poll_sleep: int = 10, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, @@ -475,7 +475,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator): :type drain_pipeline: bool """ - template_fields = ["body", 'location', 'project_id', 'gcp_conn_id'] + template_fields = ["body", "location", "project_id", "gcp_conn_id"] @apply_defaults def __init__( @@ -483,7 +483,7 @@ def __init__( body: Dict, location: str, project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, drain_pipeline: bool = False, *args, @@ -501,7 +501,9 @@ def __init__( def execute(self, context): self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, drain_pipeline=self.drain_pipeline + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + drain_pipeline=self.drain_pipeline, ) def set_current_job_id(job_id): @@ -522,6 +524,102 @@ def on_kill(self) -> None: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) +class DataflowStartSqlJobOperator(BaseOperator): + """ + Starts Dataflow SQL query. + + :param job_name: The unique name to assign to the Cloud Dataflow job. + :type job_name: str + :param query: The SQL query to execute. + :type query: str + :param options: Job parameters to be executed. It can be a dictionary with the following keys. + + For more information, look at: + `https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/sql/query + `__ + command reference + + :param options: dict + :param location: The location of the Dataflow job (for example europe-west1) + :type location: str + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :type project_id: Optional[str] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud + Platform. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it + instead of canceling during during killing task instance. See: + https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :type drain_pipeline: bool + """ + + template_fields = [ + "job_name", + "query", + "options", + "location", + "project_id", + "gcp_conn_id", + ] + + @apply_defaults + def __init__( + self, + job_name: str, + query: str, + options: Dict[str, Any], + location: str = DEFAULT_DATAFLOW_LOCATION, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + drain_pipeline: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.job_name = job_name + self.query = query + self.options = options + self.location = location + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.drain_pipeline = drain_pipeline + self.job_id = None + self.hook: Optional[DataflowHook] = None + + def execute(self, context): + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + drain_pipeline=self.drain_pipeline, + ) + + def set_current_job_id(job_id): + self.job_id = job_id + + job = self.hook.start_sql_job( + job_name=self.job_name, + query=self.query, + options=self.options, + location=self.location, + project_id=self.project_id, + on_new_job_id_callback=set_current_job_id, + ) + + return job + + def on_kill(self) -> None: + self.log.info("On kill.") + if self.job_id: + self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) + + # pylint: disable=too-many-instance-attributes class DataflowCreatePythonJobOperator(BaseOperator): """ @@ -596,14 +694,14 @@ class DataflowCreatePythonJobOperator(BaseOperator): :type drain_pipeline: bool """ - template_fields = ['options', 'dataflow_default_options', 'job_name', 'py_file'] + template_fields = ["options", "dataflow_default_options", "job_name", "py_file"] @apply_defaults def __init__( # pylint: disable=too-many-arguments self, *, py_file: str, - job_name: str = '{{task.task_id}}', + job_name: str = "{{task.task_id}}", dataflow_default_options: Optional[dict] = None, options: Optional[dict] = None, py_interpreter: str = "python3", @@ -612,7 +710,7 @@ def __init__( # pylint: disable=too-many-arguments py_system_site_packages: bool = False, project_id: Optional[str] = None, location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, poll_sleep: int = 10, drain_pipeline: bool = False, @@ -626,8 +724,8 @@ def __init__( # pylint: disable=too-many-arguments self.py_options = py_options or [] self.dataflow_default_options = dataflow_default_options or {} self.options = options or {} - self.options.setdefault('labels', {}).update( - {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} + self.options.setdefault("labels", {}).update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} ) self.py_interpreter = py_interpreter self.py_requirements = py_requirements @@ -644,7 +742,7 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context): """Execute the python dataflow job.""" with ExitStack() as exit_stack: - if self.py_file.lower().startswith('gs://'): + if self.py_file.lower().startswith("gs://"): gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member gcs_hook.provide_file(object_url=self.py_file) @@ -660,7 +758,7 @@ def execute(self, context): dataflow_options = self.dataflow_default_options.copy() dataflow_options.update(self.options) # Convert argument names from lowerCamelCase to snake case. - camel_to_snake = lambda name: re.sub(r'[A-Z]', lambda x: '_' + x.group(0).lower(), name) + camel_to_snake = lambda name: re.sub(r"[A-Z]", lambda x: "_" + x.group(0).lower(), name) formatted_options = {camel_to_snake(key): dataflow_options[key] for key in dataflow_options} def set_current_job_id(job_id): diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index be486eed511df..6dbe1f37fa4c0 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -19,6 +19,7 @@ import copy import shlex +import subprocess import unittest from typing import Any, Dict from unittest import mock @@ -90,6 +91,23 @@ }, } TEST_PROJECT_ID = 'test-project-id' +TEST_SQL_JOB_NAME = 'test-sql-job-name' +TEST_DATASET = 'test-dataset' +TEST_SQL_OPTIONS = { + "bigquery-project": TEST_PROJECT, + "bigquery-dataset": TEST_DATASET, + "bigquery-table": "beam_output", + 'bigquery-write-disposition': "write-truncate", +} +TEST_SQL_QUERY = """ +SELECT + sales_region as sales_region, + count(state_id) as count_state +FROM + bigquery.table.test-project.beam_samples.beam_table +GROUP BY sales_region; +""" +TEST_SQL_JOB_ID = 'test-job-id' class TestFallbackToVariables(unittest.TestCase): @@ -873,10 +891,80 @@ def test_cancel_job(self, mock_get_conn, jobs_controller): name=UNIQUE_JOB_NAME, poll_sleep=10, project_number=TEST_PROJECT, + num_retries=5, drain_pipeline=False, ) jobs_controller.cancel() + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.provide_authorized_gcloud')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + @mock.patch(DATAFLOW_STRING.format('subprocess.run')) + def test_start_sql_job_failed_to_run( + self, mock_run, mock_get_conn, mock_provide_authorized_gcloud, mock_controller + ): + test_job = {'id': "TEST_JOB_ID"} + mock_controller.return_value.get_jobs.return_value = [test_job] + mock_run.return_value = mock.MagicMock( + stdout=f"{TEST_JOB_ID}\n".encode(), stderr=f"{TEST_JOB_ID}\n".encode(), returncode=0 + ) + on_new_job_id_callback = mock.MagicMock() + result = self.dataflow_hook.start_sql_job( + job_name=TEST_SQL_JOB_NAME, + query=TEST_SQL_QUERY, + options=TEST_SQL_OPTIONS, + location=TEST_LOCATION, + project_id=TEST_PROJECT, + on_new_job_id_callback=on_new_job_id_callback, + ) + mock_run.assert_called_once_with( + [ + 'gcloud', + 'dataflow', + 'sql', + 'query', + TEST_SQL_QUERY, + '--project=test-project', + '--format=value(job.id)', + '--job-name=test-sql-job-name', + '--region=custom-location', + '--bigquery-project=test-project', + '--bigquery-dataset=test-dataset', + '--bigquery-table=beam_output', + '--bigquery-write-disposition=write-truncate', + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + mock_controller.assert_called_once_with( + dataflow=mock_get_conn.return_value, + job_id=TEST_JOB_ID, + location=TEST_LOCATION, + poll_sleep=10, + project_number=TEST_PROJECT, + num_retries=5, + drain_pipeline=False, + ) + mock_controller.return_value.wait_for_done.assert_called_once() + self.assertEqual(result, test_job) + + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.provide_authorized_gcloud')) + @mock.patch(DATAFLOW_STRING.format('subprocess.run')) + def test_start_sql_job(self, mock_run, mock_provide_authorized_gcloud, mock_get_conn): + mock_run.return_value = mock.MagicMock( + stdout=f"{TEST_JOB_ID}\n".encode(), stderr=f"{TEST_JOB_ID}\n".encode(), returncode=1 + ) + with self.assertRaises(AirflowException): + self.dataflow_hook.start_sql_job( + job_name=TEST_SQL_JOB_NAME, + query=TEST_SQL_QUERY, + options=TEST_SQL_OPTIONS, + location=TEST_LOCATION, + project_id=TEST_PROJECT, + on_new_job_id_callback=mock.MagicMock(), + ) + class TestDataflowJob(unittest.TestCase): def setUp(self): diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 02d95e52179ae..9cb7490990e0c 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -18,6 +18,7 @@ # import unittest +from copy import deepcopy from unittest import mock from airflow.providers.google.cloud.operators.dataflow import ( @@ -25,6 +26,7 @@ DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, DataflowStartFlexTemplateOperator, + DataflowStartSqlJobOperator, DataflowTemplatedJobStartOperator, ) from airflow.version import version @@ -69,7 +71,24 @@ }, } TEST_LOCATION = 'custom-location' -TEST_PROJECT_ID = 'test-project-id' +TEST_PROJECT = "test-project" +TEST_SQL_JOB_NAME = 'test-sql-job-name' +TEST_DATASET = 'test-dataset' +TEST_SQL_OPTIONS = { + "bigquery-project": TEST_PROJECT, + "bigquery-dataset": TEST_DATASET, + "bigquery-table": "beam_output", + 'bigquery-write-disposition': "write-truncate", +} +TEST_SQL_QUERY = """ +SELECT + sales_region as sales_region, + count(state_id) as count_state +FROM + bigquery.table.test-project.beam_samples.beam_table +GROUP BY sales_region; +""" +TEST_SQL_JOB_ID = 'test-job-id' class TestDataflowPythonOperator(unittest.TestCase): @@ -309,14 +328,14 @@ def test_execute(self, mock_dataflow): task_id="start_flex_template_streaming_beam_sql", body={"launchParameter": TEST_FLEX_PARAMETERS}, do_xcom_push=True, - project_id=TEST_PROJECT_ID, + project_id=TEST_PROJECT, location=TEST_LOCATION, ) start_flex_template.execute(mock.MagicMock()) mock_dataflow.return_value.start_flex_template.assert_called_once_with( body={"launchParameter": TEST_FLEX_PARAMETERS}, location=TEST_LOCATION, - project_id=TEST_PROJECT_ID, + project_id=TEST_PROJECT, on_new_job_id_callback=mock.ANY, ) @@ -326,11 +345,40 @@ def test_on_kill(self): body={"launchParameter": TEST_FLEX_PARAMETERS}, do_xcom_push=True, location=TEST_LOCATION, - project_id=TEST_PROJECT_ID, + project_id=TEST_PROJECT, ) start_flex_template.hook = mock.MagicMock() start_flex_template.job_id = JOB_ID start_flex_template.on_kill() start_flex_template.hook.cancel_job.assert_called_once_with( - job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT_ID + job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT + ) + + +class TestDataflowSqlOperator(unittest.TestCase): + @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') + def test_execute(self, mock_hook): + start_sql = DataflowStartSqlJobOperator( + task_id="start_sql_query", + job_name=TEST_SQL_JOB_NAME, + query=TEST_SQL_QUERY, + options=deepcopy(TEST_SQL_OPTIONS), + location=TEST_LOCATION, + do_xcom_push=True, + ) + + start_sql.execute(mock.MagicMock()) + mock_hook.assert_called_once_with( + gcp_conn_id='google_cloud_default', delegate_to=None, drain_pipeline=False + ) + mock_hook.return_value.start_sql_job.assert_called_once_with( + job_name=TEST_SQL_JOB_NAME, + query=TEST_SQL_QUERY, + options=TEST_SQL_OPTIONS, + location=TEST_LOCATION, + project_id=None, + on_new_job_id_callback=mock.ANY, ) + start_sql.job_id = TEST_SQL_JOB_ID + start_sql.on_kill() + mock_hook.return_value.cancel_job.assert_called_once_with(job_id='test-job-id', project_id=None) diff --git a/tests/providers/google/cloud/operators/test_dataflow_system.py b/tests/providers/google/cloud/operators/test_dataflow_system.py index 28abed6244b5b..4c9a3160c3753 100644 --- a/tests/providers/google/cloud/operators/test_dataflow_system.py +++ b/tests/providers/google/cloud/operators/test_dataflow_system.py @@ -33,6 +33,11 @@ PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION, PUBSUB_FLEX_TEMPLATE_TOPIC, ) +from airflow.providers.google.cloud.example_dags.example_dataflow_sql import ( + BQ_SQL_DATASET, + DATAFLOW_SQL_JOB_NAME, + DATAFLOW_SQL_LOCATION, +) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATAFLOW_KEY, GCP_GCS_TRANSFER_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context @@ -241,3 +246,117 @@ def tearDown(self) -> None: # Delete the Cloud Storage bucket self.execute_cmd(["gsutil", "rm", "-r", f"gs://{GCS_FLEX_TEMPLATE_BUCKET_NAME}"]) + + +@pytest.mark.backend("mysql", "postgres") +@pytest.mark.credential_file(GCP_GCS_TRANSFER_KEY) +class CloudDataflowExampleDagSqlSystemTest(GoogleSystemTest): + @provide_gcp_context(GCP_GCS_TRANSFER_KEY, project_id=GoogleSystemTest._project_id()) + def setUp(self) -> None: + # Build image with pipeline + with NamedTemporaryFile(suffix=".csv") as f: + f.write( + textwrap.dedent( + """\ + state_id,state_code,state_name,sales_region + 1,MO,Missouri,Region_1 + 2,SC,South Carolina,Region_1 + 3,IN,Indiana,Region_1 + 6,DE,Delaware,Region_2 + 15,VT,Vermont,Region_2 + 16,DC,District of Columbia,Region_2 + 19,CT,Connecticut,Region_2 + 20,ME,Maine,Region_2 + 35,PA,Pennsylvania,Region_2 + 38,NJ,New Jersey,Region_2 + 47,MA,Massachusetts,Region_2 + 54,RI,Rhode Island,Region_2 + 55,NY,New York,Region_2 + 60,MD,Maryland,Region_2 + 66,NH,New Hampshire,Region_2 + 4,CA,California,Region_3 + 8,AK,Alaska,Region_3 + 37,WA,Washington,Region_3 + 61,OR,Oregon,Region_3 + 33,HI,Hawaii,Region_4 + 59,AS,American Samoa,Region_4 + 65,GU,Guam,Region_4 + 5,IA,Iowa,Region_5 + 32,NV,Nevada,Region_5 + 11,PR,Puerto Rico,Region_6 + 17,CO,Colorado,Region_6 + 18,MS,Mississippi,Region_6 + 41,AL,Alabama,Region_6 + 42,AR,Arkansas,Region_6 + 43,FL,Florida,Region_6 + 44,NM,New Mexico,Region_6 + 46,GA,Georgia,Region_6 + 48,KS,Kansas,Region_6 + 52,AZ,Arizona,Region_6 + 56,TN,Tennessee,Region_6 + 58,TX,Texas,Region_6 + 63,LA,Louisiana,Region_6 + 7,ID,Idaho,Region_7 + 12,IL,Illinois,Region_7 + 13,ND,North Dakota,Region_7 + 31,MN,Minnesota,Region_7 + 34,MT,Montana,Region_7 + 36,SD,South Dakota,Region_7 + 50,MI,Michigan,Region_7 + 51,UT,Utah,Region_7 + 64,WY,Wyoming,Region_7 + 9,NE,Nebraska,Region_8 + 10,VA,Virginia,Region_8 + 14,OK,Oklahoma,Region_8 + 39,NC,North Carolina,Region_8 + 40,WV,West Virginia,Region_8 + 45,KY,Kentucky,Region_8 + 53,WI,Wisconsin,Region_8 + 57,OH,Ohio,Region_8 + 49,VI,United States Virgin Islands,Region_9 + 62,MP,Commonwealth of the Northern Mariana Islands,Region_9 + """ + ).encode() + ) + f.flush() + + self.execute_cmd(["bq", "mk", "--dataset", f'{self._project_id()}:{BQ_SQL_DATASET}']) + + self.execute_cmd( + ["bq", "load", "--autodetect", "--source_format=CSV", f"{BQ_SQL_DATASET}.beam_input", f.name] + ) + + @provide_gcp_context(GCP_GCS_TRANSFER_KEY, project_id=GoogleSystemTest._project_id()) + def test_run_example_dag_function(self): + self.run_dag("example_gcp_dataflow_sql", CLOUD_DAG_FOLDER) + + @provide_gcp_context(GCP_GCS_TRANSFER_KEY, project_id=GoogleSystemTest._project_id()) + def tearDown(self) -> None: + # Execute test query + self.execute_cmd( + [ + 'bq', + 'query', + '--use_legacy_sql=false', + f'select * FROM `{self._project_id()}.{BQ_SQL_DATASET}.beam_output`', + ] + ) + + # Stop the Dataflow pipelines. + self.execute_cmd( + [ + "bash", + "-c", + textwrap.dedent( + f"""\ + gcloud dataflow jobs list \ + --region={DATAFLOW_SQL_LOCATION} \ + --filter 'NAME:{DATAFLOW_SQL_JOB_NAME} AND STATE=Running' \ + --format 'value(JOB_ID)' \ + | xargs -r gcloud dataflow jobs cancel --region={DATAFLOW_SQL_LOCATION} + """ + ), + ] + ) + # Delete the BigQuery dataset, + self.execute_cmd(["bq", "rm", "-r", "-f", "-d", f'{self._project_id()}:{BQ_SQL_DATASET}'])