diff --git a/dvc/cache.py b/dvc/cache.py index 80a566b912..a03912c470 100644 --- a/dvc/cache.py +++ b/dvc/cache.py @@ -28,13 +28,13 @@ def _make_remote_property(name): """ def getter(self): - from dvc.remote import Remote + from dvc.remote import Cache as CloudCache remote = self.config.get(name) if not remote: return None - return Remote(self.repo, name=remote) + return CloudCache(self.repo, name=remote) getter.__name__ = name return cached_property(getter) @@ -50,7 +50,7 @@ class Cache: CACHE_DIR = "cache" def __init__(self, repo): - from dvc.remote import Remote + from dvc.remote import Cache as CloudCache self.repo = repo self.config = config = repo.config["cache"] @@ -62,7 +62,7 @@ def __init__(self, repo): else: settings = {**config, "url": config["dir"]} - self.local = Remote(repo, **settings) + self.local = CloudCache(repo, **settings) s3 = _make_remote_property("s3") gs = _make_remote_property("gs") diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index f76344f119..afa1f3fbe7 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -64,9 +64,7 @@ def _get_checksum(self, locked=True): # We are polluting our repo cache with some dir listing here if tree.isdir(path): - return self.repo.cache.local.get_dir_checksum( - path, tree=tree - ) + return self.repo.cache.local.get_checksum(path, tree) return tree.get_file_checksum(path) def status(self): diff --git a/dvc/external_repo.py b/dvc/external_repo.py index 2c66be806c..2efa0414c9 100644 --- a/dvc/external_repo.py +++ b/dvc/external_repo.py @@ -126,8 +126,9 @@ def download_update(result): raise PathMissingError(path, self.url) save_info = self.local_cache.save( path, + self.repo_tree, None, - tree=self.repo_tree, + save_link=False, download_callback=download_update, ) save_infos.append(save_info) diff --git a/dvc/output/base.py b/dvc/output/base.py index de1e4f403d..e57ae6e8fa 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -267,7 +267,7 @@ def save(self): def commit(self): if self.use_cache: - self.cache.save(self.path_info, self.info) + self.cache.save(self.path_info, self.cache.tree, self.info) def dumpd(self): ret = copy(self.info) diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 9c71b18052..4ea41c4ee9 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -1,16 +1,25 @@ import posixpath from urllib.parse import urlparse -from dvc.remote.azure import AzureRemote +from dvc.remote.azure import AzureCache, AzureRemote from dvc.remote.gdrive import GDriveRemote -from dvc.remote.gs import GSRemote -from dvc.remote.hdfs import HDFSRemote +from dvc.remote.gs import GSCache, GSRemote +from dvc.remote.hdfs import HDFSCache, HDFSRemote from dvc.remote.http import HTTPRemote from dvc.remote.https import HTTPSRemote -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalCache, LocalRemote from dvc.remote.oss import OSSRemote -from dvc.remote.s3 import S3Remote -from dvc.remote.ssh import SSHRemote +from dvc.remote.s3 import S3Cache, S3Remote +from dvc.remote.ssh import SSHCache, SSHRemote + +CACHES = [ + AzureCache, + GSCache, + HDFSCache, + S3Cache, + SSHCache, + # LocalCache is the default +] REMOTES = [ AzureRemote, @@ -26,21 +35,30 @@ ] -def _get(remote_conf): - for remote in REMOTES: +def _get(remote_conf, remotes, default): + for remote in remotes: if remote.supported(remote_conf): return remote - return LocalRemote + return default -def Remote(repo, **kwargs): +def _get_conf(repo, **kwargs): name = kwargs.get("name") if name: remote_conf = repo.config["remote"][name.lower()] else: remote_conf = kwargs - remote_conf = _resolve_remote_refs(repo.config, remote_conf) - return _get(remote_conf)(repo, remote_conf) + return _resolve_remote_refs(repo.config, remote_conf) + + +def Remote(repo, **kwargs): + remote_conf = _get_conf(repo, **kwargs) + return _get(remote_conf, REMOTES, LocalRemote)(repo, remote_conf) + + +def Cache(repo, **kwargs): + remote_conf = _get_conf(repo, **kwargs) + return _get(remote_conf, CACHES, LocalCache)(repo, remote_conf) def _resolve_remote_refs(config, remote_conf): diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 4feb57f40b..46606d11a9 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -7,7 +7,7 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree +from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin from dvc.scheme import Schemes logger = logging.getLogger(__name__) @@ -108,6 +108,9 @@ def remove(self, path_info): logger.debug(f"Removing {path_info}") self.blob_service.delete_blob(path_info.bucket, path_info.path) + def get_file_checksum(self, path_info): + return self.get_etag(path_info) + def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs ): @@ -134,10 +137,11 @@ def _download( class AzureRemote(BaseRemote): scheme = Schemes.AZURE REQUIRES = {"azure-storage-blob": "azure.storage.blob"} + TREE_CLS = AzureRemoteTree PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 5000 - TREE_CLS = AzureRemoteTree - def get_file_checksum(self, path_info): - return self.tree.get_etag(path_info) + +class AzureCache(AzureRemote, CacheMixin): + pass diff --git a/dvc/remote/base.py b/dvc/remote/base.py index b84791ecaa..625f0fcd18 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -25,7 +25,6 @@ from dvc.progress import Tqdm from dvc.remote.index import RemoteIndex, RemoteIndexNoop from dvc.remote.slow_link_detection import slow_link_guard -from dvc.scm.tree import is_working_tree from dvc.state import StateNoop from dvc.utils import tmp_fname from dvc.utils.fs import makedirs, move @@ -85,6 +84,7 @@ def wrapper(remote_obj, *args, **kwargs): class BaseRemoteTree: SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} PATH_CLS = URLInfo + CHECKSUM_DIR_SUFFIX = ".dir" def __init__(self, remote, config): self.remote = remote @@ -103,6 +103,14 @@ def dir_mode(self): def scheme(self): return self.remote.scheme + @property + def state(self): + return self.remote.state + + @property + def cache(self): + return self.remote.cache + def open(self, path_info, mode="r", encoding=None): if hasattr(self, "_generate_download_url"): get_url = partial(self._generate_download_url, path_info) @@ -164,6 +172,140 @@ def hardlink(self, from_info, to_info): def reflink(self, from_info, to_info): raise RemoteActionNotImplemented("reflink", self.scheme) + @classmethod + def is_dir_checksum(cls, checksum): + if not checksum: + return False + return checksum.endswith(cls.CHECKSUM_DIR_SUFFIX) + + def get_checksum(self, path_info, tree=None, **kwargs): + assert isinstance(path_info, str) or path_info.scheme == self.scheme + + if not tree: + tree = self + + if not tree.exists(path_info): + return None + + if tree == self: + checksum = self.state.get(path_info) + else: + checksum = None + + # If we have dir checksum in state db, but dir cache file is lost, + # then we need to recollect the dir via .get_dir_checksum() call below, + # see https://github.com/iterative/dvc/issues/2219 for context + if ( + checksum + and self.is_dir_checksum(checksum) + and not tree.exists(self.cache.checksum_to_path_info(checksum)) + ): + checksum = None + + if checksum: + return checksum + + if tree.isdir(path_info): + checksum = self.get_dir_checksum(path_info, tree, **kwargs) + else: + checksum = tree.get_file_checksum(path_info) + + if checksum and self.exists(path_info): + self.state.save(path_info, checksum) + + return checksum + + def get_file_checksum(self, path_info): + raise NotImplementedError + + def get_dir_checksum(self, path_info, tree, **kwargs): + if not self.cache: + raise RemoteCacheRequiredError(path_info) + + dir_info = self._collect_dir(path_info, tree, **kwargs) + return self._save_dir_info(dir_info, path_info) + + def _calculate_checksums(self, file_infos, tree): + file_infos = list(file_infos) + with Tqdm( + total=len(file_infos), + unit="md5", + desc="Computing file/dir hashes (only done once)", + ) as pbar: + worker = pbar.wrap_fn(tree.get_file_checksum) + with ThreadPoolExecutor( + max_workers=self.remote.checksum_jobs + ) as executor: + tasks = executor.map(worker, file_infos) + checksums = dict(zip(file_infos, tasks)) + return checksums + + def _collect_dir(self, path_info, tree, **kwargs): + file_infos = set() + + for fname in tree.walk_files(path_info, **kwargs): + if DvcIgnore.DVCIGNORE_FILE == fname.name: + raise DvcIgnoreInCollectedDirError(fname.parent) + + file_infos.add(fname) + + checksums = {fi: self.state.get(fi) for fi in file_infos} + not_in_state = { + fi for fi, checksum in checksums.items() if checksum is None + } + + new_checksums = self._calculate_checksums(not_in_state, tree) + checksums.update(new_checksums) + + result = [ + { + self.remote.PARAM_CHECKSUM: checksums[fi], + # NOTE: this is lossy transformation: + # "hey\there" -> "hey/there" + # "hey/there" -> "hey/there" + # The latter is fine filename on Windows, which + # will transform to dir/file on back transform. + # + # Yes, this is a BUG, as long as we permit "/" in + # filenames on Windows and "\" on Unix + self.remote.PARAM_RELPATH: fi.relative_to( + path_info + ).as_posix(), + } + for fi in file_infos + ] + + # Sorting the list by path to ensure reproducibility + return sorted(result, key=itemgetter(self.remote.PARAM_RELPATH)) + + def _save_dir_info(self, dir_info, path_info): + checksum, tmp_info = self._get_dir_info_checksum(dir_info) + new_info = self.cache.checksum_to_path_info(checksum) + if self.cache.changed_cache_file(checksum): + self.cache.tree.makedirs(new_info.parent) + self.cache.tree.move( + tmp_info, new_info, mode=self.remote.CACHE_MODE + ) + + if self.exists(path_info): + self.state.save(path_info, checksum) + self.state.save(new_info, checksum) + + return checksum + + def _get_dir_info_checksum(self, dir_info): + tmp = tempfile.NamedTemporaryFile(delete=False).name + with open(tmp, "w+") as fobj: + json.dump(dir_info, fobj, sort_keys=True) + + tree = self.cache.tree + from_info = PathInfo(tmp) + to_info = tree.path_info / tmp_fname("") + tree.upload(from_info, to_info, no_progress_bar=True) + + checksum = tree.get_file_checksum(to_info) + self.CHECKSUM_DIR_SUFFIX + return checksum, to_info + def upload(self, from_info, to_info, name=None, no_progress_bar=False): if not hasattr(self, "_upload"): raise RemoteActionNotImplemented("upload", self.scheme) @@ -277,6 +419,8 @@ def _download_file( class BaseRemote: + """Base cloud remote class.""" + scheme = "base" REQUIRES = {} JOBS = 4 * cpu_count() @@ -389,188 +533,400 @@ def supported(cls, config): def cache(self): return getattr(self.repo.cache, self.scheme) - def get_file_checksum(self, path_info): - raise NotImplementedError + @classmethod + def is_dir_checksum(cls, checksum): + return cls.TREE_CLS.is_dir_checksum(checksum) - def _calculate_checksums(self, file_infos): - file_infos = list(file_infos) - with Tqdm( - total=len(file_infos), - unit="md5", - desc="Computing file/dir hashes (only done once)", - ) as pbar: - worker = pbar.wrap_fn(self.get_file_checksum) - with ThreadPoolExecutor( - max_workers=self.checksum_jobs - ) as executor: - tasks = executor.map(worker, file_infos) - checksums = dict(zip(file_infos, tasks)) - return checksums + def get_checksum(self, path_info, **kwargs): + return self.tree.get_checksum(path_info, **kwargs) - def _collect_dir(self, path_info, tree=None, save_tree=False, **kwargs): - file_infos = set() + def checksum_to_path_info(self, checksum): + return self.path_info / checksum[0:2] / checksum[2:] - if tree: - walk_files = tree.walk_files - else: - walk_files = self.tree.walk_files + def path_to_checksum(self, path): + parts = self.tree.PATH_CLS(path).parts[-2:] - for fname in walk_files(path_info, **kwargs): - if DvcIgnore.DVCIGNORE_FILE == fname.name: - raise DvcIgnoreInCollectedDirError(fname.parent) + if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): + raise ValueError(f"Bad cache file path '{path}'") - file_infos.add(fname) + return "".join(parts) - if tree: - checksums = {fi: tree.get_file_checksum(fi) for fi in file_infos} - if save_tree: - for fi, checksum in checksums.items(): - self._save_file(fi, checksum, tree=tree, **kwargs) - else: - checksums = {fi: self.state.get(fi) for fi in file_infos} - not_in_state = { - fi for fi, checksum in checksums.items() if checksum is None - } + def save_info(self, path_info, tree=None, **kwargs): + return { + self.PARAM_CHECKSUM: self.tree.get_checksum( + path_info, tree=tree, **kwargs + ) + } - new_checksums = self._calculate_checksums(not_in_state) - checksums.update(new_checksums) + def open(self, *args, **kwargs): + return self.tree.open(*args, **kwargs) - result = [ - { - self.PARAM_CHECKSUM: checksums[fi], - # NOTE: this is lossy transformation: - # "hey\there" -> "hey/there" - # "hey/there" -> "hey/there" - # The latter is fine filename on Windows, which - # will transform to dir/file on back transform. - # - # Yes, this is a BUG, as long as we permit "/" in - # filenames on Windows and "\" on Unix - self.PARAM_RELPATH: fi.relative_to(path_info).as_posix(), - } - for fi in file_infos - ] + @staticmethod + def protect(path_info): + pass - # Sorting the list by path to ensure reproducibility - return sorted(result, key=itemgetter(self.PARAM_RELPATH)) + def is_protected(self, path_info): + return False - def get_dir_checksum(self, path_info, tree=None): - if not self.cache: - raise RemoteCacheRequiredError(path_info) + @staticmethod + def unprotect(path_info): + pass - dir_info = self._collect_dir(path_info, tree=None) - if tree: - # don't save state entry for path_info if it is a tree path - path_info = None - return self._save_dir_info(dir_info, path_info) + def list_paths(self, prefix=None, progress_callback=None): + if prefix: + if len(prefix) > 2: + path_info = self.path_info / prefix[:2] / prefix[2:] + else: + path_info = self.path_info / prefix[:2] + else: + path_info = self.path_info + if progress_callback: + for file_info in self.tree.walk_files(path_info): + progress_callback() + yield file_info.path + else: + yield from self.tree.walk_files(path_info) - def _save_dir_info(self, dir_info, path_info=None): - checksum, tmp_info = self._get_dir_info_checksum(dir_info) - new_info = self.cache.checksum_to_path_info(checksum) - if self.cache.changed_cache_file(checksum): - self.cache.tree.makedirs(new_info.parent) - self.cache.tree.move(tmp_info, new_info, mode=self.CACHE_MODE) + def list_checksums(self, prefix=None, progress_callback=None): + """Iterate over remote checksums. - if path_info: - self.state.save(path_info, checksum) - self.state.save(new_info, checksum) + If `prefix` is specified, only checksums which begin with `prefix` + will be returned. + """ + for path in self.list_paths(prefix, progress_callback): + try: + yield self.path_to_checksum(path) + except ValueError: + logger.debug( + "'%s' doesn't look like a cache file, skipping", path + ) - return checksum + def all(self, jobs=None, name=None): + """Iterate over all checksums in the remote. - def _get_dir_info_checksum(self, dir_info): - tmp = tempfile.NamedTemporaryFile(delete=False).name - with open(tmp, "w+") as fobj: - json.dump(dir_info, fobj, sort_keys=True) + Checksums will be fetched in parallel threads according to prefix + (except for small remotes) and a progress bar will be displayed. + """ + logger.debug( + "Fetching all checksums from '{}'".format( + name if name else "remote cache" + ) + ) - from_info = PathInfo(tmp) - to_info = self.cache.path_info / tmp_fname("") - self.cache.tree.upload(from_info, to_info, no_progress_bar=True) + if not self.CAN_TRAVERSE: + return self.list_checksums() - checksum = self.get_file_checksum(to_info) + self.CHECKSUM_DIR_SUFFIX - return checksum, to_info + remote_size, remote_checksums = self._estimate_remote_size(name=name) + return self._list_checksums_traverse( + remote_size, remote_checksums, jobs, name + ) - def get_dir_cache(self, checksum): - assert checksum + def checksums_exist(self, checksums, jobs=None, name=None): + """Check if the given checksums are stored in the remote. - dir_info = self._dir_info.get(checksum) - if dir_info: - return dir_info + There are two ways of performing this check: - try: - dir_info = self.load_dir_cache(checksum) - except DirCacheError: - dir_info = [] + - Traverse method: Get a list of all the files in the remote + (traversing the cache directory) and compare it with + the given checksums. Cache entries will be retrieved in parallel + threads according to prefix (i.e. entries starting with, "00...", + "01...", and so on) and a progress bar will be displayed. - self._dir_info[checksum] = dir_info - return dir_info + - Exists method: For each given checksum, run the `exists` + method and filter the checksums that aren't on the remote. + This is done in parallel threads. + It also shows a progress bar when performing the check. - def load_dir_cache(self, checksum): - path_info = self.checksum_to_path_info(checksum) + The reason for such an odd logic is that most of the remotes + take much shorter time to just retrieve everything they have under + a certain prefix (e.g. s3, gs, ssh, hdfs). Other remotes that can + check if particular file exists much quicker, use their own + implementation of checksums_exist (see ssh, local). - try: - with self.cache.open(path_info, "r") as fobj: - d = json.load(fobj) - except (ValueError, FileNotFoundError) as exc: - raise DirCacheError(checksum) from exc + Which method to use will be automatically determined after estimating + the size of the remote cache, and comparing the estimated size with + len(checksums). To estimate the size of the remote cache, we fetch + a small subset of cache entries (i.e. entries starting with "00..."). + Based on the number of entries in that subset, the size of the full + cache can be estimated, since the cache is evenly distributed according + to checksum. - if not isinstance(d, list): - logger.error( - "dir cache file format error '%s' [skipping the file]", - path_info, - ) - return [] + Returns: + A list with checksums that were found in the remote + """ + # Remotes which do not use traverse prefix should override + # checksums_exist() (see ssh, local) + assert self.TRAVERSE_PREFIX_LEN >= 2 - if self.tree.PATH_CLS == WindowsPathInfo: - # only need to convert it for Windows - for info in d: - # NOTE: here is a BUG, see comment to .as_posix() below - info[self.PARAM_RELPATH] = info[self.PARAM_RELPATH].replace( - "/", self.tree.PATH_CLS.sep + checksums = set(checksums) + indexed_checksums = set(self.index.intersection(checksums)) + checksums -= indexed_checksums + logger.debug( + "Matched '{}' indexed checksums".format(len(indexed_checksums)) + ) + if not checksums: + return indexed_checksums + + if len(checksums) == 1 or not self.CAN_TRAVERSE: + remote_checksums = self._list_checksums_exists( + checksums, jobs, name + ) + return list(indexed_checksums) + remote_checksums + + # Max remote size allowed for us to use traverse method + remote_size, remote_checksums = self._estimate_remote_size( + checksums, name + ) + + traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE + # For sufficiently large remotes, traverse must be weighted to account + # for performance overhead from large lists/sets. + # From testing with S3, for remotes with 1M+ files, object_exists is + # faster until len(checksums) is at least 10k~100k + if remote_size > self.TRAVERSE_THRESHOLD_SIZE: + traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER + else: + traverse_weight = traverse_pages + if len(checksums) < traverse_weight: + logger.debug( + "Large remote ('{}' checksums < '{}' traverse weight), " + "using object_exists for remaining checksums".format( + len(checksums), traverse_weight + ) + ) + return ( + list(indexed_checksums) + + list(checksums & remote_checksums) + + self._list_checksums_exists( + checksums - remote_checksums, jobs, name ) + ) - return d + logger.debug( + "Querying '{}' checksums via traverse".format(len(checksums)) + ) + remote_checksums = set( + self._list_checksums_traverse( + remote_size, remote_checksums, jobs, name + ) + ) + return list(indexed_checksums) + list( + checksums & set(remote_checksums) + ) - @classmethod - def is_dir_checksum(cls, checksum): - if not checksum: - return False - return checksum.endswith(cls.CHECKSUM_DIR_SUFFIX) + def _checksums_with_limit( + self, limit, prefix=None, progress_callback=None + ): + count = 0 + for checksum in self.list_checksums(prefix, progress_callback): + yield checksum + count += 1 + if count > limit: + logger.debug( + "`list_checksums()` returned max '{}' checksums, " + "skipping remaining results".format(limit) + ) + return - def get_checksum(self, path_info): - assert isinstance(path_info, str) or path_info.scheme == self.scheme + def _max_estimation_size(self, checksums): + # Max remote size allowed for us to use traverse method + return max( + self.TRAVERSE_THRESHOLD_SIZE, + len(checksums) + / self.TRAVERSE_WEIGHT_MULTIPLIER + * self.LIST_OBJECT_PAGE_SIZE, + ) - if not self.tree.exists(path_info): - return None + def _estimate_remote_size(self, checksums=None, name=None): + """Estimate remote cache size based on number of entries beginning with + "00..." prefix. + """ + prefix = "0" * self.TRAVERSE_PREFIX_LEN + total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) + if checksums: + max_checksums = self._max_estimation_size(checksums) + else: + max_checksums = None - checksum = self.state.get(path_info) + with Tqdm( + desc="Estimating size of " + + (f"cache in '{name}'" if name else "remote cache"), + unit="file", + ) as pbar: - # If we have dir checksum in state db, but dir cache file is lost, - # then we need to recollect the dir via .get_dir_checksum() call below, - # see https://github.com/iterative/dvc/issues/2219 for context - if ( - checksum - and self.is_dir_checksum(checksum) - and not self.tree.exists( - self.cache.checksum_to_path_info(checksum) - ) - ): - checksum = None + def update(n=1): + pbar.update(n * total_prefixes) - if checksum: - return checksum + if max_checksums: + checksums = self._checksums_with_limit( + max_checksums / total_prefixes, prefix, update + ) + else: + checksums = self.list_checksums(prefix, update) - if self.tree.isdir(path_info): - checksum = self.get_dir_checksum(path_info) + remote_checksums = set(checksums) + if remote_checksums: + remote_size = total_prefixes * len(remote_checksums) + else: + remote_size = total_prefixes + logger.debug(f"Estimated remote size: {remote_size} files") + return remote_size, remote_checksums + + def _list_checksums_traverse( + self, remote_size, remote_checksums, jobs=None, name=None + ): + """Iterate over all checksums in the remote cache. + Checksums are fetched in parallel according to prefix, except in + cases where the remote size is very small. + + All checksums from the remote (including any from the size + estimation step passed via the `remote_checksums` argument) will be + returned. + + NOTE: For large remotes the list of checksums will be very + big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list) + and we don't really need all of it at the same time, so it makes + sense to use a generator to gradually iterate over it, without + keeping all of it in memory. + """ + num_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE + if num_pages < 256 / self.JOBS: + # Fetching prefixes in parallel requires at least 255 more + # requests, for small enough remotes it will be faster to fetch + # entire cache without splitting it into prefixes. + # + # NOTE: this ends up re-fetching checksums that were already + # fetched during remote size estimation + traverse_prefixes = [None] + initial = 0 else: - checksum = self.get_file_checksum(path_info) + yield from remote_checksums + initial = len(remote_checksums) + traverse_prefixes = [f"{i:02x}" for i in range(1, 256)] + if self.TRAVERSE_PREFIX_LEN > 2: + traverse_prefixes += [ + "{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN) + for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2)) + ] + with Tqdm( + desc="Querying " + + (f"cache in '{name}'" if name else "remote cache"), + total=remote_size, + initial=initial, + unit="file", + ) as pbar: - if checksum: - self.state.save(path_info, checksum) + def list_with_update(prefix): + return list( + self.list_checksums( + prefix=prefix, progress_callback=pbar.update + ) + ) - return checksum + with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + in_remote = executor.map(list_with_update, traverse_prefixes,) + yield from itertools.chain.from_iterable(in_remote) + + def _list_checksums_exists(self, checksums, jobs=None, name=None): + logger.debug( + "Querying {} checksums via object_exists".format(len(checksums)) + ) + with Tqdm( + desc="Querying " + + ("cache in " + name if name else "remote cache"), + total=len(checksums), + unit="file", + ) as pbar: - def save_info(self, path_info): - return {self.PARAM_CHECKSUM: self.get_checksum(path_info)} + def exists_with_progress(path_info): + ret = self.tree.exists(path_info) + pbar.update_msg(str(path_info)) + return ret + + with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + path_infos = map(self.checksum_to_path_info, checksums) + in_remote = executor.map(exists_with_progress, path_infos) + ret = list(itertools.compress(checksums, in_remote)) + return ret + + @index_locked + def gc(self, named_cache, jobs=None): + used = set(named_cache.scheme_keys("local")) + + if self.scheme != "": + used.update(named_cache.scheme_keys(self.scheme)) + + removed = False + # checksums must be sorted to ensure we always remove .dir files first + for checksum in sorted( + self.all(jobs, str(self.path_info)), + key=self.is_dir_checksum, + reverse=True, + ): + if checksum in used: + continue + path_info = self.checksum_to_path_info(checksum) + if self.is_dir_checksum(checksum): + # backward compatibility + self._remove_unpacked_dir(checksum) + self.tree.remove(path_info) + removed = True + if removed: + self.index.clear() + return removed + + def _remove_unpacked_dir(self, checksum): + pass + + +class CacheMixin: + """BaseRemote extensions for cache link/checkout operations.""" + + # Override to return path as a string instead of PathInfo for clouds + # which support string paths (see local) + def checksum_to_path(self, checksum): + return self.checksum_to_path_info(checksum) + + def get_dir_cache(self, checksum): + assert checksum + + dir_info = self._dir_info.get(checksum) + if dir_info: + return dir_info + + try: + dir_info = self.load_dir_cache(checksum) + except DirCacheError: + dir_info = [] + + self._dir_info[checksum] = dir_info + return dir_info + + def load_dir_cache(self, checksum): + path_info = self.checksum_to_path_info(checksum) + + try: + with self.cache.open(path_info, "r") as fobj: + d = json.load(fobj) + except (ValueError, FileNotFoundError) as exc: + raise DirCacheError(checksum) from exc + + if not isinstance(d, list): + logger.error( + "dir cache file format error '%s' [skipping the file]", + path_info, + ) + return [] + + if self.tree.PATH_CLS == WindowsPathInfo: + # only need to convert it for Windows + for info in d: + # NOTE: here is a BUG, see comment to .as_posix() below + info[self.PARAM_RELPATH] = info[self.PARAM_RELPATH].replace( + "/", self.tree.PATH_CLS.sep + ) + + return d def changed(self, path_info, checksum_info): """Checks if data has changed. @@ -669,25 +1025,11 @@ def _do_link(self, from_info, to_info, link_method): "Created '%s': %s -> %s", self.cache_types[0], from_info, to_info, ) - def _save_file( - self, path_info, checksum, save_link=True, tree=None, **kwargs - ): + def _save_file(self, path_info, tree, checksum, save_link=True, **kwargs): assert checksum cache_info = self.checksum_to_path_info(checksum) - if tree: - if self.changed_cache(checksum): - with tree.open(path_info, mode="rb") as fobj: - # if tree has fetch enabled, DVC out will be fetched on - # open and we do not need to read/copy any data - if not ( - tree.isdvc(path_info, strict=False) and tree.fetch - ): - self.tree.copy_fobj(fobj, cache_info) - callback = kwargs.get("download_callback") - if callback: - callback(1) - else: + if tree == self.tree: if self.changed_cache(checksum): self.tree.move(path_info, cache_info, mode=self.CACHE_MODE) self.link(cache_info, path_info) @@ -702,12 +1044,23 @@ def _save_file( if save_link: self.state.save_link(path_info) - - # we need to update path and cache, since in case of reflink, - # or copy cache type moving original file results in updates on - # next executed command, which causes md5 recalculation - if not tree or is_working_tree(tree): + # we need to update path and cache, since in case of reflink, + # or copy cache type moving original file results in updates on + # next executed command, which causes md5 recalculation self.state.save(path_info, checksum) + else: + if self.changed_cache(checksum): + with tree.open(path_info, mode="rb") as fobj: + # if tree has fetch enabled, DVC out will be fetched on + # open and we do not need to read/copy any data + if not ( + tree.isdvc(path_info, strict=False) and tree.fetch + ): + self.tree.copy_fobj(fobj, cache_info) + callback = kwargs.get("download_callback") + if callback: + callback(1) + self.state.save(cache_info, checksum) return {self.PARAM_CHECKSUM: checksum} @@ -733,169 +1086,46 @@ def _cache_is_copy(self, path_info): self.cache_type_confirmed = True return self.cache_types[0] == "copy" - def _save_dir( - self, path_info, checksum, save_link=True, tree=None, **kwargs - ): - if tree: - dir_info = self._collect_dir( - path_info, tree=tree, save_tree=True, **kwargs - ) - checksum = self._save_dir_info(dir_info) - else: - dir_info = self.get_dir_cache(checksum) - - for entry in Tqdm( - dir_info, desc="Saving " + path_info.name, unit="file" - ): - entry_info = path_info / entry[self.PARAM_RELPATH] - entry_checksum = entry[self.PARAM_CHECKSUM] - self._save_file(entry_info, entry_checksum, save_link=False) - - if save_link: - self.state.save_link(path_info) - - cache_info = self.checksum_to_path_info(checksum) - self.state.save(cache_info, checksum) - if not tree or is_working_tree(tree): - self.state.save(path_info, checksum) - return {self.PARAM_CHECKSUM: checksum} - - @staticmethod - def protect(path_info): - pass - - def save( - self, path_info, checksum_info, save_link=True, tree=None, **kwargs - ): - if path_info.scheme != self.scheme: - raise RemoteActionNotImplemented( - f"save {path_info.scheme} -> {self.scheme}", self.scheme, - ) - - if tree: - if tree.isdir(path_info): - # save checksum will be computed during tree walk - checksum = None - else: - checksum = tree.get_file_checksum(path_info) - else: - checksum = checksum_info[self.PARAM_CHECKSUM] - return self._save(path_info, checksum, save_link, tree, **kwargs) - - def _save(self, path_info, checksum, save_link=True, tree=None, **kwargs): - if tree: - logger.debug("Saving tree path '%s' to cache.", path_info) - else: - to_info = self.checksum_to_path_info(checksum) - logger.debug("Saving '%s' to '%s'.", path_info, to_info) - - if tree: - isdir = tree.isdir - save_link = False - else: - isdir = self.tree.isdir - - if isdir(path_info): - return self._save_dir( - path_info, checksum, save_link, tree, **kwargs - ) - return self._save_file(path_info, checksum, save_link, tree, **kwargs) - - def open(self, *args, **kwargs): - return self.tree.open(*args, **kwargs) - - def path_to_checksum(self, path): - parts = self.tree.PATH_CLS(path).parts[-2:] - - if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): - raise ValueError(f"Bad cache file path '{path}'") - - return "".join(parts) - - def checksum_to_path_info(self, checksum): - return self.path_info / checksum[0:2] / checksum[2:] - - # Return path as a string instead of PathInfo for remotes which support - # string paths (see local) - checksum_to_path = checksum_to_path_info - - def list_cache_paths(self, prefix=None, progress_callback=None): - if prefix: - if len(prefix) > 2: - path_info = self.path_info / prefix[:2] / prefix[2:] - else: - path_info = self.path_info / prefix[:2] - else: - path_info = self.path_info - if progress_callback: - for file_info in self.tree.walk_files(path_info): - progress_callback() - yield file_info.path - else: - yield from self.tree.walk_files(path_info) - - def cache_checksums(self, prefix=None, progress_callback=None): - """Iterate over remote cache checksums. - - If `prefix` is specified, only checksums which begin with `prefix` - will be returned. - """ - for path in self.list_cache_paths(prefix, progress_callback): - try: - yield self.path_to_checksum(path) - except ValueError: - logger.debug( - "'%s' doesn't look like a cache file, skipping", path - ) - - def all(self, jobs=None, name=None): - """Iterate over all checksums in the remote cache. - - Checksums will be fetched in parallel threads according to prefix - (except for small remotes) and a progress bar will be displayed. - """ - logger.debug( - "Fetching all checksums from '{}'".format( - name if name else "remote cache" + def _save_dir(self, path_info, tree, checksum, save_link=True, **kwargs): + dir_info = self.get_dir_cache(checksum) + for entry in Tqdm( + dir_info, desc="Saving " + path_info.name, unit="file" + ): + entry_info = path_info / entry[self.PARAM_RELPATH] + entry_checksum = entry[self.PARAM_CHECKSUM] + self._save_file( + entry_info, tree, entry_checksum, save_link=False, **kwargs ) - ) - if not self.CAN_TRAVERSE: - return self.cache_checksums() + if save_link: + self.state.save_link(path_info) + if self.tree.exists(path_info): + self.state.save(path_info, checksum) - remote_size, remote_checksums = self._estimate_cache_size(name=name) - return self._cache_checksums_traverse( - remote_size, remote_checksums, jobs, name - ) + cache_info = self.checksum_to_path_info(checksum) + self.state.save(cache_info, checksum) + return {self.PARAM_CHECKSUM: checksum} - @index_locked - def gc(self, named_cache, jobs=None): - used = set(named_cache.scheme_keys("local")) + def save(self, path_info, tree, checksum_info, save_link=True, **kwargs): + if path_info.scheme != self.scheme: + raise RemoteActionNotImplemented( + f"save {path_info.scheme} -> {self.scheme}", self.scheme, + ) - if self.scheme != "": - used.update(named_cache.scheme_keys(self.scheme)) + if not checksum_info: + checksum_info = self.save_info(path_info, tree=tree, **kwargs) + checksum = checksum_info[self.PARAM_CHECKSUM] + return self._save(path_info, tree, checksum, save_link, **kwargs) - removed = False - # checksums must be sorted to ensure we always remove .dir files first - for checksum in sorted( - self.all(jobs, str(self.path_info)), - key=self.is_dir_checksum, - reverse=True, - ): - if checksum in used: - continue - path_info = self.checksum_to_path_info(checksum) - if self.is_dir_checksum(checksum): - # backward compatibility - self._remove_unpacked_dir(checksum) - self.tree.remove(path_info) - removed = True - if removed: - self.index.clear() - return removed + def _save(self, path_info, tree, checksum, save_link=True, **kwargs): + to_info = self.checksum_to_path_info(checksum) + logger.debug("Saving '%s' to '%s'.", path_info, to_info) - def is_protected(self, path_info): - return False + if tree.isdir(path_info): + return self._save_dir( + path_info, tree, checksum, save_link, **kwargs + ) + return self._save_file(path_info, tree, checksum, save_link, **kwargs) def changed_cache_file(self, checksum): """Compare the given checksum with the (corresponding) actual one. @@ -965,232 +1195,6 @@ def changed_cache(self, checksum, path_info=None, filter_info=None): ) return self.changed_cache_file(checksum) - def cache_exists(self, checksums, jobs=None, name=None): - """Check if the given checksums are stored in the remote. - - There are two ways of performing this check: - - - Traverse method: Get a list of all the files in the remote - (traversing the cache directory) and compare it with - the given checksums. Cache entries will be retrieved in parallel - threads according to prefix (i.e. entries starting with, "00...", - "01...", and so on) and a progress bar will be displayed. - - - Exists method: For each given checksum, run the `exists` - method and filter the checksums that aren't on the remote. - This is done in parallel threads. - It also shows a progress bar when performing the check. - - The reason for such an odd logic is that most of the remotes - take much shorter time to just retrieve everything they have under - a certain prefix (e.g. s3, gs, ssh, hdfs). Other remotes that can - check if particular file exists much quicker, use their own - implementation of cache_exists (see ssh, local). - - Which method to use will be automatically determined after estimating - the size of the remote cache, and comparing the estimated size with - len(checksums). To estimate the size of the remote cache, we fetch - a small subset of cache entries (i.e. entries starting with "00..."). - Based on the number of entries in that subset, the size of the full - cache can be estimated, since the cache is evenly distributed according - to checksum. - - Returns: - A list with checksums that were found in the remote - """ - # Remotes which do not use traverse prefix should override - # cache_exists() (see ssh, local) - assert self.TRAVERSE_PREFIX_LEN >= 2 - - checksums = set(checksums) - indexed_checksums = set(self.index.intersection(checksums)) - checksums -= indexed_checksums - logger.debug( - "Matched '{}' indexed checksums".format(len(indexed_checksums)) - ) - if not checksums: - return indexed_checksums - - if len(checksums) == 1 or not self.CAN_TRAVERSE: - remote_checksums = self._cache_object_exists(checksums, jobs, name) - return list(indexed_checksums) + remote_checksums - - # Max remote size allowed for us to use traverse method - remote_size, remote_checksums = self._estimate_cache_size( - checksums, name - ) - - traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE - # For sufficiently large remotes, traverse must be weighted to account - # for performance overhead from large lists/sets. - # From testing with S3, for remotes with 1M+ files, object_exists is - # faster until len(checksums) is at least 10k~100k - if remote_size > self.TRAVERSE_THRESHOLD_SIZE: - traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER - else: - traverse_weight = traverse_pages - if len(checksums) < traverse_weight: - logger.debug( - "Large remote ('{}' checksums < '{}' traverse weight), " - "using object_exists for remaining checksums".format( - len(checksums), traverse_weight - ) - ) - return ( - list(indexed_checksums) - + list(checksums & remote_checksums) - + self._cache_object_exists( - checksums - remote_checksums, jobs, name - ) - ) - - logger.debug( - "Querying '{}' checksums via traverse".format(len(checksums)) - ) - remote_checksums = set( - self._cache_checksums_traverse( - remote_size, remote_checksums, jobs, name - ) - ) - return list(indexed_checksums) + list( - checksums & set(remote_checksums) - ) - - def _checksums_with_limit( - self, limit, prefix=None, progress_callback=None - ): - count = 0 - for checksum in self.cache_checksums(prefix, progress_callback): - yield checksum - count += 1 - if count > limit: - logger.debug( - "`cache_checksums()` returned max '{}' checksums, " - "skipping remaining results".format(limit) - ) - return - - def _max_estimation_size(self, checksums): - # Max remote size allowed for us to use traverse method - return max( - self.TRAVERSE_THRESHOLD_SIZE, - len(checksums) - / self.TRAVERSE_WEIGHT_MULTIPLIER - * self.LIST_OBJECT_PAGE_SIZE, - ) - - def _estimate_cache_size(self, checksums=None, name=None): - """Estimate remote cache size based on number of entries beginning with - "00..." prefix. - """ - prefix = "0" * self.TRAVERSE_PREFIX_LEN - total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) - if checksums: - max_checksums = self._max_estimation_size(checksums) - else: - max_checksums = None - - with Tqdm( - desc="Estimating size of " - + (f"cache in '{name}'" if name else "remote cache"), - unit="file", - ) as pbar: - - def update(n=1): - pbar.update(n * total_prefixes) - - if max_checksums: - checksums = self._checksums_with_limit( - max_checksums / total_prefixes, prefix, update - ) - else: - checksums = self.cache_checksums(prefix, update) - - remote_checksums = set(checksums) - if remote_checksums: - remote_size = total_prefixes * len(remote_checksums) - else: - remote_size = total_prefixes - logger.debug(f"Estimated remote size: {remote_size} files") - return remote_size, remote_checksums - - def _cache_checksums_traverse( - self, remote_size, remote_checksums, jobs=None, name=None - ): - """Iterate over all checksums in the remote cache. - Checksums are fetched in parallel according to prefix, except in - cases where the remote size is very small. - - All checksums from the remote (including any from the size - estimation step passed via the `remote_checksums` argument) will be - returned. - - NOTE: For large remotes the list of checksums will be very - big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list) - and we don't really need all of it at the same time, so it makes - sense to use a generator to gradually iterate over it, without - keeping all of it in memory. - """ - num_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE - if num_pages < 256 / self.JOBS: - # Fetching prefixes in parallel requires at least 255 more - # requests, for small enough remotes it will be faster to fetch - # entire cache without splitting it into prefixes. - # - # NOTE: this ends up re-fetching checksums that were already - # fetched during remote size estimation - traverse_prefixes = [None] - initial = 0 - else: - yield from remote_checksums - initial = len(remote_checksums) - traverse_prefixes = [f"{i:02x}" for i in range(1, 256)] - if self.TRAVERSE_PREFIX_LEN > 2: - traverse_prefixes += [ - "{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN) - for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2)) - ] - with Tqdm( - desc="Querying " - + (f"cache in '{name}'" if name else "remote cache"), - total=remote_size, - initial=initial, - unit="file", - ) as pbar: - - def list_with_update(prefix): - return list( - self.cache_checksums( - prefix=prefix, progress_callback=pbar.update - ) - ) - - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - in_remote = executor.map(list_with_update, traverse_prefixes,) - yield from itertools.chain.from_iterable(in_remote) - - def _cache_object_exists(self, checksums, jobs=None, name=None): - logger.debug( - "Querying {} checksums via object_exists".format(len(checksums)) - ) - with Tqdm( - desc="Querying " - + ("cache in " + name if name else "remote cache"), - total=len(checksums), - unit="file", - ) as pbar: - - def exists_with_progress(path_info): - ret = self.tree.exists(path_info) - pbar.update_msg(str(path_info)) - return ret - - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - path_infos = map(self.checksum_to_path_info, checksums) - in_remote = executor.map(exists_with_progress, path_infos) - ret = list(itertools.compress(checksums, in_remote)) - return ret - def already_cached(self, path_info): current = self.get_checksum(path_info) @@ -1384,10 +1388,3 @@ def get_files_number(self, path_info, checksum, filter_info): filter_info.isin_or_eq(path_info / entry[self.PARAM_CHECKSUM]) for entry in self.get_dir_cache(checksum) ) - - @staticmethod - def unprotect(path_info): - pass - - def _remove_unpacked_dir(self, checksum): - pass diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 0204a8997e..77049929cf 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -545,6 +545,9 @@ def remove(self, path_info): item_id = self._get_item_id(path_info) self.gdrive_delete_file(item_id) + def get_file_checksum(self, path_info): + raise NotImplementedError + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): dirname = to_info.parent assert dirname @@ -562,11 +565,8 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): class GDriveRemote(BaseRemote): scheme = Schemes.GDRIVE REQUIRES = {"pydrive2": "pydrive2"} + TREE_CLS = GDriveRemoteTree DEFAULT_VERIFY = True # Always prefer traverse for GDrive since API usage quotas are a concern. TRAVERSE_WEIGHT_MULTIPLIER = 1 TRAVERSE_PREFIX_LEN = 2 - TREE_CLS = GDriveRemoteTree - - def get_file_checksum(self, path_info): - raise NotImplementedError diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 4d71198b77..341589739d 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -9,7 +9,7 @@ from dvc.exceptions import DvcException from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree +from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin from dvc.scheme import Schemes logger = logging.getLogger(__name__) @@ -157,6 +157,20 @@ def copy(self, from_info, to_info): to_bucket = self.gs.bucket(to_info.bucket) from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path) + def get_file_checksum(self, path_info): + import base64 + import codecs + + bucket = path_info.bucket + path = path_info.path + blob = self.gs.bucket(bucket).get_blob(path) + if not blob: + return None + + b64_md5 = blob.md5_hash + md5 = base64.b64decode(b64_md5) + return codecs.getencoder("hex")(md5)[0].decode("utf-8") + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): bucket = self.gs.bucket(to_info.bucket) _upload_to_bucket( @@ -184,19 +198,9 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): class GSRemote(BaseRemote): scheme = Schemes.GS REQUIRES = {"google-cloud-storage": "google.cloud.storage"} - PARAM_CHECKSUM = "md5" TREE_CLS = GSRemoteTree + PARAM_CHECKSUM = "md5" - def get_file_checksum(self, path_info): - import base64 - import codecs - bucket = path_info.bucket - path = path_info.path - blob = self.tree.gs.bucket(bucket).get_blob(path) - if not blob: - return None - - b64_md5 = blob.md5_hash - md5 = base64.b64decode(b64_md5) - return codecs.getencoder("hex")(md5)[0].decode("utf-8") +class GSCache(GSRemote, CacheMixin): + pass diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index e39ebc318b..629b5767bd 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -11,7 +11,7 @@ from dvc.scheme import Schemes from dvc.utils import fix_env, tmp_fname -from .base import BaseRemote, BaseRemoteTree, RemoteCmdError +from .base import BaseRemote, BaseRemoteTree, CacheMixin, RemoteCmdError from .pool import get_connection logger = logging.getLogger(__name__) @@ -122,28 +122,6 @@ def copy(self, from_info, to_info, **_kwargs): self.remove(tmp_info) raise - def _upload(self, from_file, to_info, **_kwargs): - with self.hdfs(to_info) as hdfs: - hdfs.mkdir(posixpath.dirname(to_info.path)) - tmp_file = tmp_fname(to_info.path) - with open(from_file, "rb") as fobj: - hdfs.upload(tmp_file, fobj) - hdfs.rename(tmp_file, to_info.path) - - def _download(self, from_info, to_file, **_kwargs): - with self.hdfs(from_info) as hdfs: - with open(to_file, "wb+") as fobj: - hdfs.download(from_info.path, fobj) - - -class HDFSRemote(BaseRemote): - scheme = Schemes.HDFS - REGEX = r"^hdfs://((?P.*)@)?.*$" - PARAM_CHECKSUM = "checksum" - REQUIRES = {"pyarrow": "pyarrow"} - TRAVERSE_PREFIX_LEN = 2 - TREE_CLS = HDFSRemoteTree - def hadoop_fs(self, cmd, user=None): cmd = "hadoop fs -" + cmd if user: @@ -182,3 +160,29 @@ def get_file_checksum(self, path_info): f"checksum {path_info.path}", user=path_info.user ) return self._group(regex, stdout, "checksum") + + def _upload(self, from_file, to_info, **_kwargs): + with self.hdfs(to_info) as hdfs: + hdfs.mkdir(posixpath.dirname(to_info.path)) + tmp_file = tmp_fname(to_info.path) + with open(from_file, "rb") as fobj: + hdfs.upload(tmp_file, fobj) + hdfs.rename(tmp_file, to_info.path) + + def _download(self, from_info, to_file, **_kwargs): + with self.hdfs(from_info) as hdfs: + with open(to_file, "wb+") as fobj: + hdfs.download(from_info.path, fobj) + + +class HDFSRemote(BaseRemote): + scheme = Schemes.HDFS + REGEX = r"^hdfs://((?P.*)@)?.*$" + PARAM_CHECKSUM = "checksum" + REQUIRES = {"pyarrow": "pyarrow"} + TREE_CLS = HDFSRemoteTree + TRAVERSE_PREFIX_LEN = 2 + + +class HDFSCache(HDFSRemote, CacheMixin): + pass diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 154550bd16..941b6de25c 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -121,6 +121,25 @@ def request(self, method, url, **kwargs): def exists(self, path_info): return bool(self.request("HEAD", path_info.url)) + def get_file_checksum(self, path_info): + url = path_info.url + headers = self.request("HEAD", url).headers + etag = headers.get("ETag") or headers.get("Content-MD5") + + if not etag: + raise DvcException( + "could not find an ETag or " + "Content-MD5 header for '{url}'".format(url=url) + ) + + if etag.startswith("W/"): + raise DvcException( + "Weak ETags are not supported." + " (Etag: '{etag}', URL: '{url}')".format(etag=etag, url=url) + ) + + return etag + def _download(self, from_info, to_file, name=None, no_progress_bar=False): response = self.request("GET", from_info.url, stream=True) if response.status_code != 200: @@ -174,26 +193,7 @@ class HTTPRemote(BaseRemote): CAN_TRAVERSE = False TREE_CLS = HTTPRemoteTree - def get_file_checksum(self, path_info): - url = path_info.url - headers = self.tree.request("HEAD", url).headers - etag = headers.get("ETag") or headers.get("Content-MD5") - - if not etag: - raise DvcException( - "could not find an ETag or " - "Content-MD5 header for '{url}'".format(url=url) - ) - - if etag.startswith("W/"): - raise DvcException( - "Weak ETags are not supported." - " (Etag: '{etag}', URL: '{url}')".format(etag=etag, url=url) - ) - - return etag - - def list_cache_paths(self, prefix=None, progress_callback=None): + def list_paths(self, prefix=None, progress_callback=None): raise NotImplementedError def gc(self): diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 1b52f1f233..b778ff46c6 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -18,6 +18,7 @@ STATUS_NEW, BaseRemote, BaseRemoteTree, + CacheMixin, index_locked, ) from dvc.remote.index import RemoteIndexNoop @@ -43,7 +44,8 @@ class LocalRemoteTree(BaseRemoteTree): def __init__(self, remote, config): super().__init__(remote, config) - self.path_info = config.get("url") + url = config.get("url") + self.path_info = self.PATH_CLS(url) if url else None @property def repo(self): @@ -193,6 +195,9 @@ def reflink(self, from_info, to_info): os.chmod(tmp_info, self.file_mode) os.rename(tmp_info, to_info) + def get_file_checksum(self, path_info): + return file_md5(path_info)[0] + @staticmethod def getsize(path_info): return os.path.getsize(path_info) @@ -241,26 +246,119 @@ def wrapper(from_info, to_info, *args, **kwargs): class LocalRemote(BaseRemote): scheme = Schemes.LOCAL - PARAM_CHECKSUM = "md5" - PARAM_PATH = "path" - TRAVERSE_PREFIX_LEN = 2 INDEX_CLS = RemoteIndexNoop TREE_CLS = LocalRemoteTree - UNPACKED_DIR_SUFFIX = ".unpacked" - + PARAM_CHECKSUM = "md5" + PARAM_PATH = "path" DEFAULT_CACHE_TYPES = ["reflink", "copy"] + TRAVERSE_PREFIX_LEN = 2 + UNPACKED_DIR_SUFFIX = ".unpacked" CACHE_MODE = 0o444 - def __init__(self, repo, config): - super().__init__(repo, config) - self.cache_dir = config.get("url") - @property def state(self): return self.repo.state + def get(self, md5): + if not md5: + return None + + return self.checksum_to_path_info(md5).url + + def _unprotect_file(self, path): + if System.is_symlink(path) or System.is_hardlink(path): + logger.debug(f"Unprotecting '{path}'") + tmp = os.path.join(os.path.dirname(path), "." + uuid()) + + # The operations order is important here - if some application + # would access the file during the process of copyfile then it + # would get only the part of file. So, at first, the file should be + # copied with the temporary name, and then original file should be + # replaced by new. + copyfile(path, tmp, name="Unprotecting '{}'".format(relpath(path))) + remove(path) + os.rename(tmp, path) + + else: + logger.debug( + "Skipping copying for '{}', since it is not " + "a symlink or a hardlink.".format(path) + ) + + os.chmod(path, self.tree.file_mode) + + def _unprotect_dir(self, path): + assert is_working_tree(self.repo.tree) + + for fname in self.repo.tree.walk_files(path): + self._unprotect_file(fname) + + def unprotect(self, path_info): + path = path_info.fspath + if not os.path.exists(path): + raise DvcException(f"can't unprotect non-existing data '{path}'") + + if os.path.isdir(path): + self._unprotect_dir(path) + else: + self._unprotect_file(path) + + def protect(self, path_info): + path = os.fspath(path_info) + mode = self.CACHE_MODE + + try: + os.chmod(path, mode) + except OSError as exc: + # There is nothing we need to do in case of a read-only file system + if exc.errno == errno.EROFS: + return + + # In shared cache scenario, we might not own the cache file, so we + # need to check if cache file is already protected. + if exc.errno not in [errno.EPERM, errno.EACCES]: + raise + + actual = stat.S_IMODE(os.stat(path).st_mode) + if actual != mode: + raise + + def is_protected(self, path_info): + try: + mode = os.stat(path_info).st_mode + except FileNotFoundError: + return False + + return stat.S_IMODE(mode) == self.CACHE_MODE + + def list_paths(self, prefix=None, progress_callback=None): + assert self.path_info is not None + if prefix: + path_info = self.path_info / prefix[:2] + if not self.tree.exists(path_info): + return + else: + path_info = self.path_info + if progress_callback: + for path in walk_files(path_info): + progress_callback() + yield path + else: + yield from walk_files(path_info) + + def _remove_unpacked_dir(self, checksum): + info = self.checksum_to_path_info(checksum) + path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) + self.tree.remove(path_info) + + +class LocalCache(LocalRemote, CacheMixin): + def __init__(self, repo, config): + super().__init__(repo, config) + self.cache_dir = config.get("url") + @property def cache_dir(self): return self.tree.path_info.fspath if self.tree.path_info else None @@ -285,26 +383,17 @@ def checksum_to_path(self, checksum): f"{self.cache_path}{os.sep}{checksum[0:2]}{os.sep}{checksum[2:]}" ) - def list_cache_paths(self, prefix=None, progress_callback=None): - assert self.path_info is not None - if prefix: - path_info = self.path_info / prefix[:2] - if not self.tree.exists(path_info): - return - else: - path_info = self.path_info - if progress_callback: - for path in walk_files(path_info): - progress_callback() - yield path - else: - yield from walk_files(path_info) - - def get(self, md5): - if not md5: - return None - - return self.checksum_to_path_info(md5).url + def checksums_exist(self, checksums, jobs=None, name=None): + return [ + checksum + for checksum in Tqdm( + checksums, + unit="file", + desc="Querying " + + ("cache in " + name if name else "local cache"), + ) + if not self.changed_cache_file(checksum) + ] def already_cached(self, path_info): assert path_info.scheme in ["", "local"] @@ -322,21 +411,6 @@ def _verify_link(self, path_info, link_type): super()._verify_link(path_info, link_type) - def get_file_checksum(self, path_info): - return file_md5(path_info)[0] - - def cache_exists(self, checksums, jobs=None, name=None): - return [ - checksum - for checksum in Tqdm( - checksums, - unit="file", - desc="Querying " - + ("cache in " + name if name else "local cache"), - ) - if not self.changed_cache_file(checksum) - ] - @index_locked def status( self, @@ -376,7 +450,7 @@ def _status( logger.debug("Collecting information from local cache...") local_exists = frozenset( - self.cache_exists(md5s, jobs=jobs, name=self.cache_dir) + self.checksums_exist(md5s, jobs=jobs, name=self.cache_dir) ) # This is a performance optimization. We can safely assume that, @@ -396,7 +470,7 @@ def _status( md5s.difference_update(remote_exists) if md5s: remote_exists.update( - remote.cache_exists( + remote.checksums_exist( md5s, jobs=jobs, name=str(remote.path_info) ) ) @@ -439,7 +513,7 @@ def _indexed_dir_checksums(self, named_cache, remote, dir_md5s): indexed_dir_exists = set() if indexed_dirs: indexed_dir_exists.update( - remote._cache_object_exists(indexed_dirs) + remote._list_checksums_exists(indexed_dirs) ) missing_dirs = indexed_dirs.difference(indexed_dir_exists) if missing_dirs: @@ -451,7 +525,7 @@ def _indexed_dir_checksums(self, named_cache, remote, dir_md5s): # Check if non-indexed (new) dir checksums exist on remote dir_exists = dir_md5s.intersection(indexed_dir_exists) - dir_exists.update(remote._cache_object_exists(dir_md5s - dir_exists)) + dir_exists.update(remote._list_checksums_exists(dir_md5s - dir_exists)) # If .dir checksum exists on the remote, assume directory contents # still exists on the remote @@ -658,74 +732,3 @@ def _log_missing_caches(checksum_info_dict): "nor on remote. Missing cache files: {}".format(missing_desc) ) logger.warning(msg) - - def _unprotect_file(self, path): - if System.is_symlink(path) or System.is_hardlink(path): - logger.debug(f"Unprotecting '{path}'") - tmp = os.path.join(os.path.dirname(path), "." + uuid()) - - # The operations order is important here - if some application - # would access the file during the process of copyfile then it - # would get only the part of file. So, at first, the file should be - # copied with the temporary name, and then original file should be - # replaced by new. - copyfile(path, tmp, name="Unprotecting '{}'".format(relpath(path))) - remove(path) - os.rename(tmp, path) - - else: - logger.debug( - "Skipping copying for '{}', since it is not " - "a symlink or a hardlink.".format(path) - ) - - os.chmod(path, self.tree.file_mode) - - def _unprotect_dir(self, path): - assert is_working_tree(self.repo.tree) - - for fname in self.repo.tree.walk_files(path): - self._unprotect_file(fname) - - def unprotect(self, path_info): - path = path_info.fspath - if not os.path.exists(path): - raise DvcException(f"can't unprotect non-existing data '{path}'") - - if os.path.isdir(path): - self._unprotect_dir(path) - else: - self._unprotect_file(path) - - def protect(self, path_info): - path = os.fspath(path_info) - mode = self.CACHE_MODE - - try: - os.chmod(path, mode) - except OSError as exc: - # There is nothing we need to do in case of a read-only file system - if exc.errno == errno.EROFS: - return - - # In shared cache scenario, we might not own the cache file, so we - # need to check if cache file is already protected. - if exc.errno not in [errno.EPERM, errno.EACCES]: - raise - - actual = stat.S_IMODE(os.stat(path).st_mode) - if actual != mode: - raise - - def _remove_unpacked_dir(self, checksum): - info = self.checksum_to_path_info(checksum) - path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) - self.tree.remove(path_info) - - def is_protected(self, path_info): - try: - mode = os.stat(path_info).st_mode - except FileNotFoundError: - return False - - return stat.S_IMODE(mode) == self.CACHE_MODE diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 642a743bb5..5ed0d67091 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -8,7 +8,7 @@ from dvc.exceptions import DvcException, ETagMismatchError from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree +from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin from dvc.scheme import Schemes logger = logging.getLogger(__name__) @@ -305,6 +305,9 @@ def _copy(cls, s3, from_info, to_info, extra_args): if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) + def get_file_checksum(self, path_info): + return self.get_etag(self.s3, path_info.bucket, path_info.path) + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): total = os.path.getsize(from_file) with Tqdm( @@ -339,7 +342,6 @@ class S3Remote(BaseRemote): PARAM_CHECKSUM = "etag" TREE_CLS = S3RemoteTree - def get_file_checksum(self, path_info): - return self.tree.get_etag( - self.tree.s3, path_info.bucket, path_info.path - ) + +class S3Cache(S3Remote, CacheMixin): + pass diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 54a968d4f4..f5784f0ab3 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -14,7 +14,7 @@ import dvc.prompt as prompt from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree +from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin from dvc.remote.pool import get_connection from dvc.scheme import Schemes from dvc.utils import to_chunks @@ -225,6 +225,13 @@ def reflink(self, from_info, to_info): with self.ssh(from_info) as ssh: ssh.reflink(from_info.path, to_info.path) + def get_file_checksum(self, path_info): + if path_info.scheme != self.scheme: + raise NotImplementedError + + with self.ssh(path_info) as ssh: + return ssh.md5(path_info.path) + def getsize(self, path_info): with self.ssh(path_info) as ssh: return ssh.getsize(path_info.path) @@ -253,26 +260,18 @@ def _upload(self, from_file, to_info, name=None, no_progress_bar=False): class SSHRemote(BaseRemote): scheme = Schemes.SSH REQUIRES = {"paramiko": "paramiko"} - JOBS = 4 + TREE_CLS = SSHRemoteTree + PARAM_CHECKSUM = "md5" # At any given time some of the connections will go over network and # paramiko stuff, so we would ideally have it double of server processors. # We use conservative setting of 4 instead to not exhaust max sessions. CHECKSUM_JOBS = 4 - TRAVERSE_PREFIX_LEN = 2 - TREE_CLS = SSHRemoteTree - DEFAULT_CACHE_TYPES = ["copy"] + TRAVERSE_PREFIX_LEN = 2 - def get_file_checksum(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - with self.tree.ssh(path_info) as ssh: - return ssh.md5(path_info.path) - - def list_cache_paths(self, prefix=None, progress_callback=None): + def list_paths(self, prefix=None, progress_callback=None): if prefix: root = posixpath.join(self.path_info.path, prefix[:2]) else: @@ -316,7 +315,7 @@ def _exists(chunk_and_channel): return results - def cache_exists(self, checksums, jobs=None, name=None): + def checksums_exist(self, checksums, jobs=None, name=None): """This is older implementation used in remote/base.py We are reusing it in RemoteSSH, because SSH's batch_exists proved to be faster than current approach (relying on exists(path_info)) applied in @@ -345,3 +344,7 @@ def exists_with_progress(chunks): in_remote = itertools.chain.from_iterable(results) ret = list(itertools.compress(checksums, in_remote)) return ret + + +class SSHCache(SSHRemote, CacheMixin): + pass diff --git a/tests/func/remote/test_index.py b/tests/func/remote/test_index.py index 81691bf091..517d1ec9ed 100644 --- a/tests/func/remote/test_index.py +++ b/tests/func/remote/test_index.py @@ -15,14 +15,14 @@ def remote(tmp_dir, dvc, tmp_path_factory, mocker): dvc.config["remote"]["upstream"] = {"url": url} dvc.config["core"]["remote"] = "upstream" - # patch cache_exists since the RemoteLOCAL normally overrides - # RemoteBASE.cache_exists. - def cache_exists(self, *args, **kwargs): - return BaseRemote.cache_exists(self, *args, **kwargs) + # patch checksums_exist since the LocalRemote normally overrides + # BaseRemote.checksums_exist. + def checksums_exist(self, *args, **kwargs): + return BaseRemote.checksums_exist(self, *args, **kwargs) - mocker.patch.object(LocalRemote, "cache_exists", cache_exists) + mocker.patch.object(LocalRemote, "checksums_exist", checksums_exist) - # patch index class since RemoteLOCAL normally overrides index class + # patch index class since LocalRemote normally overrides index class mocker.patch.object(LocalRemote, "INDEX_CLS", RemoteIndex) return dvc.cloud.get_remote("upstream") diff --git a/tests/func/test_add.py b/tests/func/test_add.py index f88564420f..1669974d34 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -340,7 +340,7 @@ def test(self): def test_should_collect_dir_cache_only_once(mocker, tmp_dir, dvc): tmp_dir.gen({"data/data": "foo"}) - get_dir_checksum_counter = mocker.spy(LocalRemote, "get_dir_checksum") + get_dir_checksum_counter = mocker.spy(LocalRemoteTree, "get_dir_checksum") ret = main(["add", "data"]) assert ret == 0 diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index 848074d1f6..b417cd8e0d 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -16,7 +16,7 @@ NoOutputOrStageError, ) from dvc.main import main -from dvc.remote import S3Remote +from dvc.remote import S3Cache, S3Remote from dvc.remote.local import LocalRemote from dvc.repo import Repo as DvcRepo from dvc.stage import Stage @@ -755,7 +755,7 @@ def test_checkout_recursive(tmp_dir, dvc): not S3.should_test(), reason="Only run with S3 credentials" ) def test_checkout_for_external_outputs(tmp_dir, dvc): - dvc.cache.s3 = S3Remote(dvc, {"url": S3.get_url()}) + dvc.cache.s3 = S3Cache(dvc, {"url": S3.get_url()}) remote = S3Remote(dvc, {"url": S3.get_url()}) file_path = remote.path_info / "foo" diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 9783fbd7a7..02611f4e15 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -24,6 +24,7 @@ SSHRemote, ) from dvc.remote.base import STATUS_DELETED, STATUS_NEW, STATUS_OK +from dvc.remote.local import LocalRemoteTree from dvc.stage.exceptions import StageNotFound from dvc.utils import file_md5 from dvc.utils.fs import remove @@ -547,7 +548,9 @@ def _get_cloud_class(self): def _prepare_repo(self): remote = self.cloud.get_remote() - self.main(["remote", "add", "-d", TEST_REMOTE, remote.cache_dir]) + self.main( + ["remote", "add", "-d", TEST_REMOTE, remote.path_info.fspath] + ) self.dvc.add(self.DATA) self.dvc.add(self.DATA_SUB) @@ -614,7 +617,7 @@ def test(self): def test_checksum_recalculation(mocker, dvc, tmp_dir): tmp_dir.gen({"foo": "foo"}) - test_get_file_checksum = mocker.spy(LocalRemote, "get_file_checksum") + test_get_file_checksum = mocker.spy(LocalRemoteTree, "get_file_checksum") url = Local.get_url() ret = main(["remote", "add", "-d", TEST_REMOTE, url]) assert ret == 0 @@ -693,7 +696,7 @@ def test_verify_checksums( remove("dir") remove(dvc.cache.local.cache_dir) - checksum_spy = mocker.spy(dvc.cache.local, "get_file_checksum") + checksum_spy = mocker.spy(dvc.cache.local.tree, "get_file_checksum") dvc.pull() assert checksum_spy.call_count == 0 diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 9b8a4cea8d..bb5c61ad23 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -10,7 +10,7 @@ from dvc.exceptions import DownloadError, UploadError from dvc.main import main from dvc.path_info import PathInfo -from dvc.remote.base import BaseRemote, RemoteCacheRequiredError +from dvc.remote.base import BaseRemoteTree, RemoteCacheRequiredError from dvc.remote.local import LocalRemoteTree from dvc.utils.fs import remove from tests.basic_env import TestDvc @@ -151,24 +151,24 @@ def test_dir_checksum_should_be_key_order_agnostic(tmp_dir, dvc): path_info = PathInfo("data") with dvc.state: with patch.object( - BaseRemote, + BaseRemoteTree, "_collect_dir", return_value=[ {"relpath": "1", "md5": "1"}, {"relpath": "2", "md5": "2"}, ], ): - checksum1 = dvc.cache.local.get_dir_checksum(path_info) + checksum1 = dvc.cache.local.get_checksum(path_info) with patch.object( - BaseRemote, + BaseRemoteTree, "_collect_dir", return_value=[ {"md5": "1", "relpath": "1"}, {"md5": "2", "relpath": "2"}, ], ): - checksum2 = dvc.cache.local.get_dir_checksum(path_info) + checksum2 = dvc.cache.local.get_checksum(path_info) assert checksum1 == checksum2 diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index 0b1f2d5bc9..58f93a5b0c 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -5,7 +5,7 @@ import pytest from moto import mock_s3 -from dvc.remote.s3 import S3Remote, S3RemoteTree +from dvc.remote.s3 import S3Cache, S3Remote, S3RemoteTree from tests.remotes import S3 # from https://github.com/spulec/moto/blob/v1.3.5/tests/test_s3/test_s3.py#L40 @@ -54,16 +54,16 @@ def test_copy_singlepart_preserve_etag(): ], ) def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): - remote = S3Remote(dvc, {"url": str(base_info.parent)}) - s3 = remote.tree.s3 + cache = S3Cache(dvc, {"url": str(base_info.parent)}) + s3 = cache.tree.s3 s3.create_bucket(Bucket=base_info.bucket) s3.put_object( Bucket=base_info.bucket, Key=(base_info / "from").path, Body="data" ) - remote.link(base_info / "from", base_info / "to") + cache.link(base_info / "from", base_info / "to") - assert remote.tree.exists(base_info / "from") - assert remote.tree.exists(base_info / "to") + assert cache.tree.exists(base_info / "from") + assert cache.tree.exists(base_info / "to") @mock_s3 diff --git a/tests/func/test_tree.py b/tests/func/test_tree.py index 1df0000d5e..daffd364e1 100644 --- a/tests/func/test_tree.py +++ b/tests/func/test_tree.py @@ -218,7 +218,7 @@ def test_repotree_cache_save(tmp_dir, dvc, scm, erepo_dir, setup_remote): with erepo_dir.dvc.state: cache = dvc.cache.local with cache.state: - cache.save(PathInfo(erepo_dir / "dir"), None, tree=tree) + cache.save(PathInfo(erepo_dir / "dir"), tree, None) for checksum in expected: assert os.path.exists(cache.checksum_to_path_info(checksum)) diff --git a/tests/unit/output/test_local.py b/tests/unit/output/test_local.py index 00944524da..c4b214ecec 100644 --- a/tests/unit/output/test_local.py +++ b/tests/unit/output/test_local.py @@ -3,7 +3,7 @@ from mock import patch from dvc.output import LocalOutput -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalCache from dvc.stage import Stage from dvc.utils import relpath from tests.basic_env import TestDvc @@ -79,7 +79,7 @@ def test_return_0_on_no_cache(self): @patch.object(LocalOutput, "checksum", "12345678.dir") @patch.object( - LocalRemote, + LocalCache, "get_dir_cache", return_value=[{"md5": "asdf"}, {"md5": "qwe"}], ) diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index b49f2ee631..c9009e9f71 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -42,7 +42,7 @@ def test_get_file_checksum(tmp_dir): to_info = remote.tree.PATH_CLS(Azure.get_url()) remote.tree.upload(PathInfo("foo"), to_info) assert remote.tree.exists(to_info) - checksum = remote.get_file_checksum(to_info) + checksum = remote.tree.get_file_checksum(to_info) assert checksum assert isinstance(checksum, str) assert checksum.strip("'").strip('"') == checksum diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 481f9f3f1f..9bbabb901c 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -41,18 +41,18 @@ def test_cmd_error(dvc): REMOTE_CLS(dvc, config).tree.remove("file") -@mock.patch.object(BaseRemote, "_cache_checksums_traverse") -@mock.patch.object(BaseRemote, "_cache_object_exists") -def test_cache_exists(object_exists, traverse, dvc): +@mock.patch.object(BaseRemote, "_list_checksums_traverse") +@mock.patch.object(BaseRemote, "_list_checksums_exists") +def test_checksums_exist(object_exists, traverse, dvc): remote = BaseRemote(dvc, {}) # remote does not support traverse remote.CAN_TRAVERSE = False with mock.patch.object( - remote, "cache_checksums", return_value=list(range(256)) + remote, "list_checksums", return_value=list(range(256)) ): checksums = set(range(1000)) - remote.cache_exists(checksums) + remote.checksums_exist(checksums) object_exists.assert_called_with(checksums, None, None) traverse.assert_not_called() @@ -62,10 +62,10 @@ def test_cache_exists(object_exists, traverse, dvc): object_exists.reset_mock() traverse.reset_mock() with mock.patch.object( - remote, "cache_checksums", return_value=list(range(256)) + remote, "list_checksums", return_value=list(range(256)) ): checksums = list(range(1000)) - remote.cache_exists(checksums) + remote.checksums_exist(checksums) # verify that _cache_paths_with_max() short circuits # before returning all 256 remote checksums max_checksums = math.ceil( @@ -83,10 +83,10 @@ def test_cache_exists(object_exists, traverse, dvc): traverse.reset_mock() remote.JOBS = 16 with mock.patch.object( - remote, "cache_checksums", return_value=list(range(256)) + remote, "list_checksums", return_value=list(range(256)) ): checksums = list(range(1000000)) - remote.cache_exists(checksums) + remote.checksums_exist(checksums) object_exists.assert_not_called() traverse.assert_called_with( 256 * pow(16, remote.TRAVERSE_PREFIX_LEN), @@ -97,44 +97,44 @@ def test_cache_exists(object_exists, traverse, dvc): @mock.patch.object( - BaseRemote, "cache_checksums", return_value=[], + BaseRemote, "list_checksums", return_value=[], ) @mock.patch.object( BaseRemote, "path_to_checksum", side_effect=lambda x: x, ) -def test_cache_checksums_traverse(path_to_checksum, cache_checksums, dvc): +def test_list_checksums_traverse(path_to_checksum, list_checksums, dvc): remote = BaseRemote(dvc, {}) remote.tree.path_info = PathInfo("foo") # parallel traverse size = 256 / remote.JOBS * remote.LIST_OBJECT_PAGE_SIZE - list(remote._cache_checksums_traverse(size, {0})) + list(remote._list_checksums_traverse(size, {0})) for i in range(1, 16): - cache_checksums.assert_any_call( + list_checksums.assert_any_call( prefix=f"{i:03x}", progress_callback=CallableOrNone ) for i in range(1, 256): - cache_checksums.assert_any_call( + list_checksums.assert_any_call( prefix=f"{i:02x}", progress_callback=CallableOrNone ) # default traverse (small remote) size -= 1 - cache_checksums.reset_mock() - list(remote._cache_checksums_traverse(size - 1, {0})) - cache_checksums.assert_called_with( + list_checksums.reset_mock() + list(remote._list_checksums_traverse(size - 1, {0})) + list_checksums.assert_called_with( prefix=None, progress_callback=CallableOrNone ) -def test_cache_checksums(dvc): +def test_list_checksums(dvc): remote = BaseRemote(dvc, {}) remote.tree.path_info = PathInfo("foo") with mock.patch.object( - remote, "list_cache_paths", return_value=["12/3456", "bar"] + remote, "list_paths", return_value=["12/3456", "bar"] ): - checksums = list(remote.cache_checksums()) + checksums = list(remote.list_checksums()) assert checksums == ["123456"] diff --git a/tests/unit/remote/test_local.py b/tests/unit/remote/test_local.py index a42e3d50c0..fe2cb016d3 100644 --- a/tests/unit/remote/test_local.py +++ b/tests/unit/remote/test_local.py @@ -5,7 +5,7 @@ from dvc.cache import NamedCache from dvc.path_info import PathInfo -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalCache def test_status_download_optimization(mocker, dvc): @@ -13,28 +13,28 @@ def test_status_download_optimization(mocker, dvc): And the desired files to fetch are already on the local cache, Don't check the existence of the desired files on the remote cache """ - remote = LocalRemote(dvc, {}) + cache = LocalCache(dvc, {}) infos = NamedCache() infos.add("local", "acbd18db4cc2f85cedef654fccc4a4d8", "foo") infos.add("local", "37b51d194a7513e45b56f6524f2d51f2", "bar") local_exists = list(infos["local"]) - mocker.patch.object(remote, "cache_exists", return_value=local_exists) + mocker.patch.object(cache, "checksums_exist", return_value=local_exists) other_remote = mocker.Mock() other_remote.url = "other_remote" - other_remote.cache_exists.return_value = [] + other_remote.checksums_exist.return_value = [] - remote.status(infos, other_remote, download=True) + cache.status(infos, other_remote, download=True) - assert other_remote.cache_exists.call_count == 0 + assert other_remote.checksums_exist.call_count == 0 @pytest.mark.parametrize("link_name", ["hardlink", "symlink"]) def test_is_protected(tmp_dir, dvc, link_name): - remote = LocalRemote(dvc, {}) - link_method = getattr(remote.tree, link_name) + cache = LocalCache(dvc, {}) + link_method = getattr(cache.tree, link_name) (tmp_dir / "foo").write_text("foo") @@ -43,47 +43,47 @@ def test_is_protected(tmp_dir, dvc, link_name): link_method(foo, link) - assert not remote.is_protected(foo) - assert not remote.is_protected(link) + assert not cache.is_protected(foo) + assert not cache.is_protected(link) - remote.protect(foo) + cache.protect(foo) - assert remote.is_protected(foo) - assert remote.is_protected(link) + assert cache.is_protected(foo) + assert cache.is_protected(link) - remote.unprotect(link) + cache.unprotect(link) - assert not remote.is_protected(link) + assert not cache.is_protected(link) if os.name == "nt" and link_name == "hardlink": # NOTE: NTFS doesn't allow deleting read-only files, which forces us to # set write perms on the link, which propagates to the source. - assert not remote.is_protected(foo) + assert not cache.is_protected(foo) else: - assert remote.is_protected(foo) + assert cache.is_protected(foo) @pytest.mark.parametrize("err", [errno.EPERM, errno.EACCES]) def test_protect_ignore_errors(tmp_dir, mocker, err): tmp_dir.gen("foo", "foo") foo = PathInfo("foo") - remote = LocalRemote(None, {}) + cache = LocalCache(None, {}) - remote.protect(foo) + cache.protect(foo) mock_chmod = mocker.patch( "os.chmod", side_effect=OSError(err, "something") ) - remote.protect(foo) + cache.protect(foo) assert mock_chmod.called def test_protect_ignore_erofs(tmp_dir, mocker): tmp_dir.gen("foo", "foo") foo = PathInfo("foo") - remote = LocalRemote(None, {}) + cache = LocalCache(None, {}) mock_chmod = mocker.patch( "os.chmod", side_effect=OSError(errno.EROFS, "read-only fs") ) - remote.protect(foo) + cache.protect(foo) assert mock_chmod.called