diff --git a/dvc/command/cache.py b/dvc/command/cache.py index 785bf40dc2..d6d38fbdf9 100644 --- a/dvc/command/cache.py +++ b/dvc/command/cache.py @@ -11,10 +11,21 @@ class CmdCacheDir(CmdConfig): def run(self): if self.args.value is None and not self.args.unset: - logger.info(self.config["cache"]["dir"]) + if self.args.level: + conf = self.config.read(level=self.args.level) + else: + # Use merged config with default values + conf = self.config + self._check(conf, False, "cache", "dir") + logger.info(conf["cache"]["dir"]) return 0 - with self.config.edit(level=self.args.level) as edit: - edit["cache"]["dir"] = self.args.value + with self.config.edit(level=self.args.level) as conf: + if self.args.unset: + self._check(conf, False, "cache", "dir") + del conf["cache"]["dir"] + else: + self._check(conf, False, "cache") + conf["cache"]["dir"] = self.args.value return 0 diff --git a/dvc/command/config.py b/dvc/command/config.py index 70deb90f50..d684f3527d 100644 --- a/dvc/command/config.py +++ b/dvc/command/config.py @@ -52,11 +52,10 @@ def run(self): ) return 1 - conf = self.config.load_one(self.args.level) + conf = self.config.read(self.args.level) prefix = self._config_file_prefix( self.args.show_origin, self.config, self.args.level ) - logger.info("\n".join(self._format_config(conf, prefix))) return 0 @@ -67,11 +66,10 @@ def run(self): remote, section, opt = self.args.name if self.args.value is None and not self.args.unset: - conf = self.config.load_one(self.args.level) + conf = self.config.read(self.args.level) prefix = self._config_file_prefix( self.args.show_origin, self.config, self.args.level ) - if remote: conf = conf["remote"] self._check(conf, remote, section, opt) @@ -117,6 +115,7 @@ def _config_file_prefix(show_origin, config, level): if not show_origin: return "" + level = level or "repo" fname = config.files[level] if level in ["local", "repo"]: @@ -141,14 +140,21 @@ def _config_file_prefix(show_origin, config, level): const="system", help="Use system config.", ) +level_group.add_argument( + "--repo", + dest="level", + action="store_const", + const="repo", + help="Use repo config (.dvc/config).", +) level_group.add_argument( "--local", dest="level", action="store_const", const="local", - help="Use local config.", + help="Use local config (.dvc/config.local).", ) -parent_config_parser.set_defaults(level="repo") +parent_config_parser.set_defaults(level=None) def add_parser(subparsers, parent_parser): diff --git a/dvc/command/remote.py b/dvc/command/remote.py index de3f566531..79a46958a6 100644 --- a/dvc/command/remote.py +++ b/dvc/command/remote.py @@ -46,13 +46,14 @@ def run(self): self._check_exists(conf) del conf["remote"][self.args.name] + up_to_level = self.args.level or "repo" # Remove core.remote refs to this remote in any shadowing configs for level in reversed(self.config.LEVELS): with self.config.edit(level) as conf: if conf["core"].get("remote") == self.args.name: del conf["core"]["remote"] - if level == self.args.level: + if level == up_to_level: break return 0 @@ -79,7 +80,7 @@ class CmdRemoteDefault(CmdRemote): def run(self): if self.args.name is None and not self.args.unset: - conf = self.config.load_one(self.args.level) + conf = self.config.read(self.args.level) try: print(conf["core"]["remote"]) except KeyError: @@ -107,7 +108,7 @@ def run(self): class CmdRemoteList(CmdRemote): def run(self): - conf = self.config.load_one(self.args.level) + conf = self.config.read(self.args.level) for name, conf in conf["remote"].items(): logger.info("{}\t{}".format(name, conf["url"])) return 0 @@ -133,8 +134,9 @@ def run(self): del conf["remote"][self.args.name] self._rename_default(conf) + up_to_level = self.args.level or "repo" for level in reversed(self.config.LEVELS): - if level == self.args.level: + if level == up_to_level: break with self.config.edit(level) as level_conf: self._rename_default(level_conf) diff --git a/dvc/config.py b/dvc/config.py index f5d4536215..a6b992e728 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -446,8 +446,16 @@ def load_config_to_level(self, level=None): merge(merged_conf, self.load_one(merge_level)) return merged_conf + def read(self, level=None): + # NOTE: we read from a merged config by default, same as git config + if level is None: + return self.load_config_to_level() + return self.load_one(level) + @contextmanager - def edit(self, level="repo"): + def edit(self, level=None): + # NOTE: we write to repo config by default, same as git config + level = level or "repo" if level in {"repo", "local"} and self.dvc_dir is None: raise ConfigError("Not inside a DVC repo") diff --git a/tests/func/test_cache.py b/tests/func/test_cache.py index 53c8f65643..7e73dde7a1 100644 --- a/tests/func/test_cache.py +++ b/tests/func/test_cache.py @@ -1,5 +1,6 @@ import os import stat +import textwrap import configobj import pytest @@ -213,3 +214,27 @@ def test_shared_cache(tmp_dir, dvc, group, dir_mode): for fname in fnames: path = os.path.join(root, fname) assert stat.S_IMODE(os.stat(path).st_mode) == 0o444 + + +def test_cache_dir_local(tmp_dir, dvc, caplog): + (tmp_dir / ".dvc" / "config.local").write_text( + textwrap.dedent( + """\ + [cache] + dir = some/path + """ + ) + ) + path = os.path.join(dvc.dvc_dir, "some", "path") + + caplog.clear() + assert main(["cache", "dir", "--local"]) == 0 + assert path in caplog.text + + caplog.clear() + assert main(["cache", "dir"]) == 0 + assert path in caplog.text + + caplog.clear() + assert main(["cache", "dir", "--repo"]) == 251 + assert "option 'dir' doesn't exist in section 'cache'" in caplog.text diff --git a/tests/func/test_config.py b/tests/func/test_config.py index a208cbcdaf..57e8d2b133 100644 --- a/tests/func/test_config.py +++ b/tests/func/test_config.py @@ -1,114 +1,179 @@ import os +import textwrap -import configobj import pytest from dvc.config import Config, ConfigError from dvc.main import main -from tests.basic_env import TestDvc -class TestConfigCLI(TestDvc): - def _contains(self, section, field, value, local=False): - fname = self.dvc.config.files["local" if local else "repo"] - - config = configobj.ConfigObj(fname) - if section not in config.keys(): - return False - - if field not in config[section].keys(): - return False - - if config[section][field] != value: - return False - - return True - - def test_root(self): - ret = main(["root"]) - self.assertEqual(ret, 0) - - # NOTE: check that `dvc root` is not blocked with dvc lock - with self.dvc.lock: - ret = main(["root"]) - self.assertEqual(ret, 0) - - def _do_test(self, local=False): - section = "core" - field = "analytics" - section_field = f"{section}.{field}" - value = "True" - newvalue = "False" - - base = ["config"] - if local: - base.append("--local") - - ret = main(base + [section_field, value]) - self.assertEqual(ret, 0) - self.assertTrue(self._contains(section, field, value, local)) - - ret = main(base + [section_field, value, "--show-origin"]) - self.assertEqual(ret, 1) - - ret = main(base + [section_field]) - self.assertEqual(ret, 0) - - ret = main(base + ["--show-origin", section_field]) - self.assertEqual(ret, 0) - - ret = main(base + [section_field, newvalue]) - self.assertEqual(ret, 0) - self.assertTrue(self._contains(section, field, newvalue, local)) - self.assertFalse(self._contains(section, field, value, local)) - - ret = main(base + [section_field, "--unset"]) - self.assertEqual(ret, 0) - self.assertFalse(self._contains(section, field, value, local)) - - ret = main(base + [section_field, "--unset", "--show-origin"]) - self.assertEqual(ret, 1) - - ret = main(base + ["--list"]) - self.assertEqual(ret, 0) +def test_config_set(tmp_dir, dvc): + assert main(["config", "core.analytics", "false"]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == textwrap.dedent( + """\ + [core] + no_scm = True + analytics = false + """ + ) + assert not (tmp_dir / ".dvc" / "config.local").exists() + + assert main(["config", "core.analytics", "true"]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == textwrap.dedent( + """\ + [core] + no_scm = True + analytics = true + """ + ) + assert not (tmp_dir / ".dvc" / "config.local").exists() + + assert main(["config", "core.analytics", "--unset"]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == textwrap.dedent( + """\ + [core] + no_scm = True + """ + ) + assert not (tmp_dir / ".dvc" / "config.local").exists() - ret = main(base + ["--list", "--show-origin"]) - self.assertEqual(ret, 0) - def test(self): - self._do_test(False) +def test_config_set_local(tmp_dir, dvc): + assert main(["config", "core.analytics", "false", "--local"]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == textwrap.dedent( + """\ + [core] + no_scm = True + """ + ) + assert (tmp_dir / ".dvc" / "config.local").read_text() == textwrap.dedent( + """\ + [core] + analytics = false + """ + ) - def test_local(self): - self._do_test(True) + assert main(["config", "core.analytics", "true", "--local"]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == textwrap.dedent( + """\ + [core] + no_scm = True + """ + ) + assert (tmp_dir / ".dvc" / "config.local").read_text() == textwrap.dedent( + """\ + [core] + analytics = true + """ + ) - def test_non_existing(self): - ret = main(["config", "non_existing_section.field"]) - self.assertEqual(ret, 251) + assert main(["config", "core.analytics", "--unset", "--local"]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == textwrap.dedent( + """\ + [core] + no_scm = True + """ + ) + assert (tmp_dir / ".dvc" / "config.local").read_text() == "\n" - ret = main(["config", "global.non_existing_field"]) - self.assertEqual(ret, 251) - ret = main(["config", "non_existing_section.field", "-u"]) - self.assertEqual(ret, 251) +@pytest.mark.parametrize( + "args, ret, msg", + [ + (["core.analytics"], 0, "False"), + (["core.remote"], 0, "myremote"), + (["remote.myremote.profile"], 0, "iterative"), + (["remote.myremote.profile", "--local"], 0, "iterative"), + ( + ["remote.myremote.profile", "--repo"], + 251, + "option 'profile' doesn't exist", + ), + (["remote.other.url"], 0, "gs://bucket/path"), + (["remote.other.url", "--local"], 0, "gs://bucket/path"), + (["remote.other.url", "--repo"], 251, "remote 'other' doesn't exist"), + ], +) +def test_config_get(tmp_dir, dvc, caplog, args, ret, msg): + (tmp_dir / ".dvc" / "config").write_text( + textwrap.dedent( + """\ + [core] + no_scm = true + analytics = False + remote = myremote + ['remote "myremote"'] + url = s3://bucket/path + region = us-east-2 + """ + ) + ) + (tmp_dir / ".dvc" / "config.local").write_text( + textwrap.dedent( + """\ + ['remote "myremote"'] + profile = iterative + ['remote "other"'] + url = gs://bucket/path + """ + ) + ) - ret = main(["config", "global.non_existing_field", "-u"]) - self.assertEqual(ret, 251) + caplog.clear() + assert main(["config"] + args) == ret + assert msg in caplog.text - ret = main(["config", "core.remote", "myremote"]) - self.assertEqual(ret, 0) - ret = main(["config", "core.non_existing_field", "-u"]) - self.assertEqual(ret, 251) +def test_config_list(tmp_dir, dvc, caplog): + (tmp_dir / ".dvc" / "config").write_text( + textwrap.dedent( + """\ + [core] + no_scm = true + analytics = False + remote = myremote + ['remote "myremote"'] + url = s3://bucket/path + region = us-east-2 + """ + ) + ) + (tmp_dir / ".dvc" / "config.local").write_text( + textwrap.dedent( + """\ + ['remote "myremote"'] + profile = iterative + access_key_id = abcde + secret_access_key = 123456 + ['remote "other"'] + url = gs://bucket/path + """ + ) + ) - def test_invalid_config_list(self): - ret = main(["config"]) - self.assertEqual(ret, 1) + caplog.clear() + assert main(["config", "--list"]) == 0 + assert "remote.myremote.url=s3://bucket/path" in caplog.text + assert "remote.myremote.region=us-east-2" in caplog.text + assert "remote.myremote.profile=iterative" in caplog.text + assert "remote.myremote.access_key_id=abcde" in caplog.text + assert "remote.myremote.secret_access_key=123456" in caplog.text + assert "remote.other.url=gs://bucket/path" in caplog.text + assert "core.analytics=False" in caplog.text + assert "core.no_scm=true" in caplog.text + assert "core.remote=myremote" in caplog.text - ret = main(["config", "--list", "core.analytics"]) - self.assertEqual(ret, 1) - ret = main(["config", "--list", "-u"]) - self.assertEqual(ret, 1) +@pytest.mark.parametrize( + "args", [["core.analytics"], ["core.analytics", "false"], ["--unset"]] +) +def test_list_bad_args(tmp_dir, dvc, caplog, args): + caplog.clear() + assert main(["config", "--list"] + args) == 1 + assert ( + "-l/--list can't be used together with any of these options: " + "-u/--unset, name, value" + ) in caplog.text def test_set_invalid_key(dvc): diff --git a/tests/func/test_root.py b/tests/func/test_root.py new file mode 100644 index 0000000000..f39562c5f9 --- /dev/null +++ b/tests/func/test_root.py @@ -0,0 +1,13 @@ +from dvc.main import main + + +def test_root(tmp_dir, dvc, caplog): + assert main(["root"]) == 0 + assert ".\n" in caplog.text + + +def test_root_locked(tmp_dir, dvc, caplog): + # NOTE: check that `dvc root` is not blocked with dvc lock + with dvc.lock: + assert main(["root"]) == 0 + assert ".\n" in caplog.text