diff --git a/dvc/command/imp.py b/dvc/command/imp.py index 9528282a09..6b896c59b8 100644 --- a/dvc/command/imp.py +++ b/dvc/command/imp.py @@ -19,6 +19,7 @@ def run(self): rev=self.args.rev, no_exec=self.args.no_exec, desc=self.args.desc, + jobs=self.args.jobs, ) except DvcException: logger.exception( @@ -82,4 +83,15 @@ def add_parser(subparsers, parent_parser): "This doesn't affect any DVC operations." ), ) + import_parser.add_argument( + "-j", + "--jobs", + type=int, + help=( + "Number of jobs to run simultaneously. " + "The default value is 4 * cpu_count(). " + "For SSH remotes, the default is 4. " + ), + metavar="", + ) import_parser.set_defaults(func=CmdImport) diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 7eb2f87f0a..558a3a464c 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -68,14 +68,14 @@ def save(self): def dumpd(self): return {self.PARAM_PATH: self.def_path, self.PARAM_REPO: self.def_repo} - def download(self, to): + def download(self, to, jobs=None): cache = self.repo.cache.local with self._make_repo(cache_dir=cache.cache_dir) as repo: if self.def_repo.get(self.PARAM_REV_LOCK) is None: self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev() - _, _, cache_infos = repo.fetch_external([self.def_path]) + _, _, cache_infos = repo.fetch_external([self.def_path], jobs=jobs) cache.checkout(to.path_info, cache_infos[0]) diff --git a/dvc/output/base.py b/dvc/output/base.py index 1f4de887b7..b25f424a5c 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -333,8 +333,8 @@ def dumpd(self): def verify_metric(self): raise DvcException(f"verify metric is not supported for {self.scheme}") - def download(self, to): - self.tree.download(self.path_info, to.path_info) + def download(self, to, jobs=None): + self.tree.download(self.path_info, to.path_info, jobs=jobs) def checkout( self, diff --git a/dvc/repo/imp_url.py b/dvc/repo/imp_url.py index 974cd8b83c..cad9652b4b 100644 --- a/dvc/repo/imp_url.py +++ b/dvc/repo/imp_url.py @@ -19,6 +19,7 @@ def imp_url( frozen=True, no_exec=False, desc=None, + jobs=None, ): from dvc.dvcfile import Dvcfile from dvc.stage import Stage, create_stage @@ -61,7 +62,7 @@ def imp_url( if no_exec: stage.ignore_outs() else: - stage.run() + stage.run(jobs=jobs) stage.frozen = frozen diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 66dca0158c..6c416702c8 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -472,7 +472,8 @@ def run( self.remove_outs(ignore_remove=False, force=False) if not self.frozen and self.is_import: - sync_import(self, dry, force) + jobs = kwargs.get("jobs", None) + sync_import(self, dry, force, jobs) elif not self.frozen and self.cmd: run_stage(self, dry, force, **kwargs) else: diff --git a/dvc/stage/imports.py b/dvc/stage/imports.py index 3ca25e2d7e..01792002ac 100644 --- a/dvc/stage/imports.py +++ b/dvc/stage/imports.py @@ -13,7 +13,7 @@ def update_import(stage, rev=None): stage.frozen = frozen -def sync_import(stage, dry=False, force=False): +def sync_import(stage, dry=False, force=False, jobs=None): """Synchronize import's outs to the workspace.""" logger.info( "Importing '{dep}' -> '{out}'".format( @@ -27,4 +27,4 @@ def sync_import(stage, dry=False, force=False): stage.outs[0].checkout() else: stage.save_deps() - stage.deps[0].download(stage.outs[0]) + stage.deps[0].download(stage.outs[0], jobs=jobs) diff --git a/dvc/tree/base.py b/dvc/tree/base.py index 1e76c6d36c..4e6925a6b6 100644 --- a/dvc/tree/base.py +++ b/dvc/tree/base.py @@ -377,6 +377,7 @@ def download( no_progress_bar=False, file_mode=None, dir_mode=None, + jobs=None, ): if not hasattr(self, "_download"): raise RemoteActionNotImplemented("download", self.scheme) @@ -393,14 +394,27 @@ def download( if self.isdir(from_info): return self._download_dir( - from_info, to_info, name, no_progress_bar, file_mode, dir_mode + from_info, + to_info, + name, + no_progress_bar, + file_mode, + dir_mode, + jobs, ) return self._download_file( from_info, to_info, name, no_progress_bar, file_mode, dir_mode ) def _download_dir( - self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + self, + from_info, + to_info, + name, + no_progress_bar, + file_mode, + dir_mode, + jobs, ): from_infos = list(self.walk_files(from_info)) to_infos = ( @@ -422,7 +436,8 @@ def _download_dir( dir_mode=dir_mode, ) ) - with ThreadPoolExecutor(max_workers=self.jobs) as executor: + max_workers = jobs or self.jobs + with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit(download_files, from_info, to_info) for from_info, to_info in zip(from_infos, to_infos) diff --git a/tests/func/test_import.py b/tests/func/test_import.py index 0983e1bb0b..996bafd1c2 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -475,3 +475,25 @@ def test_import_with_no_exec(tmp_dir, dvc, erepo_dir): dst = tmp_dir / "foo_imported" assert not dst.exists() + + +def test_import_with_jobs(mocker, dvc, erepo_dir): + from dvc.data_cloud import DataCloud + + with erepo_dir.chdir(): + erepo_dir.dvc_gen( + { + "dir1": { + "file1": "file1", + "file2": "file2", + "file3": "file3", + "file4": "file4", + }, + }, + commit="init", + ) + + spy = mocker.spy(DataCloud, "pull") + dvc.imp(os.fspath(erepo_dir), "dir1", jobs=3) + run_jobs = tuple(spy.call_args_list[0])[1].get("jobs") + assert run_jobs == 3 diff --git a/tests/unit/command/test_imp.py b/tests/unit/command/test_imp.py index ff04e65032..b9933aa497 100644 --- a/tests/unit/command/test_imp.py +++ b/tests/unit/command/test_imp.py @@ -16,6 +16,8 @@ def test_import(mocker): "version", "--desc", "description", + "--jobs", + "3", ] ) assert cli_args.func == CmdImport @@ -33,6 +35,7 @@ def test_import(mocker): rev="version", no_exec=False, desc="description", + jobs=3, ) @@ -67,4 +70,5 @@ def test_import_no_exec(mocker): rev="version", no_exec=True, desc="description", + jobs=None, )