Skip to content

Commit

Permalink
Merge pull request #4313 from ales-erjavec/silhouette-plot-dist-matri…
Browse files Browse the repository at this point in the history
…x-input

[ENH] Silhouette plot: Accept distance matrix on input
  • Loading branch information
janezd authored Jan 10, 2020
2 parents 27aabbf + a9dab28 commit 667af57
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 67 deletions.
197 changes: 130 additions & 67 deletions Orange/widgets/visualize/owsilhouetteplot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import itertools
import enum

from xml.sax.saxutils import escape
from types import SimpleNamespace as namespace

from typing import Optional
from typing import Optional, Union

import numpy as np
import sklearn.metrics
Expand All @@ -13,6 +12,7 @@
QGraphicsScene, QGraphicsWidget, QGraphicsGridLayout,
QGraphicsItemGroup, QGraphicsSimpleTextItem, QGraphicsRectItem,
QSizePolicy, QStyleOptionGraphicsItem, QWidget, QWIDGETSIZE_MAX,
QVBoxLayout
)
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QFontMetrics, QPalette
from AnyQt.QtCore import Qt, QEvent, QRectF, QSizeF, QSize, QPointF
Expand All @@ -22,6 +22,9 @@

import Orange.data
import Orange.distance
import Orange.misc
from Orange.data import Table, Domain
from Orange.misc import DistMatrix

from Orange.widgets import widget, gui, settings
from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView
Expand All @@ -39,6 +42,20 @@
ROW_NAMES_WIDTH = 200


class InputValidationError(ValueError):
message: str


class NoGroupVariable(InputValidationError):
message = "Input does not have any suitable labels"


class ValidationError(InputValidationError):
def __init__(self, message: str):
super().__init__(message)
self.message = message


class OWSilhouettePlot(widget.OWWidget):
name = "Silhouette Plot"
description = "Visually assess cluster quality and " \
Expand All @@ -49,7 +66,7 @@ class OWSilhouettePlot(widget.OWWidget):
keywords = []

class Inputs:
data = Input("Data", Orange.data.Table)
data = Input("Data", (Orange.data.Table, Orange.misc.DistMatrix))

class Outputs:
selected_data = Output("Selected Data", Orange.data.Table, default=True)
Expand Down Expand Up @@ -90,6 +107,7 @@ class Error(widget.OWWidget.Error):
singleton_clusters_all = Msg("All clusters are singletons")
memory_error = Msg("Not enough memory")
value_error = Msg("Distances could not be computed: '{}'")
input_validation_error = Msg("{}")

