From ea1584a2d648bd6aaf62cae2cbf24b6f9ac057a1 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Wed, 12 Aug 2020 14:46:45 +0300 Subject: [PATCH] dvc: migrate to azure-storage-blob >= 12.0 There has been some pretty drastic changes in 12.0, so migration is a little bit more involved. Fixes #3546 --- dvc/tree/azure.py | 135 ++++++++++++++++++-------------- setup.py | 2 +- tests/remotes/azure.py | 12 +-- tests/unit/remote/test_azure.py | 4 +- 4 files changed, 86 insertions(+), 67 deletions(-) diff --git a/dvc/tree/azure.py b/dvc/tree/azure.py index 19f9ecdcb9..78ada11b89 100644 --- a/dvc/tree/azure.py +++ b/dvc/tree/azure.py @@ -35,16 +35,20 @@ def __init__(self, repo, config): container = self._az_config.get("storage", "container_name", None) self.path_info = self.PATH_CLS(f"azure://{container}") - self._conn_kwargs = { - opt: config.get(opt) or self._az_config.get("storage", opt, None) - for opt in ["connection_string", "sas_token"] - } - self._conn_kwargs["account_name"] = self._az_config.get( - "storage", "account", None - ) - self._conn_kwargs["account_key"] = self._az_config.get( - "storage", "key", None + self._conn_str = config.get( + "connection_string" + ) or self._az_config.get("storage", "connection_string", None) + + self._account_url = None + if not self._conn_str: + name = self._az_config.get("storage", "account", None) + self._account_url = f"https://{name}.blob.core.windows.net" + + self._credential = config.get("sas_token") or self._az_config.get( + "storage", "sas_token", None ) + if not self._credential: + self._credential = self._az_config.get("storage", "key", None) @cached_property def _az_config(self): @@ -62,64 +66,71 @@ def _az_config(self): @cached_property def blob_service(self): # pylint: disable=no-name-in-module - from azure.storage.blob import BlockBlobService - from azure.common import AzureMissingResourceHttpError + from azure.storage.blob import BlobServiceClient + from azure.core.exceptions import ResourceNotFoundError logger.debug(f"URL {self.path_info}") - logger.debug(f"Connection options {self._conn_kwargs}") - blob_service = BlockBlobService(**self._conn_kwargs) + + if self._conn_str: + logger.debug(f"Using connection string '{self._conn_str}'") + blob_service = BlobServiceClient.from_connection_string( + self._conn_str, credential=self._credential + ) + else: + logger.debug(f"Using account url '{self._account_url}'") + blob_service = BlobServiceClient( + self._account_url, credential=self._credential + ) + logger.debug(f"Container name {self.path_info.bucket}") + container_client = blob_service.get_container_client( + self.path_info.bucket + ) + try: # verify that container exists - blob_service.list_blobs( - self.path_info.bucket, delimiter="/", num_results=1 - ) - except AzureMissingResourceHttpError: - blob_service.create_container(self.path_info.bucket) + container_client.get_container_properties() + except ResourceNotFoundError: + container_client.create_container() + return blob_service def get_etag(self, path_info): - etag = self.blob_service.get_blob_properties( + blob_client = self.blob_service.get_blob_client( path_info.bucket, path_info.path - ).properties.etag + ) + etag = blob_client.get_blob_properties().etag return etag.strip('"') def _generate_download_url(self, path_info, expires=3600): from azure.storage.blob import ( # pylint:disable=no-name-in-module - BlobPermissions, + BlobSasPermissions, + generate_blob_sas, ) expires_at = datetime.utcnow() + timedelta(seconds=expires) - sas_token = self.blob_service.generate_blob_shared_access_signature( - path_info.bucket, - path_info.path, - permission=BlobPermissions.READ, - expiry=expires_at, + blob_client = self.blob_service.get_blob_client( + path_info.bucket, path_info.path ) - download_url = self.blob_service.make_blob_url( - path_info.bucket, path_info.path, sas_token=sas_token + + sas_token = generate_blob_sas( + blob_client.account_name, + blob_client.container_name, + blob_client.blob_name, + account_key=blob_client.credential.account_key, + permission=BlobSasPermissions(read=True), + expiry=expires_at, ) - return download_url + return blob_client.url + "?" + sas_token def exists(self, path_info, use_dvcignore=True): paths = self._list_paths(path_info.bucket, path_info.path) return any(path_info.path == path for path in paths) def _list_paths(self, bucket, prefix): - blob_service = self.blob_service - next_marker = None - while True: - blobs = blob_service.list_blobs( - bucket, prefix=prefix, marker=next_marker - ) - - for blob in blobs: - yield blob.name - - if not blobs.next_marker: - break - - next_marker = blobs.next_marker + container_client = self.blob_service.get_container_client(bucket) + for blob in container_client.list_blobs(name_starts_with=prefix): + yield blob.name def walk_files(self, path_info, **kwargs): if not kwargs.pop("prefix", False): @@ -137,7 +148,9 @@ def remove(self, path_info): raise NotImplementedError logger.debug(f"Removing {path_info}") - self.blob_service.delete_blob(path_info.bucket, path_info.path) + self.blob_service.get_blob_client( + path_info.bucket, path_info.path + ).delete_blob() def get_file_hash(self, path_info): return self.get_etag(path_info) @@ -145,21 +158,27 @@ def get_file_hash(self, path_info): def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - self.blob_service.create_blob_from_path( - to_info.bucket, - to_info.path, - from_file, - progress_callback=pbar.update_to, - ) + + blob_client = self.blob_service.get_blob_client( + to_info.bucket, to_info.path + ) + total = os.path.getsize(from_file) + with open(from_file, "rb") as fobj: + with Tqdm.wrapattr( + fobj, "read", desc=name, total=total, disable=no_progress_bar + ) as wrapped: + blob_client.upload_blob(wrapped) def _download( self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - self.blob_service.get_blob_to_path( - from_info.bucket, - from_info.path, - to_file, - progress_callback=pbar.update_to, - ) + blob_client = self.blob_service.get_blob_client( + from_info.bucket, from_info.path + ) + total = blob_client.get_blob_properties().size + stream = blob_client.download_blob() + with open(to_file, "wb") as fobj: + with Tqdm.wrapattr( + fobj, "write", desc=name, total=total, disable=no_progress_bar + ) as wrapped: + stream.readinto(wrapped) diff --git a/setup.py b/setup.py index ed4c1c3096..300b5a23bc 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,7 @@ def run(self): gs = ["google-cloud-storage==1.19.0"] gdrive = ["pydrive2>=1.4.14"] s3 = ["boto3>=1.9.201"] -azure = ["azure-storage-blob==2.1.0", "knack"] +azure = ["azure-storage-blob>=12.0", "knack"] oss = ["oss2==2.6.1"] ssh = ["paramiko>=2.5.0"] hdfs = ["pyarrow>=0.17.0"] diff --git a/tests/remotes/azure.py b/tests/remotes/azure.py index 8dae521d2d..27db33298d 100644 --- a/tests/remotes/azure.py +++ b/tests/remotes/azure.py @@ -25,10 +25,10 @@ class Azure(Base, CloudURLInfo): @pytest.fixture(scope="session") def azure_server(docker_compose, docker_services): from azure.storage.blob import ( # pylint: disable=no-name-in-module - BlockBlobService, + BlobServiceClient, ) - from azure.common import ( # pylint: disable=no-name-in-module - AzureException, + from azure.core.exceptions import ( # pylint: disable=no-name-in-module + AzureError, ) port = docker_services.port_for("azurite", 10000) @@ -36,11 +36,11 @@ def azure_server(docker_compose, docker_services): def _check(): try: - BlockBlobService( - connection_string=connection_string, + BlobServiceClient.from_connection_string( + connection_string ).list_containers() return True - except AzureException: + except AzureError: return False docker_services.wait_until_responsive( diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index 0af52f046e..a0614bcf62 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -17,7 +17,7 @@ def test_init_env_var(monkeypatch, dvc): config = {"url": "azure://"} tree = AzureTree(dvc, config) assert tree.path_info == "azure://" + container_name - assert tree._conn_kwargs["connection_string"] == connection_string + assert tree._conn_str == connection_string def test_init(dvc): @@ -26,7 +26,7 @@ def test_init(dvc): config = {"url": url, "connection_string": connection_string} tree = AzureTree(dvc, config) assert tree.path_info == url - assert tree._conn_kwargs["connection_string"] == connection_string + assert tree._conn_str == connection_string def test_get_file_hash(tmp_dir, azure):