From d9a29dfeab62ec5555c68ff1e0bd4bc641533ae5 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Mon, 22 Jul 2024 20:07:40 +0200 Subject: [PATCH 1/2] aip-62: implement translation mechanism from aip-60 to OpenLineage Signed-off-by: Kacper Muda --- airflow/datasets/__init__.py | 30 ++++++- airflow/provider.yaml.schema.json | 4 + airflow/providers/openlineage/utils/utils.py | 30 +++++++ airflow/providers_manager.py | 82 ++++++++++++++------ tests/datasets/test_dataset.py | 48 ++++++++++-- 5 files changed, 163 insertions(+), 31 deletions(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 90cb52b6adf8a..dd273381896dc 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -56,6 +56,11 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N return ProvidersManager().dataset_uri_handlers.get(scheme) +def _get_normalized_scheme(uri: str) -> str: + parsed = urllib.parse.urlsplit(uri) + return parsed.scheme.lower() + + def _sanitize_uri(uri: str) -> str: """ Sanitize a dataset URI. @@ -72,7 +77,8 @@ def _sanitize_uri(uri: str) -> str: parsed = urllib.parse.urlsplit(uri) if not parsed.scheme and not parsed.netloc: # Does not look like a URI. return uri - normalized_scheme = parsed.scheme.lower() + if not (normalized_scheme := _get_normalized_scheme(uri)): + return uri if normalized_scheme.startswith("x-"): return uri if normalized_scheme == "airflow": @@ -231,6 +237,28 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(self.uri) + @property + def normalized_uri(self) -> str | None: + """ + Returns the normalized and AIP-60 compliant URI whenever possible. + + If we can't retrieve the scheme from URI or no normalizer is provided or if parsing fails, + it returns None. + + If a normalizer for the scheme exists and parsing is successful we return the normalizer result. + """ + if not (normalized_scheme := _get_normalized_scheme(self.uri)): + return None + + if (normalizer := _get_uri_normalizer(normalized_scheme)) is None: + return None + parsed = urllib.parse.urlsplit(self.uri) + try: + normalized_uri = normalizer(parsed) + return urllib.parse.urlunsplit(normalized_uri) + except ValueError: + return None + def as_expression(self) -> Any: """ Serialize the dataset into its scheduling expression. diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json index adbca7846d19e..8f11833ee1c8f 100644 --- a/airflow/provider.yaml.schema.json +++ b/airflow/provider.yaml.schema.json @@ -216,6 +216,10 @@ "factory": { "type": ["string", "null"], "description": "Dataset factory for specified URI. Creates AIP-60 compliant Dataset." + }, + "to_openlineage_converter": { + "type": ["string", "null"], + "description": "OpenLineage converter function for specified URI schemes. Import path to a callable accepting a Dataset and LineageContext and returning OpenLineage dataset." } } } diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index 0689ea39774e4..a36f44b3d5220 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -55,6 +55,8 @@ from airflow.utils.module_loading import import_string if TYPE_CHECKING: + from openlineage.client.run import Dataset as OpenLineageDataset + from airflow.models import DagRun, TaskInstance @@ -635,3 +637,31 @@ def should_use_external_connection(hook) -> bool: if not _IS_AIRFLOW_2_10_OR_HIGHER: return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"] return True + + +def translate_airflow_dataset(dataset: Dataset, lineage_context) -> OpenLineageDataset | None: + """ + Convert a Dataset with an AIP-60 compliant URI to an OpenLineageDataset. + + This function returns None if no URI normalizer is defined, no dataset converter is found or + some core Airflow changes are missing and ImportError is raised. + """ + try: + from airflow.datasets import _get_normalized_scheme + from airflow.providers_manager import ProvidersManager + + ol_converters = ProvidersManager().dataset_to_openlineage_converters + normalized_uri = dataset.normalized_uri + except (ImportError, AttributeError): + return None + + if normalized_uri is None: + return None + + if not (normalized_scheme := _get_normalized_scheme(normalized_uri)): + return None + + if (airflow_to_ol_converter := ol_converters.get(normalized_scheme)) is None: + return None + + return airflow_to_ol_converter(Dataset(uri=normalized_uri, extra=dataset.extra), lineage_context) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index f6d29a51d12ca..dd3e841fa1662 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -428,6 +428,7 @@ def __init__(self): self._fs_set: set[str] = set() self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} self._dataset_factories: dict[str, Callable[..., Dataset]] = {} + self._dataset_to_openlineage_converters: dict[str, Callable] = {} self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment] # keeps mapping between connection_types and hook class, package they come from self._hook_provider_dict: dict[str, HookClassProvider] = {} @@ -525,10 +526,10 @@ def initialize_providers_filesystems(self): self._discover_filesystems() @provider_info_cache("dataset_uris") - def initialize_providers_dataset_uri_handlers_and_factories(self): - """Lazy initialization of provider dataset URI handlers.""" + def initialize_providers_dataset_uri_resources(self): + """Lazy initialization of provider dataset URI handlers, factories, converters etc.""" self.initialize_providers_list() - self._discover_dataset_uri_handlers_and_factories() + self._discover_dataset_uri_resources() @provider_info_cache("hook_lineage_writers") @provider_info_cache("taskflow_decorators") @@ -881,28 +882,52 @@ def _discover_filesystems(self) -> None: self._fs_set.add(fs_module_name) self._fs_set = set(sorted(self._fs_set)) - def _discover_dataset_uri_handlers_and_factories(self) -> None: + def _discover_dataset_uri_resources(self) -> None: + """Discovers and registers dataset URI handlers, factories, and converters for all providers.""" from airflow.datasets import normalize_noop - for provider_package, provider in self._provider_dict.items(): - for handler_info in provider.data.get("dataset-uris", []): - schemes = handler_info.get("schemes") - handler_path = handler_info.get("handler") - factory_path = handler_info.get("factory") - if schemes is None: - continue - - if handler_path is not None and ( - handler := _correctness_check(provider_package, handler_path, provider) - ): - pass - else: - handler = normalize_noop - self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes) - if factory_path is not None and ( - factory := _correctness_check(provider_package, factory_path, provider) - ): - self._dataset_factories.update((scheme, factory) for scheme in schemes) + def _safe_register_resource( + provider_package_name: str, + schemes_list: list[str], + resource_path: str | None, + resource_registry: dict, + default_resource: Any = None, + ): + """ + Register a specific resource (handler, factory, or converter) for the given schemes. + + If the resolved resource (either from the path or the default) is valid, it updates + the resource registry with the appropriate resource for each scheme. + """ + resource = ( + _correctness_check(provider_package_name, resource_path, provider) + if resource_path is not None + else default_resource + ) + if resource: + resource_registry.update((scheme, resource) for scheme in schemes_list) + + for provider_name, provider in self._provider_dict.items(): + for uri_info in provider.data.get("dataset-uris", []): + if "schemes" not in uri_info or "handler" not in uri_info: + continue # Both schemas and handler must be explicitly set, handler can be set to null + common_args = {"schemes_list": uri_info["schemes"], "provider_package_name": provider_name} + _safe_register_resource( + resource_path=uri_info["handler"], + resource_registry=self._dataset_uri_handlers, + default_resource=normalize_noop, + **common_args, + ) + _safe_register_resource( + resource_path=uri_info.get("factory"), + resource_registry=self._dataset_factories, + **common_args, + ) + _safe_register_resource( + resource_path=uri_info.get("to_openlineage_converter"), + resource_registry=self._dataset_to_openlineage_converters, + **common_args, + ) def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): @@ -1301,14 +1326,21 @@ def filesystem_module_names(self) -> list[str]: @property def dataset_factories(self) -> dict[str, Callable[..., Dataset]]: - self.initialize_providers_dataset_uri_handlers_and_factories() + self.initialize_providers_dataset_uri_resources() return self._dataset_factories @property def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: - self.initialize_providers_dataset_uri_handlers_and_factories() + self.initialize_providers_dataset_uri_resources() return self._dataset_uri_handlers + @property + def dataset_to_openlineage_converters( + self, + ) -> dict[str, Callable]: + self.initialize_providers_dataset_uri_resources() + return self._dataset_to_openlineage_converters + @property def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: self.initialize_providers_configuration() diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 19dbcf4b53df7..1b2d3d6d4c723 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -31,6 +31,7 @@ DatasetAll, DatasetAny, _DatasetAliasCondition, + _get_normalized_scheme, _sanitize_uri, ) from airflow.models.dataset import DatasetAliasModel, DatasetDagRunQueue, DatasetModel @@ -454,31 +455,68 @@ def test_datasets_expression_error(expression: Callable[[], None], error: str) - assert str(info.value) == error -def mock_get_uri_normalizer(normalized_scheme): +def test_get_normalized_scheme(): + assert _get_normalized_scheme("http://example.com") == "http" + assert _get_normalized_scheme("HTTPS://example.com") == "https" + assert _get_normalized_scheme("ftp://example.com") == "ftp" + assert _get_normalized_scheme("file://") == "file" + + assert _get_normalized_scheme("example.com") == "" + assert _get_normalized_scheme("") == "" + assert _get_normalized_scheme(" ") == "" + + +def _mock_get_uri_normalizer_raising_error(normalized_scheme): def normalizer(uri): raise ValueError("Incorrect URI format") return normalizer -@patch("airflow.datasets._get_uri_normalizer", mock_get_uri_normalizer) +def _mock_get_uri_normalizer_noop(normalized_scheme): + def normalizer(uri): + return uri + + return normalizer + + +@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) @patch("airflow.datasets.warnings.warn") -def test__sanitize_uri_raises_warning(mock_warn): +def test_sanitize_uri_raises_warning(mock_warn): _sanitize_uri("postgres://localhost:5432/database.schema.table") msg = mock_warn.call_args.args[0] assert "The dataset URI postgres://localhost:5432/database.schema.table is not AIP-60 compliant" in msg assert "In Airflow 3, this will raise an exception." in msg -@patch("airflow.datasets._get_uri_normalizer", mock_get_uri_normalizer) +@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) @conf_vars({("core", "strict_dataset_uri_validation"): "True"}) -def test__sanitize_uri_raises_exception(): +def test_sanitize_uri_raises_exception(): with pytest.raises(ValueError) as e_info: _sanitize_uri("postgres://localhost:5432/database.schema.table") assert isinstance(e_info.value, ValueError) assert str(e_info.value) == "Incorrect URI format" +@patch("airflow.datasets._get_uri_normalizer", lambda x: None) +def test_normalize_uri_no_normalizer_found(): + dataset = Dataset(uri="any_uri_without_normalizer_defined") + assert dataset.normalized_uri is None + + +@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) +def test_normalize_uri_invalid_uri(): + dataset = Dataset(uri="any_uri_not_aip60_compliant") + assert dataset.normalized_uri is None + + +@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_noop) +@patch("airflow.datasets._get_normalized_scheme", lambda x: "valid_scheme") +def test_normalize_uri_valid_uri(): + dataset = Dataset(uri="valid_aip60_uri") + assert dataset.normalized_uri == "valid_aip60_uri" + + @pytest.mark.db_test @pytest.mark.usefixtures("clear_datasets") class Test_DatasetAliasCondition: From bd2ef8c38f8a3387a170c22a43579bd13f9431d4 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Mon, 22 Jul 2024 20:07:44 +0200 Subject: [PATCH 2/2] aip-62: implement translation examples from aip-60 to OpenLineage Signed-off-by: Kacper Muda --- airflow/providers/amazon/aws/datasets/s3.py | 22 ++++++++ airflow/providers/amazon/provider.yaml | 3 +- airflow/providers/common/io/datasets/file.py | 26 +++++++++ airflow/providers/common/io/provider.yaml | 3 +- .../providers/amazon/aws/datasets/test_s3.py | 54 ++++++++++++++++++- .../providers/common/io/datasets/test_file.py | 45 +++++++++++++++- 6 files changed, 149 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/datasets/s3.py b/airflow/providers/amazon/aws/datasets/s3.py index 89889efe577b3..e6bed6dbe3dfa 100644 --- a/airflow/providers/amazon/aws/datasets/s3.py +++ b/airflow/providers/amazon/aws/datasets/s3.py @@ -16,8 +16,30 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + from airflow.datasets import Dataset +from airflow.providers.amazon.aws.hooks.s3 import S3Hook + +if TYPE_CHECKING: + from urllib.parse import SplitResult + + from openlineage.client.run import Dataset as OpenLineageDataset def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset: return Dataset(uri=f"s3://{bucket}/{key}", extra=extra) + + +def sanitize_uri(uri: SplitResult) -> SplitResult: + if not uri.netloc: + raise ValueError("URI format s3:// must contain a bucket name") + return uri + + +def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset: + """Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook.""" + from openlineage.client.run import Dataset as OpenLineageDataset + + bucket, key = S3Hook.parse_s3_url(dataset.uri) + return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/") diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 9dd76ac9fa3b1..309abcc23ad28 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -561,7 +561,8 @@ sensors: dataset-uris: - schemes: [s3] - handler: null + handler: airflow.providers.amazon.aws.datasets.s3.sanitize_uri + to_openlineage_converter: airflow.providers.amazon.aws.datasets.s3.convert_dataset_to_openlineage factory: airflow.providers.amazon.aws.datasets.s3.create_dataset filesystems: diff --git a/airflow/providers/common/io/datasets/file.py b/airflow/providers/common/io/datasets/file.py index 1bc4969762b85..aa7e8d98be7a8 100644 --- a/airflow/providers/common/io/datasets/file.py +++ b/airflow/providers/common/io/datasets/file.py @@ -16,9 +16,35 @@ # under the License. from __future__ import annotations +import urllib.parse +from typing import TYPE_CHECKING + from airflow.datasets import Dataset +if TYPE_CHECKING: + from urllib.parse import SplitResult + + from openlineage.client.run import Dataset as OpenLineageDataset + def create_dataset(*, path: str, extra=None) -> Dataset: # We assume that we get absolute path starting with / return Dataset(uri=f"file://{path}", extra=extra) + + +def sanitize_uri(uri: SplitResult) -> SplitResult: + if not uri.path: + raise ValueError("URI format file:// must contain a non-empty path.") + return uri + + +def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset: + """ + Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the context. + + Windows paths are not standardized and can produce unexpected behaviour. + """ + from openlineage.client.run import Dataset as OpenLineageDataset + + parsed = urllib.parse.urlsplit(dataset.uri) + return OpenLineageDataset(namespace=f"file://{parsed.netloc}", name=parsed.path) diff --git a/airflow/providers/common/io/provider.yaml b/airflow/providers/common/io/provider.yaml index a45d3d7dfef4c..e644b3f07089d 100644 --- a/airflow/providers/common/io/provider.yaml +++ b/airflow/providers/common/io/provider.yaml @@ -53,7 +53,8 @@ xcom: dataset-uris: - schemes: [file] - handler: null + handler: airflow.providers.common.io.datasets.file.sanitize_uri + to_openlineage_converter: airflow.providers.common.io.datasets.file.convert_dataset_to_openlineage factory: airflow.providers.common.io.datasets.file.create_dataset config: diff --git a/tests/providers/amazon/aws/datasets/test_s3.py b/tests/providers/amazon/aws/datasets/test_s3.py index c7ffe252401e7..893d6acf677bc 100644 --- a/tests/providers/amazon/aws/datasets/test_s3.py +++ b/tests/providers/amazon/aws/datasets/test_s3.py @@ -16,8 +16,38 @@ # under the License. from __future__ import annotations +import urllib.parse + +import pytest + from airflow.datasets import Dataset -from airflow.providers.amazon.aws.datasets.s3 import create_dataset +from airflow.providers.amazon.aws.datasets.s3 import ( + convert_dataset_to_openlineage, + create_dataset, + sanitize_uri, +) +from airflow.providers.amazon.aws.hooks.s3 import S3Hook + + +def test_sanitize_uri(): + uri = sanitize_uri(urllib.parse.urlsplit("s3://bucket/dir/file.txt")) + result = sanitize_uri(uri) + assert result.scheme == "s3" + assert result.netloc == "bucket" + assert result.path == "/dir/file.txt" + + +def test_sanitize_uri_no_netloc(): + with pytest.raises(ValueError): + sanitize_uri(urllib.parse.urlsplit("s3://")) + + +def test_sanitize_uri_no_path(): + uri = sanitize_uri(urllib.parse.urlsplit("s3://bucket")) + result = sanitize_uri(uri) + assert result.scheme == "s3" + assert result.netloc == "bucket" + assert result.path == "" def test_create_dataset(): @@ -25,3 +55,25 @@ def test_create_dataset(): assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == Dataset( uri="s3://test-bucket/test-dir/test-path" ) + + +def test_sanitize_uri_trailing_slash(): + uri = sanitize_uri(urllib.parse.urlsplit("s3://bucket/")) + result = sanitize_uri(uri) + assert result.scheme == "s3" + assert result.netloc == "bucket" + assert result.path == "/" + + +def test_convert_dataset_to_openlineage_valid(): + uri = "s3://bucket/dir/file.txt" + ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri), lineage_context=S3Hook()) + assert ol_dataset.namespace == "s3://bucket" + assert ol_dataset.name == "dir/file.txt" + + +@pytest.mark.parametrize("uri", ("s3://bucket", "s3://bucket/")) +def test_convert_dataset_to_openlineage_no_path(uri): + ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri), lineage_context=S3Hook()) + assert ol_dataset.namespace == "s3://bucket" + assert ol_dataset.name == "/" diff --git a/tests/providers/common/io/datasets/test_file.py b/tests/providers/common/io/datasets/test_file.py index 43d63cb205586..b2e4fddf986fd 100644 --- a/tests/providers/common/io/datasets/test_file.py +++ b/tests/providers/common/io/datasets/test_file.py @@ -16,9 +16,52 @@ # under the License. from __future__ import annotations +from urllib.parse import urlsplit, urlunsplit + +import pytest +from openlineage.client.run import Dataset as OpenLineageDataset + from airflow.datasets import Dataset -from airflow.providers.common.io.datasets.file import create_dataset +from airflow.providers.common.io.datasets.file import ( + convert_dataset_to_openlineage, + create_dataset, + sanitize_uri, +) + + +@pytest.mark.parametrize( + ("uri", "expected"), + ( + ("file:///valid/path/", "file:///valid/path/"), + ("file://C://dir/file", "file://C://dir/file"), + ), +) +def test_sanitize_uri_valid(uri, expected): + result = sanitize_uri(urlsplit(uri)) + assert urlunsplit(result) == expected + + +@pytest.mark.parametrize("uri", ("file://",)) +def test_sanitize_uri_invalid(uri): + with pytest.raises(ValueError): + sanitize_uri(urlsplit(uri)) def test_file_dataset(): assert create_dataset(path="/asdf/fdsa") == Dataset(uri="file:///asdf/fdsa") + + +@pytest.mark.parametrize( + ("uri", "ol_dataset"), + ( + ("file:///valid/path", OpenLineageDataset(namespace="file://", name="/valid/path")), + ( + "file://127.0.0.1:8080/dir/file.csv", + OpenLineageDataset(namespace="file://127.0.0.1:8080", name="/dir/file.csv"), + ), + ("file:///C://dir/file", OpenLineageDataset(namespace="file://", name="/C://dir/file")), + ), +) +def test_convert_dataset_to_openlineage(uri, ol_dataset): + result = convert_dataset_to_openlineage(Dataset(uri=uri), None) + assert result == ol_dataset