Skip to content

Commit

Permalink
ENH: Allow usecols to accept callable (GH14154)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Nov 25, 2016
1 parent b1d9599 commit 1986920
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 26 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Other enhancements
^^^^^^^^^^^^^^^^^^

- ``pd.read_excel`` now preserves sheet order when using ``sheetname=None`` (:issue:`9930`)
- The ``usecols`` argument in ``pd.read_csv`` now accepts a callable function as a value (:issue:`14154`)


.. _whatsnew_0200.api_breaking:
Expand Down
45 changes: 30 additions & 15 deletions pandas/io/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pandas.util.decorators import Appender

import pandas.lib as lib
import pandas.core.common as com
import pandas.parser as _parser


Expand Down Expand Up @@ -86,13 +87,14 @@
MultiIndex is used. If you have a malformed file with delimiters at the end
of each line, you might consider index_col=False to force pandas to _not_
use the first column as the index (row names)
usecols : array-like, default None
Return a subset of the columns. All elements in this array must either
usecols : array-like or callable, default None
Return a subset of the columns. If array-like, all elements must either
be positional (i.e. integer indices into the document columns) or strings
that correspond to column names provided either by the user in `names` or
inferred from the document header row(s). For example, a valid `usecols`
parameter would be [0, 1, 2] or ['foo', 'bar', 'baz']. Using this parameter
results in much faster parsing time and lower memory usage.
parameter would be [0, 1, 2], ['foo', 'bar', 'baz'] or lambda x: x.upper()
in ['AAA', 'BBB', 'DDD']. Using this parameter results in much faster
parsing time and lower memory usage.
as_recarray : boolean, default False
DEPRECATED: this argument will be removed in a future version. Please call
`pd.read_csv(...).to_records()` instead.
Expand Down Expand Up @@ -976,17 +978,26 @@ def _is_index_col(col):
return col is not None and col is not False


def _evaluate_usecols(usecols, names):
if callable(usecols):
return set([i for i, name in enumerate(names)
if com._apply_if_callable(usecols, name)])
return usecols


def _validate_usecols_arg(usecols):
"""
Check whether or not the 'usecols' parameter
contains all integers (column selection by index)
or strings (column by name). Raises a ValueError
if that is not the case.
contains all integers (column selection by index),
strings (column by name) or is a callable. Raises
a ValueError if that is not the case.
"""
msg = ("The elements of 'usecols' must "
"either be all strings, all unicode, or all integers")
msg = ("'usecols' must either be all strings, all unicode, "
"all integers or a callable")

if usecols is not None:
if callable(usecols):
return usecols
usecols_dtype = lib.infer_dtype(usecols)
if usecols_dtype not in ('empty', 'integer',
'string', 'unicode'):
Expand Down Expand Up @@ -1426,11 +1437,12 @@ def __init__(self, src, **kwds):
self.orig_names = self.names[:]

if self.usecols:
if len(self.names) > len(self.usecols):
usecols = _evaluate_usecols(self.usecols, self.orig_names)
if len(self.names) > len(usecols):
self.names = [n for i, n in enumerate(self.names)
if (i in self.usecols or n in self.usecols)]
if (i in usecols or n in usecols)]

if len(self.names) < len(self.usecols):
if len(self.names) < len(usecols):
raise ValueError("Usecols do not match names.")

self._set_noconvert_columns()
Expand Down Expand Up @@ -1592,9 +1604,10 @@ def read(self, nrows=None):

def _filter_usecols(self, names):
# hackish
if self.usecols is not None and len(names) != len(self.usecols):
usecols = _evaluate_usecols(self.usecols, names)
if usecols is not None and len(names) != len(usecols):
names = [name for i, name in enumerate(names)
if i in self.usecols or name in self.usecols]
if i in usecols or name in usecols]
return names

def _get_index_names(self):
Expand Down Expand Up @@ -2207,7 +2220,9 @@ def _handle_usecols(self, columns, usecols_key):
usecols_key is used if there are string usecols.
"""
if self.usecols is not None:
if any([isinstance(col, string_types) for col in self.usecols]):
if callable(self.usecols):
col_indices = _evaluate_usecols(self.usecols, usecols_key)
elif any([isinstance(u, string_types) for u in self.usecols]):
if len(columns) > 1:
raise ValueError("If using multiple headers, usecols must "
"be integers.")
Expand Down
32 changes: 28 additions & 4 deletions pandas/io/tests/parser/usecols.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def test_raise_on_mixed_dtype_usecols(self):
1000,2000,3000
4000,5000,6000
"""
msg = ("The elements of 'usecols' must "
"either be all strings, all unicode, or all integers")

