diff --git a/docs/source/rules.rst b/docs/source/rules.rst index 588218af3..d0ecd2643 100644 --- a/docs/source/rules.rst +++ b/docs/source/rules.rst @@ -153,6 +153,23 @@ Examples: req_subkeys={'columns':['port', 'protocol']}) ... +context +~~~~~~~~~~~ + +``context`` is an optional field to pass extra instructions to the alert processor on how to route the alert. It can be particulary helpful to pass data to an output. + +Example: + +.. code-block:: python + + # Context provided to the pagerduty-incident output + # with instructions to assign the incident to a user. + + @rule(logs=['osquery:differential'], + outputs=['pagerduty', 'aws-s3'], + context={'pagerduty-incident':{'assigned_user': 'valid_user'}}) + ... + Helpers ------- diff --git a/manage.py b/manage.py index 00406cdf6..43219e1bc 100755 --- a/manage.py +++ b/manage.py @@ -102,7 +102,8 @@ def _add_output_subparser(subparsers): output_parser.add_argument( '--service', choices=[ - 'aws-firehose', 'aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2', 'phantom', 'slack' + 'aws-firehose', 'aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2', + 'pagerduty-incident', 'phantom', 'slack' ], required=True, help=ARGPARSE_SUPPRESS) diff --git a/stream_alert/alert_processor/helpers.py b/stream_alert/alert_processor/helpers.py index d79f7af1a..885cd35e9 100644 --- a/stream_alert/alert_processor/helpers.py +++ b/stream_alert/alert_processor/helpers.py @@ -38,7 +38,8 @@ def validate_alert(alert): 'log_source', 'outputs', 'source_service', - 'source_entity' + 'source_entity', + 'context' } if not set(alert.keys()) == alert_keys: LOGGER.error('The alert object must contain the following keys: %s', @@ -53,6 +54,11 @@ def validate_alert(alert): LOGGER.error('The alert record must be a map (dict)') return False + elif key == 'context': + if not isinstance(alert['context'], dict): + LOGGER.error('The alert context must be a map (dict)') + return False + elif key == 'outputs': if not isinstance(alert[key], list): LOGGER.error( diff --git a/stream_alert/alert_processor/outputs.py b/stream_alert/alert_processor/outputs.py index 2e67acd67..1d5631d43 100644 --- a/stream_alert/alert_processor/outputs.py +++ b/stream_alert/alert_processor/outputs.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +# pylint: disable=too-many-lines from abc import abstractmethod import cgi from collections import OrderedDict @@ -64,9 +65,7 @@ def _get_default_properties(cls): Returns: dict: Contains various default items for this output (ie: url) """ - return { - 'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json' - } + return {'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'} def get_user_defined_properties(self): """Get properties that must be asssigned by the user when configuring a new PagerDuty @@ -133,9 +132,7 @@ def _get_default_properties(cls): Returns: dict: Contains various default items for this output (ie: url) """ - return { - 'url': 'https://events.pagerduty.com/v2/enqueue' - } + return {'url': 'https://events.pagerduty.com/v2/enqueue'} def get_user_defined_properties(self): """Get properties that must be asssigned by the user when configuring a new PagerDuty @@ -198,6 +195,260 @@ def dispatch(self, **kwargs): return self._log_status(success) +@output +class PagerDutyIncidentOutput(StreamOutputBase): + """PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2""" + __service__ = 'pagerduty-incident' + INCIDENTS_ENDPOINT = 'incidents' + USERS_ENDPOINT = 'users' + POLICIES_ENDPOINT = 'escalation_policies' + SERVICES_ENDPOINT = 'services' + + def __init__(self, *args, **kwargs): + StreamOutputBase.__init__(self, *args, **kwargs) + self._base_url = None + self._headers = None + self._escalation_policy = None + + @classmethod + def _get_default_properties(cls): + """Get the standard url used for PagerDuty Incidents API v2. This value the same for + everyone, so is hard-coded here and does not need to be configured by the user + + Returns: + dict: Contains various default items for this output (ie: url) + """ + return {'api': 'https://api.pagerduty.com'} + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new PagerDuty + event output. This should be sensitive or unique information for this use-case that + needs to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + PagerDuty also requires a routing_key that represents this integration. This + value should be masked during input and is a credential requirement. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this ' + 'PagerDuty integration')), + ('token', + OutputProperty(description='the token for this PagerDuty integration', + mask_input=True, + cred_requirement=True)), + ('service_key', + OutputProperty(description='the service key for this PagerDuty integration', + mask_input=True, + cred_requirement=True)), + ('escalation_policy', + OutputProperty(description='the name of the default escalation policy')) + ]) + + @staticmethod + def _get_endpoint(base_url, endpoint): + """Helper to get the full url for a PagerDuty Incidents endpoint. + + Args: + base_url (str): Base URL for the API + endpoint (str): Endpoint that we want the full URL for + + Returns: + str: Full URL of the provided endpoint + """ + return os.path.join(base_url, endpoint) + + def _check_exists_get_id(self, filter_str, url, target_key): + """Generic method to run a search in the PagerDuty REST API and return the id + of the first occurence from the results. + + Args: + filter_str (str): The query filter to search for in the API + url (str): The url to send the requests to in the API + target_key (str): The key to extract in the returned results + + Returns: + str: ID of the targeted element that matches the provided filter or + False if a matching element does not exists. + """ + params = { + 'query': '{}'.format(filter_str) + } + resp = self._get_request(url, params, self._headers, False) + + if not self._check_http_response(resp): + return False + + response = resp.json() + if not response: + return False + + # If there are results, get the first occurence from the list + return response[target_key][0]['id'] if target_key in response else False + + def _user_verify(self, user): + """Method to verify the existance of an user with the API + + Args: + user (str): User to query about in the API. + + Returns: + dict or False: JSON object be used in the API call, containing the user_id + and user_reference. False if user is not found + """ + users_url = self._get_endpoint(self._base_url, self.USERS_ENDPOINT) + + return self._item_verify(users_url, user, self.USERS_ENDPOINT, 'user_reference') + + def _policy_verify(self, policy, default_policy): + """Method to verify the existance of a escalation policy with the API + + Args: + policy (str): Escalation policy to query about in the API + default_policy (str): Escalation policy to use if the first one is not verified + + Returns: + dict: JSON object be used in the API call, containing the policy_id + and escalation_policy_reference + """ + policies_url = self._get_endpoint(self._base_url, self.POLICIES_ENDPOINT) + + verified = self._item_verify(policies_url, policy, self.POLICIES_ENDPOINT, + 'escalation_policy_reference') + + # If the escalation policy provided is not verified in the API, use the default + if verified: + return verified + + return self._item_verify(policies_url, default_policy, self.POLICIES_ENDPOINT, + 'escalation_policy_reference') + + def _service_verify(self, service): + """Method to verify the existance of a service with the API + + Args: + service (str): Service to query about in the API + + Returns: + dict: JSON object be used in the API call, containing the service_id + and the service_reference + """ + services_url = self._get_endpoint(self._base_url, self.SERVICES_ENDPOINT) + + return self._item_verify(services_url, service, self.SERVICES_ENDPOINT, 'service_reference') + + def _item_verify(self, item_url, item_str, item_key, item_type): + """Method to verify the existance of an item with the API + + Args: + item_url (str): URL of the API Endpoint within the API to query + item_str (str): Service to query about in the API + item_key (str): Key to be extracted from search results + item_type (str): Type of item reference to be returned + + Returns: + dict: JSON object be used in the API call, containing the item id + and the item reference, or False if it fails + """ + item_id = self._check_exists_get_id(item_str, item_url, item_key) + if not item_id: + LOGGER.info('%s not found in %s, %s', item_str, item_key, self.__service__) + return False + + return { + 'id': item_id, + 'type': item_type + } + + def _incident_assignment(self, context): + """Method to determine if the incident gets assigned to a user or an escalation policy + + Args: + context (dict): Context provided in the alert record + + Returns: + tuple: assigned_key (str), assigned_value (dict to assign incident to an escalation + policy or array of dicts to assign incident to users) + """ + # Check if a user to assign the incident is provided + user_to_assign = context.get('assigned_user', False) + + # If provided, verify the user and get the id from API + if user_to_assign: + user_assignee = self._user_verify(user_to_assign) + # User is verified, return tuple + if user_assignee: + return 'assignments', [{'assignee': user_assignee}] + + # If escalation policy was not provided, use default one + policy_to_assign = context.get('assigned_policy', self._escalation_policy) + + # Verify escalation policy, return tuple + return 'escalation_policy', self._policy_verify(policy_to_assign, self._escalation_policy) + + def dispatch(self, **kwargs): + """Send incident to Pagerduty Incidents API v2 + + Keyword Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + alert['context'] (dict): Provides user or escalation policy + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + # Preparing headers for API calls + self._headers = { + 'Authorization': 'Token token={}'.format(creds['token']), + 'Accept': 'application/vnd.pagerduty+json;version=2' + } + + # Cache base_url + self._base_url = creds['api'] + + # Cache default escalation policy + self._escalation_policy = creds['escalation_policy'] + + # Extracting context data to assign the incident + rule_context = kwargs['alert'].get('context', {}) + if rule_context: + rule_context = rule_context.get(self.__service__, {}) + + # Incident assignment goes in this order: + # Provided user -> provided policy -> default policy + assigned_key, assigned_value = self._incident_assignment(rule_context) + + # Start preparing the incident JSON blob to be sent to the API + incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name']) + incident_body = { + 'type': 'incident_body', + 'details': kwargs['alert']['rule_description'] + } + # We need to get the service id from the API + incident_service = self._service_verify(creds['service_key']) + incident = { + 'incident': { + 'type': 'incident', + 'title': incident_title, + 'service': incident_service, + 'body': incident_body + }, + assigned_key: assigned_value + } + incidents_url = self._get_endpoint(self._base_url, self.INCIDENTS_ENDPOINT) + resp = self._post_request(incidents_url, incident, self._headers, True) + success = self._check_http_response(resp) + + return self._log_status(success) + @output class PhantomOutput(StreamOutputBase): """PhantomOutput handles all alert dispatching for Phantom""" diff --git a/stream_alert/rule_processor/rules_engine.py b/stream_alert/rule_processor/rules_engine.py index c51ee70e0..9c177aa23 100644 --- a/stream_alert/rule_processor/rules_engine.py +++ b/stream_alert/rule_processor/rules_engine.py @@ -28,7 +28,8 @@ 'datatypes', 'logs', 'outputs', - 'req_subkeys']) + 'req_subkeys', + 'context']) class StreamRules(object): @@ -70,6 +71,7 @@ def decorator(rule): matchers = opts.get('matchers') datatypes = opts.get('datatypes') req_subkeys = opts.get('req_subkeys') + context = opts.get('context', {}) if not (logs or datatypes): LOGGER.error( @@ -92,7 +94,8 @@ def decorator(rule): datatypes, logs, outputs, - req_subkeys) + req_subkeys, + context) return rule return decorator @@ -387,7 +390,8 @@ def process(cls, input_payload): 'log_type': payload.type, 'outputs': rule.outputs, 'source_service': payload.service(), - 'source_entity': payload.entity} + 'source_entity': payload.entity, + 'context': rule.context} alerts.append(alert) return alerts diff --git a/stream_alert_cli/test.py b/stream_alert_cli/test.py index 6f235f768..c03bff054 100644 --- a/stream_alert_cli/test.py +++ b/stream_alert_cli/test.py @@ -762,6 +762,14 @@ def setup_outputs(self, alert): helpers.put_mock_creds(output_name, creds, self.secrets_bucket, 'us-east-1', self.kms_alias) + elif service == 'pagerduty-incident': + output_name = '{}/{}'.format(service, descriptor) + creds = {'token': '247b97499078a015cc6c586bc0a92de6', + 'service_key': '247b97499078a015cc6c586bc0a92de6', + 'escalation_policy': '247b97499078a015cc6c586bc0a92de6'} + helpers.put_mock_creds(output_name, creds, self.secrets_bucket, + 'us-east-1', self.kms_alias) + elif service == 'phantom': output_name = '{}/{}'.format(service, descriptor) creds = {'ph_auth_token': '6c586bc047b9749a92de29078a015cc6', diff --git a/tests/unit/stream_alert_alert_processor/helpers.py b/tests/unit/stream_alert_alert_processor/helpers.py index 9f09816bb..71ac635c9 100644 --- a/tests/unit/stream_alert_alert_processor/helpers.py +++ b/tests/unit/stream_alert_alert_processor/helpers.py @@ -72,7 +72,15 @@ def get_random_alert(key_count, rule_name, omit_rule_desc=False): return alert -def get_alert(index=0): +def get_alert(index=0, context=None): + """This function generates a sample alert for testing purposes + + Args: + index (int): test_index value (0 by default) + context(dict): context dictionary (None by default) + """ + context = context or {} + return { 'record': { 'test_index': index, @@ -90,6 +98,7 @@ def get_alert(index=0): 'outputs': [ 'slack:unit_test_channel' ], + 'context': context, 'source_service': 's3', 'source_entity': 'corp-prefix.prod.cb.region', 'log_type': 'json', diff --git a/tests/unit/stream_alert_alert_processor/test_outputs.py b/tests/unit/stream_alert_alert_processor/test_outputs.py index 33ff7c39d..e0a049c4b 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs.py @@ -14,11 +14,12 @@ limitations under the License. """ # pylint: disable=protected-access +# pylint: disable=too-many-lines from collections import Counter, OrderedDict import json import boto3 -from mock import call, patch +from mock import call, patch, PropertyMock from moto import mock_s3, mock_kms, mock_lambda, mock_kinesis from nose.tools import ( assert_equal, @@ -268,6 +269,456 @@ def test_dispatch_bad_descriptor(self, log_error_mock): log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) +class TestPagerDutyIncidentOutput(object): + """Test class for PagerDutyIncidentOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'pagerduty-incident' + cls.__descriptor = 'unit_test_pagerduty-incident' + cls.__backup_method = None + cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def test_get_default_properties(self): + """Get Default Properties - PagerDutyIncidentOutput""" + props = self.__dispatcher._get_default_properties() + assert_equal(len(props), 1) + assert_equal(props['api'], + 'https://api.pagerduty.com') + + def test_get_endpoint(self): + """Get Endpoint - PagerDutyIncidentOutput""" + props = self.__dispatcher._get_default_properties() + endpoint = self.__dispatcher._get_endpoint(props['api'], 'testtest') + assert_equal(endpoint, + 'https://api.pagerduty.com/testtest') + + def _setup_dispatch(self, context=None): + """Helper for setting up PagerDutyIncidentOutput dispatch""" + remove_temp_secrets() + + # Cache the _get_default_properties and set it to return None + self.__backup_method = self.__dispatcher._get_default_properties + self.__dispatcher._get_default_properties = lambda: None + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'api': 'https://api.pagerduty.com', + 'token': 'mocked_token', + 'service_key': 'mocked_service_key', + 'escalation_policy': 'mocked_escalation_policy'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + return get_alert(0, context) + + def _teardown_dispatch(self): + """Replace method with cached method""" + self.__dispatcher._get_default_properties = self.__backup_method + + @patch('requests.get') + def test_check_exists_get_id(self, get_mock): + """Check Exists Get Id - PagerDutyIncidentOutput""" + # /check + get_mock.return_value.status_code = 200 + json_check = json.loads('{"check": [{"id": "checked_id"}]}') + get_mock.return_value.json.return_value = json_check + + checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + assert_equal(checked, 'checked_id') + + @patch('requests.get') + def test_check_exists_get_id_fail(self, get_mock): + """Check Exists Get Id Fail - PagerDutyIncidentOutput""" + # /check + get_mock.return_value.status_code = 200 + json_check = json.loads('{}') + get_mock.return_value.json.return_value = json_check + + checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + assert_false(checked) + + @patch('requests.get') + def test_user_verify_success(self, get_mock): + """User Verify Success - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"users": [{"id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_check + + user_verified = self.__dispatcher._user_verify('valid_user') + assert_equal(user_verified['id'], 'verified_user_id') + assert_equal(user_verified['type'], 'user_reference') + + @patch('requests.get') + def test_user_verify_fail(self, get_mock): + """User Verify Fail - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_users": [{"not_id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_check + + user_verified = self.__dispatcher._user_verify('valid_user') + assert_false(user_verified) + + @patch('requests.get') + def test_policy_verify_success_no_default(self, get_mock): + """Policy Verify Success (No Default) - PagerDutyIncidentOutput""" + # /escalation_policies + get_mock.return_value.status_code = 200 + json_check = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + get_mock.return_value.json.return_value = json_check + + policy_verified = self.__dispatcher._policy_verify('valid_policy', '') + assert_equal(policy_verified['id'], 'good_policy_id') + assert_equal(policy_verified['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_policy_verify_success_default(self, get_mock): + """Policy Verify Success (Default) - PagerDutyIncidentOutput""" + # /escalation_policies + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_check_bad = json.loads('{"no_escalation_policies": [{"id": "bad_policy_id"}]}') + json_check_good = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + get_mock.return_value.json.side_effect = [json_check_bad, json_check_good] + + policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') + assert_equal(policy_verified['id'], 'good_policy_id') + assert_equal(policy_verified['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_policy_verify_fail_default(self, get_mock): + """Policy Verify Fail (Default) - PagerDutyIncidentOutput""" + # /not_escalation_policies + type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400]) + json_check_bad = json.loads('{"escalation_policies": [{"id": "bad_policy_id"}]}') + json_check_bad_default = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + get_mock.return_value.json.side_effect = [json_check_bad, json_check_bad_default] + policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') + assert_false(policy_verified) + + @patch('requests.get') + def test_policy_verify_fail_no_default(self, get_mock): + """Policy Verify Fail (No Default) - PagerDutyIncidentOutput""" + # /not_escalation_policies + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_escalation_policies": [{"not_id": "verified_policy_id"}]}') + get_mock.return_value.json.return_value = json_check + + policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') + assert_false(policy_verified) + + @patch('requests.get') + def test_service_verify_success(self, get_mock): + """Service Verify Success - PagerDutyIncidentOutput""" + # /services + get_mock.return_value.status_code = 200 + json_check = json.loads('{"services": [{"id": "verified_service_id"}]}') + get_mock.return_value.json.return_value = json_check + + service_verified = self.__dispatcher._service_verify('valid_service') + assert_equal(service_verified['id'], 'verified_service_id') + assert_equal(service_verified['type'], 'service_reference') + + @patch('requests.get') + def test_service_verify_fail(self, get_mock): + """Service Verify Fail - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_services": [{"not_id": "verified_service_id"}]}') + get_mock.return_value.json.return_value = json_check + + service_verified = self.__dispatcher._service_verify('valid_service') + assert_false(service_verified) + + @patch('requests.get') + def test_item_verify_success(self, get_mock): + """Item Verify Success - PagerDutyIncidentOutput""" + # /items + get_mock.return_value.status_code = 200 + json_check = json.loads('{"items": [{"id": "verified_item_id"}]}') + get_mock.return_value.json.return_value = json_check + + item_verified = self.__dispatcher._item_verify('http://mock_url', 'valid_item', + 'items', 'item_reference') + assert_equal(item_verified['id'], 'verified_item_id') + assert_equal(item_verified['type'], 'item_reference') + + @patch('requests.get') + def test_incident_assignment_user(self, get_mock): + """Incident Assignment User - PagerDutyIncidentOutput""" + context = {'assigned_user': 'user_to_assign'} + get_mock.return_value.status_code = 200 + json_user = json.loads('{"users": [{"id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_user + + assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + + assert_equal(assigned_key, 'assignments') + assert_equal(assigned_value[0]['assignee']['id'], 'verified_user_id') + assert_equal(assigned_value[0]['assignee']['type'], 'user_reference') + + @patch('requests.get') + def test_incident_assignment_policy_no_default(self, get_mock): + """Incident Assignment Policy (No Default) - PagerDutyIncidentOutput""" + context = {'assigned_policy': 'policy_to_assign'} + get_mock.return_value.status_code = 200 + json_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') + get_mock.return_value.json.return_value = json_policy + + assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + + assert_equal(assigned_key, 'escalation_policy') + assert_equal(assigned_value['id'], 'verified_policy_id') + assert_equal(assigned_value['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_incident_assignment_policy_default(self, get_mock): + """Incident Assignment Policy (Default) - PagerDutyIncidentOutput""" + context = {'assigned_policy': 'bad_invalid_policy_to_assign'} + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_bad_policy = json.loads('{"not_escalation_policies": [{"id": "bad_policy_id"}]}') + json_good_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') + get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy] + + assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + + assert_equal(assigned_key, 'escalation_policy') + assert_equal(assigned_value['id'], 'verified_policy_id') + assert_equal(assigned_value['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_item_verify_fail(self, get_mock): + """Item Verify Fail - PagerDutyIncidentOutput""" + # /not_items + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_items": [{"not_id": "verified_item_id"}]}') + get_mock.return_value.json.return_value = json_check + + item_verified = self.__dispatcher._item_verify('http://mock_url', 'valid_item', + 'items', 'item_reference') + assert_false(item_verified) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_good_user(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Good User""" + ctx = { + 'pagerduty-incident': { + 'assigned_user': 'valid_user' + } + } + alert = self._setup_dispatch(context=ctx) + + # /users, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_user = json.loads('{"users": [{"id": "valid_user_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_user, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_good_policy(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Good Policy""" + ctx = { + 'pagerduty-incident': { + 'assigned_policy': 'valid_policy' + } + } + alert = self._setup_dispatch(context=ctx) + + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_bad_user(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Bad User""" + ctx = { + 'pagerduty-incident': { + 'assigned_user': 'invalid_user' + } + } + alert = self._setup_dispatch(context=ctx) + + # /users, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) + json_user = json.loads('{"not_users": [{"id": "user_id"}]}') + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_no_context(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - No Context""" + alert = self._setup_dispatch() + + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_error_mock): + """PagerDutyIncidentOutput dispatch failure - No User, Bad Policy, Bad Service""" + alert = self._setup_dispatch() + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400, 400]) + json_empty = json.loads('{}') + get_mock.return_value.json.side_effect = [json_empty, json_empty, json_empty] + + # /incidents + post_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_bad_policy(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Bad Policy""" + ctx = { + 'pagerduty-incident': { + 'assigned_policy': 'valid_policy' + } + } + alert = self._setup_dispatch(context=ctx) + # /escalation_policies, /services + get_mock.return_value.side_effect = [400, 200, 200] + type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 200, 200]) + json_bad_policy = json.loads('{}') + json_good_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_error_mock): + """PagerDutyIncidentOutput dispatch - Bad Dispatch""" + alert = self._setup_dispatch() + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.error') + @mock_s3 + @mock_kms + def test_dispatch_bad_descriptor(self, log_error_mock): + """PagerDutyIncidentOutput dispatch - Bad Descriptor""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + @mock_s3 @mock_kms @@ -410,7 +861,7 @@ def test_dispatch_failure(self, post_mock, get_mock, log_mock): get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = json.loads('{"count": 0, "data": []}') # _setup_container, dispatch - post_mock.return_value.status_code.side_effect = [200, 400] + type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400]) json_id = json.loads('{"id": 1948}') json_error = json.loads('{"message": "error message", "errors": ["error1"]}') post_mock.return_value.json.return_value.side_effect = [json_id, json_error] diff --git a/tests/unit/stream_alert_rule_processor/test_rules_engine.py b/tests/unit/stream_alert_rule_processor/test_rules_engine.py index d882cbb39..14f61b896 100644 --- a/tests/unit/stream_alert_rule_processor/test_rules_engine.py +++ b/tests/unit/stream_alert_rule_processor/test_rules_engine.py @@ -101,11 +101,13 @@ def alert_format_test(rec): # pylint: disable=unused-variable 'log_source', 'outputs', 'source_service', - 'source_entity' + 'source_entity', + 'context' } assert_items_equal(alerts[0].keys(), alert_keys) assert_is_instance(alerts[0]['record'], dict) assert_is_instance(alerts[0]['outputs'], list) + assert_is_instance(alerts[0]['context'], dict) # test alert fields assert_is_instance(alerts[0]['rule_name'], str) @@ -207,7 +209,8 @@ def cloudtrail_us_east_logs(rec): datatypes=[], logs=['test_log_type_json_nested'], outputs=['s3:sample_bucket'], - req_subkeys={'requestParameters': ['program']} + req_subkeys={'requestParameters': ['program']}, + context={} ) data = json.dumps({