diff --git a/docs/source/rules.rst b/docs/source/rules.rst index 588218af3..183f44fa1 100644 --- a/docs/source/rules.rst +++ b/docs/source/rules.rst @@ -133,6 +133,14 @@ req_subkeys This feature should be avoided, but it is useful if you defined a loose schema to trade flexibility for safety; see `Schemas `_. +context +~~~~~~~~~~~ + +``context`` is an optional argument which defines an extra field of information to pass on inside +of the alert record but without affecting schemas. It can be particulary helpful to pass data to +an output and utilize within the output processing code. + + Examples: .. code-block:: python diff --git a/manage.py b/manage.py index 3b309c163..46dab8d1e 100755 --- a/manage.py +++ b/manage.py @@ -102,7 +102,8 @@ def _add_output_subparser(subparsers): # Output service options output_parser.add_argument( '--service', - choices=['aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2', 'phantom', 'slack'], + choices=['aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2', + 'pagerduty-incident', 'phantom', 'slack'], required=True, help=ARGPARSE_SUPPRESS ) diff --git a/stream_alert/alert_processor/helpers.py b/stream_alert/alert_processor/helpers.py index d79f7af1a..885cd35e9 100644 --- a/stream_alert/alert_processor/helpers.py +++ b/stream_alert/alert_processor/helpers.py @@ -38,7 +38,8 @@ def validate_alert(alert): 'log_source', 'outputs', 'source_service', - 'source_entity' + 'source_entity', + 'context' } if not set(alert.keys()) == alert_keys: LOGGER.error('The alert object must contain the following keys: %s', @@ -53,6 +54,11 @@ def validate_alert(alert): LOGGER.error('The alert record must be a map (dict)') return False + elif key == 'context': + if not isinstance(alert['context'], dict): + LOGGER.error('The alert context must be a map (dict)') + return False + elif key == 'outputs': if not isinstance(alert[key], list): LOGGER.error( diff --git a/stream_alert/alert_processor/outputs.py b/stream_alert/alert_processor/outputs.py index 2e67acd67..8133747f7 100644 --- a/stream_alert/alert_processor/outputs.py +++ b/stream_alert/alert_processor/outputs.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +# pylint: disable=too-many-lines from abc import abstractmethod import cgi from collections import OrderedDict @@ -198,6 +199,171 @@ def dispatch(self, **kwargs): return self._log_status(success) +@output +class PagerDutyIncidentOutput(StreamOutputBase): + """PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2""" + __service__ = 'pagerduty-incident' + INCIDENTS_ENDPOINT = 'incidents' + USERS_ENDPOINT = 'users' + POLICIES_ENDPOINT = 'escalation_policies' + SERVICES_ENDPOINT = 'services' + + @classmethod + def _get_default_properties(cls): + """Get the standard url used for PagerDuty Incidents API v2. This value the same for + everyone, so is hard-coded here and does not need to be configured by the user + + Returns: + dict: Contains various default items for this output (ie: url) + """ + return { + 'api': 'https://api.pagerduty.com' + } + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new PagerDuty + event output. This should be sensitive or unique information for this use-case that + needs to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + PagerDuty also requires a routing_key that represents this integration. This + value should be masked during input and is a credential requirement. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this ' + 'PagerDuty integration')), + ('token', + OutputProperty(description='the token for this PagerDuty integration', + mask_input=True, + cred_requirement=True)), + ('service_key', + OutputProperty(description='the service key for this PagerDuty integration', + mask_input=True, + cred_requirement=True)), + ('escalation_policy', + OutputProperty(description='the name of the default escalation policy')) + ]) + + def _check_exists_get_id(self, filter_str, target_url, headers, target_key): + """Generic method to run a search in the PagerDuty REST API and return the id + of the first occurence from the results. + + Args: + filter (str): The query filter to search for in the API + url (str): The url to send the requests to in the API + headers (dict): A dictionary containing header parameters + target_key (str): The key to extract in the returned results + + Returns: + str: ID of the targeted element that matches the provided filter or + False if a matching element does not exists. + """ + params = { + 'query': '"{}"'.format(filter_str) + } + resp = self._get_request(target_url, params, headers, False) + if not self._check_http_response(resp): + return False + + response = resp.json() + + # If there are results, get the first occurence from the list + return response and response.get(target_key)[0]['id'] + + def dispatch(self, **kwargs): + """Send incident to Pagerduty Incidents API v2 + + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + # Extracting context data to assign the incident + rule_context = kwargs['alert']['context'][self.__service__] + + headers = { + 'Authorization': 'Token token={}'.format(creds['token']), + 'Accept': 'application/vnd.pagerduty+json;version=2' + } + # Check if a user to assign the incident is provided + user_to_assign = rule_context.get('assigned_user') + + # Incident assignment goes in this order: + # Provided user -> provided policy -> default policy + if user_to_assign: + users_url = os.path.join(creds['api'], self.USERS_ENDPOINT) + user_id = self._check_exists_get_id(user_to_assign, + users_url, headers, self.USERS_ENDPOINT) + if user_id: + assigned_key = 'assignments' + assigned_value = [{ + 'assignee' : { + 'id': '', + 'type': 'user_reference'} + }] + # If the user retrieval did not succeed, default to policies + else: + user_to_assign = False + + if not user_to_assign and rule_context.get('assigned_policy'): + policy_to_assign = rule_context.get('assigned_policy') + else: + policy_to_assign = creds['escalation_policy'] + + policies_url = os.path.join(creds['api'], self.POLICIES_ENDPOINT) + policy_id = self._check_exists_get_id(policy_to_assign, + policies_url, headers, self.POLICIES_ENDPOINT) + assigned_key = 'escalation_policy' + assigned_value = { + 'id': policy_id, + 'type': 'escalation_policy_reference' + } + + # Start preparing the incident JSON blob to be sent to the API + incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name']) + incident_body = { + 'type': '', + 'details': '' + } + # We need to get the service id from the API + services_url = os.path.join(creds['api'], self.SERVICES_ENDPOINT) + service_id = self._check_exists_get_id(creds['service_key'], + services_url, headers, self.SERVICES_ENDPOINT) + incident_service = { + 'id': service_id, + 'type': 'service_reference' + } + incident_priority = { + 'id': '', + 'type': 'priority_reference' + } + incident = { + 'incident': { + 'type': 'incident', + 'title': incident_title, + 'service': incident_service, + 'priority': incident_priority, + 'body': incident_body + }, + assigned_key: assigned_value + } + incidents_url = os.path.join(creds['api'], self.INCIDENTS_ENDPOINT) + resp = self._post_request(incidents_url, incident, None, True) + success = self._check_http_response(resp) + + return self._log_status(success) + @output class PhantomOutput(StreamOutputBase): """PhantomOutput handles all alert dispatching for Phantom""" diff --git a/stream_alert/rule_processor/rules_engine.py b/stream_alert/rule_processor/rules_engine.py index c51ee70e0..9c177aa23 100644 --- a/stream_alert/rule_processor/rules_engine.py +++ b/stream_alert/rule_processor/rules_engine.py @@ -28,7 +28,8 @@ 'datatypes', 'logs', 'outputs', - 'req_subkeys']) + 'req_subkeys', + 'context']) class StreamRules(object): @@ -70,6 +71,7 @@ def decorator(rule): matchers = opts.get('matchers') datatypes = opts.get('datatypes') req_subkeys = opts.get('req_subkeys') + context = opts.get('context', {}) if not (logs or datatypes): LOGGER.error( @@ -92,7 +94,8 @@ def decorator(rule): datatypes, logs, outputs, - req_subkeys) + req_subkeys, + context) return rule return decorator @@ -387,7 +390,8 @@ def process(cls, input_payload): 'log_type': payload.type, 'outputs': rule.outputs, 'source_service': payload.service(), - 'source_entity': payload.entity} + 'source_entity': payload.entity, + 'context': rule.context} alerts.append(alert) return alerts diff --git a/tests/unit/stream_alert_alert_processor/helpers.py b/tests/unit/stream_alert_alert_processor/helpers.py index 9f09816bb..35d005a61 100644 --- a/tests/unit/stream_alert_alert_processor/helpers.py +++ b/tests/unit/stream_alert_alert_processor/helpers.py @@ -90,6 +90,12 @@ def get_alert(index=0): 'outputs': [ 'slack:unit_test_channel' ], + 'context': { + 'pagerduty-incident': { + 'assigned_user':'valid_user', + 'assigned_policy': 'valid_policy' + } + }, 'source_service': 's3', 'source_entity': 'corp-prefix.prod.cb.region', 'log_type': 'json', diff --git a/tests/unit/stream_alert_alert_processor/test_outputs.py b/tests/unit/stream_alert_alert_processor/test_outputs.py index 33ff7c39d..794c68a18 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs.py @@ -190,7 +190,7 @@ def teardown_class(cls): cls.__dispatcher = None def test_get_default_properties(self): - """Get Default Properties - PagerDuty""" + """Get Default Properties - PagerDutyOutputV2""" props = self.__dispatcher._get_default_properties() assert_equal(len(props), 1) assert_equal(props['url'], @@ -268,6 +268,144 @@ def test_dispatch_bad_descriptor(self, log_error_mock): log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) +class TestPagerDutyIncidentOutput(object): + """Test class for PagerDutyIncidentOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'pagerduty-incident' + cls.__descriptor = 'unit_test_pagerduty-incident' + cls.__backup_method = None + cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def test_get_default_properties(self): + """Get Default Properties - PagerDutyIncidentOutput""" + props = self.__dispatcher._get_default_properties() + assert_equal(len(props), 1) + assert_equal(props['api'], + 'https://api.pagerduty.com') + + def _setup_dispatch(self): + """Helper for setting up PagerDutyIncidentOutput dispatch""" + remove_temp_secrets() + + # Cache the _get_default_properties and set it to return None + self.__backup_method = self.__dispatcher._get_default_properties + self.__dispatcher._get_default_properties = lambda: None + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'api': 'https://api.pagerduty.com', + 'token': 'mocked_token', + 'service_key': 'mocked_service_key', + 'escalation_policy': 'mocked_escalation_policy'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + return get_alert() + + def _teardown_dispatch(self): + """Replace method with cached method""" + self.__dispatcher._get_default_properties = self.__backup_method + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success""" + alert = self._setup_dispatch() + + # _check_user_exists, _check_service_exists + get_mock.return_value.status_code.side_effect = [200, 200] + json_user = json.loads('{"users": [{"id": "user_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.return_value.side_effect = [json_user, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_bad_user(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Bad User""" + alert = self._setup_dispatch() + + # _check_user_exists, _check_service_exists + get_mock.return_value.status_code.side_effect = [200, 200] + json_user = json.loads('{"not_users": [{"id": "user_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.return_value.side_effect = [json_user, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_error_mock): + """PagerDutyIncidentOutput dispatch failure - Bad User, Bad Policy, Bad Service""" + alert = self._setup_dispatch() + # _check_user_exists, _check_policy_exists, _check_service_exists + get_mock.return_value.status_code.side_effect = [400, 400, 400] + json_empty = json.loads('{}') + get_mock.return_value.json.return_value.side_effect = [json_empty, json_empty, json_empty] + + # /incidents + post_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.error') + @mock_s3 + @mock_kms + def test_dispatch_bad_descriptor(self, log_error_mock): + """PagerDutyIncidentOutput dispatch bad descriptor""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + @mock_s3 @mock_kms diff --git a/tests/unit/stream_alert_rule_processor/test_rules_engine.py b/tests/unit/stream_alert_rule_processor/test_rules_engine.py index d882cbb39..14f61b896 100644 --- a/tests/unit/stream_alert_rule_processor/test_rules_engine.py +++ b/tests/unit/stream_alert_rule_processor/test_rules_engine.py @@ -101,11 +101,13 @@ def alert_format_test(rec): # pylint: disable=unused-variable 'log_source', 'outputs', 'source_service', - 'source_entity' + 'source_entity', + 'context' } assert_items_equal(alerts[0].keys(), alert_keys) assert_is_instance(alerts[0]['record'], dict) assert_is_instance(alerts[0]['outputs'], list) + assert_is_instance(alerts[0]['context'], dict) # test alert fields assert_is_instance(alerts[0]['rule_name'], str) @@ -207,7 +209,8 @@ def cloudtrail_us_east_logs(rec): datatypes=[], logs=['test_log_type_json_nested'], outputs=['s3:sample_bucket'], - req_subkeys={'requestParameters': ['program']} + req_subkeys={'requestParameters': ['program']}, + context={} ) data = json.dumps({