From f6d4ffe80b9c6e0d408978fdfb5e9960e054e3af Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Fri, 17 Nov 2017 22:16:00 -0800 Subject: [PATCH 1/9] [lambda][alert] breaking output classes out into different files --- stream_alert/alert_processor/main.py | 9 +- stream_alert/alert_processor/outputs.py | 1353 ----------------- .../alert_processor/outputs/__init__.py | 18 + stream_alert/alert_processor/outputs/aws.py | 307 ++++ stream_alert/alert_processor/outputs/jira.py | 308 ++++ .../{ => outputs}/output_base.py | 39 +- .../alert_processor/outputs/pagerduty.py | 421 +++++ .../alert_processor/outputs/phantom.py | 155 ++ stream_alert/alert_processor/outputs/slack.py | 228 +++ stream_alert_cli/runner.py | 10 +- 10 files changed, 1489 insertions(+), 1359 deletions(-) delete mode 100644 stream_alert/alert_processor/outputs.py create mode 100644 stream_alert/alert_processor/outputs/__init__.py create mode 100644 stream_alert/alert_processor/outputs/aws.py create mode 100644 stream_alert/alert_processor/outputs/jira.py rename stream_alert/alert_processor/{ => outputs}/output_base.py (89%) create mode 100644 stream_alert/alert_processor/outputs/pagerduty.py create mode 100644 stream_alert/alert_processor/outputs/phantom.py create mode 100644 stream_alert/alert_processor/outputs/slack.py diff --git a/stream_alert/alert_processor/main.py b/stream_alert/alert_processor/main.py index 4fa528878..bb7eea0c3 100644 --- a/stream_alert/alert_processor/main.py +++ b/stream_alert/alert_processor/main.py @@ -18,7 +18,7 @@ from stream_alert.alert_processor import LOGGER from stream_alert.alert_processor.helpers import validate_alert -from stream_alert.alert_processor.outputs import get_output_dispatcher +from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput from stream_alert.shared import NORMALIZATION_KEY @@ -101,7 +101,12 @@ def run(alert, region, function_name, config): continue # Retrieve the proper class to handle dispatching the alerts of this services - output_dispatcher = get_output_dispatcher(service, region, function_name, config) + output_dispatcher = StreamAlertOutput.get_output_dispatcher( + service, + region, + function_name, + config + ) if not output_dispatcher: continue diff --git a/stream_alert/alert_processor/outputs.py b/stream_alert/alert_processor/outputs.py deleted file mode 100644 index 96baaaad5..000000000 --- a/stream_alert/alert_processor/outputs.py +++ /dev/null @@ -1,1353 +0,0 @@ -""" -Copyright 2017-present, Airbnb Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -# pylint: disable=too-many-lines -from abc import abstractmethod -import cgi -from collections import OrderedDict -from datetime import datetime -import json -import os -import uuid - -import backoff -from botocore.exceptions import ClientError -import boto3 - -from stream_alert.alert_processor import LOGGER -from stream_alert.alert_processor.output_base import OutputProperty, StreamOutputBase -from stream_alert.shared.backoff_handlers import ( - backoff_handler, - success_handler, - giveup_handler -) - -# STREAM_OUTPUTS will contain each subclass of the StreamOutputBase -# All included subclasses are designated using the '@output' class decorator -# The keys are the name of the service and the value is the class itself -# {cls.__service__: } - -# pylint: disable=too-many-lines -STREAM_OUTPUTS = {} - - -def output(cls): - """Class decorator to register all stream outputs""" - STREAM_OUTPUTS[cls.__service__] = cls - - -def get_output_dispatcher(service, region, function_name, config): - """Returns the subclass that should handle this particular service""" - try: - return STREAM_OUTPUTS[service](region, function_name, config) - except KeyError: - LOGGER.error('designated output service [%s] does not exist', service) - -@output -class PagerDutyOutput(StreamOutputBase): - """PagerDutyOutput handles all alert dispatching for PagerDuty Events API v1""" - __service__ = 'pagerduty' - - @classmethod - def _get_default_properties(cls): - """Get the standard url used for PagerDuty. This value the same for everyone, so - is hard-coded here and does not need to be configured by the user - Returns: - dict: Contains various default items for this output (ie: url) - """ - return {'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'} - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new PagerDuty - output. This should be sensitive or unique information for this use-case that needs - to come from the user. - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - PagerDuty also requires a service_key that represnts this integration. This - value should be masked during input and is a credential requirement. - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty(description='a short and unique descriptor for this ' - 'PagerDuty integration')), - ('service_key', - OutputProperty(description='the service key for this PagerDuty integration', - mask_input=True, - cred_requirement=True)) - ]) - - def dispatch(self, **kwargs): - """Send alert to Pagerduty - Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - """ - creds = self._load_creds(kwargs['descriptor']) - if not creds: - return self._log_status(False) - - message = 'StreamAlert Rule Triggered - {}'.format(kwargs['rule_name']) - rule_desc = kwargs['alert']['rule_description'] - details = { - 'rule_description': rule_desc, - 'record': kwargs['alert']['record'] - } - data = { - 'service_key': creds['service_key'], - 'event_type': 'trigger', - 'description': message, - 'details': details, - 'client': 'StreamAlert' - } - - resp = self._post_request(creds['url'], data, None, True) - success = self._check_http_response(resp) - - return self._log_status(success) - -@output -class PagerDutyOutputV2(StreamOutputBase): - """PagerDutyOutput handles all alert dispatching for PagerDuty Events API v2""" - __service__ = 'pagerduty-v2' - - @classmethod - def _get_default_properties(cls): - """Get the standard url used for PagerDuty Events API v2. This value the same for - everyone, so is hard-coded here and does not need to be configured by the user - - Returns: - dict: Contains various default items for this output (ie: url) - """ - return {'url': 'https://events.pagerduty.com/v2/enqueue'} - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new PagerDuty - event output. This should be sensitive or unique information for this use-case that - needs to come from the user. - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - PagerDuty also requires a routing_key that represents this integration. This - value should be masked during input and is a credential requirement. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty(description='a short and unique descriptor for this ' - 'PagerDuty integration')), - ('routing_key', - OutputProperty(description='the routing key for this PagerDuty integration', - mask_input=True, - cred_requirement=True)) - ]) - - def dispatch(self, **kwargs): - """Send alert to Pagerduty - - Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - """ - creds = self._load_creds(kwargs['descriptor']) - if not creds: - return self._log_status(False) - - summary = 'StreamAlert Rule Triggered - {}'.format(kwargs['rule_name']) - - details = { - 'rule_description': kwargs['alert']['rule_description'], - 'record': kwargs['alert']['record'] - } - payload = { - 'summary': summary, - 'source': kwargs['alert']['log_source'], - 'severity': 'critical', - 'custom_details': details - } - data = { - 'routing_key': creds['routing_key'], - 'payload': payload, - 'event_action': 'trigger', - 'client': 'StreamAlert' - } - - resp = self._post_request(creds['url'], data, None, True) - success = self._check_http_response(resp) - - return self._log_status(success) - -@output -class PagerDutyIncidentOutput(StreamOutputBase): - """PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2""" - __service__ = 'pagerduty-incident' - INCIDENTS_ENDPOINT = 'incidents' - USERS_ENDPOINT = 'users' - POLICIES_ENDPOINT = 'escalation_policies' - SERVICES_ENDPOINT = 'services' - - def __init__(self, *args, **kwargs): - StreamOutputBase.__init__(self, *args, **kwargs) - self._base_url = None - self._headers = None - self._escalation_policy = None - - @classmethod - def _get_default_properties(cls): - """Get the standard url used for PagerDuty Incidents API v2. This value the same for - everyone, so is hard-coded here and does not need to be configured by the user - - Returns: - dict: Contains various default items for this output (ie: url) - """ - return {'api': 'https://api.pagerduty.com'} - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new PagerDuty - event output. This should be sensitive or unique information for this use-case that - needs to come from the user. - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - PagerDuty also requires a routing_key that represents this integration. This - value should be masked during input and is a credential requirement. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty(description='a short and unique descriptor for this ' - 'PagerDuty integration')), - ('token', - OutputProperty(description='the token for this PagerDuty integration', - mask_input=True, - cred_requirement=True)), - ('service_key', - OutputProperty(description='the service key for this PagerDuty integration', - mask_input=True, - cred_requirement=True)), - ('escalation_policy', - OutputProperty(description='the name of the default escalation policy')), - ('email_from', - OutputProperty(description='valid user email from the PagerDuty ' - 'account linked to the token', - cred_requirement=True)) - ]) - - @staticmethod - def _get_endpoint(base_url, endpoint): - """Helper to get the full url for a PagerDuty Incidents endpoint. - - Args: - base_url (str): Base URL for the API - endpoint (str): Endpoint that we want the full URL for - - Returns: - str: Full URL of the provided endpoint - """ - return os.path.join(base_url, endpoint) - - def _check_exists(self, filter_str, url, target_key, get_id=True): - """Generic method to run a search in the PagerDuty REST API and return the id - of the first occurence from the results. - - Args: - filter_str (str): The query filter to search for in the API - url (str): The url to send the requests to in the API - target_key (str): The key to extract in the returned results - get_id (boolean): Whether to generate a dict with result and reference - - Returns: - str: ID of the targeted element that matches the provided filter or - True/False whether a matching element exists or not. - """ - params = { - 'query': '{}'.format(filter_str) - } - resp = self._get_request(url, params, self._headers, False) - - if not self._check_http_response(resp): - return False - - response = resp.json() - if not response: - return False - - if not get_id: - return True - - # If there are results, get the first occurence from the list - return response[target_key][0]['id'] if target_key in response else False - - def _user_verify(self, user, get_id=True): - """Method to verify the existance of an user with the API - - Args: - user (str): User to query about in the API. - get_id (boolean): Whether to generate a dict with result and reference - - Returns: - dict or False: JSON object be used in the API call, containing the user_id - and user_reference. False if user is not found - """ - return self._item_verify(user, self.USERS_ENDPOINT, 'user_reference', get_id) - - def _policy_verify(self, policy, default_policy): - """Method to verify the existance of a escalation policy with the API - - Args: - policy (str): Escalation policy to query about in the API - default_policy (str): Escalation policy to use if the first one is not verified - - Returns: - dict: JSON object be used in the API call, containing the policy_id - and escalation_policy_reference - """ - verified = self._item_verify(policy, self.POLICIES_ENDPOINT, 'escalation_policy_reference') - - # If the escalation policy provided is not verified in the API, use the default - if verified: - return verified - - return self._item_verify(default_policy, self.POLICIES_ENDPOINT, - 'escalation_policy_reference') - - def _service_verify(self, service): - """Method to verify the existance of a service with the API - - Args: - service (str): Service to query about in the API - - Returns: - dict: JSON object be used in the API call, containing the service_id - and the service_reference - """ - return self._item_verify(service, self.SERVICES_ENDPOINT, 'service_reference') - - def _item_verify(self, item_str, item_key, item_type, get_id=True): - """Method to verify the existance of an item with the API - - Args: - item_str (str): Service to query about in the API - item_key (str): Endpoint/key to be extracted from search results - item_type (str): Type of item reference to be returned - get_id (boolean): Whether to generate a dict with result and reference - - Returns: - dict: JSON object be used in the API call, containing the item id - and the item reference, True if it just exists or False if it fails - """ - item_url = self._get_endpoint(self._base_url, item_key) - item_id = self._check_exists(item_str, item_url, item_key, get_id) - if not item_id: - LOGGER.info('%s not found in %s, %s', item_str, item_key, self.__service__) - return False - - if get_id: - return {'id': item_id, 'type': item_type} - - return item_id - - def _incident_assignment(self, context): - """Method to determine if the incident gets assigned to a user or an escalation policy - - Args: - context (dict): Context provided in the alert record - - Returns: - tuple: assigned_key (str), assigned_value (dict to assign incident to an escalation - policy or array of dicts to assign incident to users) - """ - # Check if a user to assign the incident is provided - user_to_assign = context.get('assigned_user', False) - - # If provided, verify the user and get the id from API - if user_to_assign: - user_assignee = self._user_verify(user_to_assign) - # User is verified, return tuple - if user_assignee: - return 'assignments', [{'assignee': user_assignee}] - - # If escalation policy was not provided, use default one - policy_to_assign = context.get('assigned_policy', self._escalation_policy) - - # Verify escalation policy, return tuple - return 'escalation_policy', self._policy_verify(policy_to_assign, self._escalation_policy) - - def dispatch(self, **kwargs): - """Send incident to Pagerduty Incidents API v2 - - Keyword Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - alert['context'] (dict): Provides user or escalation policy - """ - creds = self._load_creds(kwargs['descriptor']) - if not creds: - return self._log_status(False) - - # Cache base_url - self._base_url = creds['api'] - - # Preparing headers for API calls - self._headers = { - 'Authorization': 'Token token={}'.format(creds['token']), - 'Accept': 'application/vnd.pagerduty+json;version=2' - } - - # Get user email to be added as From header and verify - user_email = creds['email_from'] - if not self._user_verify(user_email, False): - LOGGER.error('Could not verify header From: %s, %s', user_email, self.__service__) - return self._log_status(False) - - # Add From to the headers after verifying - self._headers['From'] = user_email - - # Cache default escalation policy - self._escalation_policy = creds['escalation_policy'] - - # Extracting context data to assign the incident - rule_context = kwargs['alert'].get('context', {}) - if rule_context: - rule_context = rule_context.get(self.__service__, {}) - - # Incident assignment goes in this order: - # Provided user -> provided policy -> default policy - assigned_key, assigned_value = self._incident_assignment(rule_context) - - # Start preparing the incident JSON blob to be sent to the API - incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name']) - incident_body = { - 'type': 'incident_body', - 'details': kwargs['alert']['rule_description'] - } - # We need to get the service id from the API - incident_service = self._service_verify(creds['service_key']) - incident = { - 'incident': { - 'type': 'incident', - 'title': incident_title, - 'service': incident_service, - 'body': incident_body, - assigned_key: assigned_value - } - } - incidents_url = self._get_endpoint(self._base_url, self.INCIDENTS_ENDPOINT) - resp = self._post_request(incidents_url, incident, self._headers, True) - success = self._check_http_response(resp) - - return self._log_status(success) - -@output -class PhantomOutput(StreamOutputBase): - """PhantomOutput handles all alert dispatching for Phantom""" - __service__ = 'phantom' - CONTAINER_ENDPOINT = 'rest/container' - ARTIFACT_ENDPOINT = 'rest/artifact' - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new Phantom - output. This should be sensitive or unique information for this use-case that needs - to come from the user. - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - Phantom also requires a ph_auth_token that represnts an authorization token for this - integration and a user provided url to use for alert dispatching. These values should be - masked during input and are credential requirements. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty(description='a short and unique descriptor for this ' - 'Phantom integration')), - ('ph_auth_token', - OutputProperty(description='the auth token for this Phantom integration', - mask_input=True, - cred_requirement=True)), - ('url', - OutputProperty(description='the endpoint url for this Phantom integration', - mask_input=True, - cred_requirement=True)) - ]) - - def _check_container_exists(self, rule_name, container_url, headers): - """Check to see if a Phantom container already exists for this rule - - Args: - rule_name (str): The name of the rule that triggered the alert - container_url (str): The constructed container url for this Phantom instance - headers (dict): A dictionary containing header parameters - - Returns: - int: ID of an existing Phantom container for this rule where the alerts - will be sent or False if a matching container does not yet exists - """ - # Limit the query to 1 page, since we only care if one container exists with - # this name. - params = { - '_filter_name': '"{}"'.format(rule_name), - 'page_size': 1 - } - resp = self._get_request(container_url, params, headers, False) - if not self._check_http_response(resp): - return False - - response = resp.json() - - # If the count == 0 then we know there are no containers with this name and this - # will evaluate to False. Otherwise there is at least one item in the list - # of 'data' with a container id we can use - return response and response.get('count') and response.get('data')[0]['id'] - - def _setup_container(self, rule_name, rule_description, base_url, headers): - """Establish a Phantom container to write the alerts to. This checks to see - if an appropriate containers exists first and returns the ID if so. - - Args: - rule_name (str): The name of the rule that triggered the alert - base_url (str): The base url for this Phantom instance - headers (dict): A dictionary containing header parameters - - Returns: - int: ID of the Phantom container where the alerts will be sent - or False if there is an issue getting the container id - """ - container_url = os.path.join(base_url, self.CONTAINER_ENDPOINT) - - # Check to see if there is a container already created for this rule name - existing_id = self._check_container_exists(rule_name, container_url, headers) - if existing_id: - return existing_id - - # Try to use the rule_description from the rule as the container description - ph_container = {'name': rule_name, 'description': rule_description} - resp = self._post_request(container_url, ph_container, headers, False) - - if not self._check_http_response(resp): - return False - - response = resp.json() - - return response and response.get('id') - - def dispatch(self, **kwargs): - """Send alert to Phantom - - Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - """ - creds = self._load_creds(kwargs['descriptor']) - if not creds: - return self._log_status(False) - - headers = {"ph-auth-token": creds['ph_auth_token']} - rule_desc = kwargs['alert']['rule_description'] - container_id = self._setup_container(kwargs['rule_name'], rule_desc, - creds['url'], headers) - - LOGGER.debug('sending alert to Phantom container with id %s', container_id) - - success = False - if container_id: - artifact = {'cef': kwargs['alert']['record'], - 'container_id': container_id, - 'data': kwargs['alert'], - 'name': 'Phantom Artifact', - 'label': 'Alert'} - artifact_url = os.path.join(creds['url'], self.ARTIFACT_ENDPOINT) - resp = self._post_request(artifact_url, artifact, headers, False) - - success = self._check_http_response(resp) - - return self._log_status(success) - - -@output -class SlackOutput(StreamOutputBase): - """SlackOutput handles all alert dispatching for Slack""" - __service__ = 'slack' - # Slack recommends no messages larger than 4000 bytes. This does not account for unicode - MAX_MESSAGE_SIZE = 4000 - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new Slack - output. This should be sensitive or unique information for this use-case that needs - to come from the user. - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - Slack also requires a user provided 'webhook' url that is comprised of the slack api url - and the unique integration key for this output. This value should be should be masked - during input and is a credential requirement. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty(description='a short and unique descriptor for this Slack integration ' - '(ie: channel, group, etc)')), - ('url', - OutputProperty(description='the full Slack webhook url, including the secret', - mask_input=True, - cred_requirement=True)) - ]) - - @classmethod - def _format_message(cls, rule_name, alert): - """Format the message to be sent to slack. - - Args: - rule_name (str): The name of the rule that triggered the alert - alert: Alert relevant to the triggered rule - - Returns: - dict: message with attachments to send to Slack. - The message will look like: - StreamAlert Rule Triggered: rule_name - Rule Description: - This will be the docstring from the rule, sent as the rule_description - - Record (Part 1 of 2): - ... - """ - # Convert the alert we have to a nicely formatted string for slack - alert_text = '\n'.join(cls._json_to_slack_mrkdwn(alert['record'], 0)) - # Slack requires escaping the characters: '&', '>' and '<' and cgi does just that - alert_text = cgi.escape(alert_text) - messages = [] - index = cls.MAX_MESSAGE_SIZE - while alert_text != '': - if len(alert_text) <= index: - messages.append(alert_text) - break - - # Find the closest line break prior to this index - while index > 1 and alert_text[index] != '\n': - index -= 1 - - # Append the message part up until this index, and move to the next chunk - messages.append(alert_text[:index]) - alert_text = alert_text[index+1:] - - index = cls.MAX_MESSAGE_SIZE - - header_text = '*StreamAlert Rule Triggered: {}*'.format(rule_name) - full_message = { - 'text': header_text, - 'mrkdwn': True, - 'attachments': [] - } - - for index, message in enumerate(messages): - title = 'Record:' - if len(messages) > 1: - title = 'Record (Part {} of {}):'.format(index+1, len(messages)) - rule_desc = '' - # Only print the rule description on the first attachment - if index == 0: - rule_desc = alert['rule_description'] - rule_desc = '*Rule Description:*\n{}\n'.format(rule_desc) - - # Add this attachemnt to the full message array of attachments - full_message['attachments'].append({ - 'fallback': header_text, - 'color': '#b22222', - 'pretext': rule_desc, - 'title': title, - 'text': message, - 'mrkdwn_in': ['text', 'pretext'] - }) - - # Return the json dict payload to be sent to slack - return full_message - - @classmethod - def _json_to_slack_mrkdwn(cls, json_values, indent_count): - """Translate a json object into a more human-readable list of lines - This will handle recursion of all nested maps and lists within the object - - Args: - json_values: variant to be translated (could be json map, list, etc) - indent_count (int): Number of tabs to prefix each line with - - Returns: - list: strings that have been properly tabbed and formatted for printing - """ - tab = '\t' - all_lines = [] - if isinstance(json_values, dict): - all_lines = cls._json_map_to_text(json_values, tab, indent_count) - elif isinstance(json_values, list): - all_lines = cls._json_list_to_text(json_values, tab, indent_count) - else: - all_lines.append('{}'.format(json_values)) - - return all_lines - - @classmethod - def _json_map_to_text(cls, json_values, tab, indent_count): - """Translate a map from json (dict) into a more human-readable list of lines - This will handle recursion of all nested maps and lists within the map - - Args: - json_values (dict): dictionary to be iterated over and formatted - tab (str): string value to use for indentation - indent_count (int): Number of tabs to prefix each line with - - Returns: - list: strings that have been properly tabbed and formatted for printing - """ - all_lines = [] - for key, value in json_values.iteritems(): - if isinstance(value, (dict, list)) and value: - all_lines.append('{}*{}:*'.format(tab*indent_count, key)) - all_lines.extend(cls._json_to_slack_mrkdwn(value, indent_count+1)) - else: - new_lines = cls._json_to_slack_mrkdwn(value, indent_count+1) - if len(new_lines) == 1: - all_lines.append('{}*{}:* {}'.format(tab*indent_count, key, new_lines[0])) - elif new_lines: - all_lines.append('{}*{}:*'.format(tab*indent_count, key)) - all_lines.extend(new_lines) - else: - all_lines.append('{}*{}:* {}'.format(tab*indent_count, key, value)) - - return all_lines - - @classmethod - def _json_list_to_text(cls, json_values, tab, indent_count): - """Translate a list from json into a more human-readable list of lines - This will handle recursion of all nested maps and lists within the list - - Args: - json_values (list): list to be iterated over and formatted - tab (str): string value to use for indentation - indent_count (int): Number of tabs to prefix each line with - - Returns: - list: strings that have been properly tabbed and formatted for printing - """ - all_lines = [] - for index, value in enumerate(json_values): - if isinstance(value, (dict, list)) and value: - all_lines.append('{}*[{}]*'.format(tab*indent_count, index+1)) - all_lines.extend(cls._json_to_slack_mrkdwn(value, indent_count+1)) - else: - new_lines = cls._json_to_slack_mrkdwn(value, indent_count+1) - if len(new_lines) == 1: - all_lines.append('{}*[{}]* {}'.format(tab*indent_count, index+1, new_lines[0])) - elif new_lines: - all_lines.append('{}*[{}]*'.format(tab*indent_count, index+1)) - all_lines.extend(new_lines) - else: - all_lines.append('{}*[{}]* {}'.format(tab*indent_count, index+1, value)) - - return all_lines - - def dispatch(self, **kwargs): - """Send alert text to Slack - - Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - """ - creds = self._load_creds(kwargs['descriptor']) - if not creds: - return self._log_status(False) - - slack_message = self._format_message(kwargs['rule_name'], kwargs['alert']) - - resp = self._post_request(creds['url'], slack_message) - success = self._check_http_response(resp) - - return self._log_status(success) - - -class AWSOutput(StreamOutputBase): - """Subclass to be inherited from for all AWS service outputs""" - def format_output_config(self, service_config, values): - """Format the output configuration for this AWS service to be written to disk - - AWS services are stored as a dictionary within the config instead of a list so - we have access to the AWS value (arn/bucket name/etc) for Terraform - - Args: - service_config (dict): The actual outputs config that has been read in - values (OrderedDict): Contains all the OutputProperty items for this service - - Returns: - dict{: }: Updated dictionary of descriptors and - values for this AWS service needed for the output configuration - NOTE: S3 requires the bucket name, not an arn, for this value. - Instead of implementing this differently in subclasses, all AWSOutput - subclasses should use a generic 'aws_value' to store the value for the - descriptor used in configuration - """ - return dict(service_config.get(self.__service__, {}), - **{values['descriptor'].value: values['aws_value'].value}) - - @abstractmethod - def dispatch(self, **kwargs): - """Placeholder for implementation in the subclasses""" - pass - - -@output -class KinesisFirehoseOutput(AWSOutput): - """High throughput Alert delivery to AWS S3""" - MAX_RECORD_SIZE = 1000 * 1000 - MAX_BACKOFF_ATTEMPTS = 3 - - __service__ = 'aws-firehose' - __aws_client__ = None - - def get_user_defined_properties(self): - """Properties asssigned by the user when configuring a new Firehose output - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty( - description='a short and unique descriptor for this Firehose Delivery Stream')), - ('aws_value', - OutputProperty(description='the Firehose Delivery Stream name')) - ]) - - def dispatch(self, **kwargs): - """Send alert to a Kinesis Firehose Delivery Stream - - Keyword Args: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - - Returns: - bool: Indicates a successful or failed dispatch of the alert - """ - @backoff.on_exception(backoff.fibo, - ClientError, - max_tries=self.MAX_BACKOFF_ATTEMPTS, - jitter=backoff.full_jitter, - on_backoff=backoff_handler, - on_success=success_handler, - on_giveup=giveup_handler) - def _firehose_request_wrapper(json_alert, delivery_stream): - """Make the PutRecord request to Kinesis Firehose with backoff - - Args: - json_alert (str): The JSON dumped alert body - delivery_stream (str): The Firehose Delivery Stream to send to - - Returns: - dict: Firehose response in the format below - {'RecordId': 'string'} - """ - return self.__aws_client__.put_record(DeliveryStreamName=delivery_stream, - Record={'Data': json_alert}) - - if self.__aws_client__ is None: - self.__aws_client__ = boto3.client('firehose', region_name=self.region) - - json_alert = json.dumps(kwargs['alert'], separators=(',', ':')) + '\n' - if len(json_alert) > self.MAX_RECORD_SIZE: - LOGGER.error('Alert too large to send to Firehose: \n%s...', json_alert[0:1000]) - return False - - delivery_stream = self.config[self.__service__][kwargs['descriptor']] - LOGGER.info('Sending alert [%s] to aws-firehose:%s', - kwargs['rule_name'], - delivery_stream) - - resp = _firehose_request_wrapper(json_alert, delivery_stream) - - if resp.get('RecordId'): - LOGGER.info('Alert [%s] successfully sent to aws-firehose:%s with RecordId:%s', - kwargs['rule_name'], - delivery_stream, - resp['RecordId']) - - return self._log_status(resp) - - -@output -class S3Output(AWSOutput): - """S3Output handles all alert dispatching for AWS S3""" - __service__ = 'aws-s3' - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new S3 - output. This should be sensitive or unique information for this use-case that needs - to come from the user. - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - S3 also requires a user provided bucket name to be used for this service output. This - value should not be masked during input and is not a credential requirement - that needs encrypted. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty( - description='a short and unique descriptor for this S3 bucket (ie: bucket name)')), - ('aws_value', - OutputProperty(description='the AWS S3 bucket name to use for this S3 configuration')) - ]) - - def dispatch(self, **kwargs): - """Send alert to an S3 bucket - - Organizes alert into the following folder structure: - service/entity/rule_name/datetime.json - The alert gets dumped to a JSON string - - Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - """ - alert = kwargs['alert'] - service = alert['source_service'] - entity = alert['source_entity'] - - current_date = datetime.now() - - s3_alert = alert - # JSON dump the alert to retain a consistent alerts schema across log types. - # This will get replaced by a UUID which references a record in a - # different table in the future. - s3_alert['record'] = json.dumps(s3_alert['record']) - alert_string = json.dumps(s3_alert) - - bucket = self.config[self.__service__][kwargs['descriptor']] - - # Prefix with alerts to account for generic non-streamalert buckets - # Produces the following key format: - # alerts/dt=2017-01-25-00/kinesis_my-stream_my-rule_uuid.json - # Keys need to be unique to avoid object overwriting - key = 'alerts/dt={}/{}_{}_{}_{}.json'.format( - current_date.strftime('%Y-%m-%d-%H'), - service, - entity, - alert['rule_name'], - uuid.uuid4() - ) - - LOGGER.debug('Sending alert to S3 bucket %s with key %s', bucket, key) - - client = boto3.client('s3', region_name=self.region) - resp = client.put_object(Body=alert_string, - Bucket=bucket, - Key=key) - - return self._log_status(resp) - - -@output -class LambdaOutput(AWSOutput): - """LambdaOutput handles all alert dispatching to AWS Lambda""" - __service__ = 'aws-lambda' - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new Lambda - output. This should be sensitive or unique information for this use-case that needs - to come from the user. - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - Sending to Lambda also requires a user provided Lambda function name and optional qualifier - (if applicabale for the user's use case). A fully-qualified AWS ARN is also acceptable for - this value. This value should not be masked during input and is not a credential requirement - that needs encrypted. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty(description='a short and unique descriptor for this Lambda function ' - 'configuration (ie: abbreviated name)')), - ('aws_value', - OutputProperty(description='the AWS Lambda function name, with the optional ' - 'qualifier (aka \'alias\'), to use for this ' - 'configuration (ie: output_function:qualifier)', - input_restrictions={' '})), - ]) - - def dispatch(self, **kwargs): - """Send alert to a Lambda function - - The alert gets dumped to a JSON string to be sent to the Lambda function - - Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - """ - alert = kwargs['alert'] - alert_string = json.dumps(alert['record']) - function_name = self.config[self.__service__][kwargs['descriptor']] - - # Check to see if there is an optional qualifier included here - # Acceptable values for the output configuration are the full ARN, - # a function name followed by a qualifier, or just a function name: - # 'arn:aws:lambda:aws-region:acct-id:function:function-name:prod' - # 'function-name:prod' - # 'function-name' - # Checking the length of the list for 2 or 8 should account for all - # times a qualifier is provided. - parts = function_name.split(':') - if len(parts) == 2 or len(parts) == 8: - function = parts[-2] - qualifier = parts[-1] - else: - function = parts[-1] - qualifier = None - - LOGGER.debug('Sending alert to Lambda function %s', function_name) - - client = boto3.client('lambda', region_name=self.region) - # Use the qualifier if it's available. Passing an empty qualifier in - # with `Qualifier=''` or `Qualifier=None` does not work and thus we - # have to perform different calls to client.invoke(). - if qualifier: - resp = client.invoke(FunctionName=function, - InvocationType='Event', - Payload=alert_string, - Qualifier=qualifier) - else: - resp = client.invoke(FunctionName=function, - InvocationType='Event', - Payload=alert_string) - - return self._log_status(resp) - -@output -class JiraOutput(StreamOutputBase): - """JiraOutput handles all alert dispatching for Jira""" - __service__ = 'jira' - - DEFAULT_HEADERS = {'Content-Type': 'application/json'} - LOGIN_ENDPOINT = 'rest/auth/1/session' - SEARCH_ENDPOINT = 'rest/api/2/search' - ISSUE_ENDPOINT = 'rest/api/2/issue' - COMMENT_ENDPOINT = 'rest/api/2/issue/{}/comment' - - def __init__(self, *args, **kwargs): - StreamOutputBase.__init__(self, *args, **kwargs) - self._base_url = None - self._auth_cookie = None - - def get_user_defined_properties(self): - """Get properties that must be asssigned by the user when configuring a new Jira - output. This should be sensitive or unique information for this use-case that needs - to come from the user. - - Every output should return a dict that contains a 'descriptor' with a description of the - integration being configured. - - Jira requires a username, password, URL, project key, and issue type for alert dispatching. - These values should be masked during input and are credential requirements. - - An additional parameter 'aggregate' is used to determine if alerts are aggregated into a - single Jira issue based on the StreamAlert rule. - - Returns: - OrderedDict: Contains various OutputProperty items - """ - return OrderedDict([ - ('descriptor', - OutputProperty(description='a short and unique descriptor for this ' - 'Jira integration')), - ('username', - OutputProperty(description='the Jira username', - mask_input=True, - cred_requirement=True)), - ('password', - OutputProperty(description='the Jira password', - mask_input=True, - cred_requirement=True)), - ('url', - OutputProperty(description='the Jira url', - mask_input=True, - cred_requirement=True)), - ('project_key', - OutputProperty(description='the Jira project key', - mask_input=False, - cred_requirement=True)), - ('issue_type', - OutputProperty(description='the Jira issue type', - mask_input=False, - cred_requirement=True)), - ('aggregate', - OutputProperty(description='the Jira aggregation behavior to aggregate ' - 'alerts by rule name (yes/no)', - mask_input=False, - cred_requirement=True)) - ]) - - @classmethod - def _get_default_headers(cls): - """Class method to be used to pass the default headers""" - return cls.DEFAULT_HEADERS.copy() - - def _get_headers(self): - """Instance method used to pass the default headers plus the auth cookie""" - return dict(self._get_default_headers(), **{'cookie': self._auth_cookie}) - - def _search_jira(self, jql, fields=None, max_results=100, validate_query=True): - """Search Jira for issues using a JQL query - - Args: - jql (str): The JQL query - fields (list): List of fields to return for each issue - max_results (int): Maximum number of results to return - validate_query (bool): Whether to validate the JQL query or not - - Returns: - list: list of issues matching JQL query - """ - search_url = os.path.join(self._base_url, self.SEARCH_ENDPOINT) - params = { - 'jql': jql, - 'maxResults': max_results, - 'validateQuery': validate_query, - 'fields': fields - } - - resp = self._get_request(search_url, - params=params, - headers=self._get_headers(), - verify=False) - - success = self._check_http_response(resp) - if not success: - return [] - - return resp.json()['issues'] - - def _create_comment(self, issue_id, comment): - """Add a comment to an existing issue - - Args: - issue_id (str): The existing issue ID or key - comment (str): The body of the comment - - Returns: - int: ID of the created comment or False if unsuccessful - """ - comment_url = os.path.join(self._base_url, self.COMMENT_ENDPOINT.format(issue_id)) - resp = self._post_request(comment_url, - data={'body': comment}, - headers=self._get_headers(), - verify=False) - - success = self._check_http_response(resp) - if not success: - return False - - return resp.json()['id'] - - def _get_comments(self, issue_id): - """Get all comments for an existing Jira issue - - Args: - issue_id (str): The existing issue ID or key - - Returns: - list: List of comments associated with a Jira issue - """ - comment_url = os.path.join(self._base_url, self.COMMENT_ENDPOINT.format(issue_id)) - resp = self._get_request(comment_url, - headers=self._get_headers(), - verify=False) - - success = self._check_http_response(resp) - if not success: - return [] - - return resp.json()['comments'] - - def _get_existing_issue(self, issue_summary, project_key): - """Find an existing Jira issue based on the issue summary - - Args: - issue_summary (str): The Jira issue summary - project_key (str): The Jira project to search - - Returns: - int: ID of the found issue or False if existing issue does not exist - """ - jql = 'summary ~ "{}" and project="{}"'.format(issue_summary, project_key) - resp = self._search_jira(jql, fields=['id', 'summary'], max_results=1) - jira_id = False - - try: - jira_id = int(resp[0]['id']) - except (IndexError, KeyError): - LOGGER.debug('Existing Jira issue not found') - - return jira_id - - def _create_issue(self, issue_name, project_key, issue_type, description): - """Create a Jira issue to write alerts to. Alert is written to the "description" - field of an issue. - - Args: - issue_name (str): The name of the Jira issue - project_key (str): The Jira project key which issues will be associated with - issue_type (str): The type of issue being created - description (str): The body of text which describes the issue - - Returns: - int: ID of the created issue or False if unsuccessful - """ - issue_url = os.path.join(self._base_url, self.ISSUE_ENDPOINT) - issue_body = { - 'fields': { - 'project': { - 'key': project_key - }, - 'summary': issue_name, - 'description': description, - 'issuetype': { - 'name': issue_type - } - } - } - - resp = self._post_request(issue_url, - data=issue_body, - headers=self._get_headers(), - verify=False) - - success = self._check_http_response(resp) - if not success: - return False - - return resp.json()['id'] - - def _establish_session(self, username, password): - """Establish a cookie based Jira session via basic user auth. - - Args: - username (str): The Jira username used for establishing the session - password (str): The Jira password used for establishing the session - - Returns: - str: Header value intended to be passed with every subsequent Jira request - or False if unsuccessful - """ - login_url = os.path.join(self._base_url, self.LOGIN_ENDPOINT) - auth_info = {'username': username, 'password': password} - - resp = self._post_request(login_url, - data=auth_info, - headers=self._get_default_headers(), - verify=False) - - success = self._check_http_response(resp) - if not success: - LOGGER.error("Failed to authenticate to Jira") - return False - resp_dict = resp.json() - - return '{}={}'.format(resp_dict['session']['name'], - resp_dict['session']['value']) - - def dispatch(self, **kwargs): - """Send alert to Jira - - Args: - **kwargs: consists of any combination of the following items: - descriptor (str): Service descriptor (ie: slack channel, pd integration) - rule_name (str): Name of the triggered rule - alert (dict): Alert relevant to the triggered rule - """ - creds = self._load_creds(kwargs['descriptor']) - if not creds: - return self._log_status(False) - - issue_id = None - comment_id = None - issue_summary = 'StreamAlert {}'.format(kwargs['rule_name']) - alert_body = '{{code:JSON}}{}{{code}}'.format(json.dumps(kwargs['alert'])) - self._base_url = creds['url'] - self._auth_cookie = self._establish_session(creds['username'], creds['password']) - - # Validate successful authentication - if not self._auth_cookie: - return self._log_status(False) - - # If aggregation is enabled, attempt to add alert to an existing issue. If a - # failure occurs in this block, creation of a new Jira issue will be attempted. - if creds.get('aggregate', '').lower() == 'yes': - issue_id = self._get_existing_issue(issue_summary, creds['project_key']) - if issue_id: - comment_id = self._create_comment(issue_id, alert_body) - if comment_id: - LOGGER.debug('Sending alert to an existing Jira issue %s with comment %s', - issue_id, - comment_id) - return self._log_status(True) - else: - LOGGER.error('Encountered an error when adding alert to existing ' - 'Jira issue %s. Attempting to create new Jira issue.', - issue_id) - - # Create a new Jira issue - issue_id = self._create_issue(issue_summary, - creds['project_key'], - creds['issue_type'], - alert_body) - if issue_id: - LOGGER.debug('Sending alert to a new Jira issue %s', issue_id) - - return self._log_status(issue_id or comment_id) diff --git a/stream_alert/alert_processor/outputs/__init__.py b/stream_alert/alert_processor/outputs/__init__.py new file mode 100644 index 000000000..b29cdf550 --- /dev/null +++ b/stream_alert/alert_processor/outputs/__init__.py @@ -0,0 +1,18 @@ +"""Initialize logging for the alert processor.""" +import importlib +import os + +# Import all files containing subclasses of OutputDispatcher, skipping the common base class +for output_file in os.listdir(os.path.dirname(__file__)): + # Skip the common base file and any non-py files + if output_file.startswith(('__init__', 'output_base')) or not output_file.endswith('.py'): + continue + + full_import = [ + 'stream_alert', + 'alert_processor', + 'outputs', + os.path.splitext(output_file)[0] + ] + + importlib.import_module('.'.join(full_import)) diff --git a/stream_alert/alert_processor/outputs/aws.py b/stream_alert/alert_processor/outputs/aws.py new file mode 100644 index 000000000..418adb086 --- /dev/null +++ b/stream_alert/alert_processor/outputs/aws.py @@ -0,0 +1,307 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from abc import abstractmethod +from collections import OrderedDict +from datetime import datetime +import json +import uuid + +import backoff +from botocore.exceptions import ClientError +import boto3 + +from stream_alert.alert_processor import LOGGER +from stream_alert.alert_processor.outputs.output_base import ( + OutputDispatcher, + OutputProperty, + StreamAlertOutput +) +from stream_alert.shared.backoff_handlers import ( + backoff_handler, + success_handler, + giveup_handler +) + + +class AWSOutput(OutputDispatcher): + """Subclass to be inherited from for all AWS service outputs""" + def format_output_config(self, service_config, values): + """Format the output configuration for this AWS service to be written to disk + + AWS services are stored as a dictionary within the config instead of a list so + we have access to the AWS value (arn/bucket name/etc) for Terraform + + Args: + service_config (dict): The actual outputs config that has been read in + values (OrderedDict): Contains all the OutputProperty items for this service + + Returns: + dict{: }: Updated dictionary of descriptors and + values for this AWS service needed for the output configuration + NOTE: S3 requires the bucket name, not an arn, for this value. + Instead of implementing this differently in subclasses, all AWSOutput + subclasses should use a generic 'aws_value' to store the value for the + descriptor used in configuration + """ + return dict(service_config.get(self.__service__, {}), + **{values['descriptor'].value: values['aws_value'].value}) + + @abstractmethod + def dispatch(self, **kwargs): + """Placeholder for implementation in the subclasses""" + pass + + +@StreamAlertOutput +class KinesisFirehoseOutput(AWSOutput): + """High throughput Alert delivery to AWS S3""" + MAX_RECORD_SIZE = 1000 * 1000 + MAX_BACKOFF_ATTEMPTS = 3 + + __service__ = 'aws-firehose' + __aws_client__ = None + + def get_user_defined_properties(self): + """Properties asssigned by the user when configuring a new Firehose output + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty( + description='a short and unique descriptor for this Firehose Delivery Stream')), + ('aws_value', + OutputProperty(description='the Firehose Delivery Stream name')) + ]) + + def dispatch(self, **kwargs): + """Send alert to a Kinesis Firehose Delivery Stream + + Keyword Args: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + + Returns: + bool: Indicates a successful or failed dispatch of the alert + """ + @backoff.on_exception(backoff.fibo, + ClientError, + max_tries=self.MAX_BACKOFF_ATTEMPTS, + jitter=backoff.full_jitter, + on_backoff=backoff_handler, + on_success=success_handler, + on_giveup=giveup_handler) + def _firehose_request_wrapper(json_alert, delivery_stream): + """Make the PutRecord request to Kinesis Firehose with backoff + + Args: + json_alert (str): The JSON dumped alert body + delivery_stream (str): The Firehose Delivery Stream to send to + + Returns: + dict: Firehose response in the format below + {'RecordId': 'string'} + """ + return self.__aws_client__.put_record(DeliveryStreamName=delivery_stream, + Record={'Data': json_alert}) + + if self.__aws_client__ is None: + self.__aws_client__ = boto3.client('firehose', region_name=self.region) + + json_alert = json.dumps(kwargs['alert'], separators=(',', ':')) + '\n' + if len(json_alert) > self.MAX_RECORD_SIZE: + LOGGER.error('Alert too large to send to Firehose: \n%s...', json_alert[0:1000]) + return False + + delivery_stream = self.config[self.__service__][kwargs['descriptor']] + LOGGER.info('Sending alert [%s] to aws-firehose:%s', + kwargs['rule_name'], + delivery_stream) + + resp = _firehose_request_wrapper(json_alert, delivery_stream) + + if resp.get('RecordId'): + LOGGER.info('Alert [%s] successfully sent to aws-firehose:%s with RecordId:%s', + kwargs['rule_name'], + delivery_stream, + resp['RecordId']) + + return self._log_status(resp) + + +@StreamAlertOutput +class S3Output(AWSOutput): + """S3Output handles all alert dispatching for AWS S3""" + __service__ = 'aws-s3' + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new S3 + output. This should be sensitive or unique information for this use-case that needs + to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + S3 also requires a user provided bucket name to be used for this service output. This + value should not be masked during input and is not a credential requirement + that needs encrypted. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty( + description='a short and unique descriptor for this S3 bucket (ie: bucket name)')), + ('aws_value', + OutputProperty(description='the AWS S3 bucket name to use for this S3 configuration')) + ]) + + def dispatch(self, **kwargs): + """Send alert to an S3 bucket + + Organizes alert into the following folder structure: + service/entity/rule_name/datetime.json + The alert gets dumped to a JSON string + + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + alert = kwargs['alert'] + service = alert['source_service'] + entity = alert['source_entity'] + + current_date = datetime.now() + + s3_alert = alert + # JSON dump the alert to retain a consistent alerts schema across log types. + # This will get replaced by a UUID which references a record in a + # different table in the future. + s3_alert['record'] = json.dumps(s3_alert['record']) + alert_string = json.dumps(s3_alert) + + bucket = self.config[self.__service__][kwargs['descriptor']] + + # Prefix with alerts to account for generic non-streamalert buckets + # Produces the following key format: + # alerts/dt=2017-01-25-00/kinesis_my-stream_my-rule_uuid.json + # Keys need to be unique to avoid object overwriting + key = 'alerts/dt={}/{}_{}_{}_{}.json'.format( + current_date.strftime('%Y-%m-%d-%H'), + service, + entity, + alert['rule_name'], + uuid.uuid4() + ) + + LOGGER.debug('Sending alert to S3 bucket %s with key %s', bucket, key) + + client = boto3.client('s3', region_name=self.region) + resp = client.put_object(Body=alert_string, + Bucket=bucket, + Key=key) + + return self._log_status(resp) + + +@StreamAlertOutput +class LambdaOutput(AWSOutput): + """LambdaOutput handles all alert dispatching to AWS Lambda""" + __service__ = 'aws-lambda' + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new Lambda + output. This should be sensitive or unique information for this use-case that needs + to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + Sending to Lambda also requires a user provided Lambda function name and optional qualifier + (if applicabale for the user's use case). A fully-qualified AWS ARN is also acceptable for + this value. This value should not be masked during input and is not a credential requirement + that needs encrypted. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this Lambda function ' + 'configuration (ie: abbreviated name)')), + ('aws_value', + OutputProperty(description='the AWS Lambda function name, with the optional ' + 'qualifier (aka \'alias\'), to use for this ' + 'configuration (ie: output_function:qualifier)', + input_restrictions={' '})), + ]) + + def dispatch(self, **kwargs): + """Send alert to a Lambda function + + The alert gets dumped to a JSON string to be sent to the Lambda function + + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + alert = kwargs['alert'] + alert_string = json.dumps(alert['record']) + function_name = self.config[self.__service__][kwargs['descriptor']] + + # Check to see if there is an optional qualifier included here + # Acceptable values for the output configuration are the full ARN, + # a function name followed by a qualifier, or just a function name: + # 'arn:aws:lambda:aws-region:acct-id:function:function-name:prod' + # 'function-name:prod' + # 'function-name' + # Checking the length of the list for 2 or 8 should account for all + # times a qualifier is provided. + parts = function_name.split(':') + if len(parts) == 2 or len(parts) == 8: + function = parts[-2] + qualifier = parts[-1] + else: + function = parts[-1] + qualifier = None + + LOGGER.debug('Sending alert to Lambda function %s', function_name) + + client = boto3.client('lambda', region_name=self.region) + # Use the qualifier if it's available. Passing an empty qualifier in + # with `Qualifier=''` or `Qualifier=None` does not work and thus we + # have to perform different calls to client.invoke(). + if qualifier: + resp = client.invoke(FunctionName=function, + InvocationType='Event', + Payload=alert_string, + Qualifier=qualifier) + else: + resp = client.invoke(FunctionName=function, + InvocationType='Event', + Payload=alert_string) + + return self._log_status(resp) diff --git a/stream_alert/alert_processor/outputs/jira.py b/stream_alert/alert_processor/outputs/jira.py new file mode 100644 index 000000000..537603675 --- /dev/null +++ b/stream_alert/alert_processor/outputs/jira.py @@ -0,0 +1,308 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from collections import OrderedDict +import json +import os + +from stream_alert.alert_processor import LOGGER +from stream_alert.alert_processor.outputs.output_base import ( + OutputDispatcher, + OutputProperty, + StreamAlertOutput +) + +@StreamAlertOutput +class JiraOutput(OutputDispatcher): + """JiraOutput handles all alert dispatching for Jira""" + __service__ = 'jira' + + DEFAULT_HEADERS = {'Content-Type': 'application/json'} + LOGIN_ENDPOINT = 'rest/auth/1/session' + SEARCH_ENDPOINT = 'rest/api/2/search' + ISSUE_ENDPOINT = 'rest/api/2/issue' + COMMENT_ENDPOINT = 'rest/api/2/issue/{}/comment' + + def __init__(self, *args, **kwargs): + OutputDispatcher.__init__(self, *args, **kwargs) + self._base_url = None + self._auth_cookie = None + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new Jira + output. This should be sensitive or unique information for this use-case that needs + to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + Jira requires a username, password, URL, project key, and issue type for alert dispatching. + These values should be masked during input and are credential requirements. + + An additional parameter 'aggregate' is used to determine if alerts are aggregated into a + single Jira issue based on the StreamAlert rule. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this ' + 'Jira integration')), + ('username', + OutputProperty(description='the Jira username', + mask_input=True, + cred_requirement=True)), + ('password', + OutputProperty(description='the Jira password', + mask_input=True, + cred_requirement=True)), + ('url', + OutputProperty(description='the Jira url', + mask_input=True, + cred_requirement=True)), + ('project_key', + OutputProperty(description='the Jira project key', + mask_input=False, + cred_requirement=True)), + ('issue_type', + OutputProperty(description='the Jira issue type', + mask_input=False, + cred_requirement=True)), + ('aggregate', + OutputProperty(description='the Jira aggregation behavior to aggregate ' + 'alerts by rule name (yes/no)', + mask_input=False, + cred_requirement=True)) + ]) + + @classmethod + def _get_default_headers(cls): + """Class method to be used to pass the default headers""" + return cls.DEFAULT_HEADERS.copy() + + def _get_headers(self): + """Instance method used to pass the default headers plus the auth cookie""" + return dict(self._get_default_headers(), **{'cookie': self._auth_cookie}) + + def _search_jira(self, jql, fields=None, max_results=100, validate_query=True): + """Search Jira for issues using a JQL query + + Args: + jql (str): The JQL query + fields (list): List of fields to return for each issue + max_results (int): Maximum number of results to return + validate_query (bool): Whether to validate the JQL query or not + + Returns: + list: list of issues matching JQL query + """ + search_url = os.path.join(self._base_url, self.SEARCH_ENDPOINT) + params = { + 'jql': jql, + 'maxResults': max_results, + 'validateQuery': validate_query, + 'fields': fields + } + + resp = self._get_request(search_url, + params=params, + headers=self._get_headers(), + verify=False) + + success = self._check_http_response(resp) + if not success: + return [] + + return resp.json()['issues'] + + def _create_comment(self, issue_id, comment): + """Add a comment to an existing issue + + Args: + issue_id (str): The existing issue ID or key + comment (str): The body of the comment + + Returns: + int: ID of the created comment or False if unsuccessful + """ + comment_url = os.path.join(self._base_url, self.COMMENT_ENDPOINT.format(issue_id)) + resp = self._post_request(comment_url, + data={'body': comment}, + headers=self._get_headers(), + verify=False) + + success = self._check_http_response(resp) + if not success: + return False + + return resp.json()['id'] + + def _get_comments(self, issue_id): + """Get all comments for an existing Jira issue + + Args: + issue_id (str): The existing issue ID or key + + Returns: + list: List of comments associated with a Jira issue + """ + comment_url = os.path.join(self._base_url, self.COMMENT_ENDPOINT.format(issue_id)) + resp = self._get_request(comment_url, + headers=self._get_headers(), + verify=False) + + success = self._check_http_response(resp) + if not success: + return [] + + return resp.json()['comments'] + + def _get_existing_issue(self, issue_summary, project_key): + """Find an existing Jira issue based on the issue summary + + Args: + issue_summary (str): The Jira issue summary + project_key (str): The Jira project to search + + Returns: + int: ID of the found issue or False if existing issue does not exist + """ + jql = 'summary ~ "{}" and project="{}"'.format(issue_summary, project_key) + resp = self._search_jira(jql, fields=['id', 'summary'], max_results=1) + jira_id = False + + try: + jira_id = int(resp[0]['id']) + except (IndexError, KeyError): + LOGGER.debug('Existing Jira issue not found') + + return jira_id + + def _create_issue(self, issue_name, project_key, issue_type, description): + """Create a Jira issue to write alerts to. Alert is written to the "description" + field of an issue. + + Args: + issue_name (str): The name of the Jira issue + project_key (str): The Jira project key which issues will be associated with + issue_type (str): The type of issue being created + description (str): The body of text which describes the issue + + Returns: + int: ID of the created issue or False if unsuccessful + """ + issue_url = os.path.join(self._base_url, self.ISSUE_ENDPOINT) + issue_body = { + 'fields': { + 'project': { + 'key': project_key + }, + 'summary': issue_name, + 'description': description, + 'issuetype': { + 'name': issue_type + } + } + } + + resp = self._post_request(issue_url, + data=issue_body, + headers=self._get_headers(), + verify=False) + + success = self._check_http_response(resp) + if not success: + return False + + return resp.json()['id'] + + def _establish_session(self, username, password): + """Establish a cookie based Jira session via basic user auth. + + Args: + username (str): The Jira username used for establishing the session + password (str): The Jira password used for establishing the session + + Returns: + str: Header value intended to be passed with every subsequent Jira request + or False if unsuccessful + """ + login_url = os.path.join(self._base_url, self.LOGIN_ENDPOINT) + auth_info = {'username': username, 'password': password} + + resp = self._post_request(login_url, + data=auth_info, + headers=self._get_default_headers(), + verify=False) + + success = self._check_http_response(resp) + if not success: + LOGGER.error("Failed to authenticate to Jira") + return False + resp_dict = resp.json() + + return '{}={}'.format(resp_dict['session']['name'], + resp_dict['session']['value']) + + def dispatch(self, **kwargs): + """Send alert to Jira + + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + issue_id = None + comment_id = None + issue_summary = 'StreamAlert {}'.format(kwargs['rule_name']) + alert_body = '{{code:JSON}}{}{{code}}'.format(json.dumps(kwargs['alert'])) + self._base_url = creds['url'] + self._auth_cookie = self._establish_session(creds['username'], creds['password']) + + # Validate successful authentication + if not self._auth_cookie: + return self._log_status(False) + + # If aggregation is enabled, attempt to add alert to an existing issue. If a + # failure occurs in this block, creation of a new Jira issue will be attempted. + if creds.get('aggregate', '').lower() == 'yes': + issue_id = self._get_existing_issue(issue_summary, creds['project_key']) + if issue_id: + comment_id = self._create_comment(issue_id, alert_body) + if comment_id: + LOGGER.debug('Sending alert to an existing Jira issue %s with comment %s', + issue_id, + comment_id) + return self._log_status(True) + else: + LOGGER.error('Encountered an error when adding alert to existing ' + 'Jira issue %s. Attempting to create new Jira issue.', + issue_id) + + # Create a new Jira issue + issue_id = self._create_issue(issue_summary, + creds['project_key'], + creds['issue_type'], + alert_body) + if issue_id: + LOGGER.debug('Sending alert to a new Jira issue %s', issue_id) + + return self._log_status(issue_id or comment_id) diff --git a/stream_alert/alert_processor/output_base.py b/stream_alert/alert_processor/outputs/output_base.py similarity index 89% rename from stream_alert/alert_processor/output_base.py rename to stream_alert/alert_processor/outputs/output_base.py index e6472801d..937572409 100644 --- a/stream_alert/alert_processor/output_base.py +++ b/stream_alert/alert_processor/outputs/output_base.py @@ -36,7 +36,44 @@ class OutputRequestFailure(Exception): """OutputRequestFailure handles any HTTP failures""" -class StreamOutputBase(object): +class StreamAlertOutput(object): + """Class to be used as a decorator to register all OutputDispatcher subclasses""" + _outputs = {} + + def __new__(cls, output): + StreamAlertOutput._outputs[output.__service__] = output + return output + + @classmethod + def get_output_dispatcher(cls, service, region, function_name, config): + """Returns the subclass that should handle this particular service + + Args: + service (str): The service identifier for this output + region (str): The AWS region to use for some output types + function_name (str): The invoking AWS Lambda function name + config (dict): The loaded output configuration dict + + Returns: + OutputDispatcher: Subclass of OutputDispatcher to use for sending alerts + """ + try: + return cls._outputs[service](region, function_name, config) + except KeyError: + LOGGER.error('Designated output service [%s] does not exist', service) + + @classmethod + def get_all_outputs(cls): + """Return a copy of the cache containing all of the output subclasses + + Returns: + dict: Cached dictionary of all registered StreamAlertOutputs where + the key is the service and the value is the class object + """ + return cls._outputs.copy() + + +class OutputDispatcher(object): """StreamOutputBase is the base class to handle routing alerts to outputs Public methods: diff --git a/stream_alert/alert_processor/outputs/pagerduty.py b/stream_alert/alert_processor/outputs/pagerduty.py new file mode 100644 index 000000000..d1b03ee14 --- /dev/null +++ b/stream_alert/alert_processor/outputs/pagerduty.py @@ -0,0 +1,421 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from collections import OrderedDict +import os + +from stream_alert.alert_processor import LOGGER +from stream_alert.alert_processor.outputs.output_base import ( + OutputDispatcher, + OutputProperty, + StreamAlertOutput +) + + +@StreamAlertOutput +class PagerDutyOutput(OutputDispatcher): + """PagerDutyOutput handles all alert dispatching for PagerDuty Events API v1""" + __service__ = 'pagerduty' + + @classmethod + def _get_default_properties(cls): + """Get the standard url used for PagerDuty. This value the same for everyone, so + is hard-coded here and does not need to be configured by the user + Returns: + dict: Contains various default items for this output (ie: url) + """ + return {'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'} + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new PagerDuty + output. This should be sensitive or unique information for this use-case that needs + to come from the user. + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + PagerDuty also requires a service_key that represnts this integration. This + value should be masked during input and is a credential requirement. + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this ' + 'PagerDuty integration')), + ('service_key', + OutputProperty(description='the service key for this PagerDuty integration', + mask_input=True, + cred_requirement=True)) + ]) + + def dispatch(self, **kwargs): + """Send alert to Pagerduty + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + message = 'StreamAlert Rule Triggered - {}'.format(kwargs['rule_name']) + rule_desc = kwargs['alert']['rule_description'] + details = { + 'rule_description': rule_desc, + 'record': kwargs['alert']['record'] + } + data = { + 'service_key': creds['service_key'], + 'event_type': 'trigger', + 'description': message, + 'details': details, + 'client': 'StreamAlert' + } + + resp = self._post_request(creds['url'], data, None, True) + success = self._check_http_response(resp) + + return self._log_status(success) + +@StreamAlertOutput +class PagerDutyOutputV2(OutputDispatcher): + """PagerDutyOutput handles all alert dispatching for PagerDuty Events API v2""" + __service__ = 'pagerduty-v2' + + @classmethod + def _get_default_properties(cls): + """Get the standard url used for PagerDuty Events API v2. This value the same for + everyone, so is hard-coded here and does not need to be configured by the user + + Returns: + dict: Contains various default items for this output (ie: url) + """ + return {'url': 'https://events.pagerduty.com/v2/enqueue'} + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new PagerDuty + event output. This should be sensitive or unique information for this use-case that + needs to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + PagerDuty also requires a routing_key that represents this integration. This + value should be masked during input and is a credential requirement. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this ' + 'PagerDuty integration')), + ('routing_key', + OutputProperty(description='the routing key for this PagerDuty integration', + mask_input=True, + cred_requirement=True)) + ]) + + def dispatch(self, **kwargs): + """Send alert to Pagerduty + + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + summary = 'StreamAlert Rule Triggered - {}'.format(kwargs['rule_name']) + + details = { + 'rule_description': kwargs['alert']['rule_description'], + 'record': kwargs['alert']['record'] + } + payload = { + 'summary': summary, + 'source': kwargs['alert']['log_source'], + 'severity': 'critical', + 'custom_details': details + } + data = { + 'routing_key': creds['routing_key'], + 'payload': payload, + 'event_action': 'trigger', + 'client': 'StreamAlert' + } + + resp = self._post_request(creds['url'], data, None, True) + success = self._check_http_response(resp) + + return self._log_status(success) + +@StreamAlertOutput +class PagerDutyIncidentOutput(OutputDispatcher): + """PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2""" + __service__ = 'pagerduty-incident' + INCIDENTS_ENDPOINT = 'incidents' + USERS_ENDPOINT = 'users' + POLICIES_ENDPOINT = 'escalation_policies' + SERVICES_ENDPOINT = 'services' + + def __init__(self, *args, **kwargs): + OutputDispatcher.__init__(self, *args, **kwargs) + self._base_url = None + self._headers = None + self._escalation_policy = None + + @classmethod + def _get_default_properties(cls): + """Get the standard url used for PagerDuty Incidents API v2. This value the same for + everyone, so is hard-coded here and does not need to be configured by the user + + Returns: + dict: Contains various default items for this output (ie: url) + """ + return {'api': 'https://api.pagerduty.com'} + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new PagerDuty + event output. This should be sensitive or unique information for this use-case that + needs to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + PagerDuty also requires a routing_key that represents this integration. This + value should be masked during input and is a credential requirement. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this ' + 'PagerDuty integration')), + ('token', + OutputProperty(description='the token for this PagerDuty integration', + mask_input=True, + cred_requirement=True)), + ('service_key', + OutputProperty(description='the service key for this PagerDuty integration', + mask_input=True, + cred_requirement=True)), + ('escalation_policy', + OutputProperty(description='the name of the default escalation policy')) + ]) + + @staticmethod + def _get_endpoint(base_url, endpoint): + """Helper to get the full url for a PagerDuty Incidents endpoint. + + Args: + base_url (str): Base URL for the API + endpoint (str): Endpoint that we want the full URL for + + Returns: + str: Full URL of the provided endpoint + """ + return os.path.join(base_url, endpoint) + + def _check_exists_get_id(self, filter_str, url, target_key): + """Generic method to run a search in the PagerDuty REST API and return the id + of the first occurence from the results. + + Args: + filter_str (str): The query filter to search for in the API + url (str): The url to send the requests to in the API + target_key (str): The key to extract in the returned results + + Returns: + str: ID of the targeted element that matches the provided filter or + False if a matching element does not exists. + """ + params = { + 'query': '{}'.format(filter_str) + } + resp = self._get_request(url, params, self._headers, False) + + if not self._check_http_response(resp): + return False + + response = resp.json() + if not response: + return False + + # If there are results, get the first occurence from the list + return response[target_key][0]['id'] if target_key in response else False + + def _user_verify(self, user): + """Method to verify the existance of an user with the API + + Args: + user (str): User to query about in the API. + + Returns: + dict or False: JSON object be used in the API call, containing the user_id + and user_reference. False if user is not found + """ + users_url = self._get_endpoint(self._base_url, self.USERS_ENDPOINT) + + return self._item_verify(users_url, user, self.USERS_ENDPOINT, 'user_reference') + + def _policy_verify(self, policy, default_policy): + """Method to verify the existance of a escalation policy with the API + + Args: + policy (str): Escalation policy to query about in the API + default_policy (str): Escalation policy to use if the first one is not verified + + Returns: + dict: JSON object be used in the API call, containing the policy_id + and escalation_policy_reference + """ + policies_url = self._get_endpoint(self._base_url, self.POLICIES_ENDPOINT) + + verified = self._item_verify(policies_url, policy, self.POLICIES_ENDPOINT, + 'escalation_policy_reference') + + # If the escalation policy provided is not verified in the API, use the default + if verified: + return verified + + return self._item_verify(policies_url, default_policy, self.POLICIES_ENDPOINT, + 'escalation_policy_reference') + + def _service_verify(self, service): + """Method to verify the existance of a service with the API + + Args: + service (str): Service to query about in the API + + Returns: + dict: JSON object be used in the API call, containing the service_id + and the service_reference + """ + services_url = self._get_endpoint(self._base_url, self.SERVICES_ENDPOINT) + + return self._item_verify(services_url, service, self.SERVICES_ENDPOINT, 'service_reference') + + def _item_verify(self, item_url, item_str, item_key, item_type): + """Method to verify the existance of an item with the API + + Args: + item_url (str): URL of the API Endpoint within the API to query + item_str (str): Service to query about in the API + item_key (str): Key to be extracted from search results + item_type (str): Type of item reference to be returned + + Returns: + dict: JSON object be used in the API call, containing the item id + and the item reference, or False if it fails + """ + item_id = self._check_exists_get_id(item_str, item_url, item_key) + if not item_id: + LOGGER.info('%s not found in %s, %s', item_str, item_key, self.__service__) + return False + + return { + 'id': item_id, + 'type': item_type + } + + def _incident_assignment(self, context): + """Method to determine if the incident gets assigned to a user or an escalation policy + + Args: + context (dict): Context provided in the alert record + + Returns: + tuple: assigned_key (str), assigned_value (dict to assign incident to an escalation + policy or array of dicts to assign incident to users) + """ + # Check if a user to assign the incident is provided + user_to_assign = context.get('assigned_user', False) + + # If provided, verify the user and get the id from API + if user_to_assign: + user_assignee = self._user_verify(user_to_assign) + # User is verified, return tuple + if user_assignee: + return 'assignments', [{'assignee': user_assignee}] + + # If escalation policy was not provided, use default one + policy_to_assign = context.get('assigned_policy', self._escalation_policy) + + # Verify escalation policy, return tuple + return 'escalation_policy', self._policy_verify(policy_to_assign, self._escalation_policy) + + def dispatch(self, **kwargs): + """Send incident to Pagerduty Incidents API v2 + + Keyword Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + alert['context'] (dict): Provides user or escalation policy + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + # Preparing headers for API calls + self._headers = { + 'Authorization': 'Token token={}'.format(creds['token']), + 'Accept': 'application/vnd.pagerduty+json;version=2' + } + + # Cache base_url + self._base_url = creds['api'] + + # Cache default escalation policy + self._escalation_policy = creds['escalation_policy'] + + # Extracting context data to assign the incident + rule_context = kwargs['alert'].get('context', {}) + if rule_context: + rule_context = rule_context.get(self.__service__, {}) + + # Incident assignment goes in this order: + # Provided user -> provided policy -> default policy + assigned_key, assigned_value = self._incident_assignment(rule_context) + + # Start preparing the incident JSON blob to be sent to the API + incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(kwargs['rule_name']) + incident_body = { + 'type': 'incident_body', + 'details': kwargs['alert']['rule_description'] + } + # We need to get the service id from the API + incident_service = self._service_verify(creds['service_key']) + incident = { + 'incident': { + 'type': 'incident', + 'title': incident_title, + 'service': incident_service, + 'body': incident_body + }, + assigned_key: assigned_value + } + incidents_url = self._get_endpoint(self._base_url, self.INCIDENTS_ENDPOINT) + resp = self._post_request(incidents_url, incident, self._headers, True) + success = self._check_http_response(resp) + + return self._log_status(success) diff --git a/stream_alert/alert_processor/outputs/phantom.py b/stream_alert/alert_processor/outputs/phantom.py new file mode 100644 index 000000000..60e24fc64 --- /dev/null +++ b/stream_alert/alert_processor/outputs/phantom.py @@ -0,0 +1,155 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from collections import OrderedDict +import os + +from stream_alert.alert_processor import LOGGER +from stream_alert.alert_processor.outputs.output_base import ( + OutputDispatcher, + OutputProperty, + StreamAlertOutput +) + + +@StreamAlertOutput +class PhantomOutput(OutputDispatcher): + """PhantomOutput handles all alert dispatching for Phantom""" + __service__ = 'phantom' + CONTAINER_ENDPOINT = 'rest/container' + ARTIFACT_ENDPOINT = 'rest/artifact' + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new Phantom + output. This should be sensitive or unique information for this use-case that needs + to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + Phantom also requires a ph_auth_token that represnts an authorization token for this + integration and a user provided url to use for alert dispatching. These values should be + masked during input and are credential requirements. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this ' + 'Phantom integration')), + ('ph_auth_token', + OutputProperty(description='the auth token for this Phantom integration', + mask_input=True, + cred_requirement=True)), + ('url', + OutputProperty(description='the endpoint url for this Phantom integration', + mask_input=True, + cred_requirement=True)) + ]) + + def _check_container_exists(self, rule_name, container_url, headers): + """Check to see if a Phantom container already exists for this rule + + Args: + rule_name (str): The name of the rule that triggered the alert + container_url (str): The constructed container url for this Phantom instance + headers (dict): A dictionary containing header parameters + + Returns: + int: ID of an existing Phantom container for this rule where the alerts + will be sent or False if a matching container does not yet exists + """ + # Limit the query to 1 page, since we only care if one container exists with + # this name. + params = { + '_filter_name': '"{}"'.format(rule_name), + 'page_size': 1 + } + resp = self._get_request(container_url, params, headers, False) + if not self._check_http_response(resp): + return False + + response = resp.json() + + # If the count == 0 then we know there are no containers with this name and this + # will evaluate to False. Otherwise there is at least one item in the list + # of 'data' with a container id we can use + return response and response.get('count') and response.get('data')[0]['id'] + + def _setup_container(self, rule_name, rule_description, base_url, headers): + """Establish a Phantom container to write the alerts to. This checks to see + if an appropriate containers exists first and returns the ID if so. + + Args: + rule_name (str): The name of the rule that triggered the alert + base_url (str): The base url for this Phantom instance + headers (dict): A dictionary containing header parameters + + Returns: + int: ID of the Phantom container where the alerts will be sent + or False if there is an issue getting the container id + """ + container_url = os.path.join(base_url, self.CONTAINER_ENDPOINT) + + # Check to see if there is a container already created for this rule name + existing_id = self._check_container_exists(rule_name, container_url, headers) + if existing_id: + return existing_id + + # Try to use the rule_description from the rule as the container description + ph_container = {'name': rule_name, 'description': rule_description} + resp = self._post_request(container_url, ph_container, headers, False) + + if not self._check_http_response(resp): + return False + + response = resp.json() + + return response and response.get('id') + + def dispatch(self, **kwargs): + """Send alert to Phantom + + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + headers = {"ph-auth-token": creds['ph_auth_token']} + rule_desc = kwargs['alert']['rule_description'] + container_id = self._setup_container(kwargs['rule_name'], rule_desc, + creds['url'], headers) + + LOGGER.debug('sending alert to Phantom container with id %s', container_id) + + success = False + if container_id: + artifact = {'cef': kwargs['alert']['record'], + 'container_id': container_id, + 'data': kwargs['alert'], + 'name': 'Phantom Artifact', + 'label': 'Alert'} + artifact_url = os.path.join(creds['url'], self.ARTIFACT_ENDPOINT) + resp = self._post_request(artifact_url, artifact, headers, False) + + success = self._check_http_response(resp) + + return self._log_status(success) diff --git a/stream_alert/alert_processor/outputs/slack.py b/stream_alert/alert_processor/outputs/slack.py new file mode 100644 index 000000000..f1084ca54 --- /dev/null +++ b/stream_alert/alert_processor/outputs/slack.py @@ -0,0 +1,228 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import cgi +from collections import OrderedDict + +from stream_alert.alert_processor.outputs.output_base import ( + OutputDispatcher, + OutputProperty, + StreamAlertOutput +) + + +@StreamAlertOutput +class SlackOutput(OutputDispatcher): + """SlackOutput handles all alert dispatching for Slack""" + __service__ = 'slack' + # Slack recommends no messages larger than 4000 bytes. This does not account for unicode + MAX_MESSAGE_SIZE = 4000 + + def get_user_defined_properties(self): + """Get properties that must be asssigned by the user when configuring a new Slack + output. This should be sensitive or unique information for this use-case that needs + to come from the user. + + Every output should return a dict that contains a 'descriptor' with a description of the + integration being configured. + + Slack also requires a user provided 'webhook' url that is comprised of the slack api url + and the unique integration key for this output. This value should be should be masked + during input and is a credential requirement. + + Returns: + OrderedDict: Contains various OutputProperty items + """ + return OrderedDict([ + ('descriptor', + OutputProperty(description='a short and unique descriptor for this Slack integration ' + '(ie: channel, group, etc)')), + ('url', + OutputProperty(description='the full Slack webhook url, including the secret', + mask_input=True, + cred_requirement=True)) + ]) + + @classmethod + def _format_message(cls, rule_name, alert): + """Format the message to be sent to slack. + + Args: + rule_name (str): The name of the rule that triggered the alert + alert: Alert relevant to the triggered rule + + Returns: + dict: message with attachments to send to Slack. + The message will look like: + StreamAlert Rule Triggered: rule_name + Rule Description: + This will be the docstring from the rule, sent as the rule_description + + Record (Part 1 of 2): + ... + """ + # Convert the alert we have to a nicely formatted string for slack + alert_text = '\n'.join(cls._json_to_slack_mrkdwn(alert['record'], 0)) + # Slack requires escaping the characters: '&', '>' and '<' and cgi does just that + alert_text = cgi.escape(alert_text) + messages = [] + index = cls.MAX_MESSAGE_SIZE + while alert_text != '': + if len(alert_text) <= index: + messages.append(alert_text) + break + + # Find the closest line break prior to this index + while index > 1 and alert_text[index] != '\n': + index -= 1 + + # Append the message part up until this index, and move to the next chunk + messages.append(alert_text[:index]) + alert_text = alert_text[index+1:] + + index = cls.MAX_MESSAGE_SIZE + + header_text = '*StreamAlert Rule Triggered: {}*'.format(rule_name) + full_message = { + 'text': header_text, + 'mrkdwn': True, + 'attachments': [] + } + + for index, message in enumerate(messages): + title = 'Record:' + if len(messages) > 1: + title = 'Record (Part {} of {}):'.format(index+1, len(messages)) + rule_desc = '' + # Only print the rule description on the first attachment + if index == 0: + rule_desc = alert['rule_description'] + rule_desc = '*Rule Description:*\n{}\n'.format(rule_desc) + + # Add this attachemnt to the full message array of attachments + full_message['attachments'].append({ + 'fallback': header_text, + 'color': '#b22222', + 'pretext': rule_desc, + 'title': title, + 'text': message, + 'mrkdwn_in': ['text', 'pretext'] + }) + + # Return the json dict payload to be sent to slack + return full_message + + @classmethod + def _json_to_slack_mrkdwn(cls, json_values, indent_count): + """Translate a json object into a more human-readable list of lines + This will handle recursion of all nested maps and lists within the object + + Args: + json_values: variant to be translated (could be json map, list, etc) + indent_count (int): Number of tabs to prefix each line with + + Returns: + list: strings that have been properly tabbed and formatted for printing + """ + tab = '\t' + all_lines = [] + if isinstance(json_values, dict): + all_lines = cls._json_map_to_text(json_values, tab, indent_count) + elif isinstance(json_values, list): + all_lines = cls._json_list_to_text(json_values, tab, indent_count) + else: + all_lines.append('{}'.format(json_values)) + + return all_lines + + @classmethod + def _json_map_to_text(cls, json_values, tab, indent_count): + """Translate a map from json (dict) into a more human-readable list of lines + This will handle recursion of all nested maps and lists within the map + + Args: + json_values (dict): dictionary to be iterated over and formatted + tab (str): string value to use for indentation + indent_count (int): Number of tabs to prefix each line with + + Returns: + list: strings that have been properly tabbed and formatted for printing + """ + all_lines = [] + for key, value in json_values.iteritems(): + if isinstance(value, (dict, list)) and value: + all_lines.append('{}*{}:*'.format(tab*indent_count, key)) + all_lines.extend(cls._json_to_slack_mrkdwn(value, indent_count+1)) + else: + new_lines = cls._json_to_slack_mrkdwn(value, indent_count+1) + if len(new_lines) == 1: + all_lines.append('{}*{}:* {}'.format(tab*indent_count, key, new_lines[0])) + elif new_lines: + all_lines.append('{}*{}:*'.format(tab*indent_count, key)) + all_lines.extend(new_lines) + else: + all_lines.append('{}*{}:* {}'.format(tab*indent_count, key, value)) + + return all_lines + + @classmethod + def _json_list_to_text(cls, json_values, tab, indent_count): + """Translate a list from json into a more human-readable list of lines + This will handle recursion of all nested maps and lists within the list + + Args: + json_values (list): list to be iterated over and formatted + tab (str): string value to use for indentation + indent_count (int): Number of tabs to prefix each line with + + Returns: + list: strings that have been properly tabbed and formatted for printing + """ + all_lines = [] + for index, value in enumerate(json_values): + if isinstance(value, (dict, list)) and value: + all_lines.append('{}*[{}]*'.format(tab*indent_count, index+1)) + all_lines.extend(cls._json_to_slack_mrkdwn(value, indent_count+1)) + else: + new_lines = cls._json_to_slack_mrkdwn(value, indent_count+1) + if len(new_lines) == 1: + all_lines.append('{}*[{}]* {}'.format(tab*indent_count, index+1, new_lines[0])) + elif new_lines: + all_lines.append('{}*[{}]*'.format(tab*indent_count, index+1)) + all_lines.extend(new_lines) + else: + all_lines.append('{}*[{}]* {}'.format(tab*indent_count, index+1, value)) + + return all_lines + + def dispatch(self, **kwargs): + """Send alert text to Slack + + Args: + **kwargs: consists of any combination of the following items: + descriptor (str): Service descriptor (ie: slack channel, pd integration) + rule_name (str): Name of the triggered rule + alert (dict): Alert relevant to the triggered rule + """ + creds = self._load_creds(kwargs['descriptor']) + if not creds: + return self._log_status(False) + + slack_message = self._format_message(kwargs['rule_name'], kwargs['alert']) + + resp = self._post_request(creds['url'], slack_message) + success = self._check_http_response(resp) + + return self._log_status(success) diff --git a/stream_alert_cli/runner.py b/stream_alert_cli/runner.py index d50f842e0..53697660f 100644 --- a/stream_alert_cli/runner.py +++ b/stream_alert_cli/runner.py @@ -14,7 +14,7 @@ limitations under the License. """ from app_integrations.apps.app_base import StreamAlertApp -from stream_alert.alert_processor.outputs import get_output_dispatcher +from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput from stream_alert_cli.apps import save_app_auth_info from stream_alert_cli.athena.handler import athena_handler from stream_alert_cli.config import CLIConfig @@ -108,8 +108,12 @@ def configure_output(options): kms_key_alias = kms_key_alias.split('/')[1] # Retrieve the proper service class to handle dispatching the alerts of this services - output = get_output_dispatcher(options.service, region, prefix, - config_outputs.load_outputs_config()) + output = StreamAlertOutput.get_output_dispatcher( + options.service, + region, + prefix, + config_outputs.load_outputs_config() + ) # If an output for this service has not been defined, the error is logged # prior to this From 9735069ad3907aa4098d92885ff56f8fcb6de1e8 Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Sat, 18 Nov 2017 13:59:48 -0800 Subject: [PATCH 2/9] [tests] output unit test updates to go with broken out code --- stream_alert/alert_processor/main.py | 17 +- stream_alert/alert_processor/outputs/aws.py | 16 +- stream_alert/alert_processor/outputs/jira.py | 3 +- .../alert_processor/outputs/output_base.py | 30 +- .../alert_processor/outputs/pagerduty.py | 9 +- .../alert_processor/outputs/phantom.py | 3 +- stream_alert/alert_processor/outputs/slack.py | 3 +- stream_alert_cli/runner.py | 7 +- .../stream_alert_alert_processor/helpers.py | 18 +- .../test_helpers.py | 7 +- .../stream_alert_alert_processor/test_main.py | 50 +- .../test_output_base.py | 177 -- .../test_outputs.py | 1527 ----------------- .../test_outputs/__init__.py | 21 + .../test_outputs/test_aws.py | 215 +++ .../test_outputs/test_jira.py | 199 +++ .../test_outputs/test_output_base.py | 207 +++ .../test_outputs/test_pagerduty.py | 675 ++++++++ .../test_outputs/test_phantom.py | 205 +++ .../test_outputs/test_slack.py | 235 +++ tests/unit/stream_alert_cli/test_outputs.py | 2 +- 21 files changed, 1833 insertions(+), 1793 deletions(-) delete mode 100644 tests/unit/stream_alert_alert_processor/test_output_base.py delete mode 100644 tests/unit/stream_alert_alert_processor/test_outputs.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/__init__.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py diff --git a/stream_alert/alert_processor/main.py b/stream_alert/alert_processor/main.py index bb7eea0c3..3efdedecd 100644 --- a/stream_alert/alert_processor/main.py +++ b/stream_alert/alert_processor/main.py @@ -101,23 +101,18 @@ def run(alert, region, function_name, config): continue # Retrieve the proper class to handle dispatching the alerts of this services - output_dispatcher = StreamAlertOutput.get_output_dispatcher( - service, - region, - function_name, - config - ) - - if not output_dispatcher: + dispatcher = StreamAlertOutput.create_dispatcher(service, region, function_name, config) + + if not dispatcher: continue LOGGER.debug('Sending alert to %s:%s', service, descriptor) sent = False try: - sent = output_dispatcher.dispatch(descriptor=descriptor, - rule_name=alert['rule_name'], - alert=alert) + sent = dispatcher.dispatch(descriptor=descriptor, + rule_name=alert['rule_name'], + alert=alert) except Exception as err: # pylint: disable=broad-except LOGGER.exception('An error occurred while sending alert ' diff --git a/stream_alert/alert_processor/outputs/aws.py b/stream_alert/alert_processor/outputs/aws.py index 418adb086..c144ad293 100644 --- a/stream_alert/alert_processor/outputs/aws.py +++ b/stream_alert/alert_processor/outputs/aws.py @@ -38,7 +38,9 @@ class AWSOutput(OutputDispatcher): """Subclass to be inherited from for all AWS service outputs""" - def format_output_config(self, service_config, values): + + @classmethod + def format_output_config(cls, service_config, values): """Format the output configuration for this AWS service to be written to disk AWS services are stored as a dictionary within the config instead of a list so @@ -56,13 +58,12 @@ def format_output_config(self, service_config, values): subclasses should use a generic 'aws_value' to store the value for the descriptor used in configuration """ - return dict(service_config.get(self.__service__, {}), + return dict(service_config.get(cls.__service__, {}), **{values['descriptor'].value: values['aws_value'].value}) @abstractmethod def dispatch(self, **kwargs): """Placeholder for implementation in the subclasses""" - pass @StreamAlertOutput @@ -74,7 +75,8 @@ class KinesisFirehoseOutput(AWSOutput): __service__ = 'aws-firehose' __aws_client__ = None - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Properties asssigned by the user when configuring a new Firehose output Every output should return a dict that contains a 'descriptor' with a description of the @@ -152,7 +154,8 @@ class S3Output(AWSOutput): """S3Output handles all alert dispatching for AWS S3""" __service__ = 'aws-s3' - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new S3 output. This should be sensitive or unique information for this use-case that needs to come from the user. @@ -230,7 +233,8 @@ class LambdaOutput(AWSOutput): """LambdaOutput handles all alert dispatching to AWS Lambda""" __service__ = 'aws-lambda' - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new Lambda output. This should be sensitive or unique information for this use-case that needs to come from the user. diff --git a/stream_alert/alert_processor/outputs/jira.py b/stream_alert/alert_processor/outputs/jira.py index 537603675..bf2b9f12f 100644 --- a/stream_alert/alert_processor/outputs/jira.py +++ b/stream_alert/alert_processor/outputs/jira.py @@ -40,7 +40,8 @@ def __init__(self, *args, **kwargs): self._base_url = None self._auth_cookie = None - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new Jira output. This should be sensitive or unique information for this use-case that needs to come from the user. diff --git a/stream_alert/alert_processor/outputs/output_base.py b/stream_alert/alert_processor/outputs/output_base.py index 937572409..7f6f3d724 100644 --- a/stream_alert/alert_processor/outputs/output_base.py +++ b/stream_alert/alert_processor/outputs/output_base.py @@ -45,7 +45,7 @@ def __new__(cls, output): return output @classmethod - def get_output_dispatcher(cls, service, region, function_name, config): + def create_dispatcher(cls, service, region, function_name, config): """Returns the subclass that should handle this particular service Args: @@ -54,11 +54,27 @@ def get_output_dispatcher(cls, service, region, function_name, config): function_name (str): The invoking AWS Lambda function name config (dict): The loaded output configuration dict + Returns: + OutputDispatcher: Subclass of OutputDispatcher to use for sending alerts + """ + dispatcher = cls.get_dispatcher(service) + if not dispatcher: + return False + + return dispatcher(region, function_name, config) + + @classmethod + def get_dispatcher(cls, service): + """Returns the subclass that should handle this particular service + + Args: + service (str): The service identifier for this output + Returns: OutputDispatcher: Subclass of OutputDispatcher to use for sending alerts """ try: - return cls._outputs[service](region, function_name, config) + return cls._outputs[service] except KeyError: LOGGER.error('Designated output service [%s] does not exist', service) @@ -74,7 +90,7 @@ def get_all_outputs(cls): class OutputDispatcher(object): - """StreamOutputBase is the base class to handle routing alerts to outputs + """OutputDispatcher is the base class to handle routing alerts to outputs Public methods: get_secrets_bucket_name: returns the name of the s3 bucket for secrets that @@ -305,7 +321,8 @@ def output_cred_name(self, descriptor): return cred_name - def format_output_config(self, service_config, values): + @classmethod + def format_output_config(cls, service_config, values): """Add this descriptor to the list of descriptor this service If the service doesn't exist, a new entry is added to an empty list @@ -315,10 +332,11 @@ def format_output_config(self, service_config, values): Returns: [list] List of descriptors for this service """ - return service_config.get(self.__service__, []) + [values['descriptor'].value] + return service_config.get(cls.__service__, []) + [values['descriptor'].value] + @classmethod @abstractmethod - def get_user_defined_properties(self): + def get_user_defined_properties(cls): """Base method for retrieving properties that must be asssigned by the user when configuring a new output for this service. This should include any information that is sensitive or use-case specific. For intance, if the url needed for this integration diff --git a/stream_alert/alert_processor/outputs/pagerduty.py b/stream_alert/alert_processor/outputs/pagerduty.py index d1b03ee14..b4f1f8767 100644 --- a/stream_alert/alert_processor/outputs/pagerduty.py +++ b/stream_alert/alert_processor/outputs/pagerduty.py @@ -38,7 +38,8 @@ def _get_default_properties(cls): """ return {'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'} - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new PagerDuty output. This should be sensitive or unique information for this use-case that needs to come from the user. @@ -105,7 +106,8 @@ def _get_default_properties(cls): """ return {'url': 'https://events.pagerduty.com/v2/enqueue'} - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new PagerDuty event output. This should be sensitive or unique information for this use-case that needs to come from the user. @@ -191,7 +193,8 @@ def _get_default_properties(cls): """ return {'api': 'https://api.pagerduty.com'} - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new PagerDuty event output. This should be sensitive or unique information for this use-case that needs to come from the user. diff --git a/stream_alert/alert_processor/outputs/phantom.py b/stream_alert/alert_processor/outputs/phantom.py index 60e24fc64..0eba6e3d0 100644 --- a/stream_alert/alert_processor/outputs/phantom.py +++ b/stream_alert/alert_processor/outputs/phantom.py @@ -31,7 +31,8 @@ class PhantomOutput(OutputDispatcher): CONTAINER_ENDPOINT = 'rest/container' ARTIFACT_ENDPOINT = 'rest/artifact' - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new Phantom output. This should be sensitive or unique information for this use-case that needs to come from the user. diff --git a/stream_alert/alert_processor/outputs/slack.py b/stream_alert/alert_processor/outputs/slack.py index f1084ca54..e60bd9fd1 100644 --- a/stream_alert/alert_processor/outputs/slack.py +++ b/stream_alert/alert_processor/outputs/slack.py @@ -30,7 +30,8 @@ class SlackOutput(OutputDispatcher): # Slack recommends no messages larger than 4000 bytes. This does not account for unicode MAX_MESSAGE_SIZE = 4000 - def get_user_defined_properties(self): + @classmethod + def get_user_defined_properties(cls): """Get properties that must be asssigned by the user when configuring a new Slack output. This should be sensitive or unique information for this use-case that needs to come from the user. diff --git a/stream_alert_cli/runner.py b/stream_alert_cli/runner.py index 53697660f..5e206cca4 100644 --- a/stream_alert_cli/runner.py +++ b/stream_alert_cli/runner.py @@ -108,12 +108,7 @@ def configure_output(options): kms_key_alias = kms_key_alias.split('/')[1] # Retrieve the proper service class to handle dispatching the alerts of this services - output = StreamAlertOutput.get_output_dispatcher( - options.service, - region, - prefix, - config_outputs.load_outputs_config() - ) + output = StreamAlertOutput.get_dispatcher(options.service) # If an output for this service has not been defined, the error is logged # prior to this diff --git a/tests/unit/stream_alert_alert_processor/helpers.py b/tests/unit/stream_alert_alert_processor/helpers.py index 71ac635c9..0030030e9 100644 --- a/tests/unit/stream_alert_alert_processor/helpers.py +++ b/tests/unit/stream_alert_alert_processor/helpers.py @@ -14,7 +14,6 @@ limitations under the License. """ from collections import OrderedDict -import json import os import random import shutil @@ -25,16 +24,6 @@ from tests.unit.stream_alert_alert_processor import FUNCTION_NAME, REGION -def construct_event(count): - """Helper to construct a valid test 'event' with an arbitrary number of records""" - event = {'Records': []} - for index in range(count): - event['Records'] = event['Records'] + \ - [{'Sns': {'Message': json.dumps(get_alert(index))}}] - - return event - - def get_mock_context(): """Create a fake context object using Mock""" arn = 'arn:aws:lambda:{}:555555555555:function:{}:production' @@ -72,18 +61,15 @@ def get_random_alert(key_count, rule_name, omit_rule_desc=False): return alert -def get_alert(index=0, context=None): +def get_alert(context=None): """This function generates a sample alert for testing purposes Args: index (int): test_index value (0 by default) context(dict): context dictionary (None by default) """ - context = context or {} - return { 'record': { - 'test_index': index, 'compressed_size': '9982', 'timestamp': '1496947381.18', 'node_id': '1', @@ -98,7 +84,7 @@ def get_alert(index=0, context=None): 'outputs': [ 'slack:unit_test_channel' ], - 'context': context, + 'context': context or dict(), 'source_service': 's3', 'source_entity': 'corp-prefix.prod.cb.region', 'log_type': 'json', diff --git a/tests/unit/stream_alert_alert_processor/test_helpers.py b/tests/unit/stream_alert_alert_processor/test_helpers.py index 1257a48ef..c1b9ad0a6 100644 --- a/tests/unit/stream_alert_alert_processor/test_helpers.py +++ b/tests/unit/stream_alert_alert_processor/test_helpers.py @@ -21,11 +21,8 @@ def test_valid_alert(): """Alert Processor Input Validation - Valid Alert Structure""" - # Default valid alert to test - valid_alert = get_alert() - # Test with a valid alert structure - assert_true(validate_alert(valid_alert)) + assert_true(validate_alert(get_alert())) def test_valid_alert_type(): @@ -38,7 +35,7 @@ def test_alert_keys(): # Default valid alert to be modified missing_alert_key = get_alert() - # Alter 'metadata' keys to break validation (not all required keys) + # Alter keys to break validation (not all required keys) missing_alert_key.pop('rule_name') # Test with invalid metadata keys diff --git a/tests/unit/stream_alert_alert_processor/test_main.py b/tests/unit/stream_alert_alert_processor/test_main.py index 2b783afa9..e15c9a06b 100644 --- a/tests/unit/stream_alert_alert_processor/test_main.py +++ b/tests/unit/stream_alert_alert_processor/test_main.py @@ -16,15 +16,13 @@ # pylint: disable=protected-access from collections import OrderedDict import json -import os from mock import call, mock_open, patch from nose.tools import ( assert_equal, assert_is_instance, assert_list_equal, - assert_true, - with_setup + assert_true ) import stream_alert.alert_processor as ap @@ -100,7 +98,7 @@ def test_sort_dict_recursive(): @patch('requests.post') @patch('stream_alert.alert_processor.main._load_output_config') -@patch('stream_alert.alert_processor.output_base.StreamOutputBase._load_creds') +@patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher._load_creds') def test_running_success(creds_mock, config_mock, get_mock): """Alert Processor run handler - success""" config_mock.return_value = _load_output_config('tests/unit/conf/outputs.json') @@ -137,14 +135,13 @@ def test_running_bad_output(config_mock, log_mock): handler(alert, context) - log_mock.assert_called_with( - 'The output \'%s\' does not exist!', 'slakc:test') + log_mock.assert_called_with('The output \'%s\' does not exist!', 'slakc:test') @patch('stream_alert.alert_processor.main._load_output_config') -@patch('stream_alert.alert_processor.main.get_output_dispatcher') +@patch('stream_alert.alert_processor.outputs.output_base.StreamAlertOutput.get_dispatcher') def test_running_no_dispatcher(dispatch_mock, config_mock): - """Alert Processor run handler - no dispatcher""" + """Alert Processor - Run Handler With No Dispatcher""" config_mock.return_value = _load_output_config('tests/unit/conf/outputs.json') dispatch_mock.return_value = None @@ -160,10 +157,10 @@ def test_running_no_dispatcher(dispatch_mock, config_mock): @patch('logging.Logger.exception') @patch('requests.get') @patch('stream_alert.alert_processor.main._load_output_config') -@patch('stream_alert.alert_processor.main.get_output_dispatcher') -@patch('stream_alert.alert_processor.output_base.StreamOutputBase._load_creds') +@patch('stream_alert.alert_processor.outputs.output_base.StreamAlertOutput.create_dispatcher') +@patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher._load_creds') def test_running_exception_occurred(creds_mock, dispatch_mock, config_mock, get_mock, log_mock): - """Alert Processor run handler - exception occurred""" + """Alert Processor - Run Handler, Exception Occurred""" # Use TypeError as the mock's side_effect err = TypeError('bad error') creds_mock.return_value = {'url': 'mock.url'} @@ -181,38 +178,27 @@ def test_running_exception_occurred(creds_mock, dispatch_mock, config_mock, get_ 'to %s:%s: %s. alert:\n%s', 'slack', 'unit_test_channel', err, json.dumps(alert, indent=2)) -def _teardown_env(): - """Helper method to reset environment variables""" - if 'LOGGER_LEVEL' in os.environ: - del os.environ['LOGGER_LEVEL'] - -@with_setup(setup=None, teardown=_teardown_env) @patch('stream_alert.alert_processor.LOGGER.error') def test_init_logging_bad(log_mock): """Alert Processor Init - Logging, Bad Level""" - level = 'IFNO' + with patch.dict('os.environ', {'LOGGER_LEVEL': 'IFNO'}): - os.environ['LOGGER_LEVEL'] = level + # Force reload the alert_processor package to trigger the init + reload(ap) - # Force reload the alert_processor package to trigger the init - reload(ap) + message = str(call('Defaulting to INFO logging: %s', + ValueError('Unknown level: \'IFNO\'',))) - message = str(call('Defaulting to INFO logging: %s', - ValueError('Unknown level: \'IFNO\'',))) + assert_equal(str(log_mock.call_args_list[0]), message) - assert_equal(str(log_mock.call_args_list[0]), message) - -@with_setup(setup=None, teardown=_teardown_env) @patch('stream_alert.alert_processor.LOGGER.setLevel') def test_init_logging_int_level(log_mock): """Alert Processor Init - Logging, Integer Level""" - level = '10' - - os.environ['LOGGER_LEVEL'] = level + with patch.dict('os.environ', {'LOGGER_LEVEL': '10'}): - # Force reload the alert_processor package to trigger the init - reload(ap) + # Force reload the alert_processor package to trigger the init + reload(ap) - log_mock.assert_called_with(10) + log_mock.assert_called_with(10) diff --git a/tests/unit/stream_alert_alert_processor/test_output_base.py b/tests/unit/stream_alert_alert_processor/test_output_base.py deleted file mode 100644 index 35ce55817..000000000 --- a/tests/unit/stream_alert_alert_processor/test_output_base.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -Copyright 2017-present, Airbnb Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -# pylint: disable=abstract-class-instantiated,protected-access -import os - -from mock import patch -from moto import mock_kms, mock_s3 -from nose.tools import assert_equal, assert_is_not_none - -from stream_alert.alert_processor.output_base import OutputProperty, StreamOutputBase -from stream_alert_cli.helpers import encrypt_with_kms, put_mock_creds, put_mock_s3_object -from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION -from tests.unit.stream_alert_alert_processor.helpers import remove_temp_secrets - -# Remove all abstractmethods from __abstractmethods__ so we can -# instantiate StreamOutputBase for testing -StreamOutputBase.__abstractmethods__ = frozenset() -StreamOutputBase.__service__ = 'test_service' - - -def test_output_property_default(): - """OutputProperty defaults""" - prop = OutputProperty() - - assert_equal(prop.description, '') - assert_equal(prop.value, '') - assert_equal(prop.input_restrictions, {' ', ':'}) - assert_equal(prop.mask_input, False) - assert_equal(prop.cred_requirement, False) - - -class TestStreamOutputBase(object): - """Test class for StreamOutputBase - - Perform various tests for methods inherited by all output classes - """ - __dispatcher = None - __descriptor = 'desc_test' - - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__dispatcher = StreamOutputBase(REGION, FUNCTION_NAME, CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def test_local_temp_dir(self): - """StreamOutputBase Local Temp Dir""" - temp_dir = self.__dispatcher._local_temp_dir() - assert_equal(temp_dir.split('/')[-1], 'stream_alert_secrets') - - def test_get_secrets_bucket_name(self): - """StreamOutputBase Get Secrets Bucket Name""" - bucket_name = self.__dispatcher._get_secrets_bucket_name(FUNCTION_NAME) - assert_equal(bucket_name, 'corp-prefix.streamalert.secrets') - - def test_output_cred_name(self): - """StreamOutputBase Output Cred Name""" - output_name = self.__dispatcher.output_cred_name('creds') - assert_equal(output_name, 'test_service/creds') - - @mock_s3 - def test_get_creds_from_s3(self): - """StreamOutputBase Get Creds From S3""" - descriptor = 'test_descriptor' - test_data = 'credential test string' - - bucket_name = self.__dispatcher.secrets_bucket - key = self.__dispatcher.output_cred_name(descriptor) - - local_cred_location = os.path.join(self.__dispatcher._local_temp_dir(), key) - - put_mock_s3_object(bucket_name, key, test_data, REGION) - - self.__dispatcher._get_creds_from_s3(local_cred_location, descriptor) - - with open(local_cred_location) as creds: - line = creds.readline() - - assert_equal(line, test_data) - - @mock_kms - def test_kms_decrypt(self): - """StreamOutputBase KMS Decrypt""" - test_data = 'data to encrypt' - encrypted = encrypt_with_kms(test_data, REGION, KMS_ALIAS) - decrypted = self.__dispatcher._kms_decrypt(encrypted) - - assert_equal(decrypted, test_data) - - @patch('logging.Logger.info') - def test_log_status_success(self, log_mock): - """StreamOutputBase Log status success""" - self.__dispatcher._log_status(True) - log_mock.assert_called_with('Successfully sent alert to %s', 'test_service') - - @patch('logging.Logger.error') - def test_log_status_failed(self, log_mock): - """StreamOutputBase Log status failed""" - self.__dispatcher._log_status(False) - log_mock.assert_called_with('Failed to send alert to %s', 'test_service') - - @patch('requests.Response') - def test_check_http_response(self, mock_response): - """StreamOutputBase Check HTTP Response""" - # Test with a good response code - mock_response.status_code = 200 - result = self.__dispatcher._check_http_response(mock_response) - assert_equal(result, True) - - # Test with a bad response code - mock_response.status_code = 440 - result = self.__dispatcher._check_http_response(mock_response) - assert_equal(result, False) - - @mock_s3 - @mock_kms - def test_load_creds(self): - """Load Credentials""" - remove_temp_secrets() - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': 'http://www.foo.bar/test', - 'token': 'token_to_encrypt'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - loaded_creds = self.__dispatcher._load_creds(self.__descriptor) - - assert_is_not_none(loaded_creds) - assert_equal(len(loaded_creds), 2) - assert_equal(loaded_creds['url'], u'http://www.foo.bar/test') - assert_equal(loaded_creds['token'], u'token_to_encrypt') - - -class TestFormatOutputConfig(object): - """Test class for Output Config formatting""" - __cached_name = StreamOutputBase.__service__ - - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - # Switch out the test service to one that is in the outputs.json file - StreamOutputBase.__service__ = 'slack' - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - StreamOutputBase.__service__ = cls.__cached_name - - @staticmethod - def test_format_output_config(): - """Format Output Config""" - props = {'descriptor': OutputProperty('test_desc', 'test_channel')} - - formatted = StreamOutputBase(REGION, FUNCTION_NAME, - CONFIG).format_output_config(CONFIG, props) - - assert_equal(len(formatted), 2) - assert_equal(formatted[0], 'unit_test_channel') - assert_equal(formatted[1], 'test_channel') diff --git a/tests/unit/stream_alert_alert_processor/test_outputs.py b/tests/unit/stream_alert_alert_processor/test_outputs.py deleted file mode 100644 index e59d6bbb0..000000000 --- a/tests/unit/stream_alert_alert_processor/test_outputs.py +++ /dev/null @@ -1,1527 +0,0 @@ -""" -Copyright 2017-present, Airbnb Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -# pylint: disable=protected-access -# pylint: disable=too-many-lines -from collections import Counter, OrderedDict - -import json -import boto3 -from mock import call, patch, PropertyMock -from moto import mock_s3, mock_kms, mock_lambda, mock_kinesis -from nose.tools import ( - assert_equal, - assert_false, - assert_is_none, - assert_is_not_none, - assert_set_equal, - assert_true -) - -from stream_alert.alert_processor import outputs -from stream_alert.alert_processor.output_base import OutputProperty -from stream_alert_cli.helpers import create_lambda_function, put_mock_creds -from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION -from tests.unit.stream_alert_alert_processor.helpers import ( - get_random_alert, - get_alert, - remove_temp_secrets -) - - -def test_existing_get_output_dispatcher(): - """Get output dispatcher - existing""" - service = 'aws-s3' - dispatcher = outputs.get_output_dispatcher( - service, REGION, FUNCTION_NAME, CONFIG) - assert_is_not_none(dispatcher) - - -def test_nonexistent_get_output_dispatcher(): - """Get output dispatcher - nonexistent""" - nonexistent_service = 'aws-s4' - dispatcher = outputs.get_output_dispatcher(nonexistent_service, - REGION, - FUNCTION_NAME, - CONFIG) - assert_is_none(dispatcher) - - -@patch('logging.Logger.error') -def test_get_output_dispatcher_logging(log_mock): - """Get output dispatcher - log error""" - bad_service = 'bad-output' - outputs.get_output_dispatcher(bad_service, REGION, FUNCTION_NAME, CONFIG) - log_mock.assert_called_with( - 'designated output service [%s] does not exist', - bad_service) - - -def test_user_defined_properties(): - """Get user defined properties""" - for output in outputs.STREAM_OUTPUTS.values(): - props = output(REGION, FUNCTION_NAME, CONFIG).get_user_defined_properties() - # The user defined properties should at a minimum contain a descriptor - assert_is_not_none(props.get('descriptor')) - -class TestPagerDutyOutput(object): - """Test class for PagerDutyOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'pagerduty' - cls.__descriptor = 'unit_test_pagerduty' - cls.__backup_method = None - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def test_get_default_properties(self): - """Get Default Properties - PagerDuty""" - props = self.__dispatcher._get_default_properties() - assert_equal(len(props), 1) - assert_equal(props['url'], - 'https://events.pagerduty.com/generic/2010-04-15/create_event.json') - - def _setup_dispatch(self): - """Helper for setting up PagerDutyOutput dispatch""" - remove_temp_secrets() - - # Cache the _get_default_properties and set it to return None - self.__backup_method = self.__dispatcher._get_default_properties - self.__dispatcher._get_default_properties = lambda: None - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': 'http://pagerduty.foo.bar/create_event.json', - 'service_key': 'mocked_service_key'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() - - def _teardown_dispatch(self): - """Replace method with cached method""" - self.__dispatcher._get_default_properties = self.__backup_method - - @patch('logging.Logger.info') - @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_success(self, post_mock, log_info_mock): - """PagerDutyOutput dispatch success""" - alert = self._setup_dispatch() - post_mock.return_value.status_code = 200 - post_mock.return_value.text = '' - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_failure(self, post_mock, log_error_mock): - """PagerDutyOutput dispatch failure""" - alert = self._setup_dispatch() - post_mock.return_value.text = '{"message": "error message", "errors": ["error1"]}' - post_mock.return_value.status_code = 400 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """PagerDutyOutput dispatch bad descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - -class TestPagerDutyOutputV2(object): - """Test class for PagerDutyOutputV2""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'pagerduty-v2' - cls.__descriptor = 'unit_test_pagerduty-v2' - cls.__backup_method = None - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def test_get_default_properties(self): - """Get Default Properties - PagerDuty""" - props = self.__dispatcher._get_default_properties() - assert_equal(len(props), 1) - assert_equal(props['url'], - 'https://events.pagerduty.com/v2/enqueue') - - def _setup_dispatch(self): - """Helper for setting up PagerDutyOutputV2 dispatch""" - remove_temp_secrets() - - # Cache the _get_default_properties and set it to return None - self.__backup_method = self.__dispatcher._get_default_properties - self.__dispatcher._get_default_properties = lambda: None - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': 'http://pagerduty.foo.bar/create_event.json', - 'routing_key': 'mocked_routing_key'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() - - def _teardown_dispatch(self): - """Replace method with cached method""" - self.__dispatcher._get_default_properties = self.__backup_method - - @patch('logging.Logger.info') - @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_success(self, post_mock, log_info_mock): - """PagerDutyOutputV2 dispatch success""" - alert = self._setup_dispatch() - post_mock.return_value.status_code = 200 - post_mock.return_value.text = '' - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_failure(self, post_mock, log_error_mock): - """PagerDutyOutputV2 dispatch failure""" - alert = self._setup_dispatch() - json_error = json.loads('{"message": "error message", "errors": ["error1"]}') - post_mock.return_value.json.return_value = json_error - post_mock.return_value.status_code = 400 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """PagerDutyOutputV2 dispatch bad descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - -class TestPagerDutyIncidentOutput(object): - """Test class for PagerDutyIncidentOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'pagerduty-incident' - cls.__descriptor = 'unit_test_pagerduty-incident' - cls.__backup_method = None - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def test_get_default_properties(self): - """Get Default Properties - PagerDutyIncidentOutput""" - props = self.__dispatcher._get_default_properties() - assert_equal(len(props), 1) - assert_equal(props['api'], - 'https://api.pagerduty.com') - - def test_get_endpoint(self): - """Get Endpoint - PagerDutyIncidentOutput""" - props = self.__dispatcher._get_default_properties() - endpoint = self.__dispatcher._get_endpoint(props['api'], 'testtest') - assert_equal(endpoint, - 'https://api.pagerduty.com/testtest') - - def _setup_dispatch(self, context=None): - """Helper for setting up PagerDutyIncidentOutput dispatch""" - remove_temp_secrets() - - # Cache the _get_default_properties and set it to return None - self.__backup_method = self.__dispatcher._get_default_properties - self.__dispatcher._get_default_properties = lambda: None - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'api': 'https://api.pagerduty.com', - 'token': 'mocked_token', - 'service_key': 'mocked_service_key', - 'escalation_policy': 'mocked_escalation_policy', - 'email_from': 'email@domain.com'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert(0, context) - - def _teardown_dispatch(self): - """Replace method with cached method""" - self.__dispatcher._get_default_properties = self.__backup_method - - @patch('requests.get') - def test_check_exists_get_id(self, get_mock): - """Check Exists Get Id - PagerDutyIncidentOutput""" - # /check - get_mock.return_value.status_code = 200 - json_check = json.loads('{"check": [{"id": "checked_id"}]}') - get_mock.return_value.json.return_value = json_check - - checked = self.__dispatcher._check_exists('filter', 'http://mock_url', 'check') - assert_equal(checked, 'checked_id') - - @patch('requests.get') - def test_check_exists_get_id_fail(self, get_mock): - """Check Exists Get Id Fail - PagerDutyIncidentOutput""" - # /check - get_mock.return_value.status_code = 200 - json_check = json.loads('{}') - get_mock.return_value.json.return_value = json_check - - checked = self.__dispatcher._check_exists('filter', 'http://mock_url', 'check') - assert_false(checked) - - @patch('requests.get') - def test_check_exists_no_get_id(self, get_mock): - """Check Exists No Get Id - PagerDutyIncidentOutput""" - # /check - get_mock.return_value.status_code = 200 - json_check = json.loads('{"check": [{"id": "checked_id"}]}') - get_mock.return_value.json.return_value = json_check - - assert_true(self.__dispatcher._check_exists('filter', 'http://mock_url', 'check', False)) - - @patch('requests.get') - def test_user_verify_success(self, get_mock): - """User Verify Success - PagerDutyIncidentOutput""" - get_mock.return_value.status_code = 200 - json_check = json.loads('{"users": [{"id": "verified_user_id"}]}') - get_mock.return_value.json.return_value = json_check - - user_verified = self.__dispatcher._user_verify('valid_user') - assert_equal(user_verified['id'], 'verified_user_id') - assert_equal(user_verified['type'], 'user_reference') - - @patch('requests.get') - def test_user_verify_fail(self, get_mock): - """User Verify Fail - PagerDutyIncidentOutput""" - get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_users": [{"not_id": "verified_user_id"}]}') - get_mock.return_value.json.return_value = json_check - - user_verified = self.__dispatcher._user_verify('valid_user') - assert_false(user_verified) - - @patch('requests.get') - def test_policy_verify_success_no_default(self, get_mock): - """Policy Verify Success (No Default) - PagerDutyIncidentOutput""" - # /escalation_policies - get_mock.return_value.status_code = 200 - json_check = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') - get_mock.return_value.json.return_value = json_check - - policy_verified = self.__dispatcher._policy_verify('valid_policy', '') - assert_equal(policy_verified['id'], 'good_policy_id') - assert_equal(policy_verified['type'], 'escalation_policy_reference') - - @patch('requests.get') - def test_policy_verify_success_default(self, get_mock): - """Policy Verify Success (Default) - PagerDutyIncidentOutput""" - # /escalation_policies - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_check_bad = json.loads('{"no_escalation_policies": [{"id": "bad_policy_id"}]}') - json_check_good = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') - get_mock.return_value.json.side_effect = [json_check_bad, json_check_good] - - policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') - assert_equal(policy_verified['id'], 'good_policy_id') - assert_equal(policy_verified['type'], 'escalation_policy_reference') - - @patch('requests.get') - def test_policy_verify_fail_default(self, get_mock): - """Policy Verify Fail (Default) - PagerDutyIncidentOutput""" - # /not_escalation_policies - type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400]) - json_check_bad = json.loads('{"escalation_policies": [{"id": "bad_policy_id"}]}') - json_check_bad_default = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') - get_mock.return_value.json.side_effect = [json_check_bad, json_check_bad_default] - policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') - assert_false(policy_verified) - - @patch('requests.get') - def test_policy_verify_fail_no_default(self, get_mock): - """Policy Verify Fail (No Default) - PagerDutyIncidentOutput""" - # /not_escalation_policies - get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_escalation_policies": [{"not_id": "verified_policy_id"}]}') - get_mock.return_value.json.return_value = json_check - - policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') - assert_false(policy_verified) - - @patch('requests.get') - def test_service_verify_success(self, get_mock): - """Service Verify Success - PagerDutyIncidentOutput""" - # /services - get_mock.return_value.status_code = 200 - json_check = json.loads('{"services": [{"id": "verified_service_id"}]}') - get_mock.return_value.json.return_value = json_check - - service_verified = self.__dispatcher._service_verify('valid_service') - assert_equal(service_verified['id'], 'verified_service_id') - assert_equal(service_verified['type'], 'service_reference') - - @patch('requests.get') - def test_service_verify_fail(self, get_mock): - """Service Verify Fail - PagerDutyIncidentOutput""" - get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_services": [{"not_id": "verified_service_id"}]}') - get_mock.return_value.json.return_value = json_check - - service_verified = self.__dispatcher._service_verify('valid_service') - assert_false(service_verified) - - @patch('requests.get') - def test_item_verify_success(self, get_mock): - """Item Verify Success - PagerDutyIncidentOutput""" - # /items - get_mock.return_value.status_code = 200 - json_check = json.loads('{"items": [{"id": "verified_item_id"}]}') - get_mock.return_value.json.return_value = json_check - - item_verified = self.__dispatcher._item_verify('valid_item', 'items', 'item_reference') - assert_equal(item_verified['id'], 'verified_item_id') - assert_equal(item_verified['type'], 'item_reference') - - @patch('requests.get') - def test_item_verify_no_get_id_success(self, get_mock): - """Item Verify No Get Id Success - PagerDutyIncidentOutput""" - # /items - get_mock.return_value.status_code = 200 - json_check = json.loads('{"items": [{"id": "verified_item_id"}]}') - get_mock.return_value.json.return_value = json_check - - assert_true(self.__dispatcher._item_verify('valid_item', 'items', 'item_reference', False)) - - @patch('requests.get') - def test_incident_assignment_user(self, get_mock): - """Incident Assignment User - PagerDutyIncidentOutput""" - context = {'assigned_user': 'user_to_assign'} - get_mock.return_value.status_code = 200 - json_user = json.loads('{"users": [{"id": "verified_user_id"}]}') - get_mock.return_value.json.return_value = json_user - - assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) - - assert_equal(assigned_key, 'assignments') - assert_equal(assigned_value[0]['assignee']['id'], 'verified_user_id') - assert_equal(assigned_value[0]['assignee']['type'], 'user_reference') - - @patch('requests.get') - def test_incident_assignment_policy_no_default(self, get_mock): - """Incident Assignment Policy (No Default) - PagerDutyIncidentOutput""" - context = {'assigned_policy': 'policy_to_assign'} - get_mock.return_value.status_code = 200 - json_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') - get_mock.return_value.json.return_value = json_policy - - assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) - - assert_equal(assigned_key, 'escalation_policy') - assert_equal(assigned_value['id'], 'verified_policy_id') - assert_equal(assigned_value['type'], 'escalation_policy_reference') - - @patch('requests.get') - def test_incident_assignment_policy_default(self, get_mock): - """Incident Assignment Policy (Default) - PagerDutyIncidentOutput""" - context = {'assigned_policy': 'bad_invalid_policy_to_assign'} - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_bad_policy = json.loads('{"not_escalation_policies": [{"id": "bad_policy_id"}]}') - json_good_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') - get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy] - - assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) - - assert_equal(assigned_key, 'escalation_policy') - assert_equal(assigned_value['id'], 'verified_policy_id') - assert_equal(assigned_value['type'], 'escalation_policy_reference') - - @patch('requests.get') - def test_item_verify_fail(self, get_mock): - """Item Verify Fail - PagerDutyIncidentOutput""" - # /not_items - get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_items": [{"not_id": "verified_item_id"}]}') - get_mock.return_value.json.return_value = json_check - - item_verified = self.__dispatcher._item_verify('http://mock_url', 'valid_item', - 'items', 'item_reference') - assert_false(item_verified) - - @patch('logging.Logger.info') - @patch('requests.post') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_good_user(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Good User""" - ctx = { - 'pagerduty-incident': { - 'assigned_user': 'valid_user' - } - } - alert = self._setup_dispatch(context=ctx) - - # /users, /users, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) - json_user = json.loads('{"users": [{"id": "valid_user_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') - get_mock.return_value.json.side_effect = [json_user, json_user, json_service] - - # /incidents - post_mock.return_value.status_code = 200 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.info') - @patch('requests.post') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_good_policy(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Good Policy""" - ctx = { - 'pagerduty-incident': { - 'assigned_policy': 'valid_policy' - } - } - alert = self._setup_dispatch(context=ctx) - - # /users, /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) - json_user = json.loads('{"users": [{"id": "user_id"}]}') - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') - get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] - - # /incidents - post_mock.return_value.status_code = 200 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.info') - @patch('requests.post') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_bad_user(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Bad User""" - ctx = { - 'pagerduty-incident': { - 'assigned_user': 'invalid_user' - } - } - alert = self._setup_dispatch(context=ctx) - - # /users, /users, /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200, 200]) - json_user = json.loads('{"users": [{"id": "user_id"}]}') - json_not_user = json.loads('{"not_users": [{"id": "user_id"}]}') - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') - get_mock.return_value.json.side_effect = [json_user, json_not_user, - json_policy, json_service] - - # /incidents - post_mock.return_value.status_code = 200 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.info') - @patch('requests.post') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_no_context(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - No Context""" - alert = self._setup_dispatch() - - # /users, /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) - json_user = json.loads('{"users": [{"id": "user_id"}]}') - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') - get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] - - # /incidents - post_mock.return_value.status_code = 200 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.post') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_error_mock): - """PagerDutyIncidentOutput dispatch failure - No User, Bad Policy, Bad Service""" - alert = self._setup_dispatch() - # /users, /users, /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 400, 400, 400]) - json_user = json.loads('{"users": [{"id": "user_id"}]}') - json_empty = json.loads('{}') - get_mock.return_value.json.side_effect = [json_user, json_empty, json_empty, json_empty] - - # /incidents - post_mock.return_value.status_code = 400 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - @patch('logging.Logger.info') - @patch('requests.post') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_bad_policy(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Bad Policy""" - ctx = { - 'pagerduty-incident': { - 'assigned_policy': 'valid_policy' - } - } - alert = self._setup_dispatch(context=ctx) - # /users, /escalation_policies, /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 400, 200, 200]) - json_user = json.loads('{"users": [{"id": "user_id"}]}') - json_bad_policy = json.loads('{}') - json_good_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') - get_mock.return_value.json.side_effect = [json_user, json_bad_policy, - json_good_policy, json_service] - - # /incidents - post_mock.return_value.status_code = 200 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.post') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_error_mock): - """PagerDutyIncidentOutput dispatch - Bad Dispatch""" - alert = self._setup_dispatch() - # /users, /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) - json_user = json.loads('{"users": [{"id": "user_id"}]}') - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') - get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] - - # /incidents - post_mock.return_value.status_code = 400 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_bad_email(self, get_mock, log_error_mock): - """PagerDutyIncidentOutput dispatch - Bad Email""" - alert = self._setup_dispatch() - # /users, /escalation_policies, /services - get_mock.return_value.status_code = 400 - json_user = json.loads('{"not_users": [{"id": "no_user_id"}]}') - get_mock.return_value.json.return_value = json_user - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """PagerDutyIncidentOutput dispatch - Bad Descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - -@mock_s3 -@mock_kms -class TestPhantomOutput(object): - """Test class for PhantomOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'phantom' - cls.__descriptor = 'unit_test_phantom' - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def _setup_dispatch(self, url): - """Helper for setting up PhantomOutput dispatch""" - remove_temp_secrets() - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': url, - 'ph_auth_token': 'mocked_auth_token'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() - - @patch('logging.Logger.info') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_existing_container(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch success, existing container""" - alert = self._setup_dispatch('http://phantom.foo.bar') - # _check_container_exists - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"count": 1, "data": [{"id": 1948}]}') - # dispatch - post_mock.return_value.status_code = 200 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.info') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_new_container(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch success, new container""" - alert = self._setup_dispatch('http://phantom.foo.bar') - # _check_container_exists - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"count": 0, "data": []}') - # _setup_container, dispatch - post_mock.return_value.status_code = 200 - post_mock.return_value.json.return_value = json.loads('{"id": 1948}') - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_container_failure(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch failure (setup container)""" - alert = self._setup_dispatch('http://phantom.foo.bar') - # _check_container_exists - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"count": 0, "data": []}') - # _setup_container - post_mock.return_value.status_code = 400 - json_error = json.loads('{"message": "error message", "errors": ["error1"]}') - post_mock.return_value.json.return_value = json_error - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Failed to send alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_check_container_error(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch decode error (check container)""" - alert = self._setup_dispatch('http://phantom.foo.bar') - # _check_container_exists - get_mock.return_value.status_code = 200 - get_mock.return_value.text = '{}' - # _setup_container - post_mock.return_value.status_code = 400 - json_error = json.loads('{"message": "error message", "errors": ["error1"]}') - post_mock.return_value.json.return_value = json_error - - result = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) - - @patch('logging.Logger.error') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_setup_container_error(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch decode error (setup container)""" - alert = self._setup_dispatch('http://phantom.foo.bar') - # _check_container_exists - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"count": 0, "data": []}') - # _setup_container - post_mock.return_value.status_code = 200 - post_mock.return_value.json.return_value = json.loads('{}') - - - result = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) - - @patch('logging.Logger.error') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_failure(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch failure (artifact)""" - alert = self._setup_dispatch('http://phantom.foo.bar') - # _check_container_exists - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"count": 0, "data": []}') - # _setup_container, dispatch - type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400]) - json_id = json.loads('{"id": 1948}') - json_error = json.loads('{"message": "error message", "errors": ["error1"]}') - post_mock.return_value.json.return_value.side_effect = [json_id, json_error] - - result = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) - - @patch('logging.Logger.error') - def test_dispatch_bad_descriptor(self, log_error_mock): - """PhantomOutput dispatch bad descriptor""" - alert = self._setup_dispatch('http://phantom.foo.bar') - result = self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) - - @patch('stream_alert.alert_processor.output_base.StreamOutputBase._get_request') - @patch('stream_alert.alert_processor.output_base.StreamOutputBase._post_request') - def test_dispatch_container_query(self, post_mock, get_mock): - """PhantomOutput - Container Query URL""" - alert = self._setup_dispatch('http://phantom.foo.bar') - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - full_url = 'http://phantom.foo.bar/rest/container' - params = {'_filter_name': '"rule_name"', 'page_size': 1} - headers = {'ph-auth-token': 'mocked_auth_token'} - get_mock.assert_has_calls([call(full_url, params, headers, False)]) - rule_description = 'Info about this rule and what actions to take' - ph_container = {'name': 'rule_name', 'description': rule_description} - post_mock.assert_has_calls([call(full_url, ph_container, headers, False)]) - - -class TestSlackOutput(object): - """Test class for SlackOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'slack' - cls.__descriptor = 'unit_test_channel' - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def test_format_message_single(self): - """Format Single Message - Slack""" - rule_name = 'test_rule_single' - alert = get_random_alert(25, rule_name) - loaded_message = self.__dispatcher._format_message(rule_name, alert) - - # tests - assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) - assert_equal( - loaded_message['text'], - '*StreamAlert Rule Triggered: test_rule_single*') - assert_equal(len(loaded_message['attachments']), 1) - - def test_format_message_mutliple(self): - """Format Multi-Message - Slack""" - rule_name = 'test_rule_multi-part' - alert = get_random_alert(30, rule_name) - loaded_message = self.__dispatcher._format_message(rule_name, alert) - - # tests - assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) - assert_equal( - loaded_message['text'], - '*StreamAlert Rule Triggered: test_rule_multi-part*') - assert_equal(len(loaded_message['attachments']), 2) - assert_equal(loaded_message['attachments'][1] - ['text'].split('\n')[3][1:7], '000028') - - def test_format_message_default_rule_description(self): - """Format Message Default Rule Description - Slack""" - rule_name = 'test_empty_rule_description' - alert = get_random_alert(10, rule_name, True) - loaded_message = self.__dispatcher._format_message(rule_name, alert) - - # tests - default_rule_description = '*Rule Description:*\nNo rule description provided\n' - assert_equal( - loaded_message['attachments'][0]['pretext'], - default_rule_description) - - def test_json_to_slack_mrkdwn_str(self): - """JSON to Slack mrkdwn - simple str""" - simple_str = 'value to format' - result = self.__dispatcher._json_to_slack_mrkdwn(simple_str, 0) - - assert_equal(len(result), 1) - assert_equal(result[0], simple_str) - - def test_json_to_slack_mrkdwn_dict(self): - """JSON to Slack mrkdwn - simple dict""" - simple_dict = OrderedDict([('test_key_01', 'test_value_01'), - ('test_key_02', 'test_value_02')]) - result = self.__dispatcher._json_to_slack_mrkdwn(simple_dict, 0) - - assert_equal(len(result), 2) - assert_equal(result[1], '*test_key_02:* test_value_02') - - def test_json_to_slack_mrkdwn_nested_dict(self): - """JSON to Slack mrkdwn - nested dict""" - nested_dict = OrderedDict([ - ('root_key_01', 'root_value_01'), - ('root_02', 'root_value_02'), - ('root_nested_01', OrderedDict([ - ('nested_key_01', 100), - ('nested_key_02', 200), - ('nested_nested_01', OrderedDict([ - ('nested_nested_key_01', 300) - ])) - ])) - ]) - result = self.__dispatcher._json_to_slack_mrkdwn(nested_dict, 0) - assert_equal(len(result), 7) - assert_equal(result[2], '*root_nested_01:*') - assert_equal(Counter(result[4])['\t'], 1) - assert_equal(Counter(result[6])['\t'], 2) - - def test_json_to_slack_mrkdwn_list(self): - """JSON to Slack mrkdwn - simple list""" - simple_list = ['test_value_01', 'test_value_02'] - result = self.__dispatcher._json_to_slack_mrkdwn(simple_list, 0) - - assert_equal(len(result), 2) - assert_equal(result[0], '*[1]* test_value_01') - assert_equal(result[1], '*[2]* test_value_02') - - def test_json_to_slack_mrkdwn_multi_nested(self): - """JSON to Slack mrkdwn - multi type nested""" - nested_dict = OrderedDict([ - ('root_key_01', 'root_value_01'), - ('root_02', 'root_value_02'), - ('root_nested_01', OrderedDict([ - ('nested_key_01', 100), - ('nested_key_02', 200), - ('nested_nested_01', OrderedDict([ - ('nested_nested_key_01', [ - 6161, - 1051, - 51919 - ]) - ])) - ])) - ]) - result = self.__dispatcher._json_to_slack_mrkdwn(nested_dict, 0) - assert_equal(len(result), 10) - assert_equal(result[2], '*root_nested_01:*') - assert_equal(Counter(result[4])['\t'], 1) - assert_equal(result[-1], '\t\t\t*[3]* 51919') - - def test_json_list_to_text(self): - """JSON list to text""" - simple_list = ['test_value_01', 'test_value_02'] - result = self.__dispatcher._json_list_to_text(simple_list, '\t', 0) - - assert_equal(len(result), 2) - assert_equal(result[0], '*[1]* test_value_01') - assert_equal(result[1], '*[2]* test_value_02') - - def test_json_map_to_text(self): - """JSON map to text""" - simple_dict = OrderedDict([('test_key_01', 'test_value_01'), - ('test_key_02', 'test_value_02')]) - result = self.__dispatcher._json_map_to_text(simple_dict, '\t', 0) - - assert_equal(len(result), 2) - assert_equal(result[1], '*test_key_02:* test_value_02') - - def _setup_dispatch(self): - """Helper for setting up SlackOutput dispatch""" - remove_temp_secrets() - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': 'https://api.slack.com/web-hook-key'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, - REGION, KMS_ALIAS) - - return get_alert() - - @patch('logging.Logger.info') - @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_success(self, url_mock, log_info_mock): - """SlackOutput dispatch success""" - alert = self._setup_dispatch() - url_mock.return_value.status_code = 200 - url_mock.return_value.json.return_value = json.loads('{}') - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.error') - @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_failure(self, url_mock, log_error_mock): - """SlackOutput dispatch failure""" - alert = self._setup_dispatch() - json_error = json.loads('{"message": "error message", "errors": ["error1"]}') - url_mock.return_value.json.return_value = json_error - url_mock.return_value.status_code = 400 - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """SlackOutput dispatch bad descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - - -class TestAWSOutput(object): - """Test class for AWSOutput Base""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - # pylint: disable=abstract-class-instantiated - cls.__abstractmethods_cache = outputs.AWSOutput.__abstractmethods__ - outputs.AWSOutput.__abstractmethods__ = frozenset() - cls.__dispatcher = outputs.AWSOutput(REGION, FUNCTION_NAME, CONFIG) - cls.__dispatcher.__service__ = 'aws-s3' - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - outputs.AWSOutput.__abstractmethods__ = cls.__abstractmethods_cache - cls.__dispatcher = None - - def test_aws_format_output_config(self): - """AWSOutput format output config""" - props = { - 'descriptor': OutputProperty( - 'short_descriptor', - 'descriptor_value'), - 'aws_value': OutputProperty( - 'unique arn value, bucket, etc', - 'bucket.value')} - - formatted_config = self.__dispatcher.format_output_config(CONFIG, props) - - assert_equal(len(formatted_config), 2) - assert_is_not_none(formatted_config.get('descriptor_value')) - assert_is_not_none(formatted_config.get('unit_test_bucket')) - - def test_dispatch(self): - """AWSOutput dispatch pass""" - passed = self.__dispatcher.dispatch() - assert_is_none(passed) - - -class TestS3Ouput(object): - """Test class for S3Output""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'aws-s3' - cls.__descriptor = 'unit_test_bucket' - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.dispatcher = None - - def test_locals(self): - """S3Output local variables""" - assert_equal(self.__dispatcher.__class__.__name__, 'S3Output') - assert_equal(self.__dispatcher.__service__, self.__service) - - def _setup_dispatch(self): - """Helper for setting up S3Output dispatch""" - bucket = CONFIG[self.__service][self.__descriptor] - boto3.client('s3', region_name=REGION).create_bucket(Bucket=bucket) - - return get_alert() - - @patch('logging.Logger.info') - @mock_s3 - def test_dispatch(self, log_mock): - """S3Output dispatch""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - -class TestFirehoseOutput(object): - """Test class for AWS Kinesis Firehose""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'aws-firehose' - cls.__descriptor = 'unit_test_delivery_stream' - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.dispatcher = None - - def test_locals(self): - """Output local variables - Kinesis Firehose""" - assert_equal(self.__dispatcher.__class__.__name__, 'KinesisFirehoseOutput') - assert_equal(self.__dispatcher.__service__, self.__service) - - def _setup_dispatch(self): - """Helper for setting up S3Output dispatch""" - delivery_stream = CONFIG[self.__service][self.__descriptor] - boto3.client('firehose', region_name=REGION).create_delivery_stream( - DeliveryStreamName=delivery_stream, - S3DestinationConfiguration={ - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::unit_test', - 'Prefix': '/', - 'BufferingHints': { - 'SizeInMBs': 128, - 'IntervalInSeconds': 128 - }, - 'CompressionFormat': 'GZIP', - } - ) - - return get_alert() - - @patch('logging.Logger.info') - @mock_kinesis - def test_dispatch(self, log_mock): - """Output Dispatch - Kinesis Firehose""" - alert = self._setup_dispatch() - resp = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - assert_true(resp) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @mock_kinesis - def test_dispatch_ignore_large_payload(self): - """Output Dispatch - Kinesis Firehose with Large Payload""" - alert = self._setup_dispatch() - alert['record'] = 'test' * 1000 * 1000 - resp = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - assert_false(resp) - - -class TestLambdaOuput(object): - """Test class for LambdaOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'aws-lambda' - cls.__descriptor = 'unit_test_lambda' - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.dispatcher = None - - def test_locals(self): - """LambdaOutput local variables""" - assert_equal(self.__dispatcher.__class__.__name__, 'LambdaOutput') - assert_equal(self.__dispatcher.__service__, self.__service) - - def _setup_dispatch(self, alt_descriptor=''): - """Helper for setting up LambdaOutput dispatch""" - function_name = CONFIG[self.__service][alt_descriptor or self.__descriptor] - create_lambda_function(function_name, REGION) - return get_alert() - - @mock_lambda - @patch('logging.Logger.info') - def test_dispatch(self, log_mock): - """LambdaOutput dispatch""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @mock_lambda - @patch('logging.Logger.info') - def test_dispatch_with_qualifier(self, log_mock): - """LambdaOutput dispatch with qualifier""" - alt_descriptor = '{}_qual'.format(self.__descriptor) - alert = self._setup_dispatch(alt_descriptor) - self.__dispatcher.dispatch(descriptor=alt_descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - -@mock_s3 -@mock_kms -class TestJiraOutput(object): - """Test class for PhantomOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'jira' - cls.__descriptor = 'unit_test_jira' - cls.__dispatcher = outputs.get_output_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def _setup_dispatch(self): - """Helper for setting up JiraOutput dispatch""" - remove_temp_secrets() - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'username': 'jira@foo.bar', - 'password': 'jirafoobar', - 'url': 'jira.foo.bar', - 'project_key': 'foobar', - 'issue_type': 'Task', - 'aggregate': 'yes'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() - - @patch('logging.Logger.info') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_issue_new(self, post_mock, get_mock, log_mock): - """JiraOutput dispatch success, new issue""" - alert = self._setup_dispatch() - # setup the request to not find an existing issue - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"issues":[]}') - # setup the auth and successful creation responses - auth_resp = '{"session": {"name": "cookie_name", "value": "cookie_value"}}' - post_mock.return_value.status_code = 200 - post_mock.return_value.json.side_effect = [json.loads(auth_resp), - json.loads('{"id": 5000}')] - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('logging.Logger.info') - @patch('requests.get') - @patch('requests.post') - def test_dispatch_issue_existing(self, post_mock, get_mock, log_mock): - """JiraOutput dispatch success, existing issue""" - alert = self._setup_dispatch() - # setup the request to find an existing issue - get_mock.return_value.status_code = 200 - existing_issues = '{"issues": [{"fields": {"summary": "Bogus"}, "id": "5000"}]}' - get_mock.return_value.json.return_value = json.loads(existing_issues) - auth_resp = '{"session": {"name": "cookie_name", "value": "cookie_value"}}' - # setup the auth and successful creation responses - post_mock.return_value.status_code = 200 - post_mock.return_value.json.side_effect = [json.loads(auth_resp), - json.loads('{"id": 5000}')] - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) - - @patch('requests.get') - def test_get_comments_success(self, get_mock): - """JiraOutput get comments success""" - # setup successful get comments response - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"comments": [{},{}]}') - - self.__dispatcher._load_creds('jira') - resp = self.__dispatcher._get_comments('5000') - assert_equal(resp, [{}, {}]) - - @patch('requests.get') - def test_get_comments_failure(self, get_mock): - """JiraOutput get comments failure""" - # setup successful get comments response - get_mock.return_value.status_code = 400 - - self.__dispatcher._load_creds('jira') - resp = self.__dispatcher._get_comments('5000') - assert_equal(resp, []) - - @patch('requests.get') - def test_search_failure(self, get_mock): - """JiraOutput search failure""" - # setup successful get comments response - get_mock.return_value.status_code = 400 - - self.__dispatcher._load_creds('jira') - resp = self.__dispatcher._search_jira('foobar') - assert_equal(resp, []) - - @patch('logging.Logger.error') - @patch('requests.post') - def test_auth_failure(self, post_mock, log_mock): - """JiraOutput auth failure""" - alert = self._setup_dispatch() - - # setup unsuccesful auth response - post_mock.return_value.status_code = 400 - post_mock.return_value.content = '{}' - post_mock.return_value.json.return_value = json.loads('{}') - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_has_calls([call('Encountered an error while sending to %s:\n%s', - 'jira', '{}'), - call('Failed to authenticate to Jira'), - call('Failed to send alert to %s', self.__service)]) - - @patch('logging.Logger.error') - @patch('requests.get') - @patch('requests.post') - def test_issue_creation_failure(self, post_mock, get_mock, log_mock): - """JiraOutput issue creation failure""" - alert = self._setup_dispatch() - # setup the successful search response - no results - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = json.loads('{"issues": []}') - # setup successful auth response and failed issue creation - auth_resp = '{"session": {"name": "cookie_name", "value": "cookie_value"}}' - type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400]) - post_mock.return_value.content = '{}' - post_mock.return_value.json.side_effect = [json.loads(auth_resp), - json.loads('{}')] - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_has_calls([call('Encountered an error while sending to %s:\n%s', - self.__service, '{}'), - call('Failed to send alert to %s', self.__service)]) - - @patch('logging.Logger.error') - @patch('requests.get') - @patch('requests.post') - def test_comment_creation_failure(self, post_mock, get_mock, log_mock): - """JiraOutput comment creation failure""" - alert = self._setup_dispatch() - # setup successful search response - get_mock.return_value.status_code = 200 - existing_issues = '{"issues": [{"fields": {"summary": "Bogus"}, "id": "5000"}]}' - get_mock.return_value.json.return_value = json.loads(existing_issues) - auth_resp = '{"session": {"name": "cookie_name", "value": "cookie_value"}}' - # setup successful auth, failed comment creation, and successful issue creation - type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400, 200]) - post_mock.return_value.content = '{}' - post_mock.return_value.json.side_effect = [json.loads(auth_resp), - json.loads('{"id": 6000}')] - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Encountered an error when adding alert to existing Jira ' - 'issue %s. Attempting to create new Jira issue.', 5000) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/__init__.py b/tests/unit/stream_alert_alert_processor/test_outputs/__init__.py new file mode 100644 index 000000000..48564fa87 --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/__init__.py @@ -0,0 +1,21 @@ +''' +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +''' +from stream_alert.alert_processor.main import _load_output_config + +REGION = 'us-east-1' +FUNCTION_NAME = 'corp-prefix_prod_streamalert_alert_processor' +CONFIG = _load_output_config('tests/unit/conf/outputs.json') +KMS_ALIAS = 'alias/stream_alert_secrets_test' diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py new file mode 100644 index 000000000..3ce2b19b0 --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py @@ -0,0 +1,215 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=abstract-class-instantiated,attribute-defined-outside-init +import boto3 +from mock import patch +from moto import mock_s3, mock_lambda, mock_kinesis +from nose.tools import ( + assert_equal, + assert_false, + assert_is_not_none, + assert_true +) + +from stream_alert.alert_processor.outputs.output_base import ( + OutputProperty, + StreamAlertOutput +) +from stream_alert.alert_processor.outputs import aws +from stream_alert_cli.helpers import create_lambda_function +from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, REGION +from tests.unit.stream_alert_alert_processor.helpers import get_alert + + +class TestAWSOutput(object): + """Test class for AWSOutput Base""" + + @staticmethod + @patch.object(aws.AWSOutput, '__service__', 'aws-s3') + def test_aws_format_output_config(): + """AWSOutput - Format Output Config""" + props = { + 'descriptor': OutputProperty( + 'short_descriptor', + 'descriptor_value'), + 'aws_value': OutputProperty( + 'unique arn value, bucket, etc', + 'bucket.value')} + + formatted_config = aws.AWSOutput.format_output_config(CONFIG, props) + + assert_equal(len(formatted_config), 2) + assert_is_not_none(formatted_config.get('descriptor_value')) + assert_is_not_none(formatted_config.get('unit_test_bucket')) + + +class TestS3Ouput(object): + """Test class for S3Output""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'aws-s3' + cls.__descriptor = 'unit_test_bucket' + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.dispatcher = None + + def test_locals(self): + """S3Output local variables""" + assert_equal(self.__dispatcher.__class__.__name__, 'S3Output') + assert_equal(self.__dispatcher.__service__, self.__service) + + def _setup_dispatch(self): + """Helper for setting up S3Output dispatch""" + bucket = CONFIG[self.__service][self.__descriptor] + boto3.client('s3', region_name=REGION).create_bucket(Bucket=bucket) + + return get_alert() + + @patch('logging.Logger.info') + @mock_s3 + def test_dispatch(self, log_mock): + """S3Output dispatch""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + +class TestFirehoseOutput(object): + """Test class for AWS Kinesis Firehose""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'aws-firehose' + cls.__descriptor = 'unit_test_delivery_stream' + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.dispatcher = None + + def test_locals(self): + """Output local variables - Kinesis Firehose""" + assert_equal(self.__dispatcher.__class__.__name__, 'KinesisFirehoseOutput') + assert_equal(self.__dispatcher.__service__, self.__service) + + def _setup_dispatch(self): + """Helper for setting up S3Output dispatch""" + delivery_stream = CONFIG[self.__service][self.__descriptor] + boto3.client('firehose', region_name=REGION).create_delivery_stream( + DeliveryStreamName=delivery_stream, + S3DestinationConfiguration={ + 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', + 'BucketARN': 'arn:aws:s3:::unit_test', + 'Prefix': '/', + 'BufferingHints': { + 'SizeInMBs': 128, + 'IntervalInSeconds': 128 + }, + 'CompressionFormat': 'GZIP', + } + ) + + return get_alert() + + @patch('logging.Logger.info') + @mock_kinesis + def test_dispatch(self, log_mock): + """Output Dispatch - Kinesis Firehose""" + alert = self._setup_dispatch() + resp = self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + assert_true(resp) + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @mock_kinesis + def test_dispatch_ignore_large_payload(self): + """Output Dispatch - Kinesis Firehose with Large Payload""" + alert = self._setup_dispatch() + alert['record'] = 'test' * 1000 * 1000 + resp = self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + assert_false(resp) + + + +class TestLambdaOuput(object): + """Test class for LambdaOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'aws-lambda' + cls.__descriptor = 'unit_test_lambda' + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.dispatcher = None + + def test_locals(self): + """LambdaOutput local variables""" + assert_equal(self.__dispatcher.__class__.__name__, 'LambdaOutput') + assert_equal(self.__dispatcher.__service__, self.__service) + + def _setup_dispatch(self, alt_descriptor=''): + """Helper for setting up LambdaOutput dispatch""" + function_name = CONFIG[self.__service][alt_descriptor or self.__descriptor] + create_lambda_function(function_name, REGION) + return get_alert() + + @mock_lambda + @patch('logging.Logger.info') + def test_dispatch(self, log_mock): + """LambdaOutput dispatch""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @mock_lambda + @patch('logging.Logger.info') + def test_dispatch_with_qualifier(self, log_mock): + """LambdaOutput dispatch with qualifier""" + alt_descriptor = '{}_qual'.format(self.__descriptor) + alert = self._setup_dispatch(alt_descriptor) + self.__dispatcher.dispatch(descriptor=alt_descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py new file mode 100644 index 000000000..141f4c62b --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py @@ -0,0 +1,199 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=protected-access +from mock import call, patch, PropertyMock +from moto import mock_s3, mock_kms +from nose.tools import assert_equal + +from stream_alert.alert_processor.outputs.jira import JiraOutput +from stream_alert_cli.helpers import put_mock_creds +from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION +from tests.unit.stream_alert_alert_processor.helpers import ( + get_alert, + remove_temp_secrets +) + + +@mock_s3 +@mock_kms +class TestJiraOutput(object): + """Test class for PhantomOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'jira' + cls.__descriptor = 'unit_test_jira' + cls.__dispatcher = JiraOutput(REGION, FUNCTION_NAME, CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def _setup_dispatch(self): + """Helper for setting up JiraOutput dispatch""" + remove_temp_secrets() + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'username': 'jira@foo.bar', + 'password': 'jirafoobar', + 'url': 'jira.foo.bar', + 'project_key': 'foobar', + 'issue_type': 'Task', + 'aggregate': 'yes'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + return get_alert() + + @patch('logging.Logger.info') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_issue_new(self, post_mock, get_mock, log_mock): + """JiraOutput dispatch success, new issue""" + alert = self._setup_dispatch() + # setup the request to not find an existing issue + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'issues': []} + # setup the auth and successful creation responses + auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} + post_mock.return_value.status_code = 200 + post_mock.return_value.json.side_effect = [auth_resp, {'id': 5000}] + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_issue_existing(self, post_mock, get_mock, log_mock): + """JiraOutput dispatch success, existing issue""" + alert = self._setup_dispatch() + # setup the request to find an existing issue + get_mock.return_value.status_code = 200 + existing_issues = {'issues': [{'fields': {'summary': 'Bogus'}, 'id': '5000'}]} + get_mock.return_value.json.return_value = existing_issues + auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} + # setup the auth and successful creation responses + post_mock.return_value.status_code = 200 + post_mock.return_value.json.side_effect = [auth_resp, {'id': 5000}] + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('requests.get') + def test_get_comments_success(self, get_mock): + """JiraOutput get comments success""" + # setup successful get comments response + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'comments': [{}, {}]} + + self.__dispatcher._load_creds('jira') + resp = self.__dispatcher._get_comments('5000') + assert_equal(resp, [{}, {}]) + + @patch('requests.get') + def test_get_comments_failure(self, get_mock): + """JiraOutput get comments failure""" + # setup successful get comments response + get_mock.return_value.status_code = 400 + + self.__dispatcher._load_creds('jira') + resp = self.__dispatcher._get_comments('5000') + assert_equal(resp, []) + + @patch('requests.get') + def test_search_failure(self, get_mock): + """JiraOutput search failure""" + # setup successful get comments response + get_mock.return_value.status_code = 400 + + self.__dispatcher._load_creds('jira') + resp = self.__dispatcher._search_jira('foobar') + assert_equal(resp, []) + + @patch('logging.Logger.error') + @patch('requests.post') + def test_auth_failure(self, post_mock, log_mock): + """JiraOutput auth failure""" + alert = self._setup_dispatch() + + # setup unsuccesful auth response + post_mock.return_value.status_code = 400 + post_mock.return_value.content = '{}' + post_mock.return_value.json.return_value = dict() + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_has_calls([call('Encountered an error while sending to %s:\n%s', + 'jira', '{}'), + call('Failed to authenticate to Jira'), + call('Failed to send alert to %s', self.__service)]) + + @patch('logging.Logger.error') + @patch('requests.get') + @patch('requests.post') + def test_issue_creation_failure(self, post_mock, get_mock, log_mock): + """JiraOutput issue creation failure""" + alert = self._setup_dispatch() + # setup the successful search response - no results + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'issues': []} + # setup successful auth response and failed issue creation + auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} + type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400]) + post_mock.return_value.content = '{}' + post_mock.return_value.json.side_effect = [auth_resp, dict()] + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_has_calls([call('Encountered an error while sending to %s:\n%s', + self.__service, '{}'), + call('Failed to send alert to %s', self.__service)]) + + @patch('logging.Logger.error') + @patch('requests.get') + @patch('requests.post') + def test_comment_creation_failure(self, post_mock, get_mock, log_mock): + """JiraOutput comment creation failure""" + alert = self._setup_dispatch() + # setup successful search response + get_mock.return_value.status_code = 200 + existing_issues = {'issues': [{'fields': {'summary': 'Bogus'}, 'id': '5000'}]} + get_mock.return_value.json.return_value = existing_issues + auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} + # setup successful auth, failed comment creation, and successful issue creation + type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400, 200]) + post_mock.return_value.content = '{}' + post_mock.return_value.json.side_effect = [auth_resp, {'id': 6000}] + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Encountered an error when adding alert to existing Jira ' + 'issue %s. Attempting to create new Jira issue.', 5000) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py new file mode 100644 index 000000000..6255a914b --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py @@ -0,0 +1,207 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=abstract-class-instantiated,protected-access,attribute-defined-outside-init +import os + +from mock import patch +from moto import mock_kms, mock_s3 +from nose.tools import ( + assert_equal, + assert_is_instance, + assert_is_not_none, + assert_is_none, + assert_items_equal +) + +from stream_alert.alert_processor.outputs.output_base import ( + OutputDispatcher, + OutputProperty, + StreamAlertOutput +) +from stream_alert.alert_processor.outputs.aws import S3Output +from stream_alert_cli.helpers import encrypt_with_kms, put_mock_creds, put_mock_s3_object +from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION +from tests.unit.stream_alert_alert_processor.helpers import remove_temp_secrets + + +def test_output_property_default(): + """OutputProperty defaults""" + prop = OutputProperty() + + assert_equal(prop.description, '') + assert_equal(prop.value, '') + assert_equal(prop.input_restrictions, {' ', ':'}) + assert_equal(prop.mask_input, False) + assert_equal(prop.cred_requirement, False) + + +def test_get_dispatcher_good(): + """StreamAlertOutput - Get Valid Dispatcher""" + dispatcher = StreamAlertOutput.get_dispatcher('aws-s3') + assert_is_not_none(dispatcher) + + +@patch('logging.Logger.error') +def test_get_dispatcher_bad(log_mock): + """StreamAlertOutput - Get Invalid Dispatcher""" + dispatcher = StreamAlertOutput.get_dispatcher('aws-s4') + assert_is_none(dispatcher) + log_mock.assert_called_with('Designated output service [%s] does not exist', 'aws-s4') + + +def test_create_dispatcher(): + """StreamAlertOutput - Create Dispatcher""" + dispatcher = StreamAlertOutput.create_dispatcher( + 'aws-s3', + REGION, + FUNCTION_NAME, + CONFIG + ) + assert_is_instance(dispatcher, S3Output) + + +def test_user_defined_properties(): + """OutputDispatcher - User Defined Properties""" + for output in StreamAlertOutput.get_all_outputs().values(): + props = output.get_user_defined_properties() + # The user defined properties should at a minimum contain a descriptor + assert_is_not_none(props.get('descriptor')) + +def test_output_loading(): + """OutputDispatcher - Loading Output Classes""" + loaded_outputs = set(StreamAlertOutput.get_all_outputs()) + # Add new outputs to this list to make sure they're loaded properly + expected_outputs = { + 'aws-firehose', + 'aws-lambda', + 'aws-s3', + 'jira', + 'pagerduty', + 'pagerduty-v2', + 'pagerduty-incident', + 'phantom', + 'slack' + } + assert_items_equal(loaded_outputs, expected_outputs) + + +@patch.object(OutputDispatcher, '__service__', 'test_service') +class TestOutputDispatcher(object): + """Test class for OutputDispatcher""" + + @patch.object(OutputDispatcher, '__abstractmethods__', frozenset()) + def setup(self): + """Setup before each method""" + self._dispatcher = OutputDispatcher(REGION, FUNCTION_NAME, CONFIG) + self._descriptor = 'desc_test' + + def test_local_temp_dir(self): + """OutputDispatcher - Local Temp Dir""" + temp_dir = self._dispatcher._local_temp_dir() + assert_equal(temp_dir.split('/')[-1], 'stream_alert_secrets') + + def test_get_secrets_bucket_name(self): + """OutputDispatcher - Get Secrets Bucket Name""" + bucket_name = self._dispatcher._get_secrets_bucket_name(FUNCTION_NAME) + assert_equal(bucket_name, 'corp-prefix.streamalert.secrets') + + def test_output_cred_name(self): + """OutputDispatcher - Output Cred Name""" + output_name = self._dispatcher.output_cred_name('creds') + assert_equal(output_name, 'test_service/creds') + + @mock_s3 + def test_get_creds_from_s3(self): + """OutputDispatcher - Get Creds From S3""" + test_data = 'credential test string' + + bucket_name = self._dispatcher.secrets_bucket + key = self._dispatcher.output_cred_name(self._descriptor) + + local_cred_location = os.path.join(self._dispatcher._local_temp_dir(), key) + + put_mock_s3_object(bucket_name, key, test_data, REGION) + + self._dispatcher._get_creds_from_s3(local_cred_location, self._descriptor) + + with open(local_cred_location) as creds: + line = creds.readline() + + assert_equal(line, test_data) + + @mock_kms + def test_kms_decrypt(self): + """OutputDispatcher - KMS Decrypt""" + test_data = 'data to encrypt' + encrypted = encrypt_with_kms(test_data, REGION, KMS_ALIAS) + decrypted = self._dispatcher._kms_decrypt(encrypted) + + assert_equal(decrypted, test_data) + + @patch('logging.Logger.info') + def test_log_status_success(self, log_mock): + """OutputDispatcher - Log status success""" + self._dispatcher._log_status(True) + log_mock.assert_called_with('Successfully sent alert to %s', 'test_service') + + @patch('logging.Logger.error') + def test_log_status_failed(self, log_mock): + """OutputDispatcher - Log status failed""" + self._dispatcher._log_status(False) + log_mock.assert_called_with('Failed to send alert to %s', 'test_service') + + @patch('requests.Response') + def test_check_http_response(self, mock_response): + """OutputDispatcher - Check HTTP Response""" + # Test with a good response code + mock_response.status_code = 200 + result = self._dispatcher._check_http_response(mock_response) + assert_equal(result, True) + + # Test with a bad response code + mock_response.status_code = 440 + result = self._dispatcher._check_http_response(mock_response) + assert_equal(result, False) + + @mock_s3 + @mock_kms + def test_load_creds(self): + """OutputDispatcher - Load Credentials""" + remove_temp_secrets() + output_name = self._dispatcher.output_cred_name(self._descriptor) + + creds = {'url': 'http://www.foo.bar/test', + 'token': 'token_to_encrypt'} + + put_mock_creds(output_name, creds, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + loaded_creds = self._dispatcher._load_creds(self._descriptor) + + assert_is_not_none(loaded_creds) + assert_equal(len(loaded_creds), 2) + assert_equal(loaded_creds['url'], u'http://www.foo.bar/test') + assert_equal(loaded_creds['token'], u'token_to_encrypt') + + def test_format_output_config(self): + """OutputDispatcher - Format Output Config""" + with patch.object(OutputDispatcher, '__service__', 'slack'): + props = {'descriptor': OutputProperty('test_desc', 'test_channel')} + + formatted = self._dispatcher.format_output_config(CONFIG, props) + + assert_equal(len(formatted), 2) + assert_equal(formatted[0], 'unit_test_channel') + assert_equal(formatted[1], 'test_channel') diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py new file mode 100644 index 000000000..fe8764c12 --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py @@ -0,0 +1,675 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=protected-access +import json +from mock import patch, PropertyMock +from moto import mock_s3, mock_kms +from nose.tools import ( + assert_equal, + assert_false +) + +from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput +from stream_alert_cli.helpers import put_mock_creds +from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION +from tests.unit.stream_alert_alert_processor.helpers import ( + get_alert, + remove_temp_secrets +) + + +class TestPagerDutyOutput(object): + """Test class for PagerDutyOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'pagerduty' + cls.__descriptor = 'unit_test_pagerduty' + cls.__backup_method = None + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def test_get_default_properties(self): + """Get Default Properties - PagerDuty""" + props = self.__dispatcher._get_default_properties() + assert_equal(len(props), 1) + assert_equal(props['url'], + 'https://events.pagerduty.com/generic/2010-04-15/create_event.json') + + def _setup_dispatch(self): + """Helper for setting up PagerDutyOutput dispatch""" + remove_temp_secrets() + + # Cache the _get_default_properties and set it to return None + self.__backup_method = self.__dispatcher._get_default_properties + self.__dispatcher._get_default_properties = lambda: None + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'url': 'http://pagerduty.foo.bar/create_event.json', + 'service_key': 'mocked_service_key'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + return get_alert() + + def _teardown_dispatch(self): + """Replace method with cached method""" + self.__dispatcher._get_default_properties = self.__backup_method + + @patch('logging.Logger.info') + @patch('requests.post') + @mock_s3 + @mock_kms + def test_dispatch_success(self, post_mock, log_info_mock): + """PagerDutyOutput dispatch success""" + alert = self._setup_dispatch() + post_mock.return_value.status_code = 200 + post_mock.return_value.text = '' + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @mock_s3 + @mock_kms + def test_dispatch_failure(self, post_mock, log_error_mock): + """PagerDutyOutput dispatch failure""" + alert = self._setup_dispatch() + post_mock.return_value.text = '{"message": "error message", "errors": ["error1"]}' + post_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.error') + @mock_s3 + @mock_kms + def test_dispatch_bad_descriptor(self, log_error_mock): + """PagerDutyOutput dispatch bad descriptor""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + +class TestPagerDutyOutputV2(object): + """Test class for PagerDutyOutputV2""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'pagerduty-v2' + cls.__descriptor = 'unit_test_pagerduty-v2' + cls.__backup_method = None + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def test_get_default_properties(self): + """Get Default Properties - PagerDuty""" + props = self.__dispatcher._get_default_properties() + assert_equal(len(props), 1) + assert_equal(props['url'], + 'https://events.pagerduty.com/v2/enqueue') + + def _setup_dispatch(self): + """Helper for setting up PagerDutyOutputV2 dispatch""" + remove_temp_secrets() + + # Cache the _get_default_properties and set it to return None + self.__backup_method = self.__dispatcher._get_default_properties + self.__dispatcher._get_default_properties = lambda: None + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'url': 'http://pagerduty.foo.bar/create_event.json', + 'routing_key': 'mocked_routing_key'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + return get_alert() + + def _teardown_dispatch(self): + """Replace method with cached method""" + self.__dispatcher._get_default_properties = self.__backup_method + + @patch('logging.Logger.info') + @patch('requests.post') + @mock_s3 + @mock_kms + def test_dispatch_success(self, post_mock, log_info_mock): + """PagerDutyOutputV2 dispatch success""" + alert = self._setup_dispatch() + post_mock.return_value.status_code = 200 + post_mock.return_value.text = '' + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @mock_s3 + @mock_kms + def test_dispatch_failure(self, post_mock, log_error_mock): + """PagerDutyOutputV2 dispatch failure""" + alert = self._setup_dispatch() + json_error = json.loads('{"message": "error message", "errors": ["error1"]}') + post_mock.return_value.json.return_value = json_error + post_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.error') + @mock_s3 + @mock_kms + def test_dispatch_bad_descriptor(self, log_error_mock): + """PagerDutyOutputV2 dispatch bad descriptor""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + +class TestPagerDutyIncidentOutput(object): + """Test class for PagerDutyIncidentOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'pagerduty-incident' + cls.__descriptor = 'unit_test_pagerduty-incident' + cls.__backup_method = None + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def test_get_default_properties(self): + """Get Default Properties - PagerDutyIncidentOutput""" + props = self.__dispatcher._get_default_properties() + assert_equal(len(props), 1) + assert_equal(props['api'], + 'https://api.pagerduty.com') + + def test_get_endpoint(self): + """Get Endpoint - PagerDutyIncidentOutput""" + props = self.__dispatcher._get_default_properties() + endpoint = self.__dispatcher._get_endpoint(props['api'], 'testtest') + assert_equal(endpoint, + 'https://api.pagerduty.com/testtest') + + def _setup_dispatch(self, context=None): + """Helper for setting up PagerDutyIncidentOutput dispatch""" + remove_temp_secrets() + + # Cache the _get_default_properties and set it to return None + self.__backup_method = self.__dispatcher._get_default_properties + self.__dispatcher._get_default_properties = lambda: None + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'api': 'https://api.pagerduty.com', + 'token': 'mocked_token', + 'service_key': 'mocked_service_key', + 'escalation_policy': 'mocked_escalation_policy'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + return get_alert(context) + + def _teardown_dispatch(self): + """Replace method with cached method""" + self.__dispatcher._get_default_properties = self.__backup_method + + @patch('requests.get') + def test_check_exists_get_id(self, get_mock): + """Check Exists Get Id - PagerDutyIncidentOutput""" + # /check + get_mock.return_value.status_code = 200 + json_check = json.loads('{"check": [{"id": "checked_id"}]}') + get_mock.return_value.json.return_value = json_check + + checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + assert_equal(checked, 'checked_id') + + @patch('requests.get') + def test_check_exists_get_id_fail(self, get_mock): + """Check Exists Get Id Fail - PagerDutyIncidentOutput""" + # /check + get_mock.return_value.status_code = 200 + json_check = json.loads('{}') + get_mock.return_value.json.return_value = json_check + + checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + assert_false(checked) + + @patch('requests.get') + def test_user_verify_success(self, get_mock): + """User Verify Success - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"users": [{"id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_check + + user_verified = self.__dispatcher._user_verify('valid_user') + assert_equal(user_verified['id'], 'verified_user_id') + assert_equal(user_verified['type'], 'user_reference') + + @patch('requests.get') + def test_user_verify_fail(self, get_mock): + """User Verify Fail - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_users": [{"not_id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_check + + user_verified = self.__dispatcher._user_verify('valid_user') + assert_false(user_verified) + + @patch('requests.get') + def test_policy_verify_success_no_default(self, get_mock): + """Policy Verify Success (No Default) - PagerDutyIncidentOutput""" + # /escalation_policies + get_mock.return_value.status_code = 200 + json_check = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + get_mock.return_value.json.return_value = json_check + + policy_verified = self.__dispatcher._policy_verify('valid_policy', '') + assert_equal(policy_verified['id'], 'good_policy_id') + assert_equal(policy_verified['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_policy_verify_success_default(self, get_mock): + """Policy Verify Success (Default) - PagerDutyIncidentOutput""" + # /escalation_policies + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_check_bad = json.loads('{"no_escalation_policies": [{"id": "bad_policy_id"}]}') + json_check_good = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + get_mock.return_value.json.side_effect = [json_check_bad, json_check_good] + + policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') + assert_equal(policy_verified['id'], 'good_policy_id') + assert_equal(policy_verified['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_policy_verify_fail_default(self, get_mock): + """Policy Verify Fail (Default) - PagerDutyIncidentOutput""" + # /not_escalation_policies + type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400]) + json_check_bad = json.loads('{"escalation_policies": [{"id": "bad_policy_id"}]}') + json_check_bad_default = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + get_mock.return_value.json.side_effect = [json_check_bad, json_check_bad_default] + policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') + assert_false(policy_verified) + + @patch('requests.get') + def test_policy_verify_fail_no_default(self, get_mock): + """Policy Verify Fail (No Default) - PagerDutyIncidentOutput""" + # /not_escalation_policies + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_escalation_policies": [{"not_id": "verified_policy_id"}]}') + get_mock.return_value.json.return_value = json_check + + policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') + assert_false(policy_verified) + + @patch('requests.get') + def test_service_verify_success(self, get_mock): + """Service Verify Success - PagerDutyIncidentOutput""" + # /services + get_mock.return_value.status_code = 200 + json_check = json.loads('{"services": [{"id": "verified_service_id"}]}') + get_mock.return_value.json.return_value = json_check + + service_verified = self.__dispatcher._service_verify('valid_service') + assert_equal(service_verified['id'], 'verified_service_id') + assert_equal(service_verified['type'], 'service_reference') + + @patch('requests.get') + def test_service_verify_fail(self, get_mock): + """Service Verify Fail - PagerDutyIncidentOutput""" + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_services": [{"not_id": "verified_service_id"}]}') + get_mock.return_value.json.return_value = json_check + + service_verified = self.__dispatcher._service_verify('valid_service') + assert_false(service_verified) + + @patch('requests.get') + def test_item_verify_success(self, get_mock): + """Item Verify Success - PagerDutyIncidentOutput""" + # /items + get_mock.return_value.status_code = 200 + json_check = json.loads('{"items": [{"id": "verified_item_id"}]}') + get_mock.return_value.json.return_value = json_check + + item_verified = self.__dispatcher._item_verify('http://mock_url', 'valid_item', + 'items', 'item_reference') + assert_equal(item_verified['id'], 'verified_item_id') + assert_equal(item_verified['type'], 'item_reference') + + @patch('requests.get') + def test_incident_assignment_user(self, get_mock): + """Incident Assignment User - PagerDutyIncidentOutput""" + context = {'assigned_user': 'user_to_assign'} + get_mock.return_value.status_code = 200 + json_user = json.loads('{"users": [{"id": "verified_user_id"}]}') + get_mock.return_value.json.return_value = json_user + + assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + + assert_equal(assigned_key, 'assignments') + assert_equal(assigned_value[0]['assignee']['id'], 'verified_user_id') + assert_equal(assigned_value[0]['assignee']['type'], 'user_reference') + + @patch('requests.get') + def test_incident_assignment_policy_no_default(self, get_mock): + """Incident Assignment Policy (No Default) - PagerDutyIncidentOutput""" + context = {'assigned_policy': 'policy_to_assign'} + get_mock.return_value.status_code = 200 + json_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') + get_mock.return_value.json.return_value = json_policy + + assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + + assert_equal(assigned_key, 'escalation_policy') + assert_equal(assigned_value['id'], 'verified_policy_id') + assert_equal(assigned_value['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_incident_assignment_policy_default(self, get_mock): + """Incident Assignment Policy (Default) - PagerDutyIncidentOutput""" + context = {'assigned_policy': 'bad_invalid_policy_to_assign'} + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_bad_policy = json.loads('{"not_escalation_policies": [{"id": "bad_policy_id"}]}') + json_good_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') + get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy] + + assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + + assert_equal(assigned_key, 'escalation_policy') + assert_equal(assigned_value['id'], 'verified_policy_id') + assert_equal(assigned_value['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_item_verify_fail(self, get_mock): + """Item Verify Fail - PagerDutyIncidentOutput""" + # /not_items + get_mock.return_value.status_code = 200 + json_check = json.loads('{"not_items": [{"not_id": "verified_item_id"}]}') + get_mock.return_value.json.return_value = json_check + + item_verified = self.__dispatcher._item_verify('http://mock_url', 'valid_item', + 'items', 'item_reference') + assert_false(item_verified) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_good_user(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Good User""" + ctx = { + 'pagerduty-incident': { + 'assigned_user': 'valid_user' + } + } + alert = self._setup_dispatch(context=ctx) + + # /users, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_user = json.loads('{"users": [{"id": "valid_user_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_user, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_good_policy(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Good Policy""" + ctx = { + 'pagerduty-incident': { + 'assigned_policy': 'valid_policy' + } + } + alert = self._setup_dispatch(context=ctx) + + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_bad_user(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Bad User""" + ctx = { + 'pagerduty-incident': { + 'assigned_user': 'invalid_user' + } + } + alert = self._setup_dispatch(context=ctx) + + # /users, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) + json_user = json.loads('{"not_users": [{"id": "user_id"}]}') + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_no_context(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - No Context""" + alert = self._setup_dispatch() + + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_error_mock): + """PagerDutyIncidentOutput dispatch failure - No User, Bad Policy, Bad Service""" + alert = self._setup_dispatch() + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400, 400]) + json_empty = json.loads('{}') + get_mock.return_value.json.side_effect = [json_empty, json_empty, json_empty] + + # /incidents + post_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_success_bad_policy(self, get_mock, post_mock, log_info_mock): + """PagerDutyIncidentOutput dispatch success - Bad Policy""" + ctx = { + 'pagerduty-incident': { + 'assigned_policy': 'valid_policy' + } + } + alert = self._setup_dispatch(context=ctx) + # /escalation_policies, /services + get_mock.return_value.side_effect = [400, 200, 200] + type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 200, 200]) + json_bad_policy = json.loads('{}') + json_good_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @patch('requests.get') + @mock_s3 + @mock_kms + def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_error_mock): + """PagerDutyIncidentOutput dispatch - Bad Dispatch""" + alert = self._setup_dispatch() + # /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') + json_service = json.loads('{"services": [{"id": "service_id"}]}') + get_mock.return_value.json.side_effect = [json_policy, json_service] + + # /incidents + post_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.error') + @mock_s3 + @mock_kms + def test_dispatch_bad_descriptor(self, log_error_mock): + """PagerDutyIncidentOutput dispatch - Bad Descriptor""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=alert) + + self._teardown_dispatch() + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py new file mode 100644 index 000000000..12db28cd9 --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py @@ -0,0 +1,205 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +# pylint: disable=protected-access +""" +from mock import call, patch, PropertyMock +from moto import mock_s3, mock_kms +from nose.tools import assert_equal + +from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput +from stream_alert_cli.helpers import put_mock_creds +from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION +from tests.unit.stream_alert_alert_processor.helpers import get_alert, remove_temp_secrets + + +@mock_s3 +@mock_kms +class TestPhantomOutput(object): + """Test class for PhantomOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'phantom' + cls.__descriptor = 'unit_test_phantom' + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def _setup_dispatch(self, url): + """Helper for setting up PhantomOutput dispatch""" + remove_temp_secrets() + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'url': url, + 'ph_auth_token': 'mocked_auth_token'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) + + return get_alert() + + @patch('logging.Logger.info') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_existing_container(self, post_mock, get_mock, log_mock): + """PhantomOutput dispatch success, existing container""" + alert = self._setup_dispatch('http://phantom.foo.bar') + # _check_container_exists + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'count': 1, 'data': [{'id': 1948}]} + # dispatch + post_mock.return_value.status_code = 200 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.info') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_new_container(self, post_mock, get_mock, log_mock): + """PhantomOutput dispatch success, new container""" + alert = self._setup_dispatch('http://phantom.foo.bar') + # _check_container_exists + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'count': 0, 'data': []} + # _setup_container, dispatch + post_mock.return_value.status_code = 200 + post_mock.return_value.json.return_value = {'id': 1948} + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_container_failure(self, post_mock, get_mock, log_mock): + """PhantomOutput dispatch failure (setup container)""" + alert = self._setup_dispatch('http://phantom.foo.bar') + # _check_container_exists + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'count': 0, 'data': []} + # _setup_container + post_mock.return_value.status_code = 400 + json_error = {'message': 'error message', 'errors': ['error1']} + post_mock.return_value.json.return_value = json_error + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_check_container_error(self, post_mock, get_mock, log_mock): + """PhantomOutput dispatch decode error (check container)""" + alert = self._setup_dispatch('http://phantom.foo.bar') + # _check_container_exists + get_mock.return_value.status_code = 200 + get_mock.return_value.text = '{}' + # _setup_container + post_mock.return_value.status_code = 400 + json_error = {'message': 'error message', 'errors': ['error1']} + post_mock.return_value.json.return_value = json_error + + result = self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Failed to send alert to %s', self.__service) + assert_equal(result, False) + + @patch('logging.Logger.error') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_setup_container_error(self, post_mock, get_mock, log_mock): + """PhantomOutput dispatch decode error (setup container)""" + alert = self._setup_dispatch('http://phantom.foo.bar') + # _check_container_exists + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'count': 0, 'data': []} + # _setup_container + post_mock.return_value.status_code = 200 + post_mock.return_value.json.return_value = dict() + + + result = self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Failed to send alert to %s', self.__service) + assert_equal(result, False) + + @patch('logging.Logger.error') + @patch('requests.get') + @patch('requests.post') + def test_dispatch_failure(self, post_mock, get_mock, log_mock): + """PhantomOutput dispatch failure (artifact)""" + alert = self._setup_dispatch('http://phantom.foo.bar') + # _check_container_exists + get_mock.return_value.status_code = 200 + get_mock.return_value.json.return_value = {'count': 0, 'data': []} + # _setup_container, dispatch + type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400]) + json_error = {'message': 'error message', 'errors': ['error1']} + post_mock.return_value.json.return_value.side_effect = [{'id': 1948}, json_error] + + result = self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Failed to send alert to %s', self.__service) + assert_equal(result, False) + + @patch('logging.Logger.error') + def test_dispatch_bad_descriptor(self, log_error_mock): + """PhantomOutput dispatch bad descriptor""" + alert = self._setup_dispatch('http://phantom.foo.bar') + result = self.__dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=alert) + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + assert_equal(result, False) + + @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher._get_request') + @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher._post_request') + def test_dispatch_container_query(self, post_mock, get_mock): + """PhantomOutput - Container Query URL""" + alert = self._setup_dispatch('http://phantom.foo.bar') + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + full_url = 'http://phantom.foo.bar/rest/container' + params = {'_filter_name': '"rule_name"', 'page_size': 1} + headers = {'ph-auth-token': 'mocked_auth_token'} + get_mock.assert_has_calls([call(full_url, params, headers, False)]) + rule_description = 'Info about this rule and what actions to take' + ph_container = {'name': 'rule_name', 'description': rule_description} + post_mock.assert_has_calls([call(full_url, ph_container, headers, False)]) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py new file mode 100644 index 000000000..8ef3a9aac --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py @@ -0,0 +1,235 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=protected-access +from collections import Counter, OrderedDict +import json +from mock import patch +from moto import mock_s3, mock_kms +from nose.tools import ( + assert_equal, + assert_set_equal +) + +from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput +from stream_alert_cli.helpers import put_mock_creds +from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION +from tests.unit.stream_alert_alert_processor.helpers import ( + get_random_alert, + get_alert, + remove_temp_secrets +) + + +class TestSlackOutput(object): + """Test class for SlackOutput""" + @classmethod + def setup_class(cls): + """Setup the class before any methods""" + cls.__service = 'slack' + cls.__descriptor = 'unit_test_channel' + cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, + REGION, + FUNCTION_NAME, + CONFIG) + + @classmethod + def teardown_class(cls): + """Teardown the class after all methods""" + cls.__dispatcher = None + + def test_format_message_single(self): + """Format Single Message - Slack""" + rule_name = 'test_rule_single' + alert = get_random_alert(25, rule_name) + loaded_message = self.__dispatcher._format_message(rule_name, alert) + + # tests + assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) + assert_equal( + loaded_message['text'], + '*StreamAlert Rule Triggered: test_rule_single*') + assert_equal(len(loaded_message['attachments']), 1) + + def test_format_message_mutliple(self): + """Format Multi-Message - Slack""" + rule_name = 'test_rule_multi-part' + alert = get_random_alert(30, rule_name) + loaded_message = self.__dispatcher._format_message(rule_name, alert) + + # tests + assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) + assert_equal( + loaded_message['text'], + '*StreamAlert Rule Triggered: test_rule_multi-part*') + assert_equal(len(loaded_message['attachments']), 2) + assert_equal(loaded_message['attachments'][1] + ['text'].split('\n')[3][1:7], '000028') + + def test_format_message_default_rule_description(self): + """Format Message Default Rule Description - Slack""" + rule_name = 'test_empty_rule_description' + alert = get_random_alert(10, rule_name, True) + loaded_message = self.__dispatcher._format_message(rule_name, alert) + + # tests + default_rule_description = '*Rule Description:*\nNo rule description provided\n' + assert_equal( + loaded_message['attachments'][0]['pretext'], + default_rule_description) + + def test_json_to_slack_mrkdwn_str(self): + """JSON to Slack mrkdwn - simple str""" + simple_str = 'value to format' + result = self.__dispatcher._json_to_slack_mrkdwn(simple_str, 0) + + assert_equal(len(result), 1) + assert_equal(result[0], simple_str) + + def test_json_to_slack_mrkdwn_dict(self): + """JSON to Slack mrkdwn - simple dict""" + simple_dict = OrderedDict([('test_key_01', 'test_value_01'), + ('test_key_02', 'test_value_02')]) + result = self.__dispatcher._json_to_slack_mrkdwn(simple_dict, 0) + + assert_equal(len(result), 2) + assert_equal(result[1], '*test_key_02:* test_value_02') + + def test_json_to_slack_mrkdwn_nested_dict(self): + """JSON to Slack mrkdwn - nested dict""" + nested_dict = OrderedDict([ + ('root_key_01', 'root_value_01'), + ('root_02', 'root_value_02'), + ('root_nested_01', OrderedDict([ + ('nested_key_01', 100), + ('nested_key_02', 200), + ('nested_nested_01', OrderedDict([ + ('nested_nested_key_01', 300) + ])) + ])) + ]) + result = self.__dispatcher._json_to_slack_mrkdwn(nested_dict, 0) + assert_equal(len(result), 7) + assert_equal(result[2], '*root_nested_01:*') + assert_equal(Counter(result[4])['\t'], 1) + assert_equal(Counter(result[6])['\t'], 2) + + def test_json_to_slack_mrkdwn_list(self): + """JSON to Slack mrkdwn - simple list""" + simple_list = ['test_value_01', 'test_value_02'] + result = self.__dispatcher._json_to_slack_mrkdwn(simple_list, 0) + + assert_equal(len(result), 2) + assert_equal(result[0], '*[1]* test_value_01') + assert_equal(result[1], '*[2]* test_value_02') + + def test_json_to_slack_mrkdwn_multi_nested(self): + """JSON to Slack mrkdwn - multi type nested""" + nested_dict = OrderedDict([ + ('root_key_01', 'root_value_01'), + ('root_02', 'root_value_02'), + ('root_nested_01', OrderedDict([ + ('nested_key_01', 100), + ('nested_key_02', 200), + ('nested_nested_01', OrderedDict([ + ('nested_nested_key_01', [ + 6161, + 1051, + 51919 + ]) + ])) + ])) + ]) + result = self.__dispatcher._json_to_slack_mrkdwn(nested_dict, 0) + assert_equal(len(result), 10) + assert_equal(result[2], '*root_nested_01:*') + assert_equal(Counter(result[4])['\t'], 1) + assert_equal(result[-1], '\t\t\t*[3]* 51919') + + def test_json_list_to_text(self): + """JSON list to text""" + simple_list = ['test_value_01', 'test_value_02'] + result = self.__dispatcher._json_list_to_text(simple_list, '\t', 0) + + assert_equal(len(result), 2) + assert_equal(result[0], '*[1]* test_value_01') + assert_equal(result[1], '*[2]* test_value_02') + + def test_json_map_to_text(self): + """JSON map to text""" + simple_dict = OrderedDict([('test_key_01', 'test_value_01'), + ('test_key_02', 'test_value_02')]) + result = self.__dispatcher._json_map_to_text(simple_dict, '\t', 0) + + assert_equal(len(result), 2) + assert_equal(result[1], '*test_key_02:* test_value_02') + + def _setup_dispatch(self): + """Helper for setting up SlackOutput dispatch""" + remove_temp_secrets() + + output_name = self.__dispatcher.output_cred_name(self.__descriptor) + + creds = {'url': 'https://api.slack.com/web-hook-key'} + + put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, + REGION, KMS_ALIAS) + + return get_alert() + + @patch('logging.Logger.info') + @patch('requests.post') + @mock_s3 + @mock_kms + def test_dispatch_success(self, url_mock, log_info_mock): + """SlackOutput dispatch success""" + alert = self._setup_dispatch() + url_mock.return_value.status_code = 200 + url_mock.return_value.json.return_value = json.loads('{}') + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @patch('logging.Logger.error') + @patch('requests.post') + @mock_s3 + @mock_kms + def test_dispatch_failure(self, url_mock, log_error_mock): + """SlackOutput dispatch failure""" + alert = self._setup_dispatch() + json_error = json.loads('{"message": "error message", "errors": ["error1"]}') + url_mock.return_value.json.return_value = json_error + url_mock.return_value.status_code = 400 + + self.__dispatcher.dispatch(descriptor=self.__descriptor, + rule_name='rule_name', + alert=alert) + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + + @patch('logging.Logger.error') + @mock_s3 + @mock_kms + def test_dispatch_bad_descriptor(self, log_error_mock): + """SlackOutput dispatch bad descriptor""" + alert = self._setup_dispatch() + self.__dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=alert) + + log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) diff --git a/tests/unit/stream_alert_cli/test_outputs.py b/tests/unit/stream_alert_cli/test_outputs.py index 2b1fe4ac2..51aab701d 100644 --- a/tests/unit/stream_alert_cli/test_outputs.py +++ b/tests/unit/stream_alert_cli/test_outputs.py @@ -19,7 +19,7 @@ from moto import mock_kms, mock_s3 from nose.tools import assert_false, assert_list_equal, assert_true, raises -from stream_alert.alert_processor.output_base import OutputProperty +from stream_alert.alert_processor.outputs.output_base import OutputProperty from stream_alert_cli.outputs import ( encrypt_and_push_creds_to_s3, load_config, From 12fdc55b9ec9afd91bbf0e8ebfc4fdd051f26df2 Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Mon, 20 Nov 2017 10:49:42 -0800 Subject: [PATCH 3/9] [tests] simplifying pagerduty output tests --- .../stream_alert_alert_processor/helpers.py | 9 +- .../test_outputs/test_pagerduty.py | 596 +++++++----------- 2 files changed, 222 insertions(+), 383 deletions(-) diff --git a/tests/unit/stream_alert_alert_processor/helpers.py b/tests/unit/stream_alert_alert_processor/helpers.py index 0030030e9..980642452 100644 --- a/tests/unit/stream_alert_alert_processor/helpers.py +++ b/tests/unit/stream_alert_alert_processor/helpers.py @@ -93,6 +93,9 @@ def get_alert(context=None): def remove_temp_secrets(): - """"Blow away the stream_alert_secrets directory in temp""" - secrets_dir = os.path.join(tempfile.gettempdir(), "stream_alert_secrets") - shutil.rmtree(secrets_dir) + """Remove the local secrets directory that may be left from previous runs""" + secrets_dirtemp_dir = os.path.join(tempfile.gettempdir(), 'stream_alert_secrets') + + # Check if the folder exists, and remove it if it does + if os.path.isdir(secrets_dirtemp_dir): + shutil.rmtree(secrets_dirtemp_dir) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py index fe8764c12..0edc1a210 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. """ -# pylint: disable=protected-access -import json +# pylint: disable=protected-access,attribute-defined-outside-init from mock import patch, PropertyMock from moto import mock_s3, mock_kms -from nose.tools import ( - assert_equal, - assert_false -) +from nose.tools import assert_equal, assert_false, assert_true -from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput +from stream_alert.alert_processor.outputs.pagerduty import ( + PagerDutyOutput, + PagerDutyOutputV2, + PagerDutyIncidentOutput +) from stream_alert_cli.helpers import put_mock_creds from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION from tests.unit.stream_alert_alert_processor.helpers import ( @@ -31,388 +31,281 @@ ) +@mock_s3 +@mock_kms class TestPagerDutyOutput(object): """Test class for PagerDutyOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'pagerduty' - cls.__descriptor = 'unit_test_pagerduty' - cls.__backup_method = None - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None + DESCRIPTOR = 'unit_test_pagerduty' + SERVICE = 'pagerduty' + CREDS = {'url': 'http://pagerduty.foo.bar/create_event.json', + 'service_key': 'mocked_service_key'} + + def setup(self): + """Setup before each method""" + self._dispatcher = PagerDutyOutput(REGION, FUNCTION_NAME, CONFIG) + remove_temp_secrets() + output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) + put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) def test_get_default_properties(self): - """Get Default Properties - PagerDuty""" - props = self.__dispatcher._get_default_properties() + """PagerDutyOutput - Get Default Properties""" + props = self._dispatcher._get_default_properties() assert_equal(len(props), 1) assert_equal(props['url'], 'https://events.pagerduty.com/generic/2010-04-15/create_event.json') - def _setup_dispatch(self): - """Helper for setting up PagerDutyOutput dispatch""" - remove_temp_secrets() - - # Cache the _get_default_properties and set it to return None - self.__backup_method = self.__dispatcher._get_default_properties - self.__dispatcher._get_default_properties = lambda: None - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': 'http://pagerduty.foo.bar/create_event.json', - 'service_key': 'mocked_service_key'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() - - def _teardown_dispatch(self): - """Replace method with cached method""" - self.__dispatcher._get_default_properties = self.__backup_method - @patch('logging.Logger.info') @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_success(self, post_mock, log_info_mock): - """PagerDutyOutput dispatch success""" - alert = self._setup_dispatch() + def test_dispatch_success(self, post_mock, log_mock): + """PagerDutyOutput - Dispatch Success""" post_mock.return_value.status_code = 200 - post_mock.return_value.text = '' - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - self._teardown_dispatch() - - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_failure(self, post_mock, log_error_mock): - """PagerDutyOutput dispatch failure""" - alert = self._setup_dispatch() - post_mock.return_value.text = '{"message": "error message", "errors": ["error1"]}' + def test_dispatch_failure(self, post_mock, log_mock): + """PagerDutyOutput - Dispatch Failure, Bad Request""" post_mock.return_value.status_code = 400 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """PagerDutyOutput dispatch bad descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) + def test_dispatch_bad_descriptor(self, log_mock): + """PagerDutyOutput - Dispatch Failure, Bad Descriptor""" + assert_false(self._dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=get_alert())) - self._teardown_dispatch() + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) +@mock_s3 +@mock_kms class TestPagerDutyOutputV2(object): """Test class for PagerDutyOutputV2""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'pagerduty-v2' - cls.__descriptor = 'unit_test_pagerduty-v2' - cls.__backup_method = None - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None + DESCRIPTOR = 'unit_test_pagerduty-v2' + SERVICE = 'pagerduty-v2' + CREDS = {'url': 'http://pagerduty.foo.bar/create_event.json', + 'routing_key': 'mocked_routing_key'} + + def setup(self): + """Setup before each method""" + self._dispatcher = PagerDutyOutputV2(REGION, FUNCTION_NAME, CONFIG) + remove_temp_secrets() + output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) + put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) def test_get_default_properties(self): - """Get Default Properties - PagerDuty""" - props = self.__dispatcher._get_default_properties() + """PagerDutyOutputV2 - Get Default Properties""" + props = self._dispatcher._get_default_properties() assert_equal(len(props), 1) - assert_equal(props['url'], - 'https://events.pagerduty.com/v2/enqueue') - - def _setup_dispatch(self): - """Helper for setting up PagerDutyOutputV2 dispatch""" - remove_temp_secrets() - - # Cache the _get_default_properties and set it to return None - self.__backup_method = self.__dispatcher._get_default_properties - self.__dispatcher._get_default_properties = lambda: None - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': 'http://pagerduty.foo.bar/create_event.json', - 'routing_key': 'mocked_routing_key'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() - - def _teardown_dispatch(self): - """Replace method with cached method""" - self.__dispatcher._get_default_properties = self.__backup_method + assert_equal(props['url'], 'https://events.pagerduty.com/v2/enqueue') @patch('logging.Logger.info') @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_success(self, post_mock, log_info_mock): - """PagerDutyOutputV2 dispatch success""" - alert = self._setup_dispatch() + def test_dispatch_success(self, post_mock, log_mock): + """PagerDutyOutputV2 - Dispatch Success""" post_mock.return_value.status_code = 200 - post_mock.return_value.text = '' - - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - self._teardown_dispatch() + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_failure(self, post_mock, log_error_mock): - """PagerDutyOutputV2 dispatch failure""" - alert = self._setup_dispatch() - json_error = json.loads('{"message": "error message", "errors": ["error1"]}') + def test_dispatch_failure(self, post_mock, log_mock): + """PagerDutyOutputV2 - Dispatch Failure, Bad Request""" + json_error = {'message': 'error message', 'errors': ['error1']} post_mock.return_value.json.return_value = json_error post_mock.return_value.status_code = 400 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """PagerDutyOutputV2 dispatch bad descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) + def test_dispatch_bad_descriptor(self, log_mock): + """PagerDutyOutputV2 - Dispatch Failure, Bad Descriptor""" + assert_false(self._dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=get_alert())) - self._teardown_dispatch() + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) +@mock_s3 +@mock_kms class TestPagerDutyIncidentOutput(object): """Test class for PagerDutyIncidentOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'pagerduty-incident' - cls.__descriptor = 'unit_test_pagerduty-incident' - cls.__backup_method = None - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None + DESCRIPTOR = 'unit_test_pagerduty-incident' + SERVICE = 'pagerduty-incident' + CREDS = {'api': 'https://api.pagerduty.com', + 'token': 'mocked_token', + 'service_key': 'mocked_service_key', + 'escalation_policy': 'mocked_escalation_policy'} + + def setup(self): + """Setup before each method""" + self._dispatcher = PagerDutyIncidentOutput(REGION, FUNCTION_NAME, CONFIG) + self._dispatcher._base_url = self.CREDS['api'] + remove_temp_secrets() + output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) + put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) def test_get_default_properties(self): - """Get Default Properties - PagerDutyIncidentOutput""" - props = self.__dispatcher._get_default_properties() + """PagerDutyIncidentOutput - Get Default Properties""" + props = self._dispatcher._get_default_properties() assert_equal(len(props), 1) - assert_equal(props['api'], - 'https://api.pagerduty.com') + assert_equal(props['api'], 'https://api.pagerduty.com') def test_get_endpoint(self): - """Get Endpoint - PagerDutyIncidentOutput""" - props = self.__dispatcher._get_default_properties() - endpoint = self.__dispatcher._get_endpoint(props['api'], 'testtest') - assert_equal(endpoint, - 'https://api.pagerduty.com/testtest') - - def _setup_dispatch(self, context=None): - """Helper for setting up PagerDutyIncidentOutput dispatch""" - remove_temp_secrets() - - # Cache the _get_default_properties and set it to return None - self.__backup_method = self.__dispatcher._get_default_properties - self.__dispatcher._get_default_properties = lambda: None - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'api': 'https://api.pagerduty.com', - 'token': 'mocked_token', - 'service_key': 'mocked_service_key', - 'escalation_policy': 'mocked_escalation_policy'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert(context) - - def _teardown_dispatch(self): - """Replace method with cached method""" - self.__dispatcher._get_default_properties = self.__backup_method + """PagerDutyIncidentOutput - Get Endpoint""" + endpoint = self._dispatcher._get_endpoint(self.CREDS['api'], 'testtest') + assert_equal(endpoint, 'https://api.pagerduty.com/testtest') @patch('requests.get') def test_check_exists_get_id(self, get_mock): - """Check Exists Get Id - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Check Exists Get ID""" # /check get_mock.return_value.status_code = 200 - json_check = json.loads('{"check": [{"id": "checked_id"}]}') + json_check = {'check': [{'id': 'checked_id'}]} get_mock.return_value.json.return_value = json_check - checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + checked = self._dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') assert_equal(checked, 'checked_id') @patch('requests.get') def test_check_exists_get_id_fail(self, get_mock): - """Check Exists Get Id Fail - PagerDutyIncidentOutput""" - # /check + """PagerDutyIncidentOutput - Check Exists Get Id Fail""" get_mock.return_value.status_code = 200 - json_check = json.loads('{}') - get_mock.return_value.json.return_value = json_check + get_mock.return_value.json.return_value = dict() - checked = self.__dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + checked = self._dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') assert_false(checked) @patch('requests.get') def test_user_verify_success(self, get_mock): - """User Verify Success - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - User Verify Success""" get_mock.return_value.status_code = 200 - json_check = json.loads('{"users": [{"id": "verified_user_id"}]}') + json_check = {'users': [{'id': 'verified_user_id'}]} get_mock.return_value.json.return_value = json_check - user_verified = self.__dispatcher._user_verify('valid_user') + user_verified = self._dispatcher._user_verify('valid_user') assert_equal(user_verified['id'], 'verified_user_id') assert_equal(user_verified['type'], 'user_reference') @patch('requests.get') def test_user_verify_fail(self, get_mock): - """User Verify Fail - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - User Verify Fail""" get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_users": [{"not_id": "verified_user_id"}]}') + json_check = {'not_users': [{'not_id': 'verified_user_id'}]} get_mock.return_value.json.return_value = json_check - user_verified = self.__dispatcher._user_verify('valid_user') + user_verified = self._dispatcher._user_verify('valid_user') assert_false(user_verified) @patch('requests.get') def test_policy_verify_success_no_default(self, get_mock): - """Policy Verify Success (No Default) - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Policy Verify Success (No Default)""" # /escalation_policies get_mock.return_value.status_code = 200 - json_check = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + json_check = {'escalation_policies': [{'id': 'good_policy_id'}]} get_mock.return_value.json.return_value = json_check - policy_verified = self.__dispatcher._policy_verify('valid_policy', '') + policy_verified = self._dispatcher._policy_verify('valid_policy', '') assert_equal(policy_verified['id'], 'good_policy_id') assert_equal(policy_verified['type'], 'escalation_policy_reference') @patch('requests.get') def test_policy_verify_success_default(self, get_mock): - """Policy Verify Success (Default) - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Policy Verify Success (Default)""" # /escalation_policies type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_check_bad = json.loads('{"no_escalation_policies": [{"id": "bad_policy_id"}]}') - json_check_good = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + json_check_bad = {'no_escalation_policies': [{'id': 'bad_policy_id'}]} + json_check_good = {'escalation_policies': [{'id': 'good_policy_id'}]} get_mock.return_value.json.side_effect = [json_check_bad, json_check_good] - policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') + policy_verified = self._dispatcher._policy_verify('valid_policy', 'default_policy') assert_equal(policy_verified['id'], 'good_policy_id') assert_equal(policy_verified['type'], 'escalation_policy_reference') @patch('requests.get') def test_policy_verify_fail_default(self, get_mock): - """Policy Verify Fail (Default) - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Policy Verify Fail (Default)""" # /not_escalation_policies type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400]) - json_check_bad = json.loads('{"escalation_policies": [{"id": "bad_policy_id"}]}') - json_check_bad_default = json.loads('{"escalation_policies": [{"id": "good_policy_id"}]}') + json_check_bad = {'escalation_policies': [{'id': 'bad_policy_id'}]} + json_check_bad_default = {'escalation_policies': [{'id': 'good_policy_id'}]} get_mock.return_value.json.side_effect = [json_check_bad, json_check_bad_default] - policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') - assert_false(policy_verified) + + assert_false(self._dispatcher._policy_verify('valid_policy', 'default_policy')) @patch('requests.get') def test_policy_verify_fail_no_default(self, get_mock): - """Policy Verify Fail (No Default) - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Policy Verify Fail (No Default)""" # /not_escalation_policies get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_escalation_policies": [{"not_id": "verified_policy_id"}]}') + json_check = {'not_escalation_policies': [{'not_id': 'verified_policy_id'}]} get_mock.return_value.json.return_value = json_check - policy_verified = self.__dispatcher._policy_verify('valid_policy', 'default_policy') - assert_false(policy_verified) + assert_false(self._dispatcher._policy_verify('valid_policy', 'default_policy')) @patch('requests.get') def test_service_verify_success(self, get_mock): - """Service Verify Success - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Service Verify Success""" # /services get_mock.return_value.status_code = 200 - json_check = json.loads('{"services": [{"id": "verified_service_id"}]}') + json_check = {'services': [{'id': 'verified_service_id'}]} get_mock.return_value.json.return_value = json_check - service_verified = self.__dispatcher._service_verify('valid_service') + service_verified = self._dispatcher._service_verify('valid_service') assert_equal(service_verified['id'], 'verified_service_id') assert_equal(service_verified['type'], 'service_reference') @patch('requests.get') def test_service_verify_fail(self, get_mock): - """Service Verify Fail - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Service Verify Fail""" get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_services": [{"not_id": "verified_service_id"}]}') + json_check = {'not_services': [{'not_id': 'verified_service_id'}]} get_mock.return_value.json.return_value = json_check - service_verified = self.__dispatcher._service_verify('valid_service') - assert_false(service_verified) + assert_false(self._dispatcher._service_verify('valid_service')) @patch('requests.get') def test_item_verify_success(self, get_mock): - """Item Verify Success - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Item Verify Success""" # /items get_mock.return_value.status_code = 200 - json_check = json.loads('{"items": [{"id": "verified_item_id"}]}') + json_check = {'items': [{'id': 'verified_item_id'}]} get_mock.return_value.json.return_value = json_check - item_verified = self.__dispatcher._item_verify('http://mock_url', 'valid_item', - 'items', 'item_reference') + item_verified = self._dispatcher._item_verify('http://mock_url', 'valid_item', + 'items', 'item_reference') assert_equal(item_verified['id'], 'verified_item_id') assert_equal(item_verified['type'], 'item_reference') @patch('requests.get') def test_incident_assignment_user(self, get_mock): - """Incident Assignment User - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Incident Assignment User""" context = {'assigned_user': 'user_to_assign'} get_mock.return_value.status_code = 200 - json_user = json.loads('{"users": [{"id": "verified_user_id"}]}') + json_user = {'users': [{'id': 'verified_user_id'}]} get_mock.return_value.json.return_value = json_user - assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + assigned_key, assigned_value = self._dispatcher._incident_assignment(context) assert_equal(assigned_key, 'assignments') assert_equal(assigned_value[0]['assignee']['id'], 'verified_user_id') @@ -420,13 +313,13 @@ def test_incident_assignment_user(self, get_mock): @patch('requests.get') def test_incident_assignment_policy_no_default(self, get_mock): - """Incident Assignment Policy (No Default) - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Incident Assignment Policy (No Default)""" context = {'assigned_policy': 'policy_to_assign'} get_mock.return_value.status_code = 200 - json_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') + json_policy = {'escalation_policies': [{'id': 'verified_policy_id'}]} get_mock.return_value.json.return_value = json_policy - assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + assigned_key, assigned_value = self._dispatcher._incident_assignment(context) assert_equal(assigned_key, 'escalation_policy') assert_equal(assigned_value['id'], 'verified_policy_id') @@ -434,14 +327,14 @@ def test_incident_assignment_policy_no_default(self, get_mock): @patch('requests.get') def test_incident_assignment_policy_default(self, get_mock): - """Incident Assignment Policy (Default) - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Incident Assignment Policy (Default)""" context = {'assigned_policy': 'bad_invalid_policy_to_assign'} type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_bad_policy = json.loads('{"not_escalation_policies": [{"id": "bad_policy_id"}]}') - json_good_policy = json.loads('{"escalation_policies": [{"id": "verified_policy_id"}]}') + json_bad_policy = {'not_escalation_policies': [{'id': 'bad_policy_id'}]} + json_good_policy = {'escalation_policies': [{'id': 'verified_policy_id'}]} get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy] - assigned_key, assigned_value = self.__dispatcher._incident_assignment(context) + assigned_key, assigned_value = self._dispatcher._incident_assignment(context) assert_equal(assigned_key, 'escalation_policy') assert_equal(assigned_value['id'], 'verified_policy_id') @@ -449,227 +342,170 @@ def test_incident_assignment_policy_default(self, get_mock): @patch('requests.get') def test_item_verify_fail(self, get_mock): - """Item Verify Fail - PagerDutyIncidentOutput""" + """PagerDutyIncidentOutput - Item Verify Fail""" # /not_items get_mock.return_value.status_code = 200 - json_check = json.loads('{"not_items": [{"not_id": "verified_item_id"}]}') + json_check = {'not_items': [{'not_id': 'verified_item_id'}]} get_mock.return_value.json.return_value = json_check - item_verified = self.__dispatcher._item_verify('http://mock_url', 'valid_item', - 'items', 'item_reference') + item_verified = self._dispatcher._item_verify('http://mock_url', 'valid_item', + 'items', 'item_reference') assert_false(item_verified) @patch('logging.Logger.info') @patch('requests.post') @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_good_user(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Good User""" - ctx = { - 'pagerduty-incident': { - 'assigned_user': 'valid_user' - } - } - alert = self._setup_dispatch(context=ctx) - + def test_dispatch_success_good_user(self, get_mock, post_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success, Good User""" # /users, /services type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_user = json.loads('{"users": [{"id": "valid_user_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') + json_user = {'users': [{'id': 'valid_user_id'}]} + json_service = {'services': [{'id': 'service_id'}]} get_mock.return_value.json.side_effect = [json_user, json_service] # /incidents post_mock.return_value.status_code = 200 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + ctx = {'pagerduty-incident': {'assigned_user': 'valid_user'}} - self._teardown_dispatch() + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert(context=ctx))) - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.info') @patch('requests.post') @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_good_policy(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Good Policy""" - ctx = { - 'pagerduty-incident': { - 'assigned_policy': 'valid_policy' - } - } - alert = self._setup_dispatch(context=ctx) - + def test_dispatch_success_good_policy(self, get_mock, post_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success, Good Policy""" # /escalation_policies, /services type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') + json_policy = {'escalation_policies': [{'id': 'policy_id'}]} + json_service = {'services': [{'id': 'service_id'}]} get_mock.return_value.json.side_effect = [json_policy, json_service] # /incidents post_mock.return_value.status_code = 200 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + ctx = {'pagerduty-incident': {'assigned_policy': 'valid_policy'}} - self._teardown_dispatch() + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert(context=ctx))) - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.info') @patch('requests.post') @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_bad_user(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Bad User""" - ctx = { - 'pagerduty-incident': { - 'assigned_user': 'invalid_user' - } - } - alert = self._setup_dispatch(context=ctx) - + def test_dispatch_success_bad_user(self, get_mock, post_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success, Bad User""" # /users, /escalation_policies, /services type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) - json_user = json.loads('{"not_users": [{"id": "user_id"}]}') - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') + json_user = {'not_users': [{'id': 'user_id'}]} + json_policy = {'escalation_policies': [{'id': 'policy_id'}]} + json_service = {'services': [{'id': 'service_id'}]} get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] # /incidents post_mock.return_value.status_code = 200 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + ctx = {'pagerduty-incident': {'assigned_user': 'invalid_user'}} - self._teardown_dispatch() + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert(context=ctx))) - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.info') @patch('requests.post') @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_no_context(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - No Context""" - alert = self._setup_dispatch() - + def test_dispatch_success_no_context(self, get_mock, post_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success, No Context""" # /escalation_policies, /services type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') + json_policy = {'escalation_policies': [{'id': 'policy_id'}]} + json_service = {'services': [{'id': 'service_id'}]} get_mock.return_value.json.side_effect = [json_policy, json_service] # /incidents post_mock.return_value.status_code = 200 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.post') @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_error_mock): - """PagerDutyIncidentOutput dispatch failure - No User, Bad Policy, Bad Service""" - alert = self._setup_dispatch() + def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Failure: No User, Bad Policy, Bad Service""" # /escalation_policies, /services type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400, 400]) - json_empty = json.loads('{}') - get_mock.return_value.json.side_effect = [json_empty, json_empty, json_empty] + get_mock.return_value.json.side_effect = [dict(), dict(), dict()] # /incidents post_mock.return_value.status_code = 400 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.info') @patch('requests.post') @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_success_bad_policy(self, get_mock, post_mock, log_info_mock): - """PagerDutyIncidentOutput dispatch success - Bad Policy""" - ctx = { - 'pagerduty-incident': { - 'assigned_policy': 'valid_policy' - } - } - alert = self._setup_dispatch(context=ctx) + def test_dispatch_success_bad_policy(self, get_mock, post_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success, Bad Policy""" # /escalation_policies, /services get_mock.return_value.side_effect = [400, 200, 200] type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 200, 200]) - json_bad_policy = json.loads('{}') - json_good_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') + json_bad_policy = dict() + json_good_policy = {'escalation_policies': [{'id': 'policy_id'}]} + json_service = {'services': [{'id': 'service_id'}]} get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy, json_service] # /incidents post_mock.return_value.status_code = 200 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + ctx = {'pagerduty-incident': {'assigned_policy': 'valid_policy'}} - self._teardown_dispatch() + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert(context=ctx))) - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.post') @patch('requests.get') - @mock_s3 - @mock_kms - def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_error_mock): - """PagerDutyIncidentOutput dispatch - Bad Dispatch""" - alert = self._setup_dispatch() + def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Failure, Bad Request""" # /escalation_policies, /services type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_policy = json.loads('{"escalation_policies": [{"id": "policy_id"}]}') - json_service = json.loads('{"services": [{"id": "service_id"}]}') + json_policy = {'escalation_policies': [{'id': 'policy_id'}]} + json_service = {'services': [{'id': 'service_id'}]} get_mock.return_value.json.side_effect = [json_policy, json_service] # /incidents post_mock.return_value.status_code = 400 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """PagerDutyIncidentOutput dispatch - Bad Descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) - - self._teardown_dispatch() - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + def test_dispatch_bad_descriptor(self, log_mock): + """PagerDutyIncidentOutput - Dispatch Failure, Bad Descriptor""" + assert_false(self._dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=get_alert())) + + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) From b881f12757b9b72d5e57ef2ec87b3e1fe7363664 Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Mon, 20 Nov 2017 10:49:59 -0800 Subject: [PATCH 4/9] [tests] simplifying phantom output tests --- .../alert_processor/outputs/phantom.py | 18 ++- .../test_outputs/test_phantom.py | 137 +++++++----------- 2 files changed, 65 insertions(+), 90 deletions(-) diff --git a/stream_alert/alert_processor/outputs/phantom.py b/stream_alert/alert_processor/outputs/phantom.py index 0eba6e3d0..7283c1456 100644 --- a/stream_alert/alert_processor/outputs/phantom.py +++ b/stream_alert/alert_processor/outputs/phantom.py @@ -61,7 +61,8 @@ def get_user_defined_properties(cls): cred_requirement=True)) ]) - def _check_container_exists(self, rule_name, container_url, headers): + @classmethod + def _check_container_exists(cls, rule_name, container_url, headers): """Check to see if a Phantom container already exists for this rule Args: @@ -79,8 +80,8 @@ def _check_container_exists(self, rule_name, container_url, headers): '_filter_name': '"{}"'.format(rule_name), 'page_size': 1 } - resp = self._get_request(container_url, params, headers, False) - if not self._check_http_response(resp): + resp = cls._get_request(container_url, params, headers, False) + if not cls._check_http_response(resp): return False response = resp.json() @@ -90,7 +91,8 @@ def _check_container_exists(self, rule_name, container_url, headers): # of 'data' with a container id we can use return response and response.get('count') and response.get('data')[0]['id'] - def _setup_container(self, rule_name, rule_description, base_url, headers): + @classmethod + def _setup_container(cls, rule_name, rule_description, base_url, headers): """Establish a Phantom container to write the alerts to. This checks to see if an appropriate containers exists first and returns the ID if so. @@ -103,18 +105,18 @@ def _setup_container(self, rule_name, rule_description, base_url, headers): int: ID of the Phantom container where the alerts will be sent or False if there is an issue getting the container id """ - container_url = os.path.join(base_url, self.CONTAINER_ENDPOINT) + container_url = os.path.join(base_url, cls.CONTAINER_ENDPOINT) # Check to see if there is a container already created for this rule name - existing_id = self._check_container_exists(rule_name, container_url, headers) + existing_id = cls._check_container_exists(rule_name, container_url, headers) if existing_id: return existing_id # Try to use the rule_description from the rule as the container description ph_container = {'name': rule_name, 'description': rule_description} - resp = self._post_request(container_url, ph_container, headers, False) + resp = cls._post_request(container_url, ph_container, headers, False) - if not self._check_http_response(resp): + if not cls._check_http_response(resp): return False response = resp.json() diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py index 12db28cd9..fd2632fc4 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -# pylint: disable=protected-access """ +# pylint: disable=protected-access,attribute-defined-outside-init from mock import call, patch, PropertyMock from moto import mock_s3, mock_kms -from nose.tools import assert_equal +from nose.tools import assert_false, assert_true -from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput +from stream_alert.alert_processor.outputs.phantom import PhantomOutput from stream_alert_cli.helpers import put_mock_creds from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION from tests.unit.stream_alert_alert_processor.helpers import get_alert, remove_temp_secrets @@ -28,58 +28,40 @@ @mock_kms class TestPhantomOutput(object): """Test class for PhantomOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'phantom' - cls.__descriptor = 'unit_test_phantom' - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def _setup_dispatch(self, url): - """Helper for setting up PhantomOutput dispatch""" + DESCRIPTOR = 'unit_test_phantom' + SERVICE = 'phantom' + CREDS = {'url': 'http://phantom.foo.bar', + 'ph_auth_token': 'mocked_auth_token'} + + def setup(self): + """Setup before each method""" + self._dispatcher = PhantomOutput(REGION, FUNCTION_NAME, CONFIG) remove_temp_secrets() - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': url, - 'ph_auth_token': 'mocked_auth_token'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() + output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) + put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) @patch('logging.Logger.info') @patch('requests.get') @patch('requests.post') def test_dispatch_existing_container(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch success, existing container""" - alert = self._setup_dispatch('http://phantom.foo.bar') + """PhantomOutput - Dispatch Success, Existing Container""" # _check_container_exists get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = {'count': 1, 'data': [{'id': 1948}]} # dispatch post_mock.return_value.status_code = 200 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.info') @patch('requests.get') @patch('requests.post') def test_dispatch_new_container(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch success, new container""" - alert = self._setup_dispatch('http://phantom.foo.bar') + """PhantomOutput - Dispatch Success, New Container""" # _check_container_exists get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = {'count': 0, 'data': []} @@ -87,18 +69,17 @@ def test_dispatch_new_container(self, post_mock, get_mock, log_mock): post_mock.return_value.status_code = 200 post_mock.return_value.json.return_value = {'id': 1948} - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.get') @patch('requests.post') def test_dispatch_container_failure(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch failure (setup container)""" - alert = self._setup_dispatch('http://phantom.foo.bar') + """PhantomOutput - Dispatch Failure, Setup Container""" # _check_container_exists get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = {'count': 0, 'data': []} @@ -107,18 +88,17 @@ def test_dispatch_container_failure(self, post_mock, get_mock, log_mock): json_error = {'message': 'error message', 'errors': ['error1']} post_mock.return_value.json.return_value = json_error - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Failed to send alert to %s', self.__service) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.get') @patch('requests.post') def test_dispatch_check_container_error(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch decode error (check container)""" - alert = self._setup_dispatch('http://phantom.foo.bar') + """PhantomOutput - Dispatch Failure, Decode Error w/ Container Check""" # _check_container_exists get_mock.return_value.status_code = 200 get_mock.return_value.text = '{}' @@ -127,19 +107,17 @@ def test_dispatch_check_container_error(self, post_mock, get_mock, log_mock): json_error = {'message': 'error message', 'errors': ['error1']} post_mock.return_value.json.return_value = json_error - result = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.get') @patch('requests.post') def test_dispatch_setup_container_error(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch decode error (setup container)""" - alert = self._setup_dispatch('http://phantom.foo.bar') + """PhantomOutput - Dispatch Failure, Decode Error w/ Container Creation)""" # _check_container_exists get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = {'count': 0, 'data': []} @@ -148,19 +126,17 @@ def test_dispatch_setup_container_error(self, post_mock, get_mock, log_mock): post_mock.return_value.json.return_value = dict() - result = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.get') @patch('requests.post') def test_dispatch_failure(self, post_mock, get_mock, log_mock): - """PhantomOutput dispatch failure (artifact)""" - alert = self._setup_dispatch('http://phantom.foo.bar') + """PhantomOutput - Dispatch Failure, Artifact""" # _check_container_exists get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = {'count': 0, 'data': []} @@ -169,37 +145,34 @@ def test_dispatch_failure(self, post_mock, get_mock, log_mock): json_error = {'message': 'error message', 'errors': ['error1']} post_mock.return_value.json.return_value.side_effect = [{'id': 1948}, json_error] - result = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') def test_dispatch_bad_descriptor(self, log_error_mock): - """PhantomOutput dispatch bad descriptor""" - alert = self._setup_dispatch('http://phantom.foo.bar') - result = self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) + """PhantomOutput - Dispatch Failure, Bad Descriptor""" + assert_false(self._dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=get_alert())) - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) - assert_equal(result, False) + log_error_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher._get_request') @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher._post_request') def test_dispatch_container_query(self, post_mock, get_mock): """PhantomOutput - Container Query URL""" - alert = self._setup_dispatch('http://phantom.foo.bar') - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + rule_description = 'Info about this rule and what actions to take' + headers = {'ph-auth-token': 'mocked_auth_token'} + assert_false(PhantomOutput._setup_container('rule_name', + rule_description, + self.CREDS['url'], + headers)) - full_url = 'http://phantom.foo.bar/rest/container' + full_url = '{}/rest/container'.format(self.CREDS['url']) params = {'_filter_name': '"rule_name"', 'page_size': 1} - headers = {'ph-auth-token': 'mocked_auth_token'} get_mock.assert_has_calls([call(full_url, params, headers, False)]) - rule_description = 'Info about this rule and what actions to take' ph_container = {'name': 'rule_name', 'description': rule_description} post_mock.assert_has_calls([call(full_url, ph_container, headers, False)]) From 937da23637bfca2d7b737369aec8feee5de7fd73 Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Mon, 20 Nov 2017 12:13:58 -0800 Subject: [PATCH 5/9] [tests] simplifying jira output tests --- .../test_outputs/test_jira.py | 154 ++++++++---------- .../test_outputs/test_pagerduty.py | 5 +- 2 files changed, 71 insertions(+), 88 deletions(-) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py index 141f4c62b..e39f79647 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py @@ -13,59 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. """ -# pylint: disable=protected-access +# pylint: disable=protected-access,attribute-defined-outside-init from mock import call, patch, PropertyMock from moto import mock_s3, mock_kms -from nose.tools import assert_equal +from nose.tools import assert_equal, assert_false, assert_true from stream_alert.alert_processor.outputs.jira import JiraOutput from stream_alert_cli.helpers import put_mock_creds from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - remove_temp_secrets -) +from tests.unit.stream_alert_alert_processor.helpers import get_alert, remove_temp_secrets @mock_s3 @mock_kms class TestJiraOutput(object): - """Test class for PhantomOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'jira' - cls.__descriptor = 'unit_test_jira' - cls.__dispatcher = JiraOutput(REGION, FUNCTION_NAME, CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None - - def _setup_dispatch(self): - """Helper for setting up JiraOutput dispatch""" + """Test class for JiraOutput""" + DESCRIPTOR = 'unit_test_jira' + SERVICE = 'jira' + CREDS = {'username': 'jira@foo.bar', + 'password': 'jirafoobar', + 'url': 'jira.foo.bar', + 'project_key': 'foobar', + 'issue_type': 'Task', + 'aggregate': 'yes'} + + def setup(self): + """Setup before each method""" + self._dispatcher = JiraOutput(REGION, FUNCTION_NAME, CONFIG) + self._dispatcher._base_url = self.CREDS['url'] remove_temp_secrets() - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'username': 'jira@foo.bar', - 'password': 'jirafoobar', - 'url': 'jira.foo.bar', - 'project_key': 'foobar', - 'issue_type': 'Task', - 'aggregate': 'yes'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - return get_alert() + output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) + put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) @patch('logging.Logger.info') @patch('requests.get') @patch('requests.post') def test_dispatch_issue_new(self, post_mock, get_mock, log_mock): - """JiraOutput dispatch success, new issue""" - alert = self._setup_dispatch() + """JiraOutput - Dispatch Success, New Issue""" # setup the request to not find an existing issue get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = {'issues': []} @@ -74,18 +58,17 @@ def test_dispatch_issue_new(self, post_mock, get_mock, log_mock): post_mock.return_value.status_code = 200 post_mock.return_value.json.side_effect = [auth_resp, {'id': 5000}] - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.info') @patch('requests.get') @patch('requests.post') def test_dispatch_issue_existing(self, post_mock, get_mock, log_mock): - """JiraOutput dispatch success, existing issue""" - alert = self._setup_dispatch() + """JiraOutput - Dispatch Success, Existing Issue""" # setup the request to find an existing issue get_mock.return_value.status_code = 200 existing_issues = {'issues': [{'fields': {'summary': 'Bogus'}, 'id': '5000'}]} @@ -95,105 +78,108 @@ def test_dispatch_issue_existing(self, post_mock, get_mock, log_mock): post_mock.return_value.status_code = 200 post_mock.return_value.json.side_effect = [auth_resp, {'id': 5000}] - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('requests.get') def test_get_comments_success(self, get_mock): - """JiraOutput get comments success""" + """JiraOutput - Get Comments, Success""" # setup successful get comments response get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = {'comments': [{}, {}]} + expected_result = [{}, {}] + get_mock.return_value.json.return_value = {'comments': expected_result} - self.__dispatcher._load_creds('jira') - resp = self.__dispatcher._get_comments('5000') - assert_equal(resp, [{}, {}]) + self._dispatcher._load_creds('jira') + assert_equal(self._dispatcher._get_comments('5000'), expected_result) @patch('requests.get') def test_get_comments_failure(self, get_mock): - """JiraOutput get comments failure""" + """JiraOutput - Get Comments, Failure""" # setup successful get comments response get_mock.return_value.status_code = 400 - self.__dispatcher._load_creds('jira') - resp = self.__dispatcher._get_comments('5000') - assert_equal(resp, []) + self._dispatcher._load_creds('jira') + assert_equal(self._dispatcher._get_comments('5000'), []) @patch('requests.get') def test_search_failure(self, get_mock): - """JiraOutput search failure""" + """JiraOutput - Search, Failure""" # setup successful get comments response get_mock.return_value.status_code = 400 - self.__dispatcher._load_creds('jira') - resp = self.__dispatcher._search_jira('foobar') - assert_equal(resp, []) + self._dispatcher._load_creds('jira') + assert_equal(self._dispatcher._search_jira('foobar'), []) @patch('logging.Logger.error') @patch('requests.post') def test_auth_failure(self, post_mock, log_mock): - """JiraOutput auth failure""" - alert = self._setup_dispatch() - + """JiraOutput - Auth, Failure""" # setup unsuccesful auth response post_mock.return_value.status_code = 400 - post_mock.return_value.content = '{}' + post_mock.return_value.content = 'content' post_mock.return_value.json.return_value = dict() - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) log_mock.assert_has_calls([call('Encountered an error while sending to %s:\n%s', - 'jira', '{}'), + 'jira', 'content'), call('Failed to authenticate to Jira'), - call('Failed to send alert to %s', self.__service)]) + call('Failed to send alert to %s', self.SERVICE)]) @patch('logging.Logger.error') @patch('requests.get') @patch('requests.post') def test_issue_creation_failure(self, post_mock, get_mock, log_mock): - """JiraOutput issue creation failure""" - alert = self._setup_dispatch() + """JiraOutput - Issue Creation, Failure""" # setup the successful search response - no results get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = {'issues': []} # setup successful auth response and failed issue creation - auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400]) - post_mock.return_value.content = '{}' + auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} + post_mock.return_value.content = 'some bad content' post_mock.return_value.json.side_effect = [auth_resp, dict()] - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) log_mock.assert_has_calls([call('Encountered an error while sending to %s:\n%s', - self.__service, '{}'), - call('Failed to send alert to %s', self.__service)]) + self.SERVICE, 'some bad content'), + call('Failed to send alert to %s', self.SERVICE)]) @patch('logging.Logger.error') @patch('requests.get') @patch('requests.post') def test_comment_creation_failure(self, post_mock, get_mock, log_mock): - """JiraOutput comment creation failure""" - alert = self._setup_dispatch() + """JiraOutput - Comment Creation, Failure""" # setup successful search response get_mock.return_value.status_code = 200 existing_issues = {'issues': [{'fields': {'summary': 'Bogus'}, 'id': '5000'}]} get_mock.return_value.json.return_value = existing_issues - auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} # setup successful auth, failed comment creation, and successful issue creation type(post_mock.return_value).status_code = PropertyMock(side_effect=[200, 400, 200]) - post_mock.return_value.content = '{}' + auth_resp = {'session': {'name': 'cookie_name', 'value': 'cookie_value'}} post_mock.return_value.json.side_effect = [auth_resp, {'id': 6000}] - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) log_mock.assert_called_with('Encountered an error when adding alert to existing Jira ' 'issue %s. Attempting to create new Jira issue.', 5000) + + + @patch('logging.Logger.error') + def test_dispatch_bad_descriptor(self, log_error_mock): + """JiraOutput - Dispatch Failure, Bad Descriptor""" + assert_false(self._dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=get_alert())) + + log_error_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py index 0edc1a210..5d563b9d7 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py @@ -25,10 +25,7 @@ ) from stream_alert_cli.helpers import put_mock_creds from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - remove_temp_secrets -) +from tests.unit.stream_alert_alert_processor.helpers import get_alert, remove_temp_secrets @mock_s3 From 3212c0a130f47dd4fe8c70287234b876d8d7faf2 Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Mon, 20 Nov 2017 12:34:03 -0800 Subject: [PATCH 6/9] [tests] simplifying slack output tests --- .../test_outputs/test_aws.py | 5 +- .../test_outputs/test_slack.py | 157 +++++++----------- 2 files changed, 65 insertions(+), 97 deletions(-) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py index 3ce2b19b0..66c54cad8 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -# pylint: disable=abstract-class-instantiated,attribute-defined-outside-init +# pylint: disable=abstract-class-instantiated,attribute-defined-outside-init,no-self-use import boto3 from mock import patch from moto import mock_s3, mock_lambda, mock_kinesis @@ -37,9 +37,8 @@ class TestAWSOutput(object): """Test class for AWSOutput Base""" - @staticmethod @patch.object(aws.AWSOutput, '__service__', 'aws-s3') - def test_aws_format_output_config(): + def test_aws_format_output_config(self): """AWSOutput - Format Output Config""" props = { 'descriptor': OutputProperty( diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py index 8ef3a9aac..bb2ed6660 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py @@ -13,17 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -# pylint: disable=protected-access +# pylint: disable=protected-access,attribute-defined-outside-init,no-self-use from collections import Counter, OrderedDict -import json from mock import patch from moto import mock_s3, mock_kms -from nose.tools import ( - assert_equal, - assert_set_equal -) +from nose.tools import assert_equal, assert_false, assert_true, assert_set_equal -from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput +from stream_alert.alert_processor.outputs.slack import SlackOutput from stream_alert_cli.helpers import put_mock_creds from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, KMS_ALIAS, REGION from tests.unit.stream_alert_alert_processor.helpers import ( @@ -33,28 +29,26 @@ ) +@mock_s3 +@mock_kms class TestSlackOutput(object): """Test class for SlackOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'slack' - cls.__descriptor = 'unit_test_channel' - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.__dispatcher = None + DESCRIPTOR = 'unit_test_channel' + SERVICE = 'slack' + CREDS = {'url': 'https://api.slack.com/web-hook-key'} + + def setup(self): + """Setup before each method""" + self._dispatcher = SlackOutput(REGION, FUNCTION_NAME, CONFIG) + remove_temp_secrets() + output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) + put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) def test_format_message_single(self): - """Format Single Message - Slack""" + """SlackOutput - Format Single Message - Slack""" rule_name = 'test_rule_single' alert = get_random_alert(25, rule_name) - loaded_message = self.__dispatcher._format_message(rule_name, alert) + loaded_message = SlackOutput._format_message(rule_name, alert) # tests assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) @@ -64,51 +58,46 @@ def test_format_message_single(self): assert_equal(len(loaded_message['attachments']), 1) def test_format_message_mutliple(self): - """Format Multi-Message - Slack""" + """SlackOutput - Format Multi-Message""" rule_name = 'test_rule_multi-part' alert = get_random_alert(30, rule_name) - loaded_message = self.__dispatcher._format_message(rule_name, alert) + loaded_message = SlackOutput._format_message(rule_name, alert) # tests assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) - assert_equal( - loaded_message['text'], - '*StreamAlert Rule Triggered: test_rule_multi-part*') + assert_equal(loaded_message['text'], '*StreamAlert Rule Triggered: test_rule_multi-part*') assert_equal(len(loaded_message['attachments']), 2) - assert_equal(loaded_message['attachments'][1] - ['text'].split('\n')[3][1:7], '000028') + assert_equal(loaded_message['attachments'][1]['text'].split('\n')[3][1:7], '000028') def test_format_message_default_rule_description(self): - """Format Message Default Rule Description - Slack""" + """SlackOutput - Format Message, Default Rule Description""" rule_name = 'test_empty_rule_description' alert = get_random_alert(10, rule_name, True) - loaded_message = self.__dispatcher._format_message(rule_name, alert) + loaded_message = SlackOutput._format_message(rule_name, alert) # tests default_rule_description = '*Rule Description:*\nNo rule description provided\n' - assert_equal( - loaded_message['attachments'][0]['pretext'], - default_rule_description) + assert_equal(loaded_message['attachments'][0]['pretext'], default_rule_description) def test_json_to_slack_mrkdwn_str(self): - """JSON to Slack mrkdwn - simple str""" + """SlackOutput - JSON to Slack mrkdwn, Simple String""" simple_str = 'value to format' - result = self.__dispatcher._json_to_slack_mrkdwn(simple_str, 0) + result = SlackOutput._json_to_slack_mrkdwn(simple_str, 0) assert_equal(len(result), 1) assert_equal(result[0], simple_str) def test_json_to_slack_mrkdwn_dict(self): - """JSON to Slack mrkdwn - simple dict""" + """SlackOutput - JSON to Slack mrkdwn, Simple Dict""" simple_dict = OrderedDict([('test_key_01', 'test_value_01'), ('test_key_02', 'test_value_02')]) - result = self.__dispatcher._json_to_slack_mrkdwn(simple_dict, 0) + result = SlackOutput._json_to_slack_mrkdwn(simple_dict, 0) assert_equal(len(result), 2) assert_equal(result[1], '*test_key_02:* test_value_02') def test_json_to_slack_mrkdwn_nested_dict(self): - """JSON to Slack mrkdwn - nested dict""" + """SlackOutput - JSON to Slack mrkdwn, Nested Dict""" nested_dict = OrderedDict([ ('root_key_01', 'root_value_01'), ('root_02', 'root_value_02'), @@ -120,23 +109,23 @@ def test_json_to_slack_mrkdwn_nested_dict(self): ])) ])) ]) - result = self.__dispatcher._json_to_slack_mrkdwn(nested_dict, 0) + result = SlackOutput._json_to_slack_mrkdwn(nested_dict, 0) assert_equal(len(result), 7) assert_equal(result[2], '*root_nested_01:*') assert_equal(Counter(result[4])['\t'], 1) assert_equal(Counter(result[6])['\t'], 2) def test_json_to_slack_mrkdwn_list(self): - """JSON to Slack mrkdwn - simple list""" + """SlackOutput - JSON to Slack mrkdwn, Simple List""" simple_list = ['test_value_01', 'test_value_02'] - result = self.__dispatcher._json_to_slack_mrkdwn(simple_list, 0) + result = SlackOutput._json_to_slack_mrkdwn(simple_list, 0) assert_equal(len(result), 2) assert_equal(result[0], '*[1]* test_value_01') assert_equal(result[1], '*[2]* test_value_02') def test_json_to_slack_mrkdwn_multi_nested(self): - """JSON to Slack mrkdwn - multi type nested""" + """SlackOutput - JSON to Slack mrkdwn, Multi-type Nested""" nested_dict = OrderedDict([ ('root_key_01', 'root_value_01'), ('root_02', 'root_value_02'), @@ -152,84 +141,64 @@ def test_json_to_slack_mrkdwn_multi_nested(self): ])) ])) ]) - result = self.__dispatcher._json_to_slack_mrkdwn(nested_dict, 0) + result = SlackOutput._json_to_slack_mrkdwn(nested_dict, 0) assert_equal(len(result), 10) assert_equal(result[2], '*root_nested_01:*') assert_equal(Counter(result[4])['\t'], 1) assert_equal(result[-1], '\t\t\t*[3]* 51919') def test_json_list_to_text(self): - """JSON list to text""" - simple_list = ['test_value_01', 'test_value_02'] - result = self.__dispatcher._json_list_to_text(simple_list, '\t', 0) + """SlackOutput - JSON list to text""" + simple_list = ['test_value_01', 'test_value_02', {'nested': 'value_03'}] + result = SlackOutput._json_list_to_text(simple_list, '\t', 0) - assert_equal(len(result), 2) + assert_equal(len(result), 4) assert_equal(result[0], '*[1]* test_value_01') assert_equal(result[1], '*[2]* test_value_02') + assert_equal(result[2], '*[3]*') + assert_equal(result[3], '\t*nested:* value_03') def test_json_map_to_text(self): - """JSON map to text""" + """SlackOutput - JSON map to text""" simple_dict = OrderedDict([('test_key_01', 'test_value_01'), ('test_key_02', 'test_value_02')]) - result = self.__dispatcher._json_map_to_text(simple_dict, '\t', 0) + result = SlackOutput._json_map_to_text(simple_dict, '\t', 0) assert_equal(len(result), 2) assert_equal(result[1], '*test_key_02:* test_value_02') - def _setup_dispatch(self): - """Helper for setting up SlackOutput dispatch""" - remove_temp_secrets() - - output_name = self.__dispatcher.output_cred_name(self.__descriptor) - - creds = {'url': 'https://api.slack.com/web-hook-key'} - - put_mock_creds(output_name, creds, self.__dispatcher.secrets_bucket, - REGION, KMS_ALIAS) - - return get_alert() - @patch('logging.Logger.info') @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_success(self, url_mock, log_info_mock): - """SlackOutput dispatch success""" - alert = self._setup_dispatch() + def test_dispatch_success(self, url_mock, log_mock): + """SlackOutput - Dispatch Success""" url_mock.return_value.status_code = 200 - url_mock.return_value.json.return_value = json.loads('{}') + url_mock.return_value.json.return_value = dict() - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_info_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) @patch('logging.Logger.error') @patch('requests.post') - @mock_s3 - @mock_kms - def test_dispatch_failure(self, url_mock, log_error_mock): - """SlackOutput dispatch failure""" - alert = self._setup_dispatch() - json_error = json.loads('{"message": "error message", "errors": ["error1"]}') + def test_dispatch_failure(self, url_mock, log_mock): + """SlackOutput - Dispatch Failure, Bad Request""" + json_error = {'message': 'error message', 'errors': ['error1']} url_mock.return_value.json.return_value = json_error url_mock.return_value.status_code = 400 - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) @patch('logging.Logger.error') - @mock_s3 - @mock_kms - def test_dispatch_bad_descriptor(self, log_error_mock): - """SlackOutput dispatch bad descriptor""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor='bad_descriptor', - rule_name='rule_name', - alert=alert) - - log_error_mock.assert_called_with('Failed to send alert to %s', self.__service) + def test_dispatch_bad_descriptor(self, log_mock): + """SlackOutput - Dispatch Failure, Bad Descriptor""" + assert_false(self._dispatcher.dispatch(descriptor='bad_descriptor', + rule_name='rule_name', + alert=get_alert())) + + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) From 213938390f2186a46533640b55560fdcb87f7d6c Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Mon, 20 Nov 2017 12:50:10 -0800 Subject: [PATCH 7/9] [tests] simplifying aws output tests --- .../test_outputs/test_aws.py | 180 +++++++----------- 1 file changed, 67 insertions(+), 113 deletions(-) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py index 66c54cad8..0b2498454 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_aws.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -# pylint: disable=abstract-class-instantiated,attribute-defined-outside-init,no-self-use +# pylint: disable=abstract-class-instantiated,protected-access,attribute-defined-outside-init,no-self-use import boto3 from mock import patch from moto import mock_s3, mock_lambda, mock_kinesis @@ -24,11 +24,13 @@ assert_true ) -from stream_alert.alert_processor.outputs.output_base import ( - OutputProperty, - StreamAlertOutput +from stream_alert.alert_processor.outputs.output_base import OutputProperty +from stream_alert.alert_processor.outputs.aws import ( + AWSOutput, + KinesisFirehoseOutput, + LambdaOutput, + S3Output ) -from stream_alert.alert_processor.outputs import aws from stream_alert_cli.helpers import create_lambda_function from tests.unit.stream_alert_alert_processor import CONFIG, FUNCTION_NAME, REGION from tests.unit.stream_alert_alert_processor.helpers import get_alert @@ -37,7 +39,7 @@ class TestAWSOutput(object): """Test class for AWSOutput Base""" - @patch.object(aws.AWSOutput, '__service__', 'aws-s3') + @patch.object(AWSOutput, '__service__', 'aws-s3') def test_aws_format_output_config(self): """AWSOutput - Format Output Config""" props = { @@ -48,79 +50,50 @@ def test_aws_format_output_config(self): 'unique arn value, bucket, etc', 'bucket.value')} - formatted_config = aws.AWSOutput.format_output_config(CONFIG, props) + formatted_config = AWSOutput.format_output_config(CONFIG, props) assert_equal(len(formatted_config), 2) assert_is_not_none(formatted_config.get('descriptor_value')) assert_is_not_none(formatted_config.get('unit_test_bucket')) +@mock_s3 class TestS3Ouput(object): """Test class for S3Output""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'aws-s3' - cls.__descriptor = 'unit_test_bucket' - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.dispatcher = None + DESCRIPTOR = 'unit_test_bucket' + SERVICE = 'aws-s3' - def test_locals(self): - """S3Output local variables""" - assert_equal(self.__dispatcher.__class__.__name__, 'S3Output') - assert_equal(self.__dispatcher.__service__, self.__service) - - def _setup_dispatch(self): - """Helper for setting up S3Output dispatch""" - bucket = CONFIG[self.__service][self.__descriptor] + def setup(self): + """Setup before each method""" + self._dispatcher = S3Output(REGION, FUNCTION_NAME, CONFIG) + bucket = CONFIG[self.SERVICE][self.DESCRIPTOR] boto3.client('s3', region_name=REGION).create_bucket(Bucket=bucket) - return get_alert() + def test_locals(self): + """S3Output local variables""" + assert_equal(self._dispatcher.__class__.__name__, 'S3Output') + assert_equal(self._dispatcher.__service__, self.SERVICE) @patch('logging.Logger.info') - @mock_s3 def test_dispatch(self, log_mock): - """S3Output dispatch""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + """S3Output - Dispatch Success""" + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) +@mock_kinesis class TestFirehoseOutput(object): """Test class for AWS Kinesis Firehose""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'aws-firehose' - cls.__descriptor = 'unit_test_delivery_stream' - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.dispatcher = None - - def test_locals(self): - """Output local variables - Kinesis Firehose""" - assert_equal(self.__dispatcher.__class__.__name__, 'KinesisFirehoseOutput') - assert_equal(self.__dispatcher.__service__, self.__service) + DESCRIPTOR = 'unit_test_delivery_stream' + SERVICE = 'aws-firehose' - def _setup_dispatch(self): - """Helper for setting up S3Output dispatch""" - delivery_stream = CONFIG[self.__service][self.__descriptor] + def setup(self): + """Setup before each method""" + self._dispatcher = KinesisFirehoseOutput(REGION, FUNCTION_NAME, CONFIG) + delivery_stream = CONFIG[self.SERVICE][self.DESCRIPTOR] boto3.client('firehose', region_name=REGION).create_delivery_stream( DeliveryStreamName=delivery_stream, S3DestinationConfiguration={ @@ -135,80 +108,61 @@ def _setup_dispatch(self): } ) - return get_alert() + def test_locals(self): + """Output local variables - Kinesis Firehose""" + assert_equal(self._dispatcher.__class__.__name__, 'KinesisFirehoseOutput') + assert_equal(self._dispatcher.__service__, self.SERVICE) @patch('logging.Logger.info') - @mock_kinesis def test_dispatch(self, log_mock): - """Output Dispatch - Kinesis Firehose""" - alert = self._setup_dispatch() - resp = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + """Kinesis Firehose - Output Dispatch Success""" + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - assert_true(resp) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) - @mock_kinesis def test_dispatch_ignore_large_payload(self): """Output Dispatch - Kinesis Firehose with Large Payload""" - alert = self._setup_dispatch() + alert = get_alert() alert['record'] = 'test' * 1000 * 1000 - resp = self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) - - assert_false(resp) - + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=alert)) +@mock_lambda class TestLambdaOuput(object): """Test class for LambdaOutput""" - @classmethod - def setup_class(cls): - """Setup the class before any methods""" - cls.__service = 'aws-lambda' - cls.__descriptor = 'unit_test_lambda' - cls.__dispatcher = StreamAlertOutput.create_dispatcher(cls.__service, - REGION, - FUNCTION_NAME, - CONFIG) - - @classmethod - def teardown_class(cls): - """Teardown the class after all methods""" - cls.dispatcher = None + DESCRIPTOR = 'unit_test_lambda' + SERVICE = 'aws-lambda' + + def setup(self): + """Setup before each method""" + self._dispatcher = LambdaOutput(REGION, FUNCTION_NAME, CONFIG) + create_lambda_function(CONFIG[self.SERVICE][self.DESCRIPTOR], REGION) def test_locals(self): """LambdaOutput local variables""" - assert_equal(self.__dispatcher.__class__.__name__, 'LambdaOutput') - assert_equal(self.__dispatcher.__service__, self.__service) - - def _setup_dispatch(self, alt_descriptor=''): - """Helper for setting up LambdaOutput dispatch""" - function_name = CONFIG[self.__service][alt_descriptor or self.__descriptor] - create_lambda_function(function_name, REGION) - return get_alert() + assert_equal(self._dispatcher.__class__.__name__, 'LambdaOutput') + assert_equal(self._dispatcher.__service__, self.SERVICE) - @mock_lambda @patch('logging.Logger.info') def test_dispatch(self, log_mock): """LambdaOutput dispatch""" - alert = self._setup_dispatch() - self.__dispatcher.dispatch(descriptor=self.__descriptor, - rule_name='rule_name', - alert=alert) + assert_true(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) - @mock_lambda @patch('logging.Logger.info') def test_dispatch_with_qualifier(self, log_mock): - """LambdaOutput dispatch with qualifier""" - alt_descriptor = '{}_qual'.format(self.__descriptor) - alert = self._setup_dispatch(alt_descriptor) - self.__dispatcher.dispatch(descriptor=alt_descriptor, - rule_name='rule_name', - alert=alert) - - log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + """LambdaOutput - Dispatch Success, With Qualifier""" + alt_descriptor = '{}_qual'.format(self.DESCRIPTOR) + create_lambda_function(alt_descriptor, REGION) + assert_true(self._dispatcher.dispatch(descriptor=alt_descriptor, + rule_name='rule_name', + alert=get_alert())) + + log_mock.assert_called_with('Successfully sent alert to %s', self.SERVICE) From 5fcf2e391142589e79a78f145a1fde23d464c60f Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Mon, 20 Nov 2017 13:10:26 -0800 Subject: [PATCH 8/9] [docs] updating docs to reflect new decorator for outputs --- docs/source/outputs.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/outputs.rst b/docs/source/outputs.rst index 59fe869c2..fe899466d 100644 --- a/docs/source/outputs.rst +++ b/docs/source/outputs.rst @@ -110,7 +110,7 @@ Adding support for a new service involves five steps: - This should be a string value that corresponds to an identifier that best represents this service. (ie: ``__service__ = 'aws-s3'``) -4. Add the ``@output`` class decorator to the new subclass so it registered when the `outputs` module is loaded. +4. Add the ``@StreamAlertOutput`` class decorator to the new subclass so it registered when the `outputs` module is loaded. 5. To allow the cli to configure a new integration for this service, add the value used above for the ``__service__`` property to the ``manage.py`` file. From bb51b6487d68218a9df489b5db940bbef803006f Mon Sep 17 00:00:00 2001 From: Ryan Deivert Date: Mon, 20 Nov 2017 20:11:37 -0800 Subject: [PATCH 9/9] [rebase] additions after rebase --- .../alert_processor/outputs/pagerduty.py | 77 ++++++++------- .../test_outputs/test_pagerduty.py | 99 +++++++++++++------ 2 files changed, 112 insertions(+), 64 deletions(-) diff --git a/stream_alert/alert_processor/outputs/pagerduty.py b/stream_alert/alert_processor/outputs/pagerduty.py index b4f1f8767..004cc7c3f 100644 --- a/stream_alert/alert_processor/outputs/pagerduty.py +++ b/stream_alert/alert_processor/outputs/pagerduty.py @@ -221,7 +221,11 @@ def get_user_defined_properties(cls): mask_input=True, cred_requirement=True)), ('escalation_policy', - OutputProperty(description='the name of the default escalation policy')) + OutputProperty(description='the name of the default escalation policy')), + ('email_from', + OutputProperty(description='valid user email from the PagerDuty ' + 'account linked to the token', + cred_requirement=True)) ]) @staticmethod @@ -237,7 +241,7 @@ def _get_endpoint(base_url, endpoint): """ return os.path.join(base_url, endpoint) - def _check_exists_get_id(self, filter_str, url, target_key): + def _check_exists(self, filter_str, url, target_key, get_id=True): """Generic method to run a search in the PagerDuty REST API and return the id of the first occurence from the results. @@ -245,10 +249,11 @@ def _check_exists_get_id(self, filter_str, url, target_key): filter_str (str): The query filter to search for in the API url (str): The url to send the requests to in the API target_key (str): The key to extract in the returned results + get_id (boolean): Whether to generate a dict with result and reference Returns: str: ID of the targeted element that matches the provided filter or - False if a matching element does not exists. + True/False whether a matching element exists or not. """ params = { 'query': '{}'.format(filter_str) @@ -262,44 +267,39 @@ def _check_exists_get_id(self, filter_str, url, target_key): if not response: return False + if not get_id: + return True + # If there are results, get the first occurence from the list return response[target_key][0]['id'] if target_key in response else False - def _user_verify(self, user): + def _user_verify(self, user, get_id=True): """Method to verify the existance of an user with the API - Args: user (str): User to query about in the API. - + get_id (boolean): Whether to generate a dict with result and reference Returns: dict or False: JSON object be used in the API call, containing the user_id and user_reference. False if user is not found """ - users_url = self._get_endpoint(self._base_url, self.USERS_ENDPOINT) - - return self._item_verify(users_url, user, self.USERS_ENDPOINT, 'user_reference') + return self._item_verify(user, self.USERS_ENDPOINT, 'user_reference', get_id) def _policy_verify(self, policy, default_policy): """Method to verify the existance of a escalation policy with the API - Args: policy (str): Escalation policy to query about in the API default_policy (str): Escalation policy to use if the first one is not verified - Returns: dict: JSON object be used in the API call, containing the policy_id and escalation_policy_reference """ - policies_url = self._get_endpoint(self._base_url, self.POLICIES_ENDPOINT) - - verified = self._item_verify(policies_url, policy, self.POLICIES_ENDPOINT, - 'escalation_policy_reference') + verified = self._item_verify(policy, self.POLICIES_ENDPOINT, 'escalation_policy_reference') # If the escalation policy provided is not verified in the API, use the default if verified: return verified - return self._item_verify(policies_url, default_policy, self.POLICIES_ENDPOINT, + return self._item_verify(default_policy, self.POLICIES_ENDPOINT, 'escalation_policy_reference') def _service_verify(self, service): @@ -312,32 +312,29 @@ def _service_verify(self, service): dict: JSON object be used in the API call, containing the service_id and the service_reference """ - services_url = self._get_endpoint(self._base_url, self.SERVICES_ENDPOINT) - - return self._item_verify(services_url, service, self.SERVICES_ENDPOINT, 'service_reference') + return self._item_verify(service, self.SERVICES_ENDPOINT, 'service_reference') - def _item_verify(self, item_url, item_str, item_key, item_type): + def _item_verify(self, item_str, item_key, item_type, get_id=True): """Method to verify the existance of an item with the API - Args: - item_url (str): URL of the API Endpoint within the API to query item_str (str): Service to query about in the API - item_key (str): Key to be extracted from search results + item_key (str): Endpoint/key to be extracted from search results item_type (str): Type of item reference to be returned - + get_id (boolean): Whether to generate a dict with result and reference Returns: dict: JSON object be used in the API call, containing the item id - and the item reference, or False if it fails + and the item reference, True if it just exists or False if it fails """ - item_id = self._check_exists_get_id(item_str, item_url, item_key) + item_url = self._get_endpoint(self._base_url, item_key) + item_id = self._check_exists(item_str, item_url, item_key, get_id) if not item_id: LOGGER.info('%s not found in %s, %s', item_str, item_key, self.__service__) return False - return { - 'id': item_id, - 'type': item_type - } + if get_id: + return {'id': item_id, 'type': item_type} + + return item_id def _incident_assignment(self, context): """Method to determine if the incident gets assigned to a user or an escalation policy @@ -367,7 +364,6 @@ def _incident_assignment(self, context): def dispatch(self, **kwargs): """Send incident to Pagerduty Incidents API v2 - Keyword Args: **kwargs: consists of any combination of the following items: descriptor (str): Service descriptor (ie: slack channel, pd integration) @@ -379,14 +375,23 @@ def dispatch(self, **kwargs): if not creds: return self._log_status(False) + # Cache base_url + self._base_url = creds['api'] + # Preparing headers for API calls self._headers = { 'Authorization': 'Token token={}'.format(creds['token']), 'Accept': 'application/vnd.pagerduty+json;version=2' } - # Cache base_url - self._base_url = creds['api'] + # Get user email to be added as From header and verify + user_email = creds['email_from'] + if not self._user_verify(user_email, False): + LOGGER.error('Could not verify header From: %s, %s', user_email, self.__service__) + return self._log_status(False) + + # Add From to the headers after verifying + self._headers['From'] = user_email # Cache default escalation policy self._escalation_policy = creds['escalation_policy'] @@ -413,9 +418,9 @@ def dispatch(self, **kwargs): 'type': 'incident', 'title': incident_title, 'service': incident_service, - 'body': incident_body - }, - assigned_key: assigned_value + 'body': incident_body, + assigned_key: assigned_value + } } incidents_url = self._get_endpoint(self._base_url, self.INCIDENTS_ENDPOINT) resp = self._post_request(incidents_url, incident, self._headers, True) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py index 5d563b9d7..cd120d104 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py @@ -152,7 +152,8 @@ class TestPagerDutyIncidentOutput(object): CREDS = {'api': 'https://api.pagerduty.com', 'token': 'mocked_token', 'service_key': 'mocked_service_key', - 'escalation_policy': 'mocked_escalation_policy'} + 'escalation_policy': 'mocked_escalation_policy', + 'email_from': 'email@domain.com'} def setup(self): """Setup before each method""" @@ -181,7 +182,7 @@ def test_check_exists_get_id(self, get_mock): json_check = {'check': [{'id': 'checked_id'}]} get_mock.return_value.json.return_value = json_check - checked = self._dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + checked = self._dispatcher._check_exists('filter', 'http://mock_url', 'check') assert_equal(checked, 'checked_id') @patch('requests.get') @@ -190,9 +191,19 @@ def test_check_exists_get_id_fail(self, get_mock): get_mock.return_value.status_code = 200 get_mock.return_value.json.return_value = dict() - checked = self._dispatcher._check_exists_get_id('filter', 'http://mock_url', 'check') + checked = self._dispatcher._check_exists('filter', 'http://mock_url', 'check') assert_false(checked) + @patch('requests.get') + def test_check_exists_no_get_id(self, get_mock): + """Check Exists No Get Id - PagerDutyIncidentOutput""" + # /check + get_mock.return_value.status_code = 200 + json_check = {'check': [{'id': 'checked_id'}]} + get_mock.return_value.json.return_value = json_check + + assert_true(self._dispatcher._check_exists('filter', 'http://mock_url', 'check', False)) + @patch('requests.get') def test_user_verify_success(self, get_mock): """PagerDutyIncidentOutput - User Verify Success""" @@ -289,11 +300,21 @@ def test_item_verify_success(self, get_mock): json_check = {'items': [{'id': 'verified_item_id'}]} get_mock.return_value.json.return_value = json_check - item_verified = self._dispatcher._item_verify('http://mock_url', 'valid_item', - 'items', 'item_reference') + item_verified = self._dispatcher._item_verify('valid_item', 'items', 'item_reference') + assert_equal(item_verified['id'], 'verified_item_id') assert_equal(item_verified['type'], 'item_reference') + @patch('requests.get') + def test_item_verify_no_get_id_success(self, get_mock): + """Item Verify No Get Id Success - PagerDutyIncidentOutput""" + # /items + get_mock.return_value.status_code = 200 + json_check = {'items': [{'id': 'verified_item_id'}]} + get_mock.return_value.json.return_value = json_check + + assert_true(self._dispatcher._item_verify('valid_item', 'items', 'item_reference', False)) + @patch('requests.get') def test_incident_assignment_user(self, get_mock): """PagerDutyIncidentOutput - Incident Assignment User""" @@ -354,11 +375,11 @@ def test_item_verify_fail(self, get_mock): @patch('requests.get') def test_dispatch_success_good_user(self, get_mock, post_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Success, Good User""" - # /users, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + # /users, /users, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) json_user = {'users': [{'id': 'valid_user_id'}]} json_service = {'services': [{'id': 'service_id'}]} - get_mock.return_value.json.side_effect = [json_user, json_service] + get_mock.return_value.json.side_effect = [json_user, json_user, json_service] # /incidents post_mock.return_value.status_code = 200 @@ -376,11 +397,12 @@ def test_dispatch_success_good_user(self, get_mock, post_mock, log_mock): @patch('requests.get') def test_dispatch_success_good_policy(self, get_mock, post_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Success, Good Policy""" - # /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + # /users, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) + json_user = {'users': [{'id': 'user_id'}]} json_policy = {'escalation_policies': [{'id': 'policy_id'}]} json_service = {'services': [{'id': 'service_id'}]} - get_mock.return_value.json.side_effect = [json_policy, json_service] + get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] # /incidents post_mock.return_value.status_code = 200 @@ -398,12 +420,14 @@ def test_dispatch_success_good_policy(self, get_mock, post_mock, log_mock): @patch('requests.get') def test_dispatch_success_bad_user(self, get_mock, post_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Success, Bad User""" - # /users, /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) - json_user = {'not_users': [{'id': 'user_id'}]} + # /users, /users, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200, 200]) + json_user = {'users': [{'id': 'user_id'}]} + json_not_user = {'not_users': [{'id': 'user_id'}]} json_policy = {'escalation_policies': [{'id': 'policy_id'}]} json_service = {'services': [{'id': 'service_id'}]} - get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] + get_mock.return_value.json.side_effect = [json_user, json_not_user, + json_policy, json_service] # /incidents post_mock.return_value.status_code = 200 @@ -421,11 +445,12 @@ def test_dispatch_success_bad_user(self, get_mock, post_mock, log_mock): @patch('requests.get') def test_dispatch_success_no_context(self, get_mock, post_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Success, No Context""" - # /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + # /users, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) + json_user = {'users': [{'id': 'user_id'}]} json_policy = {'escalation_policies': [{'id': 'policy_id'}]} json_service = {'services': [{'id': 'service_id'}]} - get_mock.return_value.json.side_effect = [json_policy, json_service] + get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] # /incidents post_mock.return_value.status_code = 200 @@ -441,9 +466,10 @@ def test_dispatch_success_no_context(self, get_mock, post_mock, log_mock): @patch('requests.get') def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Failure: No User, Bad Policy, Bad Service""" - # /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400, 400]) - get_mock.return_value.json.side_effect = [dict(), dict(), dict()] + # /users, /users, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 400, 400, 400]) + json_user = {'users': [{'id': 'user_id'}]} + get_mock.return_value.json.side_effect = [json_user, dict(), dict(), dict()] # /incidents post_mock.return_value.status_code = 400 @@ -459,13 +485,14 @@ def test_dispatch_failure_bad_everything(self, get_mock, post_mock, log_mock): @patch('requests.get') def test_dispatch_success_bad_policy(self, get_mock, post_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Success, Bad Policy""" - # /escalation_policies, /services - get_mock.return_value.side_effect = [400, 200, 200] - type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 200, 200]) + # /users, /escalation_policies, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 400, 200, 200]) + json_user = {'users': [{'id': 'user_id'}]} json_bad_policy = dict() json_good_policy = {'escalation_policies': [{'id': 'policy_id'}]} json_service = {'services': [{'id': 'service_id'}]} - get_mock.return_value.json.side_effect = [json_bad_policy, json_good_policy, json_service] + get_mock.return_value.json.side_effect = [json_user, json_bad_policy, + json_good_policy, json_service] # /incidents post_mock.return_value.status_code = 200 @@ -483,11 +510,12 @@ def test_dispatch_success_bad_policy(self, get_mock, post_mock, log_mock): @patch('requests.get') def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Failure, Bad Request""" - # /escalation_policies, /services - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) + # /users, /escalation_policies, /services + type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200, 200]) + json_user = {'users': [{'id': 'user_id'}]} json_policy = {'escalation_policies': [{'id': 'policy_id'}]} json_service = {'services': [{'id': 'service_id'}]} - get_mock.return_value.json.side_effect = [json_policy, json_service] + get_mock.return_value.json.side_effect = [json_user, json_policy, json_service] # /incidents post_mock.return_value.status_code = 400 @@ -498,6 +526,21 @@ def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_mock): log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) + @patch('logging.Logger.error') + @patch('requests.get') + def test_dispatch_bad_email(self, get_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Failure, Bad Email""" + # /users, /escalation_policies, /services + get_mock.return_value.status_code = 400 + json_user = {'not_users': [{'id': 'no_user_id'}]} + get_mock.return_value.json.return_value = json_user + + assert_false(self._dispatcher.dispatch(descriptor=self.DESCRIPTOR, + rule_name='rule_name', + alert=get_alert())) + + log_mock.assert_called_with('Failed to send alert to %s', self.SERVICE) + @patch('logging.Logger.error') def test_dispatch_bad_descriptor(self, log_mock): """PagerDutyIncidentOutput - Dispatch Failure, Bad Descriptor"""