From 1892226d9633d9e376be6dcfa7ef217ab72f42b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 29 Apr 2022 15:03:08 +0545 Subject: [PATCH 1/2] introduce data:status command --- dvc/cli/parser.py | 2 + dvc/commands/data.py | 136 +++++++++++++++++++++++++++++++++++++++++++ dvc/repo/__init__.py | 5 +- dvc/repo/ls.py | 28 +++++---- 4 files changed, 157 insertions(+), 14 deletions(-) create mode 100644 dvc/commands/data.py diff --git a/dvc/cli/parser.py b/dvc/cli/parser.py index 012b6cb737..6b43a72109 100644 --- a/dvc/cli/parser.py +++ b/dvc/cli/parser.py @@ -14,6 +14,7 @@ config, daemon, dag, + data, data_sync, destroy, diff, @@ -86,6 +87,7 @@ experiments, check_ignore, machine, + data, ] diff --git a/dvc/commands/data.py b/dvc/commands/data.py new file mode 100644 index 0000000000..f5f51c5d0f --- /dev/null +++ b/dvc/commands/data.py @@ -0,0 +1,136 @@ +import argparse +import logging +import os +from collections import defaultdict +from functools import partial +from operator import itemgetter + +from funcy import log_durations + +from dvc.cli.command import CmdBase +from dvc.cli.utils import fix_subparsers +from dvc.ui import ui + +logger = logging.getLogger(__name__) + +print_durations = partial( + log_durations, + ui.error_write + if logger.isEnabledFor(logging.TRACE) # type: ignore[attr-defined] + else logger.trace, # type: ignore[attr-defined] +) + + +class CmdDataStatus(CmdBase): + def _process_data( + self, + ls_data, + status_data, + diff_data, + git_staged, + git_unstaged, + git_untracked, + ): + files = set(map(itemgetter("path"), ls_data)) + ret = defaultdict(list) + + stage_modified = set() + not_in_cache = set() + for _, stage_status in status_data.items(): + for out_stats in stage_status: + if isinstance(out_stats, dict): + for _, stats in out_stats.items(): + if isinstance(stats, dict): + for path, typ in stats.items(): + if typ == "modified": + stage_modified.add(path) + if typ == "not in cache": + not_in_cache.add(path) + + diff_type_map = { + "modified": "modified_against_head", + "added": "added", + "deleted": "deleted", + "renamed": "renamed", + } + diff_files = set() + for typ, diff_p in diff_data.items(): + if typ not in diff_type_map: + continue + for info in diff_p: + path = info["path"] + if path not in stage_modified: + ret[diff_type_map[typ]].append(path) + diff_files.add(info["path"]) + + ret.update( + { + "stage_modified": list(stage_modified), + "not_in_cache": list(not_in_cache), + "dvc_tracked": list(files), + "git_staged": git_staged, + "git_unstaged": list(git_unstaged), + "git_untracked": list(git_untracked), + } + ) + return ret + + def _patch_clone(self): + from funcy import monkey + + from dvc.scm import Git + + @monkey(Git, "clone") + def clone(url, *args, **kwargs): + with print_durations(f"cloning {os.path.basename(url)}"): + return clone.original(url, *args, **kwargs) + + @print_durations() + def run(self): + from dvc.repo import lock_repo + + with print_durations("scm_status"): + git_staged, git_unstaged, git_untracked = self.repo.scm.status() + + self._patch_clone() + with lock_repo(self.repo): + with print_durations("ls"): + # pylint: disable=protected-access + ls_data = self.repo._ls(recursive=True, dvc_only=True) + with print_durations("status"): + status_data = self.repo.status() + with print_durations("diff"): + diff_data = self.repo.diff() + + processed = self._process_data( + ls_data, + status_data, + diff_data, + git_staged, + git_unstaged, + git_untracked, + ) + ui.write_json(processed) + return 0 + + +def add_parser(subparsers, parent_parser): + data_parser = subparsers.add_parser( + "data", + parents=[parent_parser], + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + data_subparsers = data_parser.add_subparsers( + dest="cmd", + help="Use `dvc data CMD --help` to display command-specific help.", + ) + fix_subparsers(data_subparsers) + data_status_parser = data_subparsers.add_parser( + "status", + parents=[parent_parser], + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + data_status_parser.add_argument( + "--json", action="store_true", default=False + ) + data_status_parser.set_defaults(func=CmdDataStatus) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index e807d8c01c..c8c9848454 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -66,7 +66,8 @@ class Repo: from dvc.repo.imp import imp # type: ignore[misc] from dvc.repo.imp_url import imp_url from dvc.repo.install import install # type: ignore[misc] - from dvc.repo.ls import ls as _ls # type: ignore[misc] + from dvc.repo.ls import _ls # type: ignore[misc] + from dvc.repo.ls import ls # type: ignore[misc] from dvc.repo.move import move from dvc.repo.pull import pull from dvc.repo.push import push @@ -76,7 +77,7 @@ class Repo: from dvc.repo.status import status from dvc.repo.update import update - ls = staticmethod(_ls) + ls = staticmethod(ls) # type: ignore[misc] get = staticmethod(_get) get_url = staticmethod(_get_url) diff --git a/dvc/repo/ls.py b/dvc/repo/ls.py index 55c41030a0..98e9341cef 100644 --- a/dvc/repo/ls.py +++ b/dvc/repo/ls.py @@ -26,25 +26,29 @@ def ls(url, path=None, rev=None, recursive=None, dvc_only=False): "isexec": bool, } """ - from . import Repo + from dvc.repo import Repo with Repo.open(url, rev=rev, subrepos=True, uninitialized=True) as repo: - path = path or "" + # pylint: disable=protected-access + return repo._ls(path=path, recursive=recursive, dvc_only=dvc_only) - ret = _ls(repo.repo_fs, path, recursive, dvc_only) - if path and not ret: - raise PathMissingError(path, repo, dvc_only=dvc_only) +def _ls(repo, path=None, recursive=None, dvc_only=False): + path = path or "" + ret = list_files(repo.repo_fs, path, recursive, dvc_only) - ret_list = [] - for path, info in ret.items(): - info["path"] = path - ret_list.append(info) - ret_list.sort(key=lambda f: f["path"]) - return ret_list + if path and not ret: + raise PathMissingError(path, repo, dvc_only=dvc_only) + ret_list = [] + for path, info in ret.items(): + info["path"] = path + ret_list.append(info) + ret_list.sort(key=lambda f: f["path"]) + return ret_list -def _ls(fs, path, recursive=None, dvc_only=False): + +def list_files(fs, path, recursive=None, dvc_only=False): try: fs_path = fs.info(path)["name"] except FileNotFoundError: From 89ad11c3fac9974ed5ed80517aa44fa21e8c5d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 29 Apr 2022 15:18:21 +0545 Subject: [PATCH 2/2] patch dvc.scm.clone instead of scmrepo.git.Git --- dvc/commands/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dvc/commands/data.py b/dvc/commands/data.py index f5f51c5d0f..9a5bbe6a07 100644 --- a/dvc/commands/data.py +++ b/dvc/commands/data.py @@ -78,9 +78,9 @@ def _process_data( def _patch_clone(self): from funcy import monkey - from dvc.scm import Git + from dvc import scm - @monkey(Git, "clone") + @monkey(scm, "clone") def clone(url, *args, **kwargs): with print_durations(f"cloning {os.path.basename(url)}"): return clone.original(url, *args, **kwargs)