Skip to content

Commit

Permalink
Merge pull request #1974 from janezd/fix-kmeans-fail
Browse files Browse the repository at this point in the history
[FIX] KMeans: Fix crashes when underlying algorithm fails
  • Loading branch information
astaric authored Feb 3, 2017
2 parents af4cbe2 + 6c0a8d1 commit c80741b
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 44 deletions.
117 changes: 75 additions & 42 deletions Orange/widgets/unsupervised/owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
from AnyQt.QtWidgets import QGridLayout, QSizePolicy, QTableView
from AnyQt.QtGui import QStandardItemModel, QStandardItem, QIntValidator
from AnyQt.QtGui import QStandardItemModel, QStandardItem, QIntValidator, QBrush
from AnyQt.QtCore import Qt, QTimer

from Orange.clustering import KMeans
Expand All @@ -24,15 +24,18 @@ class OWKMeans(widget.OWWidget):
outputs = [("Annotated Data", Table, widget.Default),
("Centroids", Table)]

class Error(widget.OWWidget.Error):
failed = widget.Msg("Clustering failed\nError: {}")

INIT_KMEANS, INIT_RANDOM = range(2)
INIT_METHODS = "Initialize with KMeans++", "Random initialization"

SILHOUETTE, INTERCLUSTER, DISTANCES = range(3)
SCORING_METHODS = [("Silhouette", lambda km: km.silhouette, False),
SCORING_METHODS = [("Silhouette", lambda km: km.silhouette, False, True),
("Inter-cluster distance",
lambda km: km.inter_cluster, True),
lambda km: km.inter_cluster, True, False),
("Distance to centroids",
lambda km: km.inertia, True)]
lambda km: km.inertia, True, False)]

OUTPUT_CLASS, OUTPUT_ATTRIBUTE, OUTPUT_META = range(3)
OUTPUT_METHODS = ("Class", "Feature", "Meta")
Expand Down Expand Up @@ -150,7 +153,8 @@ def __init__(self):
table.setSelectionMode(QTableView.SingleSelection)
table.setSelectionBehavior(QTableView.SelectRows)
table.verticalHeader().hide()
table.setItemDelegateForColumn(1, gui.TableBarItem(self))
self.bar_delegate = gui.ColoredBarItemDelegate(self, color=Qt.cyan)
table.setItemDelegateForColumn(1, self.bar_delegate)
table.setModel(self.table_model)
table.selectionModel().selectionChanged.connect(
self.table_item_selected)
Expand Down Expand Up @@ -219,6 +223,7 @@ def run_optimization(self):
try:
self.controlArea.setDisabled(True)
self.optimization_runs = []
error = ""
if not self.check_data_size(self.k_from, self.Error):
return
self.check_data_size(self.k_to, self.Warning)
Expand All @@ -231,7 +236,15 @@ def run_optimization(self):
for k in range(self.k_from, k_to + 1):
progress.advance()
kmeans.params["n_clusters"] = k
self.optimization_runs.append((k, kmeans(self.data)))
try:
self.optimization_runs.append((k, kmeans(self.data)))
except BaseException as exc:
error = str(exc)
self.optimization_runs.append((k, error))
if all(isinstance(score, str)
for _, score in self.optimization_runs):
self.Error.failed(error) # Report just the last error
self.optimization_runs = []
finally:
self.controlArea.setDisabled(False)
self.show_results()
Expand All @@ -240,11 +253,15 @@ def run_optimization(self):
def cluster(self):
if not self.check_data_size(self.k, self.Error):
return
self.km = KMeans(
n_clusters=self.k,
init=['random', 'k-means++'][self.smart_init],
n_init=self.n_init,
max_iter=self.max_iterations)(self.data)
try:
self.km = KMeans(
n_clusters=self.k,
init=['random', 'k-means++'][self.smart_init],
n_init=self.n_init,
max_iter=self.max_iterations)(self.data)
except BaseException as exc:
self.Error.failed(str(exc))
self.km = None
self.send_data()

def run(self):
Expand All @@ -260,40 +277,55 @@ def commit(self):
self.run()

def show_results(self):
minimize = self.SCORING_METHODS[self.scoring][2]
k_scores = [(k, self.SCORING_METHODS[self.scoring][1](run)) for
k, run in self.optimization_runs]
scores = list(zip(*k_scores))[1]
if minimize:
best_score, worst_score = min(scores), max(scores)
_, scoring_method, minimize, normal = self.SCORING_METHODS[self.scoring]
k_scores = [(k,
scoring_method(run) if not isinstance(run, str) else run)
for k, run in self.optimization_runs]
scores = [score for _, score in k_scores if not isinstance(score, str)]

min_score, max_score = min(scores, default=0), max(scores, default=1)
best_score = min_score if minimize else max_score
if normal:
min_score, max_score = 0, 1
nplaces = 3
else:
best_score, worst_score = max(scores), min(scores)
nplaces = min(5, np.floor(abs(math.log(max(max_score, 1e-10)))) + 2)
score_span = (max_score - min_score) or 1
self.bar_delegate.scale = (min_score, max_score)
self.bar_delegate.float_fmt = "%%.%if" % int(nplaces)

