Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions airflow/gcp/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,143 @@ def cancel(self, project_id: str, job_id: str, region: str = 'global') -> Dict:
jobId=job_id
)

def get_final_cluster_state(self, project_id, region, cluster_name, logger):
"""
Poll for the state of a cluster until one is available

:param project_id:
:param region:
:param cluster_name:
:param logger:
:return:
"""
while True:
state = DataProcHook.get_cluster_state(self.get_conn(), project_id, region, cluster_name)
if state is None:
logger.info("No state for cluster '%s'", cluster_name)
time.sleep(15)
else:
logger.info("State for cluster '%s' is %s", cluster_name, state)
return state

@staticmethod
def get_cluster_state(service, project_id, region, cluster_name):
"""
Get the state of a cluster if it has one, otherwise None
:param service:
:param project_id:
:param region:
:param cluster_name:
:return:
"""
cluster = DataProcHook.find_cluster(service, project_id, region, cluster_name)
if cluster and 'status' in cluster:
return cluster['status']['state']
else:
return None

@staticmethod
def find_cluster(service, project_id, region, cluster_name):
"""
Retrieve a cluster from the project/region if it exists, otherwise None
:param service:
:param project_id:
:param region:
:param cluster_name:
:return:
"""
cluster_list = DataProcHook.get_cluster_list_for_project(service, project_id, region)
cluster = [c for c in cluster_list if c['clusterName'] == cluster_name]
if cluster:
return cluster[0]
return None

@staticmethod
def get_cluster_list_for_project(service, project_id, region):
"""
List all clusters for a given project/region, an empty list if none exist
:param service:
:param project_id:
:param region:
:return:
"""
result = service.projects().regions().clusters().list(
projectId=project_id,
region=region
).execute()
return result.get('clusters', [])

@staticmethod
def execute_dataproc_diagnose(service, project_id, region, cluster_name):
"""
Execute the diagonse command against a given cluster, useful to get debugging
information if something has gone wrong or cluster creation failed.
:param service:
:param project_id:
:param region:
:param cluster_name:
:return:
"""
response = service.projects().regions().clusters().diagnose(
projectId=project_id,
region=region,
clusterName=cluster_name,
body={}
).execute()
operation_name = response['name']
return operation_name

@staticmethod
def execute_delete(service, project_id, region, cluster_name):
"""
Delete a specified cluster
:param service:
:param project_id:
:param region:
:param cluster_name:
:return: The identifier of the operation being executed
"""
response = service.projects().regions().clusters().delete(
projectId=project_id,
region=region,
clusterName=cluster_name
).execute(num_retries=5)
operation_name = response['name']
return operation_name

@staticmethod
def wait_for_operation_done(service, operation_name):
"""
Poll for the completion of a specific GCP operation
:param service:
:param operation_name:
:return: The response code of the completed operation
"""
while True:
response = service.projects().regions().operations().get(
name=operation_name
).execute(num_retries=5)

if response.get('done'):
return response
time.sleep(15)

@staticmethod
def wait_for_operation_done_or_error(service, operation_name):
"""
Block until the specified operation is done. Throws an AirflowException if
the operation completed but had an error
:param service:
:param operation_name:
:return:
"""
response = DataProcHook.wait_for_operation_done(service, operation_name)
if response.get('done'):
if 'error' in response:
raise AirflowException(str(response['error']))
else:
return


setattr(
DataProcHook,
Expand Down
72 changes: 70 additions & 2 deletions airflow/gcp/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,36 @@ def __init__(self,
)
), "num_workers == 0 means single node mode - no preemptibles allowed"

def _cluster_ready(self, state, service):
if state == 'RUNNING':
return True
if state == 'DELETING':
raise Exception('Tried to create a cluster but it\'s in DELETING, something went wrong.')
if state == 'ERROR':
cluster = DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name)
try:
error_details = cluster['status']['details']
except KeyError:
error_details = 'Unknown error in cluster creation, ' \
'check Google Cloud console for details.'

self.log.info('Dataproc cluster creation resulted in an ERROR state running diagnostics')
self.log.info(error_details)
diagnose_operation_name = \
DataProcHook.execute_dataproc_diagnose(service, self.project_id,
self.region, self.cluster_name)
diagnose_result = DataProcHook.wait_for_operation_done(service, diagnose_operation_name)
if diagnose_result.get('response') and diagnose_result.get('response').get('outputUri'):
output_uri = diagnose_result.get('response').get('outputUri')
self.log.info('Diagnostic information for ERROR cluster available at [%s]', output_uri)
else:
self.log.info('Diagnostic information could not be retrieved!')

raise Exception(error_details)
return False

def _get_init_action_timeout(self):
match = re.match(r"^(\d+)(s|m)$", self.init_action_timeout)
match = re.match(r"^(\d+)([sm])$", self.init_action_timeout)
if match:
if match.group(2) == "s":
return self.init_action_timeout
Expand Down Expand Up @@ -445,11 +473,51 @@ def _build_cluster_data(self):

return cluster_data

def _usable_existing_cluster_present(self, service):
existing_cluster = DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name)
if existing_cluster:
self.log.info(
'Cluster %s already exists... Checking status...',
self.cluster_name
)
existing_status = self.hook.get_final_cluster_state(self.project_id,
self.region, self.cluster_name, self.log)

if existing_status == 'RUNNING':
self.log.info('Cluster exists and is already running. Using it.')
return True

