Skip to content

Commit

Permalink
Merge pull request #10613 from jreback/stata
Browse files Browse the repository at this point in the history
ENH: add StataReader context manager to ensure closing of the path
  • Loading branch information
jreback committed Jul 18, 2015
2 parents 4bb45b1 + 59dd18b commit 5a9a9da
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
15 changes: 15 additions & 0 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,21 @@ def __init__(self, path_or_buf, convert_dates=True,

self._read_header()

def __enter__(self):
""" enter context manager """
return self

def __exit__(self, exc_type, exc_value, traceback):
""" exit context manager """
self.close()

def close(self):
""" close the handle if its open """
try:
self.path_or_buf.close()
except IOError:
pass

def _read_header(self):
first_char = self.path_or_buf.read(1)
if struct.unpack('c', first_char)[0] == b'<':
Expand Down
24 changes: 13 additions & 11 deletions pandas/io/tests/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,11 @@ def test_timestamp_and_label(self):
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
original.to_stata(path, time_stamp=time_stamp, data_label=data_label)
reader = StataReader(path)
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label

with StataReader(path) as reader:
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label

def test_numeric_column_names(self):
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
Expand Down Expand Up @@ -599,13 +600,14 @@ def test_minimal_size_col(self):
original = DataFrame(s)
with tm.ensure_clean() as path:
original.to_stata(path, write_index=False)
sr = StataReader(path)
typlist = sr.typlist
variables = sr.varlist
formats = sr.fmtlist
for variable, fmt, typ in zip(variables, formats, typlist):
self.assertTrue(int(variable[1:]) == int(fmt[1:-1]))
self.assertTrue(int(variable[1:]) == typ)

with StataReader(path) as sr:
typlist = sr.typlist
variables = sr.varlist
formats = sr.fmtlist
for variable, fmt, typ in zip(variables, formats, typlist):
self.assertTrue(int(variable[1:]) == int(fmt[1:-1]))
self.assertTrue(int(variable[1:]) == typ)

def test_excessively_long_string(self):
str_lens = (1, 244, 500)
Expand Down

0 comments on commit 5a9a9da

Please sign in to comment.