Skip to content

Commit

Permalink
[tests] updating some unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ryandeivert committed Feb 14, 2018
1 parent dec72bd commit 0a6fbc4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
16 changes: 8 additions & 8 deletions stream_alert/athena_partition_refresh/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,10 @@ def _load_config():

class ConfigError(Exception):
"""Custom StreamAlertAthena Config Exception Class"""
pass


class AthenaPartitionRefreshError(Exception):
"""Generic Athena Partition Error for erroring the Lambda function"""
pass


class StreamAlertAthenaClient(object):
Expand All @@ -116,7 +114,7 @@ class StreamAlertAthenaClient(object):
athenea_results_key: The key in S3 to store Athena query results
"""
DATABASE_DEFAULT = 'default'
DEFAULT_DATABASE_STREAMALERT = 'streamalert'
DEFAULT_DATABASE_STREAMALERT = '{}_streamalert'
DEFAULT_S3_PREFIX = 'athena_partition_refresh'

STREAMALERTS_REGEX = re.compile(r'alerts/dt=(?P<year>\d{4})'
Expand All @@ -139,17 +137,17 @@ def __init__(self, config, **kwargs):
results_key_prefix (str): The S3 key prefix to store Athena results
"""
self.config = config
self.prefix = self.config['global']['account']['prefix']

region = self.config['global']['account']['region']
prefix = self.config['global']['account']['prefix']
self.athena_client = boto3.client('athena', region_name=region)

athena_config = self.config['lambda']['athena_partition_refresh_config']

# GEt the S3 bucket to store Athena query results
results_bucket = athena_config.get('results_bucket', '').strip()
if results_bucket == '':
self.athena_results_bucket = 's3://{}.streamalert.athena-results'.format(prefix)
self.athena_results_bucket = 's3://{}.streamalert.athena-results'.format(self.prefix)
elif results_bucket[:5] != 's3://':
self.athena_results_bucket = 's3://{}'.format(results_bucket)
else:
Expand All @@ -168,7 +166,7 @@ def sa_database(self):
database = self.config['lambda']['athena_partition_refresh_config'].get('database_name', '')
database = database.replace(' ', '') # strip any spaces which are invalid database names
if database == '':
return self.DEFAULT_DATABASE_STREAMALERT
return self.DEFAULT_DATABASE_STREAMALERT.format(self.prefix)

return database

Expand Down Expand Up @@ -402,7 +400,7 @@ def add_hive_partition(self, s3_buckets_and_keys):

if not query_success:
raise AthenaPartitionRefreshError(
'The add hive partition query has failed:\n%s', query
'The add hive partition query has failed:\n{}'.format(query)
)

LOGGER.info('Successfully added the following partitions:\n%s',
Expand Down Expand Up @@ -639,7 +637,9 @@ def handler(*_):

# Check that the 'streamalert' database exists before running queries
if not stream_alert_athena.check_database_exists():
raise AthenaPartitionRefreshError('The \'streamalert\' database does not exist')
raise AthenaPartitionRefreshError(
'The \'{}\' database does not exist'.format(stream_alert_athena.sa_database)
)

if not stream_alert_athena.add_hive_partition(s3_buckets_and_keys):
LOGGER.error('Failed to add hive partition(s)')
Expand Down
2 changes: 1 addition & 1 deletion stream_alert_cli/terraform/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate_athena(config):
prefix = config['global']['account']['prefix']
database = athena_config.get('database_name', '').strip()
if database == '':
database = 'streamalert'
database = '{}_streamalert'.format(prefix)

results_bucket_name = athena_config.get('results_bucket', '').strip()
if results_bucket_name == '':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def setup(self):
self.mock_sqs.start()

sqs = boto3.resource('sqs', region_name=TEST_REGION)
self.queue = sqs.create_queue(QueueName=StreamAlertSQSClient.QUEUENAME)

prefix = CONFIG_DATA['global']['account']['prefix']
name = StreamAlertSQSClient.DEFAULT_QUEUE_NAME.format(prefix)

self.queue = sqs.create_queue(QueueName=name)
self.client = StreamAlertSQSClient(CONFIG_DATA)

# Create a fake s3 notification message to send
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/stream_alert_cli/terraform/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ def test_generate_athena():
'source_object_key': 'lambda/athena/source.zip',
'third_party_libraries': []
}

prefix = CONFIG['global']['account']['prefix']

expected_athena_config = {
'module': {
'stream_alert_athena': {
's3_logging_bucket': '{}.streamalert.s3-logging'.format(prefix),
'source': 'modules/tf_stream_alert_athena',
'database_name': '{}_streamalert'.format(prefix),
'queue_name':
'{}_streamalert_athena_data_bucket_notifications'.format(prefix),
'results_bucket': '{}.streamalert.athena-results'.format(prefix),
'current_version': '$LATEST',
'enable_metrics': False,
'lambda_handler': 'main.handler',
Expand Down

0 comments on commit 0a6fbc4

Please sign in to comment.