diff --git a/dvc/output/base.py b/dvc/output/base.py index fe3c77c664..9c347fee77 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -119,7 +119,7 @@ def _parse_path(self, remote, path): if remote: parsed = urlparse(path) return remote.path_info / parsed.path.lstrip("/") - return self.REMOTE.path_cls(path) + return self.REMOTE.TREE_CLS.PATH_CLS(path) def __repr__(self): return "{class_name}: '{def_path}'".format( @@ -300,7 +300,7 @@ def verify_metric(self): raise DvcException(f"verify metric is not supported for {self.scheme}") def download(self, to): - self.remote.download(self.path_info, to.path_info) + self.remote.tree.download(self.path_info, to.path_info) def checkout( self, diff --git a/dvc/output/local.py b/dvc/output/local.py index f84d3478b4..3f992d7211 100644 --- a/dvc/output/local.py +++ b/dvc/output/local.py @@ -33,12 +33,12 @@ def _parse_path(self, remote, path): # # FIXME: if we have Windows path containing / or posix one with \ # then we have #2059 bug and can't really handle that. - p = self.REMOTE.path_cls(path) + p = self.REMOTE.TREE_CLS.PATH_CLS(path) if not p.is_absolute(): p = self.stage.wdir / p abs_p = os.path.abspath(os.path.normpath(p)) - return self.REMOTE.path_cls(abs_p) + return self.REMOTE.TREE_CLS.PATH_CLS(abs_p) def __str__(self): if not self.is_in_repo: diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 0b662b4a08..4feb57f40b 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -1,6 +1,5 @@ import logging import os -import posixpath import threading from datetime import datetime, timedelta @@ -15,56 +14,17 @@ class AzureRemoteTree(BaseRemoteTree): - @property - def blob_service(self): - return self.remote.blob_service - - def _generate_download_url(self, path_info, expires=3600): - from azure.storage.blob import BlobPermissions - - expires_at = datetime.utcnow() + timedelta(seconds=expires) - - sas_token = self.blob_service.generate_blob_shared_access_signature( - path_info.bucket, - path_info.path, - permission=BlobPermissions.READ, - expiry=expires_at, - ) - download_url = self.blob_service.make_blob_url( - path_info.bucket, path_info.path, sas_token=sas_token - ) - return download_url - - def exists(self, path_info): - paths = self.remote.list_paths(path_info.bucket, path_info.path) - return any(path_info.path == path for path in paths) - - def remove(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - logger.debug(f"Removing {path_info}") - self.blob_service.delete_blob(path_info.bucket, path_info.path) + PATH_CLS = CloudURLInfo - -class AzureRemote(BaseRemote): - scheme = Schemes.AZURE - path_cls = CloudURLInfo - REQUIRES = {"azure-storage-blob": "azure.storage.blob"} - PARAM_CHECKSUM = "etag" - COPY_POLL_SECONDS = 5 - LIST_OBJECT_PAGE_SIZE = 5000 - TREE_CLS = AzureRemoteTree - - def __init__(self, repo, config): - super().__init__(repo, config) + def __init__(self, remote, config): + super().__init__(remote, config) url = config.get("url", "azure://") - self.path_info = self.path_cls(url) + self.path_info = self.PATH_CLS(url) if not self.path_info.bucket: container = os.getenv("AZURE_STORAGE_CONTAINER_NAME") - self.path_info = self.path_cls(f"azure://{container}") + self.path_info = self.PATH_CLS(f"azure://{container}") self.connection_string = config.get("connection_string") or os.getenv( "AZURE_STORAGE_CONNECTION_STRING" @@ -96,10 +56,27 @@ def get_etag(self, path_info): ).properties.etag return etag.strip('"') - def get_file_checksum(self, path_info): - return self.get_etag(path_info) + def _generate_download_url(self, path_info, expires=3600): + from azure.storage.blob import BlobPermissions + + expires_at = datetime.utcnow() + timedelta(seconds=expires) - def list_paths(self, bucket, prefix, progress_callback=None): + sas_token = self.blob_service.generate_blob_shared_access_signature( + path_info.bucket, + path_info.path, + permission=BlobPermissions.READ, + expiry=expires_at, + ) + download_url = self.blob_service.make_blob_url( + path_info.bucket, path_info.path, sas_token=sas_token + ) + return download_url + + def exists(self, path_info): + paths = self._list_paths(path_info.bucket, path_info.path) + return any(path_info.path == path for path in paths) + + def _list_paths(self, bucket, prefix): blob_service = self.blob_service next_marker = None while True: @@ -108,8 +85,6 @@ def list_paths(self, bucket, prefix, progress_callback=None): ) for blob in blobs: - if progress_callback: - progress_callback() yield blob.name if not blobs.next_marker: @@ -117,16 +92,21 @@ def list_paths(self, bucket, prefix, progress_callback=None): next_marker = blobs.next_marker - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - prefix = posixpath.join( - self.path_info.path, prefix[:2], prefix[2:] - ) - else: - prefix = self.path_info.path - return self.list_paths( - self.path_info.bucket, prefix, progress_callback - ) + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths( + path_info.bucket, path_info.path, **kwargs + ): + if fname.endswith("/"): + continue + + yield path_info.replace(path=fname) + + def remove(self, path_info): + if path_info.scheme != self.scheme: + raise NotImplementedError + + logger.debug(f"Removing {path_info}") + self.blob_service.delete_blob(path_info.bucket, path_info.path) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs @@ -149,3 +129,15 @@ def _download( to_file, progress_callback=pbar.update_to, ) + + +class AzureRemote(BaseRemote): + scheme = Schemes.AZURE + REQUIRES = {"azure-storage-blob": "azure.storage.blob"} + PARAM_CHECKSUM = "etag" + COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 5000 + TREE_CLS = AzureRemoteTree + + def get_file_checksum(self, path_info): + return self.tree.get_etag(path_info) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 60d502dbec..46ae988a28 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -85,6 +85,7 @@ def wrapper(remote_obj, *args, **kwargs): class BaseRemoteTree: SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} + PATH_CLS = URLInfo def __init__(self, remote, config): self.remote = remote @@ -103,10 +104,6 @@ def dir_mode(self): def scheme(self): return self.remote.scheme - @property - def path_cls(self): - return self.remote.path_cls - def open(self, path_info, mode="r", encoding=None): if hasattr(self, "_generate_download_url"): get_url = partial(self._generate_download_url, path_info) @@ -133,7 +130,7 @@ def iscopy(self, path_info): """Check if this file is an independent copy.""" return False # We can't be sure by default - def walk_files(self, path_info): + def walk_files(self, path_info, **kwargs): """Return a generator with `PathInfo`s to all the files""" raise NotImplementedError @@ -168,10 +165,127 @@ def hardlink(self, from_info, to_info): def reflink(self, from_info, to_info): raise RemoteActionNotImplemented("reflink", self.scheme) + @staticmethod + def _handle_transfer_exception(from_info, to_info, exception, operation): + if isinstance(exception, OSError) and exception.errno == errno.EMFILE: + raise exception + + logger.exception( + "failed to %s '%s' to '%s'", operation, from_info, to_info + ) + return 1 + + def upload(self, from_info, to_info, name=None, no_progress_bar=False): + if not hasattr(self, "_upload"): + raise RemoteActionNotImplemented("upload", self.scheme) + + if to_info.scheme != self.scheme: + raise NotImplementedError + + if from_info.scheme != "local": + raise NotImplementedError + + logger.debug("Uploading '%s' to '%s'", from_info, to_info) + + name = name or from_info.name + + try: + self._upload( + from_info.fspath, + to_info, + name=name, + no_progress_bar=no_progress_bar, + ) + except Exception as e: + return self._handle_transfer_exception( + from_info, to_info, e, "upload" + ) + + return 0 + + def download( + self, + from_info, + to_info, + name=None, + no_progress_bar=False, + file_mode=None, + dir_mode=None, + ): + if not hasattr(self, "_download"): + raise RemoteActionNotImplemented("download", self.scheme) + + if from_info.scheme != self.scheme: + raise NotImplementedError + + if to_info.scheme == self.scheme != "local": + self.copy(from_info, to_info) + return 0 + + if to_info.scheme != "local": + raise NotImplementedError + + if self.isdir(from_info): + return self._download_dir( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + return self._download_file( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + + def _download_dir( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): + from_infos = list(self.walk_files(from_info)) + to_infos = ( + to_info / info.relative_to(from_info) for info in from_infos + ) + + with Tqdm( + total=len(from_infos), + desc="Downloading directory", + unit="Files", + disable=no_progress_bar, + ) as pbar: + download_files = pbar.wrap_fn( + partial( + self._download_file, + name=name, + no_progress_bar=True, + file_mode=file_mode, + dir_mode=dir_mode, + ) + ) + with ThreadPoolExecutor(max_workers=self.remote.JOBS) as executor: + futures = executor.map(download_files, from_infos, to_infos) + return sum(futures) + + def _download_file( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): + makedirs(to_info.parent, exist_ok=True, mode=dir_mode) + + logger.debug("Downloading '%s' to '%s'", from_info, to_info) + name = name or to_info.name + + tmp_file = tmp_fname(to_info) + + try: + self._download( + from_info, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + except Exception as e: + return self._handle_transfer_exception( + from_info, to_info, e, "download" + ) + + move(tmp_file, to_info, mode=file_mode) + + return 0 + class BaseRemote: scheme = "base" - path_cls = URLInfo REQUIRES = {} JOBS = 4 * cpu_count() INDEX_CLS = RemoteIndex @@ -219,6 +333,10 @@ def __init__(self, repo, config): self.tree = self.TREE_CLS(self, config) + @property + def path_info(self): + return self.tree.path_info + @classmethod def get_missing_deps(cls): import importlib @@ -374,7 +492,7 @@ def _get_dir_info_checksum(self, dir_info): from_info = PathInfo(tmp) to_info = self.cache.path_info / tmp_fname("") - self.cache.upload(from_info, to_info, no_progress_bar=True) + self.cache.tree.upload(from_info, to_info, no_progress_bar=True) checksum = self.get_file_checksum(to_info) + self.CHECKSUM_DIR_SUFFIX return checksum, to_info @@ -410,12 +528,12 @@ def load_dir_cache(self, checksum): ) return [] - if self.path_cls == WindowsPathInfo: + if self.tree.PATH_CLS == WindowsPathInfo: # only need to convert it for Windows for info in d: # NOTE: here is a BUG, see comment to .as_posix() below info[self.PARAM_RELPATH] = info[self.PARAM_RELPATH].replace( - "/", self.path_cls.sep + "/", self.tree.PATH_CLS.sep ) return d @@ -694,127 +812,8 @@ def _save(self, path_info, checksum, save_link=True, tree=None, **kwargs): def open(self, *args, **kwargs): return self.tree.open(*args, **kwargs) - def _handle_transfer_exception( - self, from_info, to_info, exception, operation - ): - if isinstance(exception, OSError) and exception.errno == errno.EMFILE: - raise exception - - logger.exception( - "failed to %s '%s' to '%s'", operation, from_info, to_info - ) - return 1 - - def upload(self, from_info, to_info, name=None, no_progress_bar=False): - if not hasattr(self, "_upload"): - raise RemoteActionNotImplemented("upload", self.scheme) - - if to_info.scheme != self.scheme: - raise NotImplementedError - - if from_info.scheme != "local": - raise NotImplementedError - - logger.debug("Uploading '%s' to '%s'", from_info, to_info) - - name = name or from_info.name - - try: - self._upload( - from_info.fspath, - to_info, - name=name, - no_progress_bar=no_progress_bar, - ) - except Exception as e: - return self._handle_transfer_exception( - from_info, to_info, e, "upload" - ) - - return 0 - - def download( - self, - from_info, - to_info, - name=None, - no_progress_bar=False, - file_mode=None, - dir_mode=None, - ): - if not hasattr(self, "_download"): - raise RemoteActionNotImplemented("download", self.scheme) - - if from_info.scheme != self.scheme: - raise NotImplementedError - - if to_info.scheme == self.scheme != "local": - self.tree.copy(from_info, to_info) - return 0 - - if to_info.scheme != "local": - raise NotImplementedError - - if self.tree.isdir(from_info): - return self._download_dir( - from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ) - return self._download_file( - from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ) - - def _download_dir( - self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ): - from_infos = list(self.tree.walk_files(from_info)) - to_infos = ( - to_info / info.relative_to(from_info) for info in from_infos - ) - - with Tqdm( - total=len(from_infos), - desc="Downloading directory", - unit="Files", - disable=no_progress_bar, - ) as pbar: - download_files = pbar.wrap_fn( - partial( - self._download_file, - name=name, - no_progress_bar=True, - file_mode=file_mode, - dir_mode=dir_mode, - ) - ) - with ThreadPoolExecutor(max_workers=self.JOBS) as executor: - futures = executor.map(download_files, from_infos, to_infos) - return sum(futures) - - def _download_file( - self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ): - makedirs(to_info.parent, exist_ok=True, mode=dir_mode) - - logger.debug("Downloading '%s' to '%s'", from_info, to_info) - name = name or to_info.name - - tmp_file = tmp_fname(to_info) - - try: - self._download( - from_info, tmp_file, name=name, no_progress_bar=no_progress_bar - ) - except Exception as e: - return self._handle_transfer_exception( - from_info, to_info, e, "download" - ) - - move(tmp_file, to_info, mode=file_mode) - - return 0 - def path_to_checksum(self, path): - parts = self.path_cls(path).parts[-2:] + parts = self.tree.PATH_CLS(path).parts[-2:] if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): raise ValueError(f"Bad cache file path '{path}'") @@ -829,7 +828,19 @@ def checksum_to_path_info(self, checksum): checksum_to_path = checksum_to_path_info def list_cache_paths(self, prefix=None, progress_callback=None): - raise NotImplementedError + if prefix: + if len(prefix) > 2: + path_info = self.path_info / prefix[:2] / prefix[2:] + else: + path_info = self.path_info / prefix[:2] + else: + path_info = self.path_info + if progress_callback: + for file_info in self.tree.walk_files(path_info): + progress_callback() + yield file_info.path + else: + yield from self.tree.walk_files(path_info) def cache_checksums(self, prefix=None, progress_callback=None): """Iterate over remote cache checksums. diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 3cef0760ee..c9262f50f0 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -88,28 +88,7 @@ def __init__(self, url): class GDriveRemoteTree(BaseRemoteTree): - def exists(self, path_info): - try: - self.remote.get_item_id(path_info) - except GDrivePathNotFound: - return False - else: - return True - - def remove(self, path_info): - item_id = self.remote.get_item_id(path_info) - self.remote.gdrive_delete_file(item_id) - - -class GDriveRemote(BaseRemote): - scheme = Schemes.GDRIVE - path_cls = GDriveURLInfo - REQUIRES = {"pydrive2": "pydrive2"} - DEFAULT_VERIFY = True - # Always prefer traverse for GDrive since API usage quotas are a concern. - TRAVERSE_WEIGHT_MULTIPLIER = 1 - TRAVERSE_PREFIX_LEN = 2 - TREE_CLS = GDriveRemoteTree + PATH_CLS = GDriveURLInfo GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA" DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json" @@ -117,9 +96,10 @@ class GDriveRemote(BaseRemote): DEFAULT_GDRIVE_CLIENT_ID = "710796635688-iivsgbgsb6uv1fap6635dhvuei09o66c.apps.googleusercontent.com" # noqa: E501 DEFAULT_GDRIVE_CLIENT_SECRET = "a1Fz59uTpVNeG_VGuSKDLJXv" - def __init__(self, repo, config): - super().__init__(repo, config) - self.path_info = self.path_cls(config["url"]) + def __init__(self, remote, config): + super().__init__(remote, config) + + self.path_info = self.PATH_CLS(config["url"]) if not self.path_info.bucket: raise DvcException( @@ -146,11 +126,12 @@ def __init__(self, repo, config): self._validate_config() self._gdrive_user_credentials_path = ( tmp_fname(os.path.join(self.repo.tmp_dir, "")) - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA) + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) else config.get( "gdrive_user_credentials_file", os.path.join( - self.repo.tmp_dir, self.DEFAULT_USER_CREDENTIALS_FILE + self.remote.repo.tmp_dir, + self.DEFAULT_USER_CREDENTIALS_FILE, ), ) ) @@ -188,8 +169,8 @@ def credentials_location(self): Useful for tests, exception messages, etc. Returns either env variable name if it's set or actual path to the credentials file. """ - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): - return GDriveRemote.GDRIVE_CREDENTIALS_DATA + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): + return GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA if os.path.exists(self._gdrive_user_credentials_path): return self._gdrive_user_credentials_path return None @@ -203,7 +184,7 @@ def _validate_credentials(auth, settings): DVC config client id or secret but forgets to remove the cached credentials file. """ - if not os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): + if not os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): if ( settings["client_config"]["client_id"] != auth.credentials.client_id @@ -226,10 +207,10 @@ def _drive(self): from pydrive2.auth import GoogleAuth from pydrive2.drive import GoogleDrive - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): with open(self._gdrive_user_credentials_path, "w") as cred_file: cred_file.write( - os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA) + os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) ) auth_settings = { @@ -276,7 +257,7 @@ def _drive(self): gauth.ServiceAuth() else: gauth.CommandLineAuth() - GDriveRemote._validate_credentials(gauth, auth_settings) + GDriveRemoteTree._validate_credentials(gauth, auth_settings) # Handle AuthenticationError, RefreshError and other auth failures # It's hard to come up with a narrow exception, since PyDrive throws @@ -285,7 +266,7 @@ def _drive(self): except Exception as exc: raise GDriveAuthError(self.credentials_location) from exc finally: - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): os.remove(self._gdrive_user_credentials_path) return GoogleDrive(gauth) @@ -296,7 +277,7 @@ def _ids_cache(self): cache = { "dirs": defaultdict(list), "ids": {}, - "root_id": self.get_item_id( + "root_id": self._get_item_id( self.path_info, use_cache=False, hint="Confirm the directory exists and you can access it.", @@ -513,7 +494,7 @@ def _path_to_item_ids(self, path, create, use_cache): [self._create_dir(min(parent_ids), title, path)] if create else [] ) - def get_item_id(self, path_info, create=False, use_cache=True, hint=None): + def _get_item_id(self, path_info, create=False, use_cache=True, hint=None): assert path_info.bucket == self._bucket item_ids = self._path_to_item_ids(path_info.path, create, use_cache) @@ -523,20 +504,15 @@ def get_item_id(self, path_info, create=False, use_cache=True, hint=None): assert not create raise GDrivePathNotFound(path_info, hint) - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - dirname = to_info.parent - assert dirname - parent_id = self.get_item_id(dirname, True) - - self._gdrive_upload_file( - parent_id, to_info.name, no_progress_bar, from_file, name - ) - - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - item_id = self.get_item_id(from_info) - self._gdrive_download_file(item_id, to_file, name, no_progress_bar) + def exists(self, path_info): + try: + self._get_item_id(path_info) + except GDrivePathNotFound: + return False + else: + return True - def list_cache_paths(self, prefix=None, progress_callback=None): + def _list_paths(self, prefix=None): if not self._ids_cache["ids"]: return @@ -552,12 +528,44 @@ def list_cache_paths(self, prefix=None, progress_callback=None): query = f"({parents_query}) and trashed=false" for item in self._gdrive_list(query): - if progress_callback: - progress_callback() parent_id = item["parents"][0]["id"] yield posixpath.join( self._ids_cache["ids"][parent_id], item["title"] ) + def walk_files(self, path_info, **kwargs): + if path_info == self.path_info: + prefix = None + else: + prefix = path_info.relative_to(self.path_info).path + return self._list_paths(prefix=prefix, **kwargs) + + def remove(self, path_info): + item_id = self._get_item_id(path_info) + self.gdrive_delete_file(item_id) + + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + dirname = to_info.parent + assert dirname + parent_id = self._get_item_id(dirname, True) + + self._gdrive_upload_file( + parent_id, to_info.name, no_progress_bar, from_file, name + ) + + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + item_id = self._get_item_id(from_info) + self._gdrive_download_file(item_id, to_file, name, no_progress_bar) + + +class GDriveRemote(BaseRemote): + scheme = Schemes.GDRIVE + REQUIRES = {"pydrive2": "pydrive2"} + DEFAULT_VERIFY = True + # Always prefer traverse for GDrive since API usage quotas are a concern. + TRAVERSE_WEIGHT_MULTIPLIER = 1 + TRAVERSE_PREFIX_LEN = 2 + TREE_CLS = GDriveRemoteTree + def get_file_checksum(self, path_info): raise NotImplementedError diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 5b42c8f24f..9079750762 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -1,6 +1,5 @@ import logging import os.path -import posixpath import threading from datetime import timedelta from functools import wraps @@ -66,9 +65,27 @@ def _upload_to_bucket( class GSRemoteTree(BaseRemoteTree): - @property + PATH_CLS = CloudURLInfo + + def __init__(self, remote, config): + super().__init__(remote, config) + + url = config.get("url", "gs:///") + self.path_info = self.PATH_CLS(url) + + self.projectname = config.get("projectname", None) + self.credentialpath = config.get("credentialpath") + + @wrap_prop(threading.Lock()) + @cached_property def gs(self): - return self.remote.gs + from google.cloud.storage import Client + + return ( + Client.from_service_account_json(self.credentialpath) + if self.credentialpath + else Client(self.projectname) + ) def _generate_download_url(self, path_info, expires=3600): expiration = timedelta(seconds=int(expires)) @@ -89,7 +106,7 @@ def exists(self, path_info): def isdir(self, path_info): dir_path = path_info / "" - return bool(list(self.remote.list_paths(dir_path, max_items=1))) + return bool(list(self._list_paths(dir_path, max_items=1))) def isfile(self, path_info): if path_info.path.endswith("/"): @@ -98,8 +115,14 @@ def isfile(self, path_info): blob = self.gs.bucket(path_info.bucket).blob(path_info.path) return blob.exists() - def walk_files(self, path_info): - for fname in self.remote.list_paths(path_info / ""): + def _list_paths(self, path_info, max_items=None): + for blob in self.gs.bucket(path_info.bucket).list_blobs( + prefix=path_info.path, max_results=max_items + ): + yield blob.name + + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths(path_info / "", **kwargs): # skip nested empty directories if fname.endswith("/"): continue @@ -134,67 +157,6 @@ def copy(self, from_info, to_info): to_bucket = self.gs.bucket(to_info.bucket) from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path) - -class GSRemote(BaseRemote): - scheme = Schemes.GS - path_cls = CloudURLInfo - REQUIRES = {"google-cloud-storage": "google.cloud.storage"} - PARAM_CHECKSUM = "md5" - TREE_CLS = GSRemoteTree - - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url", "gs:///") - self.path_info = self.path_cls(url) - - self.projectname = config.get("projectname", None) - self.credentialpath = config.get("credentialpath") - - @wrap_prop(threading.Lock()) - @cached_property - def gs(self): - from google.cloud.storage import Client - - return ( - Client.from_service_account_json(self.credentialpath) - if self.credentialpath - else Client(self.projectname) - ) - - def get_file_checksum(self, path_info): - import base64 - import codecs - - bucket = path_info.bucket - path = path_info.path - blob = self.gs.bucket(bucket).get_blob(path) - if not blob: - return None - - b64_md5 = blob.md5_hash - md5 = base64.b64decode(b64_md5) - return codecs.getencoder("hex")(md5)[0].decode("utf-8") - - def list_paths( - self, path_info, max_items=None, prefix=None, progress_callback=None - ): - if prefix: - prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:]) - else: - prefix = path_info.path - for blob in self.gs.bucket(path_info.bucket).list_blobs( - prefix=path_info.path, max_results=max_items - ): - if progress_callback: - progress_callback() - yield blob.name - - def list_cache_paths(self, prefix=None, progress_callback=None): - return self.list_paths( - self.path_info, prefix=prefix, progress_callback=progress_callback - ) - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): bucket = self.gs.bucket(to_info.bucket) _upload_to_bucket( @@ -217,3 +179,24 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): disable=no_progress_bar, ) as wrapped: blob.download_to_file(wrapped) + + +class GSRemote(BaseRemote): + scheme = Schemes.GS + REQUIRES = {"google-cloud-storage": "google.cloud.storage"} + PARAM_CHECKSUM = "md5" + TREE_CLS = GSRemoteTree + + def get_file_checksum(self, path_info): + import base64 + import codecs + + bucket = path_info.bucket + path = path_info.path + blob = self.gs.bucket(bucket).get_blob(path) + if not blob: + return None + + b64_md5 = blob.md5_hash + md5 = base64.b64decode(b64_md5) + return codecs.getencoder("hex")(md5)[0].decode("utf-8") diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 4e9430ab7c..e39ebc318b 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -18,9 +18,35 @@ class HDFSRemoteTree(BaseRemoteTree): - @property - def hdfs(self): - return self.remote.hdfs + def __init__(self, remote, config): + super().__init__(remote, config) + + self.path_info = None + url = config.get("url") + if not url: + return + + parsed = urlparse(url) + user = parsed.username or config.get("user") + + self.path_info = self.PATH_CLS.from_parts( + scheme=self.scheme, + host=parsed.hostname, + user=user, + port=parsed.port, + path=parsed.path, + ) + + @staticmethod + def hdfs(path_info): + import pyarrow + + return get_connection( + pyarrow.hdfs.connect, + path_info.host, + path_info.port, + user=path_info.user, + ) @contextmanager def open(self, path_info, mode="r", encoding=None): @@ -47,6 +73,30 @@ def exists(self, path_info): with self.hdfs(path_info) as hdfs: return hdfs.exists(path_info.path) + def walk_files(self, path_info, **kwargs): + if not self.exists(path_info): + return + + root = path_info.path + dirs = deque([root]) + + with self.hdfs(self.path_info) as hdfs: + if not hdfs.exists(root): + return + while dirs: + try: + entries = hdfs.ls(dirs.pop(), detail=True) + for entry in entries: + if entry["kind"] == "directory": + dirs.append(urlparse(entry["name"]).path) + elif entry["kind"] == "file": + path = urlparse(entry["name"]).path + yield path_info.replace(path=path) + except OSError: + # When searching for a specific prefix pyarrow raises an + # exception if the specified cache dir does not exist + pass + def remove(self, path_info): if path_info.scheme != "hdfs": raise NotImplementedError @@ -72,6 +122,19 @@ def copy(self, from_info, to_info, **_kwargs): self.remove(tmp_info) raise + def _upload(self, from_file, to_info, **_kwargs): + with self.hdfs(to_info) as hdfs: + hdfs.mkdir(posixpath.dirname(to_info.path)) + tmp_file = tmp_fname(to_info.path) + with open(from_file, "rb") as fobj: + hdfs.upload(tmp_file, fobj) + hdfs.rename(tmp_file, to_info.path) + + def _download(self, from_info, to_file, **_kwargs): + with self.hdfs(from_info) as hdfs: + with open(to_file, "wb+") as fobj: + hdfs.download(from_info.path, fobj) + class HDFSRemote(BaseRemote): scheme = Schemes.HDFS @@ -81,34 +144,6 @@ class HDFSRemote(BaseRemote): TRAVERSE_PREFIX_LEN = 2 TREE_CLS = HDFSRemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) - self.path_info = None - url = config.get("url") - if not url: - return - - parsed = urlparse(url) - user = parsed.username or config.get("user") - - self.path_info = self.path_cls.from_parts( - scheme=self.scheme, - host=parsed.hostname, - user=user, - port=parsed.port, - path=parsed.path, - ) - - def hdfs(self, path_info): - import pyarrow - - return get_connection( - pyarrow.hdfs.connect, - path_info.host, - path_info.port, - user=path_info.user, - ) - def hadoop_fs(self, cmd, user=None): cmd = "hadoop fs -" + cmd if user: @@ -147,45 +182,3 @@ def get_file_checksum(self, path_info): f"checksum {path_info.path}", user=path_info.user ) return self._group(regex, stdout, "checksum") - - def _upload(self, from_file, to_info, **_kwargs): - with self.hdfs(to_info) as hdfs: - hdfs.mkdir(posixpath.dirname(to_info.path)) - tmp_file = tmp_fname(to_info.path) - with open(from_file, "rb") as fobj: - hdfs.upload(tmp_file, fobj) - hdfs.rename(tmp_file, to_info.path) - - def _download(self, from_info, to_file, **_kwargs): - with self.hdfs(from_info) as hdfs: - with open(to_file, "wb+") as fobj: - hdfs.download(from_info.path, fobj) - - def list_cache_paths(self, prefix=None, progress_callback=None): - if not self.tree.exists(self.path_info): - return - - if prefix: - root = posixpath.join(self.path_info.path, prefix[:2]) - else: - root = self.path_info.path - dirs = deque([root]) - - with self.hdfs(self.path_info) as hdfs: - if prefix and not hdfs.exists(root): - return - while dirs: - try: - entries = hdfs.ls(dirs.pop(), detail=True) - for entry in entries: - if entry["kind"] == "directory": - dirs.append(urlparse(entry["name"]).path) - elif entry["kind"] == "file": - if progress_callback: - progress_callback() - yield urlparse(entry["name"]).path - except OSError as e: - # When searching for a specific prefix pyarrow raises an - # exception if the specified cache dir does not exist - if not prefix: - raise e diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 62252ea720..154550bd16 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -24,27 +24,19 @@ def ask_password(host, user): class HTTPRemoteTree(BaseRemoteTree): - def exists(self, path_info): - return bool(self.remote.request("HEAD", path_info.url)) - + PATH_CLS = HTTPURLInfo -class HTTPRemote(BaseRemote): - scheme = Schemes.HTTP - path_cls = HTTPURLInfo SESSION_RETRIES = 5 SESSION_BACKOFF_FACTOR = 0.1 REQUEST_TIMEOUT = 10 CHUNK_SIZE = 2 ** 16 - PARAM_CHECKSUM = "etag" - CAN_TRAVERSE = False - TREE_CLS = HTTPRemoteTree - def __init__(self, repo, config): - super().__init__(repo, config) + def __init__(self, remote, config): + super().__init__(remote, config) url = config.get("url") if url: - self.path_info = self.path_cls(url) + self.path_info = self.PATH_CLS(url) user = config.get("user", None) if user: self.path_info.user = user @@ -57,71 +49,7 @@ def __init__(self, repo, config): self.ask_password = config.get("ask_password", False) self.headers = {} - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - response = self.request("GET", from_info.url, stream=True) - if response.status_code != 200: - raise HTTPError(response.status_code, response.reason) - with open(to_file, "wb") as fd: - with Tqdm.wrapattr( - fd, - "write", - total=None - if no_progress_bar - else self._content_length(response), - leave=False, - desc=from_info.url if name is None else name, - disable=no_progress_bar, - ) as fd_wrapped: - for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE): - fd_wrapped.write(chunk) - - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): - def chunks(): - with open(from_file, "rb") as fd: - with Tqdm.wrapattr( - fd, - "read", - total=None - if no_progress_bar - else os.path.getsize(from_file), - leave=False, - desc=to_info.url if name is None else name, - disable=no_progress_bar, - ) as fd_wrapped: - while True: - chunk = fd_wrapped.read(self.CHUNK_SIZE) - if not chunk: - break - yield chunk - - response = self.request("POST", to_info.url, data=chunks()) - if response.status_code not in (200, 201): - raise HTTPError(response.status_code, response.reason) - - def _content_length(self, response): - res = response.headers.get("Content-Length") - return int(res) if res else None - - def get_file_checksum(self, path_info): - url = path_info.url - headers = self.request("HEAD", url).headers - etag = headers.get("ETag") or headers.get("Content-MD5") - - if not etag: - raise DvcException( - "could not find an ETag or " - "Content-MD5 header for '{url}'".format(url=url) - ) - - if etag.startswith("W/"): - raise DvcException( - "Weak ETags are not supported." - " (Etag: '{etag}', URL: '{url}')".format(etag=etag, url=url) - ) - - return etag - - def auth_method(self, path_info=None): + def _auth_method(self, path_info=None): from requests.auth import HTTPBasicAuth, HTTPDigestAuth if path_info is None: @@ -168,7 +96,7 @@ def request(self, method, url, **kwargs): res = self._session.request( method, url, - auth=self.auth_method(), + auth=self._auth_method(), headers=self.headers, **kwargs, ) @@ -190,5 +118,83 @@ def request(self, method, url, **kwargs): except requests.exceptions.RequestException: raise DvcException(f"could not perform a {method} request") + def exists(self, path_info): + return bool(self.request("HEAD", path_info.url)) + + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + response = self.request("GET", from_info.url, stream=True) + if response.status_code != 200: + raise HTTPError(response.status_code, response.reason) + with open(to_file, "wb") as fd: + with Tqdm.wrapattr( + fd, + "write", + total=None + if no_progress_bar + else self._content_length(response), + leave=False, + desc=from_info.url if name is None else name, + disable=no_progress_bar, + ) as fd_wrapped: + for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE): + fd_wrapped.write(chunk) + + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + def chunks(): + with open(from_file, "rb") as fd: + with Tqdm.wrapattr( + fd, + "read", + total=None + if no_progress_bar + else os.path.getsize(from_file), + leave=False, + desc=to_info.url if name is None else name, + disable=no_progress_bar, + ) as fd_wrapped: + while True: + chunk = fd_wrapped.read(self.CHUNK_SIZE) + if not chunk: + break + yield chunk + + response = self.request("POST", to_info.url, data=chunks()) + if response.status_code not in (200, 201): + raise HTTPError(response.status_code, response.reason) + + @staticmethod + def _content_length(response): + res = response.headers.get("Content-Length") + return int(res) if res else None + + +class HTTPRemote(BaseRemote): + scheme = Schemes.HTTP + PARAM_CHECKSUM = "etag" + CAN_TRAVERSE = False + TREE_CLS = HTTPRemoteTree + + def get_file_checksum(self, path_info): + url = path_info.url + headers = self.tree.request("HEAD", url).headers + etag = headers.get("ETag") or headers.get("Content-MD5") + + if not etag: + raise DvcException( + "could not find an ETag or " + "Content-MD5 header for '{url}'".format(url=url) + ) + + if etag.startswith("W/"): + raise DvcException( + "Weak ETags are not supported." + " (Etag: '{etag}', URL: '{url}')".format(etag=etag, url=url) + ) + + return etag + + def list_cache_paths(self, prefix=None, progress_callback=None): + raise NotImplementedError + def gc(self): raise NotImplementedError diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 6ae7451326..12ed0e691b 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -39,6 +39,11 @@ class LocalRemoteTree(BaseRemoteTree): SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} + PATH_CLS = PathInfo + + def __init__(self, remote, config): + super().__init__(remote, config) + self.path_info = config.get("url") @property def repo(self): @@ -90,7 +95,7 @@ def iscopy(self, path_info): System.is_symlink(path_info) or System.is_hardlink(path_info) ) - def walk_files(self, path_info): + def walk_files(self, path_info, **kwargs): if self.work_tree: tree = self.work_tree else: @@ -208,10 +213,30 @@ def reflink(self, from_info, to_info): def getsize(path_info): return os.path.getsize(path_info) + def _upload( + self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs + ): + makedirs(to_info.parent, exist_ok=True) + + tmp_file = tmp_fname(to_info) + copyfile( + from_file, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + + self.remote.protect(tmp_file) + os.rename(tmp_file, to_info) + + @staticmethod + def _download( + from_info, to_file, name=None, no_progress_bar=False, **_kwargs + ): + copyfile( + from_info, to_file, no_progress_bar=no_progress_bar, name=name + ) + class LocalRemote(BaseRemote): scheme = Schemes.LOCAL - path_cls = PathInfo PARAM_CHECKSUM = "md5" PARAM_PATH = "path" TRAVERSE_PREFIX_LEN = 2 @@ -227,7 +252,6 @@ class LocalRemote(BaseRemote): def __init__(self, repo, config): super().__init__(repo, config) self.cache_dir = config.get("url") - self._dir_info = {} @property def state(self): @@ -235,11 +259,11 @@ def state(self): @property def cache_dir(self): - return self.path_info.fspath if self.path_info else None + return self.tree.path_info.fspath if self.tree.path_info else None @cache_dir.setter def cache_dir(self, value): - self.path_info = PathInfo(value) if value else None + self.tree.path_info = PathInfo(value) if value else None @classmethod def supported(cls, config): @@ -309,26 +333,6 @@ def cache_exists(self, checksums, jobs=None, name=None): if not self.changed_cache_file(checksum) ] - def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs - ): - makedirs(to_info.parent, exist_ok=True) - - tmp_file = tmp_fname(to_info) - copyfile( - from_file, tmp_file, name=name, no_progress_bar=no_progress_bar - ) - - self.protect(tmp_file) - os.rename(tmp_file, to_info) - - def _download( - self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs - ): - copyfile( - from_info, to_file, no_progress_bar=no_progress_bar, name=name - ) - @index_locked def status( self, @@ -511,14 +515,14 @@ def _process( if download: func = partial( - remote.download, + remote.tree.download, dir_mode=self.tree.dir_mode, file_mode=self.tree.file_mode, ) status = STATUS_DELETED desc = "Downloading" else: - func = remote.upload + func = remote.tree.upload status = STATUS_NEW desc = "Uploading" diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index bf29d79064..8dcee9d584 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -1,6 +1,5 @@ import logging import os -import posixpath import threading from funcy import cached_property, wrap_prop @@ -14,59 +13,13 @@ class OSSRemoteTree(BaseRemoteTree): - @property - def oss_service(self): - return self.remote.oss_service - - def _generate_download_url(self, path_info, expires=3600): - assert path_info.bucket == self.remote.path_info.bucket - - return self.oss_service.sign_url("GET", path_info.path, expires) - - def exists(self, path_info): - paths = self.remote.list_paths(path_info.path) - return any(path_info.path == path for path in paths) - - def remove(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - logger.debug(f"Removing oss://{path_info}") - self.oss_service.delete_object(path_info.path) - - -class OSSRemote(BaseRemote): - """ - oss2 document: - https://www.alibabacloud.com/help/doc-detail/32026.htm - - - Examples - ---------- - $ dvc remote add myremote oss://my-bucket/path - Set key id, key secret and endpoint using modify command - $ dvc remote modify myremote oss_key_id my-key-id - $ dvc remote modify myremote oss_key_secret my-key-secret - $ dvc remote modify myremote oss_endpoint endpoint - or environment variables - $ export OSS_ACCESS_KEY_ID="my-key-id" - $ export OSS_ACCESS_KEY_SECRET="my-key-secret" - $ export OSS_ENDPOINT="endpoint" - """ - - scheme = Schemes.OSS - path_cls = CloudURLInfo - REQUIRES = {"oss2": "oss2"} - PARAM_CHECKSUM = "etag" - COPY_POLL_SECONDS = 5 - LIST_OBJECT_PAGE_SIZE = 100 - TREE_CLS = OSSRemoteTree + PATH_CLS = CloudURLInfo def __init__(self, repo, config): super().__init__(repo, config) url = config.get("url") - self.path_info = self.path_cls(url) if url else None + self.path_info = self.PATH_CLS(url) if url else None self.endpoint = config.get("oss_endpoint") or os.getenv("OSS_ENDPOINT") @@ -106,22 +59,36 @@ def oss_service(self): ) return bucket - def list_paths(self, prefix, progress_callback=None): + def _generate_download_url(self, path_info, expires=3600): + assert path_info.bucket == self.path_info.bucket + + return self.oss_service.sign_url("GET", path_info.path, expires) + + def exists(self, path_info): + paths = self._list_paths(path_info) + return any(path_info.path == path for path in paths) + + def _list_paths(self, path_info): import oss2 - for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): - if progress_callback: - progress_callback() + for blob in oss2.ObjectIterator( + self.oss_service, prefix=path_info.path + ): yield blob.key - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - prefix = posixpath.join( - self.path_info.path, prefix[:2], prefix[2:] - ) - else: - prefix = self.path_info.path - return self.list_paths(prefix, progress_callback) + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths(path_info): + if fname.endswith("/"): + continue + + yield path_info.replace(path=fname) + + def remove(self, path_info): + if path_info.scheme != self.scheme: + raise NotImplementedError + + logger.debug(f"Removing oss://{path_info}") + self.oss_service.delete_object(path_info.path) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs @@ -138,3 +105,30 @@ def _download( self.oss_service.get_object_to_file( from_info.path, to_file, progress_callback=pbar.update_to ) + + +class OSSRemote(BaseRemote): + """ + oss2 document: + https://www.alibabacloud.com/help/doc-detail/32026.htm + + + Examples + ---------- + $ dvc remote add myremote oss://my-bucket/path + Set key id, key secret and endpoint using modify command + $ dvc remote modify myremote oss_key_id my-key-id + $ dvc remote modify myremote oss_key_secret my-key-secret + $ dvc remote modify myremote oss_endpoint endpoint + or environment variables + $ export OSS_ACCESS_KEY_ID="my-key-id" + $ export OSS_ACCESS_KEY_SECRET="my-key-secret" + $ export OSS_ENDPOINT="endpoint" + """ + + scheme = Schemes.OSS + REQUIRES = {"oss2": "oss2"} + PARAM_CHECKSUM = "etag" + COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 100 + TREE_CLS = OSSRemoteTree diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index e98b12fea0..642a743bb5 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -1,6 +1,5 @@ import logging import os -import posixpath import threading from funcy import cached_property, wrap_prop @@ -16,9 +15,103 @@ class S3RemoteTree(BaseRemoteTree): - @property + PATH_CLS = CloudURLInfo + + def __init__(self, repo, config): + super().__init__(repo, config) + + url = config.get("url", "s3://") + self.path_info = self.PATH_CLS(url) + + self.region = config.get("region") + self.profile = config.get("profile") + self.endpoint_url = config.get("endpointurl") + + if config.get("listobjects"): + self.list_objects_api = "list_objects" + else: + self.list_objects_api = "list_objects_v2" + + self.use_ssl = config.get("use_ssl", True) + + self.extra_args = {} + + self.sse = config.get("sse") + if self.sse: + self.extra_args["ServerSideEncryption"] = self.sse + + self.sse_kms_key_id = config.get("sse_kms_key_id") + if self.sse_kms_key_id: + self.extra_args["SSEKMSKeyId"] = self.sse_kms_key_id + + self.acl = config.get("acl") + if self.acl: + self.extra_args["ACL"] = self.acl + + self._append_aws_grants_to_extra_args(config) + + shared_creds = config.get("credentialpath") + if shared_creds: + os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + + @wrap_prop(threading.Lock()) + @cached_property def s3(self): - return self.remote.s3 + import boto3 + + session = boto3.session.Session( + profile_name=self.profile, region_name=self.region + ) + + return session.client( + "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl + ) + + @classmethod + def get_etag(cls, s3, bucket, path): + obj = cls.get_head_object(s3, bucket, path) + + return obj["ETag"].strip('"') + + @staticmethod + def get_head_object(s3, bucket, path, *args, **kwargs): + + try: + obj = s3.head_object(Bucket=bucket, Key=path, *args, **kwargs) + except Exception as exc: + raise DvcException(f"s3://{bucket}/{path} does not exist") from exc + return obj + + def _append_aws_grants_to_extra_args(self, config): + # Keys for extra_args can be one of the following list: + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS + """ + ALLOWED_UPLOAD_ARGS = [ + 'ACL', 'CacheControl', 'ContentDisposition', 'ContentEncoding', + 'ContentLanguage', 'ContentType', 'Expires', 'GrantFullControl', + 'GrantRead', 'GrantReadACP', 'GrantWriteACP', 'Metadata', + 'RequestPayer', 'ServerSideEncryption', 'StorageClass', + 'SSECustomerAlgorithm', 'SSECustomerKey', 'SSECustomerKeyMD5', + 'SSEKMSKeyId', 'WebsiteRedirectLocation' + ] + """ + + grants = { + "grant_full_control": "GrantFullControl", + "grant_read": "GrantRead", + "grant_read_acp": "GrantReadACP", + "grant_write_acp": "GrantWriteACP", + } + + for grant_option, extra_args_key in grants.items(): + if config.get(grant_option): + if self.acl: + raise ConfigError( + "`acl` and `grant_*` AWS S3 config options " + "are mutually exclusive" + ) + + self.extra_args[extra_args_key] = config.get(grant_option) def _generate_download_url(self, path_info, expires=3600): params = {"Bucket": path_info.bucket, "Key": path_info.path} @@ -56,7 +149,7 @@ def isdir(self, path_info): # While `data/al/` will return nothing. # dir_path = path_info / "" - return bool(list(self.remote.list_paths(dir_path, max_items=1))) + return bool(list(self._list_paths(dir_path, max_items=1))) def isfile(self, path_info): from botocore.exceptions import ClientError @@ -73,8 +166,25 @@ def isfile(self, path_info): return True - def walk_files(self, path_info, max_items=None): - for fname in self.remote.list_paths(path_info / "", max_items): + def _list_objects(self, path_info, max_items=None): + """ Read config for list object api, paginate through list objects.""" + kwargs = { + "Bucket": path_info.bucket, + "Prefix": path_info.path, + "PaginationConfig": {"MaxItems": max_items}, + } + paginator = self.s3.get_paginator(self.list_objects_api) + for page in paginator.paginate(**kwargs): + contents = page.get("Contents", ()) + yield from contents + + def _list_paths(self, path_info, max_items=None): + return ( + item["Key"] for item in self._list_objects(path_info, max_items) + ) + + def walk_files(self, path_info, **kwargs): + for fname in self._list_paths(path_info / "", **kwargs): if fname.endswith("/"): continue @@ -100,7 +210,7 @@ def makedirs(self, path_info): self.s3.put_object(Bucket=path_info.bucket, Key=dir_path.path, Body="") def copy(self, from_info, to_info): - self._copy(self.s3, from_info, to_info, self.remote.extra_args) + self._copy(self.s3, from_info, to_info, self.extra_args) @classmethod def _copy_multipart( @@ -114,7 +224,7 @@ def _copy_multipart( parts = [] byte_position = 0 for i in range(1, n_parts + 1): - obj = S3Remote.get_head_object( + obj = S3RemoteTree.get_head_object( s3, from_info.bucket, from_info.path, PartNumber=i ) part_size = obj["ContentLength"] @@ -169,7 +279,9 @@ def _copy(cls, s3, from_info, to_info, extra_args): # object is transfered in the same chunks as it was originally. from boto3.s3.transfer import TransferConfig - obj = S3Remote.get_head_object(s3, from_info.bucket, from_info.path) + obj = S3RemoteTree.get_head_object( + s3, from_info.bucket, from_info.path + ) etag = obj["ETag"].strip('"') size = obj["ContentLength"] @@ -189,122 +301,10 @@ def _copy(cls, s3, from_info, to_info, extra_args): Config=TransferConfig(multipart_threshold=size + 1), ) - cached_etag = S3Remote.get_etag(s3, to_info.bucket, to_info.path) + cached_etag = S3RemoteTree.get_etag(s3, to_info.bucket, to_info.path) if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) - -class S3Remote(BaseRemote): - scheme = Schemes.S3 - path_cls = CloudURLInfo - REQUIRES = {"boto3": "boto3"} - PARAM_CHECKSUM = "etag" - TREE_CLS = S3RemoteTree - - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url", "s3://") - self.path_info = self.path_cls(url) - - self.region = config.get("region") - self.profile = config.get("profile") - self.endpoint_url = config.get("endpointurl") - - if config.get("listobjects"): - self.list_objects_api = "list_objects" - else: - self.list_objects_api = "list_objects_v2" - - self.use_ssl = config.get("use_ssl", True) - - self.extra_args = {} - - self.sse = config.get("sse") - if self.sse: - self.extra_args["ServerSideEncryption"] = self.sse - - self.sse_kms_key_id = config.get("sse_kms_key_id") - if self.sse_kms_key_id: - self.extra_args["SSEKMSKeyId"] = self.sse_kms_key_id - - self.acl = config.get("acl") - if self.acl: - self.extra_args["ACL"] = self.acl - - self._append_aws_grants_to_extra_args(config) - - shared_creds = config.get("credentialpath") - if shared_creds: - os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) - - @wrap_prop(threading.Lock()) - @cached_property - def s3(self): - import boto3 - - session = boto3.session.Session( - profile_name=self.profile, region_name=self.region - ) - - return session.client( - "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl - ) - - @classmethod - def get_etag(cls, s3, bucket, path): - obj = cls.get_head_object(s3, bucket, path) - - return obj["ETag"].strip('"') - - def get_file_checksum(self, path_info): - return self.get_etag(self.s3, path_info.bucket, path_info.path) - - @staticmethod - def get_head_object(s3, bucket, path, *args, **kwargs): - - try: - obj = s3.head_object(Bucket=bucket, Key=path, *args, **kwargs) - except Exception as exc: - raise DvcException(f"s3://{bucket}/{path} does not exist") from exc - return obj - - def _list_objects( - self, path_info, max_items=None, prefix=None, progress_callback=None - ): - """ Read config for list object api, paginate through list objects.""" - kwargs = { - "Bucket": path_info.bucket, - "Prefix": path_info.path, - "PaginationConfig": {"MaxItems": max_items}, - } - if prefix: - kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2]) - paginator = self.s3.get_paginator(self.list_objects_api) - for page in paginator.paginate(**kwargs): - contents = page.get("Contents", ()) - if progress_callback: - for item in contents: - progress_callback() - yield item - else: - yield from contents - - def list_paths( - self, path_info, max_items=None, prefix=None, progress_callback=None - ): - return ( - item["Key"] - for item in self._list_objects( - path_info, max_items, prefix, progress_callback - ) - ) - - def list_cache_paths(self, prefix=None, progress_callback=None): - return self.list_paths( - self.path_info, prefix=prefix, progress_callback=progress_callback - ) - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): total = os.path.getsize(from_file) with Tqdm( @@ -332,33 +332,14 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): from_info.bucket, from_info.path, to_file, Callback=pbar.update ) - def _append_aws_grants_to_extra_args(self, config): - # Keys for extra_args can be one of the following list: - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS - """ - ALLOWED_UPLOAD_ARGS = [ - 'ACL', 'CacheControl', 'ContentDisposition', 'ContentEncoding', - 'ContentLanguage', 'ContentType', 'Expires', 'GrantFullControl', - 'GrantRead', 'GrantReadACP', 'GrantWriteACP', 'Metadata', - 'RequestPayer', 'ServerSideEncryption', 'StorageClass', - 'SSECustomerAlgorithm', 'SSECustomerKey', 'SSECustomerKeyMD5', - 'SSEKMSKeyId', 'WebsiteRedirectLocation' - ] - """ - - grants = { - "grant_full_control": "GrantFullControl", - "grant_read": "GrantRead", - "grant_read_acp": "GrantReadACP", - "grant_write_acp": "GrantWriteACP", - } - for grant_option, extra_args_key in grants.items(): - if config.get(grant_option): - if self.acl: - raise ConfigError( - "`acl` and `grant_*` AWS S3 config options " - "are mutually exclusive" - ) +class S3Remote(BaseRemote): + scheme = Schemes.S3 + REQUIRES = {"boto3": "boto3"} + PARAM_CHECKSUM = "etag" + TREE_CLS = S3RemoteTree - self.extra_args[extra_args_key] = config.get(grant_option) + def get_file_checksum(self, path_info): + return self.tree.get_etag( + self.tree.s3, path_info.bucket, path_info.path + ) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 416b2d4eed..54a968d4f4 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -34,120 +34,8 @@ def ask_password(host, user, port): class SSHRemoteTree(BaseRemoteTree): - @property - def ssh(self): - return self.remote.ssh - - @contextmanager - def open(self, path_info, mode="r", encoding=None): - assert mode in {"r", "rt", "rb", "wb"} - - with self.ssh(path_info) as ssh, closing( - ssh.sftp.open(path_info.path, mode) - ) as fd: - if "b" in mode: - yield fd - else: - yield io.TextIOWrapper(fd, encoding=encoding) - - def exists(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.exists(path_info.path) - - def isdir(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.isdir(path_info.path) - - def isfile(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.isfile(path_info.path) - - def walk_files(self, path_info): - with self.ssh(path_info) as ssh: - for fname in ssh.walk_files(path_info.path): - yield path_info.replace(path=fname) - - def remove(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - with self.ssh(path_info) as ssh: - ssh.remove(path_info.path) - - def makedirs(self, path_info): - with self.ssh(path_info) as ssh: - ssh.makedirs(path_info.path) - - def move(self, from_info, to_info, mode=None): - assert mode is None - if from_info.scheme != self.scheme or to_info.scheme != self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.move(from_info.path, to_info.path) - - def copy(self, from_info, to_info): - if not from_info.scheme == to_info.scheme == self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.atomic_copy(from_info.path, to_info.path) - - def symlink(self, from_info, to_info): - if not from_info.scheme == to_info.scheme == self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.symlink(from_info.path, to_info.path) - - def hardlink(self, from_info, to_info): - if not from_info.scheme == to_info.scheme == self.scheme: - raise NotImplementedError - - # See dvc/remote/local/__init__.py - hardlink() - if self.getsize(from_info) == 0: - - with self.ssh(to_info) as ssh: - ssh.sftp.open(to_info.path, "w").close() - - logger.debug( - "Created empty file: {src} -> {dest}".format( - src=str(from_info), dest=str(to_info) - ) - ) - return - - with self.ssh(from_info) as ssh: - ssh.hardlink(from_info.path, to_info.path) - - def reflink(self, from_info, to_info): - if from_info.scheme != self.scheme or to_info.scheme != self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.reflink(from_info.path, to_info.path) - - def getsize(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.getsize(path_info.path) - - -class SSHRemote(BaseRemote): - scheme = Schemes.SSH - REQUIRES = {"paramiko": "paramiko"} - - JOBS = 4 - PARAM_CHECKSUM = "md5" DEFAULT_PORT = 22 TIMEOUT = 1800 - # At any given time some of the connections will go over network and - # paramiko stuff, so we would ideally have it double of server processors. - # We use conservative setting of 4 instead to not exhaust max sessions. - CHECKSUM_JOBS = 4 - TRAVERSE_PREFIX_LEN = 2 - TREE_CLS = SSHRemoteTree - - DEFAULT_CACHE_TYPES = ["copy"] def __init__(self, repo, config): super().__init__(repo, config) @@ -169,7 +57,7 @@ def __init__(self, repo, config): or self._try_get_ssh_config_port(user_ssh_config) or self.DEFAULT_PORT ) - self.path_info = self.path_cls.from_parts( + self.path_info = self.PATH_CLS.from_parts( scheme=self.scheme, host=host, user=user, @@ -203,7 +91,7 @@ def ssh_config_filename(): def _load_user_ssh_config(hostname): import paramiko - user_config_file = SSHRemote.ssh_config_filename() + user_config_file = SSHRemoteTree.ssh_config_filename() user_ssh_config = {} if hostname and os.path.exists(user_config_file): ssh_config = paramiko.SSHConfig() @@ -248,12 +136,98 @@ def ssh(self, path_info): sock=self.sock, ) - def get_file_checksum(self, path_info): + @contextmanager + def open(self, path_info, mode="r", encoding=None): + assert mode in {"r", "rt", "rb", "wb"} + + with self.ssh(path_info) as ssh, closing( + ssh.sftp.open(path_info.path, mode) + ) as fd: + if "b" in mode: + yield fd + else: + yield io.TextIOWrapper(fd, encoding=encoding) + + def exists(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.exists(path_info.path) + + def isdir(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.isdir(path_info.path) + + def isfile(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.isfile(path_info.path) + + def walk_files(self, path_info, **kwargs): + with self.ssh(path_info) as ssh: + for fname in ssh.walk_files(path_info.path): + yield path_info.replace(path=fname) + + def remove(self, path_info): if path_info.scheme != self.scheme: raise NotImplementedError with self.ssh(path_info) as ssh: - return ssh.md5(path_info.path) + ssh.remove(path_info.path) + + def makedirs(self, path_info): + with self.ssh(path_info) as ssh: + ssh.makedirs(path_info.path) + + def move(self, from_info, to_info, mode=None): + assert mode is None + if from_info.scheme != self.scheme or to_info.scheme != self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.move(from_info.path, to_info.path) + + def copy(self, from_info, to_info): + if not from_info.scheme == to_info.scheme == self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.atomic_copy(from_info.path, to_info.path) + + def symlink(self, from_info, to_info): + if not from_info.scheme == to_info.scheme == self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.symlink(from_info.path, to_info.path) + + def hardlink(self, from_info, to_info): + if not from_info.scheme == to_info.scheme == self.scheme: + raise NotImplementedError + + # See dvc/remote/local/__init__.py - hardlink() + if self.getsize(from_info) == 0: + + with self.ssh(to_info) as ssh: + ssh.sftp.open(to_info.path, "w").close() + + logger.debug( + "Created empty file: {src} -> {dest}".format( + src=str(from_info), dest=str(to_info) + ) + ) + return + + with self.ssh(from_info) as ssh: + ssh.hardlink(from_info.path, to_info.path) + + def reflink(self, from_info, to_info): + if from_info.scheme != self.scheme or to_info.scheme != self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.reflink(from_info.path, to_info.path) + + def getsize(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.getsize(path_info.path) def _download(self, from_info, to_file, name=None, no_progress_bar=False): assert from_info.isin(self.path_info) @@ -275,12 +249,35 @@ def _upload(self, from_file, to_info, name=None, no_progress_bar=False): no_progress_bar=no_progress_bar, ) + +class SSHRemote(BaseRemote): + scheme = Schemes.SSH + REQUIRES = {"paramiko": "paramiko"} + + JOBS = 4 + PARAM_CHECKSUM = "md5" + # At any given time some of the connections will go over network and + # paramiko stuff, so we would ideally have it double of server processors. + # We use conservative setting of 4 instead to not exhaust max sessions. + CHECKSUM_JOBS = 4 + TRAVERSE_PREFIX_LEN = 2 + TREE_CLS = SSHRemoteTree + + DEFAULT_CACHE_TYPES = ["copy"] + + def get_file_checksum(self, path_info): + if path_info.scheme != self.scheme: + raise NotImplementedError + + with self.tree.ssh(path_info) as ssh: + return ssh.md5(path_info.path) + def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: root = posixpath.join(self.path_info.path, prefix[:2]) else: root = self.path_info.path - with self.ssh(self.path_info) as ssh: + with self.tree.ssh(self.path_info) as ssh: if prefix and not ssh.exists(root): return # If we simply return an iterator then with above closes instantly @@ -306,7 +303,7 @@ def _exists(chunk_and_channel): callback(path) return ret - with self.ssh(path_infos[0]) as ssh: + with self.tree.ssh(path_infos[0]) as ssh: channels = ssh.open_max_sftp_channels() max_workers = len(channels) @@ -329,7 +326,7 @@ def cache_exists(self, checksums, jobs=None, name=None): return list(set(checksums) & set(self.all())) # possibly prompt for credentials before "Querying" progress output - self.ensure_credentials() + self.tree.ensure_credentials() with Tqdm( desc="Querying " diff --git a/tests/func/remote/test_gdrive.py b/tests/func/remote/test_gdrive.py index 77aacc903a..0daecaee21 100644 --- a/tests/func/remote/test_gdrive.py +++ b/tests/func/remote/test_gdrive.py @@ -4,19 +4,19 @@ import configobj from dvc.main import main -from dvc.remote import GDriveRemote +from dvc.remote.gdrive import GDriveRemoteTree from dvc.repo import Repo def test_relative_user_credentials_file_config_setting(tmp_dir, dvc): # CI sets it to test GDrive, here we want to test the work with file system # based, regular credentials - if os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA): - del os.environ[GDriveRemote.GDRIVE_CREDENTIALS_DATA] + if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA): + del os.environ[GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA] credentials = os.path.join("secrets", "credentials.json") - # GDriveRemote.credentials_location helper checks for file existence, + # GDriveRemoteTree.credentials_location helper checks for file existence, # create the file tmp_dir.gen(credentials, "{'token': 'test'}") @@ -50,6 +50,6 @@ def test_relative_user_credentials_file_config_setting(tmp_dir, dvc): # Check that in the remote itself we got an absolute path remote = repo.cloud.get_remote(remote_name) - assert os.path.normpath(remote.credentials_location) == os.path.join( + assert os.path.normpath(remote.tree.credentials_location) == os.path.join( str_path, credentials ) diff --git a/tests/func/remote/test_index.py b/tests/func/remote/test_index.py index cc98e7ebf1..81691bf091 100644 --- a/tests/func/remote/test_index.py +++ b/tests/func/remote/test_index.py @@ -5,7 +5,7 @@ from dvc.exceptions import DownloadError, UploadError from dvc.remote.base import BaseRemote from dvc.remote.index import RemoteIndex -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.utils.fs import remove @@ -80,7 +80,7 @@ def test_clear_on_download_err(tmp_dir, dvc, tmp_path_factory, remote, mocker): remove(dvc.cache.local.cache_dir) mocked_clear = mocker.patch.object(remote.INDEX_CLS, "clear") - mocker.patch.object(LocalRemote, "_download", side_effect=Exception) + mocker.patch.object(LocalRemoteTree, "_download", side_effect=Exception) with pytest.raises(DownloadError): dvc.pull() mocked_clear.assert_called_once_with() @@ -90,14 +90,14 @@ def test_partial_upload(tmp_dir, dvc, tmp_path_factory, remote, mocker): tmp_dir.dvc_gen({"foo": "foo content"}) tmp_dir.dvc_gen({"bar": {"baz": "baz content"}}) - original = LocalRemote._upload + original = LocalRemoteTree._upload def unreliable_upload(self, from_file, to_info, name=None, **kwargs): if "baz" in name: raise Exception("stop baz") return original(self, from_file, to_info, name, **kwargs) - mocker.patch.object(LocalRemote, "_upload", unreliable_upload) + mocker.patch.object(LocalRemoteTree, "_upload", unreliable_upload) with pytest.raises(UploadError): dvc.push() with remote.index: diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index bafba3a4c4..848074d1f6 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -759,7 +759,7 @@ def test_checkout_for_external_outputs(tmp_dir, dvc): remote = S3Remote(dvc, {"url": S3.get_url()}) file_path = remote.path_info / "foo" - remote.s3.put_object( + remote.tree.s3.put_object( Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo" ) @@ -770,7 +770,7 @@ def test_checkout_for_external_outputs(tmp_dir, dvc): assert stats == {**empty_checkout, "added": [str(file_path)]} assert remote.tree.exists(file_path) - remote.s3.put_object( + remote.tree.s3.put_object( Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo\nfoo" ) stats = dvc.checkout(force=True) diff --git a/tests/func/test_external_repo.py b/tests/func/test_external_repo.py index f42a191c08..a2f8ef098c 100644 --- a/tests/func/test_external_repo.py +++ b/tests/func/test_external_repo.py @@ -4,7 +4,7 @@ from dvc.external_repo import external_repo from dvc.path_info import PathInfo -from dvc.remote import LocalRemote +from dvc.remote.local import LocalRemoteTree from dvc.scm.git import Git from dvc.utils import relpath from dvc.utils.fs import remove @@ -49,7 +49,7 @@ def test_cache_reused(erepo_dir, mocker, setup_remote): erepo_dir.dvc_gen("file", "text", commit="add file") erepo_dir.dvc.push() - download_spy = mocker.spy(LocalRemote, "download") + download_spy = mocker.spy(LocalRemoteTree, "download") # Use URL to prevent any fishy optimizations url = f"file://{erepo_dir}" diff --git a/tests/func/test_import.py b/tests/func/test_import.py index 6ab037c8bc..a3b4328d06 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -238,7 +238,7 @@ def test_download_error_pulling_imported_stage(tmp_dir, dvc, erepo_dir): remove(dst_cache) with patch( - "dvc.remote.LocalRemote._download", side_effect=Exception + "dvc.remote.local.LocalRemoteTree._download", side_effect=Exception ), pytest.raises(DownloadError): dvc.pull(["foo_imported.dvc"]) diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 40d3b8156b..9b8a4cea8d 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -10,8 +10,8 @@ from dvc.exceptions import DownloadError, UploadError from dvc.main import main from dvc.path_info import PathInfo -from dvc.remote import LocalRemote from dvc.remote.base import BaseRemote, RemoteCacheRequiredError +from dvc.remote.local import LocalRemoteTree from dvc.utils.fs import remove from tests.basic_env import TestDvc from tests.remotes import Local @@ -181,14 +181,14 @@ def test_partial_push_n_pull(tmp_dir, dvc, tmp_path_factory, setup_remote): baz = tmp_dir.dvc_gen({"baz": {"foo": "baz content"}})[0].outs[0] # Faulty upload version, failing on foo - original = LocalRemote._upload + original = LocalRemoteTree._upload def unreliable_upload(self, from_file, to_info, name=None, **kwargs): if "foo" in name: raise Exception("stop foo") return original(self, from_file, to_info, name, **kwargs) - with patch.object(LocalRemote, "_upload", unreliable_upload): + with patch.object(LocalRemoteTree, "_upload", unreliable_upload): with pytest.raises(UploadError) as upload_error_info: dvc.push() assert upload_error_info.value.amount == 3 @@ -206,7 +206,7 @@ def unreliable_upload(self, from_file, to_info, name=None, **kwargs): dvc.push() remove(dvc.cache.local.cache_dir) - with patch.object(LocalRemote, "_download", side_effect=Exception): + with patch.object(LocalRemoteTree, "_download", side_effect=Exception): with pytest.raises(DownloadError) as download_error_info: dvc.pull() # error count should be len(.dir + standalone file checksums) @@ -221,7 +221,7 @@ def test_raise_on_too_many_open_files( tmp_dir.dvc_gen({"file": "file content"}) mocker.patch.object( - LocalRemote, + LocalRemoteTree, "_upload", side_effect=OSError(errno.EMFILE, "Too many open files"), ) @@ -255,7 +255,9 @@ def test_push_order(tmp_dir, dvc, tmp_path_factory, mocker, setup_remote): tmp_dir.dvc_gen({"foo": {"bar": "bar content"}}) tmp_dir.dvc_gen({"baz": "baz content"}) - mocked_upload = mocker.patch.object(LocalRemote, "_upload", return_value=0) + mocked_upload = mocker.patch.object( + LocalRemoteTree, "_upload", return_value=0 + ) dvc.push() # last uploaded file should be dir checksum assert mocked_upload.call_args[0][0].endswith(".dir") diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index ef677b13eb..08f44e25c4 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -26,7 +26,7 @@ from dvc.main import main from dvc.output.base import BaseOutput from dvc.path_info import URLInfo -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.stage.exceptions import StageFileDoesNotExistError @@ -1384,9 +1384,9 @@ def test_force_import(self): self.assertEqual(ret, 0) patch_download = patch.object( - LocalRemote, + LocalRemoteTree, "download", - side_effect=LocalRemote.download, + side_effect=LocalRemoteTree.download, autospec=True, ) diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index f2f4cef9d3..0b1f2d5bc9 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -30,7 +30,7 @@ def wrapped(*args, **kwargs): def _get_src_dst(): - base_info = S3Remote.path_cls(S3.get_url()) + base_info = S3RemoteTree.PATH_CLS(S3.get_url()) return base_info / "from", base_info / "to" @@ -48,12 +48,16 @@ def test_copy_singlepart_preserve_etag(): @mock_s3 @pytest.mark.parametrize( "base_info", - [S3Remote.path_cls("s3://bucket/"), S3Remote.path_cls("s3://bucket/ns/")], + [ + S3RemoteTree.PATH_CLS("s3://bucket/"), + S3RemoteTree.PATH_CLS("s3://bucket/ns/"), + ], ) def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): remote = S3Remote(dvc, {"url": str(base_info.parent)}) - remote.s3.create_bucket(Bucket=base_info.bucket) - remote.s3.put_object( + s3 = remote.tree.s3 + s3.create_bucket(Bucket=base_info.bucket) + s3.put_object( Bucket=base_info.bucket, Key=(base_info / "from").path, Body="data" ) remote.link(base_info / "from", base_info / "to") @@ -64,7 +68,7 @@ def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): @mock_s3 def test_makedirs_doesnot_try_on_top_level_paths(tmp_dir, dvc, scm): - base_info = S3Remote.path_cls("s3://bucket/") + base_info = S3RemoteTree.PATH_CLS("s3://bucket/") remote = S3Remote(dvc, {"url": str(base_info)}) remote.tree.makedirs(base_info) diff --git a/tests/remotes.py b/tests/remotes.py index d1a58531ac..1a3c66b334 100644 --- a/tests/remotes.py +++ b/tests/remotes.py @@ -7,7 +7,7 @@ from moto.s3 import mock_s3 -from dvc.remote import GDriveRemote +from dvc.remote.gdrive import GDriveRemote, GDriveRemoteTree from dvc.remote.gs import GSRemote from dvc.remote.s3 import S3Remote from dvc.utils import env2bool @@ -82,7 +82,7 @@ def remote(cls, repo): @staticmethod def put_objects(remote, objects): - s3 = remote.s3 + s3 = remote.tree.s3 bucket = remote.path_info.bucket s3.create_bucket(Bucket=bucket) for key, body in objects.items(): @@ -130,7 +130,7 @@ def remote(cls, repo): @staticmethod def put_objects(remote, objects): - client = remote.gs + client = remote.tree.gs bucket = client.get_bucket(remote.path_info.bucket) for key, body in objects.items(): bucket.blob((remote.path_info / key).path).upload_from_string(body) @@ -139,7 +139,7 @@ def put_objects(remote, objects): class GDrive: @staticmethod def should_test(): - return os.getenv(GDriveRemote.GDRIVE_CREDENTIALS_DATA) is not None + return os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) is not None @staticmethod def create_dir(dvc, url): diff --git a/tests/unit/remote/ssh/test_ssh.py b/tests/unit/remote/ssh/test_ssh.py index b0b22bec9a..288a1d7b13 100644 --- a/tests/unit/remote/ssh/test_ssh.py +++ b/tests/unit/remote/ssh/test_ssh.py @@ -5,7 +5,7 @@ import pytest from mock import mock_open, patch -from dvc.remote.ssh import SSHRemote +from dvc.remote.ssh import SSHRemote, SSHRemoteTree from dvc.system import System from tests.remotes import SSHMocked @@ -28,13 +28,13 @@ def test_url(dvc): config = {"url": url} remote = SSHRemote(dvc, config) - assert remote.path_info == url + assert remote.tree.path_info == url def test_no_path(dvc): config = {"url": "ssh://127.0.0.1"} remote = SSHRemote(dvc, config) - assert remote.path_info.path == "" + assert remote.tree.path_info.path == "" mock_ssh_config = """ @@ -69,9 +69,9 @@ def test_ssh_host_override_from_config( ): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.path_info.host == expected_host + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.path_info.host == expected_host @pytest.mark.parametrize( @@ -97,9 +97,9 @@ def test_ssh_host_override_from_config( def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.path_info.user == expected_user + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.path_info.user == expected_user @pytest.mark.parametrize( @@ -108,7 +108,7 @@ def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): ({"url": "ssh://example.com:2222"}, 2222), ({"url": "ssh://example.com"}, 1234), ({"url": "ssh://example.com", "port": 4321}, 4321), - ({"url": "ssh://not_in_ssh_config.com"}, SSHRemote.DEFAULT_PORT), + ({"url": "ssh://not_in_ssh_config.com"}, SSHRemoteTree.DEFAULT_PORT), ({"url": "ssh://not_in_ssh_config.com:2222"}, 2222), ({"url": "ssh://not_in_ssh_config.com:2222", "port": 4321}, 4321), ], @@ -122,8 +122,8 @@ def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): def test_ssh_port(mock_file, mock_exists, dvc, config, expected_port): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) assert remote.path_info.port == expected_port @@ -157,9 +157,9 @@ def test_ssh_port(mock_file, mock_exists, dvc, config, expected_port): def test_ssh_keyfile(mock_file, mock_exists, dvc, config, expected_keyfile): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.keyfile == expected_keyfile + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.keyfile == expected_keyfile @pytest.mark.parametrize( @@ -179,9 +179,9 @@ def test_ssh_keyfile(mock_file, mock_exists, dvc, config, expected_keyfile): def test_ssh_gss_auth(mock_file, mock_exists, dvc, config, expected_gss_auth): remote = SSHRemote(dvc, config) - mock_exists.assert_called_with(SSHRemote.ssh_config_filename()) - mock_file.assert_called_with(SSHRemote.ssh_config_filename()) - assert remote.gss_auth == expected_gss_auth + mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) + mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) + assert remote.tree.gss_auth == expected_gss_auth def test_hardlink_optimization(dvc, tmp_dir, ssh_server): diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index b5462193da..b49f2ee631 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -19,8 +19,8 @@ def test_init_env_var(monkeypatch, dvc): config = {"url": "azure://"} remote = AzureRemote(dvc, config) - assert remote.path_info == "azure://" + container_name - assert remote.connection_string == connection_string + assert remote.tree.path_info == "azure://" + container_name + assert remote.tree.connection_string == connection_string def test_init(dvc): @@ -28,8 +28,8 @@ def test_init(dvc): url = f"azure://{container_name}/{prefix}" config = {"url": url, "connection_string": connection_string} remote = AzureRemote(dvc, config) - assert remote.path_info == url - assert remote.connection_string == connection_string + assert remote.tree.path_info == url + assert remote.tree.connection_string == connection_string def test_get_file_checksum(tmp_dir): @@ -39,8 +39,8 @@ def test_get_file_checksum(tmp_dir): tmp_dir.gen("foo", "foo") remote = AzureRemote(None, {}) - to_info = remote.path_cls(Azure.get_url()) - remote.upload(PathInfo("foo"), to_info) + to_info = remote.tree.PATH_CLS(Azure.get_url()) + remote.tree.upload(PathInfo("foo"), to_info) assert remote.tree.exists(to_info) checksum = remote.get_file_checksum(to_info) assert checksum diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index fa2325fb2b..481f9f3f1f 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -104,7 +104,7 @@ def test_cache_exists(object_exists, traverse, dvc): ) def test_cache_checksums_traverse(path_to_checksum, cache_checksums, dvc): remote = BaseRemote(dvc, {}) - remote.path_info = PathInfo("foo") + remote.tree.path_info = PathInfo("foo") # parallel traverse size = 256 / remote.JOBS * remote.LIST_OBJECT_PAGE_SIZE @@ -129,7 +129,7 @@ def test_cache_checksums_traverse(path_to_checksum, cache_checksums, dvc): def test_cache_checksums(dvc): remote = BaseRemote(dvc, {}) - remote.path_info = PathInfo("foo") + remote.tree.path_info = PathInfo("foo") with mock.patch.object( remote, "list_cache_paths", return_value=["12/3456", "bar"] diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index aeda479c10..b8b5ed83e8 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -2,7 +2,7 @@ import pytest -from dvc.remote.gdrive import GDriveAuthError, GDriveRemote +from dvc.remote.gdrive import GDriveAuthError, GDriveRemote, GDriveRemoteTree USER_CREDS_TOKEN_REFRESH_ERROR = '{"access_token": "", "client_id": "", "client_secret": "", "refresh_token": "", "token_expiry": "", "token_uri": "https://oauth2.googleapis.com/token", "user_agent": null, "revoke_uri": "https://oauth2.googleapis.com/revoke", "id_token": null, "id_token_jwt": null, "token_response": {"access_token": "", "expires_in": 3600, "scope": "https://www.googleapis.com/auth/drive.appdata https://www.googleapis.com/auth/drive", "token_type": "Bearer"}, "scopes": ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/drive.appdata"], "token_info_uri": "https://oauth2.googleapis.com/tokeninfo", "invalid": true, "_class": "OAuth2Credentials", "_module": "oauth2client.client"}' # noqa: E501 @@ -18,20 +18,20 @@ class TestRemoteGDrive: def test_init(self, dvc): remote = GDriveRemote(dvc, self.CONFIG) - assert str(remote.path_info) == self.CONFIG["url"] + assert str(remote.tree.path_info) == self.CONFIG["url"] def test_drive(self, dvc): remote = GDriveRemote(dvc, self.CONFIG) os.environ[ - GDriveRemote.GDRIVE_CREDENTIALS_DATA + GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_TOKEN_REFRESH_ERROR with pytest.raises(GDriveAuthError): - remote._drive + remote.tree._drive - os.environ[GDriveRemote.GDRIVE_CREDENTIALS_DATA] = "" + os.environ[GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA] = "" remote = GDriveRemote(dvc, self.CONFIG) os.environ[ - GDriveRemote.GDRIVE_CREDENTIALS_DATA + GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_MISSED_KEY_ERROR with pytest.raises(GDriveAuthError): - remote._drive + remote.tree._drive diff --git a/tests/unit/remote/test_gs.py b/tests/unit/remote/test_gs.py index 3f8643011c..cee6a004b3 100644 --- a/tests/unit/remote/test_gs.py +++ b/tests/unit/remote/test_gs.py @@ -18,16 +18,16 @@ def test_init(dvc): remote = GSRemote(dvc, CONFIG) - assert remote.path_info == URL - assert remote.projectname == PROJECT - assert remote.credentialpath == CREDENTIALPATH + assert remote.tree.path_info == URL + assert remote.tree.projectname == PROJECT + assert remote.tree.credentialpath == CREDENTIALPATH @mock.patch("google.cloud.storage.Client.from_service_account_json") def test_gs(mock_client, dvc): remote = GSRemote(dvc, CONFIG) - assert remote.credentialpath - remote.gs() + assert remote.tree.credentialpath + remote.tree.gs() mock_client.assert_called_once_with(CREDENTIALPATH) @@ -36,7 +36,7 @@ def test_gs_no_credspath(mock_client, dvc): config = CONFIG.copy() del config["credentialpath"] remote = GSRemote(dvc, config) - remote.gs() + remote.tree.gs() mock_client.assert_called_with(PROJECT) diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 06998f22f4..b8f0d15d6c 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -14,7 +14,7 @@ def test_download_fails_on_error_code(dvc): remote = HTTPRemote(dvc, config) with pytest.raises(HTTPError): - remote._download(URLInfo(url) / "missing.txt", "missing.txt") + remote.tree._download(URLInfo(url) / "missing.txt", "missing.txt") def test_public_auth_method(dvc): @@ -27,7 +27,7 @@ def test_public_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() is None + assert remote.tree._auth_method() is None def test_basic_auth_method(dvc): @@ -46,8 +46,8 @@ def test_basic_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() == auth - assert isinstance(remote.auth_method(), HTTPBasicAuth) + assert remote.tree._auth_method() == auth + assert isinstance(remote.tree._auth_method(), HTTPBasicAuth) def test_digest_auth_method(dvc): @@ -66,8 +66,8 @@ def test_digest_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() == auth - assert isinstance(remote.auth_method(), HTTPDigestAuth) + assert remote.tree._auth_method() == auth + assert isinstance(remote.tree._auth_method(), HTTPDigestAuth) def test_custom_auth_method(dvc): @@ -83,6 +83,6 @@ def test_custom_auth_method(dvc): remote = HTTPRemote(dvc, config) - assert remote.auth_method() is None - assert header in remote.headers - assert remote.headers[header] == password + assert remote.tree._auth_method() is None + assert header in remote.tree.headers + assert remote.tree.headers[header] == password diff --git a/tests/unit/remote/test_oss.py b/tests/unit/remote/test_oss.py index d6bd515392..3bffe14a43 100644 --- a/tests/unit/remote/test_oss.py +++ b/tests/unit/remote/test_oss.py @@ -16,7 +16,7 @@ def test_init(dvc): "oss_endpoint": endpoint, } remote = OSSRemote(dvc, config) - assert remote.path_info == url - assert remote.endpoint == endpoint - assert remote.key_id == key_id - assert remote.key_secret == key_secret + assert remote.tree.path_info == url + assert remote.tree.endpoint == endpoint + assert remote.tree.key_id == key_id + assert remote.tree.key_secret == key_secret diff --git a/tests/unit/remote/test_remote.py b/tests/unit/remote/test_remote.py index 8e0f82a5d8..34dc6948c6 100644 --- a/tests/unit/remote/test_remote.py +++ b/tests/unit/remote/test_remote.py @@ -35,7 +35,7 @@ def test_makedirs_not_create_for_top_level_path(remote_cls, dvc, mocker): remote = remote_cls(dvc, {"url": url}) mocked_client = mocker.PropertyMock() # we use remote clients with same name as scheme to interact with remote - mocker.patch.object(remote_cls, remote.scheme, mocked_client) + mocker.patch.object(remote_cls.TREE_CLS, remote.scheme, mocked_client) remote.tree.makedirs(remote.path_info) assert not mocked_client.called diff --git a/tests/unit/remote/test_remote_tree.py b/tests/unit/remote/test_remote_tree.py index fc9f557df2..e9b8dd213e 100644 --- a/tests/unit/remote/test_remote_tree.py +++ b/tests/unit/remote/test_remote_tree.py @@ -3,7 +3,7 @@ import pytest from dvc.path_info import PathInfo -from dvc.remote.s3 import S3Remote +from dvc.remote.s3 import S3Remote, S3RemoteTree from dvc.utils.fs import walk_files from tests.remotes import GCP, S3Mocked @@ -88,7 +88,7 @@ def test_walk_files(remote): @pytest.mark.parametrize("remote", [S3Mocked], indirect=True) def test_copy_preserve_etag_across_buckets(remote, dvc): - s3 = remote.s3 + s3 = remote.tree.s3 s3.create_bucket(Bucket="another") another = S3Remote(dvc, {"url": "s3://another", "region": "us-east-1"}) @@ -98,8 +98,8 @@ def test_copy_preserve_etag_across_buckets(remote, dvc): remote.tree.copy(from_info, to_info) - from_etag = S3Remote.get_etag(s3, from_info.bucket, from_info.path) - to_etag = S3Remote.get_etag(s3, "another", "foo") + from_etag = S3RemoteTree.get_etag(s3, from_info.bucket, from_info.path) + to_etag = S3RemoteTree.get_etag(s3, "another", "foo") assert from_etag == to_etag @@ -141,7 +141,7 @@ def test_isfile(remote): def test_download_dir(remote, tmpdir): path = str(tmpdir / "data") to_info = PathInfo(path) - remote.download(remote.path_info / "data", to_info) + remote.tree.download(remote.path_info / "data", to_info) assert os.path.isdir(path) data_dir = tmpdir / "data" assert len(list(walk_files(path))) == 7 diff --git a/tests/unit/remote/test_s3.py b/tests/unit/remote/test_s3.py index bf726db616..d61c0d735d 100644 --- a/tests/unit/remote/test_s3.py +++ b/tests/unit/remote/test_s3.py @@ -22,7 +22,7 @@ def test_init(dvc): config = {"url": url} remote = S3Remote(dvc, config) - assert remote.path_info == url + assert remote.tree.path_info == url def test_grants(dvc): @@ -34,16 +34,16 @@ def test_grants(dvc): "grant_full_control": "id=full-control-permission-id", } remote = S3Remote(dvc, config) + tree = remote.tree assert ( - remote.extra_args["GrantRead"] + tree.extra_args["GrantRead"] == "id=read-permission-id,id=other-read-permission-id" ) - assert remote.extra_args["GrantReadACP"] == "id=read-acp-permission-id" - assert remote.extra_args["GrantWriteACP"] == "id=write-acp-permission-id" + assert tree.extra_args["GrantReadACP"] == "id=read-acp-permission-id" + assert tree.extra_args["GrantWriteACP"] == "id=write-acp-permission-id" assert ( - remote.extra_args["GrantFullControl"] - == "id=full-control-permission-id" + tree.extra_args["GrantFullControl"] == "id=full-control-permission-id" ) @@ -57,4 +57,4 @@ def test_grants_mutually_exclusive_acl_error(dvc, grants): def test_sse_kms_key_id(dvc): remote = S3Remote(dvc, {"url": url, "sse_kms_key_id": "key"}) - assert remote.extra_args["SSEKMSKeyId"] == "key" + assert remote.tree.extra_args["SSEKMSKeyId"] == "key"