Skip to content

Commit

Permalink
[athena] KeyError bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack Naglieri committed Jan 21, 2018
1 parent a1c171b commit e29c034
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
20 changes: 13 additions & 7 deletions stream_alert/athena_partition_refresh/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,12 @@ def get_messages(self, **kwargs):
# Backoff up to 5 times to limit the time spent in this operation
# relative to the entire Lambda duration.
max_tries = kwargs.get('max_tries', 5)

# This value restricts the max time of backoff each try.
# This means the total backoff time for one function call is:
# max_tries (attempts) * max_value (seconds)
max_value = kwargs.get('max_value', 5)

# Number of messages to poll from the stream.
max_messages = kwargs.get('max_messages', self.MAX_SQS_GET_MESSAGE_COUNT)
if max_messages > self.MAX_SQS_GET_MESSAGE_COUNT:
Expand Down Expand Up @@ -487,6 +489,7 @@ def _delete_messages_from_queue():
# Determine the message batch for SQS message deletion
len_processed_messages = len(self.processed_messages)
batch = len_processed_messages if len_processed_messages < 10 else 10
# Pop processed records from the list to be deleted
message_batch = [self.processed_messages.pop() for _ in range(batch)]

# Try to delete the batch
Expand All @@ -497,16 +500,19 @@ def _delete_messages_from_queue():
for message in message_batch])

# Handle successful deletions
self.deleted_messages += len(resp['Successful'])

if resp.get('Successful'):
self.deleted_messages += len(resp['Successful'])
# Handle failure deletion
if resp.get('Failed'):
elif resp.get('Failed'):
LOGGER.error('Failed to delete the following (%d) messages:\n%s',
len(resp['Failed']), json.dumps(resp['Failed']))
# Add the failed messages back to the processed_messages attribute
failed_from_batch = [[message for message in message_batch if message['MessageId']
== failed_message['Id']] for failed_message in resp['Failed']]
self.processed_messages.extend(failed_from_batch)
# to be retried via backoff
self.processed_messages.extend([[message
for message
in message_batch
if message['MessageId'] == failed_message['Id']]
for failed_message in resp['Failed']])

return len(self.processed_messages)

Expand Down Expand Up @@ -563,7 +569,7 @@ def unique_s3_buckets_and_keys(self):
object_key = urllib.unquote(record['s3']['object']['key']).decode('utf8')
s3_buckets_and_keys[bucket_name].add(object_key)

# Add to a new list to denote processed messages
# Add to a new list to track successfully processed messages from the queue
self.processed_messages.append(message)

return s3_buckets_and_keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ def test_delete_messages_none_received(self, mock_logging):
def test_delete_messages_failure(self, mock_logging, mock_sqs_client):
"""Athena SQS - Delete Messages - Failure Response"""
instance = mock_sqs_client.return_value
instance.sqs_client.delete_message_batch.return_value = {
'Successful': [{'Id': '2'}], 'Failed': [{'Id': '1'}]}
instance.sqs_client.delete_message_batch.return_value = {'Failed': [{'Id': '1'}]}

self.client.get_messages()
self.client.unique_s3_buckets_and_keys()
Expand Down

0 comments on commit e29c034

Please sign in to comment.