diff --git a/stream_alert/rule_processor/rules_engine.py b/stream_alert/rule_processor/rules_engine.py index 0c4483601..e777b8dce 100644 --- a/stream_alert/rule_processor/rules_engine.py +++ b/stream_alert/rule_processor/rules_engine.py @@ -284,7 +284,10 @@ def process_rule(record, rule): (bool): The return function of the rule """ try: - rule_result = rule.rule_function(record) + if rule.context: + rule_result = rule.rule_function(record, rule.context) + else: + rule_result = rule.rule_function(record) except Exception: # pylint: disable=broad-except rule_result = False LOGGER.exception( diff --git a/tests/unit/stream_alert_rule_processor/test_rules_engine.py b/tests/unit/stream_alert_rule_processor/test_rules_engine.py index 3210618ba..78f429c57 100644 --- a/tests/unit/stream_alert_rule_processor/test_rules_engine.py +++ b/tests/unit/stream_alert_rule_processor/test_rules_engine.py @@ -852,3 +852,39 @@ def match_ipaddress(_): # pylint: disable=unused-variable payload = load_and_classify_payload(toggled_config, service, entity, raw_record) assert_equal(len(new_rules_engine.process(payload)), 1) + + def test_rule_modify_context(self): + """Rules Engine - Testing Context Modification""" + @rule(logs=['test_log_type_json_nested_with_data'], + outputs=['s3:sample_bucket'], + context={'assigned_user': 'not_set', 'assigned_policy': 'not_set2'}) + def modify_context_test(rec, context): # pylint: disable=unused-variable + """Modify context rule""" + context['assigned_user'] = 'valid_user' + context['assigned_policy'] = 'valid_policy' + return rec['application'] == 'web-app' + + kinesis_data = json.dumps({ + 'date': 'Dec 01 2016', + 'unixtime': '1483139547', + 'host': 'host1.web.prod.net', + 'application': 'web-app', + 'environment': 'prod', + 'data': { + 'category': 'web-server', + 'type': '1', + 'source': 'eu' + } + }) + + # prepare the payloads + service, entity = 'kinesis', 'test_kinesis_stream' + raw_record = make_kinesis_raw_record(entity, kinesis_data) + payload = load_and_classify_payload(self.config, service, entity, raw_record) + + # process payloads + alerts = self.rules_engine.process(payload) + + # alert tests + assert_equal(alerts[0]['context']['assigned_user'], 'valid_user') + assert_equal(alerts[0]['context']['assigned_policy'], 'valid_policy')