Skip to content

Commit

Permalink
Merge pull request #467 from airbnb/javier-streamalert-pagerduty-inci…
Browse files Browse the repository at this point in the history
…dents

[output] Adding new output for PagerDuty Incidents
  • Loading branch information
javuto authored Nov 16, 2017
2 parents bc3391c + 056c983 commit c1b293a
Show file tree
Hide file tree
Showing 9 changed files with 766 additions and 16 deletions.
17 changes: 17 additions & 0 deletions docs/source/rules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ Examples:
req_subkeys={'columns':['port', 'protocol']})
...
context
~~~~~~~~~~~

``context`` is an optional field to pass extra instructions to the alert processor on how to route the alert. It can be particulary helpful to pass data to an output.

Example:

.. code-block:: python
# Context provided to the pagerduty-incident output
# with instructions to assign the incident to a user.
@rule(logs=['osquery:differential'],
outputs=['pagerduty', 'aws-s3'],
context={'pagerduty-incident':{'assigned_user': 'valid_user'}})
...
Helpers
-------
Expand Down
3 changes: 2 additions & 1 deletion manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def _add_output_subparser(subparsers):
output_parser.add_argument(
'--service',
choices=[
'aws-firehose', 'aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2', 'phantom', 'slack'
'aws-firehose', 'aws-lambda', 'aws-s3', 'pagerduty', 'pagerduty-v2',
'pagerduty-incident', 'phantom', 'slack'
],
required=True,
help=ARGPARSE_SUPPRESS)
Expand Down
8 changes: 7 additions & 1 deletion stream_alert/alert_processor/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def validate_alert(alert):
'log_source',
'outputs',
'source_service',
'source_entity'
'source_entity',
'context'
}
if not set(alert.keys()) == alert_keys:
LOGGER.error('The alert object must contain the following keys: %s',
Expand All @@ -53,6 +54,11 @@ def validate_alert(alert):
LOGGER.error('The alert record must be a map (dict)')
return False

elif key == 'context':
if not isinstance(alert['context'], dict):
LOGGER.error('The alert context must be a map (dict)')
return False

elif key == 'outputs':
if not isinstance(alert[key], list):
LOGGER.error(
Expand Down
263 changes: 257 additions & 6 deletions stream_alert/alert_processor/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
# pylint: disable=too-many-lines
from abc import abstractmethod
import cgi
from collections import OrderedDict
Expand Down Expand Up @@ -64,9 +65,7 @@ def _get_default_properties(cls):
Returns:
dict: Contains various default items for this output (ie: url)
"""
return {
'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'
}
return {'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'}

def get_user_defined_properties(self):
"""Get properties that must be asssigned by the user when configuring a new PagerDuty
Expand Down Expand Up @@ -133,9 +132,7 @@ def _get_default_properties(cls):
Returns:
dict: Contains various default items for this output (ie: url)
"""
return {
'url': 'https://events.pagerduty.com/v2/enqueue'
}
return {'url': 'https://events.pagerduty.com/v2/enqueue'}

def get_user_defined_properties(self):
"""Get properties that must be asssigned by the user when configuring a new PagerDuty
Expand Down Expand Up @@ -198,6 +195,260 @@ def dispatch(self, **kwargs):

return self._log_status(success)

@output
class PagerDutyIncidentOutput(StreamOutputBase):
"""PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2"""
__service__ = 'pagerduty-incident'
INCIDENTS_ENDPOINT = 'incidents'
USERS_ENDPOINT = 'users'
POLICIES_ENDPOINT = 'escalation_policies'
SERVICES_ENDPOINT = 'services'

def __init__(self, *args, **kwargs):
StreamOutputBase.__init__(self, *args, **kwargs)
self._base_url = None
self._headers = None
self._escalation_policy = None

@classmethod
def _get_default_properties(cls):
"""Get the standard url used for PagerDuty Incidents API v2. This value the same for
everyone, so is hard-coded here and does not need to be configured by the user
Returns:
dict: Contains various default items for this output (ie: url)
"""
return {'api': 'https://api.pagerduty.com'}

def get_user_defined_properties(self):
"""Get properties that must be asssigned by the user when configuring a new PagerDuty
event output. This should be sensitive or unique information for this use-case that
needs to come from the user.
Every output should return a dict that contains a 'descriptor' with a description of the
integration being configured.
PagerDuty also requires a routing_key that represents this integration. This
value should be masked during input and is a credential requirement.
Returns:
OrderedDict: Contains various OutputProperty items
"""
return OrderedDict([
('descriptor',
OutputProperty(description='a short and unique descriptor for this '
'PagerDuty integration')),
('token',
OutputProperty(description='the token for this PagerDuty integration',
mask_input=True,
cred_requirement=True)),
('service_key',
OutputProperty(description='the service key for this PagerDuty integration',
mask_input=True,
cred_requirement=True)),
('escalation_policy',
OutputProperty(description='the name of the default escalation policy'))
])

@staticmethod
def _get_endpoint(base_url, endpoint):
"""Helper to get the full url for a PagerDuty Incidents endpoint.
Args:
base_url (str): Base URL for the API
endpoint (str): Endpoint that we want the full URL for
Returns:
str: Full URL of the provided endpoint
"""
return os.path.join(base_url, endpoint)

def _check_exists_get_id(self, filter_str, url, target_key):
"""Generic method to run a search in the PagerDuty REST API and return the id
of the first occurence from the results.
Args:
filter_str (str): The query filter to search for in the API
url (str): The url to send the requests to in the API
target_key (str): The key to extract in the returned results
Returns:
str: ID of the targeted element that matches the provided filter or
False if a matching element does not exists.
"""
params = {
'query': '{}'.format(filter_str)
}
resp = self._get_request(url, params, self._headers, False)

if not self._check_http_response(resp):
return False

response = resp.json()
if not response:
return False

# If there are results, get the first occurence from the list
return response[target_key][0]['id'] if target_key in response else False

def _user_verify(self, user):
"""Method to verify the existance of an user with the API
Args:
user (str): User to query about in the API.
Returns:
dict or False: JSON object be used in the API call, containing the user_id
and user_reference. False if user is not found
"""
users_url = self._get_endpoint(self._base_url, self.USERS_ENDPOINT)

return self._item_verify(users_url, user, self.USERS_ENDPOINT, 'user_reference')

def _policy_verify(self, policy, default_policy):
"""Method to verify the existance of a escalation policy with the API
Args:
policy (str): Escalation policy to query about in the API
default_policy (str): Escalation policy to use if the first one is not verified
Returns:
dict: JSON object be used in the API call, containing the policy_id
and escalation_policy_reference
"""
policies_url = self._get_endpoint(self._base_url, self.POLICIES_ENDPOINT)

verified = self._item_verify(policies_url, policy, self.POLICIES_ENDPOINT,
'escalation_policy_reference')

# If the escalation policy provided is not verified in the API, use the default
if verified:
return verified

return self._item_verify(policies_url, default_policy, self.POLICIES_ENDPOINT,
'escalation_policy_reference')

def _service_verify(self, service):
"""Method to verify the existance of a service with the API
Args:
service (str): Service to query about in the API
Returns:
dict: JSON object be used in the API call, containing the service_id
and the service_reference
"""
services_url = self._get_endpoint(self._base_url, self.SERVICES_ENDPOINT)

return self._item_verify(services_url, service, self.SERVICES_ENDPOINT, 'service_reference')

def _item_verify(self, item_url, item_str, item_key, item_type):
"""Method to verify the existance of an item with the API
Args:
item_url (str): URL of the API Endpoint within the API to query
item_str (str): Service to query about in the API
item_key (str): Key to be extracted from search results
item_type (str): Type of item reference to be returned
Returns:
dict: JSON object be used in the API call, containing the item id
and the item reference, or False if it fails
"""
item_id = self._check_exists_get_id(item_str, item_url, item_key)
if not item_id:
LOGGER.info('%s not found in %s, %s', item_str, item_key, self.__service__)
return False

return {
'id': item_id,
'type': item_type
}

def _incident_assignment(self, context):
"""Method to determine if the incident gets assigned to a user or an escalation policy
Args:
context (dict): Context provided in the alert record
Returns:
tuple: assigned_key (str), assigned_value (dict to assign incident to an escalation
policy or array of dicts to assign incident to users)
"""
# Check if a user to assign the incident is provided
user_to_assign = context.get('assigned_user', False)

# If provided, verify the user and get the id from API
if user_to_assign:
user_assignee = self._user_verify(user_to_assign)
# User is verified, return tuple
if user_assignee:
return 'assignments', [{'assignee': user_assignee}]

# If escalation policy was not provided, use default one
policy_to_assign = context.get('assigned_policy', self._escalation_policy)

# Verify escalation policy, return tuple
return 'escalation_policy', self._policy_verify(policy_to_assign, self._escalation_policy)

def dispatch(self, **kwargs):
"""Send incident to Pagerduty Incidents API v2
Keyword Args:
**kwargs: consists of any combination of the following items:
descriptor (str): Service descriptor (ie: slack channel, pd integration)
rule_name (str): Name of the triggered rule
alert (dict): Alert relevant to the triggered rule
alert['context'] (dict): Provides user or escalation policy
"""
creds = self._load_creds(kwargs['descriptor'])
if not creds:
return self._log_status(False)

# Preparing headers for API calls
self._headers = {
'Authorization': 'Token token={}'.format(creds['token']),
'Accept': 'application/vnd.pagerduty+json;version=2'
}

# Cache base_url
self._base_url = creds['api']

# Cache default escalation policy
self._escalation_policy = creds['escalation_policy']

# Extracting context data to assign the incident
rule_context = kwargs['alert'].get('context', {})
if rule_context:
rule_context = rule_context.get(self.__service__, {})

# Incident assignment goes in this order:
# Provided user -> provided policy -> default policy
assigned_key, assigned_value = self._incident_assignment(rule_context)

# Start preparing the incident JSON blob to be sent to the API
incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name'])
incident_body = {
'type': 'incident_body',
'details': kwargs['alert']['rule_description']
}
# We need to get the service id from the API
incident_service = self._service_verify(creds['service_key'])
incident = {
'incident': {
'type': 'incident',
'title': incident_title,
'service': incident_service,
'body': incident_body
},
assigned_key: assigned_value
}
incidents_url = self._get_endpoint(self._base_url, self.INCIDENTS_ENDPOINT)
resp = self._post_request(incidents_url, incident, self._headers, True)
success = self._check_http_response(resp)

return self._log_status(success)

@output
class PhantomOutput(StreamOutputBase):
"""PhantomOutput handles all alert dispatching for Phantom"""
Expand Down
10 changes: 7 additions & 3 deletions stream_alert/rule_processor/rules_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
'datatypes',
'logs',
'outputs',
'req_subkeys'])
'req_subkeys',
'context'])


