diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py index a5327f3b484..c7384b21ea4 100644 --- a/modin/core/dataframe/pandas/dataframe/dataframe.py +++ b/modin/core/dataframe/pandas/dataframe/dataframe.py @@ -3321,8 +3321,11 @@ def _extract_partitions(self): if self._partitions.size > 0: return self._partitions else: + dtypes = None + if self.has_materialized_dtypes: + dtypes = self.dtypes return self._partition_mgr_cls.create_partition_from_metadata( - index=self.index, columns=self.columns + index=self.index, columns=self.columns, dtypes=dtypes ) @lazy_metadata_decorator(apply_axis="both") diff --git a/modin/core/dataframe/pandas/partitioning/partition_manager.py b/modin/core/dataframe/pandas/partitioning/partition_manager.py index d7fa8640066..678091d19ad 100644 --- a/modin/core/dataframe/pandas/partitioning/partition_manager.py +++ b/modin/core/dataframe/pandas/partitioning/partition_manager.py @@ -21,7 +21,7 @@ import warnings from abc import ABC from functools import wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import numpy as np import pandas @@ -183,12 +183,16 @@ def preprocess_func(cls, map_func): # END Abstract Methods @classmethod - def create_partition_from_metadata(cls, **metadata): + def create_partition_from_metadata( + cls, dtypes: Optional[pandas.Series] = None, **metadata + ): """ Create NumPy array of partitions that holds an empty dataframe with given metadata. Parameters ---------- + dtypes : pandas.Series, optional + Dtypes that will be used when calling `astype`. **metadata : dict Metadata that has to be wrapped in a partition. @@ -197,7 +201,7 @@ def create_partition_from_metadata(cls, **metadata): np.ndarray A NumPy 2D array of a single partition which contains the data. """ - metadata_dataframe = pandas.DataFrame(**metadata) + metadata_dataframe = pandas.DataFrame(**metadata).astype(dtypes) return np.array([[cls._partition_class.put(metadata_dataframe)]]) @classmethod diff --git a/modin/tests/pandas/dataframe/test_join_sort.py b/modin/tests/pandas/dataframe/test_join_sort.py index 1959894dac6..983782e8ad0 100644 --- a/modin/tests/pandas/dataframe/test_join_sort.py +++ b/modin/tests/pandas/dataframe/test_join_sort.py @@ -186,7 +186,7 @@ def test_join_empty(how): data = np.random.randint(0, 100, size=(64, 64)) eval_general( *create_test_dfs(data), - lambda df: df.join(df.iloc[:0], how=how, lsuffix="_caller"), + lambda df: df.join(df.iloc[:0], on=1, how=how, lsuffix="_caller"), )