Skip to content

Commit

Permalink
alignment for shap
Browse files Browse the repository at this point in the history
  • Loading branch information
givasile committed Oct 26, 2024
1 parent c4db912 commit dc058c1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
5 changes: 4 additions & 1 deletion effector/global_effect_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def plot(
nof_shap_values: Union[int, str] = "all",
show_avg_output: bool = False,
y_limits: Optional[List] = None,
only_shap_values: bool = False,
) -> None:
"""
Plot the SHAP Dependence Plot (SDP) of the s-th feature.
Expand All @@ -282,6 +283,7 @@ def plot(
nof_shap_values: number of shap values to show on top of the SHAP curve
show_avg_output: whether to show the average output of the model
y_limits: limits of the y-axis
only_shap_values: whether to plot only the shap values
"""
heterogeneity = helpers.prep_confidence_interval(heterogeneity)

Expand Down Expand Up @@ -330,5 +332,6 @@ def plot(
avg_output=avg_output,
feature_names=self.feature_names,
target_name=self.target_name,
y_limits=y_limits
y_limits=y_limits,
only_shap_values=only_shap_values,
)
12 changes: 8 additions & 4 deletions effector/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,24 @@ def plot_shap(
avg_output: typing.Union[None, float] = None,
feature_names: typing.Union[None, list] = None,
target_name: typing.Union[None, str] = None,
y_limits: typing.Union[None, tuple] = None
y_limits: typing.Union[None, tuple] = None,
only_shap_values: bool = False
):

fig, ax = plt.subplots()
ax.set_title("SHAP Dependence Plot")

# scale x-axis
x = x if scale_x is None else trans_affine(x, scale_x["mean"], scale_x["std"])
xx = xx if scale_x is None else trans_affine(xx, scale_x["mean"], scale_x["std"])
if xx is not None:
xx = xx if scale_x is None else trans_affine(xx, scale_x["mean"], scale_x["std"])

# scale y-axis
if scale_y is not None:
y_std = trans_scale(y_std, scale_y["std"], square=False)
y = trans_affine(y, scale_y["mean"], scale_y["std"])
yy = trans_affine(yy, scale_y["mean"], scale_y["std"])
if yy is not None:
yy = trans_affine(yy, scale_y["mean"], scale_y["std"])
# if avg_output is not None:
# avg_output = trans_scale(avg_output, scale_y["std"], square=False)

Expand All @@ -278,7 +281,8 @@ def plot_shap(
ax.plot(xx[0], yy[0], "rx", alpha=0.5, label="SHAP values")
ax.plot(xx, yy, "rx", alpha=0.5)

ax.plot(x, y, "b-", label="SHAP curve")
if not only_shap_values:
ax.plot(x, y, "b-", label="SHAP curve")

if avg_output is not None:
ax.axhline(y=avg_output, color="black", linestyle="--", label="avg output")
Expand Down

0 comments on commit dc058c1

Please sign in to comment.