diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 09928b7d5f..ab737fdd26 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -12,7 +12,7 @@ from dvc.dependency.s3 import S3Dependency from dvc.dependency.ssh import SSHDependency from dvc.output.base import BaseOutput -from dvc.remote import get_remote +from dvc.remote import get_cloud_tree from dvc.scheme import Schemes from .repo import RepoDependency @@ -54,8 +54,8 @@ def _get(stage, p, info): parsed = urlparse(p) if p else None if parsed and parsed.scheme == "remote": - remote = get_remote(stage.repo, name=parsed.netloc) - return DEP_MAP[remote.scheme](stage, p, info, remote=remote) + tree = get_cloud_tree(stage.repo, name=parsed.netloc) + return DEP_MAP[tree.scheme](stage, p, info, tree=tree) if info and info.get(RepoDependency.PARAM_REPO): repo = info.pop(RepoDependency.PARAM_REPO) diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 3ae2a77dfb..be9eb2a066 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -26,7 +26,7 @@ def __init__(self, def_repo, stage, *args, **kwargs): self.def_repo = def_repo super().__init__(stage, *args, **kwargs) - def _parse_path(self, remote, path): + def _parse_path(self, tree, path): return None @property diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index 57e6eff82d..3f7654d4e0 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -10,7 +10,7 @@ from dvc.output.local import LocalOutput from dvc.output.s3 import S3Output from dvc.output.ssh import SSHOutput -from dvc.remote import get_remote +from dvc.remote import get_cloud_tree from dvc.remote.hdfs import HDFSRemoteTree from dvc.remote.local import LocalRemoteTree from dvc.remote.s3 import S3RemoteTree @@ -66,13 +66,13 @@ def _get( parsed = urlparse(p) if parsed.scheme == "remote": - remote = get_remote(stage.repo, name=parsed.netloc) - return OUTS_MAP[remote.scheme]( + tree = get_cloud_tree(stage.repo, name=parsed.netloc) + return OUTS_MAP[tree.scheme]( stage, p, info, cache=cache, - remote=remote, + tree=tree, metric=metric, plot=plot, persist=persist, @@ -85,7 +85,7 @@ def _get( p, info, cache=cache, - remote=None, + tree=None, metric=metric, plot=plot, persist=persist, @@ -95,7 +95,7 @@ def _get( p, info, cache=cache, - remote=None, + tree=None, metric=metric, plot=plot, persist=persist, diff --git a/dvc/output/base.py b/dvc/output/base.py index 4c7a99dcda..aba44a7960 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -12,7 +12,7 @@ DvcException, RemoteCacheRequiredError, ) -from dvc.remote.base import BaseRemoteTree, Remote +from dvc.remote.base import BaseRemoteTree logger = logging.getLogger(__name__) @@ -47,7 +47,6 @@ def __init__(self, path): class BaseOutput: IS_DEPENDENCY = False - REMOTE_CLS = Remote TREE_CLS = BaseRemoteTree PARAM_PATH = "path" @@ -85,7 +84,7 @@ def __init__( stage, path, info=None, - remote=None, + tree=None, cache=True, metric=False, plot=False, @@ -106,24 +105,23 @@ def __init__( self.repo = stage.repo if stage else None self.def_path = path self.info = info - if remote: - self.remote = remote + if tree: + self.tree = tree else: - tree = self.TREE_CLS(self.repo, {}) - self.remote = self.REMOTE_CLS(tree) + self.tree = self.TREE_CLS(self.repo, {}) self.use_cache = False if self.IS_DEPENDENCY else cache self.metric = False if self.IS_DEPENDENCY else metric self.plot = False if self.IS_DEPENDENCY else plot self.persist = persist - self.path_info = self._parse_path(remote, path) + self.path_info = self._parse_path(tree, path) if self.use_cache and self.cache is None: raise RemoteCacheRequiredError(self.path_info) - def _parse_path(self, remote, path): - if remote: + def _parse_path(self, tree, path): + if tree: parsed = urlparse(path) - return remote.path_info / parsed.path.lstrip("/") + return tree.path_info / parsed.path.lstrip("/") return self.TREE_CLS.PATH_CLS(path) def __repr__(self): @@ -167,29 +165,29 @@ def cache_path(self): @property def checksum_type(self): - return self.remote.tree.PARAM_CHECKSUM + return self.tree.PARAM_CHECKSUM @property def checksum(self): - return self.info.get(self.remote.tree.PARAM_CHECKSUM) + return self.info.get(self.tree.PARAM_CHECKSUM) @checksum.setter def checksum(self, checksum): - self.info[self.remote.tree.PARAM_CHECKSUM] = checksum + self.info[self.tree.PARAM_CHECKSUM] = checksum def get_checksum(self): - return self.remote.get_hash(self.path_info) + return self.tree.get_hash(self.path_info) @property def is_dir_checksum(self): - return self.remote.is_dir_hash(self.checksum) + return self.tree.is_dir_hash(self.checksum) @property def exists(self): - return self.remote.tree.exists(self.path_info) + return self.tree.exists(self.path_info) def save_info(self): - return self.remote.save_info(self.path_info) + return self.tree.save_info(self.path_info) def changed_checksum(self): return self.checksum != self.get_checksum() @@ -222,13 +220,13 @@ def changed(self): @property def is_empty(self): - return self.remote.tree.is_empty(self.path_info) + return self.tree.is_empty(self.path_info) def isdir(self): - return self.remote.tree.isdir(self.path_info) + return self.tree.isdir(self.path_info) def isfile(self): - return self.remote.tree.isfile(self.path_info) + return self.tree.isfile(self.path_info) # pylint: disable=no-member @@ -316,7 +314,7 @@ def verify_metric(self): raise DvcException(f"verify metric is not supported for {self.scheme}") def download(self, to): - self.remote.tree.download(self.path_info, to.path_info) + self.tree.download(self.path_info, to.path_info) def checkout( self, @@ -342,7 +340,7 @@ def checkout( ) def remove(self, ignore_remove=False): - self.remote.tree.remove(self.path_info) + self.tree.remove(self.path_info) if self.scheme != "local": return @@ -354,7 +352,7 @@ def move(self, out): if self.scheme == "local" and self.use_scm_ignore: self.repo.scm.ignore_remove(self.fspath) - self.remote.tree.move(self.path_info, out.path_info) + self.tree.move(self.path_info, out.path_info) self.def_path = out.def_path self.path_info = out.path_info self.save() @@ -373,7 +371,7 @@ def get_files_number(self, filter_info=None): def unprotect(self): if self.exists: - self.remote.tree.unprotect(self.path_info) + self.tree.unprotect(self.path_info) def get_dir_cache(self, **kwargs): if not self.is_dir_checksum: @@ -433,8 +431,8 @@ def collect_used_dir_cache( filter_path = str(filter_info) if filter_info else None is_win = os.name == "nt" for entry in self.dir_cache: - checksum = entry[self.remote.tree.PARAM_CHECKSUM] - entry_relpath = entry[self.remote.tree.PARAM_RELPATH] + checksum = entry[self.tree.PARAM_CHECKSUM] + entry_relpath = entry[self.tree.PARAM_RELPATH] if is_win: entry_relpath = entry_relpath.replace("/", os.sep) entry_path = os.path.join(path, entry_relpath) diff --git a/dvc/output/local.py b/dvc/output/local.py index 9dcebbdc12..9783773578 100644 --- a/dvc/output/local.py +++ b/dvc/output/local.py @@ -5,7 +5,7 @@ from dvc.exceptions import DvcException from dvc.istextfile import istextfile from dvc.output.base import BaseOutput -from dvc.remote.local import LocalRemote, LocalRemoteTree +from dvc.remote.local import LocalRemoteTree from dvc.utils import relpath from dvc.utils.fs import path_isin @@ -13,7 +13,6 @@ class LocalOutput(BaseOutput): - REMOTE_CLS = LocalRemote TREE_CLS = LocalRemoteTree sep = os.sep @@ -23,10 +22,10 @@ def __init__(self, stage, path, *args, **kwargs): super().__init__(stage, path, *args, **kwargs) - def _parse_path(self, remote, path): + def _parse_path(self, tree, path): parsed = urlparse(path) if parsed.scheme == "remote": - p = remote.path_info / parsed.path.lstrip("/") + p = tree.path_info / parsed.path.lstrip("/") else: # NOTE: we can path either from command line or .dvc file, # so we should expect both posix and windows style paths. diff --git a/dvc/output/ssh.py b/dvc/output/ssh.py index 27d3e6d436..017f5de5a6 100644 --- a/dvc/output/ssh.py +++ b/dvc/output/ssh.py @@ -1,7 +1,6 @@ from dvc.output.base import BaseOutput -from dvc.remote.ssh import SSHRemote, SSHRemoteTree +from dvc.remote.ssh import SSHRemoteTree class SSHOutput(BaseOutput): - REMOTE_CLS = SSHRemote TREE_CLS = SSHRemoteTree diff --git a/dvc/repo/tree.py b/dvc/repo/tree.py index 1639d144a6..bc6e7fc9dd 100644 --- a/dvc/repo/tree.py +++ b/dvc/repo/tree.py @@ -49,11 +49,11 @@ def _get_granular_checksum(self, path, out, remote=None): raise FileNotFoundError dir_cache = out.get_dir_cache(remote=remote) for entry in dir_cache: - entry_relpath = entry[out.remote.tree.PARAM_RELPATH] + entry_relpath = entry[out.tree.PARAM_RELPATH] if os.name == "nt": entry_relpath = entry_relpath.replace("/", os.sep) if path == out.path_info / entry_relpath: - return entry[out.remote.tree.PARAM_CHECKSUM] + return entry[out.tree.PARAM_CHECKSUM] raise FileNotFoundError def open( @@ -156,7 +156,7 @@ def _add_dir(self, top, trie, out, download_callback=None, **kwargs): download_callback(downloaded) for entry in dir_cache: - entry_relpath = entry[out.remote.tree.PARAM_RELPATH] + entry_relpath = entry[out.tree.PARAM_RELPATH] if os.name == "nt": entry_relpath = entry_relpath.replace("/", os.sep) path_info = out.path_info / entry_relpath diff --git a/tests/func/test_tree.py b/tests/func/test_tree.py index 8cb49d10d4..55e495c38f 100644 --- a/tests/func/test_tree.py +++ b/tests/func/test_tree.py @@ -192,7 +192,7 @@ def test_repotree_walk_fetch(tmp_dir, dvc, scm, local_remote): assert os.path.exists(out.cache_path) for entry in out.dir_cache: - hash_ = entry[out.remote.tree.PARAM_CHECKSUM] + hash_ = entry[out.tree.PARAM_CHECKSUM] assert os.path.exists(dvc.cache.local.hash_to_path_info(hash_)) diff --git a/tests/unit/dependency/test_local.py b/tests/unit/dependency/test_local.py index bff5590f1f..49936cecca 100644 --- a/tests/unit/dependency/test_local.py +++ b/tests/unit/dependency/test_local.py @@ -15,6 +15,6 @@ def _get_dependency(self): def test_save_missing(self): d = self._get_dependency() - with mock.patch.object(d.remote.tree, "exists", return_value=False): + with mock.patch.object(d.tree, "exists", return_value=False): with self.assertRaises(d.DoesNotExistError): d.save() diff --git a/tests/unit/output/test_local.py b/tests/unit/output/test_local.py index 0cf50869b2..488db4b0f3 100644 --- a/tests/unit/output/test_local.py +++ b/tests/unit/output/test_local.py @@ -19,7 +19,7 @@ def _get_output(self): def test_save_missing(self): o = self._get_output() - with patch.object(o.remote.tree, "exists", return_value=False): + with patch.object(o.tree, "exists", return_value=False): with self.assertRaises(o.DoesNotExistError): o.save()