diff --git a/airflow/gcp/hooks/dataproc.py b/airflow/gcp/hooks/dataproc.py index cf846e1545af7..75e4706243f6d 100644 --- a/airflow/gcp/hooks/dataproc.py +++ b/airflow/gcp/hooks/dataproc.py @@ -565,6 +565,143 @@ def cancel(self, project_id: str, job_id: str, region: str = 'global') -> Dict: jobId=job_id ) + def get_final_cluster_state(self, project_id, region, cluster_name, logger): + """ + Poll for the state of a cluster until one is available + + :param project_id: + :param region: + :param cluster_name: + :param logger: + :return: + """ + while True: + state = DataProcHook.get_cluster_state(self.get_conn(), project_id, region, cluster_name) + if state is None: + logger.info("No state for cluster '%s'", cluster_name) + time.sleep(15) + else: + logger.info("State for cluster '%s' is %s", cluster_name, state) + return state + + @staticmethod + def get_cluster_state(service, project_id, region, cluster_name): + """ + Get the state of a cluster if it has one, otherwise None + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: + """ + cluster = DataProcHook.find_cluster(service, project_id, region, cluster_name) + if cluster and 'status' in cluster: + return cluster['status']['state'] + else: + return None + + @staticmethod + def find_cluster(service, project_id, region, cluster_name): + """ + Retrieve a cluster from the project/region if it exists, otherwise None + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: + """ + cluster_list = DataProcHook.get_cluster_list_for_project(service, project_id, region) + cluster = [c for c in cluster_list if c['clusterName'] == cluster_name] + if cluster: + return cluster[0] + return None + + @staticmethod + def get_cluster_list_for_project(service, project_id, region): + """ + List all clusters for a given project/region, an empty list if none exist + :param service: + :param project_id: + :param region: + :return: + """ + result = service.projects().regions().clusters().list( + projectId=project_id, + region=region + ).execute() + return result.get('clusters', []) + + @staticmethod + def execute_dataproc_diagnose(service, project_id, region, cluster_name): + """ + Execute the diagonse command against a given cluster, useful to get debugging + information if something has gone wrong or cluster creation failed. + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: + """ + response = service.projects().regions().clusters().diagnose( + projectId=project_id, + region=region, + clusterName=cluster_name, + body={} + ).execute() + operation_name = response['name'] + return operation_name + + @staticmethod + def execute_delete(service, project_id, region, cluster_name): + """ + Delete a specified cluster + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: The identifier of the operation being executed + """ + response = service.projects().regions().clusters().delete( + projectId=project_id, + region=region, + clusterName=cluster_name + ).execute(num_retries=5) + operation_name = response['name'] + return operation_name + + @staticmethod + def wait_for_operation_done(service, operation_name): + """ + Poll for the completion of a specific GCP operation + :param service: + :param operation_name: + :return: The response code of the completed operation + """ + while True: + response = service.projects().regions().operations().get( + name=operation_name + ).execute(num_retries=5) + + if response.get('done'): + return response + time.sleep(15) + + @staticmethod + def wait_for_operation_done_or_error(service, operation_name): + """ + Block until the specified operation is done. Throws an AirflowException if + the operation completed but had an error + :param service: + :param operation_name: + :return: + """ + response = DataProcHook.wait_for_operation_done(service, operation_name) + if response.get('done'): + if 'error' in response: + raise AirflowException(str(response['error'])) + else: + return + setattr( DataProcHook, diff --git a/airflow/gcp/operators/dataproc.py b/airflow/gcp/operators/dataproc.py index bcbcfe48e0393..4910ccaa33abf 100644 --- a/airflow/gcp/operators/dataproc.py +++ b/airflow/gcp/operators/dataproc.py @@ -276,8 +276,36 @@ def __init__(self, ) ), "num_workers == 0 means single node mode - no preemptibles allowed" + def _cluster_ready(self, state, service): + if state == 'RUNNING': + return True + if state == 'DELETING': + raise Exception('Tried to create a cluster but it\'s in DELETING, something went wrong.') + if state == 'ERROR': + cluster = DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name) + try: + error_details = cluster['status']['details'] + except KeyError: + error_details = 'Unknown error in cluster creation, ' \ + 'check Google Cloud console for details.' + + self.log.info('Dataproc cluster creation resulted in an ERROR state running diagnostics') + self.log.info(error_details) + diagnose_operation_name = \ + DataProcHook.execute_dataproc_diagnose(service, self.project_id, + self.region, self.cluster_name) + diagnose_result = DataProcHook.wait_for_operation_done(service, diagnose_operation_name) + if diagnose_result.get('response') and diagnose_result.get('response').get('outputUri'): + output_uri = diagnose_result.get('response').get('outputUri') + self.log.info('Diagnostic information for ERROR cluster available at [%s]', output_uri) + else: + self.log.info('Diagnostic information could not be retrieved!') + + raise Exception(error_details) + return False + def _get_init_action_timeout(self): - match = re.match(r"^(\d+)(s|m)$", self.init_action_timeout) + match = re.match(r"^(\d+)([sm])$", self.init_action_timeout) if match: if match.group(2) == "s": return self.init_action_timeout @@ -445,11 +473,51 @@ def _build_cluster_data(self): return cluster_data + def _usable_existing_cluster_present(self, service): + existing_cluster = DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name) + if existing_cluster: + self.log.info( + 'Cluster %s already exists... Checking status...', + self.cluster_name + ) + existing_status = self.hook.get_final_cluster_state(self.project_id, + self.region, self.cluster_name, self.log) + + if existing_status == 'RUNNING': + self.log.info('Cluster exists and is already running. Using it.') + return True + + elif existing_status == 'DELETING': + while DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name) \ + and DataProcHook.get_cluster_state(service, self.project_id, + self.region, self.cluster_name) == 'DELETING': + self.log.info('Existing cluster is deleting, waiting for it to finish') + time.sleep(15) + + elif existing_status == 'ERROR': + self.log.info('Existing cluster in ERROR state, deleting it first') + + operation_name = DataProcHook.execute_delete(service, self.project_id, + self.region, self.cluster_name) + self.log.info("Cluster delete operation name: %s", operation_name) + DataProcHook.wait_for_operation_done_or_error(service, operation_name) + + return False + def start(self): """ Create a new cluster on Google Cloud Dataproc. """ self.log.info('Creating cluster: %s', self.cluster_name) + hook = DataProcHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to + ) + service = hook.get_conn() + + if self._usable_existing_cluster_present(service): + return True + cluster_data = self._build_cluster_data() return ( @@ -542,7 +610,7 @@ def _build_scale_cluster_data(self): @staticmethod def _get_graceful_decommission_timeout(timeout): - match = re.match(r"^(\d+)(s|m|h|d)$", timeout) + match = re.match(r"^(\d+)([smdh])$", timeout) if match: if match.group(2) == "s": return timeout diff --git a/tests/gcp/operators/test_dataproc.py b/tests/gcp/operators/test_dataproc.py index 8a7f92bb868cd..cdddbc77ffd29 100644 --- a/tests/gcp/operators/test_dataproc.py +++ b/tests/gcp/operators/test_dataproc.py @@ -399,8 +399,12 @@ def test_create_cluster(self): self.operation = {'name': 'operation', 'done': True} self.mock_execute = Mock() self.mock_execute.execute.return_value = self.operation + self.mock_list = Mock() + self.mock_list_execute = {} + self.mock_list.execute.return_value = self.mock_list_execute self.mock_clusters = Mock() self.mock_clusters.create.return_value = self.mock_execute + self.mock_clusters.list.return_value = self.mock_list self.mock_regions = Mock() self.mock_regions.clusters.return_value = self.mock_clusters self.mock_projects = Mock() @@ -523,6 +527,45 @@ def test_create_cluster_with_multiple_masters(self): 'labels': {'airflow-version': mock.ANY}}) hook.wait.assert_called_once_with(self.operation) + def test_create_cluster_deletes_error_cluster(self): + # Setup service.projects().regions().clusters().create() + # .execute() + # pylint:disable=attribute-defined-outside-init + self.operation = {'name': 'operation', 'done': True} + self.mock_execute = Mock() + self.mock_execute.execute.return_value = self.operation + self.mock_list = Mock() + self.mock_list_execute = {'clusters': [{'clusterName': CLUSTER_NAME, 'status': {'state': 'ERROR'}}]} + self.mock_list.execute.return_value = self.mock_list_execute + self.mock_clusters = Mock() + self.mock_clusters.create.return_value = self.mock_execute + self.mock_clusters.list.return_value = self.mock_list + self.mock_regions = Mock() + self.mock_regions.clusters.return_value = self.mock_clusters + self.mock_projects = Mock() + self.mock_projects.regions.return_value = self.mock_regions + self.mock_conn = Mock() + self.mock_conn.projects.return_value = self.mock_projects + + with patch(HOOK) as mock_hook: + hook = mock_hook() + hook.get_conn.return_value = self.mock_conn + hook.wait.return_value = None + hook.get_final_cluster_state.return_value = "ERROR" + + dataproc_task = DataprocClusterCreateOperator( + task_id=TASK_ID, + region=GCP_REGION, + cluster_name=CLUSTER_NAME, + project_id=GCP_PROJECT_ID, + num_workers=NUM_WORKERS, + zone=GCE_ZONE, + dag=self.dag + ) + with patch.object(dataproc_task.log, 'info') as mock_info: + dataproc_task.execute(None) + mock_info.assert_any_call('Existing cluster in ERROR state, deleting it first') + def test_build_cluster_data_internal_ip_only_without_subnetwork(self): def create_cluster_with_invalid_internal_ip_only_setup(): @@ -699,7 +742,7 @@ def test_delete_cluster(self): projectId=GCP_PROJECT_ID, clusterName=CLUSTER_NAME, requestId=mock.ANY) - hook.wait.assert_called_once_with(self.operation) + hook.wait.assert_called_with(self.operation) def test_render_template(self): task = DataprocClusterDeleteOperator( @@ -792,7 +835,7 @@ def setUp(self): schedule_interval='@daily') @mock.patch( - 'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id', + 'airflow.gcp.hooks.dataproc.DataProcHook.project_id', new_callable=PropertyMock, return_value=GCP_PROJECT_ID ) @@ -877,7 +920,7 @@ def setUp(self): schedule_interval='@daily') @mock.patch( - 'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id', + 'airflow.gcp.hooks.dataproc.DataProcHook.project_id', new_callable=PropertyMock, return_value=GCP_PROJECT_ID ) @@ -962,7 +1005,7 @@ def setUp(self): schedule_interval='@daily') @mock.patch( - 'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id', + 'airflow.gcp.hooks.dataproc.DataProcHook.project_id', new_callable=PropertyMock, return_value=GCP_PROJECT_ID ) @@ -1052,7 +1095,7 @@ def setUp(self): schedule_interval='@daily') @mock.patch( - 'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id', + 'airflow.gcp.hooks.dataproc.DataProcHook.project_id', new_callable=PropertyMock, return_value=GCP_PROJECT_ID ) @@ -1140,7 +1183,7 @@ def setUp(self): schedule_interval='@daily') @mock.patch( - 'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id', + 'airflow.gcp.hooks.dataproc.DataProcHook.project_id', new_callable=PropertyMock, return_value=GCP_PROJECT_ID )