Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@
"google": {
"deps": [
"PyOpenSSL>=23.0.0",
"apache-airflow-providers-common-compat>=1.1.0",
"apache-airflow-providers-common-compat>=1.2.1",
"apache-airflow-providers-common-sql>=1.7.2",
"apache-airflow>=2.8.0",
"asgiref>=3.5.2",
Expand Down
45 changes: 45 additions & 0 deletions providers/src/airflow/providers/google/assets/gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.common.compat.assets import Asset
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url

if TYPE_CHECKING:
from urllib.parse import SplitResult

from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset


def create_asset(*, bucket: str, key: str, extra: dict | None = None) -> Asset:
return Asset(uri=f"gs://{bucket}/{key}", extra=extra)


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format gs:// must contain a bucket name")
return uri


def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset:
"""Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset

bucket, key = _parse_gcs_url(asset.uri)
return OpenLineageDataset(namespace=f"gs://{bucket}", name=key if key else "/")
52 changes: 50 additions & 2 deletions providers/src/airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from requests import Session

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import (
Expand Down Expand Up @@ -214,6 +215,16 @@ def copy(
destination_object = source_bucket.copy_blob( # type: ignore[attr-defined]
blob=source_object, destination_bucket=destination_bucket, new_name=destination_object
)
get_hook_lineage_collector().add_input_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined]
)
get_hook_lineage_collector().add_output_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": destination_bucket.name, "key": destination_object.name}, # type: ignore[union-attr]
)

self.log.info(
"Object %s in bucket %s copied to object %s in bucket %s",
Expand Down Expand Up @@ -267,6 +278,16 @@ def rewrite(
).rewrite(source=source_object, token=token)

self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten)
get_hook_lineage_collector().add_input_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined]
)
get_hook_lineage_collector().add_output_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": destination_bucket.name, "key": destination_object}, # type: ignore[attr-defined]
)
self.log.info(
"Object %s in bucket %s rewritten to object %s in bucket %s",
source_object.name, # type: ignore[attr-defined]
Expand Down Expand Up @@ -345,9 +366,18 @@ def download(

if filename:
blob.download_to_filename(filename, timeout=timeout)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)
get_hook_lineage_collector().add_output_asset(
context=self, scheme="file", asset_kwargs={"path": filename}
)
self.log.info("File downloaded to %s", filename)
return filename
else:
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)
return blob.download_as_bytes()

except GoogleCloudError:
Expand Down Expand Up @@ -555,6 +585,9 @@ def _call_with_retry(f: Callable[[], None]) -> None:
_call_with_retry(
partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout)
)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="file", asset_kwargs={"path": filename}
)

if gzip:
os.remove(filename)
Expand All @@ -576,6 +609,10 @@ def _call_with_retry(f: Callable[[], None]) -> None:
else:
raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.")

get_hook_lineage_collector().add_output_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)

def exists(self, bucket_name: str, object_name: str, retry: Retry = DEFAULT_RETRY) -> bool:
"""
Check for the existence of a file in Google Cloud Storage.
Expand Down Expand Up @@ -691,6 +728,9 @@ def delete(self, bucket_name: str, object_name: str) -> None:
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
blob.delete()
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)

self.log.info("Blob %s deleted.", object_name)

Expand Down Expand Up @@ -1198,9 +1238,17 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec
client = self.get_conn()
bucket = client.bucket(bucket_name)
destination_blob = bucket.blob(destination_object)
destination_blob.compose(
sources=[bucket.blob(blob_name=source_object) for source_object in source_objects]
source_blobs = [bucket.blob(blob_name=source_object) for source_object in source_objects]
destination_blob.compose(sources=source_blobs)
get_hook_lineage_collector().add_output_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": destination_blob.name}
)
for single_source_blob in source_blobs:
get_hook_lineage_collector().add_input_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": bucket.name, "key": single_source_blob.name},
)

self.log.info("Completed successfully.")

Expand Down
14 changes: 11 additions & 3 deletions providers/src/airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-compat>=1.1.0
- apache-airflow-providers-common-compat>=1.2.1
- apache-airflow-providers-common-sql>=1.7.2
- asgiref>=3.5.2
- dill>=0.2.3
Expand Down Expand Up @@ -777,15 +777,23 @@ asset-uris:
- schemes: [gcp]
handler: null
- schemes: [bigquery]
handler: airflow.providers.google.datasets.bigquery.sanitize_uri
handler: airflow.providers.google.assets.bigquery.sanitize_uri
- schemes: [gs]
handler: airflow.providers.google.assets.gcs.sanitize_uri
factory: airflow.providers.google.assets.gcs.create_asset
to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage

# dataset has been renamed to asset in Airflow 3.0
# This is kept for backward compatibility.
dataset-uris:
- schemes: [gcp]
handler: null
- schemes: [bigquery]
handler: airflow.providers.google.datasets.bigquery.sanitize_uri
handler: airflow.providers.google.assets.bigquery.sanitize_uri
- schemes: [gs]
handler: airflow.providers.google.assets.gcs.sanitize_uri
factory: airflow.providers.google.assets.gcs.create_asset
to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage

hooks:
- integration-name: Google Ads
Expand Down
2 changes: 1 addition & 1 deletion providers/tests/google/assets/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest

from airflow.providers.google.datasets.bigquery import sanitize_uri
from airflow.providers.google.assets.bigquery import sanitize_uri


def test_sanitize_uri_pass() -> None:
Expand Down
74 changes: 74 additions & 0 deletions providers/tests/google/assets/test_gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import urllib.parse

import pytest

from airflow.providers.common.compat.assets import Asset
from airflow.providers.google.assets.gcs import convert_asset_to_openlineage, create_asset, sanitize_uri


def test_sanitize_uri():
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/dir/file.txt"))
result = sanitize_uri(uri)
assert result.scheme == "gs"
assert result.netloc == "bucket"
assert result.path == "/dir/file.txt"


def test_sanitize_uri_no_netloc():
with pytest.raises(ValueError):
sanitize_uri(urllib.parse.urlsplit("gs://"))


def test_sanitize_uri_no_path():
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket"))
result = sanitize_uri(uri)
assert result.scheme == "gs"
assert result.netloc == "bucket"
assert result.path == ""


def test_create_asset():
assert create_asset(bucket="test-bucket", key="test-path") == Asset(uri="gs://test-bucket/test-path")
assert create_asset(bucket="test-bucket", key="test-dir/test-path") == Asset(
uri="gs://test-bucket/test-dir/test-path"
)


def test_sanitize_uri_trailing_slash():
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/"))
result = sanitize_uri(uri)
assert result.scheme == "gs"
assert result.netloc == "bucket"
assert result.path == "/"


def test_convert_asset_to_openlineage_valid():
uri = "gs://bucket/dir/file.txt"
ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None)
assert ol_dataset.namespace == "gs://bucket"
assert ol_dataset.name == "dir/file.txt"


@pytest.mark.parametrize("uri", ("gs://bucket", "gs://bucket/"))
def test_convert_asset_to_openlineage_no_path(uri):
ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None)
assert ol_dataset.namespace == "gs://bucket"
assert ol_dataset.name == "/"
Loading