Skip to content

Commit

Permalink
More tests and addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
javier_marcos committed Nov 13, 2017
1 parent b1966aa commit 5988364
Show file tree
Hide file tree
Showing 4 changed files with 481 additions and 7 deletions.
17 changes: 17 additions & 0 deletions docs/source/rules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ Examples:
req_subkeys={'columns':['port', 'protocol']})
...
context
~~~~~~~~~~~

``context`` is an optional field to pass extra instructions to the alert processor on how to route the alert. It can be particulary helpful to pass data to an output.

Example:

.. code-block:: python
# Context provided to the pagerduty-incident output
# with instructions to assign the incident to a user.
@rule(logs=['osquery:differential'],
outputs=['pagerduty', 'aws-s3'],
context={'pagerduty-incident':{'assigned_user': 'valid_user'}})
...
Helpers
-------
Expand Down
189 changes: 183 additions & 6 deletions stream_alert/alert_processor/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def _get_default_properties(cls):
Returns:
dict: Contains various default items for this output (ie: url)
"""
return {
'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'
}
return {'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'}

def get_user_defined_properties(self):
"""Get properties that must be asssigned by the user when configuring a new PagerDuty
Expand Down Expand Up @@ -133,9 +131,7 @@ def _get_default_properties(cls):
Returns:
dict: Contains various default items for this output (ie: url)
"""
return {
'url': 'https://events.pagerduty.com/v2/enqueue'
}
return {'url': 'https://events.pagerduty.com/v2/enqueue'}

def get_user_defined_properties(self):
"""Get properties that must be asssigned by the user when configuring a new PagerDuty
Expand Down Expand Up @@ -198,6 +194,187 @@ def dispatch(self, **kwargs):

return self._log_status(success)

@output
class PagerDutyIncidentOutput(StreamOutputBase):
"""PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2"""
__service__ = 'pagerduty-incident'
INCIDENTS_ENDPOINT = 'incidents'
USERS_ENDPOINT = 'users'
POLICIES_ENDPOINT = 'escalation_policies'
SERVICES_ENDPOINT = 'services'

@classmethod
def _get_default_properties(cls):
"""Get the standard url used for PagerDuty Incidents API v2. This value the same for
everyone, so is hard-coded here and does not need to be configured by the user
Returns:
dict: Contains various default items for this output (ie: url)
"""
return {'api': 'https://api.pagerduty.com'}

def get_user_defined_properties(self):
"""Get properties that must be asssigned by the user when configuring a new PagerDuty
event output. This should be sensitive or unique information for this use-case that
needs to come from the user.
Every output should return a dict that contains a 'descriptor' with a description of the
integration being configured.
PagerDuty also requires a routing_key that represents this integration. This
value should be masked during input and is a credential requirement.
Returns:
OrderedDict: Contains various OutputProperty items
"""
return OrderedDict([
('descriptor',
OutputProperty(description='a short and unique descriptor for this '
'PagerDuty integration')),
('token',
OutputProperty(description='the token for this PagerDuty integration',
mask_input=True,
cred_requirement=True)),
('service_key',
OutputProperty(description='the service key for this PagerDuty integration',
mask_input=True,
cred_requirement=True)),
('escalation_policy',
OutputProperty(description='the name of the default escalation policy'))
])

@staticmethod
def _get_endpoint(base_url, endpoint):
"""Helper to get the full url for a PagerDuty Incidents endpoint.
Args:
base_url (str): Base URL for the API
endpoint (str): Endpoint that we want the full URL for
Returns:
str: Full URL of the provided endpoint
"""
return os.path.join(base_url, endpoint)

def _check_exists_get_id(self, filter_str, target_url, headers, target_key):
"""Generic method to run a search in the PagerDuty REST API and return the id
of the first occurence from the results.
Args:
filter (str): The query filter to search for in the API
url (str): The url to send the requests to in the API
headers (dict): A dictionary containing header parameters
target_key (str): The key to extract in the returned results
Returns:
str: ID of the targeted element that matches the provided filter or
False if a matching element does not exists.
"""
params = {
'query': '"{}"'.format(filter_str)
}
resp = self._get_request(target_url, params, headers, False)
if not self._check_http_response(resp):
return False

