Skip to content

Commit

Permalink
feat: clustering base classes with simple visualization
Browse files Browse the repository at this point in the history
Co-authored-by: Heitor Murilo Gomes <heitor_murilo_gomes@yahoo.com.br>
  • Loading branch information
cassales and hmgomes authored Aug 2, 2024
1 parent e25f6b6 commit 09b3a60
Show file tree
Hide file tree
Showing 7 changed files with 690 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ These tutorials will show you how to get started with the CapyMOA library.
notebooks/drift_detection.ipynb
notebooks/parallel_ensembles.ipynb
notebooks/save_and_load_model.ipynb
notebooks/clustering.ipynb

Talks
=====
Expand Down
291 changes: 291 additions & 0 deletions notebooks/clustering.ipynb

Large diffs are not rendered by default.

124 changes: 124 additions & 0 deletions src/capymoa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Union

from jpype import _jpype
import jpype
from moa.classifiers import (
Classifier as MOA_Classifier_Interface,
Regressor as MOA_Regressor_Interface,
Expand Down Expand Up @@ -531,3 +532,126 @@ def score_instance(self, instance):
# However, if it is not the case for a MOA learner, this method should be overridden.
prediction_array = self.moa_learner.getVotesForInstance(instance.java_instance)
return prediction_array[0]

##############################################################
######################### Clustering #########################
##############################################################
class Clusterer(ABC):
def __init__(self, schema: Schema, random_seed=1):
self.random_seed = random_seed
self.schema = schema
if self.schema is None:
raise ValueError("Schema must be initialised")

@abstractmethod
def __str__(self):
pass

@abstractmethod
def train(self, instance: Instance):
pass

# @abstractmethod
# def predict(self, instance: Instance) -> Optional[LabelIndex]:
# pass

# @abstractmethod
# def predict_proba(self, instance: Instance) -> LabelProbabilities:
# pass

class MOAClusterer(Clusterer):
"""
A wrapper class for using MOA (Massive Online Analysis) clusterers in CapyMOA.
Attributes:
- schema: The schema representing the instances. Defaults to None.
- CLI: The command-line interface (CLI) configuration for the MOA learner.
- random_seed: The random seed for reproducibility. Defaults to 1.
- moa_learner: The MOA learner object or class identifier.
"""

def __init__(self, moa_learner, schema=None, CLI=None):
super().__init__(schema=schema)
self.CLI = CLI
# If moa_learner is a class identifier instead of an object
if isinstance(moa_learner, type):
if type(moa_learner) == _jpype._JClass:
moa_learner = moa_learner()
else: # this is not a Java object, thus it certainly isn't a MOA learner
raise ValueError("Invalid MOA clusterer provided.")
self.moa_learner = moa_learner

# self.moa_learner.setRandomSeed(self.random_seed)

if self.schema is not None:
self.moa_learner.setModelContext(self.schema.get_moa_header())

# If the CLI is None, we assume the object has already been configured
# or that default values should be used.
if self.CLI is not None:
self.moa_learner.getOptions().setViaCLIString(CLI)

self.moa_learner.prepareForUse()
self.moa_learner.resetLearningImpl()
self.moa_learner.setModelContext(schema.get_moa_header())

def __str__(self):
# Removes the package information from the name of the learner.
full_name = str(self.moa_learner.getClass().getCanonicalName())
return full_name.rsplit(".", 1)[1] if "." in full_name else full_name

def CLI_help(self):
return str(self.moa_learner.getOptions().getHelpString())

def train(self, instance):
self.moa_learner.trainOnInstance(instance.java_instance.getData())

def get_micro_clusters_centers(self):
ret = []
for c in self.moa_learner.getMicroClusteringResult().getClustering():
java_array = c.getCenter()[:-1]
python_array = [java_array[i] for i in range(len(java_array))] # Convert to Python list
ret.append(python_array)
return ret

def get_micro_clusters_radii(self):
ret = []
for c in self.moa_learner.getMicroClusteringResult().getClustering():
ret.append(c.getRadius())
return ret

def get_micro_clusters_weights(self):
ret = []
for c in self.moa_learner.getMicroClusteringResult().getClustering():
ret.append(c.getWeight())
return ret

def get_clusters_centers(self):
ret = []
for c in self.moa_learner.getClusteringResult().getClustering():
java_array = c.getCenter()[:-1]
python_array = [java_array[i] for i in range(len(java_array))] # Convert to Python list
ret.append(python_array)
return ret

def get_clusters_radii(self):
ret = []
for c in self.moa_learner.getClusteringResult().getClustering():
ret.append(c.getRadius())
return ret

def get_clusters_weights(self):
ret = []
for c in self.moa_learner.getClusteringResult().getClustering():
ret.append(c.getWeight())
return ret


# def predict(self, instance):
# return Utils.maxIndex(
# self.moa_learner.getVotesForInstance(instance.java_instance)
# )

# def predict_proba(self, instance):
# return self.moa_learner.getVotesForInstance(instance.java_instance)

2 changes: 2 additions & 0 deletions src/capymoa/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PredictionIntervalEvaluator,
PredictionIntervalWindowedEvaluator,
AnomalyDetectionEvaluator,
ClusteringEvaluator,
)
from . import results

