Skip to content

Commit

Permalink
REF: implement io.pytables.DataCol._get_atom (pandas-dev#30102)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and proost committed Dec 19, 2019
1 parent c982cc5 commit faa642a
Showing 1 changed file with 59 additions and 22 deletions.
81 changes: 59 additions & 22 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
from pandas.core.dtypes.common import (
ensure_object,
is_categorical_dtype,
is_complex_dtype,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_extension_array_dtype,
is_list_like,
is_string_dtype,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCExtensionArray
Expand Down Expand Up @@ -2353,16 +2355,48 @@ def set_atom(self, block, data_converted, use_str: bool):
# set as a data block
self.set_atom_data(block)

def get_atom_string(self, shape, itemsize):
@classmethod
def _get_atom(cls, values: Union[np.ndarray, ABCExtensionArray]) -> "Col":
"""
Get an appropriately typed and shaped pytables.Col object for values.
"""

dtype = values.dtype
itemsize = dtype.itemsize

shape = values.shape
if values.ndim == 1:
# EA, use block shape pretending it is 2D
shape = (1, values.size)

if is_categorical_dtype(dtype):
codes = values.codes
atom = cls.get_atom_data(shape, kind=codes.dtype.name)
elif is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype):
atom = cls.get_atom_datetime64(shape)
elif is_timedelta64_dtype(dtype):
atom = cls.get_atom_timedelta64(shape)
elif is_complex_dtype(dtype):
atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0])

elif is_string_dtype(dtype):
atom = cls.get_atom_string(shape, itemsize)

else:
atom = cls.get_atom_data(shape, kind=dtype.name)

return atom

@classmethod
def get_atom_string(cls, shape, itemsize):
return _tables().StringCol(itemsize=itemsize, shape=shape[0])

def set_atom_string(self, data_converted: np.ndarray):
itemsize = data_converted.dtype.itemsize
self.kind = "string"
self.typ = self.get_atom_string(data_converted.shape, itemsize)
self.set_data(data_converted)

def get_atom_coltype(self, kind: str) -> Type["Col"]:
@classmethod
def get_atom_coltype(cls, kind: str) -> Type["Col"]:
""" return the PyTables column class for this column """
if kind.startswith("uint"):
k4 = kind[4:]
Expand All @@ -2373,18 +2407,16 @@ def get_atom_coltype(self, kind: str) -> Type["Col"]:

return getattr(_tables(), col_name)

def get_atom_data(self, shape, kind: str) -> "Col":
return self.get_atom_coltype(kind=kind)(shape=shape[0])
@classmethod
def get_atom_data(cls, shape, kind: str) -> "Col":
return cls.get_atom_coltype(kind=kind)(shape=shape[0])

def set_atom_complex(self, block):
self.kind = block.dtype.name
itemsize = int(self.kind.split("complex")[-1]) // 8
self.typ = _tables().ComplexCol(itemsize=itemsize, shape=block.shape[0])
self.set_data(block.values)

def set_atom_data(self, block):
self.kind = block.dtype.name
self.typ = self.get_atom_data(block.shape, kind=block.dtype.name)
self.set_data(block.values)

def set_atom_categorical(self, block):
Expand All @@ -2401,7 +2433,6 @@ def set_atom_categorical(self, block):

# write the codes; must be in a block shape
self.ordered = values.ordered
self.typ = self.get_atom_data(block.shape, kind=codes.dtype.name)
self.set_data(block.values)

# write the categories
Expand All @@ -2410,12 +2441,12 @@ def set_atom_categorical(self, block):
assert self.kind == "integer", self.kind
assert self.dtype == codes.dtype.name, codes.dtype.name

def get_atom_datetime64(self, block):
return _tables().Int64Col(shape=block.shape[0])
@classmethod
def get_atom_datetime64(cls, shape):
return _tables().Int64Col(shape=shape[0])

def set_atom_datetime64(self, block):
self.kind = "datetime64"
self.typ = self.get_atom_datetime64(block)
self.set_data(block.values)

def set_atom_datetime64tz(self, block):
Expand All @@ -2424,15 +2455,14 @@ def set_atom_datetime64tz(self, block):
self.tz = _get_tz(block.values.tz)

self.kind = "datetime64"
self.typ = self.get_atom_datetime64(block)
self.set_data(block.values)

def get_atom_timedelta64(self, block):
return _tables().Int64Col(shape=block.shape[0])
@classmethod
def get_atom_timedelta64(cls, shape):
return _tables().Int64Col(shape=shape[0])

def set_atom_timedelta64(self, block):
self.kind = "timedelta64"
self.typ = self.get_atom_timedelta64(block)
self.set_data(block.values)

@property
Expand Down Expand Up @@ -2558,16 +2588,20 @@ def validate_names(self):
# TODO: should the message here be more specifically non-str?
raise ValueError("cannot have non-object label DataIndexableCol")

def get_atom_string(self, shape, itemsize):
@classmethod
def get_atom_string(cls, shape, itemsize):
return _tables().StringCol(itemsize=itemsize)

def get_atom_data(self, shape, kind: str) -> "Col":
return self.get_atom_coltype(kind=kind)()
@classmethod
def get_atom_data(cls, shape, kind: str) -> "Col":
return cls.get_atom_coltype(kind=kind)()

def get_atom_datetime64(self, block):
@classmethod
def get_atom_datetime64(cls, shape):
return _tables().Int64Col()

def get_atom_timedelta64(self, block):
@classmethod
def get_atom_timedelta64(cls, shape):
return _tables().Int64Col()


Expand Down Expand Up @@ -3922,8 +3956,11 @@ def get_blk_items(mgr, blocks):
errors=self.errors,
)

typ = klass._get_atom(data_converted)

col = klass.create_for_block(i=i, name=new_name, version=self.version)
col.values = list(b_items)
col.typ = typ
col.set_atom(block=b, data_converted=data_converted, use_str=use_str)
col.update_info(self.info)
col.set_pos(j)
Expand Down

0 comments on commit faa642a

Please sign in to comment.