response = resp.json()

# If there are results, get the first occurence from the list
return response and response.get(target_key)[0]['id']

def dispatch(self, **kwargs):
"""Send incident to Pagerduty Incidents API v2
Keyword Args:
**kwargs: consists of any combination of the following items:
descriptor (str): Service descriptor (ie: slack channel, pd integration)
rule_name (str): Name of the triggered rule
alert (dict): Alert relevant to the triggered rule
alert['context'] (dict): Provides user or escalation policy
"""
creds = self._load_creds(kwargs['descriptor'])
if not creds:
return self._log_status(False)

# Preparing headers for API calls
headers = {
'Authorization': 'Token token={}'.format(creds['token']),
'Accept': 'application/vnd.pagerduty+json;version=2'
}

# Extracting context data to assign the incident
rule_context = kwargs['alert'].get('context', {})
if rule_context:
rule_context = rule_context[self.__service__]

# Check if a user to assign the incident is provided
user_to_assign = rule_context.get('assigned_user', False)

# Incident assignment goes in this order:
# Provided user -> provided policy -> default policy
if user_to_assign:
users_url = os.path.join(creds['api'], self.USERS_ENDPOINT)
user_id = self._check_exists_get_id(user_to_assign,
users_url, headers, self.USERS_ENDPOINT)
if user_id:
assigned_key = 'assignments'
assigned_value = [{
'assignee' : {
'id': '',
'type': 'user_reference'}
}]
# If the user retrieval did not succeed, default to policies
else:
user_to_assign = False

if not user_to_assign and rule_context.get('assigned_policy'):
policy_to_assign = rule_context.get('assigned_policy')
else:
policy_to_assign = creds['escalation_policy']

policies_url = os.path.join(creds['api'], self.POLICIES_ENDPOINT)
policy_id = self._check_exists_get_id(policy_to_assign,
policies_url, headers, self.POLICIES_ENDPOINT)
assigned_key = 'escalation_policy'
assigned_value = {
'id': policy_id,
'type': 'escalation_policy_reference'
}

# Start preparing the incident JSON blob to be sent to the API
incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name'])
incident_body = {
'type': '',
'details': ''
}
# We need to get the service id from the API
services_url = os.path.join(creds['api'], self.SERVICES_ENDPOINT)
service_id = self._check_exists_get_id(creds['service_key'],
services_url, headers, self.SERVICES_ENDPOINT)
incident_service = {
'id': service_id,
'type': 'service_reference'
}
incident_priority = {
'id': '',
'type': 'priority_reference'
}
incident = {
'incident': {
'type': 'incident',
'title': incident_title,
'service': incident_service,
'priority': incident_priority,
'body': incident_body
},
assigned_key: assigned_value
}
incidents_url = os.path.join(creds['api'], self.INCIDENTS_ENDPOINT)
resp = self._post_request(incidents_url, incident, None, True)
success = self._check_http_response(resp)

return self._log_status(success)

@output
class PhantomOutput(StreamOutputBase):
"""PhantomOutput handles all alert dispatching for Phantom"""
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/stream_alert_alert_processor/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,18 @@ def get_random_alert(key_count, rule_name, omit_rule_desc=False):
return alert


def get_alert(index=0):
def get_alert(index=0, context=None):
"""This function generates a sample alert for testing purposes
Args:
index (int): test_index value (0 by default)
context(dict): context dictionary (empty by default)
"""
if not context:
ctx = {}
else:
ctx = context

return {
'record': {
'test_index': index,
Expand All @@ -90,6 +101,7 @@ def get_alert(index=0):
'outputs': [
'slack:unit_test_channel'
],
'context': ctx,
'source_service': 's3',
'source_entity': 'corp-prefix.prod.cb.region',
'log_type': 'json',
Expand Down
Loading

0 comments on commit 5988364

Please sign in to comment.