Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Set index when reading stata file #17328

Merged
merged 1 commit into from
Sep 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.21.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ Other API Changes
- :func:`Series.argmin` and :func:`Series.argmax` will now raise a ``TypeError`` when used with ``object`` dtypes, instead of a ``ValueError`` (:issue:`13595`)
- :class:`Period` is now immutable, and will now raise an ``AttributeError`` when a user tries to assign a new value to the ``ordinal`` or ``freq`` attributes (:issue:`17116`).
- :func:`to_datetime` when passed a tz-aware ``origin=`` kwarg will now raise a more informative ``ValueError`` rather than a ``TypeError`` (:issue:`16842`)
- Renamed non-functional ``index`` to ``index_col`` in :func:`read_stata` to improve API consistency (:issue:`16342`)


.. _whatsnew_0210.deprecations:
Expand Down Expand Up @@ -370,6 +371,7 @@ I/O
- Bug in :func:`read_csv` when called with ``low_memory=False`` in which a CSV with at least one column > 2GB in size would incorrectly raise a ``MemoryError`` (:issue:`16798`).
- Bug in :func:`read_csv` when called with a single-element list ``header`` would return a ``DataFrame`` of all NaN values (:issue:`7757`)
- Bug in :func:`read_stata` where value labels could not be read when using an iterator (:issue:`16923`)
- Bug in :func:`read_stata` where the index was not set (:issue:`16342`)
- Bug in :func:`read_html` where import check fails when run in multiple threads (:issue:`16928`)

Plotting
Expand Down
61 changes: 34 additions & 27 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,30 @@
You can find more information on http://presbrey.mit.edu/PyDTA and
http://www.statsmodels.org/devel/
"""
import numpy as np

import sys
import datetime
import struct
from dateutil.relativedelta import relativedelta
import sys

from pandas.core.dtypes.common import (
is_categorical_dtype, is_datetime64_dtype,
_ensure_object)
import numpy as np
from dateutil.relativedelta import relativedelta
from pandas._libs.lib import max_len_string_array, infer_dtype
from pandas._libs.tslib import NaT, Timestamp

import pandas as pd
from pandas import compat, to_timedelta, to_datetime, isna, DatetimeIndex
from pandas.compat import (lrange, lmap, lzip, text_type, string_types, range,
zip, BytesIO)
from pandas.core.base import StringMixin
from pandas.core.categorical import Categorical
from pandas.core.dtypes.common import (is_categorical_dtype, _ensure_object,
is_datetime64_dtype)
from pandas.core.frame import DataFrame
from pandas.core.series import Series
import datetime
from pandas import compat, to_timedelta, to_datetime, isna, DatetimeIndex
from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \
zip, BytesIO
from pandas.util._decorators import Appender
import pandas as pd

from pandas.io.common import (get_filepath_or_buffer, BaseIterator,
_stringify_path)
from pandas._libs.lib import max_len_string_array, infer_dtype
from pandas._libs.tslib import NaT, Timestamp
from pandas.util._decorators import Appender
from pandas.util._decorators import deprecate_kwarg

VALID_ENCODINGS = ('ascii', 'us-ascii', 'latin-1', 'latin_1', 'iso-8859-1',
'iso8859-1', '8859', 'cp819', 'latin', 'latin1', 'L1')
Expand All @@ -53,8 +52,8 @@
Encoding used to parse the files. None defaults to latin-1."""

