From 374614f54a5bcba8bf9e2bcf048cb80720e828d1 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 14:05:32 +0900 Subject: [PATCH 01/15] remote: move upload()/download() into tree --- dvc/output/base.py | 2 +- dvc/remote/azure.py | 44 +++---- dvc/remote/base.py | 240 ++++++++++++++++++------------------- dvc/remote/gdrive.py | 28 +++-- dvc/remote/gs.py | 46 +++---- dvc/remote/hdfs.py | 26 ++-- dvc/remote/http.py | 64 +++++----- dvc/remote/local.py | 44 +++---- dvc/remote/oss.py | 32 ++--- dvc/remote/s3.py | 54 ++++----- dvc/remote/ssh/__init__.py | 40 +++---- 11 files changed, 311 insertions(+), 309 deletions(-) diff --git a/dvc/output/base.py b/dvc/output/base.py index fe3c77c664..d25b79a5df 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -300,7 +300,7 @@ def verify_metric(self): raise DvcException(f"verify metric is not supported for {self.scheme}") def download(self, to): - self.remote.download(self.path_info, to.path_info) + self.remote.tree.download(self.path_info, to.path_info) def checkout( self, diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 0b662b4a08..383c1947d1 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -46,6 +46,28 @@ def remove(self, path_info): logger.debug(f"Removing {path_info}") self.blob_service.delete_blob(path_info.bucket, path_info.path) + def _upload( + self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs + ): + with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: + self.blob_service.create_blob_from_path( + to_info.bucket, + to_info.path, + from_file, + progress_callback=pbar.update_to, + ) + + def _download( + self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs + ): + with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: + self.blob_service.get_blob_to_path( + from_info.bucket, + from_info.path, + to_file, + progress_callback=pbar.update_to, + ) + class AzureRemote(BaseRemote): scheme = Schemes.AZURE @@ -127,25 +149,3 @@ def list_cache_paths(self, prefix=None, progress_callback=None): return self.list_paths( self.path_info.bucket, prefix, progress_callback ) - - def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs - ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - self.blob_service.create_blob_from_path( - to_info.bucket, - to_info.path, - from_file, - progress_callback=pbar.update_to, - ) - - def _download( - self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs - ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - self.blob_service.get_blob_to_path( - from_info.bucket, - from_info.path, - to_file, - progress_callback=pbar.update_to, - ) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 60d502dbec..7f3f1bb07a 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -168,6 +168,125 @@ def hardlink(self, from_info, to_info): def reflink(self, from_info, to_info): raise RemoteActionNotImplemented("reflink", self.scheme) + def _handle_transfer_exception( + self, from_info, to_info, exception, operation + ): + if isinstance(exception, OSError) and exception.errno == errno.EMFILE: + raise exception + + logger.exception( + "failed to %s '%s' to '%s'", operation, from_info, to_info + ) + return 1 + + def upload(self, from_info, to_info, name=None, no_progress_bar=False): + if not hasattr(self, "_upload"): + raise RemoteActionNotImplemented("upload", self.scheme) + + if to_info.scheme != self.scheme: + raise NotImplementedError + + if from_info.scheme != "local": + raise NotImplementedError + + logger.debug("Uploading '%s' to '%s'", from_info, to_info) + + name = name or from_info.name + + try: + self._upload( + from_info.fspath, + to_info, + name=name, + no_progress_bar=no_progress_bar, + ) + except Exception as e: + return self._handle_transfer_exception( + from_info, to_info, e, "upload" + ) + + return 0 + + def download( + self, + from_info, + to_info, + name=None, + no_progress_bar=False, + file_mode=None, + dir_mode=None, + ): + if not hasattr(self, "_download"): + raise RemoteActionNotImplemented("download", self.scheme) + + if from_info.scheme != self.scheme: + raise NotImplementedError + + if to_info.scheme == self.scheme != "local": + self.copy(from_info, to_info) + return 0 + + if to_info.scheme != "local": + raise NotImplementedError + + if self.isdir(from_info): + return self._download_dir( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + return self._download_file( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + + def _download_dir( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): + from_infos = list(self.walk_files(from_info)) + to_infos = ( + to_info / info.relative_to(from_info) for info in from_infos + ) + + with Tqdm( + total=len(from_infos), + desc="Downloading directory", + unit="Files", + disable=no_progress_bar, + ) as pbar: + download_files = pbar.wrap_fn( + partial( + self._download_file, + name=name, + no_progress_bar=True, + file_mode=file_mode, + dir_mode=dir_mode, + ) + ) + with ThreadPoolExecutor(max_workers=self.JOBS) as executor: + futures = executor.map(download_files, from_infos, to_infos) + return sum(futures) + + def _download_file( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): + makedirs(to_info.parent, exist_ok=True, mode=dir_mode) + + logger.debug("Downloading '%s' to '%s'", from_info, to_info) + name = name or to_info.name + + tmp_file = tmp_fname(to_info) + + try: + self._download( + from_info, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + except Exception as e: + return self._handle_transfer_exception( + from_info, to_info, e, "download" + ) + + move(tmp_file, to_info, mode=file_mode) + + return 0 + class BaseRemote: scheme = "base" @@ -374,7 +493,7 @@ def _get_dir_info_checksum(self, dir_info): from_info = PathInfo(tmp) to_info = self.cache.path_info / tmp_fname("") - self.cache.upload(from_info, to_info, no_progress_bar=True) + self.cache.tree.upload(from_info, to_info, no_progress_bar=True) checksum = self.get_file_checksum(to_info) + self.CHECKSUM_DIR_SUFFIX return checksum, to_info @@ -694,125 +813,6 @@ def _save(self, path_info, checksum, save_link=True, tree=None, **kwargs): def open(self, *args, **kwargs): return self.tree.open(*args, **kwargs) - def _handle_transfer_exception( - self, from_info, to_info, exception, operation - ): - if isinstance(exception, OSError) and exception.errno == errno.EMFILE: - raise exception - - logger.exception( - "failed to %s '%s' to '%s'", operation, from_info, to_info - ) - return 1 - - def upload(self, from_info, to_info, name=None, no_progress_bar=False): - if not hasattr(self, "_upload"): - raise RemoteActionNotImplemented("upload", self.scheme) - - if to_info.scheme != self.scheme: - raise NotImplementedError - - if from_info.scheme != "local": - raise NotImplementedError - - logger.debug("Uploading '%s' to '%s'", from_info, to_info) - - name = name or from_info.name - - try: - self._upload( - from_info.fspath, - to_info, - name=name, - no_progress_bar=no_progress_bar, - ) - except Exception as e: - return self._handle_transfer_exception( - from_info, to_info, e, "upload" - ) - - return 0 - - def download( - self, - from_info, - to_info, - name=None, - no_progress_bar=False, - file_mode=None, - dir_mode=None, - ): - if not hasattr(self, "_download"): - raise RemoteActionNotImplemented("download", self.scheme) - - if from_info.scheme != self.scheme: - raise NotImplementedError - - if to_info.scheme == self.scheme != "local": - self.tree.copy(from_info, to_info) - return 0 - - if to_info.scheme != "local": - raise NotImplementedError - - if self.tree.isdir(from_info): - return self._download_dir( - from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ) - return self._download_file( - from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ) - - def _download_dir( - self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ): - from_infos = list(self.tree.walk_files(from_info)) - to_infos = ( - to_info / info.relative_to(from_info) for info in from_infos - ) - - with Tqdm( - total=len(from_infos), - desc="Downloading directory", - unit="Files", - disable=no_progress_bar, - ) as pbar: - download_files = pbar.wrap_fn( - partial( - self._download_file, - name=name, - no_progress_bar=True, - file_mode=file_mode, - dir_mode=dir_mode, - ) - ) - with ThreadPoolExecutor(max_workers=self.JOBS) as executor: - futures = executor.map(download_files, from_infos, to_infos) - return sum(futures) - - def _download_file( - self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ): - makedirs(to_info.parent, exist_ok=True, mode=dir_mode) - - logger.debug("Downloading '%s' to '%s'", from_info, to_info) - name = name or to_info.name - - tmp_file = tmp_fname(to_info) - - try: - self._download( - from_info, tmp_file, name=name, no_progress_bar=no_progress_bar - ) - except Exception as e: - return self._handle_transfer_exception( - from_info, to_info, e, "download" - ) - - move(tmp_file, to_info, mode=file_mode) - - return 0 - def path_to_checksum(self, path): parts = self.path_cls(path).parts[-2:] diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 3cef0760ee..970d823d6d 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -100,6 +100,21 @@ def remove(self, path_info): item_id = self.remote.get_item_id(path_info) self.remote.gdrive_delete_file(item_id) + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + dirname = to_info.parent + assert dirname + parent_id = self.get_item_id(dirname, True) + + self._gdrive_upload_file( + parent_id, to_info.name, no_progress_bar, from_file, name + ) + + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + item_id = self.remote.get_item_id(from_info) + self.remote._gdrive_download_file( + item_id, to_file, name, no_progress_bar + ) + class GDriveRemote(BaseRemote): scheme = Schemes.GDRIVE @@ -523,19 +538,6 @@ def get_item_id(self, path_info, create=False, use_cache=True, hint=None): assert not create raise GDrivePathNotFound(path_info, hint) - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - dirname = to_info.parent - assert dirname - parent_id = self.get_item_id(dirname, True) - - self._gdrive_upload_file( - parent_id, to_info.name, no_progress_bar, from_file, name - ) - - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - item_id = self.get_item_id(from_info) - self._gdrive_download_file(item_id, to_file, name, no_progress_bar) - def list_cache_paths(self, prefix=None, progress_callback=None): if not self._ids_cache["ids"]: return diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 5b42c8f24f..8395c6e4b4 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -134,6 +134,29 @@ def copy(self, from_info, to_info): to_bucket = self.gs.bucket(to_info.bucket) from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path) + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + bucket = self.gs.bucket(to_info.bucket) + _upload_to_bucket( + bucket, + from_file, + to_info, + name=name, + no_progress_bar=no_progress_bar, + ) + + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + bucket = self.gs.bucket(from_info.bucket) + blob = bucket.get_blob(from_info.path) + with open(to_file, mode="wb") as fobj: + with Tqdm.wrapattr( + fobj, + "write", + desc=name or from_info.path, + total=blob.size, + disable=no_progress_bar, + ) as wrapped: + blob.download_to_file(wrapped) + class GSRemote(BaseRemote): scheme = Schemes.GS @@ -194,26 +217,3 @@ def list_cache_paths(self, prefix=None, progress_callback=None): return self.list_paths( self.path_info, prefix=prefix, progress_callback=progress_callback ) - - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - bucket = self.gs.bucket(to_info.bucket) - _upload_to_bucket( - bucket, - from_file, - to_info, - name=name, - no_progress_bar=no_progress_bar, - ) - - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - bucket = self.gs.bucket(from_info.bucket) - blob = bucket.get_blob(from_info.path) - with open(to_file, mode="wb") as fobj: - with Tqdm.wrapattr( - fobj, - "write", - desc=name or from_info.path, - total=blob.size, - disable=no_progress_bar, - ) as wrapped: - blob.download_to_file(wrapped) diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 4e9430ab7c..fa97c431a1 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -72,6 +72,19 @@ def copy(self, from_info, to_info, **_kwargs): self.remove(tmp_info) raise + def _upload(self, from_file, to_info, **_kwargs): + with self.hdfs(to_info) as hdfs: + hdfs.mkdir(posixpath.dirname(to_info.path)) + tmp_file = tmp_fname(to_info.path) + with open(from_file, "rb") as fobj: + hdfs.upload(tmp_file, fobj) + hdfs.rename(tmp_file, to_info.path) + + def _download(self, from_info, to_file, **_kwargs): + with self.hdfs(from_info) as hdfs: + with open(to_file, "wb+") as fobj: + hdfs.download(from_info.path, fobj) + class HDFSRemote(BaseRemote): scheme = Schemes.HDFS @@ -148,19 +161,6 @@ def get_file_checksum(self, path_info): ) return self._group(regex, stdout, "checksum") - def _upload(self, from_file, to_info, **_kwargs): - with self.hdfs(to_info) as hdfs: - hdfs.mkdir(posixpath.dirname(to_info.path)) - tmp_file = tmp_fname(to_info.path) - with open(from_file, "rb") as fobj: - hdfs.upload(tmp_file, fobj) - hdfs.rename(tmp_file, to_info.path) - - def _download(self, from_info, to_file, **_kwargs): - with self.hdfs(from_info) as hdfs: - with open(to_file, "wb+") as fobj: - hdfs.download(from_info.path, fobj) - def list_cache_paths(self, prefix=None, progress_callback=None): if not self.tree.exists(self.path_info): return diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 62252ea720..32ede406cc 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -27,38 +27,8 @@ class HTTPRemoteTree(BaseRemoteTree): def exists(self, path_info): return bool(self.remote.request("HEAD", path_info.url)) - -class HTTPRemote(BaseRemote): - scheme = Schemes.HTTP - path_cls = HTTPURLInfo - SESSION_RETRIES = 5 - SESSION_BACKOFF_FACTOR = 0.1 - REQUEST_TIMEOUT = 10 - CHUNK_SIZE = 2 ** 16 - PARAM_CHECKSUM = "etag" - CAN_TRAVERSE = False - TREE_CLS = HTTPRemoteTree - - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url") - if url: - self.path_info = self.path_cls(url) - user = config.get("user", None) - if user: - self.path_info.user = user - else: - self.path_info = None - - self.auth = config.get("auth", None) - self.custom_auth_header = config.get("custom_auth_header", None) - self.password = config.get("password", None) - self.ask_password = config.get("ask_password", False) - self.headers = {} - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - response = self.request("GET", from_info.url, stream=True) + response = self.remote.request("GET", from_info.url, stream=True) if response.status_code != 200: raise HTTPError(response.status_code, response.reason) with open(to_file, "wb") as fd: @@ -94,7 +64,7 @@ def chunks(): break yield chunk - response = self.request("POST", to_info.url, data=chunks()) + response = self.remote.request("POST", to_info.url, data=chunks()) if response.status_code not in (200, 201): raise HTTPError(response.status_code, response.reason) @@ -102,6 +72,36 @@ def _content_length(self, response): res = response.headers.get("Content-Length") return int(res) if res else None + +class HTTPRemote(BaseRemote): + scheme = Schemes.HTTP + path_cls = HTTPURLInfo + SESSION_RETRIES = 5 + SESSION_BACKOFF_FACTOR = 0.1 + REQUEST_TIMEOUT = 10 + CHUNK_SIZE = 2 ** 16 + PARAM_CHECKSUM = "etag" + CAN_TRAVERSE = False + TREE_CLS = HTTPRemoteTree + + def __init__(self, repo, config): + super().__init__(repo, config) + + url = config.get("url") + if url: + self.path_info = self.path_cls(url) + user = config.get("user", None) + if user: + self.path_info.user = user + else: + self.path_info = None + + self.auth = config.get("auth", None) + self.custom_auth_header = config.get("custom_auth_header", None) + self.password = config.get("password", None) + self.ask_password = config.get("ask_password", False) + self.headers = {} + def get_file_checksum(self, path_info): url = path_info.url headers = self.request("HEAD", url).headers diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 6ae7451326..860cf81378 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -208,6 +208,26 @@ def reflink(self, from_info, to_info): def getsize(path_info): return os.path.getsize(path_info) + def _upload( + self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs + ): + makedirs(to_info.parent, exist_ok=True) + + tmp_file = tmp_fname(to_info) + copyfile( + from_file, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + + self.remote.protect(tmp_file) + os.rename(tmp_file, to_info) + + def _download( + self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs + ): + copyfile( + from_info, to_file, no_progress_bar=no_progress_bar, name=name + ) + class LocalRemote(BaseRemote): scheme = Schemes.LOCAL @@ -309,26 +329,6 @@ def cache_exists(self, checksums, jobs=None, name=None): if not self.changed_cache_file(checksum) ] - def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs - ): - makedirs(to_info.parent, exist_ok=True) - - tmp_file = tmp_fname(to_info) - copyfile( - from_file, tmp_file, name=name, no_progress_bar=no_progress_bar - ) - - self.protect(tmp_file) - os.rename(tmp_file, to_info) - - def _download( - self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs - ): - copyfile( - from_info, to_file, no_progress_bar=no_progress_bar, name=name - ) - @index_locked def status( self, @@ -511,14 +511,14 @@ def _process( if download: func = partial( - remote.download, + remote.tree.download, dir_mode=self.tree.dir_mode, file_mode=self.tree.file_mode, ) status = STATUS_DELETED desc = "Downloading" else: - func = remote.upload + func = remote.tree.upload status = STATUS_NEW desc = "Uploading" diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index bf29d79064..d690fd58db 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -34,6 +34,22 @@ def remove(self, path_info): logger.debug(f"Removing oss://{path_info}") self.oss_service.delete_object(path_info.path) + def _upload( + self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs + ): + with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: + self.oss_service.put_object_from_file( + to_info.path, from_file, progress_callback=pbar.update_to + ) + + def _download( + self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs + ): + with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: + self.oss_service.get_object_to_file( + from_info.path, to_file, progress_callback=pbar.update_to + ) + class OSSRemote(BaseRemote): """ @@ -122,19 +138,3 @@ def list_cache_paths(self, prefix=None, progress_callback=None): else: prefix = self.path_info.path return self.list_paths(prefix, progress_callback) - - def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs - ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - self.oss_service.put_object_from_file( - to_info.path, from_file, progress_callback=pbar.update_to - ) - - def _download( - self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs - ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - self.oss_service.get_object_to_file( - from_info.path, to_file, progress_callback=pbar.update_to - ) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index e98b12fea0..9ede8b2e17 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -193,6 +193,33 @@ def _copy(cls, s3, from_info, to_info, extra_args): if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + total = os.path.getsize(from_file) + with Tqdm( + disable=no_progress_bar, total=total, bytes=True, desc=name + ) as pbar: + self.s3.upload_file( + from_file, + to_info.bucket, + to_info.path, + Callback=pbar.update, + ExtraArgs=self.extra_args, + ) + + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + if no_progress_bar: + total = None + else: + total = self.s3.head_object( + Bucket=from_info.bucket, Key=from_info.path + )["ContentLength"] + with Tqdm( + disable=no_progress_bar, total=total, bytes=True, desc=name + ) as pbar: + self.s3.download_file( + from_info.bucket, from_info.path, to_file, Callback=pbar.update + ) + class S3Remote(BaseRemote): scheme = Schemes.S3 @@ -305,33 +332,6 @@ def list_cache_paths(self, prefix=None, progress_callback=None): self.path_info, prefix=prefix, progress_callback=progress_callback ) - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - total = os.path.getsize(from_file) - with Tqdm( - disable=no_progress_bar, total=total, bytes=True, desc=name - ) as pbar: - self.s3.upload_file( - from_file, - to_info.bucket, - to_info.path, - Callback=pbar.update, - ExtraArgs=self.extra_args, - ) - - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - if no_progress_bar: - total = None - else: - total = self.s3.head_object( - Bucket=from_info.bucket, Key=from_info.path - )["ContentLength"] - with Tqdm( - disable=no_progress_bar, total=total, bytes=True, desc=name - ) as pbar: - self.s3.download_file( - from_info.bucket, from_info.path, to_file, Callback=pbar.update - ) - def _append_aws_grants_to_extra_args(self, config): # Keys for extra_args can be one of the following list: # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 416b2d4eed..114567ef5c 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -131,6 +131,26 @@ def getsize(self, path_info): with self.ssh(path_info) as ssh: return ssh.getsize(path_info.path) + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + assert from_info.isin(self.path_info) + with self.ssh(self.path_info) as ssh: + ssh.download( + from_info.path, + to_file, + progress_title=name, + no_progress_bar=no_progress_bar, + ) + + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + assert to_info.isin(self.path_info) + with self.ssh(self.path_info) as ssh: + ssh.upload( + from_file, + to_info.path, + progress_title=name, + no_progress_bar=no_progress_bar, + ) + class SSHRemote(BaseRemote): scheme = Schemes.SSH @@ -255,26 +275,6 @@ def get_file_checksum(self, path_info): with self.ssh(path_info) as ssh: return ssh.md5(path_info.path) - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - assert from_info.isin(self.path_info) - with self.ssh(self.path_info) as ssh: - ssh.download( - from_info.path, - to_file, - progress_title=name, - no_progress_bar=no_progress_bar, - ) - - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - assert to_info.isin(self.path_info) - with self.ssh(self.path_info) as ssh: - ssh.upload( - from_file, - to_info.path, - progress_title=name, - no_progress_bar=no_progress_bar, - ) - def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: root = posixpath.join(self.path_info.path, prefix[:2]) From afa9fa05064c8dc133a56828df84f7f4fb2b3b0c Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 14:46:16 +0900 Subject: [PATCH 02/15] remote: move path_info/path_cls into tree --- dvc/output/base.py | 2 +- dvc/output/local.py | 4 ++-- dvc/remote/base.py | 20 ++++++++++---------- dvc/remote/local.py | 11 +++++++---- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/dvc/output/base.py b/dvc/output/base.py index d25b79a5df..9c347fee77 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -119,7 +119,7 @@ def _parse_path(self, remote, path): if remote: parsed = urlparse(path) return remote.path_info / parsed.path.lstrip("/") - return self.REMOTE.path_cls(path) + return self.REMOTE.TREE_CLS.PATH_CLS(path) def __repr__(self): return "{class_name}: '{def_path}'".format( diff --git a/dvc/output/local.py b/dvc/output/local.py index f84d3478b4..3f992d7211 100644 --- a/dvc/output/local.py +++ b/dvc/output/local.py @@ -33,12 +33,12 @@ def _parse_path(self, remote, path): # # FIXME: if we have Windows path containing / or posix one with \ # then we have #2059 bug and can't really handle that. - p = self.REMOTE.path_cls(path) + p = self.REMOTE.TREE_CLS.PATH_CLS(path) if not p.is_absolute(): p = self.stage.wdir / p abs_p = os.path.abspath(os.path.normpath(p)) - return self.REMOTE.path_cls(abs_p) + return self.REMOTE.TREE_CLS.PATH_CLS(abs_p) def __str__(self): if not self.is_in_repo: diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 7f3f1bb07a..4e9d480834 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -85,6 +85,7 @@ def wrapper(remote_obj, *args, **kwargs): class BaseRemoteTree: SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} + PATH_CLS = URLInfo def __init__(self, remote, config): self.remote = remote @@ -103,10 +104,6 @@ def dir_mode(self): def scheme(self): return self.remote.scheme - @property - def path_cls(self): - return self.remote.path_cls - def open(self, path_info, mode="r", encoding=None): if hasattr(self, "_generate_download_url"): get_url = partial(self._generate_download_url, path_info) @@ -133,7 +130,7 @@ def iscopy(self, path_info): """Check if this file is an independent copy.""" return False # We can't be sure by default - def walk_files(self, path_info): + def walk_files(self, path_info, **kwargs): """Return a generator with `PathInfo`s to all the files""" raise NotImplementedError @@ -260,7 +257,7 @@ def _download_dir( dir_mode=dir_mode, ) ) - with ThreadPoolExecutor(max_workers=self.JOBS) as executor: + with ThreadPoolExecutor(max_workers=self.remote.JOBS) as executor: futures = executor.map(download_files, from_infos, to_infos) return sum(futures) @@ -290,7 +287,6 @@ def _download_file( class BaseRemote: scheme = "base" - path_cls = URLInfo REQUIRES = {} JOBS = 4 * cpu_count() INDEX_CLS = RemoteIndex @@ -338,6 +334,10 @@ def __init__(self, repo, config): self.tree = self.TREE_CLS(self, config) + @property + def path_info(self): + return self.tree.path_info + @classmethod def get_missing_deps(cls): import importlib @@ -529,12 +529,12 @@ def load_dir_cache(self, checksum): ) return [] - if self.path_cls == WindowsPathInfo: + if self.tree.PATH_CLS == WindowsPathInfo: # only need to convert it for Windows for info in d: # NOTE: here is a BUG, see comment to .as_posix() below info[self.PARAM_RELPATH] = info[self.PARAM_RELPATH].replace( - "/", self.path_cls.sep + "/", self.tree.PATH_CLS.sep ) return d @@ -814,7 +814,7 @@ def open(self, *args, **kwargs): return self.tree.open(*args, **kwargs) def path_to_checksum(self, path): - parts = self.path_cls(path).parts[-2:] + parts = self.tree.PATH_CLS(path).parts[-2:] if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): raise ValueError(f"Bad cache file path '{path}'") diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 860cf81378..c639d26625 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -39,6 +39,11 @@ class LocalRemoteTree(BaseRemoteTree): SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} + PATH_CLS = PathInfo + + def __init__(self, remote, config): + super().__init__(remote, config) + self.path_info = config.get("url") @property def repo(self): @@ -231,7 +236,6 @@ def _download( class LocalRemote(BaseRemote): scheme = Schemes.LOCAL - path_cls = PathInfo PARAM_CHECKSUM = "md5" PARAM_PATH = "path" TRAVERSE_PREFIX_LEN = 2 @@ -247,7 +251,6 @@ class LocalRemote(BaseRemote): def __init__(self, repo, config): super().__init__(repo, config) self.cache_dir = config.get("url") - self._dir_info = {} @property def state(self): @@ -255,11 +258,11 @@ def state(self): @property def cache_dir(self): - return self.path_info.fspath if self.path_info else None + return self.tree.path_info.fspath if self.tree.path_info else None @cache_dir.setter def cache_dir(self, value): - self.path_info = PathInfo(value) if value else None + self.tree.path_info = PathInfo(value) if value else None @classmethod def supported(cls, config): From d278d27768351746f46090417baee9f7436ee9da Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 15:05:50 +0900 Subject: [PATCH 03/15] remote.azure: finish moving methods into tree --- dvc/remote/azure.py | 143 ++++++++++++++++---------------- tests/unit/remote/test_azure.py | 12 +-- 2 files changed, 79 insertions(+), 76 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 383c1947d1..240be41fac 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -1,6 +1,5 @@ import logging import os -import posixpath import threading from datetime import datetime, timedelta @@ -15,9 +14,47 @@ class AzureRemoteTree(BaseRemoteTree): - @property + PATH_CLS = CloudURLInfo + + def __init__(self, remote, config): + super().__init__(remote, config) + + url = config.get("url", "azure://") + self.path_info = self.PATH_CLS(url) + + if not self.path_info.bucket: + container = os.getenv("AZURE_STORAGE_CONTAINER_NAME") + self.path_info = self.PATH_CLS(f"azure://{container}") + + self.connection_string = config.get("connection_string") or os.getenv( + "AZURE_STORAGE_CONNECTION_STRING" + ) + + @wrap_prop(threading.Lock()) + @cached_property def blob_service(self): - return self.remote.blob_service + from azure.storage.blob import BlockBlobService + from azure.common import AzureMissingResourceHttpError + + logger.debug(f"URL {self.path_info}") + logger.debug(f"Connection string {self.connection_string}") + blob_service = BlockBlobService( + connection_string=self.connection_string + ) + logger.debug(f"Container name {self.path_info.bucket}") + try: # verify that container exists + blob_service.list_blobs( + self.path_info.bucket, delimiter="/", num_results=1 + ) + except AzureMissingResourceHttpError: + blob_service.create_container(self.path_info.bucket) + return blob_service + + def get_etag(self, path_info): + etag = self.blob_service.get_blob_properties( + path_info.bucket, path_info.path + ).properties.etag + return etag.strip('"') def _generate_download_url(self, path_info, expires=3600): from azure.storage.blob import BlobPermissions @@ -36,9 +73,36 @@ def _generate_download_url(self, path_info, expires=3600): return download_url def exists(self, path_info): - paths = self.remote.list_paths(path_info.bucket, path_info.path) + paths = self._list_paths(path_info.bucket, path_info.path) return any(path_info.path == path for path in paths) + def _list_paths(self, bucket, prefix, progress_callback=None): + blob_service = self.blob_service + next_marker = None + while True: + blobs = blob_service.list_blobs( + bucket, prefix=prefix, marker=next_marker + ) + + for blob in blobs: + if progress_callback: + progress_callback() + yield blob.name + + if not blobs.next_marker: + break + + next_marker = blobs.next_marker + + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths( + path_info.bucket, path_info.path, **kwargs + ): + if fname.endswith("/"): + continue + + yield path_info.replace(path=fname) + def remove(self, path_info): if path_info.scheme != self.scheme: raise NotImplementedError @@ -71,81 +135,20 @@ def _download( class AzureRemote(BaseRemote): scheme = Schemes.AZURE - path_cls = CloudURLInfo REQUIRES = {"azure-storage-blob": "azure.storage.blob"} PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 5000 TREE_CLS = AzureRemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url", "azure://") - self.path_info = self.path_cls(url) - - if not self.path_info.bucket: - container = os.getenv("AZURE_STORAGE_CONTAINER_NAME") - self.path_info = self.path_cls(f"azure://{container}") - - self.connection_string = config.get("connection_string") or os.getenv( - "AZURE_STORAGE_CONNECTION_STRING" - ) - - @wrap_prop(threading.Lock()) - @cached_property - def blob_service(self): - from azure.storage.blob import BlockBlobService - from azure.common import AzureMissingResourceHttpError - - logger.debug(f"URL {self.path_info}") - logger.debug(f"Connection string {self.connection_string}") - blob_service = BlockBlobService( - connection_string=self.connection_string - ) - logger.debug(f"Container name {self.path_info.bucket}") - try: # verify that container exists - blob_service.list_blobs( - self.path_info.bucket, delimiter="/", num_results=1 - ) - except AzureMissingResourceHttpError: - blob_service.create_container(self.path_info.bucket) - return blob_service - - def get_etag(self, path_info): - etag = self.blob_service.get_blob_properties( - path_info.bucket, path_info.path - ).properties.etag - return etag.strip('"') - def get_file_checksum(self, path_info): - return self.get_etag(path_info) - - def list_paths(self, bucket, prefix, progress_callback=None): - blob_service = self.blob_service - next_marker = None - while True: - blobs = blob_service.list_blobs( - bucket, prefix=prefix, marker=next_marker - ) - - for blob in blobs: - if progress_callback: - progress_callback() - yield blob.name - - if not blobs.next_marker: - break - - next_marker = blobs.next_marker + return self.tree.get_etag(path_info) def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: - prefix = posixpath.join( - self.path_info.path, prefix[:2], prefix[2:] - ) + path_info = self.path_info / prefix[:2] / prefix[2:] else: - prefix = self.path_info.path - return self.list_paths( - self.path_info.bucket, prefix, progress_callback + path_info = self.path_info + return self.tree.walk_files( + path_info, progress_callback=progress_callback ) diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index b5462193da..b49f2ee631 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -19,8 +19,8 @@ def test_init_env_var(monkeypatch, dvc): config = {"url": "azure://"} remote = AzureRemote(dvc, config) - assert remote.path_info == "azure://" + container_name - assert remote.connection_string == connection_string + assert remote.tree.path_info == "azure://" + container_name + assert remote.tree.connection_string == connection_string def test_init(dvc): @@ -28,8 +28,8 @@ def test_init(dvc): url = f"azure://{container_name}/{prefix}" config = {"url": url, "connection_string": connection_string} remote = AzureRemote(dvc, config) - assert remote.path_info == url - assert remote.connection_string == connection_string + assert remote.tree.path_info == url + assert remote.tree.connection_string == connection_string def test_get_file_checksum(tmp_dir): @@ -39,8 +39,8 @@ def test_get_file_checksum(tmp_dir): tmp_dir.gen("foo", "foo") remote = AzureRemote(None, {}) - to_info = remote.path_cls(Azure.get_url()) - remote.upload(PathInfo("foo"), to_info) + to_info = remote.tree.PATH_CLS(Azure.get_url()) + remote.tree.upload(PathInfo("foo"), to_info) assert remote.tree.exists(to_info) checksum = remote.get_file_checksum(to_info) assert checksum From c561d0fb729c9fe984847072eb37b1d4037ee5ea Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 15:55:12 +0900 Subject: [PATCH 04/15] remote.gs: finish moving methods into tree --- dvc/remote/gs.py | 78 +++++++++++++++++------------------- tests/unit/remote/test_gs.py | 12 +++--- 2 files changed, 42 insertions(+), 48 deletions(-) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 8395c6e4b4..31462b4a56 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -1,6 +1,5 @@ import logging import os.path -import posixpath import threading from datetime import timedelta from functools import wraps @@ -66,9 +65,27 @@ def _upload_to_bucket( class GSRemoteTree(BaseRemoteTree): - @property + PATH_CLS = CloudURLInfo + + def __init__(self, remote, config): + super().__init__(remote, config) + + url = config.get("url", "gs:///") + self.path_info = self.PATH_CLS(url) + + self.projectname = config.get("projectname", None) + self.credentialpath = config.get("credentialpath") + + @wrap_prop(threading.Lock()) + @cached_property def gs(self): - return self.remote.gs + from google.cloud.storage import Client + + return ( + Client.from_service_account_json(self.credentialpath) + if self.credentialpath + else Client(self.projectname) + ) def _generate_download_url(self, path_info, expires=3600): expiration = timedelta(seconds=int(expires)) @@ -89,7 +106,7 @@ def exists(self, path_info): def isdir(self, path_info): dir_path = path_info / "" - return bool(list(self.remote.list_paths(dir_path, max_items=1))) + return bool(list(self._list_paths(dir_path, max_items=1))) def isfile(self, path_info): if path_info.path.endswith("/"): @@ -98,8 +115,16 @@ def isfile(self, path_info): blob = self.gs.bucket(path_info.bucket).blob(path_info.path) return blob.exists() - def walk_files(self, path_info): - for fname in self.remote.list_paths(path_info / ""): + def _list_paths(self, path_info, max_items=None, progress_callback=None): + for blob in self.gs.bucket(path_info.bucket).list_blobs( + prefix=path_info.path, max_results=max_items + ): + if progress_callback: + progress_callback() + yield blob.name + + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths(path_info / "", **kwargs): # skip nested empty directories if fname.endswith("/"): continue @@ -160,31 +185,10 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): class GSRemote(BaseRemote): scheme = Schemes.GS - path_cls = CloudURLInfo REQUIRES = {"google-cloud-storage": "google.cloud.storage"} PARAM_CHECKSUM = "md5" TREE_CLS = GSRemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url", "gs:///") - self.path_info = self.path_cls(url) - - self.projectname = config.get("projectname", None) - self.credentialpath = config.get("credentialpath") - - @wrap_prop(threading.Lock()) - @cached_property - def gs(self): - from google.cloud.storage import Client - - return ( - Client.from_service_account_json(self.credentialpath) - if self.credentialpath - else Client(self.projectname) - ) - def get_file_checksum(self, path_info): import base64 import codecs @@ -199,21 +203,11 @@ def get_file_checksum(self, path_info): md5 = base64.b64decode(b64_md5) return codecs.getencoder("hex")(md5)[0].decode("utf-8") - def list_paths( - self, path_info, max_items=None, prefix=None, progress_callback=None - ): + def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: - prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:]) + path_info = self.path_info / prefix[:2] / prefix[2:] else: - prefix = path_info.path - for blob in self.gs.bucket(path_info.bucket).list_blobs( - prefix=path_info.path, max_results=max_items - ): - if progress_callback: - progress_callback() - yield blob.name - - def list_cache_paths(self, prefix=None, progress_callback=None): - return self.list_paths( - self.path_info, prefix=prefix, progress_callback=progress_callback + path_info = self.path_info + return self.tree.walk_files( + path_info, progress_callback=progress_callback ) diff --git a/tests/unit/remote/test_gs.py b/tests/unit/remote/test_gs.py index 3f8643011c..cee6a004b3 100644 --- a/tests/unit/remote/test_gs.py +++ b/tests/unit/remote/test_gs.py @@ -18,16 +18,16 @@ def test_init(dvc): remote = GSRemote(dvc, CONFIG) - assert remote.path_info == URL - assert remote.projectname == PROJECT - assert remote.credentialpath == CREDENTIALPATH + assert remote.tree.path_info == URL + assert remote.tree.projectname == PROJECT + assert remote.tree.credentialpath == CREDENTIALPATH @mock.patch("google.cloud.storage.Client.from_service_account_json") def test_gs(mock_client, dvc): remote = GSRemote(dvc, CONFIG) - assert remote.credentialpath - remote.gs() + assert remote.tree.credentialpath + remote.tree.gs() mock_client.assert_called_once_with(CREDENTIALPATH) @@ -36,7 +36,7 @@ def test_gs_no_credspath(mock_client, dvc): config = CONFIG.copy() del config["credentialpath"] remote = GSRemote(dvc, config) - remote.gs() + remote.tree.gs() mock_client.assert_called_with(PROJECT) From 0666fa1f95ba64ed52372bfdc649a0b399857a85 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 16:03:52 +0900 Subject: [PATCH 05/15] remote.hdfs: finish moving methods to tree --- dvc/remote/hdfs.py | 114 +++++++++++++++++++++++---------------------- 1 file changed, 58 insertions(+), 56 deletions(-) diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index fa97c431a1..c9f4a2f4ee 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -18,9 +18,34 @@ class HDFSRemoteTree(BaseRemoteTree): - @property - def hdfs(self): - return self.remote.hdfs + def __init__(self, remote, config): + super().__init__(remote, config) + + self.path_info = None + url = config.get("url") + if not url: + return + + parsed = urlparse(url) + user = parsed.username or config.get("user") + + self.path_info = self.PATH_CLS.from_parts( + scheme=self.scheme, + host=parsed.hostname, + user=user, + port=parsed.port, + path=parsed.path, + ) + + def hdfs(self, path_info): + import pyarrow + + return get_connection( + pyarrow.hdfs.connect, + path_info.host, + path_info.port, + user=path_info.user, + ) @contextmanager def open(self, path_info, mode="r", encoding=None): @@ -47,6 +72,31 @@ def exists(self, path_info): with self.hdfs(path_info) as hdfs: return hdfs.exists(path_info.path) + def walk_files(self, path_info, progress_callback=None, **kwargs): + if not self.exists(path_info): + return + + root = path_info.path + dirs = deque([root]) + + with self.hdfs(self.path_info) as hdfs: + if not hdfs.exists(root): + return + while dirs: + try: + entries = hdfs.ls(dirs.pop(), detail=True) + for entry in entries: + if entry["kind"] == "directory": + dirs.append(urlparse(entry["name"]).path) + elif entry["kind"] == "file": + if progress_callback: + progress_callback() + yield urlparse(entry["name"]).path + except OSError: + # When searching for a specific prefix pyarrow raises an + # exception if the specified cache dir does not exist + pass + def remove(self, path_info): if path_info.scheme != "hdfs": raise NotImplementedError @@ -94,34 +144,6 @@ class HDFSRemote(BaseRemote): TRAVERSE_PREFIX_LEN = 2 TREE_CLS = HDFSRemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) - self.path_info = None - url = config.get("url") - if not url: - return - - parsed = urlparse(url) - user = parsed.username or config.get("user") - - self.path_info = self.path_cls.from_parts( - scheme=self.scheme, - host=parsed.hostname, - user=user, - port=parsed.port, - path=parsed.path, - ) - - def hdfs(self, path_info): - import pyarrow - - return get_connection( - pyarrow.hdfs.connect, - path_info.host, - path_info.port, - user=path_info.user, - ) - def hadoop_fs(self, cmd, user=None): cmd = "hadoop fs -" + cmd if user: @@ -162,30 +184,10 @@ def get_file_checksum(self, path_info): return self._group(regex, stdout, "checksum") def list_cache_paths(self, prefix=None, progress_callback=None): - if not self.tree.exists(self.path_info): - return - if prefix: - root = posixpath.join(self.path_info.path, prefix[:2]) + path_info = self.path_info / prefix[:2] else: - root = self.path_info.path - dirs = deque([root]) - - with self.hdfs(self.path_info) as hdfs: - if prefix and not hdfs.exists(root): - return - while dirs: - try: - entries = hdfs.ls(dirs.pop(), detail=True) - for entry in entries: - if entry["kind"] == "directory": - dirs.append(urlparse(entry["name"]).path) - elif entry["kind"] == "file": - if progress_callback: - progress_callback() - yield urlparse(entry["name"]).path - except OSError as e: - # When searching for a specific prefix pyarrow raises an - # exception if the specified cache dir does not exist - if not prefix: - raise e + path_info = self.path_info + return self.tree.walk_files( + path_info, progress_callback=progress_callback + ) From b7521ae59aa7bd4f7ba1382c27805b4c6044b992 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 16:18:37 +0900 Subject: [PATCH 06/15] remote.http: finish moving methods into tree --- dvc/remote/http.py | 158 +++++++++++++++++---------------- tests/unit/remote/test_http.py | 18 ++-- 2 files changed, 89 insertions(+), 87 deletions(-) diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 32ede406cc..14aa821ec5 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -24,72 +24,19 @@ def ask_password(host, user): class HTTPRemoteTree(BaseRemoteTree): - def exists(self, path_info): - return bool(self.remote.request("HEAD", path_info.url)) - - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - response = self.remote.request("GET", from_info.url, stream=True) - if response.status_code != 200: - raise HTTPError(response.status_code, response.reason) - with open(to_file, "wb") as fd: - with Tqdm.wrapattr( - fd, - "write", - total=None - if no_progress_bar - else self._content_length(response), - leave=False, - desc=from_info.url if name is None else name, - disable=no_progress_bar, - ) as fd_wrapped: - for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE): - fd_wrapped.write(chunk) + PATH_CLS = HTTPURLInfo - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - def chunks(): - with open(from_file, "rb") as fd: - with Tqdm.wrapattr( - fd, - "read", - total=None - if no_progress_bar - else os.path.getsize(from_file), - leave=False, - desc=to_info.url if name is None else name, - disable=no_progress_bar, - ) as fd_wrapped: - while True: - chunk = fd_wrapped.read(self.CHUNK_SIZE) - if not chunk: - break - yield chunk - - response = self.remote.request("POST", to_info.url, data=chunks()) - if response.status_code not in (200, 201): - raise HTTPError(response.status_code, response.reason) - - def _content_length(self, response): - res = response.headers.get("Content-Length") - return int(res) if res else None - - -class HTTPRemote(BaseRemote): - scheme = Schemes.HTTP - path_cls = HTTPURLInfo SESSION_RETRIES = 5 SESSION_BACKOFF_FACTOR = 0.1 REQUEST_TIMEOUT = 10 CHUNK_SIZE = 2 ** 16 - PARAM_CHECKSUM = "etag" - CAN_TRAVERSE = False - TREE_CLS = HTTPRemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) + def __init__(self, remote, config): + super().__init__(remote, config) url = config.get("url") if url: - self.path_info = self.path_cls(url) + self.path_info = self.PATH_CLS(url) user = config.get("user", None) if user: self.path_info.user = user @@ -102,26 +49,7 @@ def __init__(self, repo, config): self.ask_password = config.get("ask_password", False) self.headers = {} - def get_file_checksum(self, path_info): - url = path_info.url - headers = self.request("HEAD", url).headers - etag = headers.get("ETag") or headers.get("Content-MD5") - - if not etag: - raise DvcException( - "could not find an ETag or " - "Content-MD5 header for '{url}'".format(url=url) - ) - - if etag.startswith("W/"): - raise DvcException( - "Weak ETags are not supported." - " (Etag: '{etag}', URL: '{url}')".format(etag=etag, url=url) - ) - - return etag - - def auth_method(self, path_info=None): + def _auth_method(self, path_info=None): from requests.auth import HTTPBasicAuth, HTTPDigestAuth if path_info is None: @@ -168,7 +96,7 @@ def request(self, method, url, **kwargs): res = self._session.request( method, url, - auth=self.auth_method(), + auth=self._auth_method(), headers=self.headers, **kwargs, ) @@ -190,5 +118,79 @@ def request(self, method, url, **kwargs): except requests.exceptions.RequestException: raise DvcException(f"could not perform a {method} request") + def exists(self, path_info): + return bool(self.request("HEAD", path_info.url)) + + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + response = self.request("GET", from_info.url, stream=True) + if response.status_code != 200: + raise HTTPError(response.status_code, response.reason) + with open(to_file, "wb") as fd: + with Tqdm.wrapattr( + fd, + "write", + total=None + if no_progress_bar + else self._content_length(response), + leave=False, + desc=from_info.url if name is None else name, + disable=no_progress_bar, + ) as fd_wrapped: + for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE): + fd_wrapped.write(chunk) + + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + def chunks(): + with open(from_file, "rb") as fd: + with Tqdm.wrapattr( + fd, + "read", + total=None + if no_progress_bar + else os.path.getsize(from_file), + leave=False, + desc=to_info.url if name is None else name, + disable=no_progress_bar, + ) as fd_wrapped: + while True: + chunk = fd_wrapped.read(self.CHUNK_SIZE) + if not chunk: + break + yield chunk + + response = self.request("POST", to_info.url, data=chunks()) + if response.status_code not in (200, 201): + raise HTTPError(response.status_code, response.reason) + + def _content_length(self, response): + res = response.headers.get("Content-Length") + return int(res) if res else None + + +class HTTPRemote(BaseRemote): + scheme = Schemes.HTTP + PARAM_CHECKSUM = "etag" + CAN_TRAVERSE = False + TREE_CLS = HTTPRemoteTree + + def get_file_checksum(self, path_info): + url = path_info.url + headers = self.tree.request("HEAD", url).headers + etag = headers.get("ETag") or headers.get("Content-MD5") + + if not etag: + raise DvcException( + "could not find an ETag or " + "Content-MD5 header for '{url}'".format(url=url) + ) + + if etag.startswith("W/"): + raise DvcException( + "Weak ETags are not supported." + " (Etag: '{etag}', URL: '{url}')".format(etag=etag, url=url) + ) + + return etag + def gc(self): raise NotImplementedError diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 06998f22f4..b8f0d15d6c 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -14,7 +14,7 @@ def test_download_fails_on_error_code(dvc): remote = HTTPRemote(dvc, config) with pytest.raises(HTTPError): - remote._download(URLInfo(url) / "missing.txt", "missing.txt") + remote.tree._download(URLInfo(url) / "missing.txt", "missing.txt") def test_public_auth_method(dvc): @@ -27,7 +27,7 @@ def test_public_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() is None + assert remote.tree._auth_method() is None def test_basic_auth_method(dvc): @@ -46,8 +46,8 @@ def test_basic_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() == auth - assert isinstance(remote.auth_method(), HTTPBasicAuth) + assert remote.tree._auth_method() == auth + assert isinstance(remote.tree._auth_method(), HTTPBasicAuth) def test_digest_auth_method(dvc): @@ -66,8 +66,8 @@ def test_digest_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() == auth - assert isinstance(remote.auth_method(), HTTPDigestAuth) + assert remote.tree._auth_method() == auth + assert isinstance(remote.tree._auth_method(), HTTPDigestAuth) def test_custom_auth_method(dvc): @@ -83,6 +83,6 @@ def test_custom_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() is None - assert header in remote.headers - assert remote.headers[header] == password + assert remote.tree._auth_method() is None + assert header in remote.tree.headers + assert remote.tree.headers[header] == password From 5006a338d538ce21dca29ecdadb80b2f87374e2b Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 16:26:06 +0900 Subject: [PATCH 07/15] remote.oss: finish moving methods into tree --- dvc/remote/oss.py | 131 ++++++++++++++++++---------------- tests/unit/remote/test_oss.py | 8 +-- 2 files changed, 72 insertions(+), 67 deletions(-) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index d690fd58db..c055682f8c 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -1,6 +1,5 @@ import logging import os -import posixpath import threading from funcy import cached_property, wrap_prop @@ -14,19 +13,78 @@ class OSSRemoteTree(BaseRemoteTree): - @property + PATH_CLS = CloudURLInfo + + def __init__(self, repo, config): + super().__init__(repo, config) + + url = config.get("url") + self.path_info = self.PATH_CLS(url) if url else None + + self.endpoint = config.get("oss_endpoint") or os.getenv("OSS_ENDPOINT") + + self.key_id = ( + config.get("oss_key_id") + or os.getenv("OSS_ACCESS_KEY_ID") + or "defaultId" + ) + + self.key_secret = ( + config.get("oss_key_secret") + or os.getenv("OSS_ACCESS_KEY_SECRET") + or "defaultSecret" + ) + + @wrap_prop(threading.Lock()) + @cached_property def oss_service(self): - return self.remote.oss_service + import oss2 + + logger.debug(f"URL: {self.path_info}") + logger.debug(f"key id: {self.key_id}") + logger.debug(f"key secret: {self.key_secret}") + + auth = oss2.Auth(self.key_id, self.key_secret) + bucket = oss2.Bucket(auth, self.endpoint, self.path_info.bucket) + + # Ensure bucket exists + try: + bucket.get_bucket_info() + except oss2.exceptions.NoSuchBucket: + bucket.create_bucket( + oss2.BUCKET_ACL_PUBLIC_READ, + oss2.models.BucketCreateConfig( + oss2.BUCKET_STORAGE_CLASS_STANDARD + ), + ) + return bucket def _generate_download_url(self, path_info, expires=3600): - assert path_info.bucket == self.remote.path_info.bucket + assert path_info.bucket == self.path_info.bucket return self.oss_service.sign_url("GET", path_info.path, expires) def exists(self, path_info): - paths = self.remote.list_paths(path_info.path) + paths = self.list_paths(path_info.path) return any(path_info.path == path for path in paths) + def _list_files(self, path_info, progress_callback=None): + import oss2 + + for blob in oss2.ObjectIterator( + self.oss_service, prefix=path_info.path + ): + if progress_callback: + progress_callback() + yield blob.key + + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths(path_info, **kwargs): + if fname.endswith("/"): + continue + + yield path_info.replace(path=fname) + def remove(self, path_info): if path_info.scheme != self.scheme: raise NotImplementedError @@ -71,70 +129,17 @@ class OSSRemote(BaseRemote): """ scheme = Schemes.OSS - path_cls = CloudURLInfo REQUIRES = {"oss2": "oss2"} PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 100 TREE_CLS = OSSRemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url") - self.path_info = self.path_cls(url) if url else None - - self.endpoint = config.get("oss_endpoint") or os.getenv("OSS_ENDPOINT") - - self.key_id = ( - config.get("oss_key_id") - or os.getenv("OSS_ACCESS_KEY_ID") - or "defaultId" - ) - - self.key_secret = ( - config.get("oss_key_secret") - or os.getenv("OSS_ACCESS_KEY_SECRET") - or "defaultSecret" - ) - - @wrap_prop(threading.Lock()) - @cached_property - def oss_service(self): - import oss2 - - logger.debug(f"URL: {self.path_info}") - logger.debug(f"key id: {self.key_id}") - logger.debug(f"key secret: {self.key_secret}") - - auth = oss2.Auth(self.key_id, self.key_secret) - bucket = oss2.Bucket(auth, self.endpoint, self.path_info.bucket) - - # Ensure bucket exists - try: - bucket.get_bucket_info() - except oss2.exceptions.NoSuchBucket: - bucket.create_bucket( - oss2.BUCKET_ACL_PUBLIC_READ, - oss2.models.BucketCreateConfig( - oss2.BUCKET_STORAGE_CLASS_STANDARD - ), - ) - return bucket - - def list_paths(self, prefix, progress_callback=None): - import oss2 - - for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): - if progress_callback: - progress_callback() - yield blob.key - def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: - prefix = posixpath.join( - self.path_info.path, prefix[:2], prefix[2:] - ) + path_info = self.path_info / prefix[:2], prefix[2:] else: - prefix = self.path_info.path - return self.list_paths(prefix, progress_callback) + path_info = self.path_info + return self.tree.walk_files( + path_info, progress_callback=progress_callback + ) diff --git a/tests/unit/remote/test_oss.py b/tests/unit/remote/test_oss.py index d6bd515392..3bffe14a43 100644 --- a/tests/unit/remote/test_oss.py +++ b/tests/unit/remote/test_oss.py @@ -16,7 +16,7 @@ def test_init(dvc): "oss_endpoint": endpoint, } remote = OSSRemote(dvc, config) - assert remote.path_info == url - assert remote.endpoint == endpoint - assert remote.key_id == key_id - assert remote.key_secret == key_secret + assert remote.tree.path_info == url + assert remote.tree.endpoint == endpoint + assert remote.tree.key_id == key_id + assert remote.tree.key_secret == key_secret From a5bed95845c6435f1971d635aa6344627451950a Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 16:40:46 +0900 Subject: [PATCH 08/15] remote.s3: finish moving methods into tree --- dvc/remote/s3.py | 280 +++++++++++++++++------------------ tests/unit/remote/test_s3.py | 14 +- 2 files changed, 147 insertions(+), 147 deletions(-) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 9ede8b2e17..e65580eff3 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -1,6 +1,5 @@ import logging import os -import posixpath import threading from funcy import cached_property, wrap_prop @@ -16,9 +15,103 @@ class S3RemoteTree(BaseRemoteTree): - @property + PATH_CLS = CloudURLInfo + + def __init__(self, repo, config): + super().__init__(repo, config) + + url = config.get("url", "s3://") + self.path_info = self.PATH_CLS(url) + + self.region = config.get("region") + self.profile = config.get("profile") + self.endpoint_url = config.get("endpointurl") + + if config.get("listobjects"): + self.list_objects_api = "list_objects" + else: + self.list_objects_api = "list_objects_v2" + + self.use_ssl = config.get("use_ssl", True) + + self.extra_args = {} + + self.sse = config.get("sse") + if self.sse: + self.extra_args["ServerSideEncryption"] = self.sse + + self.sse_kms_key_id = config.get("sse_kms_key_id") + if self.sse_kms_key_id: + self.extra_args["SSEKMSKeyId"] = self.sse_kms_key_id + + self.acl = config.get("acl") + if self.acl: + self.extra_args["ACL"] = self.acl + + self._append_aws_grants_to_extra_args(config) + + shared_creds = config.get("credentialpath") + if shared_creds: + os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + + @wrap_prop(threading.Lock()) + @cached_property def s3(self): - return self.remote.s3 + import boto3 + + session = boto3.session.Session( + profile_name=self.profile, region_name=self.region + ) + + return session.client( + "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl + ) + + @classmethod + def get_etag(cls, s3, bucket, path): + obj = cls.get_head_object(s3, bucket, path) + + return obj["ETag"].strip('"') + + @staticmethod + def get_head_object(s3, bucket, path, *args, **kwargs): + + try: + obj = s3.head_object(Bucket=bucket, Key=path, *args, **kwargs) + except Exception as exc: + raise DvcException(f"s3://{bucket}/{path} does not exist") from exc + return obj + + def _append_aws_grants_to_extra_args(self, config): + # Keys for extra_args can be one of the following list: + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS + """ + ALLOWED_UPLOAD_ARGS = [ + 'ACL', 'CacheControl', 'ContentDisposition', 'ContentEncoding', + 'ContentLanguage', 'ContentType', 'Expires', 'GrantFullControl', + 'GrantRead', 'GrantReadACP', 'GrantWriteACP', 'Metadata', + 'RequestPayer', 'ServerSideEncryption', 'StorageClass', + 'SSECustomerAlgorithm', 'SSECustomerKey', 'SSECustomerKeyMD5', + 'SSEKMSKeyId', 'WebsiteRedirectLocation' + ] + """ + + grants = { + "grant_full_control": "GrantFullControl", + "grant_read": "GrantRead", + "grant_read_acp": "GrantReadACP", + "grant_write_acp": "GrantWriteACP", + } + + for grant_option, extra_args_key in grants.items(): + if config.get(grant_option): + if self.acl: + raise ConfigError( + "`acl` and `grant_*` AWS S3 config options " + "are mutually exclusive" + ) + + self.extra_args[extra_args_key] = config.get(grant_option) def _generate_download_url(self, path_info, expires=3600): params = {"Bucket": path_info.bucket, "Key": path_info.path} @@ -56,7 +149,7 @@ def isdir(self, path_info): # While `data/al/` will return nothing. # dir_path = path_info / "" - return bool(list(self.remote.list_paths(dir_path, max_items=1))) + return bool(list(self._list_paths(dir_path, max_items=1))) def isfile(self, path_info): from botocore.exceptions import ClientError @@ -73,8 +166,35 @@ def isfile(self, path_info): return True - def walk_files(self, path_info, max_items=None): - for fname in self.remote.list_paths(path_info / "", max_items): + def _list_objects(self, path_info, max_items=None, progress_callback=None): + """ Read config for list object api, paginate through list objects.""" + kwargs = { + "Bucket": path_info.bucket, + "Prefix": path_info.path, + "PaginationConfig": {"MaxItems": max_items}, + } + paginator = self.s3.get_paginator(self.list_objects_api) + for page in paginator.paginate(**kwargs): + contents = page.get("Contents", ()) + if progress_callback: + for item in contents: + progress_callback() + yield item + else: + yield from contents + + def _list_paths(self, path_info, max_items=None, progress_callback=None): + return ( + item["Key"] + for item in self._list_objects( + path_info, max_items, progress_callback + ) + ) + + def walk_files(self, path_info, max_items=None, **kwargs): + for fname in self._list_paths( + path_info / "", max_items=max_items, **kwargs + ): if fname.endswith("/"): continue @@ -100,7 +220,7 @@ def makedirs(self, path_info): self.s3.put_object(Bucket=path_info.bucket, Key=dir_path.path, Body="") def copy(self, from_info, to_info): - self._copy(self.s3, from_info, to_info, self.remote.extra_args) + self._copy(self.s3, from_info, to_info, self.extra_args) @classmethod def _copy_multipart( @@ -114,7 +234,7 @@ def _copy_multipart( parts = [] byte_position = 0 for i in range(1, n_parts + 1): - obj = S3Remote.get_head_object( + obj = S3RemoteTree.get_head_object( s3, from_info.bucket, from_info.path, PartNumber=i ) part_size = obj["ContentLength"] @@ -169,7 +289,9 @@ def _copy(cls, s3, from_info, to_info, extra_args): # object is transfered in the same chunks as it was originally. from boto3.s3.transfer import TransferConfig - obj = S3Remote.get_head_object(s3, from_info.bucket, from_info.path) + obj = S3RemoteTree.get_head_object( + s3, from_info.bucket, from_info.path + ) etag = obj["ETag"].strip('"') size = obj["ContentLength"] @@ -189,7 +311,7 @@ def _copy(cls, s3, from_info, to_info, extra_args): Config=TransferConfig(multipart_threshold=size + 1), ) - cached_etag = S3Remote.get_etag(s3, to_info.bucket, to_info.path) + cached_etag = S3RemoteTree.get_etag(s3, to_info.bucket, to_info.path) if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) @@ -223,142 +345,20 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): class S3Remote(BaseRemote): scheme = Schemes.S3 - path_cls = CloudURLInfo REQUIRES = {"boto3": "boto3"} PARAM_CHECKSUM = "etag" TREE_CLS = S3RemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url", "s3://") - self.path_info = self.path_cls(url) - - self.region = config.get("region") - self.profile = config.get("profile") - self.endpoint_url = config.get("endpointurl") - - if config.get("listobjects"): - self.list_objects_api = "list_objects" - else: - self.list_objects_api = "list_objects_v2" - - self.use_ssl = config.get("use_ssl", True) - - self.extra_args = {} - - self.sse = config.get("sse") - if self.sse: - self.extra_args["ServerSideEncryption"] = self.sse - - self.sse_kms_key_id = config.get("sse_kms_key_id") - if self.sse_kms_key_id: - self.extra_args["SSEKMSKeyId"] = self.sse_kms_key_id - - self.acl = config.get("acl") - if self.acl: - self.extra_args["ACL"] = self.acl - - self._append_aws_grants_to_extra_args(config) - - shared_creds = config.get("credentialpath") - if shared_creds: - os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) - - @wrap_prop(threading.Lock()) - @cached_property - def s3(self): - import boto3 - - session = boto3.session.Session( - profile_name=self.profile, region_name=self.region - ) - - return session.client( - "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl - ) - - @classmethod - def get_etag(cls, s3, bucket, path): - obj = cls.get_head_object(s3, bucket, path) - - return obj["ETag"].strip('"') - def get_file_checksum(self, path_info): - return self.get_etag(self.s3, path_info.bucket, path_info.path) - - @staticmethod - def get_head_object(s3, bucket, path, *args, **kwargs): - - try: - obj = s3.head_object(Bucket=bucket, Key=path, *args, **kwargs) - except Exception as exc: - raise DvcException(f"s3://{bucket}/{path} does not exist") from exc - return obj - - def _list_objects( - self, path_info, max_items=None, prefix=None, progress_callback=None - ): - """ Read config for list object api, paginate through list objects.""" - kwargs = { - "Bucket": path_info.bucket, - "Prefix": path_info.path, - "PaginationConfig": {"MaxItems": max_items}, - } - if prefix: - kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2]) - paginator = self.s3.get_paginator(self.list_objects_api) - for page in paginator.paginate(**kwargs): - contents = page.get("Contents", ()) - if progress_callback: - for item in contents: - progress_callback() - yield item - else: - yield from contents - - def list_paths( - self, path_info, max_items=None, prefix=None, progress_callback=None - ): - return ( - item["Key"] - for item in self._list_objects( - path_info, max_items, prefix, progress_callback - ) + return self.tree.get_etag( + self.tree.s3, path_info.bucket, path_info.path ) def list_cache_paths(self, prefix=None, progress_callback=None): - return self.list_paths( - self.path_info, prefix=prefix, progress_callback=progress_callback + if prefix: + path_info = self.path_info / prefix[:2] / prefix[2:] + else: + path_info = self.path_info + return self.tree.walk_files( + path_info, progress_callback=progress_callback ) - - def _append_aws_grants_to_extra_args(self, config): - # Keys for extra_args can be one of the following list: - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS - """ - ALLOWED_UPLOAD_ARGS = [ - 'ACL', 'CacheControl', 'ContentDisposition', 'ContentEncoding', - 'ContentLanguage', 'ContentType', 'Expires', 'GrantFullControl', - 'GrantRead', 'GrantReadACP', 'GrantWriteACP', 'Metadata', - 'RequestPayer', 'ServerSideEncryption', 'StorageClass', - 'SSECustomerAlgorithm', 'SSECustomerKey', 'SSECustomerKeyMD5', - 'SSEKMSKeyId', 'WebsiteRedirectLocation' - ] - """ - - grants = { - "grant_full_control": "GrantFullControl", - "grant_read": "GrantRead", - "grant_read_acp": "GrantReadACP", - "grant_write_acp": "GrantWriteACP", - } - - for grant_option, extra_args_key in grants.items(): - if config.get(grant_option): - if self.acl: - raise ConfigError( - "`acl` and `grant_*` AWS S3 config options " - "are mutually exclusive" - ) - - self.extra_args[extra_args_key] = config.get(grant_option) diff --git a/tests/unit/remote/test_s3.py b/tests/unit/remote/test_s3.py index bf726db616..d61c0d735d 100644 --- a/tests/unit/remote/test_s3.py +++ b/tests/unit/remote/test_s3.py @@ -22,7 +22,7 @@ def test_init(dvc): config = {"url": url} remote = S3Remote(dvc, config) - assert remote.path_info == url + assert remote.tree.path_info == url def test_grants(dvc): @@ -34,16 +34,16 @@ def test_grants(dvc): "grant_full_control": "id=full-control-permission-id", } remote = S3Remote(dvc, config) + tree = remote.tree assert ( - remote.extra_args["GrantRead"] + tree.extra_args["GrantRead"] == "id=read-permission-id,id=other-read-permission-id" ) - assert remote.extra_args["GrantReadACP"] == "id=read-acp-permission-id" - assert remote.extra_args["GrantWriteACP"] == "id=write-acp-permission-id" + assert tree.extra_args["GrantReadACP"] == "id=read-acp-permission-id" + assert tree.extra_args["GrantWriteACP"] == "id=write-acp-permission-id" assert ( - remote.extra_args["GrantFullControl"] - == "id=full-control-permission-id" + tree.extra_args["GrantFullControl"] == "id=full-control-permission-id" ) @@ -57,4 +57,4 @@ def test_grants_mutually_exclusive_acl_error(dvc, grants): def test_sse_kms_key_id(dvc): remote = S3Remote(dvc, {"url": url, "sse_kms_key_id": "key"}) - assert remote.extra_args["SSEKMSKeyId"] == "key" + assert remote.tree.extra_args["SSEKMSKeyId"] == "key" From e45b7d147198e0e165f5a6c3876bf182230aa0e7 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 16:52:11 +0900 Subject: [PATCH 09/15] remote.ssh: finish moving methods into tree --- dvc/remote/ssh/__init__.py | 211 +++++++++++++++--------------- tests/unit/remote/ssh/test_ssh.py | 36 ++--- 2 files changed, 122 insertions(+), 125 deletions(-) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 114567ef5c..2936843b92 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -34,9 +34,107 @@ def ask_password(host, user, port): class SSHRemoteTree(BaseRemoteTree): - @property - def ssh(self): - return self.remote.ssh + DEFAULT_PORT = 22 + TIMEOUT = 1800 + + def __init__(self, repo, config): + super().__init__(repo, config) + url = config.get("url") + if url: + parsed = urlparse(url) + user_ssh_config = self._load_user_ssh_config(parsed.hostname) + + host = user_ssh_config.get("hostname", parsed.hostname) + user = ( + config.get("user") + or parsed.username + or user_ssh_config.get("user") + or getpass.getuser() + ) + port = ( + config.get("port") + or parsed.port + or self._try_get_ssh_config_port(user_ssh_config) + or self.DEFAULT_PORT + ) + self.path_info = self.PATH_CLS.from_parts( + scheme=self.scheme, + host=host, + user=user, + port=port, + path=parsed.path, + ) + else: + self.path_info = None + user_ssh_config = {} + + self.keyfile = config.get( + "keyfile" + ) or self._try_get_ssh_config_keyfile(user_ssh_config) + self.timeout = config.get("timeout", self.TIMEOUT) + self.password = config.get("password", None) + self.ask_password = config.get("ask_password", False) + self.gss_auth = config.get("gss_auth", False) + proxy_command = user_ssh_config.get("proxycommand", False) + if proxy_command: + import paramiko + + self.sock = paramiko.ProxyCommand(proxy_command) + else: + self.sock = None + + @staticmethod + def ssh_config_filename(): + return os.path.expanduser(os.path.join("~", ".ssh", "config")) + + @staticmethod + def _load_user_ssh_config(hostname): + import paramiko + + user_config_file = SSHRemoteTree.ssh_config_filename() + user_ssh_config = {} + if hostname and os.path.exists(user_config_file): + ssh_config = paramiko.SSHConfig() + with open(user_config_file) as f: + # For whatever reason parsing directly from f is unreliable + f_copy = io.StringIO(f.read()) + ssh_config.parse(f_copy) + user_ssh_config = ssh_config.lookup(hostname) + return user_ssh_config + + @staticmethod + def _try_get_ssh_config_port(user_ssh_config): + return silent(int)(user_ssh_config.get("port")) + + @staticmethod + def _try_get_ssh_config_keyfile(user_ssh_config): + return first(user_ssh_config.get("identityfile") or ()) + + def ensure_credentials(self, path_info=None): + if path_info is None: + path_info = self.path_info + + # NOTE: we use the same password regardless of the server :( + if self.ask_password and self.password is None: + host, user, port = path_info.host, path_info.user, path_info.port + self.password = ask_password(host, user, port) + + def ssh(self, path_info): + self.ensure_credentials(path_info) + + from .connection import SSHConnection + + return get_connection( + SSHConnection, + path_info.host, + username=path_info.user, + port=path_info.port, + key_filename=self.keyfile, + timeout=self.timeout, + password=self.password, + gss_auth=self.gss_auth, + sock=self.sock, + ) @contextmanager def open(self, path_info, mode="r", encoding=None): @@ -158,8 +256,6 @@ class SSHRemote(BaseRemote): JOBS = 4 PARAM_CHECKSUM = "md5" - DEFAULT_PORT = 22 - TIMEOUT = 1800 # At any given time some of the connections will go over network and # paramiko stuff, so we would ideally have it double of server processors. # We use conservative setting of 4 instead to not exhaust max sessions. @@ -169,110 +265,11 @@ class SSHRemote(BaseRemote): DEFAULT_CACHE_TYPES = ["copy"] - def __init__(self, repo, config): - super().__init__(repo, config) - url = config.get("url") - if url: - parsed = urlparse(url) - user_ssh_config = self._load_user_ssh_config(parsed.hostname) - - host = user_ssh_config.get("hostname", parsed.hostname) - user = ( - config.get("user") - or parsed.username - or user_ssh_config.get("user") - or getpass.getuser() - ) - port = ( - config.get("port") - or parsed.port - or self._try_get_ssh_config_port(user_ssh_config) - or self.DEFAULT_PORT - ) - self.path_info = self.path_cls.from_parts( - scheme=self.scheme, - host=host, - user=user, - port=port, - path=parsed.path, - ) - else: - self.path_info = None - user_ssh_config = {} - - self.keyfile = config.get( - "keyfile" - ) or self._try_get_ssh_config_keyfile(user_ssh_config) - self.timeout = config.get("timeout", self.TIMEOUT) - self.password = config.get("password", None) - self.ask_password = config.get("ask_password", False) - self.gss_auth = config.get("gss_auth", False) - proxy_command = user_ssh_config.get("proxycommand", False) - if proxy_command: - import paramiko - - self.sock = paramiko.ProxyCommand(proxy_command) - else: - self.sock = None - - @staticmethod - def ssh_config_filename(): - return os.path.expanduser(os.path.join("~", ".ssh", "config")) - - @staticmethod - def _load_user_ssh_config(hostname): - import paramiko - - user_config_file = SSHRemote.ssh_config_filename() - user_ssh_config = {} - if hostname and os.path.exists(user_config_file): - ssh_config = paramiko.SSHConfig() - with open(user_config_file) as f: - # For whatever reason parsing directly from f is unreliable - f_copy = io.StringIO(f.read()) - ssh_config.parse(f_copy) - user_ssh_config = ssh_config.lookup(hostname) - return user_ssh_config - - @staticmethod - def _try_get_ssh_config_port(user_ssh_config): - return silent(int)(user_ssh_config.get("port")) - - @staticmethod - def _try_get_ssh_config_keyfile(user_ssh_config): - return first(user_ssh_config.get("identityfile") or ()) - - def ensure_credentials(self, path_info=None): - if path_info is None: - path_info = self.path_info - - # NOTE: we use the same password regardless of the server :( - if self.ask_password and self.password is None: - host, user, port = path_info.host, path_info.user, path_info.port - self.password = ask_password(host, user, port) - - def ssh(self, path_info): - self.ensure_credentials(path_info) - - from .connection import SSHConnection - - return get_connection( - SSHConnection, - path_info.host, - username=path_info.user, - port=path_info.port, - key_filename=self.keyfile, - timeout=self.timeout, - password=self.password, - gss_auth=self.gss_auth, - sock=self.sock, - ) - def get_file_checksum(self, path_info): if path_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(path_info) as ssh: + with self.tree.ssh(path_info) as ssh: return ssh.md5(path_info.path) def list_cache_paths(self, prefix=None, progress_callback=None): @@ -280,7 +277,7 @@ def list_cache_paths(self, prefix=None, progress_callback=None): root = posixpath.join(self.path_info.path, prefix[:2]) else: root = self.path_info.path - with self.ssh(self.path_info) as ssh: + with self.tree.ssh(self.path_info) as ssh: if prefix and not ssh.exists(root): return # If we simply return an iterator then with above closes instantly @@ -306,7 +303,7 @@ def _exists(chunk_and_channel): callback(path) return ret - with self.ssh(path_infos[0]) as ssh: + with self.tree.ssh(path_infos[0]) as ssh: channels = ssh.open_max_sftp_channels() max_workers = len(channels) diff --git a/tests/unit/remote/ssh/test_ssh.py b/tests/unit/remote/ssh/test_ssh.py index b0b22bec9a..288a1d7b13 100644 --- a/tests/unit/remote/ssh/test_ssh.py +++ b/tests/unit/remote/ssh/test_ssh.py @@ -5,7 +5,7 @@ import pytest from mock import mock_open, patch -from dvc.remote.ssh import SSHRemote +from dvc.remote.ssh import SSHRemote, SSHRemoteTree from dvc.system import System from tests.remotes import SSHMocked @@ -28,13 +28,13 @@ def test_url(dvc): config = {"url": url} remote = SSHRemote(dvc, config) - assert remote.path_info == url + assert remote.tree.path_info == url def test_no_path(dvc): config = {"url": "ssh://127.0.0.1"} remote = SSHRemote(dvc, config) - assert remote.path_info.path == "" + assert remote.tree.path_info.path == "" mock_ssh_config = """ @@ -69,9 +69,9 @@ def test_ssh_host_override_from_config( ): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.path_info.host == expected_host + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.path_info.host == expected_host @pytest.mark.parametrize( @@ -97,9 +97,9 @@ def test_ssh_host_override_from_config( def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.path_info.user == expected_user + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.path_info.user == expected_user @pytest.mark.parametrize( @@ -108,7 +108,7 @@ def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): ({"url": "ssh://example.com:2222"}, 2222), ({"url": "ssh://example.com"}, 1234), ({"url": "ssh://example.com", "port": 4321}, 4321), - ({"url": "ssh://not_in_ssh_config.com"}, SSHRemote.DEFAULT_PORT), + ({"url": "ssh://not_in_ssh_config.com"}, SSHRemoteTree.DEFAULT_PORT), ({"url": "ssh://not_in_ssh_config.com:2222"}, 2222), ({"url": "ssh://not_in_ssh_config.com:2222", "port": 4321}, 4321), ], @@ -122,8 +122,8 @@ def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): def test_ssh_port(mock_file, mock_exists, dvc, config, expected_port): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) assert remote.path_info.port == expected_port @@ -157,9 +157,9 @@ def test_ssh_port(mock_file, mock_exists, dvc, config, expected_port): def test_ssh_keyfile(mock_file, mock_exists, dvc, config, expected_keyfile): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.keyfile == expected_keyfile + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.keyfile == expected_keyfile @pytest.mark.parametrize( @@ -179,9 +179,9 @@ def test_ssh_keyfile(mock_file, mock_exists, dvc, config, expected_keyfile): def test_ssh_gss_auth(mock_file, mock_exists, dvc, config, expected_gss_auth): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.gss_auth == expected_gss_auth + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.gss_auth == expected_gss_auth def test_hardlink_optimization(dvc, tmp_dir, ssh_server): From b02b9eed53c3320c9adc4b8874b028c336880aa1 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 17:01:58 +0900 Subject: [PATCH 10/15] remote.gdrive: finish moving methods into tree --- dvc/remote/gdrive.py | 121 ++++++++++++++++++------------- tests/unit/remote/test_gdrive.py | 14 ++-- 2 files changed, 76 insertions(+), 59 deletions(-) diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 970d823d6d..25fbc9e489 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -88,43 +88,7 @@ def __init__(self, url): class GDriveRemoteTree(BaseRemoteTree): - def exists(self, path_info): - try: - self.remote.get_item_id(path_info) - except GDrivePathNotFound: - return False - else: - return True - - def remove(self, path_info): - item_id = self.remote.get_item_id(path_info) - self.remote.gdrive_delete_file(item_id) - - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - dirname = to_info.parent - assert dirname - parent_id = self.get_item_id(dirname, True) - - self._gdrive_upload_file( - parent_id, to_info.name, no_progress_bar, from_file, name - ) - - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - item_id = self.remote.get_item_id(from_info) - self.remote._gdrive_download_file( - item_id, to_file, name, no_progress_bar - ) - - -class GDriveRemote(BaseRemote): - scheme = Schemes.GDRIVE - path_cls = GDriveURLInfo - REQUIRES = {"pydrive2": "pydrive2"} - DEFAULT_VERIFY = True - # Always prefer traverse for GDrive since API usage quotas are a concern. - TRAVERSE_WEIGHT_MULTIPLIER = 1 - TRAVERSE_PREFIX_LEN = 2 - TREE_CLS = GDriveRemoteTree + PATH_CLS = GDriveURLInfo GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA" DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json" @@ -132,9 +96,10 @@ class GDriveRemote(BaseRemote): DEFAULT_GDRIVE_CLIENT_ID = "710796635688-iivsgbgsb6uv1fap6635dhvuei09o66c.apps.googleusercontent.com" # noqa: E501 DEFAULT_GDRIVE_CLIENT_SECRET = "a1Fz59uTpVNeG_VGuSKDLJXv" - def __init__(self, repo, config): - super().__init__(repo, config) - self.path_info = self.path_cls(config["url"]) + def __init__(self, remote, config): + super().__init__(remote, config) + + self.path_info = self.PATH_CLS(config["url"]) if not self.path_info.bucket: raise DvcException( @@ -161,11 +126,12 @@ def __init__(self, repo, config): self._validate_config() self._gdrive_user_credentials_path = ( tmp_fname(os.path.join(self.repo.tmp_dir, "")) - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA) + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) else config.get( "gdrive_user_credentials_file", os.path.join( - self.repo.tmp_dir, self.DEFAULT_USER_CREDENTIALS_FILE + self.remote.repo.tmp_dir, + self.DEFAULT_USER_CREDENTIALS_FILE, ), ) ) @@ -203,8 +169,8 @@ def credentials_location(self): Useful for tests, exception messages, etc. Returns either env variable name if it's set or actual path to the credentials file. """ - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): - return GDriveRemote.GDRIVE_CREDENTIALS_DATA + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): + return GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA if os.path.exists(self._gdrive_user_credentials_path): return self._gdrive_user_credentials_path return None @@ -218,7 +184,7 @@ def _validate_credentials(auth, settings): DVC config client id or secret but forgets to remove the cached credentials file. """ - if not os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): + if not os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): if ( settings["client_config"]["client_id"] != auth.credentials.client_id @@ -241,10 +207,10 @@ def _drive(self): from pydrive2.auth import GoogleAuth from pydrive2.drive import GoogleDrive - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): with open(self._gdrive_user_credentials_path, "w") as cred_file: cred_file.write( - os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA) + os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) ) auth_settings = { @@ -291,7 +257,7 @@ def _drive(self): gauth.ServiceAuth() else: gauth.CommandLineAuth() - GDriveRemote._validate_credentials(gauth, auth_settings) + GDriveRemoteTree._validate_credentials(gauth, auth_settings) # Handle AuthenticationError, RefreshError and other auth failures # It's hard to come up with a narrow exception, since PyDrive throws @@ -300,7 +266,7 @@ def _drive(self): except Exception as exc: raise GDriveAuthError(self.credentials_location) from exc finally: - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): os.remove(self._gdrive_user_credentials_path) return GoogleDrive(gauth) @@ -311,7 +277,7 @@ def _ids_cache(self): cache = { "dirs": defaultdict(list), "ids": {}, - "root_id": self.get_item_id( + "root_id": self._get_item_id( self.path_info, use_cache=False, hint="Confirm the directory exists and you can access it.", @@ -528,7 +494,7 @@ def _path_to_item_ids(self, path, create, use_cache): [self._create_dir(min(parent_ids), title, path)] if create else [] ) - def get_item_id(self, path_info, create=False, use_cache=True, hint=None): + def _get_item_id(self, path_info, create=False, use_cache=True, hint=None): assert path_info.bucket == self._bucket item_ids = self._path_to_item_ids(path_info.path, create, use_cache) @@ -538,7 +504,15 @@ def get_item_id(self, path_info, create=False, use_cache=True, hint=None): assert not create raise GDrivePathNotFound(path_info, hint) - def list_cache_paths(self, prefix=None, progress_callback=None): + def exists(self, path_info): + try: + self._get_item_id(path_info) + except GDrivePathNotFound: + return False + else: + return True + + def _list_paths(self, prefix=None, progress_callback=None): if not self._ids_cache["ids"]: return @@ -561,5 +535,48 @@ def list_cache_paths(self, prefix=None, progress_callback=None): self._ids_cache["ids"][parent_id], item["title"] ) + def walk_files(self, path_info, **kwargs): + if path_info == self.path_info: + prefix = None + else: + prefix = path_info.relative_to(self.path_info).path + return self._list_paths(prefix=prefix, **kwargs) + + def remove(self, path_info): + item_id = self._get_item_id(path_info) + self.gdrive_delete_file(item_id) + + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + dirname = to_info.parent + assert dirname + parent_id = self._get_item_id(dirname, True) + + self._gdrive_upload_file( + parent_id, to_info.name, no_progress_bar, from_file, name + ) + + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + item_id = self._get_item_id(from_info) + self._gdrive_download_file(item_id, to_file, name, no_progress_bar) + + +class GDriveRemote(BaseRemote): + scheme = Schemes.GDRIVE + REQUIRES = {"pydrive2": "pydrive2"} + DEFAULT_VERIFY = True + # Always prefer traverse for GDrive since API usage quotas are a concern. + TRAVERSE_WEIGHT_MULTIPLIER = 1 + TRAVERSE_PREFIX_LEN = 2 + TREE_CLS = GDriveRemoteTree + def get_file_checksum(self, path_info): raise NotImplementedError + + def list_cache_paths(self, prefix=None, progress_callback=None): + if prefix: + path_info = self.path_info / prefix[2:] + else: + path_info = self.path_info + return self.tree.walk_files( + path_info, progress_callback=progress_callback + ) diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index aeda479c10..b8b5ed83e8 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -2,7 +2,7 @@ import pytest -from dvc.remote.gdrive import GDriveAuthError, GDriveRemote +from dvc.remote.gdrive import GDriveAuthError, GDriveRemote, GDriveRemoteTree USER_CREDS_TOKEN_REFRESH_ERROR = '{"access_token": "", "client_id": "", "client_secret": "", "refresh_token": "", "token_expiry": "", "token_uri": "https://oauth2.googleapis.com/token", "user_agent": null, "revoke_uri": "https://oauth2.googleapis.com/revoke", "id_token": null, "id_token_jwt": null, "token_response": {"access_token": "", "expires_in": 3600, "scope": "https://www.googleapis.com/auth/drive.appdata https://www.googleapis.com/auth/drive", "token_type": "Bearer"}, "scopes": ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/drive.appdata"], "token_info_uri": "https://oauth2.googleapis.com/tokeninfo", "invalid": true, "_class": "OAuth2Credentials", "_module": "oauth2client.client"}' # noqa: E501 @@ -18,20 +18,20 @@ class TestRemoteGDrive: def test_init(self, dvc): remote = GDriveRemote(dvc, self.CONFIG) - assert str(remote.path_info) == self.CONFIG["url"] + assert str(remote.tree.path_info) == self.CONFIG["url"] def test_drive(self, dvc): remote = GDriveRemote(dvc, self.CONFIG) os.environ[ - GDriveRemote.GDRIVE_CREDENTIALS_DATA + GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_TOKEN_REFRESH_ERROR with pytest.raises(GDriveAuthError): - remote._drive + remote.tree._drive - os.environ[GDriveRemote.GDRIVE_CREDENTIALS_DATA] = "" + os.environ[GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA] = "" remote = GDriveRemote(dvc, self.CONFIG) os.environ[ - GDriveRemote.GDRIVE_CREDENTIALS_DATA + GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_MISSED_KEY_ERROR with pytest.raises(GDriveAuthError): - remote._drive + remote.tree._drive From ffdb91d12535e08b61374160a50968975f7d9730 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 17:20:18 +0900 Subject: [PATCH 11/15] tests: update remote unit tests --- tests/remotes.py | 8 ++++---- tests/unit/remote/test_base.py | 4 ++-- tests/unit/remote/test_remote.py | 2 +- tests/unit/remote/test_remote_tree.py | 10 +++++----- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/remotes.py b/tests/remotes.py index d1a58531ac..a43e7709be 100644 --- a/tests/remotes.py +++ b/tests/remotes.py @@ -7,7 +7,7 @@ from moto.s3 import mock_s3 -from dvc.remote import GDriveRemote +from dvc.remote.gdrive import GDriveRemoteTree from dvc.remote.gs import GSRemote from dvc.remote.s3 import S3Remote from dvc.utils import env2bool @@ -82,7 +82,7 @@ def remote(cls, repo): @staticmethod def put_objects(remote, objects): - s3 = remote.s3 + s3 = remote.tree.s3 bucket = remote.path_info.bucket s3.create_bucket(Bucket=bucket) for key, body in objects.items(): @@ -130,7 +130,7 @@ def remote(cls, repo): @staticmethod def put_objects(remote, objects): - client = remote.gs + client = remote.tree.gs bucket = client.get_bucket(remote.path_info.bucket) for key, body in objects.items(): bucket.blob((remote.path_info / key).path).upload_from_string(body) @@ -139,7 +139,7 @@ def put_objects(remote, objects): class GDrive: @staticmethod def should_test(): - return os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA) is not None + return os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) is not None @staticmethod def create_dir(dvc, url): diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index fa2325fb2b..481f9f3f1f 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -104,7 +104,7 @@ def test_cache_exists(object_exists, traverse, dvc): ) def test_cache_checksums_traverse(path_to_checksum, cache_checksums, dvc): remote = BaseRemote(dvc, {}) - remote.path_info = PathInfo("foo") + remote.tree.path_info = PathInfo("foo") # parallel traverse size = 256 / remote.JOBS * remote.LIST_OBJECT_PAGE_SIZE @@ -129,7 +129,7 @@ def test_cache_checksums_traverse(path_to_checksum, cache_checksums, dvc): def test_cache_checksums(dvc): remote = BaseRemote(dvc, {}) - remote.path_info = PathInfo("foo") + remote.tree.path_info = PathInfo("foo") with mock.patch.object( remote, "list_cache_paths", return_value=["12/3456", "bar"] diff --git a/tests/unit/remote/test_remote.py b/tests/unit/remote/test_remote.py index 8e0f82a5d8..34dc6948c6 100644 --- a/tests/unit/remote/test_remote.py +++ b/tests/unit/remote/test_remote.py @@ -35,7 +35,7 @@ def test_makedirs_not_create_for_top_level_path(remote_cls, dvc, mocker): remote = remote_cls(dvc, {"url": url}) mocked_client = mocker.PropertyMock() # we use remote clients with same name as scheme to interact with remote - mocker.patch.object(remote_cls, remote.scheme, mocked_client) + mocker.patch.object(remote_cls.TREE_CLS, remote.scheme, mocked_client) remote.tree.makedirs(remote.path_info) assert not mocked_client.called diff --git a/tests/unit/remote/test_remote_tree.py b/tests/unit/remote/test_remote_tree.py index fc9f557df2..e9b8dd213e 100644 --- a/tests/unit/remote/test_remote_tree.py +++ b/tests/unit/remote/test_remote_tree.py @@ -3,7 +3,7 @@ import pytest from dvc.path_info import PathInfo -from dvc.remote.s3 import S3Remote +from dvc.remote.s3 import S3Remote, S3RemoteTree from dvc.utils.fs import walk_files from tests.remotes import GCP, S3Mocked @@ -88,7 +88,7 @@ def test_walk_files(remote): @pytest.mark.parametrize("remote", [S3Mocked], indirect=True) def test_copy_preserve_etag_across_buckets(remote, dvc): - s3 = remote.s3 + s3 = remote.tree.s3 s3.create_bucket(Bucket="another") another = S3Remote(dvc, {"url": "s3://another", "region": "us-east-1"}) @@ -98,8 +98,8 @@ def test_copy_preserve_etag_across_buckets(remote, dvc): remote.tree.copy(from_info, to_info) - from_etag = S3Remote.get_etag(s3, from_info.bucket, from_info.path) - to_etag = S3Remote.get_etag(s3, "another", "foo") + from_etag = S3RemoteTree.get_etag(s3, from_info.bucket, from_info.path) + to_etag = S3RemoteTree.get_etag(s3, "another", "foo") assert from_etag == to_etag @@ -141,7 +141,7 @@ def test_isfile(remote): def test_download_dir(remote, tmpdir): path = str(tmpdir / "data") to_info = PathInfo(path) - remote.download(remote.path_info / "data", to_info) + remote.tree.download(remote.path_info / "data", to_info) assert os.path.isdir(path) data_dir = tmpdir / "data" assert len(list(walk_files(path))) == 7 From 75c5c0f0e0b4c76c77249626a8dd8484a61538fc Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 3 Jun 2020 18:24:27 +0900 Subject: [PATCH 12/15] tests: update func tests --- tests/func/remote/test_gdrive.py | 10 +++++----- tests/func/remote/test_index.py | 8 ++++---- tests/func/test_checkout.py | 4 ++-- tests/func/test_external_repo.py | 4 ++-- tests/func/test_import.py | 2 +- tests/func/test_remote.py | 14 ++++++++------ tests/func/test_repro.py | 6 +++--- tests/func/test_s3.py | 14 +++++++++----- tests/remotes.py | 2 +- 9 files changed, 35 insertions(+), 29 deletions(-) diff --git a/tests/func/remote/test_gdrive.py b/tests/func/remote/test_gdrive.py index 77aacc903a..0daecaee21 100644 --- a/tests/func/remote/test_gdrive.py +++ b/tests/func/remote/test_gdrive.py @@ -4,19 +4,19 @@ import configobj from dvc.main import main -from dvc.remote import GDriveRemote +from dvc.remote.gdrive import GDriveRemoteTree from dvc.repo import Repo def test_relative_user_credentials_file_config_setting(tmp_dir, dvc): # CI sets it to test GDrive, here we want to test the work with file system # based, regular credentials - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): - del os.environ[GDriveRemote.GDRIVE_CREDENTIALS_DATA] + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): + del os.environ[GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA] credentials = os.path.join("secrets", "credentials.json") - # GDriveRemote.credentials_location helper checks for file existence, + # GDriveRemoteTree.credentials_location helper checks for file existence, # create the file tmp_dir.gen(credentials, "{'token': 'test'}") @@ -50,6 +50,6 @@ def test_relative_user_credentials_file_config_setting(tmp_dir, dvc): # Check that in the remote itself we got an absolute path remote = repo.cloud.get_remote(remote_name) - assert os.path.normpath(remote.credentials_location) == os.path.join( + assert os.path.normpath(remote.tree.credentials_location) == os.path.join( str_path, credentials ) diff --git a/tests/func/remote/test_index.py b/tests/func/remote/test_index.py index cc98e7ebf1..81691bf091 100644 --- a/tests/func/remote/test_index.py +++ b/tests/func/remote/test_index.py @@ -5,7 +5,7 @@ from dvc.exceptions import DownloadError, UploadError from dvc.remote.base import BaseRemote from dvc.remote.index import RemoteIndex -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.utils.fs import remove @@ -80,7 +80,7 @@ def test_clear_on_download_err(tmp_dir, dvc, tmp_path_factory, remote, mocker): remove(dvc.cache.local.cache_dir) mocked_clear = mocker.patch.object(remote.INDEX_CLS, "clear") - mocker.patch.object(LocalRemote, "_download", side_effect=Exception) + mocker.patch.object(LocalRemoteTree, "_download", side_effect=Exception) with pytest.raises(DownloadError): dvc.pull() mocked_clear.assert_called_once_with() @@ -90,14 +90,14 @@ def test_partial_upload(tmp_dir, dvc, tmp_path_factory, remote, mocker): tmp_dir.dvc_gen({"foo": "foo content"}) tmp_dir.dvc_gen({"bar": {"baz": "baz content"}}) - original = LocalRemote._upload + original = LocalRemoteTree._upload def unreliable_upload(self, from_file, to_info, name=None, **kwargs): if "baz" in name: raise Exception("stop baz") return original(self, from_file, to_info, name, **kwargs) - mocker.patch.object(LocalRemote, "_upload", unreliable_upload) + mocker.patch.object(LocalRemoteTree, "_upload", unreliable_upload) with pytest.raises(UploadError): dvc.push() with remote.index: diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index bafba3a4c4..848074d1f6 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -759,7 +759,7 @@ def test_checkout_for_external_outputs(tmp_dir, dvc): remote = S3Remote(dvc, {"url": S3.get_url()}) file_path = remote.path_info / "foo" - remote.s3.put_object( + remote.tree.s3.put_object( Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo" ) @@ -770,7 +770,7 @@ def test_checkout_for_external_outputs(tmp_dir, dvc): assert stats == {**empty_checkout, "added": [str(file_path)]} assert remote.tree.exists(file_path) - remote.s3.put_object( + remote.tree.s3.put_object( Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo\nfoo" ) stats = dvc.checkout(force=True) diff --git a/tests/func/test_external_repo.py b/tests/func/test_external_repo.py index f42a191c08..a2f8ef098c 100644 --- a/tests/func/test_external_repo.py +++ b/tests/func/test_external_repo.py @@ -4,7 +4,7 @@ from dvc.external_repo import external_repo from dvc.path_info import PathInfo -from dvc.remote import LocalRemote +from dvc.remote.local import LocalRemoteTree from dvc.scm.git import Git from dvc.utils import relpath from dvc.utils.fs import remove @@ -49,7 +49,7 @@ def test_cache_reused(erepo_dir, mocker, setup_remote): erepo_dir.dvc_gen("file", "text", commit="add file") erepo_dir.dvc.push() - download_spy = mocker.spy(LocalRemote, "download") + download_spy = mocker.spy(LocalRemoteTree, "download") # Use URL to prevent any fishy optimizations url = f"file://{erepo_dir}" diff --git a/tests/func/test_import.py b/tests/func/test_import.py index 6ab037c8bc..a3b4328d06 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -238,7 +238,7 @@ def test_download_error_pulling_imported_stage(tmp_dir, dvc, erepo_dir): remove(dst_cache) with patch( - "dvc.remote.LocalRemote._download", side_effect=Exception + "dvc.remote.local.LocalRemoteTree._download", side_effect=Exception ), pytest.raises(DownloadError): dvc.pull(["foo_imported.dvc"]) diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 40d3b8156b..9b8a4cea8d 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -10,8 +10,8 @@ from dvc.exceptions import DownloadError, UploadError from dvc.main import main from dvc.path_info import PathInfo -from dvc.remote import LocalRemote from dvc.remote.base import BaseRemote, RemoteCacheRequiredError +from dvc.remote.local import LocalRemoteTree from dvc.utils.fs import remove from tests.basic_env import TestDvc from tests.remotes import Local @@ -181,14 +181,14 @@ def test_partial_push_n_pull(tmp_dir, dvc, tmp_path_factory, setup_remote): baz = tmp_dir.dvc_gen({"baz": {"foo": "baz content"}})[0].outs[0] # Faulty upload version, failing on foo - original = LocalRemote._upload + original = LocalRemoteTree._upload def unreliable_upload(self, from_file, to_info, name=None, **kwargs): if "foo" in name: raise Exception("stop foo") return original(self, from_file, to_info, name, **kwargs) - with patch.object(LocalRemote, "_upload", unreliable_upload): + with patch.object(LocalRemoteTree, "_upload", unreliable_upload): with pytest.raises(UploadError) as upload_error_info: dvc.push() assert upload_error_info.value.amount == 3 @@ -206,7 +206,7 @@ def unreliable_upload(self, from_file, to_info, name=None, **kwargs): dvc.push() remove(dvc.cache.local.cache_dir) - with patch.object(LocalRemote, "_download", side_effect=Exception): + with patch.object(LocalRemoteTree, "_download", side_effect=Exception): with pytest.raises(DownloadError) as download_error_info: dvc.pull() # error count should be len(.dir + standalone file checksums) @@ -221,7 +221,7 @@ def test_raise_on_too_many_open_files( tmp_dir.dvc_gen({"file": "file content"}) mocker.patch.object( - LocalRemote, + LocalRemoteTree, "_upload", side_effect=OSError(errno.EMFILE, "Too many open files"), ) @@ -255,7 +255,9 @@ def test_push_order(tmp_dir, dvc, tmp_path_factory, mocker, setup_remote): tmp_dir.dvc_gen({"foo": {"bar": "bar content"}}) tmp_dir.dvc_gen({"baz": "baz content"}) - mocked_upload = mocker.patch.object(LocalRemote, "_upload", return_value=0) + mocked_upload = mocker.patch.object( + LocalRemoteTree, "_upload", return_value=0 + ) dvc.push() # last uploaded file should be dir checksum assert mocked_upload.call_args[0][0].endswith(".dir") diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index ef677b13eb..08f44e25c4 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -26,7 +26,7 @@ from dvc.main import main from dvc.output.base import BaseOutput from dvc.path_info import URLInfo -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.stage.exceptions import StageFileDoesNotExistError @@ -1384,9 +1384,9 @@ def test_force_import(self): self.assertEqual(ret, 0) patch_download = patch.object( - LocalRemote, + LocalRemoteTree, "download", - side_effect=LocalRemote.download, + side_effect=LocalRemoteTree.download, autospec=True, ) diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index f2f4cef9d3..0b1f2d5bc9 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -30,7 +30,7 @@ def wrapped(*args, **kwargs): def _get_src_dst(): - base_info = S3Remote.path_cls(S3.get_url()) + base_info = S3RemoteTree.PATH_CLS(S3.get_url()) return base_info / "from", base_info / "to" @@ -48,12 +48,16 @@ def test_copy_singlepart_preserve_etag(): @mock_s3 @pytest.mark.parametrize( "base_info", - [S3Remote.path_cls("s3://bucket/"), S3Remote.path_cls("s3://bucket/ns/")], + [ + S3RemoteTree.PATH_CLS("s3://bucket/"), + S3RemoteTree.PATH_CLS("s3://bucket/ns/"), + ], ) def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): remote = S3Remote(dvc, {"url": str(base_info.parent)}) - remote.s3.create_bucket(Bucket=base_info.bucket) - remote.s3.put_object( + s3 = remote.tree.s3 + s3.create_bucket(Bucket=base_info.bucket) + s3.put_object( Bucket=base_info.bucket, Key=(base_info / "from").path, Body="data" ) remote.link(base_info / "from", base_info / "to") @@ -64,7 +68,7 @@ def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): @mock_s3 def test_makedirs_doesnot_try_on_top_level_paths(tmp_dir, dvc, scm): - base_info = S3Remote.path_cls("s3://bucket/") + base_info = S3RemoteTree.PATH_CLS("s3://bucket/") remote = S3Remote(dvc, {"url": str(base_info)}) remote.tree.makedirs(base_info) diff --git a/tests/remotes.py b/tests/remotes.py index a43e7709be..1a3c66b334 100644 --- a/tests/remotes.py +++ b/tests/remotes.py @@ -7,7 +7,7 @@ from moto.s3 import mock_s3 -from dvc.remote.gdrive import GDriveRemoteTree +from dvc.remote.gdrive import GDriveRemote, GDriveRemoteTree from dvc.remote.gs import GSRemote from dvc.remote.s3 import S3Remote from dvc.utils import env2bool From f7804b86d6c6a449ce05c09bbf057820ab4947b8 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 4 Jun 2020 14:23:10 +0900 Subject: [PATCH 13/15] remote: use walk_files() for all remotes * list_cache_paths() now uses tree.walk_files() for all remotes except local/ssh --- dvc/remote/azure.py | 13 +------------ dvc/remote/base.py | 14 +++++++++++++- dvc/remote/gdrive.py | 13 +------------ dvc/remote/gs.py | 13 +------------ dvc/remote/hdfs.py | 13 +------------ dvc/remote/http.py | 3 +++ dvc/remote/oss.py | 13 +------------ dvc/remote/s3.py | 27 +++++---------------------- 8 files changed, 26 insertions(+), 83 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 240be41fac..4feb57f40b 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -76,7 +76,7 @@ def exists(self, path_info): paths = self._list_paths(path_info.bucket, path_info.path) return any(path_info.path == path for path in paths) - def _list_paths(self, bucket, prefix, progress_callback=None): + def _list_paths(self, bucket, prefix): blob_service = self.blob_service next_marker = None while True: @@ -85,8 +85,6 @@ def _list_paths(self, bucket, prefix, progress_callback=None): ) for blob in blobs: - if progress_callback: - progress_callback() yield blob.name if not blobs.next_marker: @@ -143,12 +141,3 @@ class AzureRemote(BaseRemote): def get_file_checksum(self, path_info): return self.tree.get_etag(path_info) - - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - path_info = self.path_info / prefix[:2] / prefix[2:] - else: - path_info = self.path_info - return self.tree.walk_files( - path_info, progress_callback=progress_callback - ) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 4e9d480834..99f9333ce8 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -829,7 +829,19 @@ def checksum_to_path_info(self, checksum): checksum_to_path = checksum_to_path_info def list_cache_paths(self, prefix=None, progress_callback=None): - raise NotImplementedError + if prefix: + if len(prefix) > 2: + path_info = self.path_info / prefix[:2] / prefix[2:] + else: + path_info = self.path_info / prefix[:2] + else: + path_info = self.path_info + if progress_callback: + for file_info in self.tree.walk_files(path_info): + progress_callback() + yield file_info.path + else: + yield from self.tree.walk_files(path_info) def cache_checksums(self, prefix=None, progress_callback=None): """Iterate over remote cache checksums. diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 25fbc9e489..c9262f50f0 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -512,7 +512,7 @@ def exists(self, path_info): else: return True - def _list_paths(self, prefix=None, progress_callback=None): + def _list_paths(self, prefix=None): if not self._ids_cache["ids"]: return @@ -528,8 +528,6 @@ def _list_paths(self, prefix=None, progress_callback=None): query = f"({parents_query}) and trashed=false" for item in self._gdrive_list(query): - if progress_callback: - progress_callback() parent_id = item["parents"][0]["id"] yield posixpath.join( self._ids_cache["ids"][parent_id], item["title"] @@ -571,12 +569,3 @@ class GDriveRemote(BaseRemote): def get_file_checksum(self, path_info): raise NotImplementedError - - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - path_info = self.path_info / prefix[2:] - else: - path_info = self.path_info - return self.tree.walk_files( - path_info, progress_callback=progress_callback - ) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 31462b4a56..9079750762 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -115,12 +115,10 @@ def isfile(self, path_info): blob = self.gs.bucket(path_info.bucket).blob(path_info.path) return blob.exists() - def _list_paths(self, path_info, max_items=None, progress_callback=None): + def _list_paths(self, path_info, max_items=None): for blob in self.gs.bucket(path_info.bucket).list_blobs( prefix=path_info.path, max_results=max_items ): - if progress_callback: - progress_callback() yield blob.name def walk_files(self, path_info, **kwargs): @@ -202,12 +200,3 @@ def get_file_checksum(self, path_info): b64_md5 = blob.md5_hash md5 = base64.b64decode(b64_md5) return codecs.getencoder("hex")(md5)[0].decode("utf-8") - - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - path_info = self.path_info / prefix[:2] / prefix[2:] - else: - path_info = self.path_info - return self.tree.walk_files( - path_info, progress_callback=progress_callback - ) diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index c9f4a2f4ee..54bcf09665 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -72,7 +72,7 @@ def exists(self, path_info): with self.hdfs(path_info) as hdfs: return hdfs.exists(path_info.path) - def walk_files(self, path_info, progress_callback=None, **kwargs): + def walk_files(self, path_info, **kwargs): if not self.exists(path_info): return @@ -89,8 +89,6 @@ def walk_files(self, path_info, progress_callback=None, **kwargs): if entry["kind"] == "directory": dirs.append(urlparse(entry["name"]).path) elif entry["kind"] == "file": - if progress_callback: - progress_callback() yield urlparse(entry["name"]).path except OSError: # When searching for a specific prefix pyarrow raises an @@ -182,12 +180,3 @@ def get_file_checksum(self, path_info): f"checksum {path_info.path}", user=path_info.user ) return self._group(regex, stdout, "checksum") - - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - path_info = self.path_info / prefix[:2] - else: - path_info = self.path_info - return self.tree.walk_files( - path_info, progress_callback=progress_callback - ) diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 14aa821ec5..e0c2957e88 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -192,5 +192,8 @@ def get_file_checksum(self, path_info): return etag + def list_cache_paths(self, prefix=None, progress_callback=None): + raise NotImplementedError + def gc(self): raise NotImplementedError diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index c055682f8c..43eccf147e 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -68,14 +68,12 @@ def exists(self, path_info): paths = self.list_paths(path_info.path) return any(path_info.path == path for path in paths) - def _list_files(self, path_info, progress_callback=None): + def _list_files(self, path_info): import oss2 for blob in oss2.ObjectIterator( self.oss_service, prefix=path_info.path ): - if progress_callback: - progress_callback() yield blob.key def walk_files(self, path_info, **kwargs): @@ -134,12 +132,3 @@ class OSSRemote(BaseRemote): COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 100 TREE_CLS = OSSRemoteTree - - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - path_info = self.path_info / prefix[:2], prefix[2:] - else: - path_info = self.path_info - return self.tree.walk_files( - path_info, progress_callback=progress_callback - ) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index e65580eff3..cb5a4568cc 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -166,7 +166,7 @@ def isfile(self, path_info): return True - def _list_objects(self, path_info, max_items=None, progress_callback=None): + def _list_objects(self, path_info, max_items=None): """ Read config for list object api, paginate through list objects.""" kwargs = { "Bucket": path_info.bucket, @@ -176,19 +176,11 @@ def _list_objects(self, path_info, max_items=None, progress_callback=None): paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): contents = page.get("Contents", ()) - if progress_callback: - for item in contents: - progress_callback() - yield item - else: - yield from contents - - def _list_paths(self, path_info, max_items=None, progress_callback=None): + yield from contents + + def _list_paths(self, path_info, max_items=None): return ( - item["Key"] - for item in self._list_objects( - path_info, max_items, progress_callback - ) + item["Key"] for item in self._list_objects(path_info, max_items) ) def walk_files(self, path_info, max_items=None, **kwargs): @@ -353,12 +345,3 @@ def get_file_checksum(self, path_info): return self.tree.get_etag( self.tree.s3, path_info.bucket, path_info.path ) - - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - path_info = self.path_info / prefix[:2] / prefix[2:] - else: - path_info = self.path_info - return self.tree.walk_files( - path_info, progress_callback=progress_callback - ) From 233650e3c08d91402925fe4cee1be33ff619994b Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 4 Jun 2020 14:46:12 +0900 Subject: [PATCH 14/15] fix DS warnings --- dvc/remote/base.py | 5 ++--- dvc/remote/hdfs.py | 3 ++- dvc/remote/http.py | 3 ++- dvc/remote/local.py | 5 +++-- dvc/remote/s3.py | 6 ++---- dvc/remote/ssh/__init__.py | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 99f9333ce8..46ae988a28 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -165,9 +165,8 @@ def hardlink(self, from_info, to_info): def reflink(self, from_info, to_info): raise RemoteActionNotImplemented("reflink", self.scheme) - def _handle_transfer_exception( - self, from_info, to_info, exception, operation - ): + @staticmethod + def _handle_transfer_exception(from_info, to_info, exception, operation): if isinstance(exception, OSError) and exception.errno == errno.EMFILE: raise exception diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 54bcf09665..c6d244431c 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -37,7 +37,8 @@ def __init__(self, remote, config): path=parsed.path, ) - def hdfs(self, path_info): + @staticmethod + def hdfs(path_info): import pyarrow return get_connection( diff --git a/dvc/remote/http.py b/dvc/remote/http.py index e0c2957e88..154550bd16 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -162,7 +162,8 @@ def chunks(): if response.status_code not in (200, 201): raise HTTPError(response.status_code, response.reason) - def _content_length(self, response): + @staticmethod + def _content_length(response): res = response.headers.get("Content-Length") return int(res) if res else None diff --git a/dvc/remote/local.py b/dvc/remote/local.py index c639d26625..12ed0e691b 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -95,7 +95,7 @@ def iscopy(self, path_info): System.is_symlink(path_info) or System.is_hardlink(path_info) ) - def walk_files(self, path_info): + def walk_files(self, path_info, **kwargs): if self.work_tree: tree = self.work_tree else: @@ -226,8 +226,9 @@ def _upload( self.remote.protect(tmp_file) os.rename(tmp_file, to_info) + @staticmethod def _download( - self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs + from_info, to_file, name=None, no_progress_bar=False, **_kwargs ): copyfile( from_info, to_file, no_progress_bar=no_progress_bar, name=name diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index cb5a4568cc..642a743bb5 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -183,10 +183,8 @@ def _list_paths(self, path_info, max_items=None): item["Key"] for item in self._list_objects(path_info, max_items) ) - def walk_files(self, path_info, max_items=None, **kwargs): - for fname in self._list_paths( - path_info / "", max_items=max_items, **kwargs - ): + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths(path_info / "", **kwargs): if fname.endswith("/"): continue diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 2936843b92..e111f62cdf 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -160,7 +160,7 @@ def isfile(self, path_info): with self.ssh(path_info) as ssh: return ssh.isfile(path_info.path) - def walk_files(self, path_info): + def walk_files(self, path_info, **kwargs): with self.ssh(path_info) as ssh: for fname in ssh.walk_files(path_info.path): yield path_info.replace(path=fname) From 9be71f1ab47cdb2ebf5722ca55541c9eb480bc2f Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 4 Jun 2020 15:19:45 +0900 Subject: [PATCH 15/15] bugfixes --- dvc/remote/hdfs.py | 3 ++- dvc/remote/oss.py | 6 +++--- dvc/remote/ssh/__init__.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index c6d244431c..e39ebc318b 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -90,7 +90,8 @@ def walk_files(self, path_info, **kwargs): if entry["kind"] == "directory": dirs.append(urlparse(entry["name"]).path) elif entry["kind"] == "file": - yield urlparse(entry["name"]).path + path = urlparse(entry["name"]).path + yield path_info.replace(path=path) except OSError: # When searching for a specific prefix pyarrow raises an # exception if the specified cache dir does not exist diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 43eccf147e..8dcee9d584 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -65,10 +65,10 @@ def _generate_download_url(self, path_info, expires=3600): return self.oss_service.sign_url("GET", path_info.path, expires) def exists(self, path_info): - paths = self.list_paths(path_info.path) + paths = self._list_paths(path_info) return any(path_info.path == path for path in paths) - def _list_files(self, path_info): + def _list_paths(self, path_info): import oss2 for blob in oss2.ObjectIterator( @@ -77,7 +77,7 @@ def _list_files(self, path_info): yield blob.key def walk_files(self, path_info, **kwargs): - for fname in self._list_paths(path_info, **kwargs): + for fname in self._list_paths(path_info): if fname.endswith("/"): continue diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index e111f62cdf..54a968d4f4 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -326,7 +326,7 @@ def cache_exists(self, checksums, jobs=None, name=None): return list(set(checksums) & set(self.all())) # possibly prompt for credentials before "Querying" progress output - self.ensure_credentials() + self.tree.ensure_credentials() with Tqdm( desc="Querying "