Skip to content

Commit

Permalink
Resolves #3467 (#3429)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanda Potts <ajpotts@users.noreply.github.com>
  • Loading branch information
ajpotts and ajpotts authored Jul 22, 2024
1 parent c8f8d69 commit e14ff6a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
41 changes: 41 additions & 0 deletions PROTO_tests/tests/dataframe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import arkouda as ak
from arkouda import io_util
from arkouda.scipy import chisquare as akchisquare
from arkouda.testing import assert_frame_equal as ak_assert_frame_equal


def alternating_1_0(n):
Expand Down Expand Up @@ -239,6 +240,46 @@ def test_dataframe_creation(self, size):
ak_to_pd = akdf.to_pandas()
assert_frame_equal(pddf, ak_to_pd)

@pytest.mark.parametrize("size", pytest.prob_size)
@pytest.mark.parametrize("dtype", ["float64", "int64"])
def test_from_pandas_with_index(self, size, dtype):
np_arry = np.arange(size, dtype=dtype)
np_arry = np_arry[::2] * -1.0 # Alter so that index is different from default
idx = pd.Index(np_arry)
pd_df = pd.DataFrame({"col1": np_arry}, index=idx)
ak_df = ak.DataFrame(pd_df)
assert pd_df.index.inferred_type == ak_df.index.inferred_type

ak_arry = ak.arange(size, dtype=dtype)
ak_arry = ak_arry[::2] * -1.0
idx = ak.Index(ak_arry)
expected_df = ak.DataFrame({"col1": ak_arry}, index=idx)

ak_assert_frame_equal(ak_df, expected_df)

@pytest.mark.parametrize("size", pytest.prob_size)
@pytest.mark.parametrize("dtype", ["float64", "int64"])
def test_round_trip_pandas_conversion(self, size, dtype):
a = ak.arange(size, dtype=dtype)
a = a[::2] * -1.0 # Alter so that index is different from default
idx = ak.Index(a)
original_df = ak.DataFrame({"col1": a}, index=idx)
round_trip_df = ak.DataFrame(original_df.to_pandas(retain_index=True))

ak_assert_frame_equal(original_df, round_trip_df)

@pytest.mark.parametrize("size", pytest.prob_size)
def test_round_trip_dataframe_conversion2(self, size):

a = ak.arange(size, dtype="float64") + 0.001

idx = ak.Index(a)
df = ak.DataFrame({"col1": a}, index=idx)
pd_df = df.to_pandas(retain_index=True)
round_trip_df = ak.DataFrame(pd_df)

ak.assert_frame_equal(df, round_trip_df)

def test_convenience_init(self):
dict1 = {"0": [1, 2], "1": [True, False], "2": ["foo", "bar"], "3": [2.3, -1.8]}
dict2 = {"0": (1, 2), "1": (True, False), "2": ("foo", "bar"), "3": (2.3, -1.8)}
Expand Down
9 changes: 5 additions & 4 deletions arkouda/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from arkouda.client import generic_msg, maxTransferBytes
from arkouda.client_dtypes import BitVector, Fields, IPv4
from arkouda.dtypes import BigInt

from arkouda.dtypes import bool_ as akbool
from arkouda.dtypes import float64 as akfloat64
from arkouda.dtypes import int64 as akint64
Expand Down Expand Up @@ -884,6 +883,7 @@ def __init__(self, initialdata=None, index=None, columns=None):
self.data = initialdata.data
self.update_nrows()
return

elif isinstance(initialdata, pd.DataFrame):
# copy pd.DataFrame data into the ak.DataFrame object
self._nrows = initialdata.shape[0]
Expand All @@ -892,14 +892,15 @@ def __init__(self, initialdata=None, index=None, columns=None):
self._columns = initialdata.columns.tolist()

if index is None:
self._set_index(initialdata.index.values.tolist())
self._set_index(initialdata.index)
else:
self._set_index(index)
self.data = {}
for key in initialdata.columns:
self.data[key] = (
SegArray.from_multi_array([array(r) for r in initialdata[key]])
if isinstance(initialdata[key][0], (list, np.ndarray))
if hasattr(initialdata[key], "values")
and isinstance(initialdata[key].values[0], (list, np.ndarray))
else array(initialdata[key])
)

Expand Down Expand Up @@ -1849,7 +1850,7 @@ def index(self):
def _set_index(self, value):
if isinstance(value, Index) or value is None:
self._index = value
elif isinstance(value, (pdarray, Strings)):
elif isinstance(value, (pdarray, Strings, pd.Index)):
self._index = Index(value)
elif isinstance(value, list):
self._index = Index(array(value))
Expand Down

0 comments on commit e14ff6a

Please sign in to comment.