diff --git a/pshtt/cli.py b/pshtt/cli.py index 13847fde..b43ad6f5 100755 --- a/pshtt/cli.py +++ b/pshtt/cli.py @@ -28,17 +28,68 @@ from . import pshtt from . import utils from . import __version__ +from .utils import smart_open +import csv import docopt import logging import sys +import pytablewriter + + +def to_csv(results, out_filename): + utils.debug("Opening CSV file: {}".format(out_filename)) + with smart_open(out_filename) as out_file: + writer = csv.writer(out_file) + + # Write out header + writer.writerow(pshtt.HEADERS) + + # Write out the row data as it completes + for result in results: + row = [result[header] for header in pshtt.HEADERS] + writer.writerow(row) + + logging.warn("Wrote results to %s.", out_filename) + + +def to_json(results, out_filename): + # Generate (yield) all the results before exporting to JSON + results = list(results) + + with smart_open(out_filename) as out_file: + json_content = utils.json_for(results) + + out_file.write(json_content + '\n') + + if out_file is not sys.stdout: + logging.warn("Wrote results to %s.", out_filename) + + +def to_markdown(results, out_filename): + # Generate (yield) all the results before exporting to Markdown + table = [ + [" %s" % result[header] for header in pshtt.HEADERS] + for result in results + ] + + utils.debug("Printing Markdown...", divider=True) + with smart_open(out_filename) as out_file: + writer = pytablewriter.MarkdownTableWriter() + + writer.header_list = pshtt.HEADERS + writer.value_matrix = table + writer.stream = out_file + + writer.write_table() + def main(): args = docopt.docopt(__doc__, version=__version__) utils.configure_logging(args['--debug']) - out_file = args['--output'] + out_filename = args['--output'] # Read from a .csv, or allow domains on the command line. domains = [] @@ -61,36 +112,24 @@ def main(): 'cache': args['--cache'], 'ca_file': args['--ca-file'] } + + # Do the domain inspections results = pshtt.inspect_domains(domains, options) # JSON can go to STDOUT, or to a file. if args['--json']: - output = utils.json_for(results) - if out_file is None: - - utils.debug("Printing JSON...", divider=True) - print(output) - else: - utils.write(output, out_file) - logging.warn("Wrote results to %s." % out_file) - # Markdwon can go to STDOUT, or to a file - elif args['--markdown']: - output = sys.stdout - if out_file is not None: - output = open(out_file, 'w') + to_json(results, out_filename) - utils.debug("Printing Markdown...", divider=True) - pshtt.md_for(results, output) + # Markdown can go to STDOUT, or to a file + elif args['--markdown']: + to_markdown(results, out_filename) - if out_file is not None: - output.close() # CSV always goes to a file. else: - if args['--output'] is None: - out_file = 'results.csv' - pshtt.csv_for(results, out_file) - utils.debug("Writing results...", divider=True) - logging.warn("Wrote results to %s." % out_file) + if out_filename is None: + out_filename = 'results.csv' + + to_csv(results, out_filename) if __name__ == '__main__': diff --git a/pshtt/pshtt.py b/pshtt/pshtt.py index b7a9e00a..019fe304 100644 --- a/pshtt/pshtt.py +++ b/pshtt/pshtt.py @@ -9,10 +9,8 @@ import re import base64 import json -import csv import os import logging -import pytablewriter import sys import codecs import OpenSSL @@ -134,6 +132,17 @@ def result_for(domain): # But also capture the extended data for those who want it. result['endpoints'] = domain.to_object() + # Convert Header fields from None to False, except for: + # - "HSTS Header" + # - "HSTS Max Age" + # - "Redirect To" + for header in HEADERS: + if header in ("HSTS Header", "HSTS Max Age", "Redirect To"): + continue + + if result[header] is None: + result[header] = False + return result @@ -1080,49 +1089,6 @@ def load_suffix_list(): return suffixes -def md_for(results, out_fd): - value_matrix = [] - for result in results: - row = [] - # TODO: Fix this upstream - for header in HEADERS: - if (header != "HSTS Header") and (header != "HSTS Max Age") and (header != "Redirect To"): - if result[header] is None: - result[header] = False - row.append(" %s" % result[header]) - value_matrix.append(row) - - writer = pytablewriter.MarkdownTableWriter() - writer.header_list = HEADERS - writer.value_matrix = value_matrix - - writer.stream = out_fd - writer.write_table() - - -def csv_for(results, out_filename): - """ - Output a CSV string for an array of results, with a - header row, and with header fields in the desired order. - """ - out_file = open(out_filename, 'w') - writer = csv.writer(out_file) - - writer.writerow(HEADERS) - - for result in results: - row = [] - # TODO: Fix this upstream - for header in HEADERS: - if (header != "HSTS Header") and (header != "HSTS Max Age") and (header != "Redirect To"): - if result[header] is None: - result[header] = False - row.append(result[header]) - writer.writerow(row) - - out_file.close() - - def inspect_domains(domains, options): # Override timeout, user agent, preload cache, default CA bundle global TIMEOUT, USER_AGENT, PRELOAD_CACHE, WEB_CACHE, SUFFIX_CACHE, CA_FILE, STORE @@ -1165,8 +1131,5 @@ def inspect_domains(domains, options): suffix_list = load_suffix_list() # For every given domain, get inspect data. - results = [] for domain in domains: - results.append(inspect(domain)) - - return results + yield inspect(domain) diff --git a/pshtt/utils.py b/pshtt/utils.py index ea9895ff..ab672f41 100644 --- a/pshtt/utils.py +++ b/pshtt/utils.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +import contextlib import os import json import errno @@ -96,3 +97,22 @@ def debug(message, divider=False): if message: logging.debug("%s\n" % message) + + +@contextlib.contextmanager +def smart_open(filename=None): + """ + Context manager that can handle writing to a file or stdout + + Adapted from: https://stackoverflow.com/a/17603000 + """ + if filename is None: + fh = sys.stdout + else: + fh = open(filename, 'w') + + try: + yield fh + finally: + if fh is not sys.stdout: + fh.close() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..44134717 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,92 @@ +import os +import sys +import tempfile +import unittest + +from pshtt.models import Domain, Endpoint +from pshtt import pshtt as _pshtt +from pshtt.cli import to_csv + + +class FakeSuffixList(object): + def get_public_suffix(self, hostname, *args, **kwargs): + return hostname + + +# Artificially setup the the preload and suffix lists +# This should be irrelevant after #126 is decided upon / merged +_pshtt.suffix_list = FakeSuffixList() +_pshtt.preload_list = [] +_pshtt.preload_pending = [] + + +class TestToCSV(unittest.TestCase): + @classmethod + def setUpClass(cls): + base_domain = 'example.com' + + domain = Domain(base_domain) + domain.http = Endpoint("http", "root", base_domain) + domain.httpwww = Endpoint("http", "www", base_domain) + domain.https = Endpoint("https", "root", base_domain) + domain.httpswww = Endpoint("https", "www", base_domain) + + cls.results = _pshtt.result_for(domain) + cls.temp_filename = os.path.join(tempfile.gettempdir(), 'results.csv') + + @unittest.skipIf(sys.version_info[0] < 3, 'Python 3 test only') + def test_no_results(self): + to_csv([], self.temp_filename) + + with open(self.temp_filename) as fh: + content = fh.read() + + expected = ','.join(_pshtt.HEADERS) + '\n' + + self.assertEqual(content, expected) + + @unittest.skipIf(sys.version_info[0] < 3, 'Python 3 test only') + def test_single_result(self): + to_csv([self.results], self.temp_filename) + + with open(self.temp_filename) as fh: + content = fh.read() + + domain_data = [ + ('Domain', 'example.com'), + ('Base Domain', 'example.com'), + ('Canonical URL', 'http://example.com'), + ('Live', 'False'), + ('Redirect', 'False'), + ('Redirect To', ''), + ('Valid HTTPS', 'False'), + ('Defaults to HTTPS', 'False'), + ('Downgrades HTTPS', 'False'), + ('Strictly Forces HTTPS', 'False'), + ('HTTPS Bad Chain', 'False'), + ('HTTPS Bad Hostname', 'False'), + ('HTTPS Expired Cert', 'False'), + ('HTTPS Self Signed Cert', 'False'), + ('HSTS', 'False'), + ('HSTS Header', ''), + ('HSTS Max Age', ''), + ('HSTS Entire Domain', 'False'), + ('HSTS Preload Ready', 'False'), + ('HSTS Preload Pending', 'False'), + ('HSTS Preloaded', 'False'), + ('Base Domain HSTS Preloaded', 'False'), + ('Domain Supports HTTPS', 'False'), + ('Domain Enforces HTTPS', 'False'), + ('Domain Uses Strong HSTS', 'False'), + ('Unknown Error', 'False'), + ] + + header = ','.join(t[0] for t in domain_data) + values = ','.join(t[1] for t in domain_data) + expected = header + '\n' + values + '\n' + self.assertEqual(content, expected) + + # Sanity check that this hard coded data has the same headers as defined + # in the package. This should never fail, as the above assert should + # catch any changes in the header columns. + self.assertEqual(header, ','.join(_pshtt.HEADERS)) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..1a0949ac --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,39 @@ +import os +import sys +import tempfile +import unittest + +from pshtt.utils import smart_open + + +class TestSmartOpen(unittest.TestCase): + def test_without_filename(self): + with smart_open() as fh: + self.assertIs(fh, sys.stdout) + + @unittest.skipIf(sys.version_info[0] < 3, 'Python 3 version of test') + def test_with_empty_filename(self): + """Should raise a `FileNotFoundError`""" + with self.assertRaises(FileNotFoundError): # noqa + with smart_open(''): + pass + + @unittest.skipIf(sys.version_info[0] >= 3, 'Python 2 version of test') + def test_with_empty_filename_python2(self): + """Should raise a `FileNotFoundError`""" + with self.assertRaises(IOError): + with smart_open(''): + pass + + @unittest.skipIf(sys.version_info[0] < 3, 'Python 3 version of test') + def test_with_real_filename(self): + test_data = 'This is the test data' + + with tempfile.TemporaryDirectory() as tmp_dirname: + # Make a temporary file to use + filename = os.path.join(tmp_dirname, 'foo') + + with smart_open(filename) as fh: + fh.write(test_data) + + self.assertEqual(test_data, open(filename).read()) diff --git a/tox.ini b/tox.ini index 41cca8ed..235cafd3 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,7 @@ deps = pytest-cov pytest coveralls -commands = pytest --cov=pshtt +commands = pytest --cov={envsitepackagesdir}/pshtt [testenv:flake8] deps = flake8