Skip to content

Commit

Permalink
Retry intermittent S3 download failures
Browse files Browse the repository at this point in the history
This builds on the recent work in botocore
(boto/botocore#210) to address two things:

* ensure than any .read() call used from the streaming response of
  `GetObject` will never hang.  We do this by applying a timeout
  to the underlying socket.
* Catch and retry IncompleteReadError. There may be times when we
  don't receive all of the contents from a `GetObject` request such
  that amount_read != content-length header.  Botocore now checks
  this and raises an exception, which the CLI now catches and retries.
  • Loading branch information
jamesls committed Jan 15, 2014
1 parent fd32ff2 commit ed5999f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 24 deletions.
74 changes: 50 additions & 24 deletions awscli/customizations/s3/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import math
import os
import time
import socket
import threading

from botocore.vendored import requests
from botocore.exceptions import IncompleteReadError

from awscli.customizations.s3.utils import find_bucket_key, MD5Error, \
operate, ReadFileChunk, relative_path
Expand All @@ -21,6 +23,10 @@ class DownloadCancelledError(Exception):
pass


class RetriesExeededError(Exception):
pass


def print_operation(filename, failed, dryrun=False):
"""
Helper function used to print out what an operation did and whether
Expand Down Expand Up @@ -292,17 +298,30 @@ class DownloadPartTask(object):

# Amount to read from response body at a time.
ITERATE_CHUNK_SIZE = 1024 * 1024
READ_TIMEOUT = 60
TOTAL_ATTEMPTS = 5

def __init__(self, part_number, chunk_size, result_queue, service,
filename, context):
filename, context, open=open):
self._part_number = part_number
self._chunk_size = chunk_size
self._result_queue = result_queue
self._filename = filename
self._service = filename.service
self._context = context
self._open = open

def __call__(self):
try:
self._download_part()
except Exception as e:
LOGGER.debug(
'Exception caught downloading byte range: %s',
e, exc_info=True)
self._context.cancel()
raise e

def _download_part(self):
total_file_size = self._filename.size
start_range = self._part_number * self._chunk_size
if self._part_number == int(total_file_size / self._chunk_size) - 1:
Expand All @@ -315,34 +334,42 @@ def __call__(self):
bucket, key = find_bucket_key(self._filename.src)
params = {'endpoint': self._filename.endpoint, 'bucket': bucket,
'key': key, 'range': range_param}
try:
LOGGER.debug("Making GetObject requests with byte range: %s",
range_param)
response_data, http = operate(self._service, 'GetObject',
params)
LOGGER.debug("Response received from GetObject")
body = response_data['Body']
self._write_to_file(body)
self._context.announce_completed_part(self._part_number)

message = print_operation(self._filename, 0)
total_parts = int(self._filename.size / self._chunk_size)
result = {'message': message, 'error': False,
'total_parts': total_parts}
self._result_queue.put(result)
except Exception as e:
LOGGER.debug(
'Exception caught downloading byte range: %s',
e, exc_info=True)
self._context.cancel()
raise e
for i in range(self.TOTAL_ATTEMPTS):
try:
LOGGER.debug("Making GetObject requests with byte range: %s",
range_param)
response_data, http = operate(self._service, 'GetObject',
params)
LOGGER.debug("Response received from GetObject")
body = response_data['Body']
self._write_to_file(body)
self._context.announce_completed_part(self._part_number)

message = print_operation(self._filename, 0)
total_parts = int(self._filename.size / self._chunk_size)
result = {'message': message, 'error': False,
'total_parts': total_parts}
self._result_queue.put(result)
return
except (socket.timeout, socket.error) as e:
LOGGER.debug("Socket timeout caught, retrying request, "
"(attempt %s / %s)", i, self.TOTAL_ATTEMPTS,
exc_info=True)
continue
except IncompleteReadError as e:
LOGGER.debug("Incomplete read detected: %s, (attempt %s / %s)",
e, i, self.TOTAL_ATTEMPTS)
continue
raise RetriesExeededError("Maximum number of attempts exceeded: %s" %
self.TOTAL_ATTEMPTS)

def _write_to_file(self, body):
self._context.wait_for_file_created()
LOGGER.debug("Writing part number %s to file: %s",
self._part_number, self._filename.dest)
iterate_chunk_size = self.ITERATE_CHUNK_SIZE
with open(self._filename.dest, 'rb+') as f:
body.set_socket_timeout(self.READ_TIMEOUT)
with self._open(self._filename.dest, 'rb+') as f:
f.seek(self._part_number * self._chunk_size)
current = body.read(iterate_chunk_size)
while current:
Expand All @@ -352,7 +379,6 @@ def _write_to_file(self, body):
self._part_number, self._filename.dest)



class CreateMultipartUploadTask(BasicTask):
def __init__(self, session, filename, parameters, result_queue,
upload_context):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/customizations/s3/fake_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def get_object(self, kwargs):
else:
body = body[int(beginning):(int(end) + 1)]
mock_response = BytesIO(body)
mock_response.set_socket_timeout = Mock()
response_data['Body'] = mock_response
etag = self.session.s3[bucket][key]['ETag']
response_data['ETag'] = etag + '--'
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/customizations/s3/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
import random
import threading
import mock
import socket

from botocore.exceptions import IncompleteReadError

from awscli.customizations.s3.tasks import DownloadPartTask
from awscli.customizations.s3.tasks import MultipartUploadContext
from awscli.customizations.s3.tasks import UploadCancelledError
from awscli.customizations.s3.tasks import print_operation
from awscli.customizations.s3.tasks import RetriesExeededError


class TestMultipartUploadContext(unittest.TestCase):
Expand Down Expand Up @@ -239,3 +244,56 @@ def test_print_operation(self):
filename.dest_type = 's3'
message = print_operation(filename, failed=False)
self.assertIn(r'e:\foo', message)


class TestDownloadPartTask(unittest.TestCase):
def setUp(self):
self.result_queue = mock.Mock()
self.service = mock.Mock()
self.filename = mock.Mock()
self.filename.size = 10 * 1024 * 1024
self.filename.src = 'bucket/key'
self.filename.dest = 'local/file'
self.filename.service = self.service
self.filename.operation_name = 'download'
self.context = mock.Mock()
self.open = mock.MagicMock()

def test_socket_timeout_is_retried(self):
self.service.get_operation.return_value.call.side_effect = socket.error
task = DownloadPartTask(1, 1024 * 1024, self.result_queue,
self.service, self.filename, self.context)
# The mock is configured to keep raising a socket.error
# so we should cancel the download.
with self.assertRaises(RetriesExeededError):
task()
self.context.cancel.assert_called_with()
# And we retried the request multiple times.
self.assertEqual(DownloadPartTask.TOTAL_ATTEMPTS,
self.service.get_operation.call_count)

def test_download_succeeds(self):
body = mock.Mock()
body.read.return_value = b''
self.service.get_operation.return_value.call.side_effect = [
socket.error, (mock.Mock(), {'Body': body})]
context = mock.Mock()
task = DownloadPartTask(1, 1024 * 1024, self.result_queue,
self.service, self.filename, self.context,
self.open)
task()
self.assertEqual(self.result_queue.put.call_count, 1)
# And we tried twice, the first one failed, the second one
# succeeded.
self.assertEqual(self.service.get_operation.call_count, 2)

def test_incomplete_read_is_retried(self):
self.service.get_operation.return_value.call.side_effect = \
IncompleteReadError(actual_bytes=1, expected_bytes=2)
task = DownloadPartTask(1, 1024 * 1024, self.result_queue,
self.service, self.filename, self.context)
with self.assertRaises(RetriesExeededError):
task()
self.context.cancel.assert_called_with()
self.assertEqual(DownloadPartTask.TOTAL_ATTEMPTS,
self.service.get_operation.call_count)

0 comments on commit ed5999f

Please sign in to comment.