From 15d3a71632b7e8c866a2b4f531f88e80b9bcf0f3 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Tue, 22 Oct 2024 16:10:24 +0200 Subject: [PATCH] feat: add Hook Level Lineage support for GCSHook Signed-off-by: Kacper Muda --- generated/provider_dependencies.json | 2 +- .../google/{datasets => assets}/__init__.py | 0 .../google/{datasets => assets}/bigquery.py | 0 .../airflow/providers/google/assets/gcs.py | 45 +++++ .../providers/google/cloud/hooks/gcs.py | 52 ++++- .../airflow/providers/google/provider.yaml | 14 +- .../tests/google/assets/test_bigquery.py | 2 +- providers/tests/google/assets/test_gcs.py | 74 +++++++ .../tests/google/cloud/hooks/test_gcs.py | 190 ++++++++++++++++++ 9 files changed, 372 insertions(+), 7 deletions(-) rename providers/src/airflow/providers/google/{datasets => assets}/__init__.py (100%) rename providers/src/airflow/providers/google/{datasets => assets}/bigquery.py (100%) create mode 100644 providers/src/airflow/providers/google/assets/gcs.py create mode 100644 providers/tests/google/assets/test_gcs.py diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 7bfbd76acad7f..6cbafaab1d91f 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -626,7 +626,7 @@ "google": { "deps": [ "PyOpenSSL>=23.0.0", - "apache-airflow-providers-common-compat>=1.1.0", + "apache-airflow-providers-common-compat>=1.2.1", "apache-airflow-providers-common-sql>=1.7.2", "apache-airflow>=2.8.0", "asgiref>=3.5.2", diff --git a/providers/src/airflow/providers/google/datasets/__init__.py b/providers/src/airflow/providers/google/assets/__init__.py similarity index 100% rename from providers/src/airflow/providers/google/datasets/__init__.py rename to providers/src/airflow/providers/google/assets/__init__.py diff --git a/providers/src/airflow/providers/google/datasets/bigquery.py b/providers/src/airflow/providers/google/assets/bigquery.py similarity index 100% rename from providers/src/airflow/providers/google/datasets/bigquery.py rename to providers/src/airflow/providers/google/assets/bigquery.py diff --git a/providers/src/airflow/providers/google/assets/gcs.py b/providers/src/airflow/providers/google/assets/gcs.py new file mode 100644 index 0000000000000..4df6995787ecc --- /dev/null +++ b/providers/src/airflow/providers/google/assets/gcs.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.common.compat.assets import Asset +from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url + +if TYPE_CHECKING: + from urllib.parse import SplitResult + + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset + + +def create_asset(*, bucket: str, key: str, extra: dict | None = None) -> Asset: + return Asset(uri=f"gs://{bucket}/{key}", extra=extra) + + +def sanitize_uri(uri: SplitResult) -> SplitResult: + if not uri.netloc: + raise ValueError("URI format gs:// must contain a bucket name") + return uri + + +def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset: + """Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the hook.""" + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset + + bucket, key = _parse_gcs_url(asset.uri) + return OpenLineageDataset(namespace=f"gs://{bucket}", name=key if key else "/") diff --git a/providers/src/airflow/providers/google/cloud/hooks/gcs.py b/providers/src/airflow/providers/google/cloud/hooks/gcs.py index fb48fcd190609..995418f183489 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/gcs.py +++ b/providers/src/airflow/providers/google/cloud/hooks/gcs.py @@ -43,6 +43,7 @@ from requests import Session from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector from airflow.providers.google.cloud.utils.helpers import normalize_directory_path from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import ( @@ -214,6 +215,16 @@ def copy( destination_object = source_bucket.copy_blob( # type: ignore[attr-defined] blob=source_object, destination_bucket=destination_bucket, new_name=destination_object ) + get_hook_lineage_collector().add_input_asset( + context=self, + scheme="gs", + asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined] + ) + get_hook_lineage_collector().add_output_asset( + context=self, + scheme="gs", + asset_kwargs={"bucket": destination_bucket.name, "key": destination_object.name}, # type: ignore[union-attr] + ) self.log.info( "Object %s in bucket %s copied to object %s in bucket %s", @@ -267,6 +278,16 @@ def rewrite( ).rewrite(source=source_object, token=token) self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten) + get_hook_lineage_collector().add_input_asset( + context=self, + scheme="gs", + asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined] + ) + get_hook_lineage_collector().add_output_asset( + context=self, + scheme="gs", + asset_kwargs={"bucket": destination_bucket.name, "key": destination_object}, # type: ignore[attr-defined] + ) self.log.info( "Object %s in bucket %s rewritten to object %s in bucket %s", source_object.name, # type: ignore[attr-defined] @@ -345,9 +366,18 @@ def download( if filename: blob.download_to_filename(filename, timeout=timeout) + get_hook_lineage_collector().add_input_asset( + context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name} + ) + get_hook_lineage_collector().add_output_asset( + context=self, scheme="file", asset_kwargs={"path": filename} + ) self.log.info("File downloaded to %s", filename) return filename else: + get_hook_lineage_collector().add_input_asset( + context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name} + ) return blob.download_as_bytes() except GoogleCloudError: @@ -555,6 +585,9 @@ def _call_with_retry(f: Callable[[], None]) -> None: _call_with_retry( partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout) ) + get_hook_lineage_collector().add_input_asset( + context=self, scheme="file", asset_kwargs={"path": filename} + ) if gzip: os.remove(filename) @@ -576,6 +609,10 @@ def _call_with_retry(f: Callable[[], None]) -> None: else: raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.") + get_hook_lineage_collector().add_output_asset( + context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name} + ) + def exists(self, bucket_name: str, object_name: str, retry: Retry = DEFAULT_RETRY) -> bool: """ Check for the existence of a file in Google Cloud Storage. @@ -691,6 +728,9 @@ def delete(self, bucket_name: str, object_name: str) -> None: bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name) blob.delete() + get_hook_lineage_collector().add_input_asset( + context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name} + ) self.log.info("Blob %s deleted.", object_name) @@ -1198,9 +1238,17 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec client = self.get_conn() bucket = client.bucket(bucket_name) destination_blob = bucket.blob(destination_object) - destination_blob.compose( - sources=[bucket.blob(blob_name=source_object) for source_object in source_objects] + source_blobs = [bucket.blob(blob_name=source_object) for source_object in source_objects] + destination_blob.compose(sources=source_blobs) + get_hook_lineage_collector().add_output_asset( + context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": destination_blob.name} ) + for single_source_blob in source_blobs: + get_hook_lineage_collector().add_input_asset( + context=self, + scheme="gs", + asset_kwargs={"bucket": bucket.name, "key": single_source_blob.name}, + ) self.log.info("Completed successfully.") diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 09e7b6643cb25..5f027c13de12e 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -97,7 +97,7 @@ versions: dependencies: - apache-airflow>=2.8.0 - - apache-airflow-providers-common-compat>=1.1.0 + - apache-airflow-providers-common-compat>=1.2.1 - apache-airflow-providers-common-sql>=1.7.2 - asgiref>=3.5.2 - dill>=0.2.3 @@ -777,7 +777,11 @@ asset-uris: - schemes: [gcp] handler: null - schemes: [bigquery] - handler: airflow.providers.google.datasets.bigquery.sanitize_uri + handler: airflow.providers.google.assets.bigquery.sanitize_uri + - schemes: [gs] + handler: airflow.providers.google.assets.gcs.sanitize_uri + factory: airflow.providers.google.assets.gcs.create_asset + to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage # dataset has been renamed to asset in Airflow 3.0 # This is kept for backward compatibility. @@ -785,7 +789,11 @@ dataset-uris: - schemes: [gcp] handler: null - schemes: [bigquery] - handler: airflow.providers.google.datasets.bigquery.sanitize_uri + handler: airflow.providers.google.assets.bigquery.sanitize_uri + - schemes: [gs] + handler: airflow.providers.google.assets.gcs.sanitize_uri + factory: airflow.providers.google.assets.gcs.create_asset + to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage hooks: - integration-name: Google Ads diff --git a/providers/tests/google/assets/test_bigquery.py b/providers/tests/google/assets/test_bigquery.py index 45da4ffb1eb71..b2f416a36bed8 100644 --- a/providers/tests/google/assets/test_bigquery.py +++ b/providers/tests/google/assets/test_bigquery.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.google.datasets.bigquery import sanitize_uri +from airflow.providers.google.assets.bigquery import sanitize_uri def test_sanitize_uri_pass() -> None: diff --git a/providers/tests/google/assets/test_gcs.py b/providers/tests/google/assets/test_gcs.py new file mode 100644 index 0000000000000..e9920302b0e0a --- /dev/null +++ b/providers/tests/google/assets/test_gcs.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import urllib.parse + +import pytest + +from airflow.providers.common.compat.assets import Asset +from airflow.providers.google.assets.gcs import convert_asset_to_openlineage, create_asset, sanitize_uri + + +def test_sanitize_uri(): + uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/dir/file.txt")) + result = sanitize_uri(uri) + assert result.scheme == "gs" + assert result.netloc == "bucket" + assert result.path == "/dir/file.txt" + + +def test_sanitize_uri_no_netloc(): + with pytest.raises(ValueError): + sanitize_uri(urllib.parse.urlsplit("gs://")) + + +def test_sanitize_uri_no_path(): + uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket")) + result = sanitize_uri(uri) + assert result.scheme == "gs" + assert result.netloc == "bucket" + assert result.path == "" + + +def test_create_asset(): + assert create_asset(bucket="test-bucket", key="test-path") == Asset(uri="gs://test-bucket/test-path") + assert create_asset(bucket="test-bucket", key="test-dir/test-path") == Asset( + uri="gs://test-bucket/test-dir/test-path" + ) + + +def test_sanitize_uri_trailing_slash(): + uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/")) + result = sanitize_uri(uri) + assert result.scheme == "gs" + assert result.netloc == "bucket" + assert result.path == "/" + + +def test_convert_asset_to_openlineage_valid(): + uri = "gs://bucket/dir/file.txt" + ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None) + assert ol_dataset.namespace == "gs://bucket" + assert ol_dataset.name == "dir/file.txt" + + +@pytest.mark.parametrize("uri", ("gs://bucket", "gs://bucket/")) +def test_convert_asset_to_openlineage_no_path(uri): + ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None) + assert ol_dataset.namespace == "gs://bucket" + assert ol_dataset.name == "/" diff --git a/providers/tests/google/cloud/hooks/test_gcs.py b/providers/tests/google/cloud/hooks/test_gcs.py index 464534bd11d93..7c06990d08cab 100644 --- a/providers/tests/google/cloud/hooks/test_gcs.py +++ b/providers/tests/google/cloud/hooks/test_gcs.py @@ -36,6 +36,7 @@ from google.cloud.storage.retry import DEFAULT_RETRY from airflow.exceptions import AirflowException +from airflow.providers.common.compat.assets import Asset from airflow.providers.google.cloud.hooks import gcs from airflow.providers.google.cloud.hooks.gcs import _fallback_object_url_to_object_name_and_bucket_name from airflow.providers.google.common.consts import CLIENT_INFO @@ -43,6 +44,7 @@ from airflow.version import version from providers.tests.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id +from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" GCS_STRING = "airflow.providers.google.cloud.hooks.gcs.{}" @@ -413,6 +415,41 @@ def test_copy_empty_source_object(self): assert str(ctx.value) == "source_bucket and source_object cannot be empty." + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Bucket.copy_blob") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_copy_exposes_lineage(self, mock_service, mock_copy, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + destination_bucket_name = "test-dest-bucket" + destination_object_name = "test-dest-object" + + source_bucket = storage.Bucket(mock_service, source_bucket_name) + mock_copy.return_value = storage.Blob( + name=destination_object_name, bucket=storage.Bucket(mock_service, destination_bucket_name) + ) + mock_service.return_value.bucket.side_effect = ( + lambda name: source_bucket + if name == source_bucket_name + else storage.Bucket(mock_service, destination_bucket_name) + ) + + self.gcs_hook.copy( + source_bucket=source_bucket_name, + source_object=source_object_name, + destination_bucket=destination_bucket_name, + destination_object=destination_object_name, + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"gs://{destination_bucket_name}/{destination_object_name}" + ) + @mock.patch("google.cloud.storage.Bucket") @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_rewrite(self, mock_service, mock_bucket): @@ -474,6 +511,40 @@ def test_rewrite_empty_source_object(self): assert str(ctx.value) == "source_bucket and source_object cannot be empty." + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_rewrite_exposes_lineage(self, mock_service, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + destination_bucket_name = "test-dest-bucket" + destination_object_name = "test-dest-object" + + dest_bucket = storage.Bucket(mock_service, destination_bucket_name) + blob = MagicMock(spec=storage.Blob) + blob.rewrite = MagicMock(return_value=(None, None, None)) + dest_bucket.blob = MagicMock(return_value=blob) + mock_service.return_value.bucket.side_effect = ( + lambda name: storage.Bucket(mock_service, source_bucket_name) + if name == source_bucket_name + else dest_bucket + ) + + self.gcs_hook.rewrite( + source_bucket=source_bucket_name, + source_object=source_object_name, + destination_bucket=destination_bucket_name, + destination_object=destination_object_name, + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"gs://{destination_bucket_name}/{destination_object_name}" + ) + @mock.patch("google.cloud.storage.Bucket") @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_delete(self, mock_service, mock_bucket): @@ -502,6 +573,22 @@ def test_delete_nonexisting_object(self, mock_service): with pytest.raises(exceptions.NotFound): self.gcs_hook.delete(bucket_name=test_bucket, object_name=test_object) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_delete_exposes_lineage(self, mock_service, hook_lineage_collector): + test_bucket = "test_bucket" + test_object = "test_object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, test_bucket) + + self.gcs_hook.delete(bucket_name=test_bucket, object_name=test_object) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 0 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"gs://{test_bucket}/{test_object}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_delete_bucket(self, mock_service): test_bucket = "test bucket" @@ -729,6 +816,33 @@ def test_compose_without_destination_object(self, mock_service): assert str(ctx.value) == "bucket_name and destination_object cannot be empty." + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_compose_exposes_lineage(self, mock_service, hook_lineage_collector): + test_bucket = "test_bucket" + source_object_names = ["test-source-object1", "test-source-object2"] + destination_object_name = "test-dest-object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, test_bucket) + + self.gcs_hook.compose( + bucket_name=test_bucket, + source_objects=source_object_names, + destination_object=destination_object_name, + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 2 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"gs://{test_bucket}/{source_object_names[0]}" + ) + assert hook_lineage_collector.collected_assets.inputs[1].asset == Asset( + uri=f"gs://{test_bucket}/{source_object_names[1]}" + ) + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"gs://{test_bucket}/{destination_object_name}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_download_as_bytes(self, mock_service): test_bucket = "test_bucket" @@ -743,6 +857,23 @@ def test_download_as_bytes(self, mock_service): assert response == test_object_bytes download_method.assert_called_once_with() + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.download_as_bytes") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_download_as_bytes_exposes_lineage(self, mock_service, mock_download, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.download(bucket_name=source_bucket_name, object_name=source_object_name, filename=None) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 0 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_download_to_file(self, mock_service): test_bucket = "test_bucket" @@ -766,6 +897,27 @@ def test_download_to_file(self, mock_service): assert response == test_file download_filename_method.assert_called_once_with(test_file, timeout=60) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.download_to_filename") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_download_to_file_exposes_lineage(self, mock_service, mock_download, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + file_name = "test.txt" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.download( + bucket_name=source_bucket_name, object_name=source_object_name, filename=file_name + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset(uri=f"file://{file_name}") + @mock.patch(GCS_STRING.format("NamedTemporaryFile")) @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_provide_file(self, mock_service, mock_temp_file): @@ -999,6 +1151,27 @@ def test_upload_file(self, mock_service, testdata_file): assert metadata == blob_object.return_value.metadata + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.upload_from_filename") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_upload_file_exposes_lineage(self, mock_service, mock_upload, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + file_name = "test.txt" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.upload( + bucket_name=source_bucket_name, object_name=source_object_name, filename=file_name + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset(uri=f"file://{file_name}") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_upload_cache_control(self, mock_service, testdata_file): test_bucket = "test_bucket" @@ -1042,6 +1215,23 @@ def test_upload_data_bytes(self, mock_service, testdata_bytes): upload_method.assert_called_once_with(testdata_bytes, content_type="text/plain", timeout=60) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.upload_from_string") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_upload_data_exposes_lineage(self, mock_service, mock_upload, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.upload(bucket_name=source_bucket_name, object_name=source_object_name, data="test") + + assert len(hook_lineage_collector.collected_assets.inputs) == 0 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + @mock.patch(GCS_STRING.format("BytesIO")) @mock.patch(GCS_STRING.format("gz.GzipFile")) @mock.patch(GCS_STRING.format("GCSHook.get_conn"))