From c81795c0476a58029bc0359088b915d41351fbd9 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Mon, 23 Dec 2019 13:19:24 +0200 Subject: [PATCH] dvc: support granularity for fetch/pull/push/status/checkout Using output path or a subdir/subfile path within an output now works with `dvc push/pull/fetch/status -c` commands. Other commands don't support the same logic for now, as there are some questions about what should commands like `dvc remove` do when given a specific output path. Example: dvc add data dvc pull data/subdir # will only pull files within data/subdir Related to #2458 --- dvc/output/base.py | 33 +++++++++++++++---- dvc/path_info.py | 5 ++- dvc/remote/base.py | 65 +++++++++++++++++++++++++++++-------- dvc/repo/__init__.py | 61 ++++++++++++++++++++-------------- dvc/repo/checkout.py | 21 +++++++----- dvc/stage.py | 33 ++++++++++++++++--- tests/func/test_checkout.py | 13 ++++++++ tests/unit/test_repo.py | 45 +++++++++++++++++++++++++ 8 files changed, 216 insertions(+), 60 deletions(-) diff --git a/dvc/output/base.py b/dvc/output/base.py index ffb68ac6af..b92a888321 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -283,11 +283,18 @@ def download(self, to): self.remote.download(self.path_info, to.path_info) def checkout( - self, force=False, progress_callback=None, tag=None, relink=False + self, + force=False, + progress_callback=None, + tag=None, + relink=False, + filter_info=None, ): if not self.use_cache: if progress_callback: - progress_callback(str(self.path_info), self.get_files_number()) + progress_callback( + str(self.path_info), self.get_files_number(filter_info) + ) return None if tag: @@ -301,6 +308,7 @@ def checkout( force=force, progress_callback=progress_callback, relink=relink, + filter_info=filter_info, ) def remove(self, ignore_remove=False): @@ -324,17 +332,21 @@ def move(self, out): if self.scheme == "local" and self.use_scm_ignore: self.repo.scm.ignore(self.fspath) - def get_files_number(self): + def get_files_number(self, filter_info=None): if not self.use_cache: return 0 - return self.cache.get_files_number(self.checksum) + return self.cache.get_files_number( + self.path_info, self.checksum, filter_info + ) def unprotect(self): if self.exists: self.remote.unprotect(self.path_info) - def _collect_used_dir_cache(self, remote=None, force=False, jobs=None): + def _collect_used_dir_cache( + self, remote=None, force=False, jobs=None, filter_info=None + ): """Get a list of `info`s related to the given directory. - Pull the directory entry from the remote cache if it was changed. @@ -383,8 +395,9 @@ def _collect_used_dir_cache(self, remote=None, force=False, jobs=None): for entry in self.dir_cache: checksum = entry[self.remote.PARAM_CHECKSUM] - path_info = self.path_info / entry[self.remote.PARAM_RELPATH] - cache.add(self.scheme, checksum, str(path_info)) + info = self.path_info / entry[self.remote.PARAM_RELPATH] + if not filter_info or info.isin_or_eq(filter_info): + cache.add(self.scheme, checksum, str(info)) return cache @@ -400,6 +413,12 @@ def get_used_cache(self, **kwargs): if not self.use_cache: return NamedCache() + if self.stage.is_repo_import: + cache = NamedCache() + dep, = self.stage.deps + cache.external[dep.repo_pair].add(dep.def_path) + return cache + if not self.info: logger.warning( "Output '{}'({}) is missing version info. Cache for it will " diff --git a/dvc/path_info.py b/dvc/path_info.py index 773ae7dcff..a4a24726e4 100644 --- a/dvc/path_info.py +++ b/dvc/path_info.py @@ -29,7 +29,10 @@ def overlaps(self, other): other = self.__class__(other) elif self.__class__ != other.__class__: return False - return self == other or self.isin(other) or other.isin(self) + return self.isin_or_eq(other) or other.isin(self) + + def isin_or_eq(self, other): + return self == other or self.isin(other) class PathInfo(pathlib.PurePath, _BasePath): diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 0f66b27ab0..cb012a210b 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -739,24 +739,36 @@ def changed_cache_file(self, checksum): return True - def _changed_dir_cache(self, checksum): + def _changed_dir_cache(self, checksum, path_info=None, filter_info=None): if self.changed_cache_file(checksum): return True - if not self._changed_unpacked_dir(checksum): + if not (path_info and filter_info) and not self._changed_unpacked_dir( + checksum + ): return False for entry in self.get_dir_cache(checksum): entry_checksum = entry[self.PARAM_CHECKSUM] + + if path_info and filter_info: + entry_info = path_info / entry[self.PARAM_RELPATH] + if not entry_info.isin_or_eq(filter_info): + continue + if self.changed_cache_file(entry_checksum): return True - self._update_unpacked_dir(checksum) + if not (path_info and filter_info): + self._update_unpacked_dir(checksum) + return False - def changed_cache(self, checksum): + def changed_cache(self, checksum, path_info=None, filter_info=None): if self.is_dir_checksum(checksum): - return self._changed_dir_cache(checksum) + return self._changed_dir_cache( + checksum, path_info=path_info, filter_info=filter_info + ) return self.changed_cache_file(checksum) def cache_exists(self, checksums, jobs=None, name=None): @@ -849,7 +861,13 @@ def makedirs(self, path_info): pass def _checkout_dir( - self, path_info, checksum, force, progress_callback=None, relink=False + self, + path_info, + checksum, + force, + progress_callback=None, + relink=False, + filter_info=None, ): # Create dir separately so that dir is created # even if there are no files in it @@ -866,6 +884,9 @@ def _checkout_dir( entry_cache_info = self.checksum_to_path_info(entry_checksum) entry_info = path_info / relative_path + if filter_info and not entry_info.isin_or_eq(filter_info): + continue + entry_checksum_info = {self.PARAM_CHECKSUM: entry_checksum} if relink or self.changed(entry_info, entry_checksum_info): self.safe_remove(entry_info, force=force) @@ -896,6 +917,7 @@ def checkout( force=False, progress_callback=None, relink=False, + filter_info=None, ): if path_info.scheme not in ["local", self.scheme]: raise NotImplementedError @@ -916,7 +938,9 @@ def checkout( logger.debug(msg.format(str(path_info))) skip = True - elif self.changed_cache(checksum): + elif self.changed_cache( + checksum, path_info=path_info, filter_info=filter_info + ): msg = "Cache '{}' not found. File '{}' won't be created." logger.warning(msg.format(checksum, str(path_info))) self.safe_remove(path_info, force=force) @@ -925,15 +949,19 @@ def checkout( if failed or skip: if progress_callback: progress_callback( - str(path_info), self.get_files_number(checksum) + str(path_info), + self.get_files_number( + self.path_info, checksum, filter_info + ), ) return failed msg = "Checking out '{}' with cache '{}'." logger.debug(msg.format(str(path_info), checksum)) - self._checkout(path_info, checksum, force, progress_callback, relink) - return None + self._checkout( + path_info, checksum, force, progress_callback, relink, filter_info + ) def _checkout( self, @@ -942,23 +970,32 @@ def _checkout( force=False, progress_callback=None, relink=False, + filter_info=None, ): if not self.is_dir_checksum(checksum): return self._checkout_file( path_info, checksum, force, progress_callback=progress_callback ) return self._checkout_dir( - path_info, checksum, force, progress_callback, relink + path_info, checksum, force, progress_callback, relink, filter_info ) - def get_files_number(self, checksum): + def get_files_number(self, path_info, checksum, filter_info): + from funcy.py3 import ilen + if not checksum: return 0 - if self.is_dir_checksum(checksum): + if not self.is_dir_checksum(checksum): + return 1 + + if not filter_info: return len(self.get_dir_cache(checksum)) - return 1 + return ilen( + filter_info.isin_or_eq(path_info / entry[self.PARAM_CHECKSUM]) + for entry in self.get_dir_cache(checksum) + ) @staticmethod def unprotect(path_info): diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 8a0afc2d3c..5fa1a2cfd3 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -217,6 +217,18 @@ def collect(self, target, with_deps=False, recursive=False, graph=None): for n in nx.dfs_postorder_nodes(pipeline, node) ] + def collect_granular(self, target, *args, **kwargs): + if not target: + return [(stage, None) for stage in self.stages] + + try: + out, = self.find_outs_by_path(target, strict=False) + filter_info = PathInfo(os.path.abspath(target)) + return [(out.stage, filter_info)] + except OutputNotFoundError: + stages = self.collect(target, *args, **kwargs) + return [(stage, None) for stage in stages] + def used_cache( self, targets=None, @@ -242,6 +254,7 @@ def used_cache( A dictionary with Schemes (representing output's location) as keys, and a list with the outputs' `dumpd` as values. """ + from funcy.py2 import icat from dvc.cache import NamedCache cache = NamedCache() @@ -251,28 +264,24 @@ def used_cache( all_tags=all_tags, all_commits=all_commits, ): - if targets: - stages = [] - for target in targets: - collected = self.collect( - target, recursive=recursive, with_deps=with_deps - ) - stages.extend(collected) - else: - stages = self.stages - - for stage in stages: - if stage.is_repo_import: - dep, = stage.deps - cache.external[dep.repo_pair].add(dep.def_path) - continue - - for out in stage.outs: - used_cache = out.get_used_cache( - remote=remote, force=force, jobs=jobs - ) - suffix = "({})".format(branch) if branch else "" - cache.update(used_cache, suffix=suffix) + targets = targets or [None] + + pairs = icat( + self.collect_granular( + target, recursive=recursive, with_deps=with_deps + ) + for target in targets + ) + + suffix = "({})".format(branch) if branch else "" + for stage, filter_info in pairs: + used_cache = stage.get_used_cache( + remote=remote, + force=force, + jobs=jobs, + filter_info=filter_info, + ) + cache.update(used_cache, suffix=suffix) return cache @@ -421,18 +430,20 @@ def collect_stages(self): def stages(self): return get_stages(self.graph) - def find_outs_by_path(self, path, outs=None, recursive=False): + def find_outs_by_path(self, path, outs=None, recursive=False, strict=True): if not outs: outs = [out for stage in self.stages for out in stage.outs] abs_path = os.path.abspath(path) + path_info = PathInfo(abs_path) is_dir = self.tree.isdir(abs_path) + match = path_info.__eq__ if strict else path_info.isin_or_eq def func(out): - if out.scheme == "local" and out.fspath == abs_path: + if out.scheme == "local" and match(out.path_info): return True - if is_dir and recursive and out.path_info.isin(abs_path): + if is_dir and recursive and out.path_info.isin(path_info): return True return False diff --git a/dvc/repo/checkout.py b/dvc/repo/checkout.py index 040590d86c..c0c6e4417d 100644 --- a/dvc/repo/checkout.py +++ b/dvc/repo/checkout.py @@ -20,8 +20,10 @@ def _cleanup_unused_links(repo): repo.state.remove_unused_links(used) -def get_all_files_numbers(stages): - return sum(stage.get_all_files_number() for stage in stages) +def get_all_files_numbers(pairs): + return sum( + stage.get_all_files_number(filter_info) for stage, filter_info in pairs + ) def _checkout( @@ -34,36 +36,37 @@ def _checkout( ): from dvc.stage import StageFileDoesNotExistError, StageFileBadNameError - stages = set() - if not targets: targets = [None] _cleanup_unused_links(self) + pairs = set() for target in targets: try: - new = self.collect( - target, with_deps=with_deps, recursive=recursive + pairs.update( + self.collect_granular( + target, with_deps=with_deps, recursive=recursive + ) ) - stages.update(new) except (StageFileDoesNotExistError, StageFileBadNameError) as exc: if not target: raise raise CheckoutErrorSuggestGit(target, exc) - total = get_all_files_numbers(stages) + total = get_all_files_numbers(pairs) if total == 0: logger.info("Nothing to do") failed = [] with Tqdm( total=total, unit="file", desc="Checkout", disable=total == 0 ) as pbar: - for stage in stages: + for stage, filter_info in pairs: failed.extend( stage.checkout( force=force, progress_callback=pbar.update_desc, relink=relink, + filter_info=filter_info, ) ) if failed: diff --git a/dvc/stage.py b/dvc/stage.py index 7aafa5d750..e171a98f1c 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -959,15 +959,28 @@ def check_missing_outputs(self): if paths: raise MissingDataSource(paths) + def _filter_outs(self, path_info): + def _func(o): + return path_info.isin_or_eq(o.path_info) + + return filter(_func, self.outs) if path_info else self.outs + @rwlocked(write=["outs"]) - def checkout(self, force=False, progress_callback=None, relink=False): + def checkout( + self, + force=False, + progress_callback=None, + relink=False, + filter_info=None, + ): failed_checkouts = [] - for out in self.outs: + for out in self._filter_outs(filter_info): failed = out.checkout( force=force, tag=self.tag, progress_callback=progress_callback, relink=relink, + filter_info=filter_info, ) if failed: failed_checkouts.append(failed) @@ -1016,5 +1029,17 @@ def _already_cached(self): ) ) - def get_all_files_number(self): - return sum(out.get_files_number() for out in self.outs) + def get_all_files_number(self, filter_info=None): + return sum( + out.get_files_number(filter_info) + for out in self._filter_outs(filter_info) + ) + + def get_used_cache(self, *args, **kwargs): + from .cache import NamedCache + + cache = NamedCache() + for out in self._filter_outs(kwargs.get("filter_info")): + cache.update(out.get_used_cache(*args, **kwargs)) + + return cache diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index e7991e1417..797e424708 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -510,3 +510,16 @@ def test_checkout_relink_protected(tmp_dir, dvc, link): dvc.checkout(["foo.dvc"], relink=True) assert not os.access("foo", os.W_OK) + + +@pytest.mark.parametrize( + "target", + [os.path.join("dir", "subdir"), os.path.join("dir", "subdir", "file")], +) +def test_partial_checkout(tmp_dir, dvc, target): + tmp_dir.dvc_gen({"dir": {"subdir": {"file": "file"}, "other": "other"}}) + shutil.rmtree("dir") + dvc.checkout([target]) + assert list(walk_files("dir", None)) == [ + os.path.join("dir", "subdir", "file") + ] diff --git a/tests/unit/test_repo.py b/tests/unit/test_repo.py index 0d7577bd4c..7b515d5c3f 100644 --- a/tests/unit/test_repo.py +++ b/tests/unit/test_repo.py @@ -1,6 +1,51 @@ import os +import pytest + def test_is_dvc_internal(dvc): assert dvc.is_dvc_internal(os.path.join("path", "to", ".dvc", "file")) assert not dvc.is_dvc_internal(os.path.join("path", "to-non-.dvc", "file")) + + +@pytest.mark.parametrize( + "path", + [ + os.path.join("dir", "subdir", "file"), + os.path.join("dir", "subdir"), + "dir", + ], +) +def test_find_outs_by_path(tmp_dir, dvc, path): + stage, = tmp_dir.dvc_gen( + {"dir": {"subdir": {"file": "file"}, "other": "other"}} + ) + + outs = dvc.find_outs_by_path(path, strict=False) + assert len(outs) == 1 + assert outs[0].path_info == stage.outs[0].path_info + + +@pytest.mark.parametrize( + "path", + [os.path.join("dir", "subdir", "file"), os.path.join("dir", "subdir")], +) +def test_used_cache(tmp_dir, dvc, path): + from dvc.cache import NamedCache + + tmp_dir.dvc_gen({"dir": {"subdir": {"file": "file"}, "other": "other"}}) + expected = NamedCache.make( + "local", "70922d6bf66eb073053a82f77d58c536.dir", "dir" + ) + expected.add( + "local", + "8c7dd922ad47494fc02c388e12c00eac", + os.path.join("dir", "subdir", "file"), + ) + + with dvc.state: + used_cache = dvc.used_cache([path]) + assert ( + used_cache._items == expected._items + and used_cache.external == expected.external + )