From 82e11c1da7a848b36114dbdde7b2e94dac261abb Mon Sep 17 00:00:00 2001 From: gfyoung Date: Sat, 8 Apr 2017 14:04:43 -0400 Subject: [PATCH] TST: Add decorators for redirecting stdout/err --- pandas/tests/frame/test_repr_info.py | 4 +- pandas/tests/io/parser/common.py | 81 +++++++-------- pandas/tests/io/parser/python_parser_only.py | 14 +-- pandas/tests/io/parser/test_textreader.py | 23 ++--- pandas/tests/io/test_sql.py | 43 +++----- pandas/tests/plotting/test_frame.py | 21 ++-- pandas/tests/series/test_repr.py | 21 ++-- pandas/util/decorators.py | 14 +-- pandas/util/testing.py | 101 ++++++++++++++++++- 9 files changed, 184 insertions(+), 138 deletions(-) diff --git a/pandas/tests/frame/test_repr_info.py b/pandas/tests/frame/test_repr_info.py index 024e11e63a924..918938c1758ed 100644 --- a/pandas/tests/frame/test_repr_info.py +++ b/pandas/tests/frame/test_repr_info.py @@ -191,6 +191,7 @@ def test_latex_repr(self): # GH 12182 self.assertIsNone(df._repr_latex_()) + @tm.capture_stdout def test_info(self): io = StringIO() self.frame.info(buf=io) @@ -198,11 +199,8 @@ def test_info(self): frame = DataFrame(np.random.randn(5, 3)) - import sys - sys.stdout = StringIO() frame.info() frame.info(verbose=False) - sys.stdout = sys.__stdout__ def test_info_wide(self): from pandas import set_option, reset_option diff --git a/pandas/tests/io/parser/common.py b/pandas/tests/io/parser/common.py index ab30301e710a6..6eadf2c61c974 100644 --- a/pandas/tests/io/parser/common.py +++ b/pandas/tests/io/parser/common.py @@ -1254,6 +1254,7 @@ def test_regex_separator(self): columns=['a', 'b', 'c']) tm.assert_frame_equal(result, expected) + @tm.capture_stdout def test_verbose_import(self): text = """a,b,c,d one,1,2,3 @@ -1265,22 +1266,18 @@ def test_verbose_import(self): one,1,2,3 two,1,2,3""" - buf = StringIO() - sys.stdout = buf + # Engines are verbose in different ways. + self.read_csv(StringIO(text), verbose=True) + output = sys.stdout.getvalue() - try: # engines are verbose in different ways - self.read_csv(StringIO(text), verbose=True) - if self.engine == 'c': - self.assertIn('Tokenization took:', buf.getvalue()) - self.assertIn('Parser memory cleanup took:', buf.getvalue()) - else: # Python engine - self.assertEqual(buf.getvalue(), - 'Filled 3 NA values in column a\n') - finally: - sys.stdout = sys.__stdout__ + if self.engine == 'c': + assert 'Tokenization took:' in output + assert 'Parser memory cleanup took:' in output + else: # Python engine + assert output == 'Filled 3 NA values in column a\n' - buf = StringIO() - sys.stdout = buf + # Reset the stdout buffer. + sys.stdout = StringIO() text = """a,b,c,d one,1,2,3 @@ -1292,16 +1289,15 @@ def test_verbose_import(self): seven,1,2,3 eight,1,2,3""" - try: # engines are verbose in different ways - self.read_csv(StringIO(text), verbose=True, index_col=0) - if self.engine == 'c': - self.assertIn('Tokenization took:', buf.getvalue()) - self.assertIn('Parser memory cleanup took:', buf.getvalue()) - else: # Python engine - self.assertEqual(buf.getvalue(), - 'Filled 1 NA values in column a\n') - finally: - sys.stdout = sys.__stdout__ + self.read_csv(StringIO(text), verbose=True, index_col=0) + output = sys.stdout.getvalue() + + # Engines are verbose in different ways. + if self.engine == 'c': + assert 'Tokenization took:' in output + assert 'Parser memory cleanup took:' in output + else: # Python engine + assert output == 'Filled 1 NA values in column a\n' def test_iteration_open_handle(self): if PY3: @@ -1696,6 +1692,7 @@ class InvalidBuffer(object): with tm.assertRaisesRegexp(ValueError, msg): self.read_csv(mock.Mock()) + @tm.capture_stderr def test_skip_bad_lines(self): # see gh-15925 data = 'a\n1\n1,2,3\n4\n5,6,7' @@ -1706,30 +1703,24 @@ def test_skip_bad_lines(self): with tm.assertRaises(ParserError): self.read_csv(StringIO(data), error_bad_lines=True) - stderr = sys.stderr expected = DataFrame({'a': [1, 4]}) - sys.stderr = StringIO() - try: - out = self.read_csv(StringIO(data), - error_bad_lines=False, - warn_bad_lines=False) - tm.assert_frame_equal(out, expected) + out = self.read_csv(StringIO(data), + error_bad_lines=False, + warn_bad_lines=False) + tm.assert_frame_equal(out, expected) - val = sys.stderr.getvalue() - self.assertEqual(val, '') - finally: - sys.stderr = stderr + val = sys.stderr.getvalue() + assert val == '' + # Reset the stderr buffer. sys.stderr = StringIO() - try: - out = self.read_csv(StringIO(data), - error_bad_lines=False, - warn_bad_lines=True) - tm.assert_frame_equal(out, expected) - val = sys.stderr.getvalue() - self.assertTrue('Skipping line 3' in val) - self.assertTrue('Skipping line 5' in val) - finally: - sys.stderr = stderr + out = self.read_csv(StringIO(data), + error_bad_lines=False, + warn_bad_lines=True) + tm.assert_frame_equal(out, expected) + + val = sys.stderr.getvalue() + assert 'Skipping line 3' in val + assert 'Skipping line 5' in val diff --git a/pandas/tests/io/parser/python_parser_only.py b/pandas/tests/io/parser/python_parser_only.py index 510e3c689649c..2949254257d15 100644 --- a/pandas/tests/io/parser/python_parser_only.py +++ b/pandas/tests/io/parser/python_parser_only.py @@ -8,7 +8,6 @@ """ import csv -import sys import pytest import pandas.util.testing as tm @@ -92,16 +91,9 @@ def test_BytesIO_input(self): def test_single_line(self): # see gh-6607: sniff separator - - buf = StringIO() - sys.stdout = buf - - try: - df = self.read_csv(StringIO('1,2'), names=['a', 'b'], - header=None, sep=None) - tm.assert_frame_equal(DataFrame({'a': [1], 'b': [2]}), df) - finally: - sys.stdout = sys.__stdout__ + df = self.read_csv(StringIO('1,2'), names=['a', 'b'], + header=None, sep=None) + tm.assert_frame_equal(DataFrame({'a': [1], 'b': [2]}), df) def test_skipfooter(self): # see gh-6607 diff --git a/pandas/tests/io/parser/test_textreader.py b/pandas/tests/io/parser/test_textreader.py index b6a9900b0b087..505dc16942f31 100644 --- a/pandas/tests/io/parser/test_textreader.py +++ b/pandas/tests/io/parser/test_textreader.py @@ -142,6 +142,7 @@ def test_integer_thousands_alt(self): expected = DataFrame([123456, 12500]) tm.assert_frame_equal(result, expected) + @tm.capture_stderr def test_skip_bad_lines(self): # too many lines, see #2430 for why data = ('a:b:c\n' @@ -165,19 +166,15 @@ def test_skip_bad_lines(self): 2: ['c', 'f', 'i', 'n']} assert_array_dicts_equal(result, expected) - stderr = sys.stderr - sys.stderr = StringIO() - try: - reader = TextReader(StringIO(data), delimiter=':', - header=None, - error_bad_lines=False, - warn_bad_lines=True) - reader.read() - val = sys.stderr.getvalue() - self.assertTrue('Skipping line 4' in val) - self.assertTrue('Skipping line 6' in val) - finally: - sys.stderr = stderr + reader = TextReader(StringIO(data), delimiter=':', + header=None, + error_bad_lines=False, + warn_bad_lines=True) + reader.read() + val = sys.stderr.getvalue() + + assert 'Skipping line 4' in val + assert 'Skipping line 6' in val def test_header_not_enough_lines(self): data = ('skip this\n' diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 890f52e8c65e9..5318e8532c58e 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -23,7 +23,6 @@ import sqlite3 import csv import os -import sys import warnings import numpy as np @@ -36,7 +35,7 @@ from pandas import DataFrame, Series, Index, MultiIndex, isnull, concat from pandas import date_range, to_datetime, to_timedelta, Timestamp import pandas.compat as compat -from pandas.compat import StringIO, range, lrange, string_types, PY36 +from pandas.compat import range, lrange, string_types, PY36 from pandas.tseries.tools import format as date_format import pandas.io.sql as sql @@ -2220,6 +2219,7 @@ def test_schema(self): cur = self.conn.cursor() cur.execute(create_sql) + @tm.capture_stdout def test_execute_fail(self): create_sql = """ CREATE TABLE test @@ -2236,14 +2236,10 @@ def test_execute_fail(self): sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.conn) - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.execute, - 'INSERT INTO test VALUES("foo", "bar", 7)', - self.conn) - finally: - sys.stdout = sys.__stdout__ + with pytest.raises(Exception): + sql.execute('INSERT INTO test VALUES("foo", "bar", 7)', self.conn) + @tm.capture_stdout def test_execute_closed_connection(self): create_sql = """ CREATE TABLE test @@ -2259,12 +2255,9 @@ def test_execute_closed_connection(self): sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) self.conn.close() - try: - sys.stdout = StringIO() - self.assertRaises(Exception, tquery, "select * from test", - con=self.conn) - finally: - sys.stdout = sys.__stdout__ + + with pytest.raises(Exception): + tquery("select * from test", con=self.conn) # Initialize connection again (needed for tearDown) self.setUp() @@ -2534,6 +2527,7 @@ def test_schema(self): cur.execute(drop_sql) cur.execute(create_sql) + @tm.capture_stdout def test_execute_fail(self): _skip_if_no_pymysql() drop_sql = "DROP TABLE IF EXISTS test" @@ -2553,14 +2547,10 @@ def test_execute_fail(self): sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.conn) - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.execute, - 'INSERT INTO test VALUES("foo", "bar", 7)', - self.conn) - finally: - sys.stdout = sys.__stdout__ + with pytest.raises(Exception): + sql.execute('INSERT INTO test VALUES("foo", "bar", 7)', self.conn) + @tm.capture_stdout def test_execute_closed_connection(self): _skip_if_no_pymysql() drop_sql = "DROP TABLE IF EXISTS test" @@ -2579,12 +2569,9 @@ def test_execute_closed_connection(self): sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) self.conn.close() - try: - sys.stdout = StringIO() - self.assertRaises(Exception, tquery, "select * from test", - con=self.conn) - finally: - sys.stdout = sys.__stdout__ + + with pytest.raises(Exception): + tquery("select * from test", con=self.conn) # Initialize connection again (needed for tearDown) self.setUp() diff --git a/pandas/tests/plotting/test_frame.py b/pandas/tests/plotting/test_frame.py index 48af366f24ea4..1527637ea3eff 100644 --- a/pandas/tests/plotting/test_frame.py +++ b/pandas/tests/plotting/test_frame.py @@ -12,7 +12,7 @@ from pandas import (Series, DataFrame, MultiIndex, PeriodIndex, date_range, bdate_range) from pandas.types.api import is_list_like -from pandas.compat import (range, lrange, StringIO, lmap, lzip, u, zip, PY3) +from pandas.compat import range, lrange, lmap, lzip, u, zip, PY3 from pandas.formats.printing import pprint_thing import pandas.util.testing as tm from pandas.util.testing import slow @@ -1558,8 +1558,8 @@ def test_line_label_none(self): self.assertEqual(ax.get_legend().get_texts()[0].get_text(), 'None') @slow + @tm.capture_stdout def test_line_colors(self): - import sys from matplotlib import cm custom_colors = 'rgcby' @@ -1568,16 +1568,13 @@ def test_line_colors(self): ax = df.plot(color=custom_colors) self._check_colors(ax.get_lines(), linecolors=custom_colors) - tmp = sys.stderr - sys.stderr = StringIO() - try: - tm.close() - ax2 = df.plot(colors=custom_colors) - lines2 = ax2.get_lines() - for l1, l2 in zip(ax.get_lines(), lines2): - self.assertEqual(l1.get_color(), l2.get_color()) - finally: - sys.stderr = tmp + tm.close() + + ax2 = df.plot(colors=custom_colors) + lines2 = ax2.get_lines() + + for l1, l2 in zip(ax.get_lines(), lines2): + self.assertEqual(l1.get_color(), l2.get_color()) tm.close() diff --git a/pandas/tests/series/test_repr.py b/pandas/tests/series/test_repr.py index 99a406a71b12b..188b96638344c 100644 --- a/pandas/tests/series/test_repr.py +++ b/pandas/tests/series/test_repr.py @@ -3,13 +3,15 @@ from datetime import datetime, timedelta +import sys + import numpy as np import pandas as pd from pandas import (Index, Series, DataFrame, date_range) from pandas.core.index import MultiIndex -from pandas.compat import StringIO, lrange, range, u +from pandas.compat import lrange, range, u from pandas import compat import pandas.util.testing as tm @@ -112,20 +114,15 @@ def test_tidy_repr(self): a.name = 'title1' repr(a) # should not raise exception + @tm.capture_stderr def test_repr_bool_fails(self): s = Series([DataFrame(np.random.randn(2, 2)) for i in range(5)]) - import sys - - buf = StringIO() - tmp = sys.stderr - sys.stderr = buf - try: - # it works (with no Cython exception barf)! - repr(s) - finally: - sys.stderr = tmp - self.assertEqual(buf.getvalue(), '') + # It works (with no Cython exception barf)! + repr(s) + + output = sys.stderr.getvalue() + assert output == '' def test_repr_name_iterable_indexable(self): s = Series([1, 2, 3], name=np.int64(3)) diff --git a/pandas/util/decorators.py b/pandas/util/decorators.py index 4e1719958e8b7..ca588e2a0432e 100644 --- a/pandas/util/decorators.py +++ b/pandas/util/decorators.py @@ -1,7 +1,6 @@ -from pandas.compat import StringIO, callable, signature +from pandas.compat import callable, signature from pandas._libs.lib import cache_readonly # noqa import types -import sys import warnings from textwrap import dedent from functools import wraps, update_wrapper @@ -196,17 +195,6 @@ def indent(text, indents=1): return jointext.join(text.split('\n')) -def suppress_stdout(f): - def wrapped(*args, **kwargs): - try: - sys.stdout = StringIO() - f(*args, **kwargs) - finally: - sys.stdout = sys.__stdout__ - - return wrapped - - def make_signature(func): """ Returns a string repr of the arg list of a func call, with any defaults diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 9d7b004374318..ef0fa04548cab 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -38,7 +38,7 @@ from pandas.compat import ( filter, map, zip, range, unichr, lrange, lmap, lzip, u, callable, Counter, raise_with_traceback, httplib, is_platform_windows, is_platform_32bit, - PY3 + StringIO, PY3 ) from pandas.computation import expressions as expr @@ -629,6 +629,105 @@ def _valid_locales(locales, normalize): return list(filter(_can_set_locale, map(normalizer, locales))) +# ----------------------------------------------------------------------------- +# Stdout / stderr decorators + + +def capture_stdout(f): + """ + Decorator to capture stdout in a buffer so that it can be checked + (or suppressed) during testing. + + Parameters + ---------- + f : callable + The test that is capturing stdout. + + Returns + ------- + f : callable + The decorated test ``f``, which captures stdout. + + Examples + -------- + + >>> from pandas.util.testing import capture_stdout + >>> + >>> import sys + >>> + >>> @capture_stdout + ... def test_print_pass(): + ... print("foo") + ... out = sys.stdout.getvalue() + ... assert out == "foo\n" + >>> + >>> @capture_stdout + ... def test_print_fail(): + ... print("foo") + ... out = sys.stdout.getvalue() + ... assert out == "bar\n" + ... + AssertionError: assert 'foo\n' == 'bar\n' + """ + + @wraps(f) + def wrapper(*args, **kwargs): + try: + sys.stdout = StringIO() + f(*args, **kwargs) + finally: + sys.stdout = sys.__stdout__ + + return wrapper + + +def capture_stderr(f): + """ + Decorator to capture stderr in a buffer so that it can be checked + (or suppressed) during testing. + + Parameters + ---------- + f : callable + The test that is capturing stderr. + + Returns + ------- + f : callable + The decorated test ``f``, which captures stderr. + + Examples + -------- + + >>> from pandas.util.testing import capture_stderr + >>> + >>> import sys + >>> + >>> @capture_stderr + ... def test_stderr_pass(): + ... sys.stderr.write("foo") + ... out = sys.stderr.getvalue() + ... assert out == "foo\n" + >>> + >>> @capture_stderr + ... def test_stderr_fail(): + ... sys.stderr.write("foo") + ... out = sys.stderr.getvalue() + ... assert out == "bar\n" + ... + AssertionError: assert 'foo\n' == 'bar\n' + """ + + @wraps(f) + def wrapper(*args, **kwargs): + try: + sys.stderr = StringIO() + f(*args, **kwargs) + finally: + sys.stderr = sys.__stderr__ + + return wrapper + # ----------------------------------------------------------------------------- # Console debugging tools