diff --git a/causalml/metrics/visualize.py b/causalml/metrics/visualize.py index a2923a09..8dd62d08 100644 --- a/causalml/metrics/visualize.py +++ b/causalml/metrics/visualize.py @@ -113,11 +113,17 @@ def get_cumlift( sorted_df["cumsum_tr"] = sorted_df[treatment_col].cumsum() sorted_df["cumsum_ct"] = sorted_df.index.values - sorted_df["cumsum_tr"] sorted_df["cumsum_y_tr"] = ( - sorted_df[outcome_col] * sorted_df[treatment_col] - ).fillna(0).cumsum(skipna=True).astype(float) + (sorted_df[outcome_col] * sorted_df[treatment_col]) + .fillna(0) + .cumsum(skipna=True) + .astype(float) + ) sorted_df["cumsum_y_ct"] = ( - sorted_df[outcome_col] * (1 - sorted_df[treatment_col]) - ).fillna(0).cumsum(skipna=True).astype(float) + (sorted_df[outcome_col] * (1 - sorted_df[treatment_col])) + .fillna(0) + .cumsum(skipna=True) + .astype(float) + ) lift.append( sorted_df["cumsum_y_tr"] / sorted_df["cumsum_tr"]