diff --git a/stream_alert/athena_partition_refresh/main.py b/stream_alert/athena_partition_refresh/main.py index b02e58ba1..304eb460d 100644 --- a/stream_alert/athena_partition_refresh/main.py +++ b/stream_alert/athena_partition_refresh/main.py @@ -98,12 +98,10 @@ def _load_config(): class ConfigError(Exception): """Custom StreamAlertAthena Config Exception Class""" - pass class AthenaPartitionRefreshError(Exception): """Generic Athena Partition Error for erroring the Lambda function""" - pass class StreamAlertAthenaClient(object): @@ -116,7 +114,7 @@ class StreamAlertAthenaClient(object): athenea_results_key: The key in S3 to store Athena query results """ DATABASE_DEFAULT = 'default' - DEFAULT_DATABASE_STREAMALERT = 'streamalert' + DEFAULT_DATABASE_STREAMALERT = '{}_streamalert' DEFAULT_S3_PREFIX = 'athena_partition_refresh' STREAMALERTS_REGEX = re.compile(r'alerts/dt=(?P\d{4})' @@ -139,9 +137,9 @@ def __init__(self, config, **kwargs): results_key_prefix (str): The S3 key prefix to store Athena results """ self.config = config + self.prefix = self.config['global']['account']['prefix'] region = self.config['global']['account']['region'] - prefix = self.config['global']['account']['prefix'] self.athena_client = boto3.client('athena', region_name=region) athena_config = self.config['lambda']['athena_partition_refresh_config'] @@ -149,7 +147,7 @@ def __init__(self, config, **kwargs): # GEt the S3 bucket to store Athena query results results_bucket = athena_config.get('results_bucket', '').strip() if results_bucket == '': - self.athena_results_bucket = 's3://{}.streamalert.athena-results'.format(prefix) + self.athena_results_bucket = 's3://{}.streamalert.athena-results'.format(self.prefix) elif results_bucket[:5] != 's3://': self.athena_results_bucket = 's3://{}'.format(results_bucket) else: @@ -168,7 +166,7 @@ def sa_database(self): database = self.config['lambda']['athena_partition_refresh_config'].get('database_name', '') database = database.replace(' ', '') # strip any spaces which are invalid database names if database == '': - return self.DEFAULT_DATABASE_STREAMALERT + return self.DEFAULT_DATABASE_STREAMALERT.format(self.prefix) return database @@ -402,7 +400,7 @@ def add_hive_partition(self, s3_buckets_and_keys): if not query_success: raise AthenaPartitionRefreshError( - 'The add hive partition query has failed:\n%s', query + 'The add hive partition query has failed:\n{}'.format(query) ) LOGGER.info('Successfully added the following partitions:\n%s', @@ -639,7 +637,9 @@ def handler(*_): # Check that the 'streamalert' database exists before running queries if not stream_alert_athena.check_database_exists(): - raise AthenaPartitionRefreshError('The \'streamalert\' database does not exist') + raise AthenaPartitionRefreshError( + 'The \'{}\' database does not exist'.format(stream_alert_athena.sa_database) + ) if not stream_alert_athena.add_hive_partition(s3_buckets_and_keys): LOGGER.error('Failed to add hive partition(s)') diff --git a/stream_alert_cli/terraform/athena.py b/stream_alert_cli/terraform/athena.py index 9d6a5e5bd..f9dfb9431 100644 --- a/stream_alert_cli/terraform/athena.py +++ b/stream_alert_cli/terraform/athena.py @@ -36,7 +36,7 @@ def generate_athena(config): prefix = config['global']['account']['prefix'] database = athena_config.get('database_name', '').strip() if database == '': - database = 'streamalert' + database = '{}_streamalert'.format(prefix) results_bucket_name = athena_config.get('results_bucket', '').strip() if results_bucket_name == '': diff --git a/tests/unit/stream_alert_athena_partition_refresh/test_main.py b/tests/unit/stream_alert_athena_partition_refresh/test_main.py index c0de8b80b..50cb30f07 100644 --- a/tests/unit/stream_alert_athena_partition_refresh/test_main.py +++ b/tests/unit/stream_alert_athena_partition_refresh/test_main.py @@ -198,7 +198,11 @@ def setup(self): self.mock_sqs.start() sqs = boto3.resource('sqs', region_name=TEST_REGION) - self.queue = sqs.create_queue(QueueName=StreamAlertSQSClient.QUEUENAME) + + prefix = CONFIG_DATA['global']['account']['prefix'] + name = StreamAlertSQSClient.DEFAULT_QUEUE_NAME.format(prefix) + + self.queue = sqs.create_queue(QueueName=name) self.client = StreamAlertSQSClient(CONFIG_DATA) # Create a fake s3 notification message to send diff --git a/tests/unit/stream_alert_cli/terraform/test_athena.py b/tests/unit/stream_alert_cli/terraform/test_athena.py index d75e61b76..b62da4199 100644 --- a/tests/unit/stream_alert_cli/terraform/test_athena.py +++ b/tests/unit/stream_alert_cli/terraform/test_athena.py @@ -42,10 +42,18 @@ def test_generate_athena(): 'source_object_key': 'lambda/athena/source.zip', 'third_party_libraries': [] } + + prefix = CONFIG['global']['account']['prefix'] + expected_athena_config = { 'module': { 'stream_alert_athena': { + 's3_logging_bucket': '{}.streamalert.s3-logging'.format(prefix), 'source': 'modules/tf_stream_alert_athena', + 'database_name': '{}_streamalert'.format(prefix), + 'queue_name': + '{}_streamalert_athena_data_bucket_notifications'.format(prefix), + 'results_bucket': '{}.streamalert.athena-results'.format(prefix), 'current_version': '$LATEST', 'enable_metrics': False, 'lambda_handler': 'main.handler',