Skip to content

Commit

Permalink
minimal fix to resolve #707
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandrmgservices committed Nov 28, 2023
1 parent 3615bc8 commit 3c0b7a6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
12 changes: 12 additions & 0 deletions causalml/metrics/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def get_cumlift(
or treatment_effect_col in df.columns
)

assert not (
(df[[outcome_col, treatment_col, treatment_effect_col]].isnull().values.any())
)

df = df.copy()
np.random.seed(random_seed)
random_cols = []
Expand Down Expand Up @@ -219,6 +223,10 @@ def get_qini(
or treatment_effect_col in df.columns
)

assert not (
(df[[outcome_col, treatment_col, treatment_effect_col]].isnull().values.any())
)

df = df.copy()
np.random.seed(random_seed)
random_cols = []
Expand Down Expand Up @@ -315,6 +323,8 @@ def get_tmlegain(
or p_col in df.columns
)

assert not ((df[[outcome_col, treatment_col, p_col]].isnull().values.any()))

inference_col = [x for x in inference_col if x in df.columns]

# Initialize TMLE
Expand Down Expand Up @@ -421,6 +431,8 @@ def get_tmleqini(
or p_col in df.columns
)

assert not ((df[[outcome_col, treatment_col, p_col]].isnull().values.any()))

inference_col = [x for x in inference_col if x in df.columns]

# Initialize TMLE
Expand Down
14 changes: 14 additions & 0 deletions tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pandas as pd
import numpy as np
import pytest
from causalml.metrics.visualize import get_cumlift


def test_visualize_get_cumlift_errors_on_nan():
df = pd.DataFrame(
[[0, np.nan, 0.5], [1, np.nan, 0.1], [1, 1, 0.4], [0, 1, 0.3], [1, 1, 0.2]],
columns=["w", "y", "pred"],
)

with pytest.raises(Exception):
get_cumlift(df)

0 comments on commit 3c0b7a6

Please sign in to comment.