msg = ("'usecols' must either be all strings, all unicode, "
"all integers or a callable")
usecols = [0, 'b', 2]

with tm.assertRaisesRegexp(ValueError, msg):
Expand Down Expand Up @@ -302,8 +303,8 @@ def test_usecols_with_mixed_encoding_strings(self):
3.568935038,7,False,a
'''

msg = ("The elements of 'usecols' must "
"either be all strings, all unicode, or all integers")
msg = ("'usecols' must either be all strings, all unicode, "
"all integers or a callable")

with tm.assertRaisesRegexp(ValueError, msg):
self.read_csv(StringIO(s), usecols=[u'AAA', b'BBB'])
Expand Down Expand Up @@ -366,3 +367,26 @@ def test_np_array_usecols(self):
expected = DataFrame([[1, 2]], columns=usecols)
result = self.read_csv(StringIO(data), usecols=usecols)
tm.assert_frame_equal(result, expected)

def test_callable_usecols(self):
# See gh-14154
s = '''AaA,bBb,CCC,ddd
0.056674973,8,True,a
2.613230982,2,False,b
3.568935038,7,False,a
'''

data = {
'AaA': {
0: 0.056674972999999997,
1: 2.6132309819999997,
2: 3.5689350380000002
},
'bBb': {0: 8, 1: 2, 2: 7},
'ddd': {0: 'a', 1: 'b', 2: 'a'}
}
expected = DataFrame(data)

df = self.read_csv(StringIO(s), usecols=lambda x:
x.upper() in ['AAA', 'BBB', 'DDD'])
tm.assert_frame_equal(df, expected)
28 changes: 21 additions & 7 deletions pandas/parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cimport util

import pandas.lib as lib
import pandas.compat as compat
import pandas.core.common as com
from pandas.types.common import (is_categorical_dtype, CategoricalDtype,
is_integer_dtype, is_float_dtype,
is_bool_dtype, is_object_dtype,
Expand Down Expand Up @@ -300,8 +301,10 @@ cdef class TextReader:
object compression
object mangle_dupe_cols
object tupleize_cols
object usecols
list dtype_cast_order
set noconvert, usecols
set noconvert


def __cinit__(self, source,
delimiter=b',',
Expand Down Expand Up @@ -437,7 +440,10 @@ cdef class TextReader:
# suboptimal
if usecols is not None:
self.has_usecols = 1
self.usecols = set(usecols)
if callable(usecols):
self.usecols = usecols
else:
self.usecols = set(usecols)

# XXX
if skipfooter > 0:
Expand Down Expand Up @@ -701,7 +707,6 @@ cdef class TextReader:
cdef StringPath path = _string_path(self.c_encoding)

header = []

if self.parser.header_start >= 0:

# Header is in the file
Expand Down Expand Up @@ -821,7 +826,7 @@ cdef class TextReader:
# 'data has %d fields'
# % (passed_count, field_count))

if self.has_usecols and self.allow_leading_cols:
if self.has_usecols and self.allow_leading_cols and not callable(self.usecols):
nuse = len(self.usecols)
if nuse == passed_count:
self.leading_cols = 0
Expand Down Expand Up @@ -1015,17 +1020,25 @@ cdef class TextReader:

results = {}
nused = 0

for i in range(self.table_width):

if i < self.leading_cols:
# Pass through leading columns always
name = i
elif self.usecols and nused == len(self.usecols):
elif self.usecols and not callable(self.usecols) and nused == len(self.usecols):
# Once we've gathered all requested columns, stop. GH5766
break
else:
name = self._get_column_name(i, nused)
if self.has_usecols and not (i in self.usecols or
name in self.usecols):
usecols = set()
if callable(self.usecols):
if com._apply_if_callable(self.usecols, name):
usecols = set([i])
else:
usecols = self.usecols
if self.has_usecols and not (i in usecols or
name in usecols):
continue
nused += 1

Expand Down Expand Up @@ -1341,6 +1354,7 @@ def _maybe_upcast(arr):

return arr


cdef enum StringPath:
CSTRING
UTF8
Expand Down

0 comments on commit 1986920

Please sign in to comment.