Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT-#7254: Support right merge/join #7226

Merged
merged 22 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import abc
import warnings
from functools import cached_property
from typing import Hashable, List, Optional
from typing import TYPE_CHECKING, Hashable, List, Optional

import numpy as np
import pandas
Expand Down Expand Up @@ -53,6 +53,9 @@

from . import doc_utils

if TYPE_CHECKING:
from typing_extensions import Self

Check warning on line 57 in modin/core/storage_formats/base/query_compiler.py

View check run for this annotation

Codecov / codecov/patch

modin/core/storage_formats/base/query_compiler.py#L57

Added line #L57 was not covered by tests


def _get_axis(axis):
"""
Expand Down Expand Up @@ -150,7 +153,7 @@
else:
return obj

def default_to_pandas(self, pandas_op, *args, **kwargs):
def default_to_pandas(self, pandas_op, *args, **kwargs) -> Self:
"""
Do fallback to pandas for the passed function.

Expand Down Expand Up @@ -4459,7 +4462,7 @@
# END Abstract methods for QueryCompiler

@cached_property
def __constructor__(self) -> type[BaseQueryCompiler]:
def __constructor__(self) -> type[Self]:
"""
Get query compiler constructor.

Expand Down
68 changes: 40 additions & 28 deletions modin/core/storage_formats/pandas/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import pandas
from pandas.core.dtypes.common import is_list_like
Expand Down Expand Up @@ -103,7 +103,7 @@ def func(left, right):
@classmethod
def row_axis_merge(
cls, left: PandasQueryCompiler, right: PandasQueryCompiler, kwargs: dict
):
) -> PandasQueryCompiler:
"""
Execute merge using row-axis implementation.

