diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 5d9985f1b2..e6e28839d5 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -1,5 +1,4 @@ import logging -import os import threading from datetime import datetime, timedelta @@ -16,23 +15,38 @@ class AzureRemoteTree(BaseRemoteTree): scheme = Schemes.AZURE PATH_CLS = CloudURLInfo - REQUIRES = {"azure-storage-blob": "azure.storage.blob"} + REQUIRES = { + "azure-storage-blob": "azure.storage.blob", + "azure-cli-core": "azure.cli.core", + } PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 5000 def __init__(self, repo, config): + from azure.cli.core import get_default_cli + super().__init__(repo, config) + # NOTE: az_config takes care of env vars + az_config = get_default_cli().config + url = config.get("url", "azure://") self.path_info = self.PATH_CLS(url) if not self.path_info.bucket: - container = os.getenv("AZURE_STORAGE_CONTAINER_NAME") + container = az_config.get("storage", "container_name", None) self.path_info = self.PATH_CLS(f"azure://{container}") - self.connection_string = config.get("connection_string") or os.getenv( - "AZURE_STORAGE_CONNECTION_STRING" + self._conn_kwargs = { + opt: config.get(opt) or az_config.get("storage", opt, None) + for opt in ["connection_string", "sas_token"] + } + self._conn_kwargs["account_name"] = az_config.get( + "storage", "account", None + ) + self._conn_kwargs["account_key"] = az_config.get( + "storage", "key", None ) @wrap_prop(threading.Lock()) @@ -43,10 +57,8 @@ def blob_service(self): from azure.common import AzureMissingResourceHttpError logger.debug(f"URL {self.path_info}") - logger.debug(f"Connection string {self.connection_string}") - blob_service = BlockBlobService( - connection_string=self.connection_string - ) + logger.debug(f"Connection options {self._conn_kwargs}") + blob_service = BlockBlobService(**self._conn_kwargs) logger.debug(f"Container name {self.path_info.bucket}") try: # verify that container exists blob_service.list_blobs( diff --git a/setup.py b/setup.py index ab63ab0c16..8d86b18122 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def run(self): gs = ["google-cloud-storage==1.19.0"] gdrive = ["pydrive2>=1.4.14"] s3 = ["boto3>=1.9.201"] -azure = ["azure-storage-blob==2.1.0"] +azure = ["azure-storage-blob==2.1.0", "azure-cli-core>=2.0.70"] oss = ["oss2==2.6.1"] ssh = ["paramiko>=2.5.0"] hdfs = ["pyarrow>=0.17.0"] diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index 360ab3927f..2c601c30d1 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -20,7 +20,7 @@ def test_init_env_var(monkeypatch, dvc): config = {"url": "azure://"} tree = AzureRemoteTree(dvc, config) assert tree.path_info == "azure://" + container_name - assert tree.connection_string == connection_string + assert tree._conn_kwargs["connection_string"] == connection_string def test_init(dvc): @@ -29,7 +29,7 @@ def test_init(dvc): config = {"url": url, "connection_string": connection_string} tree = AzureRemoteTree(dvc, config) assert tree.path_info == url - assert tree.connection_string == connection_string + assert tree._conn_kwargs["connection_string"] == connection_string def test_get_file_hash(tmp_dir):