Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-architecting of CSV / MD / JSON writing #125

Merged
merged 16 commits into from
Oct 24, 2017
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 62 additions & 23 deletions pshtt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,68 @@
from . import pshtt
from . import utils
from . import __version__
from .utils import smart_open

import csv
import docopt
import logging
import sys

import pytablewriter


def to_csv(results, out_filename):
utils.debug("Opening CSV file: {}".format(out_filename))
with smart_open(out_filename) as out_file:
writer = csv.writer(out_file)

# Write out header
writer.writerow(pshtt.HEADERS)

# Write out the row data as it completes
for result in results:
row = [result[header] for header in pshtt.HEADERS]
writer.writerow(row)

logging.warn("Wrote results to %s.", out_filename)


def to_json(results, out_filename):
# Generate (yield) all the results before exporting to JSON
results = list(results)

with smart_open(out_filename) as out_file:
json_content = utils.json_for(results)

out_file.write(json_content + '\n')

if out_file is not sys.stdout:
logging.warn("Wrote results to %s.", out_filename)


def to_markdown(results, out_filename):
# Generate (yield) all the results before exporting to Markdown
table = [
[" %s" % result[header] for header in pshtt.HEADERS]
for result in results
]

utils.debug("Printing Markdown...", divider=True)
with smart_open(out_filename) as out_file:
writer = pytablewriter.MarkdownTableWriter()

writer.header_list = pshtt.HEADERS
writer.value_matrix = table
writer.stream = out_file

writer.write_table()


def main():
args = docopt.docopt(__doc__, version=__version__)
utils.configure_logging(args['--debug'])

out_file = args['--output']
out_filename = args['--output']

# Read from a .csv, or allow domains on the command line.
domains = []
Expand All @@ -61,36 +112,24 @@ def main():
'cache': args['--cache'],
'ca_file': args['--ca-file']
}

# Do the domain inspections
results = pshtt.inspect_domains(domains, options)

# JSON can go to STDOUT, or to a file.
if args['--json']:
output = utils.json_for(results)
if out_file is None:

utils.debug("Printing JSON...", divider=True)
print(output)
else:
utils.write(output, out_file)
logging.warn("Wrote results to %s." % out_file)
# Markdwon can go to STDOUT, or to a file
elif args['--markdown']:
output = sys.stdout
if out_file is not None:
output = open(out_file, 'w')
to_json(results, out_filename)

utils.debug("Printing Markdown...", divider=True)
pshtt.md_for(results, output)
# Markdown can go to STDOUT, or to a file
elif args['--markdown']:
to_markdown(results, out_filename)

if out_file is not None:
output.close()
# CSV always goes to a file.
else:
if args['--output'] is None:
out_file = 'results.csv'
pshtt.csv_for(results, out_file)
utils.debug("Writing results...", divider=True)
logging.warn("Wrote results to %s." % out_file)
if out_filename is None:
out_filename = 'results.csv'

to_csv(results, out_filename)


if __name__ == '__main__':
Expand Down
61 changes: 12 additions & 49 deletions pshtt/pshtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import re
import base64
import json
import csv
import os
import logging
import pytablewriter
import sys
import codecs
import OpenSSL
Expand Down Expand Up @@ -134,6 +132,17 @@ def result_for(domain):
# But also capture the extended data for those who want it.
result['endpoints'] = domain.to_object()

# Convert Header fields from None to False, except for:
# - "HSTS Header"
# - "HSTS Max Age"
# - "Redirect To"
for header in HEADERS:
if header in ("HSTS Header", "HSTS Max Age", "Redirect To"):
continue

if result[header] is None:
result[header] = False

return result


Expand Down Expand Up @@ -1080,49 +1089,6 @@ def load_suffix_list():
return suffixes


def md_for(results, out_fd):
value_matrix = []
for result in results:
row = []
# TODO: Fix this upstream
for header in HEADERS:
if (header != "HSTS Header") and (header != "HSTS Max Age") and (header != "Redirect To"):
if result[header] is None:
result[header] = False
row.append(" %s" % result[header])
value_matrix.append(row)

writer = pytablewriter.MarkdownTableWriter()
writer.header_list = HEADERS
writer.value_matrix = value_matrix

writer.stream = out_fd
writer.write_table()


def csv_for(results, out_filename):
"""
Output a CSV string for an array of results, with a
header row, and with header fields in the desired order.
"""
out_file = open(out_filename, 'w')
writer = csv.writer(out_file)

writer.writerow(HEADERS)

for result in results:
row = []
# TODO: Fix this upstream
for header in HEADERS:
if (header != "HSTS Header") and (header != "HSTS Max Age") and (header != "Redirect To"):
if result[header] is None:
result[header] = False
row.append(result[header])
writer.writerow(row)

out_file.close()


def inspect_domains(domains, options):
# Override timeout, user agent, preload cache, default CA bundle
global TIMEOUT, USER_AGENT, PRELOAD_CACHE, WEB_CACHE, SUFFIX_CACHE, CA_FILE, STORE
Expand Down Expand Up @@ -1165,8 +1131,5 @@ def inspect_domains(domains, options):
suffix_list = load_suffix_list()

# For every given domain, get inspect data.
results = []
for domain in domains:
results.append(inspect(domain))

return results
yield inspect(domain)
20 changes: 20 additions & 0 deletions pshtt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python

import contextlib
import os
import json
import errno
Expand Down Expand Up @@ -96,3 +97,22 @@ def debug(message, divider=False):

if message:
logging.debug("%s\n" % message)


