diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 53f2ff455d32e..db9362c5c821e 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -949,6 +949,21 @@ def __init__(self, path_or_buf, convert_dates=True, self._read_header() + def __enter__(self): + """ enter context manager """ + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ exit context manager """ + self.close() + + def close(self): + """ close the handle if its open """ + try: + self.path_or_buf.close() + except IOError: + pass + def _read_header(self): first_char = self.path_or_buf.read(1) if struct.unpack('c', first_char)[0] == b'<': diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index a06c4384d72c5..4b2781c9dceb6 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -430,10 +430,11 @@ def test_timestamp_and_label(self): data_label = 'This is a data file.' with tm.ensure_clean() as path: original.to_stata(path, time_stamp=time_stamp, data_label=data_label) - reader = StataReader(path) - parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M')) - assert parsed_time_stamp == time_stamp - assert reader.data_label == data_label + + with StataReader(path) as reader: + parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M')) + assert parsed_time_stamp == time_stamp + assert reader.data_label == data_label def test_numeric_column_names(self): original = DataFrame(np.reshape(np.arange(25.0), (5, 5))) @@ -599,13 +600,14 @@ def test_minimal_size_col(self): original = DataFrame(s) with tm.ensure_clean() as path: original.to_stata(path, write_index=False) - sr = StataReader(path) - typlist = sr.typlist - variables = sr.varlist - formats = sr.fmtlist - for variable, fmt, typ in zip(variables, formats, typlist): - self.assertTrue(int(variable[1:]) == int(fmt[1:-1])) - self.assertTrue(int(variable[1:]) == typ) + + with StataReader(path) as sr: + typlist = sr.typlist + variables = sr.varlist + formats = sr.fmtlist + for variable, fmt, typ in zip(variables, formats, typlist): + self.assertTrue(int(variable[1:]) == int(fmt[1:-1])) + self.assertTrue(int(variable[1:]) == typ) def test_excessively_long_string(self): str_lens = (1, 244, 500)