diff --git a/conf/global.json b/conf/global.json index a027417c6..b7b743d4a 100644 --- a/conf/global.json +++ b/conf/global.json @@ -6,11 +6,11 @@ "region": "us-east-1" }, "infrastructure": { + "metrics": { + "enabled": true + }, "monitoring": { "create_sns_topic": true - }, - "metrics": { - "enabled": false } }, "terraform": { diff --git a/stream_alert/alert_processor/__init__.py b/stream_alert/alert_processor/__init__.py index e69de29bb..900f3e8a8 100644 --- a/stream_alert/alert_processor/__init__.py +++ b/stream_alert/alert_processor/__init__.py @@ -0,0 +1,21 @@ +"""Initialize logging for the alert processor.""" +import logging +import os + +from stream_alert.shared import ALERT_PROCESSOR_NAME as FUNCTION_NAME + +# Create a package level logger to import +LEVEL = os.environ.get('LOGGER_LEVEL', 'INFO').upper() + +# Cast integer levels to avoid a ValueError +if LEVEL.isdigit(): + LEVEL = int(LEVEL) + +logging.basicConfig(format='%(name)s [%(levelname)s]: [%(module)s.%(funcName)s] %(message)s') + +LOGGER = logging.getLogger('StreamAlertOutput') +try: + LOGGER.setLevel(LEVEL) +except (TypeError, ValueError) as err: + LOGGER.setLevel('INFO') + LOGGER.error('Defaulting to INFO logging: %s', err) diff --git a/stream_alert/alert_processor/helpers.py b/stream_alert/alert_processor/helpers.py index 88150d121..d79f7af1a 100644 --- a/stream_alert/alert_processor/helpers.py +++ b/stream_alert/alert_processor/helpers.py @@ -13,11 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -import logging - -logging.basicConfig() -LOGGER = logging.getLogger('StreamAlertOutput') -LOGGER.setLevel(logging.DEBUG) +from stream_alert.alert_processor import LOGGER def validate_alert(alert): diff --git a/stream_alert/alert_processor/main.py b/stream_alert/alert_processor/main.py index 399ffab79..37d27bebb 100644 --- a/stream_alert/alert_processor/main.py +++ b/stream_alert/alert_processor/main.py @@ -15,15 +15,11 @@ """ from collections import OrderedDict import json -import logging +from stream_alert.alert_processor import LOGGER from stream_alert.alert_processor.helpers import validate_alert from stream_alert.alert_processor.outputs import get_output_dispatcher -logging.basicConfig() -LOGGER = logging.getLogger('StreamAlertOutput') -LOGGER.setLevel(logging.DEBUG) - def handler(event, context): """StreamAlert Alert Processor diff --git a/stream_alert/alert_processor/output_base.py b/stream_alert/alert_processor/output_base.py index 00a545486..9f3b72ae6 100644 --- a/stream_alert/alert_processor/output_base.py +++ b/stream_alert/alert_processor/output_base.py @@ -16,7 +16,6 @@ from abc import ABCMeta, abstractmethod from collections import namedtuple import json -import logging import os import ssl import tempfile @@ -25,8 +24,7 @@ import boto3 from botocore.exceptions import ClientError -logging.basicConfig() -LOGGER = logging.getLogger('StreamAlertOutput') +from stream_alert.alert_processor import LOGGER OutputProperty = namedtuple('OutputProperty', 'description, value, input_restrictions, mask_input, cred_requirement') @@ -274,7 +272,6 @@ def get_user_defined_properties(self): Returns: OrderedDict: Contains various OutputProperty items """ - pass @abstractmethod def dispatch(self, **kwargs): @@ -287,4 +284,3 @@ def dispatch(self, **kwargs): rule_name (str): Name of the triggered rule alert (dict): Alert relevant to the triggered rule """ - pass diff --git a/stream_alert/alert_processor/outputs.py b/stream_alert/alert_processor/outputs.py index 48872a199..a4fb43c4a 100644 --- a/stream_alert/alert_processor/outputs.py +++ b/stream_alert/alert_processor/outputs.py @@ -18,17 +18,14 @@ from collections import OrderedDict from datetime import datetime import json -import logging import os import uuid import boto3 +from stream_alert.alert_processor import LOGGER from stream_alert.alert_processor.output_base import OutputProperty, StreamOutputBase -logging.basicConfig() -LOGGER = logging.getLogger('StreamAlertOutput') - # STREAM_OUTPUTS will contain each subclass of the StreamOutputBase # All included subclasses are designated using the '@output' class decorator # The keys are the name of the service and the value is the class itself diff --git a/stream_alert/athena_partition_refresh/__init__.py b/stream_alert/athena_partition_refresh/__init__.py index e69de29bb..3290428fb 100644 --- a/stream_alert/athena_partition_refresh/__init__.py +++ b/stream_alert/athena_partition_refresh/__init__.py @@ -0,0 +1,21 @@ +"""Initialize logging for the athena partition refresh function.""" +import logging +import os + +from stream_alert.shared import ATHENA_PARTITION_REFRESH_NAME as FUNCTION_NAME + +# Create a package level logger to import +LEVEL = os.environ.get('LOGGER_LEVEL', 'INFO').upper() + +# Cast integer levels to avoid a ValueError +if LEVEL.isdigit(): + LEVEL = int(LEVEL) + +logging.basicConfig(format='%(name)s [%(levelname)s]: [%(module)s.%(funcName)s] %(message)s') + +LOGGER = logging.getLogger('StreamAlertAthena') +try: + LOGGER.setLevel(LEVEL) +except (TypeError, ValueError) as err: + LOGGER.setLevel('INFO') + LOGGER.error('Defaulting to INFO logging: %s', err) diff --git a/stream_alert/athena_partition_refresh/main.py b/stream_alert/athena_partition_refresh/main.py index 15669c94d..7fc1476ea 100644 --- a/stream_alert/athena_partition_refresh/main.py +++ b/stream_alert/athena_partition_refresh/main.py @@ -16,7 +16,6 @@ from collections import defaultdict from datetime import datetime import json -import logging import os import re import urllib @@ -24,11 +23,7 @@ import backoff import boto3 -logging.basicConfig( - format='%(name)s [%(levelname)s]: [%(module)s.%(funcName)s] %(message)s') -LEVEL = os.environ.get('LOGGER_LEVEL', 'INFO') -LOGGER = logging.getLogger('StreamAlertAthena') -LOGGER.setLevel(LEVEL.upper()) +from stream_alert.athena_partition_refresh import LOGGER def _backoff_handler(details): diff --git a/stream_alert/rule_processor/__init__.py b/stream_alert/rule_processor/__init__.py index eb4d1dda1..7944b86da 100644 --- a/stream_alert/rule_processor/__init__.py +++ b/stream_alert/rule_processor/__init__.py @@ -2,6 +2,8 @@ import logging import os +from stream_alert.shared import RULE_PROCESSOR_NAME as FUNCTION_NAME + # Create a package level logger to import LEVEL = os.environ.get('LOGGER_LEVEL', 'INFO').upper() diff --git a/stream_alert/rule_processor/handler.py b/stream_alert/rule_processor/handler.py index a89f17b57..f80f6e3a6 100644 --- a/stream_alert/rule_processor/handler.py +++ b/stream_alert/rule_processor/handler.py @@ -16,13 +16,13 @@ from logging import DEBUG as LOG_LEVEL_DEBUG import json -from stream_alert.rule_processor import LOGGER +from stream_alert.rule_processor import FUNCTION_NAME, LOGGER from stream_alert.rule_processor.classifier import StreamClassifier from stream_alert.rule_processor.config import load_config, load_env from stream_alert.rule_processor.payload import load_stream_payload from stream_alert.rule_processor.rules_engine import StreamRules from stream_alert.rule_processor.sink import StreamSink -from stream_alert.shared.metrics import Metrics +from stream_alert.shared.metrics import MetricLogger class StreamAlert(object): @@ -51,7 +51,6 @@ def __init__(self, context, enable_alert_processor=True): # Instantiate a classifier that is used for this run self.classifier = StreamClassifier(config=config) - self.metrics = Metrics('RuleProcessor', self.env['lambda_region']) self.enable_alert_processor = enable_alert_processor self._failed_record_count = 0 self._alerts = [] @@ -76,10 +75,7 @@ def run(self, event): if not records: return False - self.metrics.add_metric( - Metrics.Name.TOTAL_RECORDS, - len(records), - Metrics.Unit.COUNT) + MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_RECORDS, len(records)) for raw_record in records: # Get the service and entity from the payload. If the service/entity @@ -101,7 +97,7 @@ def run(self, event): continue # Create the StreamPayload to use for encapsulating parsed info - payload = load_stream_payload(service, entity, raw_record, self.metrics) + payload = load_stream_payload(service, entity, raw_record) if not payload: continue @@ -109,25 +105,19 @@ def run(self, event): LOGGER.debug('Invalid record count: %d', self._failed_record_count) - self.metrics.add_metric( - Metrics.Name.FAILED_PARSES, - self._failed_record_count, - Metrics.Unit.COUNT) + MetricLogger.log_metric(FUNCTION_NAME, + MetricLogger.FAILED_PARSES, + self._failed_record_count) LOGGER.debug('%s alerts triggered', len(self._alerts)) - self.metrics.add_metric( - Metrics.Name.TRIGGERED_ALERTS, len( - self._alerts), Metrics.Unit.COUNT) + MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TRIGGERED_ALERTS, len(self._alerts)) # Check if debugging logging is on before json dumping alerts since # this can be time consuming if there are a lot of alerts if self._alerts and LOGGER.isEnabledFor(LOG_LEVEL_DEBUG): LOGGER.debug('Alerts:\n%s', json.dumps(self._alerts, indent=2)) - # Send any cached metrics to CloudWatch before returning - self.metrics.send_metrics() - return self._failed_record_count == 0 def get_alerts(self): diff --git a/stream_alert/rule_processor/payload.py b/stream_alert/rule_processor/payload.py index 6fb04c709..1bfb2ae6e 100644 --- a/stream_alert/rule_processor/payload.py +++ b/stream_alert/rule_processor/payload.py @@ -25,18 +25,17 @@ import boto3 -from stream_alert.rule_processor import LOGGER -from stream_alert.shared.metrics import Metrics +from stream_alert.rule_processor import FUNCTION_NAME, LOGGER +from stream_alert.shared.metrics import MetricLogger -def load_stream_payload(service, entity, raw_record, metrics): +def load_stream_payload(service, entity, raw_record): """Returns the right StreamPayload subclass for this service Args: service (str): service name to load class for entity (str): entity for this service raw_record (str): record raw payload data - metrics (Metrics): payload metrics """ payload_map = {'s3': S3Payload, 'sns': SnsPayload, @@ -46,7 +45,7 @@ def load_stream_payload(service, entity, raw_record, metrics): LOGGER.error('Service payload not supported: %s', service) return - return payload_map[service](raw_record=raw_record, entity=entity, metrics=metrics) + return payload_map[service](raw_record=raw_record, entity=entity) class StreamPayload(object): @@ -76,7 +75,6 @@ def __init__(self, **kwargs): """ self.raw_record = kwargs['raw_record'] self.entity = kwargs['entity'] - self.metrics = kwargs['metrics'] self.pre_parsed_record = None self._refresh_record(None) @@ -177,7 +175,7 @@ def pre_parse(self): avg_record_size, self.s3_object_size) - self.metrics.add_metric(Metrics.Name.TOTAL_S3_RECORDS, line_num, Metrics.Unit.COUNT) + MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_S3_RECORDS, line_num) def _download_object(self, region, bucket, key): """Download an object from S3. @@ -216,9 +214,8 @@ def _download_object(self, region, bucket, key): total_time = time.time() - start_time LOGGER.info('Completed download in %s seconds', round(total_time, 2)) - # Publish a metric on how long this object took to download - self.metrics.add_metric( - Metrics.Name.S3_DOWNLOAD_TIME, total_time, Metrics.Unit.SECONDS) + # Log a metric on how long this object took to download + MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.S3_DOWNLOAD_TIME, total_time) return downloaded_s3_object diff --git a/stream_alert/shared/__init__.py b/stream_alert/shared/__init__.py index 5f0c03d71..82a143680 100644 --- a/stream_alert/shared/__init__.py +++ b/stream_alert/shared/__init__.py @@ -1,7 +1,12 @@ -"""Define logger for shared functionality.""" +"""Define some shared resources.""" import logging import os + +ALERT_PROCESSOR_NAME = 'alert_processor' +ATHENA_PARTITION_REFRESH_NAME = 'athena_partition_refresh' +RULE_PROCESSOR_NAME = 'rule_processor' + # Create a package level logger to import LEVEL = os.environ.get('LOGGER_LEVEL', 'INFO').upper() diff --git a/stream_alert/shared/metrics.py b/stream_alert/shared/metrics.py index a7ea35715..b8a323dce 100644 --- a/stream_alert/shared/metrics.py +++ b/stream_alert/shared/metrics.py @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. """ -from datetime import datetime -import json import os -import boto3 -from botocore.exceptions import ClientError - -from stream_alert.shared import LOGGER +from stream_alert.shared import ( + ALERT_PROCESSOR_NAME, + ATHENA_PARTITION_REFRESH_NAME, + LOGGER, + RULE_PROCESSOR_NAME +) CLUSTER = os.environ.get('CLUSTER', 'unknown_cluster') @@ -31,121 +31,78 @@ LOGGER.error('Invalid value for metric toggling, expected 0 or 1: %s', err.message) +if not ENABLE_METRICS: + LOGGER.debug('Logging of metric data is currently disabled.') + -class Metrics(object): - """Class to hold metric names and unit constants. +class MetricLogger(object): + """Class to hold metric logging to be picked up by log metric filters. This basically acts as an enum, allowing for the use of dot notation for accessing properties and avoids doing dict lookups a ton. """ - def __init__(self, lambda_function, region): - self.boto_cloudwatch = boto3.client('cloudwatch', region_name=region) - self._metric_data = [] - self._dimensions = [ - { - 'Name': 'Cluster', - 'Value': CLUSTER - }, - { - 'Name': 'Function', - 'Value': lambda_function - } - ] - - class Name(object): - """Constant metric names used for CloudWatch""" - FAILED_PARSES = 'FailedParses' - S3_DOWNLOAD_TIME = 'S3DownloadTime' - TOTAL_RECORDS = 'TotalRecords' - TOTAL_S3_RECORDS = 'TotalS3Records' - TRIGGERED_ALERTS = 'TriggeredAlerts' - - class Unit(object): - """Unit names for metrics. These are taken from the boto3 CloudWatch page""" - SECONDS = 'Seconds' - MICROSECONDS = 'Microseconds' - MILLISECONDS = 'Milliseconds' - BYTES = 'Bytes' - KILOBYTES = 'Kilobytes' - MEGABYTES = 'Megabytes' - GIGABYTES = 'Gigabytes' - TERABYTES = 'Terabytes' - BITS = 'Bits' - KILOBITS = 'Kilobits' - MEGABITS = 'Megabits' - GIGABITS = 'Gigabits' - TERABITS = 'Terabits' - PERCENT = 'Percent' - COUNT = 'Count' - BYTES_PER_SECOND = 'Bytes/Second' - KILOBYTES_PER_SECOND = 'Kilobytes/Second' - MEGABYTES_PER_SECOND = 'Megabytes/Second' - GIGABYTES_PER_SECOND = 'Gigabytes/Second' - TERABYTES_PER_SECOND = 'Terabytes/Second' - BITS_PER_SECOND = 'Bits/Second' - KILOBITS_PER_SECOND = 'Kilobits/Second' - MEGABITS_PER_SECOND = 'Megabits/Second' - GIGABITS_PER_SECOND = 'Gigabits/Second' - TERABITS_PER_SECOND = 'Terabits/Second' - COUNT_PER_SECOND = 'Count/Second' - NONE = 'None' - - def add_metric(self, metric_name, value, unit): - """Add a metric to the list of metrics to be sent to CloudWatch + # Constant metric names used for CloudWatch + FAILED_PARSES = 'FailedParses' + S3_DOWNLOAD_TIME = 'S3DownloadTime' + TOTAL_RECORDS = 'TotalRecords' + TOTAL_S3_RECORDS = 'TotalS3Records' + TRIGGERED_ALERTS = 'TriggeredAlerts' + + _default_filter = '{{ $.metric_name = "{}" }}' + _default_value_lookup = '$.metric_value' + + # Establish all the of available metrics for each processor. These use default + # values for the filter pattern and value lookup, created above, but can be + # overridden in special cases. The terraform generate code uses these values to + # create the actual CloudWatch metric filters that will be used for each function. + # If additional metric logging is added that does not conform to this default + # configuration, new filters & lookups should be created to handle them as well. + _available_metrics = { + ALERT_PROCESSOR_NAME: {}, # Placeholder for future alert processor metrics + ATHENA_PARTITION_REFRESH_NAME: {}, # Placeholder for future athena processor metrics + RULE_PROCESSOR_NAME: { + FAILED_PARSES: (_default_filter.format(FAILED_PARSES), _default_value_lookup), + S3_DOWNLOAD_TIME: (_default_filter.format(S3_DOWNLOAD_TIME), _default_value_lookup), + TOTAL_RECORDS: (_default_filter.format(TOTAL_RECORDS), _default_value_lookup), + TOTAL_S3_RECORDS: (_default_filter.format(TOTAL_S3_RECORDS), _default_value_lookup), + TRIGGERED_ALERTS: (_default_filter.format(TRIGGERED_ALERTS), _default_value_lookup) + } + } + + @classmethod + def log_metric(cls, lambda_function, metric_name, value): + """Log a metric using the logger the list of metrics to be sent to CloudWatch Args: metric_name (str): Name of metric to publish to. Choices are in `Metrics.Name` above value (num): Numeric information to post to metric. AWS expects this to be of type 'float' but will accept any numeric value that is not super small (negative) or super large. - unit (str): Unit to use for this metric. Choices are in `Metrics.Unit` above. """ - if metric_name not in self.Name.__dict__.values(): - LOGGER.error('Metric name not defined: %s', metric_name) - return - - if unit not in self.Unit.__dict__.values(): - LOGGER.error('Metric unit not defined: %s', unit) - return - - self._metric_data.append( - { - 'MetricName': metric_name, - 'Timestamp': datetime.utcnow(), - 'Unit': unit, - 'Value': value, - "Dimensions": self._dimensions - } - ) - - def send_metrics(self): - """Public method for publishing custom metric data to CloudWatch.""" + # Do not log any metrics if they have been disabled by the user if not ENABLE_METRICS: - LOGGER.debug('Sending of metric data is currently disabled.') return - if not self._metric_data: - LOGGER.debug('No metric data to send to CloudWatch.') + if lambda_function not in cls._available_metrics: + LOGGER.error('Function \'%s\' not defined in available metrics. Options are: %s', + lambda_function, + ', '.join('\'{}\''.format(key) for key in cls._available_metrics + if cls._available_metrics[key])) return - for metric in self._metric_data: - LOGGER.debug('Sending metric data to CloudWatch: %s', metric['MetricName']) + if metric_name not in cls._available_metrics[lambda_function]: + LOGGER.error('Metric name (\'%s\') not defined for \'%s\' function. Options are: %s', + metric_name, + lambda_function, + ', '.join('\'{}\''.format(value) + for value in cls._available_metrics[lambda_function])) + return - self._put_metrics() + # Use a default format for logging this metric that will get picked up by the filters + LOGGER.info('{"metric_name": "%s", "metric_value": %s}', metric_name, value) - def _put_metrics(self): - """Protected method for publishing custom metric data to CloudWatch that - handles all of the boto3 calls and error handling. - """ - try: - self.boto_cloudwatch.put_metric_data( - Namespace='StreamAlert', MetricData=self._metric_data) - except ClientError as err: - LOGGER.exception( - 'Failed to send metric to CloudWatch. Error: %s\nMetric data:\n%s', - err.response, - json.dumps( - self._metric_data, - indent=2, - default=lambda d: d.isoformat())) + @classmethod + def get_available_metrics(cls): + """Return the protected dictionary of metrics for all functions""" + return cls._available_metrics diff --git a/stream_alert_cli/helpers.py b/stream_alert_cli/helpers.py index e77f87127..6229dff51 100644 --- a/stream_alert_cli/helpers.py +++ b/stream_alert_cli/helpers.py @@ -138,8 +138,11 @@ def format_lambda_test_record(test_record): return template -def create_lambda_function(function_name, region): +def create_lambda_function(function_name, region): """Helper function to create mock lambda function""" + if function_name.find(':') != -1: + function_name = function_name.split(':')[0] + boto3.client('lambda', region_name=region).create_function( FunctionName=function_name, Runtime='python2.7', @@ -154,7 +157,6 @@ def create_lambda_function(function_name, region): } ) - def encrypt_with_kms(data, region, alias): """Encrypt the given data with KMS.""" kms_client = boto3.client('kms', region_name=region) diff --git a/stream_alert_cli/terraform_generate.py b/stream_alert_cli/terraform_generate.py index 8871e963d..f37a63698 100644 --- a/stream_alert_cli/terraform_generate.py +++ b/stream_alert_cli/terraform_generate.py @@ -17,6 +17,7 @@ import json import os +from stream_alert.shared import metrics from stream_alert_cli.logger import LOGGER_CLI RESTRICTED_CLUSTER_NAMES = ('main', 'athena') @@ -280,6 +281,57 @@ def generate_stream_alert(cluster_name, cluster_dict, config): return True +def generate_cloudwatch_log_metrics(cluster_name, cluster_dict, config): + """Add the CloudWatch Metric Filters module to the Terraform cluster dict. + + Args: + cluster_name (str): The name of the currently generating cluster + cluster_dict (defaultdict): The dict containing all Terraform config for a given cluster. + config (dict): The loaded config from the 'conf/' directory + + Returns: + bool: Result of applying the cloudwatch metric filters to the stream_alert module + """ + enable_metrics = config['global'].get('infrastructure', + {}).get('metrics', {}).get('enabled', False) + + # Do not add any metric filters if metrics are disabled + if not enable_metrics: + return + + current_metrics = metrics.MetricLogger.get_available_metrics() + + # Add metric filters for the rule and alert processor + # The funcs dict acts as a simple map to a human-readable name + funcs = {metrics.ALERT_PROCESSOR_NAME: 'AlertProcessor', + metrics.RULE_PROCESSOR_NAME: 'RuleProcessor'} + + for func in funcs: + if func not in current_metrics: + continue + + metric_prefix = funcs[func] + filter_pattern_idx, filter_value_idx = 0, 1 + + # Add filters for the cluster and aggregate + # Use a list of strings that represnt the following comma separated values: + # ,, + filters = [] + for metric, settings in current_metrics[func].items(): + filters.extend([ + '{},{},{}'.format( + '{}-{}-{}'.format(metric_prefix, metric, cluster_name.upper()), + settings[filter_pattern_idx], + settings[filter_value_idx]), + '{},{},{}'.format( + '{}-{}'.format(metric_prefix, metric), + settings[filter_pattern_idx], + settings[filter_value_idx]) + ]) + + cluster_dict['module']['stream_alert_{}'.format(cluster_name)] \ + ['{}_metric_filters'.format(func)] = filters + def generate_cloudwatch_monitoring(cluster_name, cluster_dict, config): """Add the CloudWatch Monitoring module to the Terraform cluster dict. @@ -563,6 +615,8 @@ def generate_cluster(**kwargs): if not generate_stream_alert(cluster_name, cluster_dict, config): return + generate_cloudwatch_log_metrics(cluster_name, cluster_dict, config) + if modules['cloudwatch_monitoring']['enabled']: if not generate_cloudwatch_monitoring(cluster_name, cluster_dict, config): return @@ -629,6 +683,28 @@ def generate_athena(config): 'prefix': config['global']['account']['prefix'] } + if not enable_metrics: + return athena_dict + + # Check to see if there are any metrics configured for the athena function + current_metrics = metrics.MetricLogger.get_available_metrics() + if metrics.ATHENA_PARTITION_REFRESH_NAME not in current_metrics: + return athena_dict + + metric_prefix = 'AthenaRefresh' + filter_pattern_idx, filter_value_idx = 0, 1 + + # Add filters for the cluster and aggregate + # Use a list of strings that represnt the following comma separated values: + # ,, + filters = ['{},{},{}'.format('{}-{}'.format(metric_prefix, metric), + settings[filter_pattern_idx], + settings[filter_value_idx]) + for metric, settings in + current_metrics[metrics.ATHENA_PARTITION_REFRESH_NAME].iteritems()] + + athena_dict['module']['stream_alert_athena']['athena_metric_filters'] = filters + return athena_dict diff --git a/stream_alert_cli/test.py b/stream_alert_cli/test.py index 23ad46de5..41c9f8d37 100644 --- a/stream_alert_cli/test.py +++ b/stream_alert_cli/test.py @@ -139,7 +139,7 @@ def _validate_test_records(self, rule_name, test_record, formatted_record, print return # Create the StreamPayload to use for encapsulating parsed info - payload = load_stream_payload(service, entity, formatted_record, Mock()) + payload = load_stream_payload(service, entity, formatted_record) if not payload: self.all_tests_passed = False return diff --git a/terraform/modules/tf_stream_alert/iam.tf b/terraform/modules/tf_stream_alert/iam.tf index 753a8e9b6..4a0001191 100644 --- a/terraform/modules/tf_stream_alert/iam.tf +++ b/terraform/modules/tf_stream_alert/iam.tf @@ -42,37 +42,6 @@ data "aws_iam_policy_document" "rule_processor_invoke_alert_proc" { } } -// IAM Role Policy: Allow the Rule Processor to publish CloudWatch metrics -resource "aws_iam_role_policy" "streamalert_rule_processor_cloudwatch_put_metric_data" { - name = "${var.prefix}_${var.cluster}_streamalert_rule_processor_put_metric_data" - role = "${aws_iam_role.streamalert_rule_processor_role.id}" - - policy = "${data.aws_iam_policy_document.put_metric_data.json}" -} - -// IAM Role Policy: Allow the Alert Processor to publish CloudWatch metrics -resource "aws_iam_role_policy" "streamalert_alert_processor_cloudwatch_put_metric_data" { - name = "${var.prefix}_${var.cluster}_streamalert_alert_processor_put_metric_data" - role = "${aws_iam_role.streamalert_alert_processor_role.id}" - - policy = "${data.aws_iam_policy_document.put_metric_data.json}" -} - -// IAM Policy Doc: Allow the Lambda Functions to publish CloudWatch metrics -data "aws_iam_policy_document" "put_metric_data" { - statement { - effect = "Allow" - - actions = [ - "cloudwatch:PutMetricData", - ] - - resources = [ - "*", - ] - } -} - // IAM Role: Alert Processor Execution Role resource "aws_iam_role" "streamalert_alert_processor_role" { name = "${var.prefix}_${var.cluster}_streamalert_alert_processor_role" diff --git a/terraform/modules/tf_stream_alert/main.tf b/terraform/modules/tf_stream_alert/main.tf index 238b31dde..8d2527de9 100644 --- a/terraform/modules/tf_stream_alert/main.tf +++ b/terraform/modules/tf_stream_alert/main.tf @@ -150,3 +150,33 @@ resource "aws_cloudwatch_log_group" "alert_processor" { name = "/aws/lambda/${var.prefix}_${var.cluster}_streamalert_alert_processor" retention_in_days = 60 } + +// CloudWatch metric filters for the rule processor +// The split list is made up of: , , +resource "aws_cloudwatch_log_metric_filter" "rule_processor_cw_metric_filters" { + count = "${length(var.rule_processor_metric_filters)}" + name = "${element(split(",", var.rule_processor_metric_filters[count.index]), 0)}" + pattern = "${element(split(",", var.rule_processor_metric_filters[count.index]), 1)}" + log_group_name = "${aws_cloudwatch_log_group.rule_processor.name}" + + metric_transformation { + name = "${element(split(",", var.rule_processor_metric_filters[count.index]), 0)}" + namespace = "${var.namespace}" + value = "${element(split(",", var.rule_processor_metric_filters[count.index]), 2)}" + } +} + +// CloudWatch metric filters for the alert processor +// The split list is made up of: , , +resource "aws_cloudwatch_log_metric_filter" "alert_processor_cw_metric_filters" { + count = "${length(var.alert_processor_metric_filters)}" + name = "${element(split(",", var.alert_processor_metric_filters[count.index]), 0)}" + pattern = "${element(split(",", var.alert_processor_metric_filters[count.index]), 1)}" + log_group_name = "${aws_cloudwatch_log_group.alert_processor.name}" + + metric_transformation { + name = "${element(split(",", var.alert_processor_metric_filters[count.index]), 0)}" + namespace = "${var.namespace}" + value = "${element(split(",", var.alert_processor_metric_filters[count.index]), 2)}" + } +} diff --git a/terraform/modules/tf_stream_alert/variables.tf b/terraform/modules/tf_stream_alert/variables.tf index 1c6afdf1f..9b06cd835 100644 --- a/terraform/modules/tf_stream_alert/variables.tf +++ b/terraform/modules/tf_stream_alert/variables.tf @@ -86,3 +86,18 @@ variable "alert_processor_vpc_security_group_ids" { type = "list" default = [] } + +variable "rule_processor_metric_filters" { + type = "list" + default = [] +} + +variable "alert_processor_metric_filters" { + type = "list" + default = [] +} + +variable "namespace" { + type = "string" + default = "StreamAlert" +} diff --git a/terraform/modules/tf_stream_alert_athena/main.tf b/terraform/modules/tf_stream_alert_athena/main.tf index 95b40cdc2..89d06cd6e 100644 --- a/terraform/modules/tf_stream_alert_athena/main.tf +++ b/terraform/modules/tf_stream_alert_athena/main.tf @@ -97,3 +97,18 @@ resource "aws_cloudwatch_log_group" "athena" { name = "/aws/lambda/${var.prefix}_streamalert_athena_partition_refresh" retention_in_days = 60 } + +// CloudWatch metric filters for the athena partition refresh function +// The split list is made up of: , , +resource "aws_cloudwatch_log_metric_filter" "athena_partition_refresh_cw_metric_filters" { + count = "${length(var.athena_metric_filters)}" + name = "${element(split(",", var.athena_metric_filters[count.index]), 0)}" + pattern = "${element(split(",", var.athena_metric_filters[count.index]), 1)}" + log_group_name = "${aws_cloudwatch_log_group.athena.name}" + + metric_transformation { + name = "${element(split(",", var.athena_metric_filters[count.index]), 0)}" + namespace = "${var.namespace}" + value = "${element(split(",", var.athena_metric_filters[count.index]), 2)}" + } +} diff --git a/terraform/modules/tf_stream_alert_athena/variables.tf b/terraform/modules/tf_stream_alert_athena/variables.tf index b5180a446..b7e45c3fd 100644 --- a/terraform/modules/tf_stream_alert_athena/variables.tf +++ b/terraform/modules/tf_stream_alert_athena/variables.tf @@ -46,3 +46,13 @@ variable "refresh_interval" { variable "enable_metrics" { default = false } + +variable "athena_metric_filters" { + type = "list" + default = [] +} + +variable "namespace" { + type = "string" + default = "StreamAlert" +} diff --git a/tests/unit/conf/outputs.json b/tests/unit/conf/outputs.json index 9de1ecb20..756397d2b 100644 --- a/tests/unit/conf/outputs.json +++ b/tests/unit/conf/outputs.json @@ -3,7 +3,8 @@ "unit_test_bucket": "unit.test.bucket.name" }, "aws-lambda": { - "unit_test_lambda": "unit_test_function" + "unit_test_lambda": "unit_test_function", + "unit_test_lambda_qual": "unit_test_function:production" }, "pagerduty": [ "unit_test_pagerduty" diff --git a/tests/unit/stream_alert_alert_processor/test_main.py b/tests/unit/stream_alert_alert_processor/test_main.py index 984ef9239..b4a8fc9d6 100644 --- a/tests/unit/stream_alert_alert_processor/test_main.py +++ b/tests/unit/stream_alert_alert_processor/test_main.py @@ -16,15 +16,18 @@ # pylint: disable=protected-access from collections import OrderedDict import json +import os -from mock import mock_open, patch +from mock import call, mock_open, patch from nose.tools import ( assert_equal, assert_is_instance, assert_list_equal, - assert_true + assert_true, + with_setup ) +import stream_alert.alert_processor as ap from stream_alert.alert_processor.main import _load_output_config, _sort_dict, handler from tests.unit.stream_alert_alert_processor import FUNCTION_NAME, REGION from tests.unit.stream_alert_alert_processor.helpers import get_alert, get_mock_context @@ -177,3 +180,39 @@ def test_running_exception_occurred(creds_mock, dispatch_mock, config_mock, url_ 'An error occurred while sending alert ' 'to %s:%s: %s. alert:\n%s', 'slack', 'unit_test_channel', err, json.dumps(alert, indent=2)) + +def _teardown_env(): + """Helper method to reset environment variables""" + if 'LOGGER_LEVEL' in os.environ: + del os.environ['LOGGER_LEVEL'] + + +@with_setup(setup=None, teardown=_teardown_env) +@patch('stream_alert.alert_processor.LOGGER.error') +def test_init_logging_bad(log_mock): + """Alert Processor Init - Logging, Bad Level""" + level = 'IFNO' + + os.environ['LOGGER_LEVEL'] = level + + # Force reload the alert_processor package to trigger the init + reload(ap) + + message = str(call('Defaulting to INFO logging: %s', + ValueError('Unknown level: \'IFNO\'',))) + + assert_equal(str(log_mock.call_args_list[0]), message) + + +@with_setup(setup=None, teardown=_teardown_env) +@patch('stream_alert.alert_processor.LOGGER.setLevel') +def test_init_logging_int_level(log_mock): + """Alert Processor Init - Logging, Integer Level""" + level = '10' + + os.environ['LOGGER_LEVEL'] = level + + # Force reload the alert_processor package to trigger the init + reload(ap) + + log_mock.assert_called_with(10) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs.py b/tests/unit/stream_alert_alert_processor/test_outputs.py index f09571429..fc935dd5f 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs.py @@ -587,9 +587,9 @@ def test_locals(self): assert_equal(self.__dispatcher.__class__.__name__, 'LambdaOutput') assert_equal(self.__dispatcher.__service__, self.__service) - def _setup_dispatch(self): + def _setup_dispatch(self, alt_descriptor=''): """Helper for setting up LambdaOutput dispatch""" - function_name = CONFIG[self.__service][self.__descriptor] + function_name = CONFIG[self.__service][alt_descriptor or self.__descriptor] create_lambda_function(function_name, REGION) return get_alert() @@ -603,3 +603,15 @@ def test_dispatch(self, log_mock): alert=alert) log_mock.assert_called_with('Successfully sent alert to %s', self.__service) + + @mock_lambda + @patch('logging.Logger.info') + def test_dispatch_with_qualifier(self, log_mock): + """LambdaOutput dispatch with qualifier""" + alt_descriptor = '{}_qual'.format(self.__descriptor) + alert = self._setup_dispatch(alt_descriptor) + self.__dispatcher.dispatch(descriptor=alt_descriptor, + rule_name='rule_name', + alert=alert) + + log_mock.assert_called_with('Successfully sent alert to %s', self.__service) diff --git a/tests/unit/stream_alert_athena_partition_refresh/test_main.py b/tests/unit/stream_alert_athena_partition_refresh/test_main.py index 2923ce551..18209cc75 100644 --- a/tests/unit/stream_alert_athena_partition_refresh/test_main.py +++ b/tests/unit/stream_alert_athena_partition_refresh/test_main.py @@ -18,9 +18,10 @@ # specific test: nosetests -v -s tests/unit/file.py:TestStreamPayload.test_name from datetime import datetime import json +import os import boto3 -from mock import patch +from mock import call, patch from moto import mock_sqs from nose.tools import ( assert_equal, @@ -28,9 +29,11 @@ assert_is_none, assert_true, raises, - nottest + nottest, + with_setup ) +import stream_alert.athena_partition_refresh as apr from stream_alert.athena_partition_refresh.main import ( _backoff_handler, _load_config, @@ -582,3 +585,40 @@ def test_run_athena_query(self): assert_true(query_success) assert_equal(query_results['ResultSet']['Rows'], [{'Data': [{'test': 'test'}]}]) + + +def _teardown_env(): + """Helper method to reset environment variables""" + if 'LOGGER_LEVEL' in os.environ: + del os.environ['LOGGER_LEVEL'] + + +@with_setup(setup=None, teardown=_teardown_env) +@patch('stream_alert.athena_partition_refresh.LOGGER.error') +def test_init_logging_bad(log_mock): + """Athena Parition Refresh Init - Logging, Bad Level""" + level = 'IFNO' + + os.environ['LOGGER_LEVEL'] = level + + # Force reload the athena_partition_refresh package to trigger the init + reload(apr) + + message = str(call('Defaulting to INFO logging: %s', + ValueError('Unknown level: \'IFNO\'',))) + + assert_equal(str(log_mock.call_args_list[0]), message) + + +@with_setup(setup=None, teardown=_teardown_env) +@patch('stream_alert.athena_partition_refresh.LOGGER.setLevel') +def test_init_logging_int_level(log_mock): + """Athena Parition Refresh Init - Logging, Integer Level""" + level = '10' + + os.environ['LOGGER_LEVEL'] = level + + # Force reload the athena_partition_refresh package to trigger the init + reload(apr) + + log_mock.assert_called_with(10) diff --git a/tests/unit/stream_alert_cli/test_terraform_generate.py b/tests/unit/stream_alert_cli/test_terraform_generate.py index ff57ed369..58dd8b832 100644 --- a/tests/unit/stream_alert_cli/test_terraform_generate.py +++ b/tests/unit/stream_alert_cli/test_terraform_generate.py @@ -632,7 +632,8 @@ def test_generate_athena(self): 'unit-testing-2.streamalerts' ], 'prefix': 'unit-testing', - 'refresh_interval': 'rate(10 minutes)' + 'refresh_interval': 'rate(10 minutes)', + 'athena_metric_filters': [] } } } diff --git a/tests/unit/stream_alert_rule_processor/test_classifier.py b/tests/unit/stream_alert_rule_processor/test_classifier.py index 8af118e15..4b2f64e1b 100644 --- a/tests/unit/stream_alert_rule_processor/test_classifier.py +++ b/tests/unit/stream_alert_rule_processor/test_classifier.py @@ -44,7 +44,7 @@ def setup(self): def _prepare_and_classify_payload(self, service, entity, raw_record): """Helper method to return a preparsed and classified payload""" - payload = load_stream_payload(service, entity, raw_record, None) + payload = load_stream_payload(service, entity, raw_record) payload = payload.pre_parse().next() self.classifier.load_sources(service, entity) @@ -253,7 +253,7 @@ def test_parse_convert_fail(self, log_mock): }) raw_record = make_kinesis_raw_record(entity, kinesis_data) - payload = load_stream_payload(service, entity, raw_record, None) + payload = load_stream_payload(service, entity, raw_record) payload = payload.pre_parse().next() result = self.classifier._parse(payload) @@ -279,7 +279,7 @@ def test_mult_schema_match_success(self): service, entity = 'kinesis', 'test_stream_2' raw_record = make_kinesis_raw_record(entity, kinesis_data) - payload = load_stream_payload(service, entity, raw_record, None) + payload = load_stream_payload(service, entity, raw_record) self.classifier.load_sources(service, entity) @@ -308,7 +308,7 @@ def test_mult_schema_match_failure(self, log_mock): service, entity = 'kinesis', 'test_stream_2' raw_record = make_kinesis_raw_record(entity, kinesis_data) - payload = load_stream_payload(service, entity, raw_record, None) + payload = load_stream_payload(service, entity, raw_record) self.classifier.load_sources(service, entity) @@ -337,7 +337,7 @@ def test_mult_schema_match(self, log_mock): service, entity = 'kinesis', 'test_stream_2' raw_record = make_kinesis_raw_record(entity, kinesis_data) - payload = load_stream_payload(service, entity, raw_record, None) + payload = load_stream_payload(service, entity, raw_record) self.classifier.load_sources(service, entity) diff --git a/tests/unit/stream_alert_rule_processor/test_handler.py b/tests/unit/stream_alert_rule_processor/test_handler.py index 876fbb0ff..7bff7947f 100644 --- a/tests/unit/stream_alert_rule_processor/test_handler.py +++ b/tests/unit/stream_alert_rule_processor/test_handler.py @@ -17,7 +17,7 @@ import base64 import logging -from mock import call, Mock, mock_open, patch +from mock import call, mock_open, patch from nose.tools import ( assert_equal, assert_false, @@ -32,7 +32,6 @@ from tests.unit.stream_alert_rule_processor.test_helpers import get_mock_context, get_valid_event -@patch('stream_alert.rule_processor.handler.Metrics.send_metrics', Mock()) class TestStreamAlert(object): """Test class for StreamAlert class""" @@ -52,7 +51,7 @@ def test_run_no_records(self): @staticmethod @raises(ConfigError) - def test_run_config_error(_): + def test_run_config_error(): """StreamAlert Class - Run, Config Error""" mock = mock_open(read_data='non-json string that will raise an exception') with patch('__builtin__.open', mock): @@ -115,8 +114,7 @@ def test_run_load_payload_bad( load_payload_mock.assert_called_with( 'lambda', 'entity', - 'record', - self.__sa_handler.metrics + 'record' ) @patch('stream_alert.rule_processor.handler.StreamRules.process') @@ -179,10 +177,9 @@ def test_run_send_alerts(self, extract_mock, rules_mock, sink_mock): sink_mock.assert_called_with(['success!!']) @patch('logging.Logger.debug') - @patch('stream_alert.shared.metrics.Metrics.send_metrics') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') - def test_run_debug_log_alert(self, extract_mock, rules_mock, _, log_mock): + def test_run_debug_log_alert(self, extract_mock, rules_mock, log_mock): """StreamAlert Class - Run, Debug Log Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = ['success!!'] diff --git a/tests/unit/stream_alert_rule_processor/test_helpers.py b/tests/unit/stream_alert_rule_processor/test_helpers.py index 2e1163654..53625fc4a 100644 --- a/tests/unit/stream_alert_rule_processor/test_helpers.py +++ b/tests/unit/stream_alert_rule_processor/test_helpers.py @@ -85,7 +85,7 @@ def get_valid_event(count=1): def load_and_classify_payload(config, service, entity, raw_record): """Return a loaded and classified payload.""" # prepare the payloads - payload = load_stream_payload(service, entity, raw_record, None) + payload = load_stream_payload(service, entity, raw_record) payload = payload.pre_parse().next() classifier = StreamClassifier(config=config) diff --git a/tests/unit/stream_alert_rule_processor/test_payload.py b/tests/unit/stream_alert_rule_processor/test_payload.py index acaa4f2ee..aa1a5ca0e 100644 --- a/tests/unit/stream_alert_rule_processor/test_payload.py +++ b/tests/unit/stream_alert_rule_processor/test_payload.py @@ -20,7 +20,7 @@ import os import tempfile -from mock import call, Mock, patch +from mock import call, patch from nose.tools import ( assert_equal, assert_false, @@ -46,7 +46,7 @@ def teardown_s3(): def test_load_payload_valid(): """StreamPayload - Loading Stream Payload, Valid""" - payload = load_stream_payload('s3', 'entity', 'record', None) + payload = load_stream_payload('s3', 'entity', 'record') assert_is_instance(payload, S3Payload) @@ -54,14 +54,14 @@ def test_load_payload_valid(): @patch('logging.Logger.error') def test_load_payload_invalid(log_mock): """StreamPayload - Loading Stream Payload, Invalid""" - load_stream_payload('blah', 'entity', 'record', None) + load_stream_payload('blah', 'entity', 'record') log_mock.assert_called_with('Service payload not supported: %s', 'blah') def test_repr_string(): """StreamPayload - String Representation""" - s3_payload = load_stream_payload('s3', 'entity', 'record', None) + s3_payload = load_stream_payload('s3', 'entity', 'record') # Set some values that are different than the defaults s3_payload.type = 'unit_type' @@ -77,28 +77,28 @@ def test_repr_string(): def test_get_service_kinesis(): """StreamPayload - Get Service, Kinesis""" - kinesis_payload = load_stream_payload('kinesis', 'entity', 'record', None) + kinesis_payload = load_stream_payload('kinesis', 'entity', 'record') assert_equal(kinesis_payload.service(), 'kinesis') def test_get_service_s3(): """StreamPayload - Get Service, S3""" - s3_payload = load_stream_payload('s3', 'entity', 'record', None) + s3_payload = load_stream_payload('s3', 'entity', 'record') assert_equal(s3_payload.service(), 's3') def test_get_service_sns(): """StreamPayload - Get Service, SNS""" - sns_payload = load_stream_payload('sns', 'entity', 'record', None) + sns_payload = load_stream_payload('sns', 'entity', 'record') assert_equal(sns_payload.service(), 'sns') def test_refresh_record(): """StreamPayload - Refresh Record""" - s3_payload = load_stream_payload('s3', 'entity', 'record', None) + s3_payload = load_stream_payload('s3', 'entity', 'record') # Set some values that are different than the defaults s3_payload.type = 'unit_type' @@ -121,7 +121,7 @@ def test_pre_parse_kinesis(log_mock): kinesis_data = json.dumps({'test': 'value'}) entity = 'unit_test_entity' raw_record = make_kinesis_raw_record(entity, kinesis_data) - kinesis_payload = load_stream_payload('kinesis', entity, raw_record, Mock()) + kinesis_payload = load_stream_payload('kinesis', entity, raw_record) kinesis_payload = kinesis_payload.pre_parse().next() @@ -139,7 +139,7 @@ def test_pre_parse_sns(log_mock): """SNSPayload - Pre Parse""" sns_data = json.dumps({'test': 'value'}) raw_record = make_sns_raw_record('unit_topic', sns_data) - sns_payload = load_stream_payload('sns', 'entity', raw_record, Mock()) + sns_payload = load_stream_payload('sns', 'entity', raw_record) sns_payload = sns_payload.pre_parse().next() @@ -159,7 +159,7 @@ def test_pre_parse_s3(s3_mock, *_): s3_mock.side_effect = [((0, records[0]), (1, records[1]))] raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name') - s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record, Mock()) + s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record) for index, record in enumerate(s3_payload.pre_parse()): assert_equal(record.pre_parsed_record, records[index]) @@ -183,7 +183,7 @@ def test_pre_parse_s3_debug(s3_mock, log_mock, _): s3_mock.side_effect = [((100, records[0]), (200, records[1]))] raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name') - s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record, Mock()) + s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record) S3Payload.s3_object_size = 350 _ = [_ for _ in s3_payload.pre_parse()] @@ -208,7 +208,7 @@ def test_pre_parse_s3_debug(s3_mock, log_mock, _): def test_s3_object_too_large(): """S3Payload - S3ObjectSizeError, Object too Large""" raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name') - s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record, None) + s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record) S3Payload.s3_object_size = (128 * 1024 * 1024) + 10 s3_payload._download_object('region', 'bucket', 'key') @@ -219,7 +219,7 @@ def test_s3_object_too_large(): def test_get_object(log_mock, _): """S3Payload - Get S3 Info from Raw Record""" raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name') - s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record, None) + s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record) s3_payload._get_object() log_mock.assert_called_with( @@ -234,7 +234,7 @@ def test_get_object(log_mock, _): def test_s3_download_object(log_mock, *_): """S3Payload - Download Object""" raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name') - s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record, Mock()) + s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record) s3_payload._download_object('us-east-1', 'unit_bucket_name', 'unit_key_name') assert_equal(log_mock.call_args_list[1][0][0], 'Completed download in %s seconds') @@ -247,7 +247,7 @@ def test_s3_download_object(log_mock, *_): def test_s3_download_object_mb(log_mock, *_): """S3Payload - Download Object, Size in MB""" raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name') - s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record, Mock()) + s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record) S3Payload.s3_object_size = (127.8 * 1024 * 1024) s3_payload._download_object('us-east-1', 'unit_bucket_name', 'unit_key_name') diff --git a/tests/unit/stream_alert_shared/test_metrics.py b/tests/unit/stream_alert_shared/test_metrics.py index 951a34884..2584dfd82 100644 --- a/tests/unit/stream_alert_shared/test_metrics.py +++ b/tests/unit/stream_alert_shared/test_metrics.py @@ -16,27 +16,23 @@ # pylint: disable=no-self-use,protected-access import os -from botocore.exceptions import ClientError -from mock import call, Mock, patch +from mock import call, patch from nose.tools import assert_equal -import stream_alert.shared as shared -from tests.unit.stream_alert_rule_processor import REGION +from stream_alert import shared class TestMetrics(object): """Test class for Metrics class""" - def __init__(self): - self.__metrics = None - def setup(self): """Setup before each method""" - self.__metrics = shared.metrics.Metrics('TestFunction', REGION) + os.environ['ENABLE_METRICS'] = '1' + # Force reload the metrics package to trigger env var loading + reload(shared.metrics) def teardown(self): """Teardown after each method""" - self.__metrics = None if 'ENABLE_METRICS' in os.environ: del os.environ['ENABLE_METRICS'] @@ -44,71 +40,41 @@ def teardown(self): del os.environ['LOGGER_LEVEL'] @patch('logging.Logger.error') - def test_invalid_metric_name(self, log_mock): - """Metrics - Invalid Name""" - self.__metrics.add_metric('bad metric name', 100, 'Seconds') + def test_invalid_metric_function(self, log_mock): + """Metrics - Invalid Function Name""" + shared.metrics.MetricLogger.log_metric('rule_procesor', '', '') - log_mock.assert_called_with('Metric name not defined: %s', 'bad metric name') + log_mock.assert_called_with( + 'Function \'%s\' not defined in available metrics. ' + 'Options are: %s', 'rule_procesor', '\'rule_processor\'') @patch('logging.Logger.error') - def test_invalid_metric_unit(self, log_mock): - """Metrics - Invalid Unit Type""" - self.__metrics.add_metric('FailedParses', 100, 'Total') + def test_invalid_metric_name(self, log_mock): + """Metrics - Invalid Metric Name""" + shared.metrics.MetricLogger.log_metric('rule_processor', 'FailedParsed', '') - log_mock.assert_called_with('Metric unit not defined: %s', 'Total') + assert_equal(log_mock.call_args[0][0], 'Metric name (\'%s\') not defined for ' + '\'%s\' function. Options are: %s') + assert_equal(log_mock.call_args[0][1], 'FailedParsed') + assert_equal(log_mock.call_args[0][2], 'rule_processor') - @patch('stream_alert.shared.metrics.Metrics._put_metrics') - def test_valid_metric(self, metric_mock): + @patch('logging.Logger.info') + def test_valid_metric(self, log_mock): """Metrics - Valid Metric""" - # Enable the metrics - os.environ['ENABLE_METRICS'] = '1' - - # Force reload the metrics package to trigger constant loading - reload(shared.metrics) - - self.__metrics.add_metric('FailedParses', 100, 'Count') - self.__metrics.send_metrics() - - metric_mock.assert_called() - - @patch('logging.Logger.exception') - def test_boto_failed(self, log_mock): - """Metrics - Boto Call Failed""" - self.__metrics.boto_cloudwatch = Mock() - - err_response = {'Error': {'Code': 100}} + shared.metrics.MetricLogger.log_metric('rule_processor', 'FailedParses', 100) - # Add ClientError side_effect to mock - self.__metrics.boto_cloudwatch.put_metric_data.side_effect = ClientError( - err_response, 'operation') - - self.__metrics._metric_data.append({'test': 'info'}) - self.__metrics._put_metrics() - - log_mock.assert_called_with( - 'Failed to send metric to CloudWatch. Error: %s\nMetric data:\n%s', - err_response, - '[\n {\n "test": "info"\n }\n]') + log_mock.assert_called_with('{"metric_name": "%s", "metric_value": %s}', + 'FailedParses', 100) @patch('logging.Logger.debug') - def test_no_metrics_to_send(self, log_mock): - """Metrics - No Metrics To Send""" - # Enable the metrics - os.environ['ENABLE_METRICS'] = '1' + def test_disabled_metrics(self, log_mock): + """Metrics - Metrics Disabled""" + os.environ['ENABLE_METRICS'] = '0' # Force reload the metrics package to trigger constant loading reload(shared.metrics) - self.__metrics.send_metrics() - - log_mock.assert_called_with('No metric data to send to CloudWatch.') - - @patch('logging.Logger.debug') - def test_disabled_metrics(self, log_mock): - """Metrics - Metrics Disabled""" - self.__metrics.send_metrics() - - log_mock.assert_called_with('Sending of metric data is currently disabled.') + log_mock.assert_called_with('Logging of metric data is currently disabled.') @patch('logging.Logger.error') def test_disabled_metrics_error(self, log_mock): @@ -123,7 +89,7 @@ def test_disabled_metrics_error(self, log_mock): 'invalid literal for int() with ' 'base 10: \'bad\'') - @patch('stream_alert.shared.LOGGER.error') + @patch('logging.Logger.error') def test_init_logging_bad(self, log_mock): """Shared Init - Logging, Bad Level""" level = 'IFNO' @@ -138,7 +104,7 @@ def test_init_logging_bad(self, log_mock): assert_equal(str(log_mock.call_args_list[0]), message) - @patch('stream_alert.shared.LOGGER.setLevel') + @patch('logging.Logger.setLevel') def test_init_logging_int_level(self, log_mock): """Shared Init - Logging, Integer Level""" level = '10'