diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py index 9fa61f906f9fe..279b9dd21a862 100644 --- a/airflow/contrib/hooks/gcp_dataflow_hook.py +++ b/airflow/contrib/hooks/gcp_dataflow_hook.py @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. import json +import re import select import subprocess import time @@ -166,7 +167,7 @@ def __init__(self, def get_conn(self): """ - Returns a Google Cloud Storage service object. + Returns a Google Cloud Dataflow service object. """ http_authorized = self._authorize() return build( @@ -191,10 +192,7 @@ def _set_variables(variables): def start_java_dataflow(self, task_id, variables, dataflow, job_class=None, append_job_name=True): - if append_job_name: - name = task_id + "-" + str(uuid.uuid1())[:8] - else: - name = task_id + name = self._build_dataflow_job_name(task_id, append_job_name) variables['jobName'] = name def label_formatter(labels_dict): @@ -207,19 +205,13 @@ def label_formatter(labels_dict): def start_template_dataflow(self, task_id, variables, parameters, dataflow_template, append_job_name=True): - if append_job_name: - name = task_id + "-" + str(uuid.uuid1())[:8] - else: - name = task_id + name = self._build_dataflow_job_name(task_id, append_job_name) self._start_template_dataflow( name, variables, parameters, dataflow_template) def start_python_dataflow(self, task_id, variables, dataflow, py_options, append_job_name=True): - if append_job_name: - name = task_id + "-" + str(uuid.uuid1())[:8] - else: - name = task_id + name = self._build_dataflow_job_name(task_id, append_job_name) variables['job_name'] = name def label_formatter(labels_dict): @@ -229,6 +221,23 @@ def label_formatter(labels_dict): ["python"] + py_options + [dataflow], label_formatter) + @staticmethod + def _build_dataflow_job_name(task_id, append_job_name=True): + task_id = str(task_id).replace('_', '-') + + assert re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id), \ + 'Invalid job_name ({}); the name must consist of ' \ + 'only the characters [-a-z0-9], starting with a ' \ + 'letter and ending with a letter or number '.format( + task_id) + + if append_job_name: + job_name = task_id + "-" + str(uuid.uuid1())[:8] + else: + job_name = task_id + + return job_name + def _build_cmd(self, task_id, variables, label_formatter): command = ["--runner=DataflowRunner"] if variables is not None: diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py index f16dcdfcbc1e1..90714c6ee4f62 100644 --- a/tests/contrib/hooks/test_gcp_dataflow_hook.py +++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -172,6 +172,73 @@ def poll_resp_error(): self.assertRaises(Exception, dataflow.wait_for_done) mock_logging.warning.assert_has_calls([call('test'), call('error')]) + def test_valid_dataflow_job_name(self): + job_name = self.dataflow_hook._build_dataflow_job_name( + task_id=TASK_ID, append_job_name=False + ) + + self.assertEquals(job_name, TASK_ID) + + def test_fix_underscore_in_task_id(self): + task_id_with_underscore = 'test_example' + fixed_job_name = task_id_with_underscore.replace( + '_', '-' + ) + job_name = self.dataflow_hook._build_dataflow_job_name( + task_id=task_id_with_underscore, append_job_name=False + ) + + self.assertEquals(job_name, fixed_job_name) + + def test_invalid_dataflow_job_name(self): + invalid_job_name = '9test_invalid_name' + fixed_name = invalid_job_name.replace( + '_', '-') + + with self.assertRaises(AssertionError) as e: + self.dataflow_hook._build_dataflow_job_name( + task_id=invalid_job_name, append_job_name=False + ) + # Test whether the job_name is present in the Error msg + self.assertIn('Invalid job_name ({})'.format(fixed_name), + str(e.exception)) + + def test_dataflow_job_regex_check(self): + + self.assertEquals(self.dataflow_hook._build_dataflow_job_name( + task_id='df-job-1', append_job_name=False + ), 'df-job-1') + + self.assertEquals(self.dataflow_hook._build_dataflow_job_name( + task_id='df-job', append_job_name=False + ), 'df-job') + + self.assertEquals(self.dataflow_hook._build_dataflow_job_name( + task_id='dfjob', append_job_name=False + ), 'dfjob') + + self.assertEquals(self.dataflow_hook._build_dataflow_job_name( + task_id='dfjob1', append_job_name=False + ), 'dfjob1') + + self.assertRaises( + AssertionError, + self.dataflow_hook._build_dataflow_job_name, + task_id='1dfjob', append_job_name=False + ) + + self.assertRaises( + AssertionError, + self.dataflow_hook._build_dataflow_job_name, + task_id='dfjob@', append_job_name=False + ) + + self.assertRaises( + AssertionError, + self.dataflow_hook._build_dataflow_job_name, + task_id='df^jo', append_job_name=False + ) + class DataFlowTemplateHookTest(unittest.TestCase):