best_run = scores.index(best_score)
score_span = (best_score - worst_score) or 1
max_score = max(scores)
nplaces = min(5, np.floor(abs(math.log(max(max_score, 1e-10)))) + 2)
fmt = "{{:.{}f}}".format(int(nplaces))
model = self.table_model
model.setRowCount(len(k_scores))
no_selection = True
for i, (k, score) in enumerate(k_scores):
item = model.item(i, 0)
if item is None:
item = QStandardItem()
item.setData(k, Qt.DisplayRole)
item.setTextAlignment(Qt.AlignCenter)
model.setItem(i, 0, item)
item = model.item(i, 1)
if item is None:
item = QStandardItem()
item.setData(fmt.format(score) if not np.isnan(score) else 'out-of-memory error',
Qt.DisplayRole)
bar_ratio = 0.95 * (score - worst_score) / score_span
item.setData(bar_ratio, gui.TableBarItem.BarRole)
item0 = model.item(i, 0) or QStandardItem()
item0.setData(k, Qt.DisplayRole)
item0.setTextAlignment(Qt.AlignCenter)
model.setItem(i, 0, item0)
item = model.item(i, 1) or QStandardItem()
if not isinstance(score, str):
item.setData(score, Qt.DisplayRole)
item.setData(None, Qt.ToolTipRole)
bar_ratio = 0.95 * (score - min_score) / score_span
item.setData(bar_ratio, gui.BarRatioRole)
if no_selection and score == best_score:
self.table_view.selectRow(i)
no_selection = False
color = Qt.black
flags = Qt.ItemIsEnabled | Qt.ItemIsSelectable
else:
item.setData("clustering failed", Qt.DisplayRole)
item.setData(score, Qt.ToolTipRole)
item.setData(None, gui.BarRatioRole)
color = Qt.gray
flags = Qt.NoItemFlags
item0.setData(QBrush(color), Qt.ForegroundRole)
item0.setFlags(flags)
item.setData(QBrush(color), Qt.ForegroundRole)
item.setFlags(flags)
model.setItem(i, 1, item)
self.table_view.resizeRowsToContents()

self.table_view.selectRow(best_run)
self.table_view.show()
if minimize:
self.table_box.setTitle("Scoring (smaller is better)")
Expand All @@ -314,13 +346,12 @@ def selected_row(self):
def table_item_selected(self):
row = self.selected_row()
if row is not None:
self.send_data(row)
self.send_data()

def send_data(self, row=None):
def send_data(self):
if self.optimize_k:
if row is None:
row = self.selected_row()
km = self.optimization_runs[row][1]
row = self.selected_row() if self.optimization_runs else None
km = self.optimization_runs[row][1] if row is not None else None
else:
km = self.km
if not self.data or not km:
Expand Down Expand Up @@ -356,6 +387,8 @@ def send_data(self, row=None):
def set_data(self, data):
self.data = data
if data is None:
self.Error.clear()
self.Warning.clear()
self.table_model.setRowCount(0)
self.send("Annotated Data", None)
self.send("Centroids", None)
Expand Down
63 changes: 61 additions & 2 deletions Orange/widgets/unsupervised/tests/test_owkmeans.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from unittest.mock import patch

from AnyQt.QtWidgets import QRadioButton

from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.unsupervised.owkmeans import OWKMeans
import Orange.clustering

from Orange.data import Table
class TestOWKMeans(WidgetTest):

def setUp(self):
self.widget = self.create_widget(OWKMeans,
stored_settings={"auto_apply": False})
self.widget = self.create_widget(
OWKMeans, stored_settings={"auto_apply": False}) # type: OWKMeans
self.iris = Table("iris")

def test_optimization_report_display(self):
Expand All @@ -32,3 +35,59 @@ def test_data_on_output(self):
self.send_signal("Data", None)
# removing data should have cleared the output
self.assertEqual(self.widget.data, None)

class KMeansFail(Orange.clustering.KMeans):
fail_on = set()

def fit(self, *args):
# when not optimizing, params is empty?!
k = self.params.get("n_clusters", 3)
if k in self.fail_on:
raise ValueError("k={} fails".format(k))
return super().fit(*args)

@patch("Orange.widgets.unsupervised.owkmeans.KMeans", new=KMeansFail)
def test_optimization_fails(self):
widget = self.widget
widget.k_from = 3
widget.k_to = 8
widget.scoring = 0
widget.optimize_k = True

self.KMeansFail.fail_on = {3, 5, 7}
self.send_signal("Data", self.iris)
self.assertIsInstance(widget.optimization_runs[0][1], str)
self.assertIsInstance(widget.optimization_runs[2][1], str)
self.assertIsInstance(widget.optimization_runs[4][1], str)
self.assertNotIsInstance(widget.optimization_runs[1][1], str)
self.assertNotIsInstance(widget.optimization_runs[3][1], str)
self.assertNotIsInstance(widget.optimization_runs[5][1], str)
self.assertFalse(widget.Error.failed.is_shown())
self.assertEqual(widget.selected_row(), 1)
self.assertIsNotNone(self.get_output("Annotated Data"))

self.KMeansFail.fail_on = set(range(3, 9))
widget.run()
self.assertTrue(widget.Error.failed.is_shown())
self.assertEqual(widget.optimization_runs, [])
self.assertIsNone(self.get_output("Annotated Data"))

self.KMeansFail.fail_on = set()
widget.run()
self.assertFalse(widget.Error.failed.is_shown())
self.assertEqual(widget.selected_row(), 0)
self.assertIsNotNone(self.get_output("Annotated Data"))

@patch("Orange.widgets.unsupervised.owkmeans.KMeans", new=KMeansFail)
def test_run_fails(self):
self.widget.k = 3
self.widget.optimize_k = False
self.KMeansFail.fail_on = {3}
self.send_signal("Data", self.iris)
self.assertTrue(self.widget.Error.failed.is_shown())
self.assertIsNone(self.get_output("Annotated Data"))

self.KMeansFail.fail_on = set()
self.widget.run()
self.assertFalse(self.widget.Error.failed.is_shown())
self.assertIsNotNone(self.get_output("Annotated Data"))

0 comments on commit c80741b

Please sign in to comment.