diff --git a/dvc/command/pkg.py b/dvc/command/pkg.py index c82f46c218..8ddaa8c498 100644 --- a/dvc/command/pkg.py +++ b/dvc/command/pkg.py @@ -15,7 +15,10 @@ class CmdPkgInstall(CmdBase): def run(self): try: self.repo.pkg.install( - self.args.url, version=self.args.version, name=self.args.name + self.args.url, + version=self.args.version, + name=self.args.name, + force=self.args.force, ) return 0 except DvcException: @@ -113,6 +116,13 @@ def add_parser(subparsers, parent_parser): "from URL." ), ) + pkg_install_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Reinstall package if it is already installed.", + ) pkg_install_parser.set_defaults(func=CmdPkgInstall) PKG_UNINSTALL_HELP = "Uninstall package(s)." diff --git a/dvc/pkg.py b/dvc/pkg.py index ad7194bfd5..80d7b9cdec 100644 --- a/dvc/pkg.py +++ b/dvc/pkg.py @@ -58,15 +58,17 @@ def repo(self): def installed(self): return os.path.exists(self.path) - def install(self, cache_dir=None): + def install(self, cache_dir=None, force=False): import git if self.installed: - logger.info( - "Skipping installing '{}'('{}') as it is already " - "installed.".format(self.name, self.url) - ) - return + if not force: + logger.info( + "Skipping installing '{}'('{}') as it is already " + "installed.".format(self.name, self.url) + ) + return + self.uninstall() git.Repo.clone_from( self.url, self.path, depth=1, no_single_branch=True @@ -113,9 +115,9 @@ def __init__(self, repo): self.pkg_dir = os.path.join(repo.dvc_dir, self.PKG_DIR) self.cache_dir = repo.cache.local.cache_dir - def install(self, url, **kwargs): + def install(self, url, force=False, **kwargs): pkg = Pkg(self.pkg_dir, url=url, **kwargs) - pkg.install(cache_dir=self.cache_dir) + pkg.install(cache_dir=self.cache_dir, force=force) def uninstall(self, name): pkg = Pkg(self.pkg_dir, name=name) diff --git a/tests/func/test_pkg.py b/tests/func/test_pkg.py index 14bfdba8cf..5a49e16c79 100644 --- a/tests/func/test_pkg.py +++ b/tests/func/test_pkg.py @@ -47,6 +47,24 @@ def test_uninstall_corrupted(repo_dir, dvc_repo): assert not os.path.exists(mypkg_dir) +def test_force_install(repo_dir, dvc_repo, pkg): + name = os.path.basename(pkg.root_dir) + pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") + mypkg_dir = os.path.join(pkg_dir, name) + + os.makedirs(mypkg_dir) + + dvc_repo.pkg.install(pkg.root_dir) + assert not os.listdir(mypkg_dir) + + dvc_repo.pkg.install(pkg.root_dir, force=True) + assert os.path.exists(pkg_dir) + assert os.path.isdir(pkg_dir) + assert os.path.exists(mypkg_dir) + assert os.path.isdir(mypkg_dir) + assert os.path.isdir(os.path.join(mypkg_dir, ".git")) + + def test_install_version(repo_dir, dvc_repo, pkg): name = os.path.basename(pkg.root_dir) pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") diff --git a/tests/unit/command/test_pkg.py b/tests/unit/command/test_pkg.py index 0e2ebe81a6..9ee45d1bb5 100644 --- a/tests/unit/command/test_pkg.py +++ b/tests/unit/command/test_pkg.py @@ -5,7 +5,16 @@ def test_pkg_install(mocker, dvc_repo): args = parse_args( - ["pkg", "install", "url", "--version", "version", "--name", "name"] + [ + "pkg", + "install", + "url", + "--version", + "version", + "--name", + "name", + "--force", + ] ) assert args.func == CmdPkgInstall @@ -14,7 +23,9 @@ def test_pkg_install(mocker, dvc_repo): assert cmd.run() == 0 - m.assert_called_once_with("url", version="version", name="name") + m.assert_called_once_with( + "url", version="version", name="name", force=True + ) def test_pkg_uninstall(mocker, dvc_repo):