Skip to content

Commit

Permalink
Splitting dispatch and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
javier_marcos committed Nov 14, 2017
1 parent f2ca160 commit 33c0f77
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 26 deletions.
114 changes: 88 additions & 26 deletions stream_alert/alert_processor/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
# pylint: disable=too-many-lines
from abc import abstractmethod
import cgi
from collections import OrderedDict
Expand Down Expand Up @@ -256,12 +257,12 @@ def _get_endpoint(base_url, endpoint):
"""
return os.path.join(base_url, endpoint)

def _check_exists_get_id(self, filter_str, target_url, headers, target_key):
def _check_exists_get_id(self, filter_str, url, headers, target_key):
"""Generic method to run a search in the PagerDuty REST API and return the id
of the first occurence from the results.
Args:
filter (str): The query filter to search for in the API
filter_str (str): The query filter to search for in the API
url (str): The url to send the requests to in the API
headers (dict): A dictionary containing header parameters
target_key (str): The key to extract in the returned results
Expand All @@ -273,14 +274,90 @@ def _check_exists_get_id(self, filter_str, target_url, headers, target_key):
params = {
'query': '"{}"'.format(filter_str)
}
resp = self._get_request(target_url, params, headers, False)
resp = self._get_request(url, params, headers, False)
if not self._check_http_response(resp):
return False

response = resp.json()

# If there are results, get the first occurence from the list
return response and response.get(target_key)[0]['id']
target = response.get(target_key, False)
if response and target:
return target[0]['id']

return False

def _user_verify(self, api_url, user, headers):
"""Method to verify the existance of an user with the API
Args:
api_url (str): Base URL of the API
user (str): User to query about in the API.
headers (dict): Headers used for API authentication
Returns:
dict or False: JSON object be used in the API call, containing the user_id
and user_reference. False if user is not found
"""
users_url = self._get_endpoint(api_url, self.USERS_ENDPOINT)
user_id = self._check_exists_get_id(user, users_url, headers,
self.USERS_ENDPOINT)
if not user_id:
LOGGER.info('User[%s] not found in %s', user, self.__service__)
return False

return {
'id': user_id,
'type': 'user_reference'
}

def _policy_verify(self, api_url, policy, headers):
"""Method to verify the existance of a escalation policy with the API
Args:
api_url (str): Base URL of the API
policy (str): Escalation policy to query about in the API
headers (dict): Headers used for API authentication
Returns:
dict: JSON object be used in the API call, containing the policy_id
and escalation_policy_reference
"""
policies_url = self._get_endpoint(api_url, self.POLICIES_ENDPOINT)
policy_id = self._check_exists_get_id(policy, policies_url, headers,
self.POLICIES_ENDPOINT)
if not policy_id:
LOGGER.info('Escalation Policy[%s] not found in %s', policy, self.__service__)
return False

return {
'id': policy_id,
'type': 'escalation_policy_reference'
}

def _service_verify(self, api_url, service, headers):
"""Method to verify the existance of a service with the API
Args:
api_url (str): Base URL of the API
service (str): Service to query about in the API
headers (dict): Headers used for API authentication
Returns:
dict: JSON object be used in the API call, containing the service_id
and the service_reference
"""
services_url = self._get_endpoint(api_url, self.SERVICES_ENDPOINT)
service_id = self._check_exists_get_id(service, services_url, headers,
self.SERVICES_ENDPOINT)
if not service_id:
LOGGER.info('Service[%s] not found in %s', service, self.__service__)
return False

return {
'id': service_id,
'type': 'service_reference'
}