Expand All @@ -25,5 +26,6 @@
"PredictionIntervalEvaluator",
"PredictionIntervalWindowedEvaluator",
"AnomalyDetectionEvaluator",
"ClusteringEvaluator",
"results"
]
67 changes: 66 additions & 1 deletion src/capymoa/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
import os

from capymoa.stream import Schema, Stream
from capymoa.base import AnomalyDetector, ClassifierSSL, MOAPredictionIntervalLearner

from capymoa.base import (
AnomalyDetector,
ClassifierSSL,
MOAPredictionIntervalLearner,
Clusterer
)

from capymoa.evaluation.results import PrequentialResults
from capymoa._utils import _translate_metric_name
Expand Down Expand Up @@ -585,6 +591,65 @@ def s_auc(self):
index = self.metrics_header().index("s_auc")
return self.metrics()[index]

class ClusteringEvaluator:
# TODO improve documentation of ClusteringEvaluator
"""
Abstract clustering evaluator for CapyMOA.
It is slightly different from the other evaluators because it does not have a moa_evaluator object.
Clustering evaluation at this point is very simple and only uses the unsupervised metrics.
"""
def __init__(self, update_interval=1000):
"""
Only the update_interval is set here.
"""
self.instances_seen = 0
self.update_interval = update_interval
self.measurements = {name: [] for name in self.metrics_header()}
# self.clusterer = None
self.clusterer_name = None

def __str__(self):
return str(self.metrics_dict())

def get_instances_seen(self):
return self.instances_seen

def get_update_interval(self):
return self.update_interval

def get_clusterer_name(self):
return self.clusterer_name

def update(self, clusterer: Clusterer):
if self.clusterer_name is None:
self.clusterer_name = str(clusterer)
self.instances_seen += 1
if self.instances_seen % self.update_interval == 0:
self._update_measurements(clusterer)

def _update_measurements(self, clusterer: Clusterer):
# update centers, weights, sizes, and radii
self.measurements["m_centers"].append(clusterer.get_micro_clusters_centers())
self.measurements["m_weights"].append(clusterer.get_micro_clusters_weights())
self.measurements["m_radii"].append(clusterer.get_micro_clusters_radii())
# if there is a way to get cluster IDs, add it below
# self.measurements["m_IDs"].append(clusterer.get_clusters_ids())

# calculate silhouette score
# TODO: delegate silhouette to moa
# Check how it is done among different clusterers

def metrics_header(self):
performance_names = ["m_centers", "m_weights", "m_radii"]
return performance_names

def metrics(self):
# using the static list to keep the order of the metrics
return [self.measurements[key] for key in self.metrics_header()]

def get_measurements(self):
return self.measurements


