diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 508e9f35bad1..dbd114975faa 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -3,6 +3,9 @@ ## 1.5.1 (Unreleased) +### Bug fixes + +- Fix AttributeException in StreamDownloadGenerator #11462 ## 1.5.0 (2020-05-04) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 0f064e618f23..9b0089401513 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -234,12 +234,12 @@ async def __anext__(self): else: await asyncio.sleep(retry_interval) headers = {'range': 'bytes=' + str(self.downloaded) + '-'} - resp = self.pipeline.run(self.request, stream=True, headers=headers) - if resp.status_code == 416: + resp = await self.pipeline.run(self.request, stream=True, headers=headers) + if resp.http_response.status_code == 416: raise chunk = await self.response.internal_response.content.read(self.block_size) if not chunk: - raise StopIteration() + raise StopAsyncIteration() self.downloaded += len(chunk) return chunk continue diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 33d924a4004d..d8f82b34f4da 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -175,13 +175,13 @@ class HttpTransport( @abc.abstractmethod def send(self, request, **kwargs): - # type: (PipelineRequest, Any) -> PipelineResponse + # type: (HTTPRequestType, Any) -> HTTPResponseType """Send the request using this HTTP sender. :param request: The pipeline request object - :type request: ~azure.core.pipeline.PipelineRequest + :type request: ~azure.core.transport.HTTPRequest :return: The pipeline response object. - :rtype: ~azure.core.pipeline.PipelineResponse + :rtype: ~azure.core.pipeline.transport.HttpResponse """ @abc.abstractmethod diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index d79999ed6297..bfc51ef6109b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -28,7 +28,7 @@ import abc from collections.abc import AsyncIterator -from typing import AsyncIterator as AsyncIteratorType, TypeVar, Generic +from typing import AsyncIterator as AsyncIteratorType, TypeVar, Generic, Any from ._base import ( _HttpResponseBase, _HttpClientTransportResponse, diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index 537f902ece3b..ea0e7b0e4c9c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -137,7 +137,7 @@ def __next__(self): time.sleep(retry_interval) headers = {'range': 'bytes=' + str(self.downloaded) + '-'} resp = self.pipeline.run(self.request, stream=True, headers=headers) - if resp.status_code == 416: + if resp.http_response.status_code == 416: raise chunk = next(self.iter_content_func) if not chunk: diff --git a/sdk/core/azure-core/tests/azure_core_asynctests/test_stream_generator.py b/sdk/core/azure-core/tests/azure_core_asynctests/test_stream_generator.py new file mode 100644 index 000000000000..f7368159615f --- /dev/null +++ b/sdk/core/azure-core/tests/azure_core_asynctests/test_stream_generator.py @@ -0,0 +1,107 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.core.pipeline.transport import ( + HttpRequest, + AsyncHttpResponse, + AsyncHttpTransport, +) +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator +from unittest import mock +import pytest + +@pytest.mark.asyncio +async def test_connection_error_response(): + class MockTransport(AsyncHttpTransport): + def __init__(self): + self._count = 0 + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + async def close(self): + pass + async def open(self): + pass + + async def send(self, request, **kwargs): + request = HttpRequest('GET', 'http://127.0.0.1/') + response = AsyncHttpResponse(request, None) + response.status_code = 200 + return response + + class MockContent(): + def __init__(self): + self._first = True + + async def read(self, block_size): + if self._first: + self._first = False + raise ConnectionError + return None + + class MockInternalResponse(): + def __init__(self): + self.headers = {} + self.content = MockContent() + + async def close(self): + pass + + class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + http_request = HttpRequest('GET', 'http://127.0.0.1/') + pipeline = AsyncPipeline(MockTransport()) + http_response = AsyncHttpResponse(http_request, None) + http_response.internal_response = MockInternalResponse() + stream = AioHttpStreamDownloadGenerator(pipeline, http_response) + with mock.patch('asyncio.sleep', new_callable=AsyncMock): + with pytest.raises(StopAsyncIteration): + await stream.__anext__() + +@pytest.mark.asyncio +async def test_connection_error_416(): + class MockTransport(AsyncHttpTransport): + def __init__(self): + self._count = 0 + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + async def close(self): + pass + async def open(self): + pass + + async def send(self, request, **kwargs): + request = HttpRequest('GET', 'http://127.0.0.1/') + response = AsyncHttpResponse(request, None) + response.status_code = 416 + return response + + class MockContent(): + async def read(self, block_size): + raise ConnectionError + + class MockInternalResponse(): + def __init__(self): + self.headers = {} + self.content = MockContent() + + async def close(self): + pass + + class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + http_request = HttpRequest('GET', 'http://127.0.0.1/') + pipeline = AsyncPipeline(MockTransport()) + http_response = AsyncHttpResponse(http_request, None) + http_response.internal_response = MockInternalResponse() + stream = AioHttpStreamDownloadGenerator(pipeline, http_response) + with mock.patch('asyncio.sleep', new_callable=AsyncMock): + with pytest.raises(ConnectionError): + await stream.__anext__() diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py new file mode 100644 index 000000000000..4f40bba885eb --- /dev/null +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -0,0 +1,101 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import requests +from azure.core.pipeline.transport import ( + HttpRequest, + HttpResponse, + HttpTransport, +) +from azure.core.pipeline import Pipeline, PipelineResponse +from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator +try: + from unittest import mock +except ImportError: + import mock +import pytest + +def test_connection_error_response(): + class MockTransport(HttpTransport): + def __init__(self): + self._count = 0 + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + def close(self): + pass + def open(self): + pass + + def send(self, request, **kwargs): + request = HttpRequest('GET', 'http://127.0.0.1/') + response = HttpResponse(request, None) + response.status_code = 200 + return response + + def next(self): + self.__next__() + + def __next__(self): + if self._count == 0: + self._count += 1 + raise requests.exceptions.ConnectionError + + class MockInternalResponse(): + def iter_content(self, block_size): + return MockTransport() + + def close(self): + pass + + http_request = HttpRequest('GET', 'http://127.0.0.1/') + pipeline = Pipeline(MockTransport()) + http_response = HttpResponse(http_request, None) + http_response.internal_response = MockInternalResponse() + stream = StreamDownloadGenerator(pipeline, http_response) + with mock.patch('time.sleep', return_value=None): + with pytest.raises(StopIteration): + stream.__next__() + +def test_connection_error_416(): + class MockTransport(HttpTransport): + def __init__(self): + self._count = 0 + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + def close(self): + pass + def open(self): + pass + + def send(self, request, **kwargs): + request = HttpRequest('GET', 'http://127.0.0.1/') + response = HttpResponse(request, None) + response.status_code = 416 + return response + + def next(self): + self.__next__() + + def __next__(self): + if self._count == 0: + self._count += 1 + raise requests.exceptions.ConnectionError + + class MockInternalResponse(): + def iter_content(self, block_size): + return MockTransport() + + def close(self): + pass + + http_request = HttpRequest('GET', 'http://127.0.0.1/') + pipeline = Pipeline(MockTransport()) + http_response = HttpResponse(http_request, None) + http_response.internal_response = MockInternalResponse() + stream = StreamDownloadGenerator(pipeline, http_response) + with mock.patch('time.sleep', return_value=None): + with pytest.raises(requests.exceptions.ConnectionError): + stream.__next__() \ No newline at end of file