@contextlib.contextmanager
def smart_open(filename=None):
"""
Context manager that can handle writing to a file or stdout

Adapted from: https://stackoverflow.com/a/17603000
"""
if filename is None:
fh = sys.stdout
else:
fh = open(filename, 'w')

try:
yield fh
finally:
if fh is not sys.stdout:
fh.close()
59 changes: 59 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import sys
import tempfile
import unittest

from pshtt.models import Domain, Endpoint
from pshtt import pshtt as _pshtt
from pshtt.cli import to_csv


class FakeSuffixList(object):
def get_public_suffix(self, hostname, *args, **kwargs):
return hostname


# Artificially setup the the preload and suffix lists
# This should be irrelevant after #126 is decided upon / merged
_pshtt.suffix_list = FakeSuffixList()
_pshtt.preload_list = []
_pshtt.preload_pending = []


class TestToCSV(unittest.TestCase):
@classmethod
def setUpClass(cls):
base_domain = 'example.com'

domain = Domain(base_domain)
domain.http = Endpoint("http", "root", base_domain)
domain.httpwww = Endpoint("http", "www", base_domain)
domain.https = Endpoint("https", "root", base_domain)
domain.httpswww = Endpoint("https", "www", base_domain)

cls.results = _pshtt.result_for(domain)
cls.temp_filename = os.path.join(tempfile.gettempdir(), 'results.csv')

@unittest.skipIf(sys.version_info[0] < 3, 'Python 3 test only')
def test_no_results(self):
to_csv([], self.temp_filename)

with open(self.temp_filename) as fh:
content = fh.read()

expected = 'Domain,Base Domain,Canonical URL,Live,Redirect,Redirect To,Valid HTTPS,Defaults to HTTPS,Downgrades HTTPS,Strictly Forces HTTPS,HTTPS Bad Chain,HTTPS Bad Hostname,HTTPS Expired Cert,HTTPS Self Signed Cert,HSTS,HSTS Header,HSTS Max Age,HSTS Entire Domain,HSTS Preload Ready,HSTS Preload Pending,HSTS Preloaded,Base Domain HSTS Preloaded,Domain Supports HTTPS,Domain Enforces HTTPS,Domain Uses Strong HSTS,Unknown Error\n'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we generate this by pulling in the global var and joining them with .? That would prevent us from having to update the test data every time a column header changes.


self.assertEqual(content, expected)

@unittest.skipIf(sys.version_info[0] < 3, 'Python 3 test only')
def test_single_result(self):
to_csv([self.results], self.temp_filename)

with open(self.temp_filename) as fh:
content = fh.read()

expected = ''
expected += 'Domain,Base Domain,Canonical URL,Live,Redirect,Redirect To,Valid HTTPS,Defaults to HTTPS,Downgrades HTTPS,Strictly Forces HTTPS,HTTPS Bad Chain,HTTPS Bad Hostname,HTTPS Expired Cert,HTTPS Self Signed Cert,HSTS,HSTS Header,HSTS Max Age,HSTS Entire Domain,HSTS Preload Ready,HSTS Preload Pending,HSTS Preloaded,Base Domain HSTS Preloaded,Domain Supports HTTPS,Domain Enforces HTTPS,Domain Uses Strong HSTS,Unknown Error\n'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note as above.

expected += 'example.com,example.com,http://example.com,False,False,,False,False,False,False,False,False,False,False,False,,,False,False,False,False,False,False,False,False,False\n'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be best managed in the test data as an array which is also joined by a comma in code before appending to expected. It still might require changing when the headers change, but it would be much more obvious and direct how to make that change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Replying to all three comments) Yeah, I actually started down that route originally. For the headers, as you say, using ','.join(_pshtt.HEADERS) is pretty simple, the issue for me then became how to convert the data itself, as doing a list of the values will make it tougher to map each back to the appropriate header...

Maybe creating a full domain object and setting all the values there would be best... Let me give that a shot.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as doing a list of the values will make it tougher to map each back to the appropriate header...

This is true, but it's probably an acceptable annoyance, IMO. I think it's likely to be less brittle than a concatenated string, and so just turning that into a static array would be enough for me for this PR.


self.assertEqual(content, expected)
39 changes: 39 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import sys
import tempfile
import unittest

from pshtt.utils import smart_open


class TestSmartOpen(unittest.TestCase):
def test_without_filename(self):
with smart_open() as fh:
self.assertIs(fh, sys.stdout)

@unittest.skipIf(sys.version_info[0] < 3, 'Python 3 version of test')
def test_with_empty_filename(self):
"""Should raise a `FileNotFoundError`"""
with self.assertRaises(FileNotFoundError): # noqa
with smart_open(''):
pass

@unittest.skipIf(sys.version_info[0] >= 3, 'Python 2 version of test')
def test_with_empty_filename_python2(self):
"""Should raise a `FileNotFoundError`"""
with self.assertRaises(IOError):
with smart_open(''):
pass

@unittest.skipIf(sys.version_info[0] < 3, 'Python 3 version of test')
def test_with_real_filename(self):
test_data = 'This is the test data'

with tempfile.TemporaryDirectory() as tmp_dirname:
# Make a temporary file to use
filename = os.path.join(tmp_dirname, 'foo')

with smart_open(filename) as fh:
fh.write(test_data)

self.assertEqual(test_data, open(filename).read())
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ deps =
pytest-cov
pytest
coveralls
commands = pytest --cov=pshtt
commands = pytest --cov={envsitepackagesdir}/pshtt

[testenv:flake8]
deps = flake8
Expand Down