_statafile_processing_params2 = """\
index : identifier of index column
identifier of column that should be used as index of the DataFrame
index_col : string, optional, default: None
Column to set as index
convert_missing : boolean, defaults to False
Flag indicating whether to convert missing values to their Stata
representations. If False, missing values are replaced with nans.
Expand Down Expand Up @@ -159,15 +158,16 @@


@Appender(_read_stata_doc)
@deprecate_kwarg(old_arg_name='index', new_arg_name='index_col')
def read_stata(filepath_or_buffer, convert_dates=True,
convert_categoricals=True, encoding=None, index=None,
convert_categoricals=True, encoding=None, index_col=None,
convert_missing=False, preserve_dtypes=True, columns=None,
order_categoricals=True, chunksize=None, iterator=False):

reader = StataReader(filepath_or_buffer,
convert_dates=convert_dates,
convert_categoricals=convert_categoricals,
index=index, convert_missing=convert_missing,
index_col=index_col, convert_missing=convert_missing,
preserve_dtypes=preserve_dtypes,
columns=columns,
order_categoricals=order_categoricals,
Expand Down Expand Up @@ -944,8 +944,9 @@ def __init__(self, encoding):
class StataReader(StataParser, BaseIterator):
__doc__ = _stata_reader_doc

@deprecate_kwarg(old_arg_name='index', new_arg_name='index_col')
def __init__(self, path_or_buf, convert_dates=True,
convert_categoricals=True, index=None,
convert_categoricals=True, index_col=None,
convert_missing=False, preserve_dtypes=True,
columns=None, order_categoricals=True,
encoding='latin-1', chunksize=None):
Expand All @@ -956,7 +957,7 @@ def __init__(self, path_or_buf, convert_dates=True,
# calls to read).
self._convert_dates = convert_dates
self._convert_categoricals = convert_categoricals
self._index = index
self._index_col = index_col
self._convert_missing = convert_missing
self._preserve_dtypes = preserve_dtypes
self._columns = columns
Expand Down Expand Up @@ -1460,8 +1461,9 @@ def get_chunk(self, size=None):
return self.read(nrows=size)

@Appender(_read_method_doc)
@deprecate_kwarg(old_arg_name='index', new_arg_name='index_col')
def read(self, nrows=None, convert_dates=None,
convert_categoricals=None, index=None,
convert_categoricals=None, index_col=None,
convert_missing=None, preserve_dtypes=None,
columns=None, order_categoricals=None):
# Handle empty file or chunk. If reading incrementally raise
Expand All @@ -1486,6 +1488,8 @@ def read(self, nrows=None, convert_dates=None,
columns = self._columns
if order_categoricals is None:
order_categoricals = self._order_categoricals
if index_col is None:
index_col = self._index_col

if nrows is None:
nrows = self.nobs
Expand Down Expand Up @@ -1524,14 +1528,14 @@ def read(self, nrows=None, convert_dates=None,
self._read_value_labels()

if len(data) == 0:
data = DataFrame(columns=self.varlist, index=index)
data = DataFrame(columns=self.varlist)
else:
data = DataFrame.from_records(data, index=index)
data = DataFrame.from_records(data)
data.columns = self.varlist

# If index is not specified, use actual row number rather than
# restarting at 0 for each chunk.
if index is None:
if index_col is None:
ix = np.arange(self._lines_read - read_lines, self._lines_read)
data = data.set_index(ix)

Expand All @@ -1553,7 +1557,7 @@ def read(self, nrows=None, convert_dates=None,
cols_ = np.where(self.dtyplist)[0]

# Convert columns (if needed) to match input type
index = data.index
ix = data.index
requires_type_conversion = False
data_formatted = []
for i in cols_:
Expand All @@ -1563,7 +1567,7 @@ def read(self, nrows=None, convert_dates=None,
if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
requires_type_conversion = True
data_formatted.append(
(col, Series(data[col], index, self.dtyplist[i])))
(col, Series(data[col], ix, self.dtyplist[i])))
else:
data_formatted.append((col, data[col]))
if requires_type_conversion:
Expand Down Expand Up @@ -1606,6 +1610,9 @@ def read(self, nrows=None, convert_dates=None,
if convert:
data = DataFrame.from_items(retyped_data)

if index_col is not None:
data = data.set_index(data.pop(index_col))

return data

def _do_convert_missing(self, data, convert_missing):
Expand Down
11 changes: 10 additions & 1 deletion pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_read_write_reread_dta15(self):
tm.assert_frame_equal(parsed_114, parsed_117)

def test_timestamp_and_label(self):
original = DataFrame([(1,)], columns=['var'])
original = DataFrame([(1,)], columns=['variable'])
time_stamp = datetime(2000, 2, 29, 14, 21)
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
Expand Down Expand Up @@ -1309,3 +1309,12 @@ def test_value_labels_iterator(self, write_index):
dta_iter = pd.read_stata(path, iterator=True)
value_labels = dta_iter.value_labels()
assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}

def test_set_index(self):
# GH 17328
df = tm.makeDataFrame()
df.index.name = 'index'
with tm.ensure_clean() as path:
df.to_stata(path)
reread = pd.read_stata(path, index_col='index')
tm.assert_frame_equal(df, reread)