From 33c0f770be61beebff95c8a8c080ff2e54ba2db7 Mon Sep 17 00:00:00 2001 From: javier_marcos Date: Tue, 14 Nov 2017 07:30:38 -0800 Subject: [PATCH] Splitting dispatch and more tests --- stream_alert/alert_processor/outputs.py | 114 ++++++++++++++---- .../test_outputs.py | 73 +++++++++++ 2 files changed, 161 insertions(+), 26 deletions(-) diff --git a/stream_alert/alert_processor/outputs.py b/stream_alert/alert_processor/outputs.py index 5649c5a0c..736d6c68c 100644 --- a/stream_alert/alert_processor/outputs.py +++ b/stream_alert/alert_processor/outputs.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +# pylint: disable=too-many-lines from abc import abstractmethod import cgi from collections import OrderedDict @@ -256,12 +257,12 @@ def _get_endpoint(base_url, endpoint): """ return os.path.join(base_url, endpoint) - def _check_exists_get_id(self, filter_str, target_url, headers, target_key): + def _check_exists_get_id(self, filter_str, url, headers, target_key): """Generic method to run a search in the PagerDuty REST API and return the id of the first occurence from the results. Args: - filter (str): The query filter to search for in the API + filter_str (str): The query filter to search for in the API url (str): The url to send the requests to in the API headers (dict): A dictionary containing header parameters target_key (str): The key to extract in the returned results @@ -273,14 +274,90 @@ def _check_exists_get_id(self, filter_str, target_url, headers, target_key): params = { 'query': '"{}"'.format(filter_str) } - resp = self._get_request(target_url, params, headers, False) + resp = self._get_request(url, params, headers, False) if not self._check_http_response(resp): return False response = resp.json() # If there are results, get the first occurence from the list - return response and response.get(target_key)[0]['id'] + target = response.get(target_key, False) + if response and target: + return target[0]['id'] + + return False + + def _user_verify(self, api_url, user, headers): + """Method to verify the existance of an user with the API + + Args: + api_url (str): Base URL of the API + user (str): User to query about in the API. + headers (dict): Headers used for API authentication + + Returns: + dict or False: JSON object be used in the API call, containing the user_id + and user_reference. False if user is not found + """ + users_url = self._get_endpoint(api_url, self.USERS_ENDPOINT) + user_id = self._check_exists_get_id(user, users_url, headers, + self.USERS_ENDPOINT) + if not user_id: + LOGGER.info('User[%s] not found in %s', user, self.__service__) + return False + + return { + 'id': user_id, + 'type': 'user_reference' + } + + def _policy_verify(self, api_url, policy, headers): + """Method to verify the existance of a escalation policy with the API + + Args: + api_url (str): Base URL of the API + policy (str): Escalation policy to query about in the API + headers (dict): Headers used for API authentication + + Returns: + dict: JSON object be used in the API call, containing the policy_id + and escalation_policy_reference + """ + policies_url = self._get_endpoint(api_url, self.POLICIES_ENDPOINT) + policy_id = self._check_exists_get_id(policy, policies_url, headers, + self.POLICIES_ENDPOINT) + if not policy_id: + LOGGER.info('Escalation Policy[%s] not found in %s', policy, self.__service__) + return False + + return { + 'id': policy_id, + 'type': 'escalation_policy_reference' + } + + def _service_verify(self, api_url, service, headers): + """Method to verify the existance of a service with the API + + Args: + api_url (str): Base URL of the API + service (str): Service to query about in the API + headers (dict): Headers used for API authentication + + Returns: + dict: JSON object be used in the API call, containing the service_id + and the service_reference + """ + services_url = self._get_endpoint(api_url, self.SERVICES_ENDPOINT) + service_id = self._check_exists_get_id(service, services_url, headers, + self.SERVICES_ENDPOINT) + if not service_id: + LOGGER.info('Service[%s] not found in %s', service, self.__service__) + return False + + return { + 'id': service_id, + 'type': 'service_reference' + } def dispatch(self, **kwargs): """Send incident to Pagerduty Incidents API v2 @@ -313,29 +390,20 @@ def dispatch(self, **kwargs): # Incident assignment goes in this order: # Provided user -> provided policy -> default policy if user_to_assign: - users_url = self._get_endpoint(creds['api'], self.USERS_ENDPOINT) - user_id = self._check_exists_get_id(user_to_assign, users_url, headers, - self.USERS_ENDPOINT) - if user_id: + user_assignee = self._user_verify(creds['api'], user_to_assign, headers) + if user_assignee: assigned_key = 'assignments' assigned_value = [{ - 'assignee' : { - 'id': user_id, - 'type': 'user_reference'} - }] + 'assignee' : user_assignee}] + else: + user_to_assign = user_assignee policy_to_assign = creds['escalation_policy'] if not user_to_assign and rule_context.get('assigned_policy'): policy_to_assign = rule_context.get('assigned_policy') - policies_url = self._get_endpoint(creds['api'], self.POLICIES_ENDPOINT) - policy_id = self._check_exists_get_id(policy_to_assign, policies_url, headers, - self.POLICIES_ENDPOINT) assigned_key = 'escalation_policy' - assigned_value = { - 'id': policy_id, - 'type': 'escalation_policy_reference' - } + assigned_value = self._policy_verify(creds['api'], policy_to_assign, headers) # Start preparing the incident JSON blob to be sent to the API incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name']) @@ -344,13 +412,7 @@ def dispatch(self, **kwargs): 'details': kwargs['alert']['rule_description'] } # We need to get the service id from the API - services_url = self._get_endpoint(creds['api'], self.SERVICES_ENDPOINT) - service_id = self._check_exists_get_id(creds['service_key'], services_url, headers, - self.SERVICES_ENDPOINT) - incident_service = { - 'id': service_id, - 'type': 'service_reference' - } + incident_service = self._service_verify(creds['api'], creds['service_key'], headers) incident = { 'incident': { 'type': 'incident', diff --git a/tests/unit/stream_alert_alert_processor/test_outputs.py b/tests/unit/stream_alert_alert_processor/test_outputs.py index bc77365cc..6fcfc0d44 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs.py @@ -324,6 +324,79 @@ def _teardown_dispatch(self): """Replace method with cached method""" self.__dispatcher._get_default_properties = self.__backup_method + @patch('requests.get') + def test_check_exists_get_id(self, get_mock): + """Check Exists Get Id - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"check": [{"id": "checked_id"}]}') + get_mock.return_value.json.return_value = json_check + + checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', {}, 'check') + assert_equal(checked, 'checked_id') + + @patch('requests.get') + def test_user_verify_success(self, get_mock): + """User Verify Success - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"users": [{"id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_check + + user_verified = self.__dispatcher._user_verify('http://mock_url', 'valid_user', {}) + assert_equal(user_verified['id'], 'verified_user_id') + assert_equal(user_verified['type'], 'user_reference') + + @patch('requests.get') + def test_user_verify_fail(self, get_mock): + """User Verify Fail - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_users": [{"not_id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_check + + user_verified = self.__dispatcher._user_verify('http://mock_url', 'valid_user', {}) + assert_false(user_verified) + + @patch('requests.get') + def test_policy_verify_success(self, get_mock): + """Policy Verify Success - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') + get_mock.return_value.json.return_value = json_check + + policy_verified = self.__dispatcher._policy_verify('http://mock_url', 'valid_policy', {}) + assert_equal(policy_verified['id'], 'verified_policy_id') + assert_equal(policy_verified['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_policy_verify_fail(self, get_mock): + """Policy Verify Fail - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_escalation_policies": [{"not_id": "verified_policy_id"}]}') + get_mock.return_value.json.return_value = json_check + + policy_verified = self.__dispatcher._policy_verify('http://mock_url', 'valid_policy', {}) + assert_false(policy_verified) + + @patch('requests.get') + def test_service_verify_success(self, get_mock): + """Service Verify Success - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"services": [{"id": "verified_service_id"}]}') + get_mock.return_value.json.return_value = json_check + + service_verified = self.__dispatcher._service_verify('http://mock_url', 'valid_user', {}) + assert_equal(service_verified['id'], 'verified_service_id') + assert_equal(service_verified['type'], 'service_reference') + + @patch('requests.get') + def test_service_verify_fail(self, get_mock): + """Service Verify Fail - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_services": [{"not_id": "verified_service_id"}]}') + get_mock.return_value.json.return_value = json_check + + service_verified = self.__dispatcher._service_verify('http://mock_url', 'valid_user', {}) + assert_false(service_verified) + @patch('logging.Logger.info') @patch('requests.post') @patch('requests.get')