diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 17d1638559860..cf4cf3c0ce657 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -233,6 +233,8 @@ def __init__( self.source_object_required = source_object_required self.exact_match = exact_match self.match_glob = match_glob + self.resolved_source_objects: set[str] = set() + self.resolved_target_objects: set[str] = set() def execute(self, context: Context): @@ -540,7 +542,34 @@ def _copy_single_object(self, hook, source_object, destination_object): destination_object, ) + self.resolved_source_objects.add(source_object) + if not destination_object: + self.resolved_target_objects.add(source_object) + else: + self.resolved_target_objects.add(destination_object) + hook.rewrite(self.source_bucket, source_object, self.destination_bucket, destination_object) if self.move_object: hook.delete(self.source_bucket, source_object) + + def get_openlineage_events_on_complete(self, task_instance): + """ + Implementing _on_complete because execute method does preprocessing on internals. + This means we won't have to normalize self.source_object and self.source_objects, + destination bucket and so on. + """ + from openlineage.client.run import Dataset + + from airflow.providers.openlineage.extractors import OperatorLineage + + return OperatorLineage( + inputs=[ + Dataset(namespace=f"gs://{self.source_bucket}", name=source) + for source in sorted(self.resolved_source_objects) + ], + outputs=[ + Dataset(namespace=f"gs://{self.destination_bucket}", name=target) + for target in sorted(self.resolved_target_objects) + ], + ) diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index de84f0f63f257..f316d6acf617d 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -83,6 +83,7 @@ def get_operator_classnames(cls) -> list[str]: return [] def extract(self) -> OperatorLineage | None: + # OpenLineage methods are optional - if there's no method, return None try: return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore except AttributeError: @@ -100,7 +101,15 @@ def extract_on_complete(self, task_instance) -> OperatorLineage | None: def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None: try: - facets = get_facets_method(*args) + facets: OperatorLineage = get_facets_method(*args) + # "rewrite" OperatorLineage to safeguard against different version of the same class + # that was existing in openlineage-airflow package outside of Airflow repo + return OperatorLineage( + inputs=facets.inputs, + outputs=facets.outputs, + run_facets=facets.run_facets, + job_facets=facets.job_facets, + ) except ImportError: self.log.exception( "OpenLineage provider method failed to import OpenLineage integration. " @@ -108,11 +117,4 @@ def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | ) except Exception: self.log.exception("OpenLineage provider method failed to extract data from provider. ") - else: - return OperatorLineage( - inputs=facets.inputs, - outputs=facets.outputs, - run_facets=facets.run_facets, - job_facets=facets.job_facets, - ) return None diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 7437642e4e9b6..6b953edd7f10e 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -539,7 +539,7 @@ def test_expected_output_full_tests_needed( { "affected-providers-list-as-string": "amazon apache.beam apache.cassandra cncf.kubernetes " "common.sql facebook google hashicorp microsoft.azure microsoft.mssql " - "mysql oracle postgres presto salesforce sftp ssh trino", + "mysql openlineage oracle postgres presto salesforce sftp ssh trino", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "needs-helm-tests": "false", @@ -564,8 +564,8 @@ def test_expected_output_full_tests_needed( { "affected-providers-list-as-string": "amazon apache.beam apache.cassandra " "cncf.kubernetes common.sql facebook google " - "hashicorp microsoft.azure microsoft.mssql mysql oracle postgres presto " - "salesforce sftp ssh trino", + "hashicorp microsoft.azure microsoft.mssql mysql openlineage oracle postgres " + "presto salesforce sftp ssh trino", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "image-build": "true", @@ -666,7 +666,7 @@ def test_expected_output_pull_request_v2_3( "affected-providers-list-as-string": "amazon apache.beam apache.cassandra " "cncf.kubernetes common.sql " "facebook google hashicorp microsoft.azure microsoft.mssql mysql " - "oracle postgres presto salesforce sftp ssh trino", + "openlineage oracle postgres presto salesforce sftp ssh trino", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "image-build": "true", @@ -685,6 +685,7 @@ def test_expected_output_pull_request_v2_3( "--package-filter apache-airflow-providers-microsoft-azure " "--package-filter apache-airflow-providers-microsoft-mssql " "--package-filter apache-airflow-providers-mysql " + "--package-filter apache-airflow-providers-openlineage " "--package-filter apache-airflow-providers-oracle " "--package-filter apache-airflow-providers-postgres " "--package-filter apache-airflow-providers-presto " @@ -697,7 +698,7 @@ def test_expected_output_pull_request_v2_3( "skip-provider-tests": "false", "parallel-test-types-list-as-string": "Providers[amazon] Always CLI " "Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.sql,facebook," - "hashicorp,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto," + "hashicorp,microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto," "salesforce,sftp,ssh,trino] Providers[google]", }, id="CLI tests and Google-related provider tests should run if cli/chart files changed", @@ -965,6 +966,7 @@ def test_upgrade_to_newer_dependencies(files: tuple[str, ...], expected_outputs: "--package-filter apache-airflow-providers-microsoft-azure " "--package-filter apache-airflow-providers-microsoft-mssql " "--package-filter apache-airflow-providers-mysql " + "--package-filter apache-airflow-providers-openlineage " "--package-filter apache-airflow-providers-oracle " "--package-filter apache-airflow-providers-postgres " "--package-filter apache-airflow-providers-presto " diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 03d821eff15e9..09e75c6eae21a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -453,6 +453,7 @@ "microsoft.azure", "microsoft.mssql", "mysql", + "openlineage", "oracle", "postgres", "presto", diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py index d29a505ba3290..cf525235b14a1 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py @@ -21,6 +21,7 @@ from unittest import mock import pytest +from openlineage.client.run import Dataset from airflow.exceptions import AirflowException from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, GCSToGCSOperator @@ -827,3 +828,75 @@ def test_copy_files_into_a_folder( for src, dst in zip(expected_source_objects, expected_destination_objects) ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls) + + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook") + def test_execute_simple_reports_openlineage(self, mock_hook): + operator = GCSToGCSOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_object=SOURCE_OBJECTS_SINGLE_FILE[0], + destination_bucket=DESTINATION_BUCKET, + ) + + operator.execute(None) + + lineage = operator.get_openlineage_events_on_complete(None) + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.inputs[0] == Dataset( + namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_SINGLE_FILE[0] + ) + assert lineage.outputs[0] == Dataset( + namespace=f"gs://{DESTINATION_BUCKET}", name=SOURCE_OBJECTS_SINGLE_FILE[0] + ) + + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook") + def test_execute_multiple_reports_openlineage(self, mock_hook): + operator = GCSToGCSOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_objects=SOURCE_OBJECTS_LIST, + destination_bucket=DESTINATION_BUCKET, + destination_object=DESTINATION_OBJECT, + ) + + operator.execute(None) + + lineage = operator.get_openlineage_events_on_complete(None) + assert len(lineage.inputs) == 3 + assert len(lineage.outputs) == 1 + assert lineage.inputs == [ + Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[0]), + Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[1]), + Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[2]), + ] + assert lineage.outputs[0] == Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name=DESTINATION_OBJECT) + + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook") + def test_execute_wildcard_reports_openlineage(self, mock_hook): + mock_hook.return_value.list.return_value = [ + "test_object1.txt", + "test_object2.txt", + ] + + operator = GCSToGCSOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_object=SOURCE_OBJECT_WILDCARD_SUFFIX, + destination_bucket=DESTINATION_BUCKET, + destination_object=DESTINATION_OBJECT, + ) + + operator.execute(None) + + lineage = operator.get_openlineage_events_on_complete(None) + assert len(lineage.inputs) == 2 + assert len(lineage.outputs) == 2 + assert lineage.inputs == [ + Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object1.txt"), + Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object2.txt"), + ] + assert lineage.outputs == [ + Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name="foo/bar/1.txt"), + Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name="foo/bar/2.txt"), + ]