Skip to content

Commit

Permalink
feat: update plotting regression/prediction interval
Browse files Browse the repository at this point in the history
  • Loading branch information
Spencer Sun authored and hmgomes committed May 14, 2024
1 parent 4baad0f commit bf3f7ce
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/capymoa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def __init__(self, schema=None, CLI=None, random_seed=1, moa_learner=None):

self.moa_learner.prepareForUse()
self.moa_learner.resetLearning()
self.moa_learner.setModelContext(self.schema.get_moa_header())

def __str__(self):
full_name = str(self.moa_learner.getClass().getCanonicalName())
Expand Down
148 changes: 117 additions & 31 deletions src/capymoa/evaluation/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,13 @@ def plot_regression_results(
plot_predictions=True,
plot_residuals=True,

target_type='line', # line or dots
# target_type='line', # line or dots
add_target_markers=True,
target_marker='*', # can be any markers supported by matplotlib

predictions_type='dots', # line or dots
predictions_marker='.', # can be any markers supported by matplotlib

absolute_residuals=False,

plot_hist_residuals=False,
Expand Down Expand Up @@ -293,30 +297,46 @@ def plot_regression_results(

# plot targets
if plot_target:
if target_type == 'line':
plt.plot(instance_numbers, targets, label="targets", linewidth=3,
plt.plot(instance_numbers, targets, label="targets", linewidth=1,
color=color_target if color_target is not None else "g")
elif target_type == 'dots':
plt.scatter(instance_numbers, targets, label="targets", marker=target_marker, s=20,
if add_target_markers:
plt.scatter(instance_numbers, targets, label="targets", marker=target_marker, s=20,
color=color_target if color_target is not None else "g")
else:
raise ValueError("Target_type must be 'line' or 'dots'.")

# plot predictions
if plot_predictions:
for i, prediction in enumerate(predictions):
plt.plot(instance_numbers, predictions[i],
label=results[i]['learner'] + " predictions",
color=color_predictions[i] if color_predictions is not None else default_colors[i],
linewidth=1, linestyle="--", alpha=0.5)
if predictions_type == 'line':
plt.plot(instance_numbers, predictions[i],
label=results[i]['learner'] + " predictions",
color=color_predictions[i] if color_predictions is not None else default_colors[i],
linewidth=1, linestyle="--", alpha=0.5)
elif predictions_type == 'dots':
plt.scatter(instance_numbers, predictions[i],
label=results[i]['learner'] + " predictions",
color=color_predictions[i] if color_predictions is not None else default_colors[i],
marker=predictions_marker, s=20)
else:
raise ValueError("Predictions_type must be 'line' or 'dots'.")

if predictions_type == 'dots':
if len(results) > 2 :
plot_residuals = False

for i in range(len(instance_numbers)):
values = [predictions[x][i] for x in range(len(predictions))]
values.append(targets[i])
values = np.array(values)
plt.vlines(x=instance_numbers[i], ymin=min(values), ymax=max(values), linestyles='dashed', colors='grey',
linewidth=0.5)

# plot residuals
if plot_residuals:
for i, residual in enumerate(residuals):
plt.bar(instance_numbers, residuals[i] if not absolute_residuals else absolute_values[i],
label=results[i]['learner'] + " residuals" if not absolute_residuals else " absolute residuals",
color=color_predictions[i] if color_predictions is not None else default_colors[i], alpha=0.5)


if stream is not None and isinstance(stream, DriftStream):
if not prevent_plotting_drifts:
drifts = stream.get_drifts()
Expand Down Expand Up @@ -430,6 +450,8 @@ def plot_prediction_interval(
figure_name=None,
save_only=False,

dynamic_switch=True,

prevent_plotting_drifts=False,

):
Expand Down Expand Up @@ -471,7 +493,7 @@ def plot_prediction_interval(
plt.plot(instance_numbers, l, linewidth=0.1, alpha=0.2,
color=colors[0] if colors is not None else default_colors[0])
plt.fill_between(instance_numbers, u, l, color=colors[0] if colors is not None else default_colors[0],
alpha=0.3, label=results[0]["learner"] + " interval")
alpha=0.5, label=results[0]["learner"] + " interval")
if plot_predictions:
plt.plot(instance_numbers, np.array(predictions), linewidth=1, linestyle='-',
color=colors[0] if colors is not None else default_colors[0],
Expand Down Expand Up @@ -592,24 +614,88 @@ def plot_prediction_interval(
u_second = np.array(upper_second)
l_second = np.array(lower_second)

# Plot first area
plt.plot(instance_numbers, u_first, linewidth=0.1, alpha=0.2,
color=colors[0] if colors is not None else default_colors[0])
plt.plot(instance_numbers, l_first, linewidth=0.1, alpha=0.2,
color=colors[0] if colors is not None else default_colors[0])
plt.fill_between(instance_numbers, u_first, l_first,
color=colors[0] if colors is not None else default_colors[0],
alpha=0.1, label=results[0]["learner"] + " interval")

# Plot second area
plt.plot(instance_numbers, u_second, linewidth=0.1, alpha=0.4,
color=colors[1] if colors is not None else default_colors[1])
plt.plot(instance_numbers, l_second, linewidth=0.1, alpha=0.4,
color=colors[1] if colors is not None else default_colors[1])
plt.fill_between(instance_numbers, u_second, l_second,
color=colors[1] if colors is not None else default_colors[1],
alpha=0.3, label=results[1]["learner"] + " interval")

if not dynamic_switch:
# Plot first area
plt.plot(instance_numbers, u_first, linewidth=0.1, alpha=0.4,
color=colors[0] if colors is not None else default_colors[0])
plt.plot(instance_numbers, l_first, linewidth=0.1, alpha=0.4,
color=colors[0] if colors is not None else default_colors[0])
plt.fill_between(instance_numbers, u_first, l_first,
color=colors[0] if colors is not None else default_colors[0],
alpha=0.4, label=results[0]["learner"] + " interval")

# Plot second area
plt.plot(instance_numbers, u_second, linewidth=0.1, alpha=0.5,
color=colors[1] if colors is not None else default_colors[1])
plt.plot(instance_numbers, l_second, linewidth=0.1, alpha=0.5,
color=colors[1] if colors is not None else default_colors[1])
plt.fill_between(instance_numbers, u_second, l_second,
color=colors[1] if colors is not None else default_colors[1],
alpha=0.5, label=results[1]["learner"] + " interval")
else:
# define function for further dynamic plot
def _plot_first(i, alpha):
plt.plot(instance_numbers[switch_points[i]:switch_points[i+1]+1],
u_first[switch_points[i]:switch_points[i + 1]+1], linewidth=0.1, alpha=alpha,
color=colors[0] if colors is not None else default_colors[0])
plt.plot(instance_numbers[switch_points[i]:switch_points[i + 1]+1],
l_first[switch_points[i]:switch_points[i + 1]+1], linewidth=0.1, alpha=alpha,
color=colors[0] if colors is not None else default_colors[0])

plt.fill_between(instance_numbers[switch_points[i]:switch_points[i + 1]+1],
u_first[switch_points[i]:switch_points[i + 1]+1],
l_first[switch_points[i]:switch_points[i + 1]+1],
color=colors[0] if colors is not None else default_colors[0],
alpha=alpha, label=results[0]["learner"] + " interval" if i == 0 else "")

def _plot_second(i, alpha):
plt.plot(instance_numbers[switch_points[i]:switch_points[i + 1]+1],
u_second[switch_points[i]:switch_points[i + 1]+1], linewidth=0.1, alpha=alpha,
color=colors[1] if colors is not None else default_colors[1])
plt.plot(instance_numbers[switch_points[i]:switch_points[i + 1]+1],
l_second[switch_points[i]:switch_points[i + 1]+1], linewidth=0.1, alpha=alpha,
color=colors[1] if colors is not None else default_colors[1])

plt.fill_between(instance_numbers[switch_points[i ]:switch_points[i + 1]+1],
u_second[switch_points[i]:switch_points[i + 1]+1],
l_second[switch_points[i]:switch_points[i + 1]+1],
color=colors[1] if colors is not None else default_colors[1],
alpha=alpha, label=results[1]["learner"] + " interval" if i == 0 else "")

# determine which on top first
first_first = l_first[0] > l_second[0]
# find the switch point
larger = True
switch_points = [0]
for i in range(0, len(l_first)):
if larger:
if l_first[i] < l_second[i]:
switch_points.append(i)
larger = not larger
else:
if l_first[i] > l_second[i]:
switch_points.append(i)
larger = not larger
switch_points.append(len(u_first) -1)

# Plot dynamic switching areas
for i in range(len(switch_points) - 1):
if first_first:
if i % 2 == 0:
_plot_first(i, alpha=0.4)
_plot_second(i, alpha=0.5)
else:
_plot_second(i, alpha=0.4)
_plot_first(i, alpha=0.5)
else:
if i % 2 == 0:
_plot_second(i, alpha=0.4)
_plot_first(i, alpha=0.5)
else:
_plot_first(i, alpha=0.4)
_plot_second(i, alpha=0.5)

# Plot predictions
if plot_predictions:
plt.plot(instance_numbers, np.array(predictions_first), linewidth=1, linestyle='-',
color=colors[0] if colors is not None else default_colors[0],
Expand Down

0 comments on commit bf3f7ce

Please sign in to comment.