Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,14 @@ repos:
^airflow/providers/.*\.py$
exclude: ^.*/.*_vendor/
- id: check-get-lineage-collector-providers
language: pygrep
language: python
name: Check providers import hook lineage code from compat
description: Make sure you import from airflow.provider.common.compat.lineage.hook instead of
airflow.lineage.hook.
entry: "airflow\\.lineage\\.hook"
pass_filenames: true
entry: ./scripts/ci/pre_commit/check_get_lineage_collector_providers.py
files: ^airflow/providers/.*\.py$
exclude: ^airflow/providers/common/compat/.*\.py$
additional_dependencies: [ 'rich>=12.4.4' ]
- id: check-decorated-operator-implements-custom-name
name: Check @task decorator implements custom_operator_name
language: python
Expand Down
48 changes: 46 additions & 2 deletions airflow/io/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

from airflow.io.store import attach
from airflow.io.utils.stat import stat_result
from airflow.lineage.hook import get_hook_lineage_collector
from airflow.utils.log.logging_mixin import LoggingMixin

if typing.TYPE_CHECKING:
from fsspec import AbstractFileSystem
Expand All @@ -39,6 +41,42 @@
default = "file"


class TrackingFileWrapper(LoggingMixin):
"""Wrapper that tracks file operations to intercept lineage."""

def __init__(self, path: ObjectStoragePath, obj):
super().__init__()
self._path: ObjectStoragePath = path
self._obj = obj

def __getattr__(self, name):
attr = getattr(self._obj, name)
if callable(attr):
# If the attribute is a method, wrap it in another method to intercept the call
def wrapper(*args, **kwargs):
self.log.error("Calling method: %s", name)
if name == "read":
get_hook_lineage_collector().add_input_dataset(context=self._path, uri=str(self._path))
elif name == "write":
get_hook_lineage_collector().add_output_dataset(context=self._path, uri=str(self._path))
result = attr(*args, **kwargs)
return result

return wrapper
return attr

def __getitem__(self, key):
# Intercept item access
return self._obj[key]

def __enter__(self):
self._obj.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._obj.__exit__(exc_type, exc_val, exc_tb)


class ObjectStoragePath(CloudPath):
"""A path-like object for object storage."""

Expand Down Expand Up @@ -121,7 +159,7 @@ def namespace(self) -> str:
def open(self, mode="r", **kwargs):
"""Open the file pointed to by this path."""
kwargs.setdefault("block_size", kwargs.pop("buffering", None))
return self.fs.open(self.path, mode=mode, **kwargs)
return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs))

def stat(self) -> stat_result: # type: ignore[override]
"""Call ``stat`` and return the result."""
Expand Down Expand Up @@ -276,6 +314,11 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
if isinstance(dst, str):
dst = ObjectStoragePath(dst)

if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file":
# only emit this in "optimized" variants - else lineage will be captured by file writes/reads
get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self))
get_hook_lineage_collector().add_output_dataset(context=dst, uri=str(dst))

# same -> same
if self.samestore(dst):
self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs)
Expand Down Expand Up @@ -319,7 +362,6 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
continue

src_obj._cp_file(dst)

return

# remote file -> remote dir
Expand All @@ -339,6 +381,8 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs)
path = ObjectStoragePath(path)

if self.samestore(path):
get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self))
get_hook_lineage_collector().add_output_dataset(context=path, uri=str(path))
return self.fs.move(self.path, path.path, recursive=recursive, **kwargs)

# non-local copy
Expand Down
12 changes: 7 additions & 5 deletions airflow/lineage/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
# under the License.
from __future__ import annotations

from typing import Union
from typing import TYPE_CHECKING, Union

import attr

from airflow.datasets import Dataset
from airflow.hooks.base import BaseHook
from airflow.io.store import ObjectStore
from airflow.providers_manager import ProvidersManager
from airflow.utils.log.logging_mixin import LoggingMixin

# Store context what sent lineage.
LineageContext = Union[BaseHook, ObjectStore]
if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
from airflow.io.path import ObjectStoragePath

# Store context what sent lineage.
LineageContext = Union[BaseHook, ObjectStoragePath]

_hook_lineage_collector: HookLineageCollector | None = None

Expand Down
83 changes: 83 additions & 0 deletions scripts/ci/pre_commit/check_get_lineage_collector_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import ast
import sys
from pathlib import Path
from typing import NamedTuple

sys.path.insert(0, str(Path(__file__).parent.resolve()))
from common_precommit_utils import console, initialize_breeze_precommit

initialize_breeze_precommit(__name__, __file__)


class ImportTuple(NamedTuple):
module: list[str]
name: list[str]
alias: str


def get_toplevel_imports(path: str):
with open(path) as fh:
root = ast.parse(fh.read(), path)

for node in ast.iter_child_nodes(root):
if isinstance(node, ast.Import):
module: list[str] = node.names[0].name.split(".") if node.names else []
elif isinstance(node, ast.ImportFrom) and node.module:
module = node.module.split(".")
else:
continue

