Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 77 additions & 58 deletions dvc/tree/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,20 @@ def __init__(self, repo, config):
container = self._az_config.get("storage", "container_name", None)
self.path_info = self.PATH_CLS(f"azure://{container}")

self._conn_kwargs = {
opt: config.get(opt) or self._az_config.get("storage", opt, None)
for opt in ["connection_string", "sas_token"]
}
self._conn_kwargs["account_name"] = self._az_config.get(
"storage", "account", None
)
self._conn_kwargs["account_key"] = self._az_config.get(
"storage", "key", None
self._conn_str = config.get(
"connection_string"
) or self._az_config.get("storage", "connection_string", None)

self._account_url = None
if not self._conn_str:
name = self._az_config.get("storage", "account", None)
self._account_url = f"https://{name}.blob.core.windows.net"

self._credential = config.get("sas_token") or self._az_config.get(
"storage", "sas_token", None
)
if not self._credential:
self._credential = self._az_config.get("storage", "key", None)

@cached_property
def _az_config(self):
Expand All @@ -62,64 +66,71 @@ def _az_config(self):
@cached_property
def blob_service(self):
# pylint: disable=no-name-in-module
from azure.storage.blob import BlockBlobService
from azure.common import AzureMissingResourceHttpError
from azure.storage.blob import BlobServiceClient
from azure.core.exceptions import ResourceNotFoundError

logger.debug(f"URL {self.path_info}")
logger.debug(f"Connection options {self._conn_kwargs}")
blob_service = BlockBlobService(**self._conn_kwargs)

if self._conn_str:
logger.debug(f"Using connection string '{self._conn_str}'")
blob_service = BlobServiceClient.from_connection_string(
self._conn_str, credential=self._credential
)
else:
logger.debug(f"Using account url '{self._account_url}'")
blob_service = BlobServiceClient(
self._account_url, credential=self._credential
)

logger.debug(f"Container name {self.path_info.bucket}")
container_client = blob_service.get_container_client(
self.path_info.bucket
)

try: # verify that container exists
blob_service.list_blobs(
self.path_info.bucket, delimiter="/", num_results=1
)
except AzureMissingResourceHttpError:
blob_service.create_container(self.path_info.bucket)
container_client.get_container_properties()
except ResourceNotFoundError:
container_client.create_container()

return blob_service

def get_etag(self, path_info):
etag = self.blob_service.get_blob_properties(
blob_client = self.blob_service.get_blob_client(
path_info.bucket, path_info.path
).properties.etag
)
etag = blob_client.get_blob_properties().etag
return etag.strip('"')

def _generate_download_url(self, path_info, expires=3600):
from azure.storage.blob import ( # pylint:disable=no-name-in-module
BlobPermissions,
BlobSasPermissions,
generate_blob_sas,
)

expires_at = datetime.utcnow() + timedelta(seconds=expires)

sas_token = self.blob_service.generate_blob_shared_access_signature(
path_info.bucket,
path_info.path,
permission=BlobPermissions.READ,
expiry=expires_at,
blob_client = self.blob_service.get_blob_client(
path_info.bucket, path_info.path
)
download_url = self.blob_service.make_blob_url(
path_info.bucket, path_info.path, sas_token=sas_token

sas_token = generate_blob_sas(
blob_client.account_name,
blob_client.container_name,
blob_client.blob_name,
account_key=blob_client.credential.account_key,
permission=BlobSasPermissions(read=True),
expiry=expires_at,
)
return download_url
return blob_client.url + "?" + sas_token

def exists(self, path_info, use_dvcignore=True):
paths = self._list_paths(path_info.bucket, path_info.path)
return any(path_info.path == path for path in paths)

def _list_paths(self, bucket, prefix):
blob_service = self.blob_service
next_marker = None
while True:
blobs = blob_service.list_blobs(
bucket, prefix=prefix, marker=next_marker
)

for blob in blobs:
yield blob.name

if not blobs.next_marker:
break

next_marker = blobs.next_marker
container_client = self.blob_service.get_container_client(bucket)
for blob in container_client.list_blobs(name_starts_with=prefix):
yield blob.name

def walk_files(self, path_info, **kwargs):
if not kwargs.pop("prefix", False):
Expand All @@ -137,29 +148,37 @@ def remove(self, path_info):
raise NotImplementedError

logger.debug(f"Removing {path_info}")
self.blob_service.delete_blob(path_info.bucket, path_info.path)
self.blob_service.get_blob_client(
path_info.bucket, path_info.path
).delete_blob()

def get_file_hash(self, path_info):
return self.get_etag(path_info)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
):
with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar:
self.blob_service.create_blob_from_path(
to_info.bucket,
to_info.path,
from_file,
progress_callback=pbar.update_to,
)

blob_client = self.blob_service.get_blob_client(
to_info.bucket, to_info.path
)
total = os.path.getsize(from_file)
with open(from_file, "rb") as fobj:
with Tqdm.wrapattr(
fobj, "read", desc=name, total=total, disable=no_progress_bar
) as wrapped:
blob_client.upload_blob(wrapped)

def _download(
self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs
):
with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar:
self.blob_service.get_blob_to_path(
from_info.bucket,
from_info.path,
to_file,
progress_callback=pbar.update_to,
)
blob_client = self.blob_service.get_blob_client(
from_info.bucket, from_info.path
)
total = blob_client.get_blob_properties().size
stream = blob_client.download_blob()
with open(to_file, "wb") as fobj:
with Tqdm.wrapattr(
fobj, "write", desc=name, total=total, disable=no_progress_bar
) as wrapped:
stream.readinto(wrapped)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def run(self):
gs = ["google-cloud-storage==1.19.0"]
gdrive = ["pydrive2>=1.4.14"]
s3 = ["boto3>=1.9.201"]
azure = ["azure-storage-blob==2.1.0", "knack"]
azure = ["azure-storage-blob>=12.0", "knack"]
oss = ["oss2==2.6.1"]
ssh = ["paramiko>=2.5.0"]
hdfs = ["pyarrow>=0.17.0"]
Expand Down
12 changes: 6 additions & 6 deletions tests/remotes/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ class Azure(Base, CloudURLInfo):
@pytest.fixture(scope="session")
def azure_server(docker_compose, docker_services):
from azure.storage.blob import ( # pylint: disable=no-name-in-module
BlockBlobService,
BlobServiceClient,
)
from azure.common import ( # pylint: disable=no-name-in-module
AzureException,
from azure.core.exceptions import ( # pylint: disable=no-name-in-module
AzureError,
)

port = docker_services.port_for("azurite", 10000)
connection_string = TEST_AZURE_CONNECTION_STRING.format(port=port)

def _check():
try:
BlockBlobService(
connection_string=connection_string,
BlobServiceClient.from_connection_string(
connection_string
).list_containers()
return True
except AzureException:
except AzureError:
return False

docker_services.wait_until_responsive(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/remote/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_init_env_var(monkeypatch, dvc):
config = {"url": "azure://"}
tree = AzureTree(dvc, config)
assert tree.path_info == "azure://" + container_name
assert tree._conn_kwargs["connection_string"] == connection_string
assert tree._conn_str == connection_string


def test_init(dvc):
Expand All @@ -26,7 +26,7 @@ def test_init(dvc):
config = {"url": url, "connection_string": connection_string}
tree = AzureTree(dvc, config)
assert tree.path_info == url
assert tree._conn_kwargs["connection_string"] == connection_string
assert tree._conn_str == connection_string


def test_get_file_hash(tmp_dir, azure):
Expand Down