diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index 7f8cadd33c7af..cd4017cfd9243 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -581,6 +581,127 @@ def update_cluster( ) return operation + @GoogleBaseHook.fallback_to_default_project_id + def start_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_config: Union[Dict, Cluster], + labels: Optional[Dict[str, str]] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + """ + Starts a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Name of the cluster to create + :param labels: Labels that will be assigned to created cluster + :param cluster_config: Required. The cluster config to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.ClusterConfig` + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + # Dataproc labels must conform to the following regex: + # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows + # semantic versioning spec: x.y.z). + labels = labels or {} + labels.update({'airflow-version': 'v' + + airflow_version.replace('.', '-').replace('+', '-')}) + + cluster = { + "project_id": project_id, + "cluster_name": cluster_name, + "config": cluster_config, + "labels": labels, + } + + client = self.get_cluster_client(region=region) + result = client.start_cluster( + request={ + 'project_id': project_id, + 'region': region, + 'cluster': cluster, + 'request_id': request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def stop_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_config: Union[Dict, Cluster], + labels: Optional[Dict[str, str]] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + """ + Stops a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Name of the cluster to create + :param labels: Labels that will be assigned to created cluster + :param cluster_config: Required. The cluster config to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.ClusterConfig` + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + # Dataproc labels must conform to the following regex: + # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows + # semantic versioning spec: x.y.z). + labels = labels or {} + labels.update({'airflow-version': 'v' + + airflow_version.replace('.', '-').replace('+', '-')}) + + cluster = { + "project_id": project_id, + "cluster_name": cluster_name, + "config": cluster_config, + "labels": labels, + } + + client = self.get_cluster_client(region=region) + result = client.stop_cluster( + request={ + 'project_id': project_id, + 'region': region, + 'cluster': cluster, + 'request_id': request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id def create_workflow_template( self, diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 753ca287132d5..8bae221d953dc 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -825,6 +825,254 @@ def execute(self, context: 'Context') -> None: operation.result() self.log.info("Cluster deleted.") +class DataprocStartClusterOperator(BaseOperator): + """ + Starts a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated). + :param region: Required. The Cloud Dataproc region in which to handle the request (templated). + :param cluster_name: Required. The cluster name (templated). + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ('project_id', 'region', 'cluster_name', 'impersonation_chain') + + operator_extra_links = (DataprocLink(),) + + def __init__( + self, + *, + project_id: str, + region: str, + cluster_name: str, + cluster_uuid: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.cluster_name = cluster_name + self.cluster_uuid = cluster_uuid + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) + + def _start_cluster(self): + self.hook.start_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + labels=self.labels, + cluster_config=self.cluster_config, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _get_cluster(self) -> Cluster: + return self.hook.get_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _handle_error_state(self, cluster: Cluster) -> None: + if cluster.status.state != cluster.status.State.ERROR: + return + self.log.info("Cluster is in ERROR state") + gcs_uri = self.hook.diagnose_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) + raise AirflowException("Cluster was started but is in ERROR state") + + def _wait_for_cluster_in_starting_state(self) -> Cluster: + time_left = self.timeout + cluster = self._get_cluster() + for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): + if cluster.status.state != cluster.status.State.RUNNING: + break + if time_left < 0: + raise AirflowException( + f"Cluster {self.cluster_name} is still CREATING state, aborting") + time.sleep(time_to_sleep) + time_left = time_left - time_to_sleep + cluster = self._get_cluster() + return cluster + + def execute(self, context: 'Context') -> None: + self.log.info('Starting cluster: %s', self.cluster_name) + # Save data required to display extra link no matter what the cluster status will be + DataprocLink.persist( + context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name + ) + self._start_cluster() + cluster = self._get_cluster() + self._handle_error_state(cluster) + if cluster.status.state == cluster.status.State.STARTING: + # Wait for cluster to be running + cluster = self._wait_for_cluster_in_starting_state() + self._handle_error_state(cluster) + + self.log.info("Cluster started") + + +class DataprocStopClusterOperator(BaseOperator): + """ + Stops a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated). + :param region: Required. The Cloud Dataproc region in which to handle the request (templated). + :param cluster_name: Required. The cluster name (templated). + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ('project_id', 'region', 'cluster_name', 'impersonation_chain') + + operator_extra_links = (DataprocLink(),) + + def __init__( + self, + *, + project_id: str, + region: str, + cluster_name: str, + cluster_uuid: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.cluster_name = cluster_name + self.cluster_uuid = cluster_uuid + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) + + def _stop_cluster(self): + self.hook.stop_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + labels=self.labels, + cluster_config=self.cluster_config, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _get_cluster(self) -> Cluster: + return self.hook.get_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _handle_error_state(self, cluster: Cluster) -> None: + if cluster.status.state != cluster.status.State.ERROR: + return + self.log.info("Cluster is in ERROR state") + gcs_uri = self.hook.diagnose_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) + raise AirflowException("Cluster was stopped but is in ERROR state") + + def _wait_for_cluster_in_stopting_state(self) -> Cluster: + time_left = self.timeout + cluster = self._get_cluster() + for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): + if cluster.status.state != cluster.status.State.STOPPED: + break + if time_left < 0: + raise AirflowException( + f"Cluster {self.cluster_name} is still STOPPING state, aborting") + time.sleep(time_to_sleep) + time_left = time_left - time_to_sleep + cluster = self._get_cluster() + return cluster + + def execute(self, context: 'Context') -> None: + self.log.info('Stopping cluster: %s', self.cluster_name) + + # Save data required to display extra link no matter what the cluster status will be + DataprocLink.persist( + context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name + ) + self._stop_cluster() + cluster = self._get_cluster() + self._handle_error_state(cluster) + if cluster.status.state == cluster.status.State.STOPPING: + # Wait for cluster to be STOPPED + cluster = self._wait_for_cluster_in_stopting_state() + self._handle_error_state(cluster) + + self.log.info("Cluster stopped") + class DataprocJobBaseOperator(BaseOperator): """