diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 87118f4b5a53d..e9246c41afcda 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -684,14 +684,14 @@ repos: ^airflow/providers/.*\.py$ exclude: ^.*/.*_vendor/ - id: check-get-lineage-collector-providers - language: pygrep + language: python name: Check providers import hook lineage code from compat description: Make sure you import from airflow.provider.common.compat.lineage.hook instead of airflow.lineage.hook. - entry: "airflow\\.lineage\\.hook" - pass_filenames: true + entry: ./scripts/ci/pre_commit/check_get_lineage_collector_providers.py files: ^airflow/providers/.*\.py$ exclude: ^airflow/providers/common/compat/.*\.py$ + additional_dependencies: [ 'rich>=12.4.4' ] - id: check-decorated-operator-implements-custom-name name: Check @task decorator implements custom_operator_name language: python diff --git a/airflow/io/path.py b/airflow/io/path.py index 5f39782e2ac8f..7c8d1f9f19be0 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -29,6 +29,8 @@ from airflow.io.store import attach from airflow.io.utils.stat import stat_result +from airflow.lineage.hook import get_hook_lineage_collector +from airflow.utils.log.logging_mixin import LoggingMixin if typing.TYPE_CHECKING: from fsspec import AbstractFileSystem @@ -39,6 +41,42 @@ default = "file" +class TrackingFileWrapper(LoggingMixin): + """Wrapper that tracks file operations to intercept lineage.""" + + def __init__(self, path: ObjectStoragePath, obj): + super().__init__() + self._path: ObjectStoragePath = path + self._obj = obj + + def __getattr__(self, name): + attr = getattr(self._obj, name) + if callable(attr): + # If the attribute is a method, wrap it in another method to intercept the call + def wrapper(*args, **kwargs): + self.log.error("Calling method: %s", name) + if name == "read": + get_hook_lineage_collector().add_input_dataset(context=self._path, uri=str(self._path)) + elif name == "write": + get_hook_lineage_collector().add_output_dataset(context=self._path, uri=str(self._path)) + result = attr(*args, **kwargs) + return result + + return wrapper + return attr + + def __getitem__(self, key): + # Intercept item access + return self._obj[key] + + def __enter__(self): + self._obj.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._obj.__exit__(exc_type, exc_val, exc_tb) + + class ObjectStoragePath(CloudPath): """A path-like object for object storage.""" @@ -121,7 +159,7 @@ def namespace(self) -> str: def open(self, mode="r", **kwargs): """Open the file pointed to by this path.""" kwargs.setdefault("block_size", kwargs.pop("buffering", None)) - return self.fs.open(self.path, mode=mode, **kwargs) + return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs)) def stat(self) -> stat_result: # type: ignore[override] """Call ``stat`` and return the result.""" @@ -276,6 +314,11 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) if isinstance(dst, str): dst = ObjectStoragePath(dst) + if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file": + # only emit this in "optimized" variants - else lineage will be captured by file writes/reads + get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self)) + get_hook_lineage_collector().add_output_dataset(context=dst, uri=str(dst)) + # same -> same if self.samestore(dst): self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs) @@ -319,7 +362,6 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) continue src_obj._cp_file(dst) - return # remote file -> remote dir @@ -339,6 +381,8 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) path = ObjectStoragePath(path) if self.samestore(path): + get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self)) + get_hook_lineage_collector().add_output_dataset(context=path, uri=str(path)) return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) # non-local copy diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index ee12e1624e12d..744db3fb38d83 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -17,18 +17,20 @@ # under the License. from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING, Union import attr from airflow.datasets import Dataset -from airflow.hooks.base import BaseHook -from airflow.io.store import ObjectStore from airflow.providers_manager import ProvidersManager from airflow.utils.log.logging_mixin import LoggingMixin -# Store context what sent lineage. -LineageContext = Union[BaseHook, ObjectStore] +if TYPE_CHECKING: + from airflow.hooks.base import BaseHook + from airflow.io.path import ObjectStoragePath + + # Store context what sent lineage. + LineageContext = Union[BaseHook, ObjectStoragePath] _hook_lineage_collector: HookLineageCollector | None = None diff --git a/scripts/ci/pre_commit/check_get_lineage_collector_providers.py b/scripts/ci/pre_commit/check_get_lineage_collector_providers.py new file mode 100755 index 0000000000000..b7c811f6f792b --- /dev/null +++ b/scripts/ci/pre_commit/check_get_lineage_collector_providers.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import sys +from pathlib import Path +from typing import NamedTuple + +sys.path.insert(0, str(Path(__file__).parent.resolve())) +from common_precommit_utils import console, initialize_breeze_precommit + +initialize_breeze_precommit(__name__, __file__) + + +class ImportTuple(NamedTuple): + module: list[str] + name: list[str] + alias: str + + +def get_toplevel_imports(path: str): + with open(path) as fh: + root = ast.parse(fh.read(), path) + + for node in ast.iter_child_nodes(root): + if isinstance(node, ast.Import): + module: list[str] = node.names[0].name.split(".") if node.names else [] + elif isinstance(node, ast.ImportFrom) and node.module: + module = node.module.split(".") + else: + continue + + for n in node.names: # type: ignore[attr-defined] + yield ImportTuple(module=module, name=n.name.split("."), alias=n.asname) + + +errors: list[str] = [] + + +def main() -> int: + for path in sys.argv[1:]: + import_count = 0 + local_error_count = 0 + for imp in get_toplevel_imports(path): + import_count += 1 + if len(imp.module) > 2: + if imp.module[:3] == ["airflow", "lineage", "hook"]: + local_error_count += 1 + errors.append(f"{path}: ({'.'.join(imp.module)})") + console.print(f"[blue]{path}:[/] Import count: {import_count}, error_count {local_error_count}") + if errors: + console.print( + "[red]Some providers files import directly top level from `airflow.lineage.hook` and they are not allowed.[/]\n" + "Only TYPE_CHECKING imports from `airflow.lineage.hook` is allowed in providers." + ) + console.print("Error summary:") + for error in errors: + console.print(error) + return 1 + else: + console.print("[green]All good!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/ci/pre_commit/check_tests_in_right_folders.py b/scripts/ci/pre_commit/check_tests_in_right_folders.py index 1d556fa748cd5..b4f68baffa1f1 100755 --- a/scripts/ci/pre_commit/check_tests_in_right_folders.py +++ b/scripts/ci/pre_commit/check_tests_in_right_folders.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 1ccdfbfb79805..c9c5a99bc42f8 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -267,9 +267,11 @@ def test_relative_to(self): with pytest.raises(ValueError): o1.relative_to(o3) - def test_move_local(self): - _from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") - _to = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") + def test_move_local(self, hook_lineage_collector): + _from_path = f"file:///tmp/{str(uuid.uuid4())}" + _to_path = f"file:///tmp/{str(uuid.uuid4())}" + _from = ObjectStoragePath(_from_path) + _to = ObjectStoragePath(_to_path) _from.touch() _from.move(_to) @@ -278,13 +280,19 @@ def test_move_local(self): _to.unlink() - def test_move_remote(self): + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=_from_path) + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=_to_path) + + def test_move_remote(self, hook_lineage_collector): attach("fakefs", fs=FakeRemoteFileSystem()) - _from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") - print(_from) - _to = ObjectStoragePath(f"fakefs:///tmp/{str(uuid.uuid4())}") - print(_to) + _from_path = f"file:///tmp/{str(uuid.uuid4())}" + _to_path = f"fakefs:///tmp/{str(uuid.uuid4())}" + + _from = ObjectStoragePath(_from_path) + _to = ObjectStoragePath(_to_path) _from.touch() _from.move(_to) @@ -293,7 +301,12 @@ def test_move_remote(self): _to.unlink() - def test_copy_remote_remote(self): + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=str(_from)) + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=str(_to)) + + def test_copy_remote_remote(self, hook_lineage_collector): attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True)) attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True)) @@ -301,13 +314,15 @@ def test_copy_remote_remote(self): dir_dst = f"bucket2/{str(uuid.uuid4())}" key = "foo/bar/baz.txt" - _from = ObjectStoragePath(f"ffs://{dir_src}") + _from_path = f"ffs://{dir_src}" + _from = ObjectStoragePath(_from_path) _from_file = _from / key _from_file.touch() assert _from.bucket == "bucket1" assert _from_file.exists() - _to = ObjectStoragePath(f"ffs2://{dir_dst}") + _to_path = f"ffs2://{dir_dst}" + _to = ObjectStoragePath(_to_path) _from.copy(_to) assert _to.bucket == "bucket2" @@ -319,6 +334,12 @@ def test_copy_remote_remote(self): _from.rmdir(recursive=True) _to.rmdir(recursive=True) + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=str(_from_file)) + + # Empty file - shutil.copyfileobj does nothing + assert len(hook_lineage_collector.collected_datasets.outputs) == 0 + def test_serde_objectstoragepath(self): path = "file:///bucket/key/part1/part2" o = ObjectStoragePath(path) diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py new file mode 100644 index 0000000000000..dab0bee6ecefb --- /dev/null +++ b/tests/io/test_wrapper.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import uuid +from unittest.mock import patch + +from airflow.datasets import Dataset +from airflow.io.path import ObjectStoragePath + + +@patch("airflow.providers_manager.ProvidersManager") +def test_wrapper_catches_reads_writes(providers_manager, hook_lineage_collector): + providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + uri = f"file:///tmp/{str(uuid.uuid4())}" + path = ObjectStoragePath(uri) + file = path.open("w") + file.write("aaa") + file.close() + + assert len(hook_lineage_collector.outputs) == 1 + assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri) + + file = path.open("r") + file.read() + file.close() + + path.unlink(missing_ok=True) + + assert len(hook_lineage_collector.inputs) == 1 + assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri) + + +@patch("airflow.providers_manager.ProvidersManager") +def test_wrapper_works_with_contextmanager(providers_manager, hook_lineage_collector): + providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + uri = f"file:///tmp/{str(uuid.uuid4())}" + path = ObjectStoragePath(uri) + with path.open("w") as file: + file.write("asdf") + + assert len(hook_lineage_collector.outputs) == 1 + assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri) + + with path.open("r") as file: + file.read() + path.unlink(missing_ok=True) + + assert len(hook_lineage_collector.inputs) == 1 + assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)