Skip to content

Commit

Permalink
[Fix]Add retry for reading streaming data
Browse files Browse the repository at this point in the history
  • Loading branch information
xiafu-msft committed Apr 15, 2021
1 parent 576fdf5 commit 57443b9
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 94 deletions.
112 changes: 66 additions & 46 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import sys
import threading
import time

import requests
import warnings
from io import BytesIO

from azure.core.exceptions import HttpResponseError
from azure.core.exceptions import HttpResponseError, ServiceResponseError
from azure.core.tracing.common import with_current_context
from ._shared.encryption import decrypt_blob
from ._shared.request_handlers import validate_and_format_range_headers
Expand Down Expand Up @@ -43,10 +46,9 @@ def process_range_and_offset(start_range, end_range, length, encryption):
def process_content(data, start_offset, end_offset, encryption):
if data is None:
raise ValueError("Response cannot be None.")
try:
content = b"".join(list(data))
except Exception as error:
raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error)

content = b"".join(list(data))

if content and encryption.get("key") is not None or encryption.get("resolver") is not None:
try:
return decrypt_blob(
Expand Down Expand Up @@ -189,19 +191,28 @@ def _download_chunk(self, chunk_start, chunk_end):
)

try:
_, response = self.client.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
data_stream_total=self.total_size,
download_stream_current=self.progress_total,
**self.request_options
)
retry_active = True
retry_total = 3
while retry_active:
_, response = self.client.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
data_stream_total=self.total_size,
download_stream_current=self.progress_total,
**self.request_options
)
try:
chunk_data = process_content(response, offset[0], offset[1], self.encryption_options)
retry_active = False
except (requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError) as error:
retry_total -= 1
if retry_total <= 0:
raise ServiceResponseError(error, error=error)
time.sleep(1)
except HttpResponseError as error:
process_storage_error(error)

chunk_data = process_content(response, offset[0], offset[1], self.encryption_options)

# This makes sure that if_match is set so that we can validate
# that subsequent downloads are to an unmodified blob
if self.request_options.get("modified_access_conditions"):
Expand Down Expand Up @@ -334,16 +345,6 @@ def __init__(
# TODO: Set to the stored MD5 when the service returns this
self.properties.content_md5 = None

if self.size == 0:
self._current_content = b""
else:
self._current_content = process_content(
self._response,
self._initial_offset[0],
self._initial_offset[1],
self._encryption_options
)

def __len__(self):
return self.size

Expand All @@ -357,29 +358,47 @@ def _initial_request(self):
)

try:
location_mode, response = self._clients.blob.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self._validate_content,
data_stream_total=None,
download_stream_current=0,
**self._request_options
)

# Check the location we read from to ensure we use the same one
# for subsequent requests.
self._location_mode = location_mode
retry_active = True
retry_total = 3
while retry_active:
location_mode, response = self._clients.blob.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self._validate_content,
data_stream_total=None,
download_stream_current=0,
**self._request_options
)

# Parse the total file size and adjust the download size if ranges
# were specified
self._file_size = parse_length_from_content_range(response.properties.content_range)
if self._end_range is not None:
# Use the end range index unless it is over the end of the file
self.size = min(self._file_size, self._end_range - self._start_range + 1)
elif self._start_range is not None:
self.size = self._file_size - self._start_range
else:
self.size = self._file_size
# Check the location we read from to ensure we use the same one
# for subsequent requests.
self._location_mode = location_mode

# Parse the total file size and adjust the download size if ranges
# were specified
self._file_size = parse_length_from_content_range(response.properties.content_range)
if self._end_range is not None:
# Use the end range index unless it is over the end of the file
self.size = min(self._file_size, self._end_range - self._start_range + 1)
elif self._start_range is not None:
self.size = self._file_size - self._start_range
else:
self.size = self._file_size

try:
self._current_content = process_content(
response,
self._initial_offset[0],
self._initial_offset[1],
self._encryption_options
)
retry_active = False
except (requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError) as error:
retry_total -= 1
if retry_total <= 0:
raise ServiceResponseError(error, error=error)
time.sleep(1)

except HttpResponseError as error:
if self._start_range is None and error.response.status_code == 416:
Expand All @@ -399,6 +418,7 @@ def _initial_request(self):
# Set the download size to empty
self.size = 0
self._file_size = 0
self._current_content = b""
else:
process_storage_error(error)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ async def send(self, request):
request.context.options.pop('raw_response_hook', self._response_callback)

