diff --git a/dvc/config.py b/dvc/config.py index 1fc63acbb3..9cc7e80f71 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -184,6 +184,8 @@ class RelPath(str): }, "http": {**HTTP_COMMON, **REMOTE_COMMON}, "https": {**HTTP_COMMON, **REMOTE_COMMON}, + "webdav": {**HTTP_COMMON, **REMOTE_COMMON}, + "webdavs": {**HTTP_COMMON, **REMOTE_COMMON}, "remote": {str: object}, # Any of the above options are valid } ) diff --git a/dvc/path_info.py b/dvc/path_info.py index e502130aa8..d9c0d04f17 100644 --- a/dvc/path_info.py +++ b/dvc/path_info.py @@ -104,7 +104,14 @@ def __repr__(self): class URLInfo(_BasePath): - DEFAULT_PORTS = {"http": 80, "https": 443, "ssh": 22, "hdfs": 0} + DEFAULT_PORTS = { + "http": 80, + "https": 443, + "ssh": 22, + "hdfs": 0, + "webdav": 80, + "webdavs": 443, + } def __init__(self, url): p = urlparse(url) @@ -312,3 +319,19 @@ def __eq__(self, other): and self._path == other._path and self._extra_parts == other._extra_parts ) + + +class WebdavURLInfo(HTTPURLInfo): + def __init__(self, url): + super().__init__(url) + + @cached_property + def url(self): + return "{}://{}{}{}{}{}".format( + self.scheme.replace("webdav", "http"), + self.netloc, + self._spath, + (";" + self.params) if self.params else "", + ("?" + self.query) if self.query else "", + ("#" + self.fragment) if self.fragment else "", + ) diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 6c2de0a057..76c0047272 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -11,6 +11,8 @@ from dvc.remote.oss import RemoteOSS from dvc.remote.s3 import RemoteS3 from dvc.remote.ssh import RemoteSSH +from dvc.remote.webdav import RemoteWEBDAV +from dvc.remote.webdavs import RemoteWEBDAVS REMOTES = [ @@ -23,6 +25,8 @@ RemoteS3, RemoteSSH, RemoteOSS, + RemoteWEBDAV, + RemoteWEBDAVS, # NOTE: RemoteLOCAL is the default ] diff --git a/dvc/remote/webdav.py b/dvc/remote/webdav.py new file mode 100644 index 0000000000..15cc3a1ecc --- /dev/null +++ b/dvc/remote/webdav.py @@ -0,0 +1,65 @@ +import os.path + +from .http import RemoteHTTP +from dvc.scheme import Schemes +from dvc.progress import Tqdm +from dvc.exceptions import HTTPError +from dvc.path_info import WebdavURLInfo + + +class RemoteWEBDAV(RemoteHTTP): + scheme = Schemes.WEBDAV + path_cls = WebdavURLInfo + REQUEST_TIMEOUT = 20 + + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + def chunks(): + with open(from_file, "rb") as fd: + with Tqdm.wrapattr( + fd, + "read", + total=None + if no_progress_bar + else os.path.getsize(from_file), + leave=False, + desc=to_info.url if name is None else name, + disable=no_progress_bar, + ) as fd_wrapped: + while True: + chunk = fd_wrapped.read(self.CHUNK_SIZE) + if not chunk: + break + yield chunk + + self._create_collections(to_info) + response = self._request("PUT", to_info.url, data=chunks()) + if response.status_code not in (200, 201): + raise HTTPError(response.status_code, response.reason) + + def _create_collections(self, to_info): + url_cols = [x.url + "/" for x in to_info.parents][:-1] + from_idx = 0 + for idx, url in enumerate(url_cols): + from_idx = idx + if bool(self._request("HEAD", url)): + break + for url in reversed(url_cols[:from_idx]): + response = self._request("MKCOL", url) + if response.status_code not in (200, 201): + if bool(self._request("HEAD", url)): + continue + raise HTTPError(response.status_code, response.reason) + + def remove(self, path_info): + response = self._request("DELETE", path_info.url) + if response.status_code not in (200, 201, 204): + raise HTTPError(response.status_code, response.reason) + + def gc(self): + return super(RemoteHTTP, self).gc() + + def list_cache_paths(self, prefix=None, progress_callback=None): + raise NotImplementedError + + def walk_files(self, path_info): + raise NotImplementedError diff --git a/dvc/remote/webdavs.py b/dvc/remote/webdavs.py new file mode 100644 index 0000000000..1302123340 --- /dev/null +++ b/dvc/remote/webdavs.py @@ -0,0 +1,15 @@ +from .webdav import RemoteWEBDAV +from dvc.scheme import Schemes + + +class RemoteWEBDAVS(RemoteWEBDAV): + scheme = Schemes.WEBDAVS + + def gc(self): + raise NotImplementedError + + def list_cache_paths(self, prefix=None, progress_callback=None): + raise NotImplementedError + + def walk_files(self, path_info): + raise NotImplementedError diff --git a/dvc/scheme.py b/dvc/scheme.py index e64e24f5ac..76c6d7a497 100644 --- a/dvc/scheme.py +++ b/dvc/scheme.py @@ -9,3 +9,5 @@ class Schemes: GDRIVE = "gdrive" LOCAL = "local" OSS = "oss" + WEBDAV = "webdav" + WEBDAVS = "webdavs" diff --git a/tests/unit/remote/test_webdav.py b/tests/unit/remote/test_webdav.py new file mode 100644 index 0000000000..539ae7b406 --- /dev/null +++ b/tests/unit/remote/test_webdav.py @@ -0,0 +1,20 @@ +import pytest + +from dvc.exceptions import HTTPError +from dvc.path_info import WebdavURLInfo +from dvc.remote.webdav import RemoteWEBDAV +from tests.utils.httpd import StaticFileServer, WebDavSimpleHandler + + +def test_create_collections(dvc): + with StaticFileServer(handler_class=WebDavSimpleHandler) as httpd: + url0 = "webdav://localhost:{}/a/b/file.txt".format(httpd.server_port) + url1 = "webdav://localhost:{}/a/c/file.txt".format(httpd.server_port) + config = {"url": url0} + + remote = RemoteWEBDAV(dvc, config) + + remote._create_collections(WebdavURLInfo(url0)) + + with pytest.raises(HTTPError): + remote._create_collections(WebdavURLInfo(url1)) diff --git a/tests/unit/test_path_info.py b/tests/unit/test_path_info.py index 0b202fa124..075bb823d3 100644 --- a/tests/unit/test_path_info.py +++ b/tests/unit/test_path_info.py @@ -7,6 +7,7 @@ from dvc.path_info import HTTPURLInfo from dvc.path_info import PathInfo from dvc.path_info import URLInfo +from dvc.path_info import WebdavURLInfo TEST_DEPTH = len(pathlib.Path(__file__).parents) + 1 @@ -89,3 +90,10 @@ def test_https_url_info_str(): def test_path_info_as_posix(mocker, path, as_posix, osname): mocker.patch("os.name", osname) assert PathInfo(path).as_posix() == as_posix + + +def test_webdav_url_info_str(): + u1 = WebdavURLInfo("webdav://test.com/t1") + u2 = WebdavURLInfo("webdavs://test.com/t1") + assert u1.url == "http://test.com/t1" + assert u2.url == "https://test.com/t1" diff --git a/tests/utils/httpd.py b/tests/utils/httpd.py index 378bb75b3f..f0a48b6f13 100644 --- a/tests/utils/httpd.py +++ b/tests/utils/httpd.py @@ -65,6 +65,28 @@ def do_POST(self): self.end_headers() +class WebDavSimpleHandler(SimpleHTTPRequestHandler): + def do_HEAD(self): + if self.path == "/a/": + self.send_response(HTTPStatus.OK) + elif self.path == "/a/b/": + self.send_response(HTTPStatus.OK) + elif self.path == "/a/c/": + self.send_response(HTTPStatus.BAD_REQUEST) + else: + self.send_response(HTTPStatus.BAD_REQUEST) + self.end_headers() + + def do_MKCOL(self): + if self.path == "/a/b/": + self.send_response(HTTPStatus.CREATED) + elif self.path == "/a/c/": + self.send_response(HTTPStatus.BAD_REQUEST) + else: + self.send_response(HTTPStatus.BAD_REQUEST) + self.end_headers() + + class StaticFileServer: _lock = threading.Lock()