Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions dvc/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,14 @@ def _cache_is_copy(cache, path_info):


def _checkout_file(
path_info, fs, obj, cache, force, progress_callback=None, relink=False,
path_info,
fs,
obj,
cache,
force,
progress_callback=None,
relink=False,
state=None,
):
"""The file is changed we need to checkout a new copy"""
modified = False
Expand All @@ -170,7 +177,9 @@ def _checkout_file(
_link(cache, cache_info, path_info)
modified = True

fs.repo.state.save(path_info, fs, obj.hash_info)
if state:
state.save(path_info, fs, obj.hash_info)

if progress_callback:
progress_callback(str(path_info))

Expand Down Expand Up @@ -202,6 +211,7 @@ def _checkout_dir(
progress_callback=None,
relink=False,
dvcignore: Optional[DvcIgnoreFilter] = None,
state=None,
):
modified = False
# Create dir separately so that dir is created
Expand All @@ -221,6 +231,7 @@ def _checkout_dir(
force,
progress_callback,
relink,
state=None,
)
if entry_modified:
modified = True
Expand All @@ -232,7 +243,8 @@ def _checkout_dir(
or modified
)

fs.repo.state.save(path_info, fs, obj.hash_info)
if state:
state.save(path_info, fs, obj.hash_info)

# relink is not modified, assume it as nochange
return modified and not relink
Expand All @@ -247,10 +259,11 @@ def _checkout(
progress_callback=None,
relink=False,
dvcignore: Optional[DvcIgnoreFilter] = None,
state=None,
):
if not obj.hash_info.isdir:
ret = _checkout_file(
path_info, fs, obj, cache, force, progress_callback, relink
path_info, fs, obj, cache, force, progress_callback, relink, state,
)
else:
ret = _checkout_dir(
Expand All @@ -262,9 +275,11 @@ def _checkout(
progress_callback,
relink,
dvcignore=dvcignore,
state=state,
)

fs.repo.state.save_link(path_info, fs)
if state:
state.save_link(path_info, fs)

return ret

Expand All @@ -279,6 +294,7 @@ def checkout(
relink=False,
quiet=False,
dvcignore: Optional[DvcIgnoreFilter] = None,
state=None,
):
if path_info.scheme not in ["local", cache.fs.scheme]:
raise NotImplementedError
Expand Down Expand Up @@ -330,4 +346,5 @@ def checkout(
progress_callback,
relink,
dvcignore=dvcignore,
state=state,
)
2 changes: 1 addition & 1 deletion dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
else:
self.dvc_dir = os.path.abspath(os.path.realpath(dvc_dir))

self.wfs = LocalFileSystem(None, {"url": self.dvc_dir})
self.wfs = LocalFileSystem(url=self.dvc_dir)
self.fs = fs or self.wfs

self.load(validate=validate, config=config)
Expand Down
3 changes: 2 additions & 1 deletion dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
def _get(stage, p, info):
parsed = urlparse(p) if p else None
if parsed and parsed.scheme == "remote":
fs = get_cloud_fs(stage.repo, name=parsed.netloc)
cls, config = get_cloud_fs(stage.repo, name=parsed.netloc)
fs = cls(**config)
return DEP_MAP[fs.scheme](stage, p, info, fs=fs)

if info and info.get(RepoDependency.PARAM_REPO):
Expand Down
9 changes: 8 additions & 1 deletion dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ def download(self, to, jobs=None):
)
save(odb, obj, jobs=jobs)

checkout(to.path_info, to.fs, obj, odb, dvcignore=None)
checkout(
to.path_info,
to.fs,
obj,
odb,
dvcignore=None,
state=self.repo.state,
)

def update(self, rev=None):
if rev:
Expand Down
18 changes: 17 additions & 1 deletion dvc/fs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,20 @@ def get_cloud_fs(repo, **kwargs):
remote_conf = SCHEMA["remote"][str](remote_conf)
except Invalid as exc:
raise ConfigError(str(exc)) from None
return get_fs_cls(remote_conf)(repo, remote_conf)

if "jobs" not in remote_conf:
jobs = repo.config["core"].get("jobs")
if jobs:
remote_conf["jobs"] = jobs

if "checksum_jobs" not in remote_conf:
checksum_jobs = repo.config["core"].get("checksum_jobs")
if checksum_jobs:
remote_conf["checksum_jobs"] = checksum_jobs

cls = get_fs_cls(remote_conf)

if isinstance(cls, GDriveFileSystem):
remote_conf["gdrive_credentials_tmp_dir"] = repo.tmp_dir

Comment on lines 91 to 105
Copy link
Copy Markdown
Contributor Author

@efiop efiop May 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These config options would fit nicely into a dynamic config schema, if we had one. Will take a look...

return cls, remote_conf
6 changes: 3 additions & 3 deletions dvc/fs/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ class AzureFileSystem(FSSpecWrapper): # pylint:disable=abstract-method
"azure-identity": "azure.identity",
}

