From e7650e51211f277e5aaa6d53c968819b593bcebe Mon Sep 17 00:00:00 2001 From: alireza Date: Wed, 2 Mar 2022 00:25:50 +0330 Subject: [PATCH 1/3] dataproc hook for start and stop cluster added --- .../providers/google/cloud/hooks/dataproc.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index 7f8cadd33c7af..cd4017cfd9243 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -581,6 +581,127 @@ def update_cluster( ) return operation + @GoogleBaseHook.fallback_to_default_project_id + def start_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_config: Union[Dict, Cluster], + labels: Optional[Dict[str, str]] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + """ + Starts a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Name of the cluster to create + :param labels: Labels that will be assigned to created cluster + :param cluster_config: Required. The cluster config to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.ClusterConfig` + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + # Dataproc labels must conform to the following regex: + # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows + # semantic versioning spec: x.y.z). + labels = labels or {} + labels.update({'airflow-version': 'v' + + airflow_version.replace('.', '-').replace('+', '-')}) + + cluster = { + "project_id": project_id, + "cluster_name": cluster_name, + "config": cluster_config, + "labels": labels, + } + + client = self.get_cluster_client(region=region) + result = client.start_cluster( + request={ + 'project_id': project_id, + 'region': region, + 'cluster': cluster, + 'request_id': request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def stop_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_config: Union[Dict, Cluster], + labels: Optional[Dict[str, str]] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + """ + Stops a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Name of the cluster to create + :param labels: Labels that will be assigned to created cluster + :param cluster_config: Required. The cluster config to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.ClusterConfig` + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + # Dataproc labels must conform to the following regex: + # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows + # semantic versioning spec: x.y.z). + labels = labels or {} + labels.update({'airflow-version': 'v' + + airflow_version.replace('.', '-').replace('+', '-')}) + + cluster = { + "project_id": project_id, + "cluster_name": cluster_name, + "config": cluster_config, + "labels": labels, + } + + client = self.get_cluster_client(region=region) + result = client.stop_cluster( + request={ + 'project_id': project_id, + 'region': region, + 'cluster': cluster, + 'request_id': request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id def create_workflow_template( self, From 00231c7d390361f2df55166357befce758c6ed20 Mon Sep 17 00:00:00 2001 From: alireza Date: Wed, 2 Mar 2022 00:39:54 +0330 Subject: [PATCH 2/3] dataproc cluster start and finish operator added --- .../google/cloud/operators/dataproc.py | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 753ca287132d5..39762de5906aa 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -825,6 +825,255 @@ def execute(self, context: 'Context') -> None: operation.result() self.log.info("Cluster deleted.") +class DataprocStartClusterOperator(BaseOperator): + """ + Starts a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated). + :param region: Required. The Cloud Dataproc region in which to handle the request (templated). + :param cluster_name: Required. The cluster name (templated). + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ('project_id', 'region', 'cluster_name', 'impersonation_chain') + + operator_extra_links = (DataprocLink(),) + + def __init__( + self, + *, + project_id: str, + region: str, + cluster_name: str, + cluster_uuid: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.cluster_name = cluster_name + self.cluster_uuid = cluster_uuid + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def _start_cluster(self, hook: DataprocHook): + hook.start_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + labels=self.labels, + cluster_config=self.cluster_config, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _get_cluster(self, hook: DataprocHook) -> Cluster: + return hook.get_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: + if cluster.status.state != cluster.status.State.ERROR: + return + self.log.info("Cluster is in ERROR state") + gcs_uri = hook.diagnose_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) + raise AirflowException("Cluster was started but is in ERROR state") + + def _wait_for_cluster_in_starting_state(self, hook: DataprocHook) -> Cluster: + time_left = self.timeout + cluster = self._get_cluster(hook) + for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): + if cluster.status.state != cluster.status.State.RUNNING: + break + if time_left < 0: + raise AirflowException( + f"Cluster {self.cluster_name} is still CREATING state, aborting") + time.sleep(time_to_sleep) + time_left = time_left - time_to_sleep + cluster = self._get_cluster(hook) + return cluster + + def execute(self, context: 'Context') -> None: + self.log.info('Starting cluster: %s', self.cluster_name) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain) + # Save data required to display extra link no matter what the cluster status will be + DataprocLink.persist( + context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name + ) + self._start_cluster(hook) + cluster = self._get_cluster(hook) + self._handle_error_state(hook, cluster) + if cluster.status.state == cluster.status.State.STARTING: + # Wait for cluster to be running + cluster = self._wait_for_cluster_in_starting_state(hook) + self._handle_error_state(hook, cluster) + + self.log.info("Cluster started") + + +class DataprocStopClusterOperator(BaseOperator): + """ + Stops a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated). + :param region: Required. The Cloud Dataproc region in which to handle the request (templated). + :param cluster_name: Required. The cluster name (templated). + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ('project_id', 'region', 'cluster_name', 'impersonation_chain') + + operator_extra_links = (DataprocLink(),) + + def __init__( + self, + *, + project_id: str, + region: str, + cluster_name: str, + cluster_uuid: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.cluster_name = cluster_name + self.cluster_uuid = cluster_uuid + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def _stop_cluster(self, hook: DataprocHook): + hook.stop_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + labels=self.labels, + cluster_config=self.cluster_config, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _get_cluster(self, hook: DataprocHook) -> Cluster: + return hook.get_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: + if cluster.status.state != cluster.status.State.ERROR: + return + self.log.info("Cluster is in ERROR state") + gcs_uri = hook.diagnose_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) + raise AirflowException("Cluster was stopped but is in ERROR state") + + def _wait_for_cluster_in_stopting_state(self, hook: DataprocHook) -> Cluster: + time_left = self.timeout + cluster = self._get_cluster(hook) + for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): + if cluster.status.state != cluster.status.State.STOPPED: + break + if time_left < 0: + raise AirflowException( + f"Cluster {self.cluster_name} is still STOPPING state, aborting") + time.sleep(time_to_sleep) + time_left = time_left - time_to_sleep + cluster = self._get_cluster(hook) + return cluster + + def execute(self, context: 'Context') -> None: + self.log.info('Stopping cluster: %s', self.cluster_name) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain) + # Save data required to display extra link no matter what the cluster status will be + DataprocLink.persist( + context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name + ) + self._stop_cluster(hook) + cluster = self._get_cluster(hook) + self._handle_error_state(hook, cluster) + if cluster.status.state == cluster.status.State.STOPPING: + # Wait for cluster to be STOPPED + cluster = self._wait_for_cluster_in_stopting_state(hook) + self._handle_error_state(hook, cluster) + + self.log.info("Cluster stopped") + class DataprocJobBaseOperator(BaseOperator): """ From c3c81c3144386d1de535c1c5e777270e727bb69e Mon Sep 17 00:00:00 2001 From: alireza Date: Wed, 2 Mar 2022 00:57:10 +0330 Subject: [PATCH 3/3] hooks in cluster start and stop moved to init --- .../google/cloud/operators/dataproc.py | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 39762de5906aa..8bae221d953dc 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -883,9 +883,10 @@ def __init__( self.metadata = metadata self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) - def _start_cluster(self, hook: DataprocHook): - hook.start_cluster( + def _start_cluster(self): + self.hook.start_cluster( project_id=self.project_id, region=self.region, cluster_name=self.cluster_name, @@ -897,8 +898,8 @@ def _start_cluster(self, hook: DataprocHook): metadata=self.metadata, ) - def _get_cluster(self, hook: DataprocHook) -> Cluster: - return hook.get_cluster( + def _get_cluster(self) -> Cluster: + return self.hook.get_cluster( project_id=self.project_id, region=self.region, cluster_name=self.cluster_name, @@ -907,19 +908,19 @@ def _get_cluster(self, hook: DataprocHook) -> Cluster: metadata=self.metadata, ) - def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: + def _handle_error_state(self, cluster: Cluster) -> None: if cluster.status.state != cluster.status.State.ERROR: return self.log.info("Cluster is in ERROR state") - gcs_uri = hook.diagnose_cluster( + gcs_uri = self.hook.diagnose_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) raise AirflowException("Cluster was started but is in ERROR state") - def _wait_for_cluster_in_starting_state(self, hook: DataprocHook) -> Cluster: + def _wait_for_cluster_in_starting_state(self) -> Cluster: time_left = self.timeout - cluster = self._get_cluster(hook) + cluster = self._get_cluster() for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): if cluster.status.state != cluster.status.State.RUNNING: break @@ -928,24 +929,22 @@ def _wait_for_cluster_in_starting_state(self, hook: DataprocHook) -> Cluster: f"Cluster {self.cluster_name} is still CREATING state, aborting") time.sleep(time_to_sleep) time_left = time_left - time_to_sleep - cluster = self._get_cluster(hook) + cluster = self._get_cluster() return cluster def execute(self, context: 'Context') -> None: self.log.info('Starting cluster: %s', self.cluster_name) - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be DataprocLink.persist( context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name ) - self._start_cluster(hook) - cluster = self._get_cluster(hook) - self._handle_error_state(hook, cluster) + self._start_cluster() + cluster = self._get_cluster() + self._handle_error_state(cluster) if cluster.status.state == cluster.status.State.STARTING: # Wait for cluster to be running - cluster = self._wait_for_cluster_in_starting_state(hook) - self._handle_error_state(hook, cluster) + cluster = self._wait_for_cluster_in_starting_state() + self._handle_error_state(cluster) self.log.info("Cluster started") @@ -1008,9 +1007,10 @@ def __init__( self.metadata = metadata self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) - def _stop_cluster(self, hook: DataprocHook): - hook.stop_cluster( + def _stop_cluster(self): + self.hook.stop_cluster( project_id=self.project_id, region=self.region, cluster_name=self.cluster_name, @@ -1022,8 +1022,8 @@ def _stop_cluster(self, hook: DataprocHook): metadata=self.metadata, ) - def _get_cluster(self, hook: DataprocHook) -> Cluster: - return hook.get_cluster( + def _get_cluster(self) -> Cluster: + return self.hook.get_cluster( project_id=self.project_id, region=self.region, cluster_name=self.cluster_name, @@ -1032,19 +1032,19 @@ def _get_cluster(self, hook: DataprocHook) -> Cluster: metadata=self.metadata, ) - def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: + def _handle_error_state(self, cluster: Cluster) -> None: if cluster.status.state != cluster.status.State.ERROR: return self.log.info("Cluster is in ERROR state") - gcs_uri = hook.diagnose_cluster( + gcs_uri = self.hook.diagnose_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) raise AirflowException("Cluster was stopped but is in ERROR state") - def _wait_for_cluster_in_stopting_state(self, hook: DataprocHook) -> Cluster: + def _wait_for_cluster_in_stopting_state(self) -> Cluster: time_left = self.timeout - cluster = self._get_cluster(hook) + cluster = self._get_cluster() for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): if cluster.status.state != cluster.status.State.STOPPED: break @@ -1053,24 +1053,23 @@ def _wait_for_cluster_in_stopting_state(self, hook: DataprocHook) -> Cluster: f"Cluster {self.cluster_name} is still STOPPING state, aborting") time.sleep(time_to_sleep) time_left = time_left - time_to_sleep - cluster = self._get_cluster(hook) + cluster = self._get_cluster() return cluster def execute(self, context: 'Context') -> None: self.log.info('Stopping cluster: %s', self.cluster_name) - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain) + # Save data required to display extra link no matter what the cluster status will be DataprocLink.persist( context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name ) - self._stop_cluster(hook) - cluster = self._get_cluster(hook) - self._handle_error_state(hook, cluster) + self._stop_cluster() + cluster = self._get_cluster() + self._handle_error_state(cluster) if cluster.status.state == cluster.status.State.STOPPING: # Wait for cluster to be STOPPED - cluster = self._wait_for_cluster_in_stopting_state(hook) - self._handle_error_state(hook, cluster) + cluster = self._wait_for_cluster_in_stopting_state() + self._handle_error_state(cluster) self.log.info("Cluster stopped")