class Warning(widget.OWWidget.Warning):
missing_cluster_assignment = Msg(
Expand All @@ -101,7 +119,10 @@ def __init__(self):
super().__init__()
#: The input data
self.data = None # type: Optional[Orange.data.Table]
#: Distance matrix computed from data
#: The input distance matrix (if present)
self.distances = None # type: Optional[Orange.misc.DistMatrix]
#: The effective distance matrix (is self.distances or computed from
#: self.data depending on input)
self._matrix = None # type: Optional[Orange.misc.DistMatrix]
#: An bool mask (size == len(data)) indicating missing group/cluster
#: assignments
Expand All @@ -114,10 +135,16 @@ def __init__(self):
self._silhouette = None # type: Optional[np.ndarray]
self._silplot = None # type: Optional[SilhouettePlot]

gui.comboBox(
self.controlArea, self, "distance_idx", box="Distance",
controllayout = self.controlArea.layout()
assert isinstance(controllayout, QVBoxLayout)
self._distances_gui_box = distbox = gui.widgetBox(
None, "Distance"
)
self._distances_gui_cb = gui.comboBox(
distbox, self, "distance_idx",
items=[name for name, _ in OWSilhouettePlot.Distances],
orientation=Qt.Horizontal, callback=self._invalidate_distances)
controllayout.addWidget(distbox)

box = gui.vBox(self.controlArea, "Cluster Label")
self.cluster_var_cb = gui.comboBox(
Expand Down Expand Up @@ -179,58 +206,85 @@ def pack_settings(self):

@Inputs.data
@check_sql_input
def set_data(self, data):
def set_data(self, data: Union[Table, DistMatrix, None]):
"""
Set the input dataset.
Set the input dataset or distance matrix.
"""
self.closeContext()
self.clear()
error_msg = ""
warning_msg = ""
candidatevars = []
if data is not None:
candidatevars = [
v for v in data.domain.variables + data.domain.metas
if v.is_discrete and len(v.values) >= 2]
if not candidatevars:
error_msg = "Input does not have any suitable labels."
data = None

self.data = data
if data is not None:
self.cluster_var_model[:] = candidatevars
if data.domain.class_var in candidatevars:
self.cluster_var_idx = \
candidatevars.index(data.domain.class_var)
try:
if isinstance(data, Orange.misc.DistMatrix):
self._set_distances(data)
elif isinstance(data, Orange.data.Table):
self._set_table(data)
else:
self.cluster_var_idx = 0
self.distances = None
self.data = None
except InputValidationError as err:
self.Error.input_validation_error(err.message)
self.distances = None
self.data = None

def _set_table(self, data: Table):
self._setup_control_models(data.domain)
self.data = data
self.distances = None

annotvars = [var for var in data.domain.metas if var.is_string]
self.annotation_var_model[:] = ["None"] + annotvars
self.annotation_var_idx = 1 if len(annotvars) else 0
self.openContext(Orange.data.Domain(candidatevars))
def _set_distances(self, distances: DistMatrix):
if isinstance(distances.row_items, Orange.data.Table) and \
distances.axis == 1:
data = distances.row_items
else:
raise ValidationError("Input matrix does not have associated data")

self.error(error_msg)
self.warning(warning_msg)
if data is not None:
self._setup_control_models(data.domain)
self.distances = distances
self.data = data

def handleNewSignals(self):
if self.data is not None:
if not self._is_empty():
self._update()
self._replot()
if self.pending_selection is not None and self._silplot is not None:
# If selection contains indices that are too large, the data
# file must had been modified, so we ignore selection
if max(self.pending_selection) < len(self.data):
if max(self.pending_selection, default=-1) < len(self.data):
self._silplot.setSelection(np.array(self.pending_selection))
self.pending_selection = None

# Disable/enable the Distances GUI controls if applicable
self._distances_gui_box.setEnabled(self.distances is None)

self.unconditional_commit()

def _setup_control_models(self, domain: Domain):
groupvars = [
v for v in domain.variables + domain.metas
if v.is_discrete and len(v.values) >= 2]
if not groupvars:
raise NoGroupVariable()
self.cluster_var_model[:] = groupvars
if domain.class_var in groupvars:
self.cluster_var_idx = groupvars.index(domain.class_var)
else:
self.cluster_var_idx = 0
annotvars = [var for var in domain.metas if var.is_string]
self.annotation_var_model[:] = ["None"] + annotvars
self.annotation_var_idx = 1 if annotvars else 0
self.openContext(Orange.data.Domain(groupvars))

def _is_empty(self) -> bool:
# Is empty (does not have any input).
return (self.data is None or len(self.data) == 0) \
and self.distances is None

def clear(self):
"""
Clear the widget state.
"""
self.data = None
self.distances = None
self._matrix = None
self._mask = None
self._silhouette = None
Expand Down Expand Up @@ -263,44 +317,42 @@ def _invalidate_scores(self):
if self.data is not None:
self.commit()

def _ensure_matrix(self):
# ensure self._matrix is computed if necessary
if self._is_empty():
return
if self._matrix is None:
if self.distances is not None:
self._matrix = np.asarray(self.distances)
elif self.data is not None:
data = self.data
_, metric = self.Distances[self.distance_idx]
if not metric.supports_discrete and any(
a.is_discrete for a in data.domain.attributes):
self.Warning.ignoring_categorical()
data = Orange.distance.remove_discrete_features(data)
try:
self._matrix = np.asarray(metric(data))
except MemoryError:
self.Error.memory_error()
return
except ValueError as err:
self.Error.value_error(str(err))
return
else:
assert False, "invalid state"

def _update(self):
# Update/recompute the distances/scores as required
# Update/recompute the effective distances and scores as required.
self._clear_messages()

if self.data is None or not len(self.data):
if self._is_empty():
self._reset_all()
return

if self._matrix is None and self.data is not None:
_, metric = self.Distances[self.distance_idx]
data = self.data
if not metric.supports_discrete and any(
a.is_discrete for a in data.domain.attributes):
self.Warning.ignoring_categorical()
data = Orange.distance.remove_discrete_features(data)
try:
self._matrix = np.asarray(metric(data))
except MemoryError:
self.Error.memory_error()
return
except ValueError as err:
self.Error.value_error(str(err))
return

self._update_labels()

def _reset_all(self):
self._mask = None
self._silhouette = None
self._labels = None
self._matrix = None
self._clear_scene()

def _clear_messages(self):
self.Error.clear()
self.Warning.clear()
self._ensure_matrix()
if self._matrix is None:
return

def _update_labels(self):
labelvar = self.cluster_var_model[self.cluster_var_idx]
labels, _ = self.data.get_column_view(labelvar)
labels = np.asarray(labels, dtype=float)
Expand All @@ -310,7 +362,7 @@ def _update_labels(self):
labels = labels.astype(int)
labels = labels[~mask]

labels_unq, _ = np.unique(labels, return_counts=True)
labels_unq = np.unique(labels)

if len(labels_unq) < 2:
self.Error.need_two_clusters()
Expand All @@ -335,6 +387,17 @@ def _update_labels(self):
self.Warning.nan_distances(
count_nandist, s="s" if count_nandist > 1 else "")

def _reset_all(self):
self._mask = None
self._silhouette = None
self._labels = None
self._matrix = None
self._clear_scene()

def _clear_messages(self):
self.Error.clear()
self.Warning.clear()

def _set_bar_height(self):
visible = self.bar_size >= 5
self._silplot.setBarHeight(self.bar_size)
Expand Down
43 changes: 43 additions & 0 deletions Orange/widgets/visualize/tests/test_owsilhouetteplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import numpy as np

import Orange.distance
from Orange.data import (
Table, Domain, ContinuousVariable, DiscreteVariable, StringVariable)
from Orange.misc import DistMatrix
from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_SIGNAL_NAME
from Orange.widgets.visualize.owsilhouetteplot import OWSilhouettePlot
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin
Expand Down Expand Up @@ -174,6 +176,47 @@ def test_saved_selection(self):
self.send_signal(w.Inputs.data, iris, widget=w)
self.assertEqual(len(self.get_output(w.Outputs.selected_data)), 20)

def test_distance_input(self):
widget = self.widget
data = Table("heart_disease")[::4]
matrix = Orange.distance.Euclidean(data)
self.send_signal(widget.Inputs.data, matrix, widget=widget)
self.assertIsNotNone(widget.distances)
self.assertIsNotNone(widget.data)
self.assertFalse(widget._distances_gui_box.isEnabled())

self.send_signal(widget.Inputs.data, data, widget=widget)
self.assertIsNone(widget.distances)
self.assertIsNotNone(widget.data)
self.assertTrue(widget._distances_gui_box.isEnabled())

def test_input_distance_no_data(self):
widget = self.widget
matrix = DistMatrix(
np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]),
row_items=None
)
self.send_signal(widget.Inputs.data, matrix, widget=widget)
self.assertTrue(widget.Error.input_validation_error.is_shown())
self.assertIsNone(widget.data)
self.assertIsNone(widget.distances)
self.send_signal(widget.Inputs.data, None, widget=widget)
self.assertFalse(widget.Error.input_validation_error.is_shown())

def test_no_group_var(self):
widget = self.widget
data = Table("iris")[::4]
data = data[:, data.domain.attributes]
matrix = Orange.distance.Euclidean(data)
self.send_signal(widget.Inputs.data, matrix, widget=widget)

self.assertTrue(widget.Error.input_validation_error.is_shown())
self.assertIsNone(widget.data)
self.assertIsNone(widget.distances)

self.send_signal(widget.Inputs.data, None, widget=widget)
self.assertFalse(widget.Error.input_validation_error.is_shown())


if __name__ == "__main__":
unittest.main()

0 comments on commit 667af57

Please sign in to comment.