Skip to content

Commit

Permalink
Merge pull request #8767 from bashtage/stata-categorical
Browse files Browse the repository at this point in the history
ENH: Add categorical support for Stata export
  • Loading branch information
jreback committed Nov 13, 2014
2 parents ed47013 + 204b50e commit 8d1ae49
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 29 deletions.
6 changes: 6 additions & 0 deletions doc/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3626,12 +3626,18 @@ outside of this range, the data is cast to ``int16``.
if ``int64`` values are larger than 2**53.

.. warning::

:class:`~pandas.io.stata.StataWriter`` and
:func:`~pandas.core.frame.DataFrame.to_stata` only support fixed width
strings containing up to 244 characters, a limitation imposed by the version
115 dta file format. Attempting to write *Stata* dta files with strings
longer than 244 characters raises a ``ValueError``.

.. warning::

*Stata* data files only support text labels for categorical data. Exporting
data frames containing categorical data will convert non-string categorical values
to strings.

.. _io.stata_reader:

Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.15.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ API changes
Enhancements
~~~~~~~~~~~~

- Added ability to export Categorical data to Stata (:issue:`8633`).

.. _whatsnew_0152.performance:

Expand Down
256 changes: 230 additions & 26 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import struct
from dateutil.relativedelta import relativedelta
from pandas.core.base import StringMixin
from pandas.core.categorical import Categorical
from pandas.core.frame import DataFrame
from pandas.core.series import Series
from pandas.core.categorical import Categorical
import datetime
from pandas import compat, to_timedelta, to_datetime, isnull, DatetimeIndex
from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \
zip
zip, BytesIO
import pandas.core.common as com
from pandas.io.common import get_filepath_or_buffer
from pandas.lib import max_len_string_array, infer_dtype
Expand Down Expand Up @@ -336,6 +336,15 @@ class PossiblePrecisionLoss(Warning):
conversion range. This may result in a loss of precision in the saved data.
"""

class ValueLabelTypeMismatch(Warning):
pass

value_label_mismatch_doc = """
Stata value labels (pandas categories) must be strings. Column {0} contains
non-string labels which will be converted to strings. Please check that the
Stata data file created has not lost information due to duplicate labels.
"""


class InvalidColumnName(Warning):
pass
Expand Down Expand Up @@ -425,6 +434,131 @@ def _cast_to_stata_types(data):
return data


class StataValueLabel(object):
"""
Parse a categorical column and prepare formatted output
Parameters
-----------
value : int8, int16, int32, float32 or float64
The Stata missing value code
Attributes
----------
string : string
String representation of the Stata missing value
value : int8, int16, int32, float32 or float64
The original encoded missing value
Methods
-------
generate_value_label
"""

def __init__(self, catarray):

self.labname = catarray.name

categories = catarray.cat.categories
self.value_labels = list(zip(np.arange(len(categories)), categories))
self.value_labels.sort(key=lambda x: x[0])
self.text_len = np.int32(0)
self.off = []
self.val = []
self.txt = []
self.n = 0

# Compute lengths and setup lists of offsets and labels
for vl in self.value_labels:
category = vl[1]
if not isinstance(category, string_types):
category = str(category)
import warnings
warnings.warn(value_label_mismatch_doc.format(catarray.name),
ValueLabelTypeMismatch)

self.off.append(self.text_len)
self.text_len += len(category) + 1 # +1 for the padding
self.val.append(vl[0])
self.txt.append(category)
self.n += 1

if self.text_len > 32000:
raise ValueError('Stata value labels for a single variable must '
'have a combined length less than 32,000 '
'characters.')

# Ensure int32
self.off = np.array(self.off, dtype=np.int32)
self.val = np.array(self.val, dtype=np.int32)

# Total length
self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len

def _encode(self, s):
"""
Python 3 compatability shim
"""
if compat.PY3:
return s.encode(self._encoding)
else:
return s

def generate_value_label(self, byteorder, encoding):
"""
Parameters
----------
byteorder : str
Byte order of the output
encoding : str
File encoding
Returns
-------
value_label : bytes
Bytes containing the formatted value label
"""

self._encoding = encoding
bio = BytesIO()
null_string = '\x00'
null_byte = b'\x00'

# len
bio.write(struct.pack(byteorder + 'i', self.len))

# labname
labname = self._encode(_pad_bytes(self.labname[:32], 33))
bio.write(labname)

# padding - 3 bytes
for i in range(3):
bio.write(struct.pack('c', null_byte))

# value_label_table
# n - int32
bio.write(struct.pack(byteorder + 'i', self.n))

# textlen - int32
bio.write(struct.pack(byteorder + 'i', self.text_len))

# off - int32 array (n elements)
for offset in self.off:
bio.write(struct.pack(byteorder + 'i', offset))

# val - int32 array (n elements)
for value in self.val:
bio.write(struct.pack(byteorder + 'i', value))

# txt - Text labels, null terminated
for text in self.txt:
bio.write(self._encode(text + null_string))

bio.seek(0)
return bio.read()


class StataMissingValue(StringMixin):
"""
An observation's missing value.
Expand Down Expand Up @@ -477,25 +611,31 @@ class StataMissingValue(StringMixin):
for i in range(1, 27):
MISSING_VALUES[i + b] = '.' + chr(96 + i)

base = b'\x00\x00\x00\x7f'
float32_base = b'\x00\x00\x00\x7f'
increment = struct.unpack('<i', b'\x00\x08\x00\x00')[0]
for i in range(27):
value = struct.unpack('<f', base)[0]
value = struct.unpack('<f', float32_base)[0]
MISSING_VALUES[value] = '.'
if i > 0:
MISSING_VALUES[value] += chr(96 + i)
int_value = struct.unpack('<i', struct.pack('<f', value))[0] + increment
base = struct.pack('<i', int_value)
float32_base = struct.pack('<i', int_value)

base = b'\x00\x00\x00\x00\x00\x00\xe0\x7f'
float64_base = b'\x00\x00\x00\x00\x00\x00\xe0\x7f'
increment = struct.unpack('q', b'\x00\x00\x00\x00\x00\x01\x00\x00')[0]
for i in range(27):
value = struct.unpack('<d', base)[0]
value = struct.unpack('<d', float64_base)[0]
MISSING_VALUES[value] = '.'
if i > 0:
MISSING_VALUES[value] += chr(96 + i)
int_value = struct.unpack('q', struct.pack('<d', value))[0] + increment
base = struct.pack('q', int_value)
float64_base = struct.pack('q', int_value)

BASE_MISSING_VALUES = {'int8': 101,
'int16': 32741,
'int32': 2147483621,
'float32': struct.unpack('<f', float32_base)[0],
'float64': struct.unpack('<d', float64_base)[0]}

def __init__(self, value):
self._value = value
Expand All @@ -518,6 +658,22 @@ def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.string == other.string and self.value == other.value)

@classmethod
def get_base_missing_value(cls, dtype):
if dtype == np.int8:
value = cls.BASE_MISSING_VALUES['int8']
elif dtype == np.int16:
value = cls.BASE_MISSING_VALUES['int16']
elif dtype == np.int32:
value = cls.BASE_MISSING_VALUES['int32']
elif dtype == np.float32:
value = cls.BASE_MISSING_VALUES['float32']
elif dtype == np.float64:
value = cls.BASE_MISSING_VALUES['float64']
else:
raise ValueError('Unsupported dtype')
return value


class StataParser(object):
_default_encoding = 'cp1252'
Expand Down Expand Up @@ -1111,10 +1267,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
umissing, umissing_loc = np.unique(series[missing],
return_inverse=True)
replacement = Series(series, dtype=np.object)
for i, um in enumerate(umissing):
for j, um in enumerate(umissing):
missing_value = StataMissingValue(um)

loc = missing_loc[umissing_loc == i]
loc = missing_loc[umissing_loc == j]
replacement.iloc[loc] = missing_value
else: # All replacements are identical
dtype = series.dtype
Expand Down Expand Up @@ -1390,6 +1546,45 @@ def _write(self, to_write):
else:
self._file.write(to_write)

def _prepare_categoricals(self, data):
"""Check for categorigal columns, retain categorical information for
Stata file and convert categorical data to int"""

is_cat = [com.is_categorical_dtype(data[col]) for col in data]
self._is_col_cat = is_cat
self._value_labels = []
if not any(is_cat):
return data

get_base_missing_value = StataMissingValue.get_base_missing_value
index = data.index
data_formatted = []
for col, col_is_cat in zip(data, is_cat):
if col_is_cat:
self._value_labels.append(StataValueLabel(data[col]))
dtype = data[col].cat.codes.dtype
if dtype == np.int64:
raise ValueError('It is not possible to export int64-based '
'categorical data to Stata.')
values = data[col].cat.codes.values.copy()

# Upcast if needed so that correct missing values can be set
if values.max() >= get_base_missing_value(dtype):
if dtype == np.int8:
dtype = np.int16
elif dtype == np.int16:
dtype = np.int32
else:
dtype = np.float64
values = np.array(values, dtype=dtype)

# Replace missing values with Stata missing value for type
values[values == -1] = get_base_missing_value(dtype)
data_formatted.append((col, values, index))

else:
data_formatted.append((col, data[col]))
return DataFrame.from_items(data_formatted)

def _replace_nans(self, data):
# return data
Expand Down Expand Up @@ -1480,27 +1675,26 @@ def _check_column_names(self, data):
def _prepare_pandas(self, data):
#NOTE: we might need a different API / class for pandas objects so
# we can set different semantics - handle this with a PR to pandas.io
class DataFrameRowIter(object):
def __init__(self, data):
self.data = data

def __iter__(self):
for row in data.itertuples():
# First element is index, so remove
yield row[1:]

if self._write_index:
data = data.reset_index()
# Check columns for compatibility with stata
data = _cast_to_stata_types(data)

# Ensure column names are strings
data = self._check_column_names(data)

# Check columns for compatibility with stata, upcast if necessary
data = _cast_to_stata_types(data)

# Replace NaNs with Stata missing values
data = self._replace_nans(data)
self.datarows = DataFrameRowIter(data)

# Convert categoricals to int data, and strip labels
data = self._prepare_categoricals(data)

self.nobs, self.nvar = data.shape
self.data = data
self.varlist = data.columns.tolist()

dtypes = data.dtypes
if self._convert_dates is not None:
self._convert_dates = _maybe_convert_to_int_keys(
Expand All @@ -1515,6 +1709,7 @@ def __iter__(self):
self.fmtlist = []
for col, dtype in dtypes.iteritems():
self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col]))

# set the given format for the datetime cols
if self._convert_dates is not None:
for key in self._convert_dates:
Expand All @@ -1529,8 +1724,14 @@ def write_file(self):
self._write(_pad_bytes("", 5))
self._prepare_data()
self._write_data()
self._write_value_labels()
self._file.close()

def _write_value_labels(self):
for vl in self._value_labels:
self._file.write(vl.generate_value_label(self._byteorder,
self._encoding))

def _write_header(self, data_label=None, time_stamp=None):
byteorder = self._byteorder
# ds_format - just use 114
Expand Down Expand Up @@ -1585,9 +1786,15 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
self._write(_pad_bytes(fmt, 49))

# lbllist, 33*nvar, char array
#NOTE: this is where you could get fancy with pandas categorical type
for i in range(nvar):
self._write(_pad_bytes("", 33))
# Use variable name when categorical
if self._is_col_cat[i]:
name = self.varlist[i]
name = self._null_terminate(name, True)
name = _pad_bytes(name[:32], 33)
self._write(name)
else: # Default is empty label
self._write(_pad_bytes("", 33))

def _write_variable_labels(self, labels=None):
nvar = self.nvar
Expand Down Expand Up @@ -1624,9 +1831,6 @@ def _prepare_data(self):
data_cols.append(data[col].values)
dtype = np.dtype(dtype)

# 3. Convert to record array

# data.to_records(index=False, convert_datetime64=False)
if has_strings:
self.data = np.fromiter(zip(*data_cols), dtype=dtype)
else:
Expand Down
Loading

0 comments on commit 8d1ae49

Please sign in to comment.