diff --git a/stream_alert/rule_processor/rules_engine.py b/stream_alert/rule_processor/rules_engine.py index 80c48e698..f3af5267d 100644 --- a/stream_alert/rule_processor/rules_engine.py +++ b/stream_alert/rule_processor/rules_engine.py @@ -22,23 +22,47 @@ from stream_alert.shared import NORMALIZATION_KEY, resources from stream_alert.shared.alert import Alert from stream_alert.shared.rule import import_folders, Rule +from stream_alert.shared.lookup_tables import LookupTables from stream_alert.shared.rule_table import RuleTable - _IGNORE_KEYS = {StreamThreatIntel.IOC_KEY, NORMALIZATION_KEY} class RulesEngine(object): """Class to act as a rules engine that processes rules""" _RULE_TABLE_LAST_REFRESH = datetime(year=1970, month=1, day=1) + _LOOKUP_TABLES = {} _RULE_TABLE = None def __init__(self, config, *rule_paths): """Initialize a RulesEngine instance to cache a StreamThreatIntel instance.""" + self._threat_intel = StreamThreatIntel.load_from_config(config) self._required_outputs_set = resources.get_required_outputs() import_folders(*rule_paths) self._load_rule_table(config) + lookup_tables = LookupTables.load_lookup_tables(config) + if lookup_tables: + RulesEngine._LOOKUP_TABLES = lookup_tables.download_s3_objects() + + @classmethod + def get_lookup_table(cls, table_name): + """Return lookup table by table name + + Rule Processor supports to load arbitrary json files from S3 buckets into + memory for quick reference while writing rules. This information is stored + in class variable `_LOOKUP_TABLES` which is a dictionary. Json file name + without extension will the key name(a.k.a table_name), and json content + will be the value. + + Args: + table_name (str): Lookup table name. It is also the json file name without + extension. + + Returns: + dict: A dictionary contains lookup table information. + """ + return cls._LOOKUP_TABLES.get(table_name) @classmethod def _load_rule_table(cls, config): diff --git a/stream_alert/shared/lookup_tables.py b/stream_alert/shared/lookup_tables.py new file mode 100644 index 000000000..d1c9ca53f --- /dev/null +++ b/stream_alert/shared/lookup_tables.py @@ -0,0 +1,104 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from datetime import datetime, timedelta +import json +import os +import time + +import boto3 +from botocore.exceptions import ClientError + +from stream_alert.shared import LOGGER + +class LookupTables(object): + """Lookup Tables to useful information which can be referenced from rules""" + + _LOOKUP_TABLES_LAST_REFRESH = datetime(year=1970, month=1, day=1) + + def __init__(self, buckets_info): + self._s3_client = boto3.resource('s3') + self._buckets_info = buckets_info + + def download_s3_objects(self): + """Download S3 files (json format) from S3 buckets into memory. + + Returns: + dict: A dictionary contains information loaded from S3. The file name + will be the key, and value is file content in json format. + """ + + _lookup_tables = {} + + for bucket, files in self._buckets_info.iteritems(): + for json_file in files: + table_name = os.path.splitext(json_file)[0] + try: + start_time = time.time() + s3_object = self._s3_client.Object(bucket, json_file).get() + size_kb = round(s3_object.get('ContentLength') / 1024.0, 2) + size_mb = round(size_kb / 1024.0, 2) + display_size = '{}MB'.format(size_mb) if size_mb else '{}KB'.format(size_kb) + LOGGER.info('Downloaded S3 file size %s and updated lookup table %s', + display_size, table_name) + _lookup_tables[table_name] = json.loads(s3_object.get('Body').read()) + except ClientError as err: + LOGGER.error('Encounterred error while downloading %s from %s, %s', + json_file, bucket, err.response['Error']['Message']) + return _lookup_tables + + total_time = time.time() - start_time + LOGGER.info('Downloaded S3 file %s seconds', round(total_time, 2)) + + return _lookup_tables + + @classmethod + def load_lookup_tables(cls, config): + """Load arbitrary json files to memory from S3 buckets when lookup table enabled + + The lookup tables will also be refreshed based on "cache_refresh_minutes" setting + in the config. + + Args: + config (dict): Loaded configuration from 'conf/' directory + + Returns: + Return False if lookup table enabled or missing config. Otherwise, it + will return an instance of LookupTables class. + """ + lookup_tables = config['global']['infrastructure'].get('lookup_tables') + if not (lookup_tables and lookup_tables.get('enabled', False)): + return False + + buckets_info = lookup_tables.get('buckets') + if not buckets_info: + LOGGER.error('Buckets not defined') + return False + + lookup_refresh_interval = lookup_tables.get('cache_refresh_minutes', 10) + now = datetime.utcnow() + refresh_delta = timedelta(minutes=lookup_refresh_interval) + needs_refresh = cls._LOOKUP_TABLES_LAST_REFRESH + refresh_delta < now + if not needs_refresh: + LOGGER.debug('lookup tables do not need refresh (last refresh time: %s; ' + 'current time: %s)', cls._LOOKUP_TABLES_LAST_REFRESH, now) + return False + + LOGGER.info('Refreshing lookup tables (last refresh time: %s; current time: %s)', + cls._LOOKUP_TABLES_LAST_REFRESH, now) + + cls._LOOKUP_TABLES_LAST_REFRESH = now + + return cls(buckets_info) diff --git a/stream_alert_cli/helpers.py b/stream_alert_cli/helpers.py index 6462645a8..97d72d750 100644 --- a/stream_alert_cli/helpers.py +++ b/stream_alert_cli/helpers.py @@ -663,6 +663,29 @@ def put_mock_s3_object(bucket, key, data, region): s3_client.put_object(Body=data, Bucket=bucket, Key=key, ServerSideEncryption='AES256') +def mock_s3_bucket(config): + """Mock S3 bucket for lookup tables testing""" + region = config['global']['account']['region'] + lookup_tables_config = config['global']['infrastructure'].get('lookup_tables') + if lookup_tables_config: + buckets_info = lookup_tables_config.get( + 'buckets', {'test_buckets': ['foo.json', 'bar.json']} + ) + else: + buckets_info = {'test_buckets': ['foo.json', 'bar.json']} + + for bucket, files in buckets_info.iteritems(): + for json_file in files: + test_json_file = os.path.join('tests/integration/fixtures', json_file) + if os.path.isfile(test_json_file): + data = open(test_json_file, 'r') + else: + data = json.dumps({'key': 'value'}) + put_mock_s3_object(bucket, json_file, data, region) + + if isinstance(data, file): + data.close() + def mock_me(context): """Decorator function for wrapping framework in mock calls diff --git a/stream_alert_cli/test.py b/stream_alert_cli/test.py index 3886f14d7..aa16c1931 100644 --- a/stream_alert_cli/test.py +++ b/stream_alert_cli/test.py @@ -1041,6 +1041,8 @@ def run_tests(options, context): # Run the rule processor for all rules or designated rule set if context.mocked: helpers.setup_mock_alerts_table(alerts_table) + # Mock S3 bucket for lookup tables testing + helpers.mock_s3_bucket(config) rule_proc_tester = RuleProcessorTester(context, config, test_rules) alert_proc_tester = AlertProcessorTester(config, context) diff --git a/tests/unit/conf/global.json b/tests/unit/conf/global.json index 0435825bc..ab6a9f946 100644 --- a/tests/unit/conf/global.json +++ b/tests/unit/conf/global.json @@ -14,6 +14,16 @@ "read_capacity": 5, "write_capacity": 5 }, + "lookup_tables": { + "buckets": { + "bucket_name": [ + "foo.json", + "bar.json" + ] + }, + "cache_refresh_minutes": 10, + "enabled": false + }, "monitoring": { "create_sns_topic": true }, diff --git a/tests/unit/stream_alert_rule_processor/test_rules_engine.py b/tests/unit/stream_alert_rule_processor/test_rules_engine.py index 6c0f33d31..9db2b0e46 100644 --- a/tests/unit/stream_alert_rule_processor/test_rules_engine.py +++ b/tests/unit/stream_alert_rule_processor/test_rules_engine.py @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -# pylint: disable=no-self-use,protected-access,attribute-defined-outside-init +# pylint: disable=no-self-use,protected-access,attribute-defined-outside-init,too-many-lines from datetime import datetime, timedelta import json import os from mock import Mock, patch +from moto import mock_s3 from nose.tools import ( assert_equal, assert_false, @@ -27,12 +28,14 @@ assert_true, ) +from stream_alert_cli.helpers import put_mock_s3_object from stream_alert.rule_processor.parsers import get_parser from stream_alert.rule_processor.rules_engine import RulesEngine from stream_alert.shared import NORMALIZATION_KEY from stream_alert.shared.config import load_config from stream_alert.shared.rule import disable, matcher, Matcher, rule, Rule from stream_alert.shared.rule_table import RuleTable +from stream_alert.shared.lookup_tables import LookupTables from tests.unit.stream_alert_rule_processor.test_helpers import ( load_and_classify_payload, @@ -54,8 +57,14 @@ def setup(self): Rule._rules.clear() self.config = load_config('tests/unit/conf') self.config['global']['threat_intel']['enabled'] = False + self.config['global']['infrastructure']['lookup_tables']['enabled'] = False self.rules_engine = RulesEngine(self.config) + def teardown(self): + """Clean up setup for lookup tables""" + RulesEngine._LOOKUP_TABLES = {} + LookupTables._LOOKUP_TABLES_LAST_REFRESH = datetime(year=1970, month=1, day=1) + def test_basic_rule_matcher_process(self): """Rules Engine - Basic Rule/Matcher""" @matcher @@ -945,3 +954,42 @@ def rule_staged_only(_): # pylint: disable=unused-variable # alert tests assert_equal(list(alerts[0].outputs)[0], 'aws-firehose:alerts') + + @patch('logging.Logger.error') + def test_load_lookup_tables_missing_config(self, mock_logger): + """Rules Engine - Load lookup tables with missing config""" + self.config['global']['infrastructure'].pop('lookup_tables') + _ = RulesEngine(self.config) + assert_equal(RulesEngine._LOOKUP_TABLES, {}) + assert_equal(LookupTables._LOOKUP_TABLES_LAST_REFRESH, + datetime(year=1970, month=1, day=1)) + assert_equal(RulesEngine.get_lookup_table('table_name'), None) + + self.config['global']['infrastructure']['lookup_tables'] = { + 'cache_refresh_minutes': 10, + 'enabled': True + } + _ = RulesEngine(self.config) + mock_logger.assert_called_with('Buckets not defined') + + @patch('logging.Logger.debug') + def test_load_lookup_tables(self, mock_logger): + """Rules Engine - Load lookup table""" + s3_mock = mock_s3() + s3_mock.start() + put_mock_s3_object( + 'bucket_name', 'foo.json', json.dumps({'key1': 'value1'}), 'us-east-1' + ) + put_mock_s3_object( + 'bucket_name', 'bar.json', json.dumps({'key2': 'value2'}), 'us-east-1' + ) + self.config['global']['infrastructure']['lookup_tables']['enabled'] = True + _ = RulesEngine(self.config) + assert_equal(RulesEngine.get_lookup_table('foo'), {'key1': 'value1'}) + assert_equal(RulesEngine.get_lookup_table('bar'), {'key2': 'value2'}) + assert_equal(RulesEngine.get_lookup_table('not_exist'), None) + + _ = RulesEngine(self.config) + mock_logger.assert_called() + + s3_mock.stop() diff --git a/tests/unit/stream_alert_shared/test_lookup_tables.py b/tests/unit/stream_alert_shared/test_lookup_tables.py new file mode 100644 index 000000000..e64bdb2be --- /dev/null +++ b/tests/unit/stream_alert_shared/test_lookup_tables.py @@ -0,0 +1,114 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from datetime import datetime +import json +import os + +from mock import patch +from moto import mock_s3 +from nose.tools import assert_equal + +from stream_alert_cli.helpers import put_mock_s3_object +from stream_alert.shared.config import load_config +from stream_alert.shared.lookup_tables import LookupTables + +# pylint: disable=protected-access +class TestLookupTables(object): + """Test LookupTables class""" + def __init__(self): + self.buckets_info = {'bucket_name': ['foo.json', 'bar.json']} + self.region = 'us-east-1' + + def setup(self): + """LookupTables - Setup S3 bucket mocking""" + # pylint: disable=attribute-defined-outside-init + self.config = load_config('tests/unit/conf') + self.lookup_tables = LookupTables(self.buckets_info) + self.s3_mock = mock_s3() + self.s3_mock.start() + for bucket, files in self.buckets_info.iteritems(): + for json_file in files: + put_mock_s3_object( + bucket, + json_file, + json.dumps({ + '{}_key'.format(bucket): '{}_value'.format(os.path.splitext(json_file)[0]) + }), + self.region + ) + + def teardown(self): + """LookupTables - Stop S3 bucket mocking""" + self.s3_mock.stop() + LookupTables._LOOKUP_TABLES_LAST_REFRESH = datetime(year=1970, month=1, day=1) + + def test_download_s3_object(self): + """LookupTables - Download s3 object""" + result = self.lookup_tables.download_s3_objects() + assert_equal(result.keys(), ['foo', 'bar']) + expect_result = { + 'foo': {'bucket_name_key': 'foo_value'}, + 'bar': {'bucket_name_key': 'bar_value'} + } + assert_equal(result, expect_result) + + @patch('logging.Logger.error') + def test_download_s3_object_bucket_exception(self, mock_logger): # pylint: disable=no-self-use + """LookupTables - S3 bucket doesn't exist""" + lookup_tables = LookupTables({'wrong_bucket': ['foo.json']}) + lookup_tables.download_s3_objects() + mock_logger.assert_called_with( + 'Encounterred error while downloading %s from %s, %s', + 'foo.json', + 'wrong_bucket', + 'The specified bucket does not exist' + ) + + def test_download_s3_object_file_exception(self): # pylint: disable=no-self-use + """LookupTables - S3 file doesn't exist""" + lookup_tables = LookupTables({'bucket_name': ['wrong_file']}) + lookup_tables.download_s3_objects() + + @patch('logging.Logger.error') + def test_load_lookup_tables_missing_config(self, mock_logger): + """LookupTables - Load lookup tables with missing config""" + # Remove lookup_tables config for this test case. + self.config['global']['infrastructure'].pop('lookup_tables') + lookup_tables = LookupTables.load_lookup_tables(self.config) + assert_equal(lookup_tables, False) + assert_equal(LookupTables._LOOKUP_TABLES_LAST_REFRESH, + datetime(year=1970, month=1, day=1)) + + self.config['global']['infrastructure']['lookup_tables'] = { + 'cache_refresh_minutes': 10, + 'enabled': True + } + lookup_tables = LookupTables.load_lookup_tables(self.config) + mock_logger.assert_called_with('Buckets not defined') + + @patch('logging.Logger.debug') + def test_load_lookup_tables(self, mock_logger): + """LookupTables - Load lookup table""" + self.config['global']['infrastructure']['lookup_tables']['enabled'] = True + lookup_tables = LookupTables.load_lookup_tables(self.config) + result = lookup_tables.download_s3_objects() + + assert_equal(result.get('foo'), {'bucket_name_key': 'foo_value'}) + assert_equal(result.get('bar'), {'bucket_name_key': 'bar_value'}) + assert_equal(result.get('not_exist'), None) + + LookupTables.load_lookup_tables(self.config) + mock_logger.assert_called()