diff --git a/manage.py b/manage.py index 00406cdf6..43219e1bc 100755 --- a/manage.py +++ b/manage.py @@ -102,7 +102,8 @@ def _add_output_subparser(subparsers): output_parser.add_argument( '--service', choices=[ - 'aws-firehose', 'aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2', 'phantom', 'slack' + 'aws-firehose', 'aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2', + 'pagerduty-incident', 'phantom', 'slack' ], required=True, help=ARGPARSE_SUPPRESS) diff --git a/stream_alert/alert_processor/helpers.py b/stream_alert/alert_processor/helpers.py index d79f7af1a..885cd35e9 100644 --- a/stream_alert/alert_processor/helpers.py +++ b/stream_alert/alert_processor/helpers.py @@ -38,7 +38,8 @@ def validate_alert(alert): 'log_source', 'outputs', 'source_service', - 'source_entity' + 'source_entity', + 'context' } if not set(alert.keys()) == alert_keys: LOGGER.error('The alert object must contain the following keys: %s', @@ -53,6 +54,11 @@ def validate_alert(alert): LOGGER.error('The alert record must be a map (dict)') return False + elif key == 'context': + if not isinstance(alert['context'], dict): + LOGGER.error('The alert context must be a map (dict)') + return False + elif key == 'outputs': if not isinstance(alert[key], list): LOGGER.error( diff --git a/stream_alert/alert_processor/outputs.py b/stream_alert/alert_processor/outputs.py index 58b71f19c..5649c5a0c 100644 --- a/stream_alert/alert_processor/outputs.py +++ b/stream_alert/alert_processor/outputs.py @@ -313,28 +313,24 @@ def dispatch(self, **kwargs): # Incident assignment goes in this order: # Provided user -> provided policy -> default policy if user_to_assign: - users_url = os.path.join(creds['api'], self.USERS_ENDPOINT) - user_id = self._check_exists_get_id(user_to_assign, - users_url, headers, self.USERS_ENDPOINT) + users_url = self._get_endpoint(creds['api'], self.USERS_ENDPOINT) + user_id = self._check_exists_get_id(user_to_assign, users_url, headers, + self.USERS_ENDPOINT) if user_id: assigned_key = 'assignments' assigned_value = [{ 'assignee' : { - 'id': '', + 'id': user_id, 'type': 'user_reference'} }] - # If the user retrieval did not succeed, default to policies - else: - user_to_assign = False + policy_to_assign = creds['escalation_policy'] if not user_to_assign and rule_context.get('assigned_policy'): policy_to_assign = rule_context.get('assigned_policy') - else: - policy_to_assign = creds['escalation_policy'] - policies_url = os.path.join(creds['api'], self.POLICIES_ENDPOINT) - policy_id = self._check_exists_get_id(policy_to_assign, - policies_url, headers, self.POLICIES_ENDPOINT) + policies_url = self._get_endpoint(creds['api'], self.POLICIES_ENDPOINT) + policy_id = self._check_exists_get_id(policy_to_assign, policies_url, headers, + self.POLICIES_ENDPOINT) assigned_key = 'escalation_policy' assigned_value = { 'id': policy_id, @@ -344,32 +340,27 @@ def dispatch(self, **kwargs): # Start preparing the incident JSON blob to be sent to the API incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name']) incident_body = { - 'type': '', - 'details': '' + 'type': 'incident_body', + 'details': kwargs['alert']['rule_description'] } # We need to get the service id from the API - services_url = os.path.join(creds['api'], self.SERVICES_ENDPOINT) - service_id = self._check_exists_get_id(creds['service_key'], - services_url, headers, self.SERVICES_ENDPOINT) + services_url = self._get_endpoint(creds['api'], self.SERVICES_ENDPOINT) + service_id = self._check_exists_get_id(creds['service_key'], services_url, headers, + self.SERVICES_ENDPOINT) incident_service = { 'id': service_id, 'type': 'service_reference' } - incident_priority = { - 'id': '', - 'type': 'priority_reference' - } incident = { 'incident': { 'type': 'incident', 'title': incident_title, 'service': incident_service, - 'priority': incident_priority, 'body': incident_body }, assigned_key: assigned_value } - incidents_url = os.path.join(creds['api'], self.INCIDENTS_ENDPOINT) + incidents_url = self._get_endpoint(creds['api'], self.INCIDENTS_ENDPOINT) resp = self._post_request(incidents_url, incident, None, True) success = self._check_http_response(resp) diff --git a/stream_alert/rule_processor/rules_engine.py b/stream_alert/rule_processor/rules_engine.py index c51ee70e0..9c177aa23 100644 --- a/stream_alert/rule_processor/rules_engine.py +++ b/stream_alert/rule_processor/rules_engine.py @@ -28,7 +28,8 @@ 'datatypes', 'logs', 'outputs', - 'req_subkeys']) + 'req_subkeys', + 'context']) class StreamRules(object): @@ -70,6 +71,7 @@ def decorator(rule): matchers = opts.get('matchers') datatypes = opts.get('datatypes') req_subkeys = opts.get('req_subkeys') + context = opts.get('context', {}) if not (logs or datatypes): LOGGER.error( @@ -92,7 +94,8 @@ def decorator(rule): datatypes, logs, outputs, - req_subkeys) + req_subkeys, + context) return rule return decorator @@ -387,7 +390,8 @@ def process(cls, input_payload): 'log_type': payload.type, 'outputs': rule.outputs, 'source_service': payload.service(), - 'source_entity': payload.entity} + 'source_entity': payload.entity, + 'context': rule.context} alerts.append(alert) return alerts diff --git a/tests/unit/stream_alert_rule_processor/test_rules_engine.py b/tests/unit/stream_alert_rule_processor/test_rules_engine.py index d882cbb39..14f61b896 100644 --- a/tests/unit/stream_alert_rule_processor/test_rules_engine.py +++ b/tests/unit/stream_alert_rule_processor/test_rules_engine.py @@ -101,11 +101,13 @@ def alert_format_test(rec): # pylint: disable=unused-variable 'log_source', 'outputs', 'source_service', - 'source_entity' + 'source_entity', + 'context' } assert_items_equal(alerts[0].keys(), alert_keys) assert_is_instance(alerts[0]['record'], dict) assert_is_instance(alerts[0]['outputs'], list) + assert_is_instance(alerts[0]['context'], dict) # test alert fields assert_is_instance(alerts[0]['rule_name'], str) @@ -207,7 +209,8 @@ def cloudtrail_us_east_logs(rec): datatypes=[], logs=['test_log_type_json_nested'], outputs=['s3:sample_bucket'], - req_subkeys={'requestParameters': ['program']} + req_subkeys={'requestParameters': ['program']}, + context={} ) data = json.dumps({