diff --git a/airflow/example_dags/tutorial_objectstorage.py b/airflow/example_dags/tutorial_objectstorage.py index 11d817400df23..4660aa3c8e8c1 100644 --- a/airflow/example_dags/tutorial_objectstorage.py +++ b/airflow/example_dags/tutorial_objectstorage.py @@ -43,7 +43,7 @@ } # [START create_object_storage_path] -base = ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default") +base = ObjectStoragePath("s3://aws_default@airflow-tutorial-data/") # [END create_object_storage_path] diff --git a/airflow/io/path.py b/airflow/io/path.py index f5eeb14eff043..0e6f80254baf5 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -92,6 +92,7 @@ def __new__( cls: type[PT], *args: str | os.PathLike, scheme: str | None = None, + conn_id: str | None = None, **kwargs: typing.Any, ) -> PT: args_list = list(args) @@ -137,7 +138,14 @@ def __new__( else: args_list.insert(0, parsed_url.path) - return cls._from_parts(args_list, url=parsed_url, **kwargs) # type: ignore + # This matches the parsing logic in urllib.parse; see: + # https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203 + userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") + if have_info: + conn_id = conn_id or userinfo or None + parsed_url = parsed_url._replace(netloc=hostinfo) + + return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore @functools.lru_cache def __hash__(self) -> int: diff --git a/airflow/io/store/__init__.py b/airflow/io/store/__init__.py index 6bf40c939f734..a5a4bd12ddba1 100644 --- a/airflow/io/store/__init__.py +++ b/airflow/io/store/__init__.py @@ -131,7 +131,7 @@ def attach( if not alias: alias = f"{protocol}-{conn_id}" if conn_id else protocol - if store := _STORE_CACHE.get(alias, None): + if store := _STORE_CACHE.get(alias): return store _STORE_CACHE[alias] = store = ObjectStore(protocol=protocol, conn_id=conn_id, fs=fs) diff --git a/docs/apache-airflow/core-concepts/objectstorage.rst b/docs/apache-airflow/core-concepts/objectstorage.rst index 046cb4852278c..d72a734293496 100644 --- a/docs/apache-airflow/core-concepts/objectstorage.rst +++ b/docs/apache-airflow/core-concepts/objectstorage.rst @@ -74,20 +74,22 @@ object you want to interact with. For example, to point to a bucket in s3, you w from airflow.io.path import ObjectStoragePath - base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default") # conn_id is optional + base = ObjectStoragePath("s3://aws_default@my-bucket/") +The username part of the URI is optional. It can alternatively be passed in as a separate keyword argument: + +.. code-block:: python + + # Equivalent to the previous example. + base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default") Listing file-objects: .. code-block:: python @task - def list_files() -> list(ObjectStoragePath): - files = [] - for f in base.iterdir(): - if f.is_file(): - files.append(f) - + def list_files() -> list[ObjectStoragePath]: + files = [f for f in base.iterdir() if f.is_file()] return files diff --git a/docs/apache-airflow/tutorial/objectstorage.rst b/docs/apache-airflow/tutorial/objectstorage.rst index 89ffe0e8f95d6..610450b931986 100644 --- a/docs/apache-airflow/tutorial/objectstorage.rst +++ b/docs/apache-airflow/tutorial/objectstorage.rst @@ -32,7 +32,7 @@ analytical database. You can do this by running ``pip install duckdb``. The tuto makes use of S3 Object Storage. This requires that the amazon provider is installed including ``s3fs`` by running ``pip install apache-airflow-providers-amazon[s3fs]``. If you would like to use a different storage provider, you can do so by changing the -url in the ``create_object_storage_path`` function to the appropriate url for your +URL in the ``create_object_storage_path`` function to the appropriate URL for your provider, for example by replacing ``s3://`` with ``gs://`` for Google Cloud Storage. You will also need the right provider to be installed then. Finally, you will need ``pandas``, which can be installed by running ``pip install pandas``. @@ -49,9 +49,19 @@ It is the fundamental building block of the Object Storage API. :start-after: [START create_object_storage_path] :end-before: [END create_object_storage_path] -The ObjectStoragePath constructor can take an optional connection id. If supplied -it will use the connection to obtain the right credentials to access the backend. -Otherwise it will revert to the default for that backend. +The username part of the URL given to ObjectStoragePath should be a connection ID. +The specified connection will be used to obtain the right credentials to access +the backend. If it is omitted, the default connection for the backend will be used. + +The connection ID can alternatively be passed in with a keyword argument: + +.. code-block:: python + + ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default") + +This is useful when reusing a URL defined for another purpose (e.g. Dataset), +which generally does not contain a username part. The explicit keyword argument +takes precedence over the URL's username value if both are specified. It is safe to instantiate an ObjectStoragePath at the root of your DAG. Connections will not be created until the path is used. This means that you can create the diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 1ac263c59f500..54a675360f351 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -52,6 +52,13 @@ def _strip_protocol(cls, path) -> str: class TestFs: + def setup_class(self): + self._store_cache = _STORE_CACHE.copy() + + def teardown(self): + _STORE_CACHE.clear() + _STORE_CACHE.update(self._store_cache) + def test_alias(self): store = attach("file", alias="local") assert isinstance(store.fs, LocalFileSystem) @@ -100,6 +107,19 @@ def test_ls(self): assert not o.exists() + @pytest.fixture() + def fake_fs(self): + fs = mock.Mock() + fs._strip_protocol.return_value = "/" + fs.conn_id = "fake" + return fs + + def test_objectstoragepath_init_conn_id_in_uri(self, fake_fs): + fake_fs.stat.return_value = {"stat": "result"} + attach(protocol="fake", conn_id="fake", fs=fake_fs) + p = ObjectStoragePath("fake://fake@bucket/path") + assert p.stat() == {"stat": "result", "conn_id": "fake", "protocol": "fake"} + @pytest.mark.parametrize( "fn, args, fn2, path, expected_args, expected_kwargs", [ @@ -124,12 +144,8 @@ def test_ls(self): ), ], ) - def test_standard_extended_api(self, monkeypatch, fn, args, fn2, path, expected_args, expected_kwargs): - _fs = mock.Mock() - _fs._strip_protocol.return_value = "/" - _fs.conn_id = "fake" - - store = attach(protocol="file", conn_id="fake", fs=_fs) + def test_standard_extended_api(self, fake_fs, fn, args, fn2, path, expected_args, expected_kwargs): + store = attach(protocol="file", conn_id="fake", fs=fake_fs) o = ObjectStoragePath(path, conn_id="fake") getattr(o, fn)(**args)