From 198692094685e4940acfda943c1417d015034939 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Fri, 25 Nov 2016 18:06:31 -0500 Subject: [PATCH] ENH: Allow usecols to accept callable (GH14154) --- doc/source/whatsnew/v0.20.0.txt | 1 + pandas/io/parsers.py | 45 ++++++++++++++++++++----------- pandas/io/tests/parser/usecols.py | 32 +++++++++++++++++++--- pandas/parser.pyx | 28 ++++++++++++++----- 4 files changed, 80 insertions(+), 26 deletions(-) diff --git a/doc/source/whatsnew/v0.20.0.txt b/doc/source/whatsnew/v0.20.0.txt index 65b62601c70224..4f837437a86027 100644 --- a/doc/source/whatsnew/v0.20.0.txt +++ b/doc/source/whatsnew/v0.20.0.txt @@ -31,6 +31,7 @@ Other enhancements ^^^^^^^^^^^^^^^^^^ - ``pd.read_excel`` now preserves sheet order when using ``sheetname=None`` (:issue:`9930`) +- The ``usecols`` argument in ``pd.read_csv`` now accepts a callable function as a value (:issue:`14154`) .. _whatsnew_0200.api_breaking: diff --git a/pandas/io/parsers.py b/pandas/io/parsers.py index 929b360854d5bc..50611a0dfcd14d 100755 --- a/pandas/io/parsers.py +++ b/pandas/io/parsers.py @@ -34,6 +34,7 @@ from pandas.util.decorators import Appender import pandas.lib as lib +import pandas.core.common as com import pandas.parser as _parser @@ -86,13 +87,14 @@ MultiIndex is used. If you have a malformed file with delimiters at the end of each line, you might consider index_col=False to force pandas to _not_ use the first column as the index (row names) -usecols : array-like, default None - Return a subset of the columns. All elements in this array must either +usecols : array-like or callable, default None + Return a subset of the columns. If array-like, all elements must either be positional (i.e. integer indices into the document columns) or strings that correspond to column names provided either by the user in `names` or inferred from the document header row(s). For example, a valid `usecols` - parameter would be [0, 1, 2] or ['foo', 'bar', 'baz']. Using this parameter - results in much faster parsing time and lower memory usage. + parameter would be [0, 1, 2], ['foo', 'bar', 'baz'] or lambda x: x.upper() + in ['AAA', 'BBB', 'DDD']. Using this parameter results in much faster + parsing time and lower memory usage. as_recarray : boolean, default False DEPRECATED: this argument will be removed in a future version. Please call `pd.read_csv(...).to_records()` instead. @@ -976,17 +978,26 @@ def _is_index_col(col): return col is not None and col is not False +def _evaluate_usecols(usecols, names): + if callable(usecols): + return set([i for i, name in enumerate(names) + if com._apply_if_callable(usecols, name)]) + return usecols + + def _validate_usecols_arg(usecols): """ Check whether or not the 'usecols' parameter - contains all integers (column selection by index) - or strings (column by name). Raises a ValueError - if that is not the case. + contains all integers (column selection by index), + strings (column by name) or is a callable. Raises + a ValueError if that is not the case. """ - msg = ("The elements of 'usecols' must " - "either be all strings, all unicode, or all integers") + msg = ("'usecols' must either be all strings, all unicode, " + "all integers or a callable") if usecols is not None: + if callable(usecols): + return usecols usecols_dtype = lib.infer_dtype(usecols) if usecols_dtype not in ('empty', 'integer', 'string', 'unicode'): @@ -1426,11 +1437,12 @@ def __init__(self, src, **kwds): self.orig_names = self.names[:] if self.usecols: - if len(self.names) > len(self.usecols): + usecols = _evaluate_usecols(self.usecols, self.orig_names) + if len(self.names) > len(usecols): self.names = [n for i, n in enumerate(self.names) - if (i in self.usecols or n in self.usecols)] + if (i in usecols or n in usecols)] - if len(self.names) < len(self.usecols): + if len(self.names) < len(usecols): raise ValueError("Usecols do not match names.") self._set_noconvert_columns() @@ -1592,9 +1604,10 @@ def read(self, nrows=None): def _filter_usecols(self, names): # hackish - if self.usecols is not None and len(names) != len(self.usecols): + usecols = _evaluate_usecols(self.usecols, names) + if usecols is not None and len(names) != len(usecols): names = [name for i, name in enumerate(names) - if i in self.usecols or name in self.usecols] + if i in usecols or name in usecols] return names def _get_index_names(self): @@ -2207,7 +2220,9 @@ def _handle_usecols(self, columns, usecols_key): usecols_key is used if there are string usecols. """ if self.usecols is not None: - if any([isinstance(col, string_types) for col in self.usecols]): + if callable(self.usecols): + col_indices = _evaluate_usecols(self.usecols, usecols_key) + elif any([isinstance(u, string_types) for u in self.usecols]): if len(columns) > 1: raise ValueError("If using multiple headers, usecols must " "be integers.") diff --git a/pandas/io/tests/parser/usecols.py b/pandas/io/tests/parser/usecols.py index 5051171ccb8f07..7ee088631276dc 100644 --- a/pandas/io/tests/parser/usecols.py +++ b/pandas/io/tests/parser/usecols.py @@ -23,8 +23,9 @@ def test_raise_on_mixed_dtype_usecols(self): 1000,2000,3000 4000,5000,6000 """ - msg = ("The elements of 'usecols' must " - "either be all strings, all unicode, or all integers") + + msg = ("'usecols' must either be all strings, all unicode, " + "all integers or a callable") usecols = [0, 'b', 2] with tm.assertRaisesRegexp(ValueError, msg): @@ -302,8 +303,8 @@ def test_usecols_with_mixed_encoding_strings(self): 3.568935038,7,False,a ''' - msg = ("The elements of 'usecols' must " - "either be all strings, all unicode, or all integers") + msg = ("'usecols' must either be all strings, all unicode, " + "all integers or a callable") with tm.assertRaisesRegexp(ValueError, msg): self.read_csv(StringIO(s), usecols=[u'AAA', b'BBB']) @@ -366,3 +367,26 @@ def test_np_array_usecols(self): expected = DataFrame([[1, 2]], columns=usecols) result = self.read_csv(StringIO(data), usecols=usecols) tm.assert_frame_equal(result, expected) + + def test_callable_usecols(self): + # See gh-14154 + s = '''AaA,bBb,CCC,ddd + 0.056674973,8,True,a + 2.613230982,2,False,b + 3.568935038,7,False,a + ''' + + data = { + 'AaA': { + 0: 0.056674972999999997, + 1: 2.6132309819999997, + 2: 3.5689350380000002 + }, + 'bBb': {0: 8, 1: 2, 2: 7}, + 'ddd': {0: 'a', 1: 'b', 2: 'a'} + } + expected = DataFrame(data) + + df = self.read_csv(StringIO(s), usecols=lambda x: + x.upper() in ['AAA', 'BBB', 'DDD']) + tm.assert_frame_equal(df, expected) diff --git a/pandas/parser.pyx b/pandas/parser.pyx index 6b43dfbabc4a0d..97782ad9a0e59c 100644 --- a/pandas/parser.pyx +++ b/pandas/parser.pyx @@ -38,6 +38,7 @@ cimport util import pandas.lib as lib import pandas.compat as compat +import pandas.core.common as com from pandas.types.common import (is_categorical_dtype, CategoricalDtype, is_integer_dtype, is_float_dtype, is_bool_dtype, is_object_dtype, @@ -300,8 +301,10 @@ cdef class TextReader: object compression object mangle_dupe_cols object tupleize_cols + object usecols list dtype_cast_order - set noconvert, usecols + set noconvert + def __cinit__(self, source, delimiter=b',', @@ -437,7 +440,10 @@ cdef class TextReader: # suboptimal if usecols is not None: self.has_usecols = 1 - self.usecols = set(usecols) + if callable(usecols): + self.usecols = usecols + else: + self.usecols = set(usecols) # XXX if skipfooter > 0: @@ -701,7 +707,6 @@ cdef class TextReader: cdef StringPath path = _string_path(self.c_encoding) header = [] - if self.parser.header_start >= 0: # Header is in the file @@ -821,7 +826,7 @@ cdef class TextReader: # 'data has %d fields' # % (passed_count, field_count)) - if self.has_usecols and self.allow_leading_cols: + if self.has_usecols and self.allow_leading_cols and not callable(self.usecols): nuse = len(self.usecols) if nuse == passed_count: self.leading_cols = 0 @@ -1015,17 +1020,25 @@ cdef class TextReader: results = {} nused = 0 + for i in range(self.table_width): + if i < self.leading_cols: # Pass through leading columns always name = i - elif self.usecols and nused == len(self.usecols): + elif self.usecols and not callable(self.usecols) and nused == len(self.usecols): # Once we've gathered all requested columns, stop. GH5766 break else: name = self._get_column_name(i, nused) - if self.has_usecols and not (i in self.usecols or - name in self.usecols): + usecols = set() + if callable(self.usecols): + if com._apply_if_callable(self.usecols, name): + usecols = set([i]) + else: + usecols = self.usecols + if self.has_usecols and not (i in usecols or + name in usecols): continue nused += 1 @@ -1341,6 +1354,7 @@ def _maybe_upcast(arr): return arr + cdef enum StringPath: CSTRING UTF8