From db4ee7388e9329badb4c84499b113ae0ce1d3862 Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Thu, 9 Jan 2025 14:55:00 +0000 Subject: [PATCH] Add Dataplex Catalog Entry Group operators --- .../operators/cloud/dataplex.rst | 94 ++++ docs/spelling_wordlist.txt | 2 + generated/provider_dependencies.json | 2 +- .../providers/google/cloud/hooks/dataplex.py | 212 +++++++- .../providers/google/cloud/links/dataplex.py | 51 +- .../google/cloud/operators/dataplex.py | 490 +++++++++++++++++- .../airflow/providers/google/provider.yaml | 4 +- .../tests/google/cloud/hooks/test_dataplex.py | 114 ++++ .../tests/google/cloud/links/test_dataplex.py | 168 ++++++ .../google/cloud/operators/test_dataplex.py | 143 +++++ .../dataplex/example_dataplex_catalog.py | 118 +++++ tests/always/test_project_structure.py | 2 +- 12 files changed, 1392 insertions(+), 8 deletions(-) create mode 100644 providers/tests/google/cloud/links/test_dataplex.py create mode 100644 providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst b/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst index cbeb5eafcd037..1846f1ad41b5e 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst @@ -417,3 +417,97 @@ To get a Data Profile scan job you can use: :dedent: 4 :start-after: [START howto_dataplex_get_data_profile_job_operator] :end-before: [END howto_dataplex_get_data_profile_job_operator] + + +Google Dataplex Catalog Operators +================================= + +Dataplex Catalog provides a unified inventory of Google Cloud resources, such as BigQuery, and other resources, +such as on-premises resources. Dataplex Catalog automatically retrieves metadata for Google Cloud resources, +and you bring metadata for third-party resources into Dataplex Catalog. + +For more information about Dataplex Catalog visit `Dataplex Catalog production documentation `__ + +.. _howto/operator:DataplexCatalogCreateEntryGroupOperator: + +Create an EntryGroup +-------------------- + +To create an Entry Group in specific location in Dataplex Catalog you can +use :class:`~airflow.providers.google.cloud.operators.dataplex.DataplexCatalogCreateEntryGroupOperator` +For more information about the available fields to pass when creating an Entry Group, visit `Entry Group resource configuration. `__ + +A simple Entry Group configuration can look as followed: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py + :language: python + :dedent: 0 + :start-after: [START howto_dataplex_entry_group_configuration] + :end-before: [END howto_dataplex_entry_group_configuration] + +With this configuration you can create an Entry Group resource: + +:class:`~airflow.providers.google.cloud.operators.dataplex.DataplexCatalogCreateEntryGroupOperator` + +.. exampleinclude:: /../../providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataplex_catalog_create_entry_group] + :end-before: [END howto_operator_dataplex_catalog_create_entry_group] + +.. _howto/operator:DataplexCatalogDeleteEntryGroupOperator: + +Delete an EntryGroup +-------------------- + +To delete an Entry Group in specific location in Dataplex Catalog you can +use :class:`~airflow.providers.google.cloud.operators.dataplex.DataplexCatalogDeleteEntryGroupOperator` + +.. exampleinclude:: /../../providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataplex_catalog_delete_entry_group] + :end-before: [END howto_operator_dataplex_catalog_delete_entry_group] + +.. _howto/operator:DataplexCatalogListEntryGroupsOperator: + +List EntryGroups +---------------- + +To list all Entry Groups in specific location in Dataplex Catalog you can +use :class:`~airflow.providers.google.cloud.operators.dataplex.DataplexCatalogListEntryGroupsOperator`. +This operator also supports filtering and ordering the result of the operation. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataplex_catalog_list_entry_groups] + :end-before: [END howto_operator_dataplex_catalog_list_entry_groups] + +.. _howto/operator:DataplexCatalogGetEntryGroupOperator: + +Get an EntryGroup +----------------- + +To retrieve an Entry Group in specific location in Dataplex Catalog you can +use :class:`~airflow.providers.google.cloud.operators.dataplex.DataplexCatalogGetEntryGroupOperator` + +.. exampleinclude:: /../../providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataplex_catalog_get_entry_group] + :end-before: [END howto_operator_dataplex_catalog_get_entry_group] + +.. _howto/operator:DataplexCatalogUpdateEntryGroupOperator: + +Update an EntryGroup +-------------------- + +To update an Entry Group in specific location in Dataplex Catalog you can +use :class:`~airflow.providers.google.cloud.operators.dataplex.DataplexCatalogUpdateEntryGroupOperator` + +.. exampleinclude:: /../../providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataplex_catalog_update_entry_group] + :end-before: [END howto_operator_dataplex_catalog_update_entry_group] diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d1a1e62d521fe..f0f7d90518fb5 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -574,6 +574,8 @@ encodable encryptor enqueue enqueued +EntryGroup +EntryGroups entrypoint entrypoints Enum diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 2724c6a73d4e9..8797be6c641b0 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -656,7 +656,7 @@ "google-cloud-datacatalog>=3.23.0", "google-cloud-dataflow-client>=0.8.6", "google-cloud-dataform>=0.5.0", - "google-cloud-dataplex>=1.10.0", + "google-cloud-dataplex>=2.6.0", "google-cloud-dataproc-metastore>=1.12.0", "google-cloud-dataproc>=5.12.0", "google-cloud-dlp>=3.12.0", diff --git a/providers/src/airflow/providers/google/cloud/hooks/dataplex.py b/providers/src/airflow/providers/google/cloud/hooks/dataplex.py index 387dfb00a50c2..cb2c7e41a2067 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/dataplex.py +++ b/providers/src/airflow/providers/google/cloud/hooks/dataplex.py @@ -20,15 +20,22 @@ import time from collections.abc import Sequence +from copy import deepcopy from typing import TYPE_CHECKING, Any from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.dataplex_v1 import DataplexServiceClient, DataScanServiceAsyncClient, DataScanServiceClient +from google.cloud.dataplex_v1 import ( + DataplexServiceClient, + DataScanServiceAsyncClient, + DataScanServiceClient, +) +from google.cloud.dataplex_v1.services.catalog_service import CatalogServiceClient from google.cloud.dataplex_v1.types import ( Asset, DataScan, DataScanJob, + EntryGroup, Lake, Task, Zone, @@ -47,6 +54,7 @@ from google.api_core.operation import Operation from google.api_core.retry import Retry from google.api_core.retry_async import AsyncRetry + from google.cloud.dataplex_v1.services.catalog_service.pagers import ListEntryGroupsPager from googleapiclient.discovery import Resource PATH_DATA_SCAN = "projects/{project_id}/locations/{region}/dataScans/{data_scan_id}" @@ -110,6 +118,14 @@ def get_dataplex_data_scan_client(self) -> DataScanServiceClient: credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) + def get_dataplex_catalog_client(self) -> CatalogServiceClient: + """Return CatalogServiceClient.""" + client_options = ClientOptions(api_endpoint="dataplex.googleapis.com:443") + + return CatalogServiceClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options + ) + def wait_for_operation(self, timeout: float | None, operation: Operation): """Wait for long-lasting operation to complete.""" try: @@ -118,6 +134,200 @@ def wait_for_operation(self, timeout: float | None, operation: Operation): error = operation.exception(timeout=timeout) raise AirflowException(error) + @GoogleBaseHook.fallback_to_default_project_id + def create_entry_group( + self, + location: str, + entry_group_id: str, + entry_group_configuration: EntryGroup | dict, + project_id: str = PROVIDE_PROJECT_ID, + validate_only: bool = False, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Create an Entry resource. + + :param location: Required. The ID of the Google Cloud location that the task belongs to. + :param entry_group_id: Required. EntryGroup identifier. + :param entry_group_configuration: Required. EntryGroup configuration body. + :param project_id: Optional. The ID of the Google Cloud project that the task belongs to. + :param validate_only: Optional. If set, performs request validation, but does not actually execute + the create request. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + """ + client = self.get_dataplex_catalog_client() + return client.create_entry_group( + request={ + "parent": client.common_location_path(project_id, location), + "entry_group_id": entry_group_id, + "entry_group": entry_group_configuration, + "validate_only": validate_only, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_entry_group( + self, + location: str, + entry_group_id: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> EntryGroup: + """ + Get an EntryGroup resource. + + :param location: Required. The ID of the Google Cloud location that the task belongs to. + :param entry_group_id: Required. EntryGroup identifier. + :param project_id: Optional. The ID of the Google Cloud project that the task belongs to. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + """ + client = self.get_dataplex_catalog_client() + return client.get_entry_group( + request={ + "name": client.entry_group_path(project_id, location, entry_group_id), + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_entry_group( + self, + location: str, + entry_group_id: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete an EntryGroup resource. + + :param location: Required. The ID of the Google Cloud location that the task belongs to. + :param entry_group_id: Required. EntryGroup identifier. + :param project_id: Optional. The ID of the Google Cloud project that the task belongs to. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + """ + client = self.get_dataplex_catalog_client() + return client.delete_entry_group( + request={ + "name": client.entry_group_path(project_id, location, entry_group_id), + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_entry_groups( + self, + location: str, + filter_by: str | None = None, + order_by: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListEntryGroupsPager: + """ + List EntryGroups resources from specific location. + + :param location: Required. The ID of the Google Cloud location that the task belongs to. + :param filter_by: Optional. Filter to apply on the list results. + :param order_by: Optional. Fields to order the results by. + :param page_size: Optional. Maximum number of EntryGroups to return on one page. + :param page_token: Optional. Token to retrieve the next page of results. + :param project_id: Optional. The ID of the Google Cloud project that the task belongs to. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + """ + client = self.get_dataplex_catalog_client() + return client.list_entry_groups( + request={ + "parent": client.common_location_path(project_id, location), + "filter": filter_by, + "order_by": order_by, + "page_size": page_size, + "page_token": page_token, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def update_entry_group( + self, + location: str, + entry_group_id: str, + entry_group_configuration: dict | EntryGroup, + project_id: str = PROVIDE_PROJECT_ID, + update_mask: list[str] | FieldMask | None = None, + validate_only: bool | None = False, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Update an EntryGroup resource. + + :param entry_group_id: Required. ID of the EntryGroup to update. + :param entry_group_configuration: Required. The updated configuration body of the EntryGroup. + :param location: Required. The ID of the Google Cloud location that the task belongs to. + :param update_mask: Optional. Names of fields whose values to overwrite on an entry group. + If this parameter is absent or empty, all modifiable fields are overwritten. If such + fields are non-required and omitted in the request body, their values are emptied. + :param project_id: Optional. The ID of the Google Cloud project that the task belongs to. + :param validate_only: Optional. The service validates the request without performing any mutations. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + """ + client = self.get_dataplex_catalog_client() + _entry_group = ( + deepcopy(entry_group_configuration) + if isinstance(entry_group_configuration, dict) + else EntryGroup.to_dict(entry_group_configuration) + ) + _entry_group["name"] = client.entry_group_path(project_id, location, entry_group_id) + return client.update_entry_group( + request={ + "entry_group": _entry_group, + "update_mask": FieldMask(paths=update_mask) if type(update_mask) is list else update_mask, + "validate_only": validate_only, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + @GoogleBaseHook.fallback_to_default_project_id def create_task( self, diff --git a/providers/src/airflow/providers/google/cloud/links/dataplex.py b/providers/src/airflow/providers/google/cloud/links/dataplex.py index 80d4b2cb9c07d..e0bd1ae584479 100644 --- a/providers/src/airflow/providers/google/cloud/links/dataplex.py +++ b/providers/src/airflow/providers/google/cloud/links/dataplex.py @@ -30,8 +30,10 @@ DATAPLEX_TASK_LINK = DATAPLEX_BASE_LINK + "/{lake_id}.{task_id};location={region}/jobs?project={project_id}" DATAPLEX_TASKS_LINK = DATAPLEX_BASE_LINK + "?project={project_id}&qLake={lake_id}.{region}" -DATAPLEX_LAKE_LINK = ( - "https://console.cloud.google.com/dataplex/lakes/{lake_id};location={region}?project={project_id}" +DATAPLEX_LAKE_LINK = "/dataplex/lakes/{lake_id};location={region}?project={project_id}" +DATAPLEX_CATALOG_ENTRY_GROUPS_LINK = "/dataplex/catalog/entry-groups?project={project_id}" +DATAPLEX_CATALOG_ENTRY_GROUP_LINK = ( + "/dataplex/projects/{project_id}/locations/{location}/entryGroups/{entry_group_id}?project={project_id}" ) @@ -103,3 +105,48 @@ def persist( "project_id": task_instance.project_id, }, ) + + +class DataplexCatalogEntryGroupLink(BaseGoogleLink): + """Helper class for constructing Dataplex Catalog EntryGroup link.""" + + name = "Dataplex Catalog EntryGroup" + key = "dataplex_catalog_entry_group_key" + format_str = DATAPLEX_CATALOG_ENTRY_GROUP_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + ): + task_instance.xcom_push( + context=context, + key=DataplexCatalogEntryGroupLink.key, + value={ + "entry_group_id": task_instance.entry_group_id, + "location": task_instance.location, + "project_id": task_instance.project_id, + }, + ) + + +class DataplexCatalogEntryGroupsLink(BaseGoogleLink): + """Helper class for constructing Dataplex Catalog EntryGroups link.""" + + name = "Dataplex Catalog EntryGroups" + key = "dataplex_catalog_entry_groups_key" + format_str = DATAPLEX_CATALOG_ENTRY_GROUPS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + ): + task_instance.xcom_push( + context=context, + key=DataplexCatalogEntryGroupsLink.key, + value={ + "location": task_instance.location, + "project_id": task_instance.project_id, + }, + ) diff --git a/providers/src/airflow/providers/google/cloud/operators/dataplex.py b/providers/src/airflow/providers/google/cloud/operators/dataplex.py index 8f7a0d694b9f3..33874063955ff 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataplex.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataplex.py @@ -20,8 +20,11 @@ import time from collections.abc import Sequence +from functools import cached_property from typing import TYPE_CHECKING, Any +from google.protobuf.json_format import MessageToDict + from airflow.exceptions import AirflowException from airflow.providers.google.cloud.triggers.dataplex import ( DataplexDataProfileJobTrigger, @@ -33,15 +36,26 @@ from airflow.utils.context import Context -from google.api_core.exceptions import AlreadyExists, GoogleAPICallError +from google.api_core.exceptions import AlreadyExists, GoogleAPICallError, NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry, exponential_sleep_generator -from google.cloud.dataplex_v1.types import Asset, DataScan, DataScanJob, Lake, Task, Zone +from google.cloud.dataplex_v1.types import ( + Asset, + DataScan, + DataScanJob, + EntryGroup, + Lake, + ListEntryGroupsResponse, + Task, + Zone, +) from googleapiclient.errors import HttpError from airflow.configuration import conf from airflow.providers.google.cloud.hooks.dataplex import AirflowDataQualityScanException, DataplexHook from airflow.providers.google.cloud.links.dataplex import ( + DataplexCatalogEntryGroupLink, + DataplexCatalogEntryGroupsLink, DataplexLakeLink, DataplexTaskLink, DataplexTasksLink, @@ -2093,3 +2107,475 @@ def execute(self, context: Context): ) hook.wait_for_operation(timeout=self.timeout, operation=operation) self.log.info("Dataplex asset %s deleted successfully!", self.asset_id) + + +class DataplexCatalogBaseOperator(GoogleCloudBaseOperator): + """ + Base class for all Dataplex Catalog operators. + + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "project_id", + "location", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + location: str, + gcp_conn_id: str = "google_cloud_default", + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + @cached_property + def hook(self) -> DataplexHook: + return DataplexHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class DataplexCatalogCreateEntryGroupOperator(DataplexCatalogBaseOperator): + """ + Create an EntryGroup resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataplexCatalogCreateEntryGroupOperator` + + :param entry_group_id: Required. EntryGroup identifier. + :param entry_group_configuration: Required. EntryGroup configuration. + For more details please see API documentation: + https://cloud.google.com/dataplex/docs/reference/rest/v1/projects.locations.entryGroups#EntryGroup + :param validate_request: Optional. If set, performs request validation, but does not actually + execute the request. + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entry_group_id", "entry_group_configuration"} | set(DataplexCatalogBaseOperator.template_fields) + ) + operator_extra_links = (DataplexCatalogEntryGroupLink(),) + + def __init__( + self, + entry_group_id: str, + entry_group_configuration: EntryGroup | dict, + validate_request: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.entry_group_id = entry_group_id + self.entry_group_configuration = entry_group_configuration + self.validate_request = validate_request + + def execute(self, context: Context): + DataplexCatalogEntryGroupLink.persist( + context=context, + task_instance=self, + ) + + if self.validate_request: + self.log.info("Validating a Create Dataplex Catalog EntryGroup request.") + else: + self.log.info("Creating a Dataplex Catalog EntryGroup.") + + try: + operation = self.hook.create_entry_group( + entry_group_id=self.entry_group_id, + entry_group_configuration=self.entry_group_configuration, + location=self.location, + project_id=self.project_id, + validate_only=self.validate_request, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + entry_group = self.hook.wait_for_operation(timeout=self.timeout, operation=operation) + except AlreadyExists: + entry_group = self.hook.get_entry_group( + entry_group_id=self.entry_group_id, + location=self.location, + project_id=self.project_id, + ) + self.log.info( + "Dataplex Catalog EntryGroup %s already exists.", + self.entry_group_id, + ) + result = EntryGroup.to_dict(entry_group) + return result + except Exception as ex: + raise AirflowException(ex) + else: + result = EntryGroup.to_dict(entry_group) if not self.validate_request else None + + if not self.validate_request: + self.log.info("Dataplex Catalog EntryGroup %s was successfully created.", self.entry_group_id) + return result + + +class DataplexCatalogGetEntryGroupOperator(DataplexCatalogBaseOperator): + """ + Get an EntryGroup resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataplexCatalogGetEntryGroupOperator` + + :param entry_group_id: Required. EntryGroup identifier. + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entry_group_id"} | set(DataplexCatalogBaseOperator.template_fields) + ) + operator_extra_links = (DataplexCatalogEntryGroupLink(),) + + def __init__( + self, + entry_group_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.entry_group_id = entry_group_id + + def execute(self, context: Context): + DataplexCatalogEntryGroupLink.persist( + context=context, + task_instance=self, + ) + self.log.info( + "Retrieving Dataplex Catalog EntryGroup %s.", + self.entry_group_id, + ) + try: + entry_group = self.hook.get_entry_group( + entry_group_id=self.entry_group_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.info( + "Dataplex Catalog EntryGroup %s not found.", + self.entry_group_id, + ) + raise AirflowException(NotFound) + except Exception as ex: + raise AirflowException(ex) + + return EntryGroup.to_dict(entry_group) + + +class DataplexCatalogDeleteEntryGroupOperator(DataplexCatalogBaseOperator): + """ + Delete an EntryGroup resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataplexCatalogDeleteEntryGroupOperator` + + :param entry_group_id: Required. EntryGroup identifier. + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entry_group_id"} | set(DataplexCatalogBaseOperator.template_fields) + ) + + def __init__( + self, + entry_group_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.entry_group_id = entry_group_id + + def execute(self, context: Context): + self.log.info( + "Deleting Dataplex Catalog EntryGroup %s.", + self.entry_group_id, + ) + try: + operation = self.hook.delete_entry_group( + entry_group_id=self.entry_group_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.hook.wait_for_operation(timeout=self.timeout, operation=operation) + + except NotFound: + self.log.info( + "Dataplex Catalog EntryGroup %s not found.", + self.entry_group_id, + ) + raise AirflowException(NotFound) + except Exception as ex: + raise AirflowException(ex) + return None + + +class DataplexCatalogListEntryGroupsOperator(DataplexCatalogBaseOperator): + """ + List EntryGroup resources. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataplexCatalogListEntryGroupsOperator` + + :param filter_by: Optional. Filter to apply on the list results. + :param order_by: Optional. Fields to order the results by. + :param page_size: Optional. Maximum number of EntryGroups to return on the page. + :param page_token: Optional. Token to retrieve the next page of results. + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple(DataplexCatalogBaseOperator.template_fields) + operator_extra_links = (DataplexCatalogEntryGroupsLink(),) + + def __init__( + self, + page_size: int | None = None, + page_token: str | None = None, + filter_by: str | None = None, + order_by: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.page_size = page_size + self.page_token = page_token + self.filter_by = filter_by + self.order_by = order_by + + def execute(self, context: Context): + DataplexCatalogEntryGroupsLink.persist( + context=context, + task_instance=self, + ) + self.log.info( + "Listing Dataplex Catalog EntryGroup from location %s.", + self.location, + ) + try: + entry_group_on_page = self.hook.list_entry_groups( + location=self.location, + project_id=self.project_id, + page_size=self.page_size, + page_token=self.page_token, + filter_by=self.filter_by, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("EntryGroup on page: %s", entry_group_on_page) + self.xcom_push( + context=context, + key="entry_group_page", + value=ListEntryGroupsResponse.to_dict(entry_group_on_page._response), + ) + except Exception as ex: + raise AirflowException(ex) + + # Constructing list to return EntryGroups in readable format + entry_groups_list = [ + MessageToDict(entry_group._pb, preserving_proto_field_name=True) + for entry_group in next(iter(entry_group_on_page.pages)).entry_groups + ] + return entry_groups_list + + +class DataplexCatalogUpdateEntryGroupOperator(DataplexCatalogBaseOperator): + """ + Update an EntryGroup resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataplexCatalogUpdateEntryGroupOperator` + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param location: Required. The ID of the Google Cloud region that the task belongs to. + :param update_mask: Optional. Names of fields whose values to overwrite on an entry group. + If this parameter is absent or empty, all modifiable fields are overwritten. If such + fields are non-required and omitted in the request body, their values are emptied. + :param entry_group_id: Required. ID of the EntryGroup to update. + :param entry_group_configuration: Required. The updated configuration body of the EntryGroup. + For more details please see API documentation: + https://cloud.google.com/dataplex/docs/reference/rest/v1/projects.locations.entryGroups#EntryGroup + :param validate_only: Optional. The service validates the request without performing any mutations. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param gcp_conn_id: Optional. The connection ID to use when fetching connection info. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entry_group_id", "entry_group_configuration", "update_mask"} + | set(DataplexCatalogBaseOperator.template_fields) + ) + operator_extra_links = (DataplexCatalogEntryGroupLink(),) + + def __init__( + self, + entry_group_id: str, + entry_group_configuration: dict | EntryGroup, + update_mask: list[str] | FieldMask | None = None, + validate_request: bool | None = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.entry_group_id = entry_group_id + self.entry_group_configuration = entry_group_configuration + self.update_mask = update_mask + self.validate_request = validate_request + + def execute(self, context: Context): + DataplexCatalogEntryGroupLink.persist( + context=context, + task_instance=self, + ) + + if self.validate_request: + self.log.info("Validating an Update Dataplex Catalog EntryGroup request.") + else: + self.log.info( + "Updating Dataplex Catalog EntryGroup %s.", + self.entry_group_id, + ) + try: + operation = self.hook.update_entry_group( + location=self.location, + project_id=self.project_id, + entry_group_id=self.entry_group_id, + entry_group_configuration=self.entry_group_configuration, + update_mask=self.update_mask, + validate_only=self.validate_request, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + entry_group = self.hook.wait_for_operation(timeout=self.timeout, operation=operation) + + except NotFound as ex: + self.log.info("Specified EntryGroup was not found.") + raise AirflowException(ex) + except Exception as exc: + raise AirflowException(exc) + else: + result = EntryGroup.to_dict(entry_group) if not self.validate_request else None + + if not self.validate_request: + self.log.info("EntryGroup %s was successfully updated.", self.entry_group_id) + return result diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 772c8babdeeff..97277806b855c 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -131,7 +131,7 @@ dependencies: - google-cloud-datacatalog>=3.23.0 - google-cloud-dataflow-client>=0.8.6 - google-cloud-dataform>=0.5.0 - - google-cloud-dataplex>=1.10.0 + - google-cloud-dataplex>=2.6.0 - google-cloud-dataproc>=5.12.0 - google-cloud-dataproc-metastore>=1.12.0 - google-cloud-dlp>=3.12.0 @@ -1203,6 +1203,8 @@ extra-links: - airflow.providers.google.cloud.links.dataplex.DataplexTaskLink - airflow.providers.google.cloud.links.dataplex.DataplexTasksLink - airflow.providers.google.cloud.links.dataplex.DataplexLakeLink + - airflow.providers.google.cloud.links.dataplex.DataplexCatalogEntryGroupLink + - airflow.providers.google.cloud.links.dataplex.DataplexCatalogEntryGroupsLink - airflow.providers.google.cloud.links.bigquery.BigQueryDatasetLink - airflow.providers.google.cloud.links.bigquery.BigQueryTableLink - airflow.providers.google.cloud.links.bigquery.BigQueryJobDetailLink diff --git a/providers/tests/google/cloud/hooks/test_dataplex.py b/providers/tests/google/cloud/hooks/test_dataplex.py index 8f1f5d9866619..4a4f550eca67b 100644 --- a/providers/tests/google/cloud/hooks/test_dataplex.py +++ b/providers/tests/google/cloud/hooks/test_dataplex.py @@ -19,6 +19,7 @@ from unittest import mock from google.api_core.gapic_v1.method import DEFAULT +from google.protobuf.field_mask_pb2 import FieldMask from airflow.providers.google.cloud.operators.dataplex import DataplexHook @@ -30,6 +31,9 @@ DATAPLEX_HOOK_DS_CLIENT = ( "airflow.providers.google.cloud.hooks.dataplex.DataplexHook.get_dataplex_data_scan_client" ) +DATAPLEX_CATALOG_HOOK_CLIENT = ( + "airflow.providers.google.cloud.hooks.dataplex.DataplexHook.get_dataplex_catalog_client" +) PROJECT_ID = "project-id" REGION = "region" @@ -44,12 +48,21 @@ ASSET_ID = "test_asset_id" ZONE_ID = "test_zone_id" JOB_ID = "job_id" + +LOCATION = "us-central1" +ENTRY_GROUP_ID = "entry-group-id" +ENTRY_GROUP_BODY = {"description": "Some descr"} +ENTRY_GROUP_UPDATED_BODY = {"description": "Some new descr"} +UPDATE_MASK = ["description"] + +COMMON_PARENT = f"projects/{PROJECT_ID}/locations/{LOCATION}" DATA_SCAN_NAME = f"projects/{PROJECT_ID}/locations/{REGION}/dataScans/{DATA_SCAN_ID}" DATA_SCAN_JOB_NAME = f"projects/{PROJECT_ID}/locations/{REGION}/dataScans/{DATA_SCAN_ID}/jobs/{JOB_ID}" ZONE_NAME = f"projects/{PROJECT_ID}/locations/{REGION}/lakes/{LAKE_ID}" ZONE_PARENT = f"projects/{PROJECT_ID}/locations/{REGION}/lakes/{LAKE_ID}/zones/{ZONE_ID}" ASSET_PARENT = f"projects/{PROJECT_ID}/locations/{REGION}/lakes/{LAKE_ID}/zones/{ZONE_ID}/assets/{ASSET_ID}" DATASCAN_PARENT = f"projects/{PROJECT_ID}/locations/{REGION}" +ENTRY_GROUP_PARENT = f"projects/{PROJECT_ID}/locations/{LOCATION}/entryGroup/{ENTRY_GROUP_ID}" class TestDataplexHook: @@ -311,3 +324,104 @@ def test_get_data_scan(self, mock_client): timeout=None, metadata=(), ) + + @mock.patch(DATAPLEX_CATALOG_HOOK_CLIENT) + def test_create_entry_group(self, mock_client): + mock_common_location_path = mock_client.return_value.common_location_path + mock_common_location_path.return_value = COMMON_PARENT + self.hook.create_entry_group( + project_id=PROJECT_ID, + location=LOCATION, + entry_group_id=ENTRY_GROUP_ID, + entry_group_configuration=ENTRY_GROUP_BODY, + validate_only=False, + ) + mock_client.return_value.create_entry_group.assert_called_once_with( + request=dict( + parent=COMMON_PARENT, + entry_group_id=ENTRY_GROUP_ID, + entry_group=ENTRY_GROUP_BODY, + validate_only=False, + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch(DATAPLEX_CATALOG_HOOK_CLIENT) + def test_delete_entry_group(self, mock_client): + mock_common_location_path = mock_client.return_value.entry_group_path + mock_common_location_path.return_value = ENTRY_GROUP_PARENT + self.hook.delete_entry_group(project_id=PROJECT_ID, location=LOCATION, entry_group_id=ENTRY_GROUP_ID) + + mock_client.return_value.delete_entry_group.assert_called_once_with( + request=dict( + name=ENTRY_GROUP_PARENT, + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch(DATAPLEX_CATALOG_HOOK_CLIENT) + def test_list_entry_groups(self, mock_client): + mock_common_location_path = mock_client.return_value.common_location_path + mock_common_location_path.return_value = COMMON_PARENT + self.hook.list_entry_groups( + project_id=PROJECT_ID, + location=LOCATION, + order_by="name", + page_size=2, + filter_by="'description' = 'Some descr'", + ) + mock_client.return_value.list_entry_groups.assert_called_once_with( + request=dict( + parent=COMMON_PARENT, + page_size=2, + page_token=None, + filter="'description' = 'Some descr'", + order_by="name", + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch(DATAPLEX_CATALOG_HOOK_CLIENT) + def test_get_entry_group(self, mock_client): + mock_common_location_path = mock_client.return_value.entry_group_path + mock_common_location_path.return_value = ENTRY_GROUP_PARENT + self.hook.get_entry_group(project_id=PROJECT_ID, location=LOCATION, entry_group_id=ENTRY_GROUP_ID) + + mock_client.return_value.get_entry_group.assert_called_once_with( + request=dict( + name=ENTRY_GROUP_PARENT, + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @mock.patch(DATAPLEX_CATALOG_HOOK_CLIENT) + def test_update_entry_group(self, mock_client): + mock_common_location_path = mock_client.return_value.entry_group_path + mock_common_location_path.return_value = ENTRY_GROUP_PARENT + self.hook.update_entry_group( + project_id=PROJECT_ID, + location=LOCATION, + entry_group_id=ENTRY_GROUP_ID, + entry_group_configuration=ENTRY_GROUP_UPDATED_BODY, + update_mask=UPDATE_MASK, + validate_only=False, + ) + + mock_client.return_value.update_entry_group.assert_called_once_with( + request=dict( + entry_group={**ENTRY_GROUP_UPDATED_BODY, "name": ENTRY_GROUP_PARENT}, + update_mask=FieldMask(paths=UPDATE_MASK), + validate_only=False, + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) diff --git a/providers/tests/google/cloud/links/test_dataplex.py b/providers/tests/google/cloud/links/test_dataplex.py new file mode 100644 index 0000000000000..05661c84bd3ee --- /dev/null +++ b/providers/tests/google/cloud/links/test_dataplex.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.google.cloud.links.dataplex import ( + DataplexCatalogEntryGroupLink, + DataplexCatalogEntryGroupsLink, + DataplexLakeLink, + DataplexTaskLink, + DataplexTasksLink, +) +from airflow.providers.google.cloud.operators.dataplex import ( + DataplexCatalogCreateEntryGroupOperator, + DataplexCatalogGetEntryGroupOperator, + DataplexCreateLakeOperator, + DataplexCreateTaskOperator, + DataplexListTasksOperator, +) + +TEST_LOCATION = "test-location" +TEST_PROJECT_ID = "test-project-id" +TEST_ENTRY_GROUP_ID = "test-entry-group-id" +TEST_ENTRY_GROUP_ID_BODY = {"description": "some description"} +TEST_ENTRY_GROUPS_ID = "test-entry-groups-id" +TEST_TASK_ID = "test-task-id" +TEST_TASKS_ID = "test-tasks-id" +TEST_LAKE_ID = "test-lake-id" +TEST_LAKE_BODY = {"name": "some_name"} + +DATAPLEX_BASE_LINK = "https://console.cloud.google.com/dataplex/" +EXPECTED_DATAPLEX_CATALOG_ENTRY_GROUP_LINK = ( + DATAPLEX_BASE_LINK + + f"projects/{TEST_PROJECT_ID}/locations/{TEST_LOCATION}/entryGroups/{TEST_ENTRY_GROUP_ID}?project={TEST_PROJECT_ID}" +) +EXPECTED_DATAPLEX_CATALOG_ENTRY_GROUPS_LINK = ( + DATAPLEX_BASE_LINK + f"catalog/entry-groups?project={TEST_PROJECT_ID}" +) +DATAPLEX_LAKE_LINK = ( + DATAPLEX_BASE_LINK + f"lakes/{TEST_LAKE_ID};location={TEST_LOCATION}?project={TEST_PROJECT_ID}" +) +EXPECTED_DATAPLEX_TASK_LINK = ( + DATAPLEX_BASE_LINK + + f"process/tasks/{TEST_LAKE_ID}.{TEST_TASK_ID};location={TEST_LOCATION}/jobs?project={TEST_PROJECT_ID}" +) +EXPECTED_DATAPLEX_TASKS_LINK = ( + DATAPLEX_BASE_LINK + f"process/tasks?project={TEST_PROJECT_ID}&qLake={TEST_LAKE_ID}.{TEST_LOCATION}" +) + + +class TestDataplexTaskLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator, session): + expected_url = EXPECTED_DATAPLEX_TASK_LINK + link = DataplexTaskLink() + ti = create_task_instance_of_operator( + DataplexCreateTaskOperator, + dag_id="test_link_dag", + task_id="test_link_task", + region=TEST_LOCATION, + lake_id=TEST_LAKE_ID, + project_id=TEST_PROJECT_ID, + body=TEST_LAKE_BODY, + dataplex_task_id=TEST_TASK_ID, + ) + session.add(ti) + session.commit() + link.persist(context={"ti": ti}, task_instance=ti.task) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestDataplexTasksLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator, session): + expected_url = EXPECTED_DATAPLEX_TASKS_LINK + link = DataplexTasksLink() + ti = create_task_instance_of_operator( + DataplexListTasksOperator, + dag_id="test_link_dag", + task_id="test_link_task", + region=TEST_LOCATION, + lake_id=TEST_LAKE_ID, + project_id=TEST_PROJECT_ID, + ) + session.add(ti) + session.commit() + link.persist(context={"ti": ti}, task_instance=ti.task) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestDataplexLakeLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator, session): + expected_url = DATAPLEX_LAKE_LINK + link = DataplexLakeLink() + ti = create_task_instance_of_operator( + DataplexCreateLakeOperator, + dag_id="test_link_dag", + task_id="test_link_task", + region=TEST_LOCATION, + lake_id=TEST_LAKE_ID, + project_id=TEST_PROJECT_ID, + body={}, + ) + session.add(ti) + session.commit() + link.persist(context={"ti": ti}, task_instance=ti.task) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestDataplexCatalogEntryGroupLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator, session): + expected_url = EXPECTED_DATAPLEX_CATALOG_ENTRY_GROUP_LINK + link = DataplexCatalogEntryGroupLink() + ti = create_task_instance_of_operator( + DataplexCatalogGetEntryGroupOperator, + dag_id="test_link_dag", + task_id="test_link_task", + location=TEST_LOCATION, + entry_group_id=TEST_ENTRY_GROUP_ID, + project_id=TEST_PROJECT_ID, + ) + session.add(ti) + session.commit() + link.persist(context={"ti": ti}, task_instance=ti.task) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestDataplexCatalogEntryGroupsLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator, session): + expected_url = EXPECTED_DATAPLEX_CATALOG_ENTRY_GROUPS_LINK + link = DataplexCatalogEntryGroupsLink() + ti = create_task_instance_of_operator( + DataplexCatalogCreateEntryGroupOperator, + dag_id="test_link_dag", + task_id="test_link_task", + location=TEST_LOCATION, + entry_group_id=TEST_ENTRY_GROUP_ID, + entry_group_configuration=TEST_ENTRY_GROUP_ID_BODY, + project_id=TEST_PROJECT_ID, + ) + session.add(ti) + session.commit() + link.persist(context={"ti": ti}, task_instance=ti.task) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url diff --git a/providers/tests/google/cloud/operators/test_dataplex.py b/providers/tests/google/cloud/operators/test_dataplex.py index 1eec9008e2c10..2aff961623bb3 100644 --- a/providers/tests/google/cloud/operators/test_dataplex.py +++ b/providers/tests/google/cloud/operators/test_dataplex.py @@ -20,9 +20,15 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.dataplex_v1.services.catalog_service.pagers import ListEntryGroupsPager +from google.cloud.dataplex_v1.types import ListEntryGroupsRequest, ListEntryGroupsResponse from airflow.exceptions import TaskDeferred from airflow.providers.google.cloud.operators.dataplex import ( + DataplexCatalogCreateEntryGroupOperator, + DataplexCatalogDeleteEntryGroupOperator, + DataplexCatalogGetEntryGroupOperator, + DataplexCatalogListEntryGroupsOperator, DataplexCreateAssetOperator, DataplexCreateLakeOperator, DataplexCreateOrUpdateDataProfileScanOperator, @@ -51,6 +57,7 @@ DATASCANJOB_STR = "airflow.providers.google.cloud.operators.dataplex.DataScanJob" ZONE_STR = "airflow.providers.google.cloud.operators.dataplex.Zone" ASSET_STR = "airflow.providers.google.cloud.operators.dataplex.Asset" +ENTRY_GROUP_STR = "airflow.providers.google.cloud.operators.dataplex.EntryGroup" PROJECT_ID = "project-id" REGION = "region" @@ -72,6 +79,7 @@ ASSET_ID = "test_asset_id" ZONE_ID = "test_zone_id" JOB_ID = "test_job_id" +ENTRY_GROUP_NAME = "test_entry_group" class TestDataplexCreateTaskOperator: @@ -734,3 +742,138 @@ def test_execute(self, hook_mock): timeout=None, metadata=(), ) + + +class TestDataplexCatalogCreateEntryGroupOperator: + @mock.patch(ENTRY_GROUP_STR) + @mock.patch(HOOK_STR) + def test_execute(self, hook_mock, entry_group_mock): + op = DataplexCatalogCreateEntryGroupOperator( + task_id="create_task", + project_id=PROJECT_ID, + location=REGION, + entry_group_id=ENTRY_GROUP_NAME, + entry_group_configuration=BODY, + validate_request=None, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + entry_group_mock.return_value.to_dict.return_value = None + hook_mock.return_value.wait_for_operation.return_value = None + op.execute(context=mock.MagicMock()) + hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + hook_mock.return_value.create_entry_group.assert_called_once_with( + entry_group_id=ENTRY_GROUP_NAME, + entry_group_configuration=BODY, + location=REGION, + project_id=PROJECT_ID, + validate_only=None, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + +class TestDataplexCatalogGetEntryGroupOperator: + @mock.patch(ENTRY_GROUP_STR) + @mock.patch(HOOK_STR) + def test_execute(self, hook_mock, entry_group_mock): + op = DataplexCatalogGetEntryGroupOperator( + project_id=PROJECT_ID, + location=REGION, + entry_group_id=ENTRY_GROUP_NAME, + task_id="get_task", + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context=mock.MagicMock()) + entry_group_mock.return_value.to_dict.return_value = None + hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + hook_mock.return_value.get_entry_group.assert_called_once_with( + project_id=PROJECT_ID, + location=REGION, + entry_group_id=ENTRY_GROUP_NAME, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + +class TestDataplexCatalogDeleteEntryGroupOperator: + @mock.patch(HOOK_STR) + def test_execute(self, hook_mock): + op = DataplexCatalogDeleteEntryGroupOperator( + project_id=PROJECT_ID, + location=REGION, + entry_group_id=ENTRY_GROUP_NAME, + task_id="delete_task", + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + hook_mock.return_value.wait_for_operation.return_value = None + op.execute(context=mock.MagicMock()) + hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + hook_mock.return_value.delete_entry_group.assert_called_once_with( + project_id=PROJECT_ID, + location=REGION, + entry_group_id=ENTRY_GROUP_NAME, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + +class TestDataplexCatalogListEntryGroupsOperator: + @mock.patch(ENTRY_GROUP_STR) + @mock.patch(HOOK_STR) + def test_execute(self, hook_mock, entry_group_mock): + op = DataplexCatalogListEntryGroupsOperator( + project_id=PROJECT_ID, + location=REGION, + task_id="list_task", + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + hook_mock.return_value.list_entry_groups.return_value = ListEntryGroupsPager( + response=( + ListEntryGroupsResponse( + entry_groups=[ + { + "name": "aaa", + "description": "Test Entry Group 1", + "display_name": "Entry Group One", + } + ] + ) + ), + method=mock.MagicMock(), + request=ListEntryGroupsRequest(parent=""), + ) + + entry_group_mock.return_value.to_dict.return_value = None + op.execute(context=mock.MagicMock()) + hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + hook_mock.return_value.list_entry_groups.assert_called_once_with( + project_id=PROJECT_ID, + location=REGION, + page_size=None, + page_token=None, + filter_by=None, + order_by=None, + retry=DEFAULT, + timeout=None, + metadata=(), + ) diff --git a/providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py b/providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py new file mode 100644 index 0000000000000..8eec8a317d640 --- /dev/null +++ b/providers/tests/system/google/cloud/dataplex/example_dataplex_catalog.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use Dataplex Catalog. +""" + +from __future__ import annotations + +import datetime +import os + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.dataplex import ( + DataplexCatalogCreateEntryGroupOperator, + DataplexCatalogDeleteEntryGroupOperator, + DataplexCatalogGetEntryGroupOperator, + DataplexCatalogListEntryGroupsOperator, + DataplexCatalogUpdateEntryGroupOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +from providers.tests.system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +DAG_ID = "dataplex_catalog" +GCP_LOCATION = "us-central1" + +ENTRY_GROUP_NAME = f"{DAG_ID}_entry_group_{ENV_ID}".replace("_", "-") +# [START howto_dataplex_entry_group_configuration] +ENTRY_GROUP_BODY = {"display_name": "Display Name", "description": "Some description"} +# [END howto_dataplex_entry_group_configuration] + +with DAG( + DAG_ID, + start_date=datetime.datetime(2021, 1, 1), + schedule="@once", + tags=["example", "dataplex_catalog"], +) as dag: + # [START howto_operator_dataplex_catalog_create_entry_group] + create_entry_group = DataplexCatalogCreateEntryGroupOperator( + task_id="create_entry_group", + project_id=PROJECT_ID, + location=GCP_LOCATION, + entry_group_id=ENTRY_GROUP_NAME, + entry_group_configuration=ENTRY_GROUP_BODY, + validate_request=False, + ) + # [END howto_operator_dataplex_catalog_create_entry_group] + + # [START howto_operator_dataplex_catalog_get_entry_group] + get_entry_group = DataplexCatalogGetEntryGroupOperator( + task_id="get_entry_group", + project_id=PROJECT_ID, + location=GCP_LOCATION, + entry_group_id=ENTRY_GROUP_NAME, + ) + # [END howto_operator_dataplex_catalog_get_entry_group] + + # [START howto_operator_dataplex_catalog_list_entry_groups] + list_entry_group = DataplexCatalogListEntryGroupsOperator( + task_id="list_entry_group", + project_id=PROJECT_ID, + location=GCP_LOCATION, + order_by="name", + filter_by='display_name = "Display Name"', + ) + # [END howto_operator_dataplex_catalog_list_entry_groups] + + # [START howto_operator_dataplex_catalog_update_entry_group] + update_entry_group = DataplexCatalogUpdateEntryGroupOperator( + task_id="update_entry_group", + project_id=PROJECT_ID, + location=GCP_LOCATION, + entry_group_id=ENTRY_GROUP_NAME, + entry_group_configuration={"display_name": "Updated Display Name"}, + update_mask=["display_name"], + ) + # [END howto_operator_dataplex_catalog_update_entry_group] + + # [START howto_operator_dataplex_catalog_delete_entry_group] + delete_entry_group = DataplexCatalogDeleteEntryGroupOperator( + task_id="delete_entry_group", + project_id=PROJECT_ID, + location=GCP_LOCATION, + entry_group_id=ENTRY_GROUP_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END howto_operator_dataplex_catalog_delete_entry_group] + + create_entry_group >> get_entry_group >> list_entry_group >> update_entry_group >> delete_entry_group + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index f12b3ad6a6684..b894acdb3a86b 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -116,7 +116,6 @@ def test_providers_modules_should_have_tests(self): "providers/tests/google/cloud/links/test_dataflow.py", "providers/tests/google/cloud/links/test_dataform.py", "providers/tests/google/cloud/links/test_datafusion.py", - "providers/tests/google/cloud/links/test_dataplex.py", "providers/tests/google/cloud/links/test_dataprep.py", "providers/tests/google/cloud/links/test_dataproc.py", "providers/tests/google/cloud/links/test_datastore.py", @@ -396,6 +395,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator", "airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator", "airflow.providers.google.cloud.operators.dataproc._DataprocStartStopClusterBaseOperator", + "airflow.providers.google.cloud.operators.dataplex.DataplexCatalogBaseOperator", "airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator", "airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator", "airflow.providers.google.marketing_platform.operators.search_ads._GoogleSearchAdsBaseOperator",