def __init__(self, repo, config):
super().__init__(repo, config)
def __init__(self, **config):
super().__init__(**config)

url = config.get("url")
self.path_info = self.PATH_CLS(url)
Expand All @@ -92,7 +92,7 @@ def __init__(self, repo, config):
self.path_info = self.PATH_CLS(url)
self.bucket = self.path_info.bucket

def _prepare_credentials(self, config):
def _prepare_credentials(self, **config):
from azure.identity.aio import DefaultAzureCredential

# Disable spam from failed cred types for DefaultAzureCredential
Expand Down
32 changes: 7 additions & 25 deletions dvc/fs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
from multiprocessing import cpu_count
from typing import Any, ClassVar, Dict, FrozenSet, Optional

from funcy import cached_property

from dvc.exceptions import DvcException
from dvc.path_info import URLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
from dvc.utils import tmp_fname
from dvc.utils.fs import makedirs, move
from dvc.utils.http import open_url
Expand Down Expand Up @@ -55,29 +52,13 @@ class BaseFileSystem:
PARAM_CHECKSUM: ClassVar[Optional[str]] = None
DETAIL_FIELDS: FrozenSet[str] = frozenset()

def __init__(self, repo, config):
self.repo = repo
self.config = config

self._check_requires()
def __init__(self, **kwargs):
self._check_requires(**kwargs)

self.path_info = None

@cached_property
def jobs(self):
return (
self.config.get("jobs")
or (self.repo and self.repo.config["core"].get("jobs"))
or self._JOBS
)

@cached_property
def hash_jobs(self):
return (
self.config.get("checksum_jobs")
or (self.repo and self.repo.config["core"].get("checksum_jobs"))
or self.HASH_JOBS
)
self.jobs = kwargs.get("jobs") or self._JOBS
self.hash_jobs = kwargs.get("checksum_jobs") or self.HASH_JOBS

@classmethod
def get_missing_deps(cls):
Expand All @@ -92,15 +73,16 @@ def get_missing_deps(cls):

return missing

def _check_requires(self):
def _check_requires(self, **kwargs):
from ..scheme import Schemes
from ..utils import format_link
from ..utils.pkg import PKG

missing = self.get_missing_deps()
if not missing:
return

url = self.config.get("url", f"{self.scheme}://")
url = kwargs.get("url", f"{self.scheme}://")

scheme = self.scheme
if scheme == Schemes.WEBDAVS:
Expand Down
5 changes: 3 additions & 2 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class DvcFileSystem(BaseFileSystem): # pylint:disable=abstract-method
scheme = "local"
PARAM_CHECKSUM = "md5"

def __init__(self, repo):
super().__init__(repo, {"url": repo.root_dir})
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.repo = kwargs["repo"]

def _find_outs(self, path, *args, **kwargs):
outs = self.repo.find_outs_by_path(path, *args, **kwargs)
Expand Down
10 changes: 6 additions & 4 deletions dvc/fs/fsspec_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

# pylint: disable=no-member
class FSSpecWrapper(BaseFileSystem):
def __init__(self, repo, config):
super().__init__(repo, config)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.fs_args = {"skip_instance_cache": True}
self.fs_args.update(self._prepare_credentials(config))
self.fs_args.update(self._prepare_credentials(**kwargs))

