Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Column dropping #1551

Merged
merged 8 commits into from
Jul 25, 2024
15 changes: 15 additions & 0 deletions python/tests/experimental/core/test_udf_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ def test_multioutput_udf_dataframe() -> None:
assert results.get_column("bar") is not None


def test_drop_columns() -> None:
schema = udf_schema(drop_columns={"xx1", "xx2"})
df = pd.DataFrame({"xx1": [42, 7], "xx2": [3.14, 2.72]})
results = why.log(df, schema=schema).view()
assert results.get_column("xx1") is None
assert results.get_column("xx2") is None
# UDFs that needed the dropped columns as input still work
assert results.get_column("f1.foo") is not None
assert results.get_column("f1.bar") is not None
assert results.get_column("blah.foo") is not None
assert results.get_column("blah.bar") is not None
assert results.get_column("foo") is not None
assert results.get_column("bar") is not None


@register_dataset_udf(["col1"], schema_name="unit-tests")
def add5(x: Union[Dict[str, List], pd.DataFrame]) -> Union[List, pd.Series]:
return [xx + 5 for xx in x["col1"]]
Expand Down
10 changes: 10 additions & 0 deletions python/whylogs/experimental/core/udf_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
List,
Mapping,
Optional,
Set,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(
segments: Optional[Dict[str, SegmentationPartition]] = None,
validators: Optional[Dict[str, List[Validator]]] = None,
udf_specs: Optional[List[UdfSpec]] = None,
drop_columns: Optional[Set[str]] = None,
) -> None:
super().__init__(
resolvers=resolvers,
Expand All @@ -189,6 +191,7 @@ def __init__(
segments=segments,
validators=validators,
)
self.drop_columns = drop_columns if drop_columns else set()
udf_specs = udf_specs if udf_specs else []
self.multicolumn_udfs = [spec for spec in udf_specs if spec.column_names]
self.type_udfs = defaultdict(list)
Expand Down Expand Up @@ -242,10 +245,15 @@ def _run_udfs(
new_df = pd.DataFrame()
if row is not None:
self._run_udfs_on_row(row, new_columns, row.keys()) # type: ignore
if self.drop_columns:
for col in set(row.keys()).intersection(self.drop_columns):
row.pop(col)

if pandas is not None:
self._run_udfs_on_dataframe(pandas, new_df, pandas.keys())
new_df = pd.concat([pandas, new_df], axis=1)
if self.drop_columns:
new_df = new_df.drop(columns=list(set(new_df.keys()).intersection(self.drop_columns)))

return new_df if pandas is not None else None, new_columns

Expand Down Expand Up @@ -459,6 +467,7 @@ def udf_schema(
validators: Optional[Dict[str, List[Validator]]] = None,
schema_name: Union[str, List[str]] = "",
include_default_schema: bool = True,
drop_columns: Optional[Set[str]] = None,
) -> UdfSchema:
"""
Returns a UdfSchema that implements any registered UDFs, along with any
Expand All @@ -484,4 +493,5 @@ def udf_schema(
segments,
validators,
generate_udf_specs(other_udf_specs, schema_name, include_default_schema),
drop_columns,
)
Loading