Expand All @@ -126,10 +126,25 @@ def row_axis_merge(
right_index = kwargs.get("right_index", False)
sort = kwargs.get("sort", False)

if how in ["left", "inner"] and left_index is False and right_index is False:
if (
(
how in ["left", "inner"]
or (how == "right" and right._modin_frame._partitions.size != 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
and left_index is False
and right_index is False
):
kwargs["sort"] = False

def should_keep_index(left, right):
reverted = False
if how == "right":
left, right = right, left
reverted = True

def should_keep_index(
left: PandasQueryCompiler,
right: PandasQueryCompiler,
) -> bool:
keep_index = False
if left_on is not None and right_on is not None:
keep_index = any(
Expand All @@ -145,21 +160,12 @@ def should_keep_index(left, right):
return keep_index

def map_func(
left, right, *axis_lengths, kwargs=kwargs, **service_kwargs
): # pragma: no cover
df = pandas.merge(left, right, **kwargs)

if kwargs["how"] == "left":
partition_idx = service_kwargs["partition_idx"]
if len(axis_lengths):
if not should_keep_index(left, right):
# Doesn't work for "inner" case, since the partition sizes of the
# left dataframe may change
start = sum(axis_lengths[:partition_idx])
stop = sum(axis_lengths[: partition_idx + 1])

df.index = pandas.RangeIndex(start, stop)

left, right, kwargs=kwargs
) -> pandas.DataFrame: # pragma: no cover
if reverted:
df = pandas.merge(right, left, **kwargs)
else:
df = pandas.merge(left, right, **kwargs)
return df

# Want to ensure that these are python lists
Expand All @@ -171,7 +177,11 @@ def map_func(

right_to_broadcast = right._modin_frame.combine()
new_columns, new_dtypes = cls._compute_result_metadata(
left, right, on, left_on, right_on, kwargs.get("suffixes", ("_x", "_y"))
*((left, right) if not reverted else (right, left)),
on,
left_on,
right_on,
kwargs.get("suffixes", ("_x", "_y")),
)

# We rebalance when the ratio of the number of existing partitions to
Expand All @@ -188,7 +198,6 @@ def map_func(
left._modin_frame.broadcast_apply_full_axis(
axis=1,
func=map_func,
enumerate_partitions=how == "left",
other=right_to_broadcast,
# We're going to explicitly change the shape across the 1-axis,
# so we want for partitioning to adapt as well
Expand All @@ -199,7 +208,6 @@ def map_func(
new_columns=new_columns,
sync_labels=False,
dtypes=new_dtypes,
pass_axis_lengths_to_partitions=how == "left",
)
)

Expand Down Expand Up @@ -238,16 +246,20 @@ def map_func(
else new_left.sort_rows_by_column_values(on)
)

return (
new_left.reset_index(drop=True)
if not keep_index and (kwargs["how"] != "left" or sort)
else new_left
)
return new_left if keep_index else new_left.reset_index(drop=True)
else:
return left.default_to_pandas(pandas.DataFrame.merge, right, **kwargs)

@classmethod
def _compute_result_metadata(cls, left, right, on, left_on, right_on, suffixes):
def _compute_result_metadata(
cls,
left: PandasQueryCompiler,
right: PandasQueryCompiler,
on,
left_on,
right_on,
suffixes,
) -> tuple[Optional[pandas.Index], Optional[ModinDtypes]]:
"""
Compute columns and dtypes metadata for the result of merge if possible.

Expand Down
41 changes: 28 additions & 13 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
queries for the ``PandasDataframe``.
"""

from __future__ import annotations

import ast
import hashlib
import re
Expand Down Expand Up @@ -313,8 +315,8 @@ def from_dataframe(cls, df, data_cls):

# END Dataframe exchange protocol

index = property(_get_axis(0), _set_axis(0))
columns = property(_get_axis(1), _set_axis(1))
index: pandas.Index = property(_get_axis(0), _set_axis(0))
columns: pandas.Index = property(_get_axis(1), _set_axis(1))

@property
def dtypes(self):
Expand Down Expand Up @@ -524,33 +526,46 @@ def merge(self, right, **kwargs):
get_logger().info(message)
return MergeImpl.row_axis_merge(self, right, kwargs)

def join(self, right, **kwargs):
def join(self, right: PandasQueryCompiler, **kwargs) -> PandasQueryCompiler:
on = kwargs.get("on", None)
how = kwargs.get("how", "left")
sort = kwargs.get("sort", False)
left = self

if how in ["left", "inner"]:

def map_func(left, right, kwargs=kwargs): # pragma: no cover
return pandas.DataFrame.join(left, right, **kwargs)
if how in ["left", "inner"] or (
how == "right" and right._modin_frame._partitions.size != 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if left has size equals to 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty dataframes are processed at a higher level (simply defaulted to pandas) and all implementation logic relies on the fact that the left operand has a non-empty set of partitions. Therefore, to avoid an error, we need to handle this situation as before.

):
reverted = False
if how == "right":
left, right = right, left
reverted = True

def map_func(
left, right, kwargs=kwargs
) -> pandas.DataFrame: # pragma: no cover
if reverted:
df = pandas.DataFrame.join(right, left, **kwargs)
else:
df = pandas.DataFrame.join(left, right, **kwargs)
return df

right_to_broadcast = right._modin_frame.combine()
new_self = self.__constructor__(
self._modin_frame.broadcast_apply_full_axis(
left = left.__constructor__(
left._modin_frame.broadcast_apply_full_axis(
axis=1,
func=map_func,
# We're going to explicitly change the shape across the 1-axis,
# so we want for partitioning to adapt as well
keep_partitioning=False,
num_splits=merge_partitioning(
self._modin_frame, right._modin_frame, axis=1
left._modin_frame, right._modin_frame, axis=1
),
other=right_to_broadcast,
)
)
return new_self.sort_rows_by_column_values(on) if sort else new_self
return left.sort_rows_by_column_values(on) if sort else left
else:
return self.default_to_pandas(pandas.DataFrame.join, right, **kwargs)
return left.default_to_pandas(pandas.DataFrame.join, right, **kwargs)

# END Inter-Data operations

Expand Down Expand Up @@ -586,7 +601,7 @@ def reindex(self, axis, labels, **kwargs):
)
return self.__constructor__(new_modin_frame)

def reset_index(self, **kwargs):
def reset_index(self, **kwargs) -> PandasQueryCompiler:
if self.lazy_execution:

def _reset(df, *axis_lengths, partition_idx): # pragma: no cover
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1454,9 +1454,13 @@ def _join_by_index(self, other_modin_frames, how, sort, ignore_index):
condition=condition,
)

new_columns = Index.__new__(
Index, data=new_columns, dtype=new_columns_dtype
)
# in the case of heterogeneous data, using the `dtype` parameter of the
# `Index` constructor can lead to the following error:
# `ValueError: string values cannot be losslessly cast to int64`
new_columns = Index(new_columns)
if new_columns.dtype != new_columns_dtype and new_columns_dtype is not None:
# ValueError: string values cannot be losslessly cast to int64
new_columns = new_columns.astype(new_columns_dtype)
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
lhs = lhs.__constructor__(
dtypes=lhs._dtypes_for_exprs(exprs),
columns=new_columns,
Expand Down Expand Up @@ -1994,8 +1998,6 @@ def sort_rows(self, columns, ascending, ignore_index, na_position):
drop_index_cols_after = [
col for col in base._index_cols if col in columns
]
if not drop_index_cols_after:
drop_index_cols_after = None

if drop_index_cols_before:
exprs = dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1347,13 +1347,13 @@ def build_row_idx_filter_expr(row_idx, row_col):
return row_col.eq(row_idx)

if is_range_like(row_idx):
start = row_idx[0]
stop = row_idx[-1]
start = row_idx.start
stop = row_idx.stop
step = row_idx.step
if step < 0:
start, stop = stop, start
step = -step
exprs = [row_col.ge(start), row_col.le(stop)]
exprs = [row_col.ge(start), row_col.cmp("<", stop)]
if step > 1:
mod = OpExpr("MOD", [row_col, LiteralExpr(step)], _get_dtype(int))
exprs.append(mod.eq(0))
Expand Down
Loading
Loading