for n in node.names: # type: ignore[attr-defined]
yield ImportTuple(module=module, name=n.name.split("."), alias=n.asname)


errors: list[str] = []


def main() -> int:
for path in sys.argv[1:]:
import_count = 0
local_error_count = 0
for imp in get_toplevel_imports(path):
import_count += 1
if len(imp.module) > 2:
if imp.module[:3] == ["airflow", "lineage", "hook"]:
local_error_count += 1
errors.append(f"{path}: ({'.'.join(imp.module)})")
console.print(f"[blue]{path}:[/] Import count: {import_count}, error_count {local_error_count}")
if errors:
console.print(
"[red]Some providers files import directly top level from `airflow.lineage.hook` and they are not allowed.[/]\n"
"Only TYPE_CHECKING imports from `airflow.lineage.hook` is allowed in providers."
)
console.print("Error summary:")
for error in errors:
console.print(error)
return 1
else:
console.print("[green]All good!")
return 0


if __name__ == "__main__":
sys.exit(main())
1 change: 1 addition & 0 deletions scripts/ci/pre_commit/check_tests_in_right_folders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand Down
43 changes: 32 additions & 11 deletions tests/io/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,11 @@ def test_relative_to(self):
with pytest.raises(ValueError):
o1.relative_to(o3)

def test_move_local(self):
_from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
_to = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
def test_move_local(self, hook_lineage_collector):
_from_path = f"file:///tmp/{str(uuid.uuid4())}"
_to_path = f"file:///tmp/{str(uuid.uuid4())}"
_from = ObjectStoragePath(_from_path)
_to = ObjectStoragePath(_to_path)

_from.touch()
_from.move(_to)
Expand All @@ -278,13 +280,19 @@ def test_move_local(self):

_to.unlink()

def test_move_remote(self):
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=_from_path)
assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=_to_path)

def test_move_remote(self, hook_lineage_collector):
attach("fakefs", fs=FakeRemoteFileSystem())

_from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
print(_from)
_to = ObjectStoragePath(f"fakefs:///tmp/{str(uuid.uuid4())}")
print(_to)
_from_path = f"file:///tmp/{str(uuid.uuid4())}"
_to_path = f"fakefs:///tmp/{str(uuid.uuid4())}"

_from = ObjectStoragePath(_from_path)
_to = ObjectStoragePath(_to_path)

_from.touch()
_from.move(_to)
Expand All @@ -293,21 +301,28 @@ def test_move_remote(self):

_to.unlink()

def test_copy_remote_remote(self):
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=str(_from))
assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=str(_to))

def test_copy_remote_remote(self, hook_lineage_collector):
attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True))

dir_src = f"bucket1/{str(uuid.uuid4())}"
dir_dst = f"bucket2/{str(uuid.uuid4())}"
key = "foo/bar/baz.txt"

_from = ObjectStoragePath(f"ffs://{dir_src}")
_from_path = f"ffs://{dir_src}"
_from = ObjectStoragePath(_from_path)
_from_file = _from / key
_from_file.touch()
assert _from.bucket == "bucket1"
assert _from_file.exists()

_to = ObjectStoragePath(f"ffs2://{dir_dst}")
_to_path = f"ffs2://{dir_dst}"
_to = ObjectStoragePath(_to_path)
_from.copy(_to)

assert _to.bucket == "bucket2"
Expand All @@ -319,6 +334,12 @@ def test_copy_remote_remote(self):
_from.rmdir(recursive=True)
_to.rmdir(recursive=True)

assert len(hook_lineage_collector.collected_datasets.inputs) == 1
assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=str(_from_file))

# Empty file - shutil.copyfileobj does nothing
assert len(hook_lineage_collector.collected_datasets.outputs) == 0

def test_serde_objectstoragepath(self):
path = "file:///bucket/key/part1/part2"
o = ObjectStoragePath(path)
Expand Down
64 changes: 64 additions & 0 deletions tests/io/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import uuid
from unittest.mock import patch

from airflow.datasets import Dataset
from airflow.io.path import ObjectStoragePath


@patch("airflow.providers_manager.ProvidersManager")
def test_wrapper_catches_reads_writes(providers_manager, hook_lineage_collector):
providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x)
uri = f"file:///tmp/{str(uuid.uuid4())}"
path = ObjectStoragePath(uri)
file = path.open("w")
file.write("aaa")
file.close()

assert len(hook_lineage_collector.outputs) == 1
assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)

file = path.open("r")
file.read()
file.close()

path.unlink(missing_ok=True)

assert len(hook_lineage_collector.inputs) == 1
assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)


@patch("airflow.providers_manager.ProvidersManager")
def test_wrapper_works_with_contextmanager(providers_manager, hook_lineage_collector):
providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x)
uri = f"file:///tmp/{str(uuid.uuid4())}"
path = ObjectStoragePath(uri)
with path.open("w") as file:
file.write("asdf")

assert len(hook_lineage_collector.outputs) == 1
assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)

with path.open("r") as file:
file.read()
path.unlink(missing_ok=True)

assert len(hook_lineage_collector.inputs) == 1
assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)