@cached_property
def fs(self):
Expand Down Expand Up @@ -50,7 +50,9 @@ def _entry_hook(self, entry):
entries within info() and ls(detail=True) calls"""
return entry

def _prepare_credentials(self, config): # pylint: disable=unused-argument
def _prepare_credentials(
self, **config
): # pylint: disable=unused-argument
"""Prepare the arguments for authentication to the
host filesystem"""
return {}
Expand Down
16 changes: 9 additions & 7 deletions dvc/fs/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class GDriveFileSystem(BaseFileSystem): # pylint:disable=abstract-method
DEFAULT_GDRIVE_CLIENT_ID = "710796635688-iivsgbgsb6uv1fap6635dhvuei09o66c.apps.googleusercontent.com" # noqa: E501
DEFAULT_GDRIVE_CLIENT_SECRET = "a1Fz59uTpVNeG_VGuSKDLJXv"

def __init__(self, repo, config):
super().__init__(repo, config)
def __init__(self, **config):
super().__init__(**config)

self.path_info = self.PATH_CLS(config["url"])

Expand All @@ -126,17 +126,19 @@ def __init__(self, repo, config):
self._client_id = config.get("gdrive_client_id")
self._client_secret = config.get("gdrive_client_secret")
self._validate_config()

tmp_dir = config["gdrive_credentials_tmp_dir"]
assert tmp_dir

self._gdrive_service_credentials_path = tmp_fname(
os.path.join(self.repo.tmp_dir, "")
os.path.join(tmp_dir, "")
)
self._gdrive_user_credentials_path = (
tmp_fname(os.path.join(self.repo.tmp_dir, ""))
tmp_fname(os.path.join(tmp_dir, ""))
if os.getenv(GDriveFileSystem.GDRIVE_CREDENTIALS_DATA)
else config.get(
"gdrive_user_credentials_file",
os.path.join(
self.repo.tmp_dir, self.DEFAULT_USER_CREDENTIALS_FILE,
),
os.path.join(tmp_dir, self.DEFAULT_USER_CREDENTIALS_FILE,),
)
)

Expand Down
2 changes: 1 addition & 1 deletion dvc/fs/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class GitFileSystem(BaseFileSystem): # pylint:disable=abstract-method
scheme = "local"

def __init__(self, root_dir, trie):
super().__init__(None, {})
super().__init__()
self._fs_root = root_dir
self.trie = trie

Expand Down
6 changes: 3 additions & 3 deletions dvc/fs/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class GSFileSystem(FSSpecWrapper): # pylint:disable=abstract-method
PARAM_CHECKSUM = "etag"
DETAIL_FIELDS = frozenset(("etag", "size"))

def __init__(self, repo, config):
super().__init__(repo, config)
def __init__(self, **config):
super().__init__(**config)

url = config.get("url", "gs://")
self.path_info = self.PATH_CLS(url)

def _prepare_credentials(self, config):
def _prepare_credentials(self, **config):
login_info = {"consistency": None}
login_info["project"] = config.get("projectname")
login_info["token"] = config.get("credentialpath")
Expand Down
4 changes: 2 additions & 2 deletions dvc/fs/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ class HDFSFileSystem(BaseFileSystem):
PARAM_CHECKSUM = "checksum"
TRAVERSE_PREFIX_LEN = 2

def __init__(self, repo, config):
super().__init__(repo, config)
def __init__(self, **config):
super().__init__(**config)

self.path_info = None
url = config.get("url")
Expand Down
4 changes: 2 additions & 2 deletions dvc/fs/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class HTTPFileSystem(BaseFileSystem): # pylint:disable=abstract-method
REQUEST_TIMEOUT = 60
CHUNK_SIZE = 2 ** 16

def __init__(self, repo, config):
super().__init__(repo, config)
def __init__(self, **config):
super().__init__(**config)

url = config.get("url")
if url:
Expand Down
18 changes: 7 additions & 11 deletions dvc/fs/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,22 @@ class LocalFileSystem(BaseFileSystem):
PARAM_PATH = "path"
TRAVERSE_PREFIX_LEN = 2

def __init__(self, repo, config):
super().__init__(repo, config)
def __init__(self, **config):
from fsspec.implementations.local import LocalFileSystem as LocalFS

super().__init__(**config)
self.fs = LocalFS()
url = config.get("url")
self.path_info = self.PATH_CLS(url) if url else None

@property
def fs_root(self):
return self.config.get("url")
self.fs_root = url

@staticmethod
def open(path_info, mode="r", encoding=None, **kwargs):
return open(path_info, mode=mode, encoding=encoding)

def exists(self, path_info) -> bool:
assert isinstance(path_info, str) or path_info.scheme == "local"
if self.repo:
ret = os.path.lexists(path_info)
else:
ret = os.path.exists(path_info)
return ret
return self.fs.exists(path_info)

def isfile(self, path_info) -> bool:
return os.path.isfile(path_info)
Expand Down
Loading