class ClassificationWindowedEvaluator(ClassificationEvaluator):
"""
Expand Down
122 changes: 121 additions & 1 deletion src/capymoa/evaluation/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
PrequentialResults,
# PrequentialRegressionResults
)

from capymoa.base import Clusterer
import os
import shutil
from PIL import Image
import glob

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -833,3 +837,119 @@ def _plot_second(i, alpha):
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
figure_name = figure_name if figure_name else f"prediction_interval_over_time_comparison_{current_time}.pdf"
plt.savefig(figure_path + figure_name)

def _plot_clustering_state(
clusterer_name,
centers,
weights=None,
radii=None,
cluster_ids=None,
figure_path="./",
figure_name=None,
show_fig=True,
save_fig=False,
make_gif=False,
):
fig, ax = plt.subplots()
# Use a colormap to represent weights
if weights is not None:
scatter = ax.scatter(*zip(*centers), c=weights, cmap='copper', label='Centers', s=100, edgecolor='k')
cbar = fig.colorbar(scatter)
cbar.set_label('Weights')

# Add circles representing the radius of each center
if radii is not None:
for (x, y), radius in zip(centers, radii):
circle = plt.Circle((x, y), radius, color='red', fill=False)
ax.add_patch(circle)

# Annotate the centers with cluster IDs
if cluster_ids is not None:
for (x, y), cluster_id in zip(centers, cluster_ids):
ax.text(x, y, str(cluster_id), fontsize=7, ha='center', va='center', color='white')
else:
for (x, y), cluster_id in zip(centers, range(len(centers))):
ax.text(x, y, str(cluster_id), fontsize=7, ha='center', va='center', color='white')

# Add labels and title
output_name = f'Clustering from {clusterer_name}'
ax.set_xlabel('F1')
ax.set_ylabel('F2')
ax.set_title(output_name)
ax.legend()
ax.axis('equal') # Ensure that the circles are not distorted
# Show the plot or save it to the specified path
if show_fig:
plt.show()
else:
plt.close(fig)
if make_gif:
ax.set_title(output_name + f'_{figure_name}')
# print(ax.get_xlim(), ax.get_ylim())
minx, maxx, miny, maxy = ax.get_xlim()[0], ax.get_xlim()[1], ax.get_ylim()[0], ax.get_ylim()[1]
return fig, minx, maxx, miny, maxy
else:
# not a gif, use timestamp
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
figure_name = figure_name if figure_name else f"clustering_result_{current_time}"
fig.savefig(figure_path + figure_name + '.png', dpi=300)

def plot_clustering_state(
clusterer: Clusterer,
plot_radii=True,
plot_weights=True,
plot_IDs=True,
figure_path="./",
figure_name=None,
show_fig=True,
save_fig=False,
make_gif=False,
):
centers = clusterer.get_micro_clusters_centers()
weights = clusterer.get_micro_clusters_weights() if plot_weights else None
radii = clusterer.get_micro_clusters_radii() if plot_radii else None
# Assuming cluster IDs are 0, 1, 2, ..., N-1
cluster_ids = range(len(centers)) if plot_IDs else None
_plot_clustering_state(str(clusterer), centers, weights, radii, cluster_ids, figure_path, figure_name, show_fig, save_fig, make_gif)

def plot_clustering_evolution(clusteringResults, clean_up=True, dpi=300):
centers = clusteringResults.get_measurements()['m_centers']
weights = clusteringResults.get_measurements()['m_weights']
radii = clusteringResults.get_measurements()['m_radii']
gif_path = './gifmaker/'
os.makedirs(gif_path, exist_ok=True)
figs = []

maxx, maxy, minx, miny = -np.inf, -np.inf, np.inf, np.inf
for i in range(len(centers)):
fig, e_minx, e_maxx, e_miny, e_maxy = _plot_clustering_state(clusteringResults.clusterer_name, centers[i], weights[i], radii[i], figure_path=gif_path, figure_name=f'{i:05}', show_fig=False, save_fig=True, make_gif=True)
if e_minx < minx:
minx = e_minx
if e_maxx > maxx:
maxx = e_maxx
if e_miny < miny:
miny = e_miny
if e_maxy > maxy:
maxy = e_maxy
figs.append(fig)

# make the images with shared x and y lim
for f in figs:
f.gca().set_xlim([minx, maxx])
f.gca().set_ylim([miny, maxy])
f.savefig(f'{gif_path}{f.gca().get_title()}.png', dpi=dpi)
plt.close(f)

# Open images and store them in a list
images = [Image.open(img) for img in sorted(glob.glob(gif_path + '*.png'))]
# Create a GIF from the images
images[0].save(
'output.gif',
save_all=True,
append_images=images[1:],
duration=500, # Duration of each frame in milliseconds
loop=1 # 0 means loop forever; set to 1 for single loop
)
# clean up after making the gif
if clean_up:
shutil.rmtree(gif_path)
Loading

0 comments on commit 09b3a60

Please sign in to comment.