diff --git a/google/cloud/storage/blob.py b/google/cloud/storage/blob.py index 044ca492e..f7bf720c8 100644 --- a/google/cloud/storage/blob.py +++ b/google/cloud/storage/blob.py @@ -992,7 +992,7 @@ def download_to_file( timeout=_DEFAULT_TIMEOUT, checksum="md5", ): - """Download the contents of this blob into a file-like object. + """DEPRECATED. Download the contents of this blob into a file-like object. .. note:: @@ -1084,31 +1084,19 @@ def download_to_file( """ client = self._require_client(client) - download_url = self._get_download_url( - client, + client.download_blob_to_file( + self, + file_obj=file_obj, + start=start, + end=end, + raw_download=raw_download, if_generation_match=if_generation_match, if_generation_not_match=if_generation_not_match, if_metageneration_match=if_metageneration_match, if_metageneration_not_match=if_metageneration_not_match, + timeout=timeout, + checksum=checksum, ) - headers = _get_encryption_headers(self._encryption_key) - headers["accept-encoding"] = "gzip" - - transport = self._get_transport(client) - try: - self._do_download( - transport, - file_obj, - download_url, - headers, - start, - end, - raw_download, - timeout=timeout, - checksum=checksum, - ) - except resumable_media.InvalidResponse as exc: - _raise_from_invalid_response(exc) def download_to_filename( self, diff --git a/google/cloud/storage/bucket.py b/google/cloud/storage/bucket.py index 8fb28817f..3b51d9f82 100644 --- a/google/cloud/storage/bucket.py +++ b/google/cloud/storage/bucket.py @@ -1229,7 +1229,7 @@ def list_blobs( timeout=_DEFAULT_TIMEOUT, retry=DEFAULT_RETRY, ): - """Return an iterator used to find blobs in the bucket. + """DEPRECATED. Return an iterator used to find blobs in the bucket. .. note:: Direct use of this method is deprecated. Use ``Client.list_blobs`` instead. @@ -1329,52 +1329,24 @@ def list_blobs( >>> client = storage.Client() >>> bucket = storage.Bucket("my-bucket-name", user_project='my-project') - >>> all_blobs = list(bucket.list_blobs()) + >>> all_blobs = list(client.list_blobs(bucket)) """ - extra_params = {"projection": projection} - - if prefix is not None: - extra_params["prefix"] = prefix - - if delimiter is not None: - extra_params["delimiter"] = delimiter - - if start_offset is not None: - extra_params["startOffset"] = start_offset - - if end_offset is not None: - extra_params["endOffset"] = end_offset - - if include_trailing_delimiter is not None: - extra_params["includeTrailingDelimiter"] = include_trailing_delimiter - - if versions is not None: - extra_params["versions"] = versions - - if fields is not None: - extra_params["fields"] = fields - - if self.user_project is not None: - extra_params["userProject"] = self.user_project - client = self._require_client(client) - path = self.path + "/o" - api_request = functools.partial( - client._connection.api_request, timeout=timeout, retry=retry - ) - iterator = page_iterator.HTTPIterator( - client=client, - api_request=api_request, - path=path, - item_to_value=_item_to_blob, - page_token=page_token, + return client.list_blobs( + self, max_results=max_results, - extra_params=extra_params, - page_start=_blobs_page_start, + page_token=page_token, + prefix=prefix, + delimiter=delimiter, + start_offset=start_offset, + end_offset=end_offset, + include_trailing_delimiter=include_trailing_delimiter, + versions=versions, + projection=projection, + fields=fields, + timeout=timeout, + retry=retry, ) - iterator.bucket = self - iterator.prefixes = set() - return iterator def list_notifications( self, client=None, timeout=_DEFAULT_TIMEOUT, retry=DEFAULT_RETRY @@ -3111,7 +3083,7 @@ def make_public( for blob in blobs: blob.acl.all().grant_read() - blob.acl.save(client=client, timeout=timeout, retry=retry) + blob.acl.save(client=client, timeout=timeout) def make_private( self, diff --git a/google/cloud/storage/client.py b/google/cloud/storage/client.py index ab67cca2d..42358ef68 100644 --- a/google/cloud/storage/client.py +++ b/google/cloud/storage/client.py @@ -25,6 +25,8 @@ from google.auth.credentials import AnonymousCredentials +from google import resumable_media + from google.api_core import page_iterator from google.cloud._helpers import _LocalStack, _NOW from google.cloud.client import ClientWithProject @@ -39,8 +41,12 @@ _sign_message, ) from google.cloud.storage.batch import Batch -from google.cloud.storage.bucket import Bucket -from google.cloud.storage.blob import Blob +from google.cloud.storage.bucket import Bucket, _item_to_blob, _blobs_page_start +from google.cloud.storage.blob import ( + Blob, + _get_encryption_headers, + _raise_from_invalid_response, +) from google.cloud.storage.hmac_key import HMACKeyMetadata from google.cloud.storage.acl import BucketACL from google.cloud.storage.acl import DefaultObjectACL @@ -602,7 +608,20 @@ def create_bucket( bucket._set_properties(api_response) return bucket - def download_blob_to_file(self, blob_or_uri, file_obj, start=None, end=None): + def download_blob_to_file( + self, + blob_or_uri, + file_obj, + start=None, + end=None, + raw_download=False, + if_generation_match=None, + if_generation_not_match=None, + if_metageneration_match=None, + if_metageneration_not_match=None, + timeout=_DEFAULT_TIMEOUT, + checksum="md5", + ): """Download the contents of a blob object or blob URI into a file-like object. Args: @@ -617,6 +636,40 @@ def download_blob_to_file(self, blob_or_uri, file_obj, start=None, end=None): (Optional) The first byte in a range to be downloaded. end (int): (Optional) The last byte in a range to be downloaded. + raw_download (bool): + (Optional) If true, download the object without any expansion. + if_generation_match (long): + (Optional) Make the operation conditional on whether + the blob's current generation matches the given value. + Setting to 0 makes the operation succeed only if there + are no live versions of the blob. + if_generation_not_match (long): + (Optional) Make the operation conditional on whether + the blob's current generation does not match the given + value. If no live blob exists, the precondition fails. + Setting to 0 makes the operation succeed only if there + is a live version of the blob. + if_metageneration_match (long): + (Optional) Make the operation conditional on whether the + blob's current metageneration matches the given value. + if_metageneration_not_match (long): + (Optional) Make the operation conditional on whether the + blob's current metageneration does not match the given value. + timeout ([Union[float, Tuple[float, float]]]): + (Optional) The number of seconds the transport should wait for the + server response. Depending on the retry strategy, a request may be + repeated several times using the same timeout each time. + Can also be passed as a tuple (connect_timeout, read_timeout). + See :meth:`requests.Session.request` documentation for details. + checksum (str): + (Optional) The type of checksum to compute to verify the integrity + of the object. The response headers must contain a checksum of the + requested type. If the headers lack an appropriate checksum (for + instance in the case of transcoded or ranged downloads where the + remote service does not know the correct checksum, including + downloads where chunk_size is set) an INFO-level log will be + emitted. Supported values are "md5", "crc32c" and None. The default + is "md5". Examples: Download a blob using a blob resource. @@ -642,11 +695,33 @@ def download_blob_to_file(self, blob_or_uri, file_obj, start=None, end=None): """ + if not isinstance(blob_or_uri, Blob): + blob_or_uri = Blob.from_string(blob_or_uri) + download_url = blob_or_uri._get_download_url( + self, + if_generation_match=if_generation_match, + if_generation_not_match=if_generation_not_match, + if_metageneration_match=if_metageneration_match, + if_metageneration_not_match=if_metageneration_not_match, + ) + headers = _get_encryption_headers(blob_or_uri._encryption_key) + headers["accept-encoding"] = "gzip" + + transport = self._http try: - blob_or_uri.download_to_file(file_obj, client=self, start=start, end=end) - except AttributeError: - blob = Blob.from_string(blob_or_uri, self) - blob.download_to_file(file_obj, client=self, start=start, end=end) + blob_or_uri._do_download( + transport, + file_obj, + download_url, + headers, + start, + end, + raw_download, + timeout=timeout, + checksum=checksum, + ) + except resumable_media.InvalidResponse as exc: + _raise_from_invalid_response(exc) def list_blobs( self, @@ -761,21 +836,50 @@ def list_blobs( >>> all_blobs = list(client.list_blobs(bucket)) """ bucket = self._bucket_arg_to_bucket(bucket_or_name) - return bucket.list_blobs( - max_results=max_results, - page_token=page_token, - prefix=prefix, - delimiter=delimiter, - start_offset=start_offset, - end_offset=end_offset, - include_trailing_delimiter=include_trailing_delimiter, - versions=versions, - projection=projection, - fields=fields, + + extra_params = {"projection": projection} + + if prefix is not None: + extra_params["prefix"] = prefix + + if delimiter is not None: + extra_params["delimiter"] = delimiter + + if start_offset is not None: + extra_params["startOffset"] = start_offset + + if end_offset is not None: + extra_params["endOffset"] = end_offset + + if include_trailing_delimiter is not None: + extra_params["includeTrailingDelimiter"] = include_trailing_delimiter + + if versions is not None: + extra_params["versions"] = versions + + if fields is not None: + extra_params["fields"] = fields + + if bucket.user_project is not None: + extra_params["userProject"] = bucket.user_project + + path = bucket.path + "/o" + api_request = functools.partial( + self._connection.api_request, timeout=timeout, retry=DEFAULT_RETRY + ) + iterator = page_iterator.HTTPIterator( client=self, - timeout=timeout, - retry=retry, + api_request=api_request, + path=path, + item_to_value=_item_to_blob, + page_token=page_token, + max_results=max_results, + extra_params=extra_params, + page_start=_blobs_page_start, ) + iterator.bucket = bucket + iterator.prefixes = set() + return iterator def list_buckets( self, diff --git a/tests/unit/test_blob.py b/tests/unit/test_blob.py index 28f4e31d2..e7caa90a2 100644 --- a/tests/unit/test_blob.py +++ b/tests/unit/test_blob.py @@ -52,6 +52,12 @@ def _get_default_timeout(): return _DEFAULT_TIMEOUT + @staticmethod + def _make_client(*args, **kw): + from google.cloud.storage.client import Client + + return Client(*args, **kw) + def test_ctor_wo_encryption_key(self): BLOB_NAME = "blob-name" bucket = _Bucket() @@ -1176,7 +1182,7 @@ def test_download_to_file_with_failure(self): blob_name = "blob-name" media_link = "http://test.invalid" - client = mock.Mock(spec=[u"_http"]) + client = self._make_client() bucket = _Bucket(client) blob = self._make_one(blob_name, bucket=bucket) blob._properties["mediaLink"] = media_link @@ -1204,8 +1210,7 @@ def test_download_to_file_with_failure(self): def test_download_to_file_wo_media_link(self): blob_name = "blob-name" - client = mock.Mock(_connection=_Connection, spec=[u"_http"]) - client._connection.API_BASE_URL = "https://storage.googleapis.com" + client = self._make_client() bucket = _Bucket(client) blob = self._make_one(blob_name, bucket=bucket) blob._do_download = mock.Mock() @@ -1243,8 +1248,7 @@ def test_download_to_file_w_generation_match(self): ) ) - client = mock.Mock(_connection=_Connection, spec=[u"_http"]) - client._connection.API_BASE_URL = "https://storage.googleapis.com" + client = self._make_client() blob = self._make_one("blob-name", bucket=_Bucket(client)) blob._do_download = mock.Mock() file_obj = io.BytesIO() @@ -1265,7 +1269,7 @@ def test_download_to_file_w_generation_match(self): def _download_to_file_helper(self, use_chunks, raw_download, timeout=None): blob_name = "blob-name" - client = mock.Mock(spec=[u"_http"]) + client = self._make_client() bucket = _Bucket(client) media_link = "http://example.com/media/" properties = {"mediaLink": media_link} @@ -1324,7 +1328,7 @@ def _download_to_filename_helper(self, updated, raw_download, timeout=None): from google.cloud._testing import _NamedTemporaryFile blob_name = "blob-name" - client = mock.Mock(spec=["_http"]) + client = self._make_client() bucket = _Bucket(client) media_link = "http://example.com/media/" properties = {"mediaLink": media_link} @@ -1377,7 +1381,7 @@ def test_download_to_filename_w_generation_match(self): EXPECTED_LINK = MEDIA_LINK + "?ifGenerationMatch={}".format(GENERATION_NUMBER) HEADERS = {"accept-encoding": "gzip"} - client = mock.Mock(spec=["_http"]) + client = self._make_client() blob = self._make_one( "blob-name", bucket=_Bucket(client), properties={"mediaLink": MEDIA_LINK} @@ -1422,7 +1426,7 @@ def test_download_to_filename_corrupted(self): from google.resumable_media import DataCorruption blob_name = "blob-name" - client = mock.Mock(spec=["_http"]) + client = self._make_client() bucket = _Bucket(client) media_link = "http://example.com/media/" properties = {"mediaLink": media_link} @@ -1465,7 +1469,7 @@ def test_download_to_filename_w_key(self): blob_name = "blob-name" # Create a fake client/bucket and use them in the Blob() constructor. - client = mock.Mock(spec=["_http"]) + client = self._make_client() bucket = _Bucket(client) media_link = "http://example.com/media/" properties = {"mediaLink": media_link} @@ -1496,7 +1500,7 @@ def test_download_to_filename_w_key(self): def _download_as_bytes_helper(self, raw_download, timeout=None): blob_name = "blob-name" - client = mock.Mock(spec=["_http"]) + client = self._make_client() bucket = _Bucket(client) media_link = "http://example.com/media/" properties = {"mediaLink": media_link} diff --git a/tests/unit/test_bucket.py b/tests/unit/test_bucket.py index 255953dcb..f3f2b4cd0 100644 --- a/tests/unit/test_bucket.py +++ b/tests/unit/test_bucket.py @@ -449,6 +449,12 @@ def _get_default_timeout(): return _DEFAULT_TIMEOUT + @staticmethod + def _make_client(*args, **kw): + from google.cloud.storage.client import Client + + return Client(*args, **kw) + def _make_one(self, client=None, name=None, properties=None, user_project=None): if client is None: connection = _Connection() @@ -854,7 +860,8 @@ def test_get_blob_hit_with_kwargs(self): def test_list_blobs_defaults(self): NAME = "name" connection = _Connection({"items": []}) - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) iterator = bucket.list_blobs() blobs = list(iterator) @@ -892,7 +899,8 @@ def test_list_blobs_w_all_arguments_and_user_project(self): "userProject": USER_PROJECT, } connection = _Connection({"items": []}) - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(name=NAME, user_project=USER_PROJECT) iterator = bucket.list_blobs( max_results=MAX_RESULTS, @@ -1039,7 +1047,8 @@ def test_delete_hit_with_user_project(self): GET_BLOBS_RESP = {"items": []} connection = _Connection(GET_BLOBS_RESP) connection._delete_bucket = True - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME, user_project=USER_PROJECT) result = bucket.delete(force=True, timeout=42) self.assertIsNone(result) @@ -1063,7 +1072,8 @@ def test_delete_force_delete_blobs(self): DELETE_BLOB1_RESP = DELETE_BLOB2_RESP = {} connection = _Connection(GET_BLOBS_RESP, DELETE_BLOB1_RESP, DELETE_BLOB2_RESP) connection._delete_bucket = True - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) result = bucket.delete(force=True) self.assertIsNone(result) @@ -1112,7 +1122,8 @@ def test_delete_force_miss_blobs(self): # Note the connection does not have a response for the blob. connection = _Connection(GET_BLOBS_RESP) connection._delete_bucket = True - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) result = bucket.delete(force=True) self.assertIsNone(result) @@ -1135,7 +1146,8 @@ def test_delete_too_many(self): GET_BLOBS_RESP = {"items": [{"name": BLOB_NAME1}, {"name": BLOB_NAME2}]} connection = _Connection(GET_BLOBS_RESP) connection._delete_bucket = True - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) # Make the Bucket refuse to delete with 2 objects. @@ -2288,13 +2300,11 @@ def test_versioning_enabled_getter(self): @mock.patch("warnings.warn") def test_create_deprecated(self, mock_warn): - from google.cloud.storage.client import Client - PROJECT = "PROJECT" BUCKET_NAME = "bucket-name" DATA = {"name": BUCKET_NAME} connection = _make_connection(DATA) - client = Client(project=PROJECT) + client = self._make_client(project=PROJECT) client._base_connection = connection bucket = self._make_one(client=client, name=BUCKET_NAME) @@ -2318,13 +2328,11 @@ def test_create_deprecated(self, mock_warn): ) def test_create_w_user_project(self): - from google.cloud.storage.client import Client - PROJECT = "PROJECT" BUCKET_NAME = "bucket-name" DATA = {"name": BUCKET_NAME} connection = _make_connection(DATA) - client = Client(project=PROJECT) + client = self._make_client(project=PROJECT) client._base_connection = connection bucket = self._make_one(client=client, name=BUCKET_NAME) @@ -2749,9 +2757,9 @@ def all(self): def grant_read(self): self._granted = True - def save(self, client=None, timeout=None, retry=DEFAULT_RETRY): + def save(self, client=None, timeout=None): _saved.append( - (self._bucket, self._name, self._granted, client, timeout, retry) + (self._bucket, self._name, self._granted, client, timeout) ) def item_to_blob(self, item): @@ -2762,16 +2770,18 @@ def item_to_blob(self, item): permissive = [{"entity": "allUsers", "role": _ACLEntity.READER_ROLE}] after = {"acl": permissive, "defaultObjectAcl": []} connection = _Connection(after, {"items": [{"name": BLOB_NAME}]}) - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) bucket.acl.loaded = True bucket.default_object_acl.loaded = True - with mock.patch("google.cloud.storage.bucket._item_to_blob", new=item_to_blob): + with mock.patch("google.cloud.storage.client._item_to_blob", new=item_to_blob): bucket.make_public(recursive=True, timeout=42, retry=DEFAULT_RETRY) + self.assertEqual(list(bucket.acl), permissive) self.assertEqual(list(bucket.default_object_acl), []) - self.assertEqual(_saved, [(bucket, BLOB_NAME, True, None, 42, DEFAULT_RETRY)]) + self.assertEqual(_saved, [(bucket, BLOB_NAME, True, None, 42)]) kw = connection._requested self.assertEqual(len(kw), 2) self.assertEqual(kw[0]["method"], "PATCH") @@ -2781,6 +2791,7 @@ def item_to_blob(self, item): self.assertEqual(kw[0]["timeout"], 42) self.assertEqual(kw[1]["method"], "GET") self.assertEqual(kw[1]["path"], "/b/%s/o" % NAME) + self.assertEqual(kw[1]["retry"], DEFAULT_RETRY) max_results = bucket._MAX_OBJECTS_FOR_ITERATION + 1 self.assertEqual( kw[1]["query_params"], {"maxResults": max_results, "projection": "full"} @@ -2798,7 +2809,8 @@ def test_make_public_recursive_too_many(self): BLOB_NAME2 = "blob-name2" GET_BLOBS_RESP = {"items": [{"name": BLOB_NAME1}, {"name": BLOB_NAME2}]} connection = _Connection(AFTER, GET_BLOBS_RESP) - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) bucket.acl.loaded = True bucket.default_object_acl.loaded = True @@ -2891,9 +2903,9 @@ def all(self): def revoke_read(self): self._granted = False - def save(self, client=None, timeout=None, retry=DEFAULT_RETRY): + def save(self, client=None, timeout=None): _saved.append( - (self._bucket, self._name, self._granted, client, timeout, retry) + (self._bucket, self._name, self._granted, client, timeout) ) def item_to_blob(self, item): @@ -2904,16 +2916,17 @@ def item_to_blob(self, item): no_permissions = [] after = {"acl": no_permissions, "defaultObjectAcl": []} connection = _Connection(after, {"items": [{"name": BLOB_NAME}]}) - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) bucket.acl.loaded = True bucket.default_object_acl.loaded = True - with mock.patch("google.cloud.storage.bucket._item_to_blob", new=item_to_blob): + with mock.patch("google.cloud.storage.client._item_to_blob", new=item_to_blob): bucket.make_private(recursive=True, timeout=42, retry=DEFAULT_RETRY) self.assertEqual(list(bucket.acl), no_permissions) self.assertEqual(list(bucket.default_object_acl), []) - self.assertEqual(_saved, [(bucket, BLOB_NAME, False, None, 42, DEFAULT_RETRY)]) + self.assertEqual(_saved, [(bucket, BLOB_NAME, False, None, 42)]) kw = connection._requested self.assertEqual(len(kw), 2) self.assertEqual(kw[0]["method"], "PATCH") @@ -2923,6 +2936,7 @@ def item_to_blob(self, item): self.assertEqual(kw[0]["timeout"], 42) self.assertEqual(kw[1]["method"], "GET") self.assertEqual(kw[1]["path"], "/b/%s/o" % NAME) + self.assertEqual(kw[1]["retry"], DEFAULT_RETRY) max_results = bucket._MAX_OBJECTS_FOR_ITERATION + 1 self.assertEqual( kw[1]["query_params"], {"maxResults": max_results, "projection": "full"} @@ -2938,7 +2952,8 @@ def test_make_private_recursive_too_many(self): BLOB_NAME2 = "blob-name2" GET_BLOBS_RESP = {"items": [{"name": BLOB_NAME1}, {"name": BLOB_NAME2}]} connection = _Connection(AFTER, GET_BLOBS_RESP) - client = _Client(connection) + client = self._make_client() + client._base_connection = connection bucket = self._make_one(client=client, name=NAME) bucket.acl.loaded = True bucket.default_object_acl.loaded = True @@ -2951,7 +2966,8 @@ def test_page_empty_response(self): from google.api_core import page_iterator connection = _Connection() - client = _Client(connection) + client = self._make_client() + client._base_connection = connection name = "name" bucket = self._make_one(client=client, name=name) iterator = bucket.list_blobs() @@ -2968,7 +2984,8 @@ def test_page_non_empty_response(self): blob_name = "blob-name" response = {"items": [{"name": blob_name}], "prefixes": ["foo"]} connection = _Connection() - client = _Client(connection) + client = self._make_client() + client._base_connection = connection name = "name" bucket = self._make_one(client=client, name=name) @@ -2998,8 +3015,7 @@ def test_cumulative_prefixes(self): "nextPageToken": "s39rmf9", } response2 = {"items": [], "prefixes": ["bar"]} - connection = _Connection() - client = _Client(connection) + client = self._make_client() name = "name" bucket = self._make_one(client=client, name=name) responses = [response1, response2] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 4efc35e98..ee0e387dd 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1007,32 +1007,72 @@ def test_create_bucket_w_object_success(self): json_sent = http.request.call_args_list[0][1]["data"] self.assertEqual(json_expected, json.loads(json_sent)) - def test_download_blob_to_file_with_blob(self): - project = "PROJECT" + def test_download_blob_to_file_with_failure(self): + from google.resumable_media import InvalidResponse + from google.cloud.storage.blob import Blob + from google.cloud.storage.constants import _DEFAULT_TIMEOUT + + raw_response = requests.Response() + raw_response.status_code = http_client.NOT_FOUND + raw_request = requests.Request("GET", "http://example.com") + raw_response.request = raw_request.prepare() + grmp_response = InvalidResponse(raw_response) + credentials = _make_credentials() - client = self._make_one(project=project, credentials=credentials) - blob = mock.Mock() + client = self._make_one(credentials=credentials) + blob = mock.create_autospec(Blob) + blob._encryption_key = None + blob._get_download_url = mock.Mock() + blob._do_download = mock.Mock() + blob._do_download.side_effect = grmp_response + file_obj = io.BytesIO() + with self.assertRaises(exceptions.NotFound): + client.download_blob_to_file(blob, file_obj) + + self.assertEqual(file_obj.tell(), 0) - client.download_blob_to_file(blob, file_obj) - blob.download_to_file.assert_called_once_with( - file_obj, client=client, start=None, end=None + headers = {"accept-encoding": "gzip"} + blob._do_download.assert_called_once_with( + client._http, + file_obj, + blob._get_download_url(), + headers, + None, + None, + False, + checksum="md5", + timeout=_DEFAULT_TIMEOUT, ) def test_download_blob_to_file_with_uri(self): + from google.cloud.storage.constants import _DEFAULT_TIMEOUT + project = "PROJECT" credentials = _make_credentials() client = self._make_one(project=project, credentials=credentials) blob = mock.Mock() file_obj = io.BytesIO() + blob._encryption_key = None + blob._get_download_url = mock.Mock() + blob._do_download = mock.Mock() with mock.patch( "google.cloud.storage.client.Blob.from_string", return_value=blob ): client.download_blob_to_file("gs://bucket_name/path/to/object", file_obj) - blob.download_to_file.assert_called_once_with( - file_obj, client=client, start=None, end=None + headers = {"accept-encoding": "gzip"} + blob._do_download.assert_called_once_with( + client._http, + file_obj, + blob._get_download_url(), + headers, + None, + None, + False, + checksum="md5", + timeout=_DEFAULT_TIMEOUT, ) def test_download_blob_to_file_with_invalid_uri(self): @@ -1044,6 +1084,51 @@ def test_download_blob_to_file_with_invalid_uri(self): with pytest.raises(ValueError, match="URI scheme must be gs"): client.download_blob_to_file("http://bucket_name/path/to/object", file_obj) + def _download_blob_to_file_helper(self, use_chunks, raw_download): + from google.cloud.storage.blob import Blob + from google.cloud.storage.constants import _DEFAULT_TIMEOUT + + credentials = _make_credentials() + client = self._make_one(credentials=credentials) + blob = mock.create_autospec(Blob) + blob._encryption_key = None + blob._get_download_url = mock.Mock() + if use_chunks: + blob._CHUNK_SIZE_MULTIPLE = 1 + blob.chunk_size = 3 + blob._do_download = mock.Mock() + + file_obj = io.BytesIO() + if raw_download: + client.download_blob_to_file(blob, file_obj, raw_download=True) + else: + client.download_blob_to_file(blob, file_obj) + + headers = {"accept-encoding": "gzip"} + blob._do_download.assert_called_once_with( + client._http, + file_obj, + blob._get_download_url(), + headers, + None, + None, + raw_download, + checksum="md5", + timeout=_DEFAULT_TIMEOUT, + ) + + def test_download_blob_to_file_wo_chunks_wo_raw(self): + self._download_blob_to_file_helper(use_chunks=False, raw_download=False) + + def test_download_blob_to_file_w_chunks_wo_raw(self): + self._download_blob_to_file_helper(use_chunks=True, raw_download=False) + + def test_download_blob_to_file_wo_chunks_w_raw(self): + self._download_blob_to_file_helper(use_chunks=False, raw_download=True) + + def test_download_blob_to_file_w_chunks_w_raw(self): + self._download_blob_to_file_helper(use_chunks=True, raw_download=True) + def test_list_blobs(self): from google.cloud.storage.bucket import Bucket