def dispatch(self, **kwargs):
"""Send incident to Pagerduty Incidents API v2
Expand Down Expand Up @@ -313,29 +390,20 @@ def dispatch(self, **kwargs):
# Incident assignment goes in this order:
# Provided user -> provided policy -> default policy
if user_to_assign:
users_url = self._get_endpoint(creds['api'], self.USERS_ENDPOINT)
user_id = self._check_exists_get_id(user_to_assign, users_url, headers,
self.USERS_ENDPOINT)
if user_id:
user_assignee = self._user_verify(creds['api'], user_to_assign, headers)
if user_assignee:
assigned_key = 'assignments'
assigned_value = [{
'assignee' : {
'id': user_id,
'type': 'user_reference'}
}]
'assignee' : user_assignee}]
else:
user_to_assign = user_assignee

policy_to_assign = creds['escalation_policy']
if not user_to_assign and rule_context.get('assigned_policy'):
policy_to_assign = rule_context.get('assigned_policy')

policies_url = self._get_endpoint(creds['api'], self.POLICIES_ENDPOINT)
policy_id = self._check_exists_get_id(policy_to_assign, policies_url, headers,
self.POLICIES_ENDPOINT)
assigned_key = 'escalation_policy'
assigned_value = {
'id': policy_id,
'type': 'escalation_policy_reference'
}
assigned_value = self._policy_verify(creds['api'], policy_to_assign, headers)

# Start preparing the incident JSON blob to be sent to the API
incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name'])
Expand All @@ -344,13 +412,7 @@ def dispatch(self, **kwargs):
'details': kwargs['alert']['rule_description']
}
# We need to get the service id from the API
services_url = self._get_endpoint(creds['api'], self.SERVICES_ENDPOINT)
service_id = self._check_exists_get_id(creds['service_key'], services_url, headers,
self.SERVICES_ENDPOINT)
incident_service = {
'id': service_id,
'type': 'service_reference'
}
incident_service = self._service_verify(creds['api'], creds['service_key'], headers)
incident = {
'incident': {
'type': 'incident',
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/stream_alert_alert_processor/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,79 @@ def _teardown_dispatch(self):
"""Replace method with cached method"""
self.__dispatcher._get_default_properties = self.__backup_method

@patch('requests.get')
def test_check_exists_get_id(self, get_mock):
"""Check Exists Get Id - PagerDutyIncidentOutput"""
get_mock.return_value.status_code = 200
json_check = json.loads('{"check": [{"id": "checked_id"}]}')
get_mock.return_value.json.return_value = json_check

checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', {}, 'check')
assert_equal(checked, 'checked_id')

@patch('requests.get')
def test_user_verify_success(self, get_mock):
"""User Verify Success - PagerDutyIncidentOutput"""
get_mock.return_value.status_code = 200
json_check = json.loads('{"users": [{"id": "verified_user_id"}]}')
get_mock.return_value.json.return_value = json_check

user_verified = self.__dispatcher._user_verify('http://mock_url', 'valid_user', {})
assert_equal(user_verified['id'], 'verified_user_id')
assert_equal(user_verified['type'], 'user_reference')

@patch('requests.get')
def test_user_verify_fail(self, get_mock):
"""User Verify Fail - PagerDutyIncidentOutput"""
get_mock.return_value.status_code = 200
json_check = json.loads('{"not_users": [{"not_id": "verified_user_id"}]}')
get_mock.return_value.json.return_value = json_check

user_verified = self.__dispatcher._user_verify('http://mock_url', 'valid_user', {})
assert_false(user_verified)

@patch('requests.get')
def test_policy_verify_success(self, get_mock):
"""Policy Verify Success - PagerDutyIncidentOutput"""
get_mock.return_value.status_code = 200
json_check = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}')
get_mock.return_value.json.return_value = json_check

policy_verified = self.__dispatcher._policy_verify('http://mock_url', 'valid_policy', {})
assert_equal(policy_verified['id'], 'verified_policy_id')
assert_equal(policy_verified['type'], 'escalation_policy_reference')

@patch('requests.get')
def test_policy_verify_fail(self, get_mock):
"""Policy Verify Fail - PagerDutyIncidentOutput"""
get_mock.return_value.status_code = 200
json_check = json.loads('{"not_escalation_policies": [{"not_id": "verified_policy_id"}]}')
get_mock.return_value.json.return_value = json_check

policy_verified = self.__dispatcher._policy_verify('http://mock_url', 'valid_policy', {})
assert_false(policy_verified)

@patch('requests.get')
def test_service_verify_success(self, get_mock):
"""Service Verify Success - PagerDutyIncidentOutput"""
get_mock.return_value.status_code = 200
json_check = json.loads('{"services": [{"id": "verified_service_id"}]}')
get_mock.return_value.json.return_value = json_check

service_verified = self.__dispatcher._service_verify('http://mock_url', 'valid_user', {})
assert_equal(service_verified['id'], 'verified_service_id')
assert_equal(service_verified['type'], 'service_reference')

@patch('requests.get')
def test_service_verify_fail(self, get_mock):
"""Service Verify Fail - PagerDutyIncidentOutput"""
get_mock.return_value.status_code = 200
json_check = json.loads('{"not_services": [{"not_id": "verified_service_id"}]}')
get_mock.return_value.json.return_value = json_check

service_verified = self.__dispatcher._service_verify('http://mock_url', 'valid_user', {})
assert_false(service_verified)

@patch('logging.Logger.info')
@patch('requests.post')
@patch('requests.get')
Expand Down

0 comments on commit 33c0f77

Please sign in to comment.