Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions upath/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import os
import re
from pathlib import PurePath
from typing import TYPE_CHECKING
from typing import Any

if TYPE_CHECKING:
from upath.core import UPath

__all__ = [
"get_upath_protocol",
"normalize_empty_netloc",
"compatible_protocol",
]

# Regular expression to match fsspec style protocols.
Expand Down Expand Up @@ -59,3 +64,15 @@ def normalize_empty_netloc(pth: str) -> str:
path = m.group("path")
pth = f"{protocol}:///{path}"
return pth


def compatible_protocol(protocol: str, *args: str | os.PathLike[str] | UPath) -> bool:
"""check if UPath protocols are compatible"""
for arg in args:
other_protocol = get_upath_protocol(arg)
# consider protocols equivalent if they match up to the first "+"
other_protocol = other_protocol.partition("+")[0]
# protocols: only identical (or empty "") protocols can combine
if other_protocol and other_protocol != protocol:
return False
return True
24 changes: 7 additions & 17 deletions upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from upath._flavour import LazyFlavourDescriptor
from upath._flavour import upath_get_kwargs_from_url
from upath._flavour import upath_urijoin
from upath._protocol import compatible_protocol
from upath._protocol import get_upath_protocol
from upath._stat import UPathStatResult
from upath.registry import get_upath_class
Expand Down Expand Up @@ -251,23 +252,12 @@ def __init__(
self._storage_options = storage_options.copy()

# check that UPath subclasses in args are compatible
# --> ensures items in _raw_paths are compatible
for arg in args:
if not isinstance(arg, UPath):
continue
# protocols: only identical (or empty "") protocols can combine
if arg.protocol and arg.protocol != self._protocol:
raise TypeError("can't combine different UPath protocols as parts")
# storage_options: args may not define other storage_options
if any(
self._storage_options.get(key) != value
for key, value in arg.storage_options.items()
):
# TODO:
# Future versions of UPath could verify that storage_options
# can be combined between UPath instances. Not sure if this
# is really necessary though. A warning might be enough...
pass
# TODO:
# Future versions of UPath could verify that storage_options
# can be combined between UPath instances. Not sure if this
# is really necessary though. A warning might be enough...
if not compatible_protocol(self._protocol, *args):
raise ValueError("can't combine incompatible UPath protocols")

# fill ._raw_paths
if hasattr(self, "_raw_paths"):
Expand Down
7 changes: 7 additions & 0 deletions upath/implementations/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
class CloudPath(UPath):
__slots__ = ()

def __init__(
self, *args, protocol: str | None = None, **storage_options: Any
) -> None:
super().__init__(*args, protocol=protocol, **storage_options)
if not self.drive and len(self.parts) > 1:
raise ValueError("non key-like path provided (bucket/container missing)")

@classmethod
def _transform_init_args(
cls,
Expand Down
15 changes: 15 additions & 0 deletions upath/implementations/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import MutableMapping
from urllib.parse import SplitResult

from upath._protocol import compatible_protocol
from upath.core import UPath

__all__ = [
Expand Down Expand Up @@ -141,6 +142,8 @@ def __new__(
raise NotImplementedError(
f"cannot instantiate {cls.__name__} on your system"
)
if not compatible_protocol("", *args):
raise ValueError("can't combine incompatible UPath protocols")
obj = super().__new__(cls, *args)
obj._protocol = ""
return obj # type: ignore[return-value]
Expand All @@ -152,6 +155,11 @@ def __init__(
self._drv, self._root, self._parts = type(self)._parse_args(args)
_upath_init(self)

def _make_child(self, args):
if not compatible_protocol(self._protocol, *args):
raise ValueError("can't combine incompatible UPath protocols")
return super()._make_child(args)

@classmethod
def _from_parts(cls, *args, **kwargs):
obj = super(Path, cls)._from_parts(*args, **kwargs)
Expand Down Expand Up @@ -205,6 +213,8 @@ def __new__(
raise NotImplementedError(
f"cannot instantiate {cls.__name__} on your system"
)
if not compatible_protocol("", *args):
raise ValueError("can't combine incompatible UPath protocols")
obj = super().__new__(cls, *args)
obj._protocol = ""
return obj # type: ignore[return-value]
Expand All @@ -216,6 +226,11 @@ def __init__(
self._drv, self._root, self._parts = self._parse_args(args)
_upath_init(self)

def _make_child(self, args):
if not compatible_protocol(self._protocol, *args):
raise ValueError("can't combine incompatible UPath protocols")
return super()._make_child(args)

@classmethod
def _from_parts(cls, *args, **kwargs):
obj = super(Path, cls)._from_parts(*args, **kwargs)
Expand Down
29 changes: 29 additions & 0 deletions upath/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,32 @@ def test_query_string(uri, query_str):
p = UPath(uri)
assert str(p).endswith(query_str)
assert p.path.endswith(query_str)


@pytest.mark.parametrize(
"base,join",
[
("/a", "s3://bucket/b"),
("s3://bucket/a", "gs://b/c"),
("gs://bucket/a", "memory://b/c"),
("memory://bucket/a", "s3://b/c"),
],
)
def test_joinpath_on_protocol_mismatch(base, join):
with pytest.raises(ValueError):
UPath(base).joinpath(UPath(join))
with pytest.raises(ValueError):
UPath(base) / UPath(join)


@pytest.mark.parametrize(
"base,join",
[
("/a", "s3://bucket/b"),
("s3://bucket/a", "gs://b/c"),
("gs://bucket/a", "memory://b/c"),
("memory://bucket/a", "s3://b/c"),
],
)
def test_joinuri_on_protocol_mismatch(base, join):
assert UPath(base).joinuri(UPath(join)) == UPath(join)