class StreamRules(object):
Expand Down Expand Up @@ -70,6 +71,7 @@ def decorator(rule):
matchers = opts.get('matchers')
datatypes = opts.get('datatypes')
req_subkeys = opts.get('req_subkeys')
context = opts.get('context', {})

if not (logs or datatypes):
LOGGER.error(
Expand All @@ -92,7 +94,8 @@ def decorator(rule):
datatypes,
logs,
outputs,
req_subkeys)
req_subkeys,
context)
return rule
return decorator

Expand Down Expand Up @@ -387,7 +390,8 @@ def process(cls, input_payload):
'log_type': payload.type,
'outputs': rule.outputs,
'source_service': payload.service(),
'source_entity': payload.entity}
'source_entity': payload.entity,
'context': rule.context}
alerts.append(alert)

return alerts
8 changes: 8 additions & 0 deletions stream_alert_cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,14 @@ def setup_outputs(self, alert):
helpers.put_mock_creds(output_name, creds, self.secrets_bucket,
'us-east-1', self.kms_alias)

elif service == 'pagerduty-incident':
output_name = '{}/{}'.format(service, descriptor)
creds = {'token': '247b97499078a015cc6c586bc0a92de6',
'service_key': '247b97499078a015cc6c586bc0a92de6',
'escalation_policy': '247b97499078a015cc6c586bc0a92de6'}
helpers.put_mock_creds(output_name, creds, self.secrets_bucket,
'us-east-1', self.kms_alias)

elif service == 'phantom':
output_name = '{}/{}'.format(service, descriptor)
creds = {'ph_auth_token': '6c586bc047b9749a92de29078a015cc6',
Expand Down
Loading

0 comments on commit c1b293a

Please sign in to comment.