From 6f28d4e5652bd111fa3d822dbc3bfe794468191a Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Fri, 14 Sep 2018 16:15:45 -0700 Subject: [PATCH] caching athena client --- stream_alert/athena_partition_refresh/main.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/stream_alert/athena_partition_refresh/main.py b/stream_alert/athena_partition_refresh/main.py index 7a5c1d7db..6c27738c8 100644 --- a/stream_alert/athena_partition_refresh/main.py +++ b/stream_alert/athena_partition_refresh/main.py @@ -45,6 +45,8 @@ class AthenaRefresher(object): STREAMALERT_DATABASE = '{}_streamalert' ATHENA_S3_PREFIX = 'athena_partition_refresh' + _ATHENA_CLIENT = None + def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] @@ -63,14 +65,21 @@ def __init__(self): 's3://{}.streamalert.athena-results'.format(prefix) ) - self._athena_client = AthenaClient( - db_name, - results_bucket, - self.ATHENA_S3_PREFIX - ) - self._s3_buckets_and_keys = defaultdict(set) + self._create_client(db_name, results_bucket) + + @classmethod + def _create_client(cls, db_name, results_bucket): + if cls._ATHENA_CLIENT: + return # Client already created/cached + + cls._ATHENA_CLIENT = AthenaClient(db_name, results_bucket, cls.ATHENA_S3_PREFIX) + + # Check if the database exists when the client is created + if not cls._ATHENA_CLIENT.check_database_exists(): + raise AthenaRefreshError('The \'{}\' database does not exist'.format(db_name)) + def _get_partitions_from_keys(self): """Get the partitions that need to be added for the Athena tables @@ -151,7 +160,7 @@ def _add_partitions(self): athena_table=athena_table, partition_statement=partition_statement)) - success = self._athena_client.run_query(query=query) + success = self._ATHENA_CLIENT.run_query(query=query) if not success: raise AthenaRefreshError( 'The add hive partition query has failed:\n{}'.format(query) @@ -169,11 +178,6 @@ def run(self, event): should contain one (or maybe more) S3 bucket notification message. """ # Check that the database being used exists before running queries - if not self._athena_client.check_database_exists(): - raise AthenaRefreshError( - 'The \'{}\' database does not exist'.format(self._athena_client.database) - ) - for sqs_rec in event['Records']: LOGGER.debug('Processing event with message ID \'%s\' and SentTimestamp %s', sqs_rec['messageId'],