response = await self.next.send(request)
await response.http_response.load_body()

will_retry = is_retry(response, request.context.options.get('mode'))
if not will_retry and download_stream_current is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import asyncio
import sys
import time
from io import BytesIO
from itertools import islice
import warnings

from azure.core.exceptions import HttpResponseError
from aiohttp import ClientPayloadError

from azure.core.exceptions import HttpResponseError, ServiceResponseError
from .._shared.encryption import decrypt_blob
from .._shared.request_handlers import validate_and_format_range_headers
from .._shared.response_handlers import process_storage_error, parse_length_from_content_range
Expand All @@ -21,10 +24,11 @@
async def process_content(data, start_offset, end_offset, encryption):
if data is None:
raise ValueError("Response cannot be None.")
try:
content = data.response.body()
except Exception as error:
raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error)

content = b''
async for next_chunk in data:
content += next_chunk

if encryption.get('key') is not None or encryption.get('resolver') is not None:
try:
return decrypt_blob(
Expand Down Expand Up @@ -91,19 +95,29 @@ async def _download_chunk(self, chunk_start, chunk_end):
check_content_md5=self.validate_content
)
try:
_, response = await self.client.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
data_stream_total=self.total_size,
download_stream_current=self.progress_total,
**self.request_options
)
retry_active = True
retry_total = 3
while retry_active:
_, response = await self.client.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
data_stream_total=self.total_size,
download_stream_current=self.progress_total,
**self.request_options
)

try:
chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options)
retry_active = False
except ClientPayloadError as error:
retry_total -= 1
if retry_total <= 0:
raise ServiceResponseError(error, error=error)
time.sleep(1)
except HttpResponseError as error:
process_storage_error(error)

chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options)

# This makes sure that if_match is set so that we can validate
# that subsequent downloads are to an unmodified blob
if self.request_options.get('modified_access_conditions'):
Expand Down Expand Up @@ -243,16 +257,6 @@ async def _setup(self):
# TODO: Set to the stored MD5 when the service returns this
self.properties.content_md5 = None

if self.size == 0:
self._current_content = b""
else:
self._current_content = await process_content(
self._response,
self._initial_offset[0],
self._initial_offset[1],
self._encryption_options
)

async def _initial_request(self):
range_header, range_validation = validate_and_format_range_headers(
self._initial_range[0],
Expand All @@ -262,29 +266,45 @@ async def _initial_request(self):
check_content_md5=self._validate_content)

try:
location_mode, response = await self._clients.blob.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self._validate_content,
data_stream_total=None,
download_stream_current=0,
**self._request_options)

# Check the location we read from to ensure we use the same one
# for subsequent requests.
self._location_mode = location_mode

# Parse the total file size and adjust the download size if ranges
# were specified
self._file_size = parse_length_from_content_range(response.properties.content_range)
if self._end_range is not None:
# Use the length unless it is over the end of the file
self.size = min(self._file_size, self._end_range - self._start_range + 1)
elif self._start_range is not None:
self.size = self._file_size - self._start_range
else:
self.size = self._file_size
retry_active = True
retry_total = 3
while retry_active:
location_mode, response = await self._clients.blob.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self._validate_content,
data_stream_total=None,
download_stream_current=0,
**self._request_options)

# Check the location we read from to ensure we use the same one
# for subsequent requests.
self._location_mode = location_mode

# Parse the total file size and adjust the download size if ranges
# were specified
self._file_size = parse_length_from_content_range(response.properties.content_range)
if self._end_range is not None:
# Use the length unless it is over the end of the file
self.size = min(self._file_size, self._end_range - self._start_range + 1)
elif self._start_range is not None:
self.size = self._file_size - self._start_range
else:
self.size = self._file_size

try:
self._current_content = await process_content(
response,
self._initial_offset[0],
self._initial_offset[1],
self._encryption_options
)
retry_active = False
except ClientPayloadError as error:
retry_total -= 1
if retry_total <= 0:
raise ServiceResponseError(error, error=error)
time.sleep(1)
except HttpResponseError as error:
if self._start_range is None and error.response.status_code == 416:
# Get range will fail on an empty file. If the user did not
Expand All @@ -302,6 +322,7 @@ async def _initial_request(self):
# Set the download size to empty
self.size = 0
self._file_size = 0
self._current_content = b""
else:
process_storage_error(error)

Expand Down

0 comments on commit 57443b9

Please sign in to comment.