elif existing_status == 'DELETING':
while DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name) \
and DataProcHook.get_cluster_state(service, self.project_id,
self.region, self.cluster_name) == 'DELETING':
self.log.info('Existing cluster is deleting, waiting for it to finish')
time.sleep(15)

elif existing_status == 'ERROR':
self.log.info('Existing cluster in ERROR state, deleting it first')

operation_name = DataProcHook.execute_delete(service, self.project_id,
self.region, self.cluster_name)
self.log.info("Cluster delete operation name: %s", operation_name)
DataProcHook.wait_for_operation_done_or_error(service, operation_name)

return False

def start(self):
"""
Create a new cluster on Google Cloud Dataproc.
"""
self.log.info('Creating cluster: %s', self.cluster_name)
hook = DataProcHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to
)
service = hook.get_conn()

if self._usable_existing_cluster_present(service):
return True

cluster_data = self._build_cluster_data()

return (
Expand Down Expand Up @@ -542,7 +610,7 @@ def _build_scale_cluster_data(self):

@staticmethod
def _get_graceful_decommission_timeout(timeout):
match = re.match(r"^(\d+)(s|m|h|d)$", timeout)
match = re.match(r"^(\d+)([smdh])$", timeout)
if match:
if match.group(2) == "s":
return timeout
Expand Down
55 changes: 49 additions & 6 deletions tests/gcp/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,12 @@ def test_create_cluster(self):
self.operation = {'name': 'operation', 'done': True}
self.mock_execute = Mock()
self.mock_execute.execute.return_value = self.operation
self.mock_list = Mock()
self.mock_list_execute = {}
self.mock_list.execute.return_value = self.mock_list_execute
self.mock_clusters = Mock()
self.mock_clusters.create.return_value = self.mock_execute
self.mock_clusters.list.return_value = self.mock_list
self.mock_regions = Mock()
self.mock_regions.clusters.return_value = self.mock_clusters
self.mock_projects = Mock()
Expand Down Expand Up @@ -523,6 +527,45 @@ def test_create_cluster_with_multiple_masters(self):
'labels': {'airflow-version': mock.ANY}})
hook.wait.assert_called_once_with(self.operation)

def test_create_cluster_deletes_error_cluster(self):
# Setup service.projects().regions().clusters().create()
# .execute()
# pylint:disable=attribute-defined-outside-init
self.operation = {'name': 'operation', 'done': True}
self.mock_execute = Mock()
self.mock_execute.execute.return_value = self.operation
self.mock_list = Mock()
self.mock_list_execute = {'clusters': [{'clusterName': CLUSTER_NAME, 'status': {'state': 'ERROR'}}]}
self.mock_list.execute.return_value = self.mock_list_execute
self.mock_clusters = Mock()
self.mock_clusters.create.return_value = self.mock_execute
self.mock_clusters.list.return_value = self.mock_list
self.mock_regions = Mock()
self.mock_regions.clusters.return_value = self.mock_clusters
self.mock_projects = Mock()
self.mock_projects.regions.return_value = self.mock_regions
self.mock_conn = Mock()
self.mock_conn.projects.return_value = self.mock_projects

with patch(HOOK) as mock_hook:
hook = mock_hook()
hook.get_conn.return_value = self.mock_conn
hook.wait.return_value = None
hook.get_final_cluster_state.return_value = "ERROR"

dataproc_task = DataprocClusterCreateOperator(
task_id=TASK_ID,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
project_id=GCP_PROJECT_ID,
num_workers=NUM_WORKERS,
zone=GCE_ZONE,
dag=self.dag
)
with patch.object(dataproc_task.log, 'info') as mock_info:
dataproc_task.execute(None)
mock_info.assert_any_call('Existing cluster in ERROR state, deleting it first')

def test_build_cluster_data_internal_ip_only_without_subnetwork(self):

def create_cluster_with_invalid_internal_ip_only_setup():
Expand Down Expand Up @@ -699,7 +742,7 @@ def test_delete_cluster(self):
projectId=GCP_PROJECT_ID,
clusterName=CLUSTER_NAME,
requestId=mock.ANY)
hook.wait.assert_called_once_with(self.operation)
hook.wait.assert_called_with(self.operation)

def test_render_template(self):
task = DataprocClusterDeleteOperator(
Expand Down Expand Up @@ -792,7 +835,7 @@ def setUp(self):
schedule_interval='@daily')

@mock.patch(
'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id',
'airflow.gcp.hooks.dataproc.DataProcHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID
)
Expand Down Expand Up @@ -877,7 +920,7 @@ def setUp(self):
schedule_interval='@daily')

@mock.patch(
'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id',
'airflow.gcp.hooks.dataproc.DataProcHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID
)
Expand Down Expand Up @@ -962,7 +1005,7 @@ def setUp(self):
schedule_interval='@daily')

@mock.patch(
'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id',
'airflow.gcp.hooks.dataproc.DataProcHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID
)
Expand Down Expand Up @@ -1052,7 +1095,7 @@ def setUp(self):
schedule_interval='@daily')

@mock.patch(
'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id',
'airflow.gcp.hooks.dataproc.DataProcHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID
)
Expand Down Expand Up @@ -1140,7 +1183,7 @@ def setUp(self):
schedule_interval='@daily')

@mock.patch(
'airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook.project_id',
'airflow.gcp.hooks.dataproc.DataProcHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID
)
Expand Down