Skip to content

Commit

Permalink
updating all unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ryandeivert committed Sep 19, 2018
1 parent 59f3058 commit 7e115de
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 134 deletions.
25 changes: 1 addition & 24 deletions tests/unit/stream_alert_athena_partition_refresh/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import os

from mock import call, Mock, patch
from mock import Mock, patch
from nose.tools import assert_equal, assert_raises, assert_true, raises

from stream_alert.athena_partition_refresh.main import AthenaRefresher, AthenaRefreshError
Expand All @@ -26,29 +26,6 @@
from tests.unit.helpers.aws_mocks import MockAthenaClient


@patch('logging.Logger.error')
def test_init_logging_bad(log_mock):
"""Athena Parition Refresh Init - Logging, Bad Level"""
level = 'IFNO'
with patch.dict(os.environ, {'LOGGER_LEVEL': level}):
import stream_alert.athena_partition_refresh
reload(stream_alert.athena_partition_refresh)

message = str(call('Defaulting to INFO logging: %s',
ValueError('Unknown level: \'IFNO\'',)))

assert_equal(str(log_mock.call_args_list[0]), message)


@patch('stream_alert.athena_partition_refresh.LOGGER.setLevel')
def test_init_logging_int_level(log_mock):
"""Athena Parition Refresh Init - Logging, Integer Level"""
with patch.dict(os.environ, {'LOGGER_LEVEL': '10'}):
import stream_alert.athena_partition_refresh
reload(stream_alert.athena_partition_refresh)
log_mock.assert_called_with(10)


# Without this time.sleep patch, backoff performs sleep
# operations and drastically slows down testing
@patch('time.sleep', Mock())
Expand Down
25 changes: 7 additions & 18 deletions tests/unit/stream_alert_rule_processor/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# pylint: disable=protected-access,attribute-defined-outside-init
import base64
import json
import logging
import os

from mock import ANY, call, Mock, patch
Expand All @@ -29,8 +28,7 @@
)
import boto3

from stream_alert.rule_processor import LOGGER
from stream_alert.rule_processor.handler import StreamAlert
import stream_alert.rule_processor.handler as handler
from stream_alert.rule_processor.threat_intel import StreamThreatIntel
from stream_alert.shared.alert import Alert
from stream_alert.shared.config import load_config
Expand All @@ -55,7 +53,7 @@ class TestStreamAlert(object):
'AWS_DEFAULT_REGION': 'us-east-1'})
def setup(self):
"""Setup before each method"""
self.__sa_handler = StreamAlert(get_mock_context())
self.__sa_handler = handler.StreamAlert(get_mock_context())

def test_run_no_records(self):
"""StreamAlert Class - Run, No Records"""
Expand Down Expand Up @@ -200,18 +198,9 @@ def test_run_debug_log_alert(self, extract_mock, rules_mock, alerts_mock, log_mo
rules_mock.return_value = ([Alert('rule_name', {}, {'output'})], ['normalized_records'])
alerts_mock.return_value = []

# Cache the logger level
log_level = LOGGER.getEffectiveLevel()

# Increase the logger level to debug
LOGGER.setLevel(logging.DEBUG)

self.__sa_handler.run(get_valid_event())

# Reset the logger level
LOGGER.setLevel(log_level)

log_mock.assert_called_with('Alerts:\n%s', ANY)
with patch.object(handler, 'LOGGER_DEBUG_ENABLED', True):
self.__sa_handler.run(get_valid_event())
log_mock.assert_called_with('Alerts:\n%s', ANY)

@patch('stream_alert.rule_processor.handler.load_stream_payload')
@patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
Expand Down Expand Up @@ -285,7 +274,7 @@ def match_ipaddress(_): # pylint: disable=unused-variable
region='us-east-1')
mock_query.return_value = ([], [])

sa_handler = StreamAlert(get_mock_context())
sa_handler = handler.StreamAlert(get_mock_context())
event = {
'account': 123456,
'region': '123456123456',
Expand Down Expand Up @@ -327,7 +316,7 @@ def match_ip_address(_): # pylint: disable=unused-variable
region='us-east-1')
mock_query.return_value = ([], [])

