Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,127 @@ def update_cluster(
)
return operation

@GoogleBaseHook.fallback_to_default_project_id
def start_cluster(
self,
region: str,
project_id: str,
cluster_name: str,
cluster_config: Union[Dict, Cluster],
labels: Optional[Dict[str, str]] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
):
"""
Starts a cluster in a project.

:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param cluster_name: Name of the cluster to create
:param labels: Labels that will be assigned to created cluster
:param cluster_config: Required. The cluster config to create.
If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.cloud.dataproc_v1.types.ClusterConfig`
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``CreateClusterRequest`` requests with the same id, then the second request will be ignored and
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
# Dataproc labels must conform to the following regex:
# [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
# semantic versioning spec: x.y.z).
labels = labels or {}
labels.update({'airflow-version': 'v' +
airflow_version.replace('.', '-').replace('+', '-')})

cluster = {
"project_id": project_id,
"cluster_name": cluster_name,
"config": cluster_config,
"labels": labels,
}

client = self.get_cluster_client(region=region)
result = client.start_cluster(
request={
'project_id': project_id,
'region': region,
'cluster': cluster,
'request_id': request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def stop_cluster(
self,
region: str,
project_id: str,
cluster_name: str,
cluster_config: Union[Dict, Cluster],
labels: Optional[Dict[str, str]] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
):
"""
Stops a cluster in a project.

:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param cluster_name: Name of the cluster to create
:param labels: Labels that will be assigned to created cluster
:param cluster_config: Required. The cluster config to create.
If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.cloud.dataproc_v1.types.ClusterConfig`
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``CreateClusterRequest`` requests with the same id, then the second request will be ignored and
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
# Dataproc labels must conform to the following regex:
# [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
# semantic versioning spec: x.y.z).
labels = labels or {}
labels.update({'airflow-version': 'v' +
airflow_version.replace('.', '-').replace('+', '-')})

cluster = {
"project_id": project_id,
"cluster_name": cluster_name,
"config": cluster_config,
"labels": labels,
}

client = self.get_cluster_client(region=region)
result = client.stop_cluster(
request={
'project_id': project_id,
'region': region,
'cluster': cluster,
'request_id': request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result


@GoogleBaseHook.fallback_to_default_project_id
def create_workflow_template(
self,
Expand Down
248 changes: 248 additions & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,254 @@ def execute(self, context: 'Context') -> None:
operation.result()
self.log.info("Cluster deleted.")

class DataprocStartClusterOperator(BaseOperator):
"""
Starts a cluster in a project.

:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated).
:param region: Required. The Cloud Dataproc region in which to handle the request (templated).
:param cluster_name: Required. The cluster name (templated).
:param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
if cluster with specified UUID does not exist.
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields: Sequence[str] = ('project_id', 'region', 'cluster_name', 'impersonation_chain')

operator_extra_links = (DataprocLink(),)

def __init__(
self,
*,
project_id: str,
region: str,
cluster_name: str,
cluster_uuid: Optional[str] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
self.cluster_uuid = cluster_uuid
self.request_id = request_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)

def _start_cluster(self):
self.hook.start_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
labels=self.labels,
cluster_config=self.cluster_config,
request_id=self.request_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _get_cluster(self) -> Cluster:
return self.hook.get_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _handle_error_state(self, cluster: Cluster) -> None:
if cluster.status.state != cluster.status.State.ERROR:
return
self.log.info("Cluster is in ERROR state")
gcs_uri = self.hook.diagnose_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri)
raise AirflowException("Cluster was started but is in ERROR state")

def _wait_for_cluster_in_starting_state(self) -> Cluster:
time_left = self.timeout
cluster = self._get_cluster()
for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
if cluster.status.state != cluster.status.State.RUNNING:
break
if time_left < 0:
raise AirflowException(
f"Cluster {self.cluster_name} is still CREATING state, aborting")
time.sleep(time_to_sleep)
time_left = time_left - time_to_sleep
cluster = self._get_cluster()
return cluster

def execute(self, context: 'Context') -> None:
self.log.info('Starting cluster: %s', self.cluster_name)
# Save data required to display extra link no matter what the cluster status will be
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
)
self._start_cluster()
cluster = self._get_cluster()
self._handle_error_state(cluster)
if cluster.status.state == cluster.status.State.STARTING:
# Wait for cluster to be running
cluster = self._wait_for_cluster_in_starting_state()
self._handle_error_state(cluster)

self.log.info("Cluster started")


class DataprocStopClusterOperator(BaseOperator):
"""
Stops a cluster in a project.

:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated).
:param region: Required. The Cloud Dataproc region in which to handle the request (templated).
:param cluster_name: Required. The cluster name (templated).
:param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
if cluster with specified UUID does not exist.
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields: Sequence[str] = ('project_id', 'region', 'cluster_name', 'impersonation_chain')

operator_extra_links = (DataprocLink(),)

def __init__(
self,
*,
project_id: str,
region: str,
cluster_name: str,
cluster_uuid: Optional[str] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
self.cluster_uuid = cluster_uuid
self.request_id = request_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)

def _stop_cluster(self):
self.hook.stop_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
labels=self.labels,
cluster_config=self.cluster_config,
request_id=self.request_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _get_cluster(self) -> Cluster:
return self.hook.get_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _handle_error_state(self, cluster: Cluster) -> None:
if cluster.status.state != cluster.status.State.ERROR:
return
self.log.info("Cluster is in ERROR state")
gcs_uri = self.hook.diagnose_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri)
raise AirflowException("Cluster was stopped but is in ERROR state")

def _wait_for_cluster_in_stopting_state(self) -> Cluster:
time_left = self.timeout
cluster = self._get_cluster()
for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
if cluster.status.state != cluster.status.State.STOPPED:
break
if time_left < 0:
raise AirflowException(
f"Cluster {self.cluster_name} is still STOPPING state, aborting")
time.sleep(time_to_sleep)
time_left = time_left - time_to_sleep
cluster = self._get_cluster()
return cluster

def execute(self, context: 'Context') -> None:
self.log.info('Stopping cluster: %s', self.cluster_name)

# Save data required to display extra link no matter what the cluster status will be
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
)
self._stop_cluster()
cluster = self._get_cluster()
self._handle_error_state(cluster)
if cluster.status.state == cluster.status.State.STOPPING:
# Wait for cluster to be STOPPED
cluster = self._wait_for_cluster_in_stopting_state()
self._handle_error_state(cluster)

self.log.info("Cluster stopped")


class DataprocJobBaseOperator(BaseOperator):
"""
Expand Down