Skip to content

Commit

Permalink
minimal fix to resolve #707 (#720)
Browse files Browse the repository at this point in the history
* minimal fix to resolve #707
* lint
* remove .bool()
* operate on series not df

---------

Co-authored-by: ras44 <rolandrmgservices@gmail.com>
  • Loading branch information
ras44 and rolandrmgservices authored Dec 1, 2023
1 parent bae55f5 commit ae6c28e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
31 changes: 18 additions & 13 deletions causalml/metrics/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ def get_cumlift(
Returns:
(pandas.DataFrame): average uplifts of model estimates in cumulative population
"""

assert (
(outcome_col in df.columns)
and (treatment_col in df.columns)
or treatment_effect_col in df.columns
(outcome_col in df.columns and df[outcome_col].notnull().all())
and (treatment_col in df.columns and df[treatment_col].notnull().all())
or (
treatment_effect_col in df.columns
and df[treatment_effect_col].notnull().all()
)
)

df = df.copy()
Expand Down Expand Up @@ -214,9 +216,12 @@ def get_qini(
(pandas.DataFrame): cumulative gains of model estimates in population
"""
assert (
(outcome_col in df.columns)
and (treatment_col in df.columns)
or treatment_effect_col in df.columns
(outcome_col in df.columns and df[outcome_col].notnull().all())
and (treatment_col in df.columns and df[treatment_col].notnull().all())
or (
treatment_effect_col in df.columns
and df[treatment_effect_col].notnull().all()
)
)

df = df.copy()
Expand Down Expand Up @@ -310,9 +315,9 @@ def get_tmlegain(
(pandas.DataFrame): cumulative gains of model estimates based of TMLE
"""
assert (
(outcome_col in df.columns)
and (treatment_col in df.columns)
or p_col in df.columns
(outcome_col in df.columns and df[outcome_col].notnull().all())
and (treatment_col in df.columns and df[treatment_col].notnull().all())
or (p_col in df.columns and df[p_col].notnull().all())
)

inference_col = [x for x in inference_col if x in df.columns]
Expand Down Expand Up @@ -416,9 +421,9 @@ def get_tmleqini(
(pandas.DataFrame): cumulative gains of model estimates based of TMLE
"""
assert (
(outcome_col in df.columns)
and (treatment_col in df.columns)
or p_col in df.columns
(outcome_col in df.columns and df[outcome_col].notnull().all())
and (treatment_col in df.columns and df[treatment_col].notnull().all())
or (p_col in df.columns and df[p_col].notnull().all())
)

inference_col = [x for x in inference_col if x in df.columns]
Expand Down
14 changes: 14 additions & 0 deletions tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pandas as pd
import numpy as np
import pytest
from causalml.metrics.visualize import get_cumlift


def test_visualize_get_cumlift_errors_on_nan():
df = pd.DataFrame(
[[0, np.nan, 0.5], [1, np.nan, 0.1], [1, 1, 0.4], [0, 1, 0.3], [1, 1, 0.2]],
columns=["w", "y", "pred"],
)

with pytest.raises(Exception):
get_cumlift(df)

0 comments on commit ae6c28e

Please sign in to comment.