sa_handler = StreamAlert(get_mock_context())
sa_handler = handler.StreamAlert(get_mock_context())
event = {
'account': 123456,
'region': '123456123456',
Expand Down
44 changes: 0 additions & 44 deletions tests/unit/stream_alert_rule_processor/test_init.py

This file was deleted.

44 changes: 18 additions & 26 deletions tests/unit/stream_alert_rule_processor/test_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# pylint: disable=protected-access
import json
import gzip
import logging
import os
import tempfile

Expand All @@ -32,7 +31,7 @@
)
from pyfakefs import fake_filesystem_unittest

from stream_alert.rule_processor import LOGGER
import stream_alert.rule_processor.payload as payload
from stream_alert.rule_processor.payload import load_stream_payload, S3ObjectSizeError, S3Payload
from tests.unit.stream_alert_rule_processor.test_helpers import (
make_kinesis_raw_record,
Expand Down Expand Up @@ -173,36 +172,29 @@ def test_pre_parse_s3(s3_mock, *_):
@patch('stream_alert.rule_processor.payload.S3Payload._read_downloaded_s3_object')
def test_pre_parse_s3_debug(s3_mock, log_mock, _):
"""S3Payload - Pre Parse, Debug On"""
# Cache the logger level
log_level = LOGGER.getEffectiveLevel()
with patch.object(payload, 'LOGGER_DEBUG_ENABLED', True):

# Increase the logger level to debug
LOGGER.setLevel(logging.DEBUG)
records = ['_first_line_test_' * 10,
'_second_line_test_' * 10]

records = ['_first_line_test_' * 10,
'_second_line_test_' * 10]
s3_mock.side_effect = [((100, records[0]), (200, records[1]))]

s3_mock.side_effect = [((100, records[0]), (200, records[1]))]
raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name')
s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record)
S3Payload.s3_object_size = 350

raw_record = make_s3_raw_record('unit_bucket_name', 'unit_key_name')
s3_payload = load_stream_payload('s3', 'unit_key_name', raw_record)
S3Payload.s3_object_size = 350

_ = [_ for _ in s3_payload.pre_parse()]

calls = [
call('Processed %s S3 records out of an approximate total of %s '
'(average record size: %s bytes, total size: %s bytes)',
100, 350, 1, 350),
call('Processed %s S3 records out of an approximate total of %s '
'(average record size: %s bytes, total size: %s bytes)',
200, 350, 1, 350)
]
_ = [_ for _ in s3_payload.pre_parse()]

log_mock.assert_has_calls(calls)
calls = [
call('Processed %s S3 records out of an approximate total of %s '
'(average record size: %s bytes, total size: %s bytes)',
100, 350, 1, 350),
call('Processed %s S3 records out of an approximate total of %s '
'(average record size: %s bytes, total size: %s bytes)',
200, 350, 1, 350)
]

# Reset the logger level and stop the patchers
LOGGER.setLevel(log_level)
log_mock.assert_has_calls(calls)


@with_setup(setup=None, teardown=teardown_s3)
Expand Down
23 changes: 1 addition & 22 deletions tests/unit/stream_alert_shared/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# pylint: disable=no-self-use,protected-access
import os

from mock import call, patch
from mock import patch
from nose.tools import assert_equal

from stream_alert import shared
Expand Down Expand Up @@ -69,24 +69,3 @@ def test_disabled_metrics_error(self, log_mock):
'expected 0 or 1: %s',
'invalid literal for int() with '
'base 10: \'bad\'')

@patch('logging.Logger.error')
def test_init_logging_bad(self, log_mock):
"""Shared Init - Logging, Bad Level"""
with patch.dict('os.environ', {'LOGGER_LEVEL': 'IFNO'}):
# Force reload the shared package to trigger the init
reload(shared)

message = str(call('Defaulting to INFO logging: %s',
ValueError('Unknown level: \'IFNO\'',)))

assert_equal(str(log_mock.call_args_list[0]), message)

@patch('logging.Logger.setLevel')
def test_init_logging_int_level(self, log_mock):
"""Shared Init - Logging, Integer Level"""
with patch.dict('os.environ', {'LOGGER_LEVEL': '10'}):
# Force reload the shared package to trigger the init
reload(shared)

log_mock.assert_called_with(10)

0 comments on commit 7e115de

Please sign in to comment.