Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ def init_drive(self):
),
)
)
self.remote_drive_id = None

@gdrive_retry
def gdrive_upload_file(
self, args, no_progress_bar=True, from_file="", progress_name=""
):
parent = {"id": args["parent_id"]}
item = self.drive.CreateFile(
{"title": args["title"], "parents": [{"id": args["parent_id"]}]}
{"title": args["title"], "parents": [parent]}
)

with open(from_file, "rb") as fobj:
Expand All @@ -133,7 +135,9 @@ def gdrive_upload_file(
def gdrive_download_file(
self, file_id, to_file, progress_name, no_progress_bar
):
gdrive_file = self.drive.CreateFile({"id": file_id})
param = {"id": file_id}
# it does not create a file on the remote
gdrive_file = self.drive.CreateFile(param)
bar_format = (
"Donwloading {desc:{ncols_desc}.{ncols_desc}}... "
+ Tqdm.format_sizeof(int(gdrive_file["fileSize"]), "B", 1024)
Expand All @@ -144,7 +148,12 @@ def gdrive_download_file(
gdrive_file.GetContentFile(to_file)

def gdrive_list_item(self, query):
file_list = self.drive.ListFile({"q": query, "maxResults": 1000})
param = {"q": query, "maxResults": 1000, "corpora": self.corpora}

if self.remote_drive_id:
param["driveId"] = self.remote_drive_id

file_list = self.drive.ListFile(param)

# Isolate and decorate fetching of remote drive items in pages
get_list = gdrive_retry(lambda: next(file_list, None))
Expand Down Expand Up @@ -240,21 +249,22 @@ def drive(self):

self._gdrive = GoogleDrive(gauth)

if self.bucket != "root" and self.bucket != "appDataFolder":
self.remote_drive_id = self.get_remote_drive_id(self.bucket)
self.corpora = "drive" if self.remote_drive_id else "default"
self.remote_root_id = self.get_remote_id(
self.path_info, create=True
)

self._cached_dirs, self._cached_ids = self.cache_root_dirs()

return self._gdrive

@gdrive_retry
def create_remote_dir(self, parent_id, title):
parent = {"id": parent_id}
item = self.drive.CreateFile(
{
"title": title,
"parents": [{"id": parent_id}],
"mimeType": FOLDER_MIME_TYPE,
}
{"title": title, "parents": [parent], "mimeType": FOLDER_MIME_TYPE}
)
item.Upload()
return item
Expand All @@ -272,12 +282,28 @@ def get_remote_item(self, name, parents_ids):

query += " and trashed=false and title='{}'".format(name)

param = {
"q": query,
# Remote might contain items with duplicated titles
"maxResults": 1,
Comment thread
shcheklein marked this conversation as resolved.
"corpora": self.corpora,
}

if self.remote_drive_id:
param["driveId"] = self.remote_drive_id

# Limit found remote items count to 1 in response
item_list = self.drive.ListFile(
{"q": query, "maxResults": 1}
).GetList()
item_list = self.drive.ListFile(param).GetList()
return next(iter(item_list), None)

@gdrive_retry
def get_remote_drive_id(self, remote_id):
param = {"id": remote_id}
# it does not create a file on the remote
item = self.drive.CreateFile(param)
Comment thread
shcheklein marked this conversation as resolved.
Outdated
item.FetchMetadata("driveId")
return item.get("driveId", None)

def resolve_remote_item_from_path(self, path_parts, create):
parents_ids = [self.bucket]
current_path = ""
Expand All @@ -301,6 +327,10 @@ def get_remote_id_from_cache(self, remote_path):
return []

def get_remote_id(self, path_info, create=False):
if not path_info.path and path_info.bucket:
# Case sensitive base path
return self.bucket

remote_ids = self.get_remote_id_from_cache(path_info.path)

if remote_ids:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run(self):
# Extra dependencies for remote integrations

gs = ["google-cloud-storage==1.19.0"]
gdrive = ["pydrive2>=1.4.1"]
gdrive = ["pydrive2>=1.4.2"]
s3 = ["boto3>=1.9.201"]
azure = ["azure-storage-blob==2.1.0"]
oss = ["oss2==2.6.1"]
Expand Down