From a63bb48df25d1ca914a06a7d8b62b1da831d12cb Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 28 Nov 2023 17:24:52 +0800 Subject: [PATCH 1/5] Ensure fs attached in tests are cleaned up --- tests/io/test_path.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 1ac263c59f500..a5acda54b6a68 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -52,6 +52,13 @@ def _strip_protocol(cls, path) -> str: class TestFs: + def setup_class(self): + self._store_cache = _STORE_CACHE.copy() + + def teardown(self): + _STORE_CACHE.clear() + _STORE_CACHE.update(self._store_cache) + def test_alias(self): store = attach("file", alias="local") assert isinstance(store.fs, LocalFileSystem) @@ -100,6 +107,13 @@ def test_ls(self): assert not o.exists() + @pytest.fixture() + def fake_fs(self): + fs = mock.Mock() + fs._strip_protocol.return_value = "/" + fs.conn_id = "fake" + return fs + @pytest.mark.parametrize( "fn, args, fn2, path, expected_args, expected_kwargs", [ @@ -124,12 +138,8 @@ def test_ls(self): ), ], ) - def test_standard_extended_api(self, monkeypatch, fn, args, fn2, path, expected_args, expected_kwargs): - _fs = mock.Mock() - _fs._strip_protocol.return_value = "/" - _fs.conn_id = "fake" - - store = attach(protocol="file", conn_id="fake", fs=_fs) + def test_standard_extended_api(self, fake_fs, fn, args, fn2, path, expected_args, expected_kwargs): + store = attach(protocol="file", conn_id="fake", fs=fake_fs) o = ObjectStoragePath(path, conn_id="fake") getattr(o, fn)(**args) From cbbebe3c6bffc6c1145d5e6f22ff085d8086b80d Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 28 Nov 2023 17:25:27 +0800 Subject: [PATCH 2/5] Pass conn ID to ObjectStoragePath via URI This enables an alternative ObjectStoragePath init syntax, using the auth section in the URI to supply conn ID instead of a separate keyword argument. The explicit keyword argument is honored if supplied. --- airflow/example_dags/tutorial_objectstorage.py | 2 +- airflow/io/path.py | 7 ++++++- airflow/io/store/__init__.py | 2 +- .../core-concepts/objectstorage.rst | 16 +++++++++------- tests/io/test_path.py | 6 ++++++ 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/airflow/example_dags/tutorial_objectstorage.py b/airflow/example_dags/tutorial_objectstorage.py index 11d817400df23..4660aa3c8e8c1 100644 --- a/airflow/example_dags/tutorial_objectstorage.py +++ b/airflow/example_dags/tutorial_objectstorage.py @@ -43,7 +43,7 @@ } # [START create_object_storage_path] -base = ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default") +base = ObjectStoragePath("s3://aws_default@airflow-tutorial-data/") # [END create_object_storage_path] diff --git a/airflow/io/path.py b/airflow/io/path.py index f5eeb14eff043..911f29a6a35e4 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -92,6 +92,7 @@ def __new__( cls: type[PT], *args: str | os.PathLike, scheme: str | None = None, + conn_id: str | None = None, **kwargs: typing.Any, ) -> PT: args_list = list(args) @@ -137,7 +138,11 @@ def __new__( else: args_list.insert(0, parsed_url.path) - return cls._from_parts(args_list, url=parsed_url, **kwargs) # type: ignore + if parsed_url.username is not None: + conn_id = conn_id or parsed_url.username or None + parsed_url = parsed_url._replace(netloc=parsed_url.netloc.rsplit("@", 1)[-1]) + + return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore @functools.lru_cache def __hash__(self) -> int: diff --git a/airflow/io/store/__init__.py b/airflow/io/store/__init__.py index 6bf40c939f734..a5a4bd12ddba1 100644 --- a/airflow/io/store/__init__.py +++ b/airflow/io/store/__init__.py @@ -131,7 +131,7 @@ def attach( if not alias: alias = f"{protocol}-{conn_id}" if conn_id else protocol - if store := _STORE_CACHE.get(alias, None): + if store := _STORE_CACHE.get(alias): return store _STORE_CACHE[alias] = store = ObjectStore(protocol=protocol, conn_id=conn_id, fs=fs) diff --git a/docs/apache-airflow/core-concepts/objectstorage.rst b/docs/apache-airflow/core-concepts/objectstorage.rst index 046cb4852278c..d72a734293496 100644 --- a/docs/apache-airflow/core-concepts/objectstorage.rst +++ b/docs/apache-airflow/core-concepts/objectstorage.rst @@ -74,20 +74,22 @@ object you want to interact with. For example, to point to a bucket in s3, you w from airflow.io.path import ObjectStoragePath - base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default") # conn_id is optional + base = ObjectStoragePath("s3://aws_default@my-bucket/") +The username part of the URI is optional. It can alternatively be passed in as a separate keyword argument: + +.. code-block:: python + + # Equivalent to the previous example. + base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default") Listing file-objects: .. code-block:: python @task - def list_files() -> list(ObjectStoragePath): - files = [] - for f in base.iterdir(): - if f.is_file(): - files.append(f) - + def list_files() -> list[ObjectStoragePath]: + files = [f for f in base.iterdir() if f.is_file()] return files diff --git a/tests/io/test_path.py b/tests/io/test_path.py index a5acda54b6a68..54a675360f351 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -114,6 +114,12 @@ def fake_fs(self): fs.conn_id = "fake" return fs + def test_objectstoragepath_init_conn_id_in_uri(self, fake_fs): + fake_fs.stat.return_value = {"stat": "result"} + attach(protocol="fake", conn_id="fake", fs=fake_fs) + p = ObjectStoragePath("fake://fake@bucket/path") + assert p.stat() == {"stat": "result", "conn_id": "fake", "protocol": "fake"} + @pytest.mark.parametrize( "fn, args, fn2, path, expected_args, expected_kwargs", [ From 0e6f7e02727bf35c24d4c89452c0607e26d0b041 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 30 Nov 2023 17:37:24 +0800 Subject: [PATCH 3/5] Add explaination on connection ID in tutorial --- docs/apache-airflow/tutorial/objectstorage.rst | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/apache-airflow/tutorial/objectstorage.rst b/docs/apache-airflow/tutorial/objectstorage.rst index 89ffe0e8f95d6..610450b931986 100644 --- a/docs/apache-airflow/tutorial/objectstorage.rst +++ b/docs/apache-airflow/tutorial/objectstorage.rst @@ -32,7 +32,7 @@ analytical database. You can do this by running ``pip install duckdb``. The tuto makes use of S3 Object Storage. This requires that the amazon provider is installed including ``s3fs`` by running ``pip install apache-airflow-providers-amazon[s3fs]``. If you would like to use a different storage provider, you can do so by changing the -url in the ``create_object_storage_path`` function to the appropriate url for your +URL in the ``create_object_storage_path`` function to the appropriate URL for your provider, for example by replacing ``s3://`` with ``gs://`` for Google Cloud Storage. You will also need the right provider to be installed then. Finally, you will need ``pandas``, which can be installed by running ``pip install pandas``. @@ -49,9 +49,19 @@ It is the fundamental building block of the Object Storage API. :start-after: [START create_object_storage_path] :end-before: [END create_object_storage_path] -The ObjectStoragePath constructor can take an optional connection id. If supplied -it will use the connection to obtain the right credentials to access the backend. -Otherwise it will revert to the default for that backend. +The username part of the URL given to ObjectStoragePath should be a connection ID. +The specified connection will be used to obtain the right credentials to access +the backend. If it is omitted, the default connection for the backend will be used. + +The connection ID can alternatively be passed in with a keyword argument: + +.. code-block:: python + + ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default") + +This is useful when reusing a URL defined for another purpose (e.g. Dataset), +which generally does not contain a username part. The explicit keyword argument +takes precedence over the URL's username value if both are specified. It is safe to instantiate an ObjectStoragePath at the root of your DAG. Connections will not be created until the path is used. This means that you can create the From a6fc4bd086ca73c766e69d2e0d1b3b6074170c03 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 30 Nov 2023 17:41:50 +0800 Subject: [PATCH 4/5] Add note on parsing logic --- airflow/io/path.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/io/path.py b/airflow/io/path.py index 911f29a6a35e4..8696f6238a663 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -140,6 +140,9 @@ def __new__( if parsed_url.username is not None: conn_id = conn_id or parsed_url.username or None + # If there are multiple @ in the host string, parse from the last. + # This matches the parsing logic in urllib.parse; see: + # https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75aa/Lib/urllib/parse.py#L196 parsed_url = parsed_url._replace(netloc=parsed_url.netloc.rsplit("@", 1)[-1]) return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore From 50389cf9b64e56e434cfae10ece94866dce78375 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 30 Nov 2023 17:44:44 +0800 Subject: [PATCH 5/5] Do not use two different parser While unlikely, it is *possible* for CPython to change how username and password is parsed from netloc. To ensure we always do the right thing, it is better to just implement the logic outselves so it always work as expected. We need this logic to get the host info anyway, so it is not possible to solely rely on CPython's parser. --- airflow/io/path.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/io/path.py b/airflow/io/path.py index 8696f6238a663..0e6f80254baf5 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -138,12 +138,12 @@ def __new__( else: args_list.insert(0, parsed_url.path) - if parsed_url.username is not None: - conn_id = conn_id or parsed_url.username or None - # If there are multiple @ in the host string, parse from the last. - # This matches the parsing logic in urllib.parse; see: - # https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75aa/Lib/urllib/parse.py#L196 - parsed_url = parsed_url._replace(netloc=parsed_url.netloc.rsplit("@", 1)[-1]) + # This matches the parsing logic in urllib.parse; see: + # https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203 + userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") + if have_info: + conn_id = conn_id or userinfo or None + parsed_url = parsed_url._replace(netloc=hostinfo) return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore