From d914df78d7e6d620252714a40f6f5c194cea4511 Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Thu, 28 Feb 2019 12:25:46 +0100 Subject: [PATCH 1/3] FreeViz: Offload work to a separate thread --- Orange/projection/base.py | 9 + Orange/widgets/tests/base.py | 8 +- Orange/widgets/utils/concurrent.py | 215 ++++++++++++- Orange/widgets/visualize/owfreeviz.py | 300 +++++++----------- .../widgets/visualize/tests/test_owfreeviz.py | 77 ++++- 5 files changed, 404 insertions(+), 205 deletions(-) diff --git a/Orange/projection/base.py b/Orange/projection/base.py index 58211b1ac6b..28fea302eb0 100644 --- a/Orange/projection/base.py +++ b/Orange/projection/base.py @@ -1,3 +1,4 @@ +import copy import inspect import threading @@ -135,6 +136,7 @@ def proj_variable(i, name): super().__init__(proj=proj) self.orig_domain = domain + self.n_components = n_components var_names = self._get_var_names(n_components) self.domain = Orange.data.Domain( [proj_variable(i, var_names[i]) for i in range(n_components)], @@ -145,6 +147,13 @@ def _get_var_names(self, n): names = [f"{self.var_prefix}-{postfix}" for postfix in postfixes] return get_unique_names(self.orig_domain, names) + def copy(self): + proj = copy.deepcopy(self.proj) + model = type(self)(proj, self.domain.copy(), self.n_components) + model.pre_domain = self.pre_domain.copy() + model.name = self.name + return model + class LinearProjector(Projector): name = "Linear Projection" diff --git a/Orange/widgets/tests/base.py b/Orange/widgets/tests/base.py index 8d5ddc34bec..7ec3020df10 100644 --- a/Orange/widgets/tests/base.py +++ b/Orange/widgets/tests/base.py @@ -1013,12 +1013,18 @@ def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT): self.widget.graph.update_coordinates.assert_not_called() self.widget.graph.update_point_props.assert_called_once() - def test_saved_selection(self): + def test_saved_selection(self, timeout=DEFAULT_TIMEOUT): self.send_signal(self.widget.Inputs.data, self.data) + if self.widget.isBlocking(): + spy = QSignalSpy(self.widget.blockingStateChanged) + self.assertTrue(spy.wait(timeout)) self.widget.graph.select_by_indices(list(range(0, len(self.data), 10))) settings = self.widget.settingsHandler.pack_data(self.widget) w = self.create_widget(self.widget.__class__, stored_settings=settings) self.send_signal(self.widget.Inputs.data, self.data, widget=w) + if w.isBlocking(): + spy = QSignalSpy(w.blockingStateChanged) + self.assertTrue(spy.wait(timeout)) self.assertEqual(np.sum(w.graph.selection), 15) np.testing.assert_equal(self.widget.graph.selection, w.graph.selection) diff --git a/Orange/widgets/utils/concurrent.py b/Orange/widgets/utils/concurrent.py index d0c54ce10cb..38a2a807519 100644 --- a/Orange/widgets/utils/concurrent.py +++ b/Orange/widgets/utils/concurrent.py @@ -3,6 +3,7 @@ """ # TODO: Rename the module to something that does not conflict with stdlib # concurrent +from typing import Callable, Any import os import threading import atexit @@ -11,18 +12,15 @@ import weakref from functools import partial import concurrent.futures - from concurrent.futures import Future, TimeoutError from contextlib import contextmanager -from typing import Callable, Optional, Any, List from AnyQt.QtCore import ( Qt, QObject, QMetaObject, QThreadPool, QThread, QRunnable, QSemaphore, - QEventLoop, QCoreApplication, QEvent, Q_ARG + QEventLoop, QCoreApplication, QEvent, Q_ARG, + pyqtSignal as Signal, pyqtSlot as Slot ) -from AnyQt.QtCore import pyqtSignal as Signal, pyqtSlot as Slot - _log = logging.getLogger(__name__) @@ -784,3 +782,210 @@ def __call__(self, *args): args = [Q_ARG(atype, arg) for atype, arg in zip(self.arg_types, args)] return QMetaObject.invokeMethod( self.obj, self.method, self.conntype, *args) + + +class TaskState(QObject): + + status_changed = Signal(str) + _p_status_changed = Signal(str) + + progress_changed = Signal(float) + _p_progress_changed = Signal(float) + + partial_result_ready = Signal(object) + _p_partial_result_ready = Signal(object) + + def __init__(self, *args): + super().__init__(*args) + self.__future = None + self.watcher = FutureWatcher() + self.__interruption_requested = False + self.__progress = 0 + # Helpers to route the signal emits via a this object's queue. + # This ensures 'atomic' disconnect from signals for targets/slots + # in the same thread. Requires that the event loop is running in this + # object's thread. + self._p_status_changed.connect( + self.status_changed, Qt.QueuedConnection) + self._p_progress_changed.connect( + self.progress_changed, Qt.QueuedConnection) + self._p_partial_result_ready.connect( + self.partial_result_ready, Qt.QueuedConnection) + + @property + def future(self) -> Future: + return self.__future + + def set_status(self, text: str): + self._p_status_changed.emit(text) + + def set_progress_value(self, value: float): + if round(value, 1) > round(self.__progress, 1): + # Only emit progress when it has changed sufficiently + self._p_progress_changed.emit(value) + self.__progress = value + + def set_partial_result(self, value: Any): + self._p_partial_result_ready.emit(value) + + def is_interruption_requested(self) -> bool: + return self.__interruption_requested + + def start(self, executor: concurrent.futures.Executor, + func: Callable[[], Any] = None) -> Future: + assert self.future is None + assert not self.__interruption_requested + self.__future = executor.submit(func) + self.watcher.setFuture(self.future) + return self.future + + def cancel(self) -> bool: + assert not self.__interruption_requested + self.__interruption_requested = True + if self.future is not None: + rval = self.future.cancel() + else: + # not even scheduled yet + rval = True + return rval + + +class ConcurrentMixin: + """ + A base class for concurrent mixins. The class provides methods for running + tasks in a separate thread. + + Widgets should use `ConcurrentWidgetMixin` rather than this class. + """ + def __init__(self): + self.__executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.__task = None # type: Optional[TaskState] + + @property + def task(self) -> TaskState: + return self.__task + + def on_partial_result(self, result: Any) -> None: + """ Invoked from runner (by state) to send the partial results + The method should handle partial results, i.e. show them in the plot. + + :param result: any data structure to hold final result + """ + raise NotImplementedError + + def on_done(self, result: Any) -> None: + """ Invoked when task is done. + The method should re-set the result (to double check it) and + perform operations with obtained results, eg. send data to the output. + + :param result: any data structure to hold temporary result + """ + raise NotImplementedError + + def on_exception(self, ex: Exception): + """ Invoked when an exception occurs during the calculation. + Override in order to handle exceptions, eg. show an error + message in the widget. + + :param ex: exception + """ + raise ex + + def start(self, task: Callable, *args, **kwargs): + """ Call from derived class to start the task. + :param task: runner - a method to run in a thread - should accept + `state` parameter + """ + self.__cancel_task(wait=False) + assert callable(task), "`task` must be callable!" + state = TaskState(self) + task = partial(task, *(args + (state,)), **kwargs) + self.__start_task(task, state) + + def cancel(self): + """ Call from derived class to stop the task. """ + self.__cancel_task(wait=False) + + def shutdown(self): + """ Call from derived class when the widget is deleted + (in onDeleteWidget). + """ + self.__cancel_task(wait=True) + self.__executor.shutdown(True) + + def __start_task(self, task: Callable[[], Any], state: TaskState): + assert self.__task is None + self._connect_signals(state) + state.start(self.__executor, task) + state.setParent(self) + self.__task = state + + def __cancel_task(self, wait: bool = True): + if self.__task is not None: + state, self.__task = self.__task, None + state.cancel() + self._disconnect_signals(state) + if wait: + concurrent.futures.wait([state.future]) + state.deleteLater() + else: + w = FutureWatcher(state.future, parent=state) + w.done.connect(state.deleteLater) + + def _connect_signals(self, state: TaskState): + state.partial_result_ready.connect(self.on_partial_result) + state.watcher.done.connect(self._on_task_done) + + def _disconnect_signals(self, state: TaskState): + state.partial_result_ready.disconnect(self.on_partial_result) + state.watcher.done.disconnect(self._on_task_done) + + def _on_task_done(self, future: Future): + assert future.done() + assert self.__task is not None + assert self.__task.future is future + assert self.__task.watcher.future() is future + self.__task, task = None, self.__task + task.deleteLater() + ex = future.exception() + if ex is not None: + self.on_exception(ex) + else: + self.on_done(future.result()) + + +class ConcurrentWidgetMixin(ConcurrentMixin): + """ + A concurrent mixin to be used along with OWWidget. + """ + def __set_state_ready(self): + self.progressBarFinished() + self.setBlocking(False) + self.setStatusMessage("") + + def __set_state_busy(self): + self.progressBarInit() + self.setBlocking(True) + + def start(self, task: Callable, *args, **kwargs): + self.__set_state_ready() + super().start(task, *args, **kwargs) + self.__set_state_busy() + + def cancel(self): + super().cancel() + self.__set_state_ready() + + def _connect_signals(self, state: TaskState): + super()._connect_signals(state) + state.status_changed.connect(self.setStatusMessage) + state.progress_changed.connect(self.progressBarSet) + + def _disconnect_signals(self, state: TaskState): + super()._disconnect_signals(state) + state.status_changed.disconnect(self.setStatusMessage) + state.progress_changed.disconnect(self.progressBarSet) + + def _on_task_done(self, future: Future): + super()._on_task_done(future) + self.__set_state_ready() diff --git a/Orange/widgets/visualize/owfreeviz.py b/Orange/widgets/visualize/owfreeviz.py index 729674ea5ce..749327abcf7 100644 --- a/Orange/widgets/visualize/owfreeviz.py +++ b/Orange/widgets/visualize/owfreeviz.py @@ -1,12 +1,10 @@ +# pylint: disable=too-many-ancestors from enum import IntEnum -import sys +from types import SimpleNamespace as namespace import numpy as np -from AnyQt.QtCore import ( - Qt, QObject, QEvent, QRectF, QLineF, QTimer, QPoint, - pyqtSignal as Signal, pyqtSlot as Slot -) +from AnyQt.QtCore import Qt, QRectF, QLineF, QPoint from AnyQt.QtGui import QColor import pyqtgraph as pg @@ -15,128 +13,45 @@ from Orange.projection import FreeViz from Orange.projection.freeviz import FreeVizModel from Orange.widgets import widget, gui, settings +from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.visualize.utils.component import OWGraphWithAnchors from Orange.widgets.visualize.utils.plotutils import AnchorItem from Orange.widgets.visualize.utils.widget import OWAnchorProjectionWidget -class AsyncUpdateLoop(QObject): - """ - Run/drive an coroutine from the event loop. - - This is a utility class which can be used for implementing - asynchronous update loops. I.e. coroutines which periodically yield - control back to the Qt event loop. - - """ - Next = QEvent.registerEventType() - - #: State flags - Idle, Running, Cancelled, Finished = 0, 1, 2, 3 - #: The coroutine has yielded control to the caller (with `object`) - yielded = Signal(object) - #: The coroutine has finished/exited (either with an exception - #: or with a return statement) - finished = Signal() - - #: The coroutine has returned (normal return statement / StopIteration) - returned = Signal(object) - #: The coroutine has exited with with an exception. - raised = Signal(object) - #: The coroutine was cancelled/closed. - cancelled = Signal() - - def __init__(self, parent=None, **kwargs): - super().__init__(parent, **kwargs) - self.__coroutine = None - self.__next_pending = False # Flag for compressing scheduled events - self.__in_next = False - self.__state = AsyncUpdateLoop.Idle - - @Slot(object) - def setCoroutine(self, loop): - """ - Set the coroutine. - - The coroutine will be resumed (repeatedly) from the event queue. - If there is an existing coroutine set it is first closed/cancelled. - - Raises an RuntimeError if the current coroutine is running. - """ - if self.__coroutine is not None: - self.__coroutine.close() - self.__coroutine = None - self.__state = AsyncUpdateLoop.Cancelled - - self.cancelled.emit() - self.finished.emit() - - if loop is not None: - self.__coroutine = loop - self.__state = AsyncUpdateLoop.Running - self.__schedule_next() - - @Slot() - def cancel(self): - """ - Cancel/close the current coroutine. - - Raises an RuntimeError if the current coroutine is running. - """ - self.setCoroutine(None) - - def state(self): - """ - Return the current state. - """ - return self.__state - - def isRunning(self): - return self.__state == AsyncUpdateLoop.Running - - def __schedule_next(self): - if not self.__next_pending: - self.__next_pending = True - QTimer.singleShot(10, self.__on_timeout) - - def __next(self): - if self.__coroutine is not None: - try: - rval = next(self.__coroutine) - except StopIteration as stop: - self.__state = AsyncUpdateLoop.Finished - self.returned.emit(stop.value) - self.finished.emit() - self.__coroutine = None - except BaseException as er: - self.__state = AsyncUpdateLoop.Finished - self.raised.emit(er) - self.finished.emit() - self.__coroutine = None - else: - self.yielded.emit(rval) - self.__schedule_next() - - @Slot() - def __on_timeout(self): - assert self.__next_pending - self.__next_pending = False - if not self.__in_next: - self.__in_next = True - try: - self.__next() - finally: - self.__in_next = False - else: - # warn - self.__schedule_next() +class Result(namespace): + projector = None # type: FreeViz + projection = None # type: FreeVizModel - def customEvent(self, event): - if event.type() == AsyncUpdateLoop.Next: - self.__on_timeout() - else: - super().customEvent(event) + +MAX_ITERATIONS = 1000 + + +def run_freeviz(data: Table, projector: FreeViz, state: TaskState): + res = Result(projector=projector, projection=None) + step, steps = 0, MAX_ITERATIONS + initial = res.projector.components_.T + state.set_status("Calculating...") + while True: + # Needs a copy because projection should not be modified inplace. + # If it is modified inplace, the widget and the thread hold a + # reference to the same object. When the thread is interrupted it + # is still modifying the object, but the widget receives it + # (the modified object) with a delay. + res.projection = res.projector(data).copy() + anchors = res.projector.components_.T + res.projector.initial = anchors + + state.set_partial_result(res) + if np.allclose(initial, anchors, rtol=1e-5, atol=1e-4): + return res + initial = anchors + + step += 1 + state.set_progress_value(100 * step / steps) + if state.is_interruption_requested(): + return res class OWFreeVizGraph(OWGraphWithAnchors): @@ -207,8 +122,7 @@ def items(): return ["Circular", "Random"] -class OWFreeViz(OWAnchorProjectionWidget): - MAX_ITERATIONS = 1000 +class OWFreeViz(OWAnchorProjectionWidget, ConcurrentWidgetMixin): MAX_INSTANCES = 10000 name = "FreeViz" @@ -237,11 +151,8 @@ class Warning(OWAnchorProjectionWidget.Warning): " two values are not shown.") def __init__(self): - super().__init__() - self._loop = AsyncUpdateLoop(parent=self) - self._loop.yielded.connect(self.__set_projection) - self._loop.finished.connect(self.__freeviz_finished) - self._loop.raised.connect(self.__on_error) + OWAnchorProjectionWidget.__init__(self) + ConcurrentWidgetMixin.__init__(self) def _add_controls(self): self.__add_controls_start_box() @@ -258,8 +169,7 @@ def __add_controls_start_box(self): box, self, "initialization", label="Initialization:", items=InitType.items(), orientation=Qt.Horizontal, labelWidth=90, callback=self.__init_combo_changed) - self.btn_start = gui.button( - box, self, "Optimize", self.__toggle_start, enabled=False) + self.run_button = gui.button(box, self, "Start", self._toggle_run) @property def effective_variables(self): @@ -269,59 +179,73 @@ def effective_variables(self): def __radius_slider_changed(self): self.graph.update_radius() - def __toggle_start(self): - if self._loop.isRunning(): - self._loop.cancel() - self.btn_start.setText("Optimize") - self.progressBarFinished(processEvents=False) - else: - self._start() - def __init_combo_changed(self): - if self.data is None: - return - running = self._loop.isRunning() - if running: - self._loop.cancel() + self.Error.proj_error.clear() self.init_projection() - self.graph.update_coordinates() + self.setup_plot() self.commit() - if running: - self._start() - - def _start(self): - def update_freeviz(anchors): - while True: - self.projection = self.projector(self.effective_data) - _anchors = self.projector.components_.T - self.projector.initial = _anchors - yield _anchors - if np.allclose(anchors, _anchors, rtol=1e-5, atol=1e-4): - return - anchors = _anchors + if self.task is not None: + self._run() + + def _toggle_run(self): + if self.task is not None: + self.cancel() + self.graph.set_sample_size(None) + self.run_button.setText("Resume") + self.commit() + else: + self._run() + def _run(self): + if self.data is None: + return self.graph.set_sample_size(self.SAMPLE_SIZE) - self._loop.setCoroutine(update_freeviz(self.projector.components_.T)) - self.btn_start.setText("Stop") - self.progressBarInit() - self.setBlocking(True) - self.setStatusMessage("Optimizing") - - def __set_projection(self, _): - # Set/update the projection matrix and coordinate embeddings - self.progressBarAdvance(100. / self.MAX_ITERATIONS) + self.run_button.setText("Stop") + self.start(run_freeviz, self.effective_data, self.projector) + + # ConcurrentWidgetMixin + def on_partial_result(self, result: Result): + assert isinstance(result.projector, FreeViz) + assert isinstance(result.projection, FreeVizModel) + self.projector = result.projector + self.projection = result.projection self.graph.update_coordinates() + self.graph.update_density() - def __freeviz_finished(self): + def on_done(self, result: Result): + assert isinstance(result.projector, FreeViz) + assert isinstance(result.projection, FreeVizModel) + self.projector = result.projector + self.projection = result.projection self.graph.set_sample_size(None) - self.btn_start.setText("Optimize") - self.setStatusMessage("") - self.setBlocking(False) - self.progressBarFinished() + self.run_button.setText("Start") self.commit() - def __on_error(self, err): - sys.excepthook(type(err), err, getattr(err, "__traceback__")) + def on_exception(self, ex: Exception): + self.Error.proj_error(ex) + self.graph.set_sample_size(None) + self.run_button.setText("Start") + + # OWAnchorProjectionWidget + def set_data(self, data): + super().set_data(data) + if self._invalidated: + self.init_projection() + + def init_projection(self): + if self.data is None: + return + anchors = FreeViz.init_radial(len(self.effective_variables)) \ + if self.initialization == InitType.Circular \ + else FreeViz.init_random(len(self.effective_variables), 2) + self.projector = FreeViz(scale=False, center=False, + initial=anchors, maxiter=10) + data = self.projector.preprocess(self.effective_data) + self.projector.domain = data.domain + self.projector.components_ = anchors.T + self.projection = FreeVizModel(self.projector, self.projector.domain, 2) + self.projection.pre_domain = data.domain + self.projection.name = self.projector.name def check_data(self): def error(err): @@ -346,25 +270,11 @@ def error(err): else: if len(self.effective_variables) < len(domain.attributes): self.Warning.removed_features() - self.btn_start.setEnabled(self.data is not None) - def set_data(self, data): - super().set_data(data) - if self.data is not None: - self.init_projection() - - def init_projection(self): - anchors = FreeViz.init_radial(len(self.effective_variables)) \ - if self.initialization == InitType.Circular \ - else FreeViz.init_random(len(self.effective_variables), 2) - self.projector = FreeViz(scale=False, center=False, - initial=anchors, maxiter=10) - data = self.projector.preprocess(self.effective_data) - self.projector.domain = data.domain - self.projector.components_ = anchors.T - self.projection = FreeVizModel(self.projector, self.projector.domain, 2) - self.projection.pre_domain = data.domain - self.projection.name = self.projector.name + def enable_controls(self): + super().enable_controls() + self.run_button.setEnabled(self.data is not None) + self.run_button.setText("Start") def get_coordinates_data(self): embedding = self.get_embedding() @@ -379,7 +289,11 @@ def _manual_move(self, anchor_idx, x, y): def clear(self): super().clear() - self._loop.cancel() + self.cancel() + + def onDeleteWidget(self): + self.shutdown() + super().onDeleteWidget() @classmethod def migrate_settings(cls, _settings, version): @@ -416,5 +330,5 @@ def boundingRect(self): if __name__ == "__main__": # pragma: no cover - data = Table("zoo") - WidgetPreview(OWFreeViz).run(set_data=data, set_subset_data=data[::10]) + table = Table("zoo") + WidgetPreview(OWFreeViz).run(set_data=table, set_subset_data=table[::10]) diff --git a/Orange/widgets/visualize/tests/test_owfreeviz.py b/Orange/widgets/visualize/tests/test_owfreeviz.py index b9a90a0ff6f..6cde4e7bef7 100644 --- a/Orange/widgets/visualize/tests/test_owfreeviz.py +++ b/Orange/widgets/visualize/tests/test_owfreeviz.py @@ -1,15 +1,19 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring import warnings +import unittest +from unittest.mock import Mock import numpy as np from Orange.data import Table +from Orange.projection import FreeViz +from Orange.projection.freeviz import FreeVizModel from Orange.widgets.tests.base import ( WidgetTest, WidgetOutputsTestMixin, AnchorProjectionWidgetTestMixin ) from Orange.widgets.tests.utils import simulate -from Orange.widgets.visualize.owfreeviz import OWFreeViz +from Orange.widgets.visualize.owfreeviz import OWFreeViz, Result, run_freeviz class TestOWFreeViz(WidgetTest, AnchorProjectionWidgetTestMixin, @@ -49,22 +53,40 @@ def test_error_msg(self): def test_optimization(self): self.send_signal(self.widget.Inputs.data, self.heart_disease) - self.widget.btn_start.click() + self.widget.run_button.click() + self.assertEqual(self.widget.run_button.text(), "Stop") def test_optimization_cancelled(self): self.test_optimization() - self.widget.btn_start.click() + self.widget.run_button.click() + self.assertEqual(self.widget.run_button.text(), "Resume") def test_optimization_reset(self): - self.send_signal(self.widget.Inputs.data, self.data) + self.test_optimization() init = self.widget.controls.initialization simulate.combobox_activate_index(init, 0) + self.assertEqual(self.widget.run_button.text(), "Stop") simulate.combobox_activate_index(init, 1) + self.assertEqual(self.widget.run_button.text(), "Stop") + + def test_optimization_finish(self): + self.send_signal(self.widget.Inputs.data, self.data) + output1 = self.get_output(self.widget.Outputs.components) + self.widget.run_button.click() + self.assertEqual(self.widget.run_button.text(), "Stop") + self.wait_until_stop_blocking() + self.assertEqual(self.widget.run_button.text(), "Start") + output2 = self.get_output(self.widget.Outputs.components) + self.assertTrue((output1.X != output2.X).any()) + + def test_optimization_no_data(self): + self.widget.run_button.click() + self.assertEqual(self.widget.run_button.text(), "Start") def test_constant_data(self): data = Table("titanic")[56:59] self.send_signal(self.widget.Inputs.data, data) - self.widget.btn_start.click() + self.widget.run_button.click() self.assertTrue(self.widget.Error.constant_data.is_shown()) self.send_signal(self.widget.Inputs.data, None) self.assertFalse(self.widget.Error.constant_data.is_shown()) @@ -100,4 +122,47 @@ def test_discrete_attributes(self): zoo = Table("zoo") self.send_signal(self.widget.Inputs.data, zoo) self.assertTrue(self.widget.Warning.removed_features.is_shown()) - self.widget.btn_start.click() + self.widget.run_button.click() + + +class TestOWFreeVizRunner(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.data = Table("iris") + + def setUp(self): + anchors = FreeViz.init_radial(len(self.data.domain.attributes)) + self.projector = projector = FreeViz(scale=False, center=False, + initial=anchors, maxiter=10) + self.projector.domain = self.data.domain + self.projector.components_ = anchors.T + self.projection = FreeVizModel(projector, projector.domain, 2) + self.projection.pre_domain = self.data.domain + + def test_Result(self): + result = Result(projector=self.projector, projection=self.projection) + self.assertIsInstance(result.projector, FreeViz) + self.assertIsInstance(result.projection, FreeVizModel) + + def test_run(self): + state = Mock() + state.is_interruption_requested = Mock(return_value=False) + result = run_freeviz(self.data, self.projector, state) + array = np.array([[1.66883742e-01, 9.40395481e-38], + [-8.86817512e-02, 9.96060012e-01], + [6.67450609e-02, -3.97675811e-01], + [-1.44947052e-01, -5.98384200e-01]]) + np.testing.assert_almost_equal(array.T, result.projection.components_) + state.set_status.assert_called_once_with("Calculating...") + self.assertGreater(state.set_partial_result.call_count, 40) + self.assertGreater(state.set_progress_value.call_count, 40) + + def test_run_do_not_modify_model_inplace(self): + state = Mock() + state.is_interruption_requested.return_value = True + result = run_freeviz(self.data, self.projector, state) + state.set_partial_result.assert_called_once() + self.assertIs(self.projector, result.projector) + self.assertIsNot(self.projection.proj, result.projection.proj) + self.assertTrue((self.projection.components_.T != + result.projection.components_.T).any()) From bee049bd210b13539577e15a059af332c55b445a Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Fri, 1 Mar 2019 07:30:51 +0100 Subject: [PATCH 2/3] OWConcurrentProjectionWidget: Add an example --- Orange/widgets/tests/base.py | 5 +- .../widgets/utils/tests/concurrent_example.py | 129 ++++++++++++++++++ .../utils/tests/test_concurrent_example.py | 60 ++++++++ 3 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 Orange/widgets/utils/tests/concurrent_example.py create mode 100644 Orange/widgets/utils/tests/test_concurrent_example.py diff --git a/Orange/widgets/tests/base.py b/Orange/widgets/tests/base.py index 7ec3020df10..b427ef2c5b4 100644 --- a/Orange/widgets/tests/base.py +++ b/Orange/widgets/tests/base.py @@ -857,10 +857,13 @@ def _compare_selected_annotated_domains(self, selected, annotated): annotated_vars = annotated.domain.variables self.assertLessEqual(set(selected_vars), set(annotated_vars)) - def test_setup_graph(self): + def test_setup_graph(self, timeout=DEFAULT_TIMEOUT): """Plot should exist after data has been sent in order to be properly set/updated""" self.send_signal(self.widget.Inputs.data, self.data) + if self.widget.isBlocking(): + spy = QSignalSpy(self.widget.blockingStateChanged) + self.assertTrue(spy.wait(timeout)) self.assertIsNotNone(self.widget.graph.scatterplot_item) def test_default_attrs(self, timeout=DEFAULT_TIMEOUT): diff --git a/Orange/widgets/utils/tests/concurrent_example.py b/Orange/widgets/utils/tests/concurrent_example.py new file mode 100644 index 00000000000..c37b067ca6c --- /dev/null +++ b/Orange/widgets/utils/tests/concurrent_example.py @@ -0,0 +1,129 @@ +# pylint: disable=too-many-ancestors +from typing import Optional +from types import SimpleNamespace as namespace + +import numpy as np + +from Orange.data import Table +from Orange.widgets import gui +from Orange.widgets.settings import Setting +from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin +from Orange.widgets.utils.widgetpreview import WidgetPreview +from Orange.widgets.visualize.utils.widget import OWDataProjectionWidget + + +class Result(namespace): + embedding = None # type: Optional[np.ndarray] + + +def run(data: Table, embedding: Optional[np.ndarray], state: TaskState): + res = Result(embedding=embedding) + + # simulate wasteful calculation (increase 'steps') + step, steps = 0, 10 + state.set_status("Calculating...") + while step < steps: + for _ in range(steps): + x_data = np.array(np.mean(data.X, axis=1)) + if x_data.ndim == 2: + x_data = x_data.ravel() + y_data = np.random.rand(len(x_data)) + embedding = np.vstack((x_data, y_data)).T + step += 1 + if step % (steps / 10) == 0: + state.set_progress_value(100 * step / steps) + + if state.is_interruption_requested(): + return res + + res.embedding = embedding + state.set_partial_result(res) + return res + + +class OWConcurrentWidget(OWDataProjectionWidget, ConcurrentWidgetMixin): + name = "Projection" + param = Setting(0) + + def __init__(self): + OWDataProjectionWidget.__init__(self) + ConcurrentWidgetMixin.__init__(self) + self.embedding = None # type: Optional[np.ndarray] + + # GUI + def _add_controls(self): + box = gui.vBox(self.controlArea, True) + gui.comboBox( + box, self, "param", label="Parameter:", + items=["Param A", "Param B"], labelWidth=80, + callback=self.__param_combo_changed + ) + self.run_button = gui.button(box, self, "Start", self._toggle_run) + super()._add_controls() + + def __param_combo_changed(self): + self._run() + + def _toggle_run(self): + # Pause task + if self.task is not None: + self.cancel() + self.run_button.setText("Resume") + self.commit() + # Resume task + else: + self._run() + + def _run(self): + if self.data is None: + return + self.run_button.setText("Stop") + self.start(run, self.data, self.embedding) + + # ConcurrentWidgetMixin + def on_partial_result(self, result: Result): + assert isinstance(result.embedding, np.ndarray) + assert len(result.embedding) == len(self.data) + first_result = self.embedding is None + self.embedding = result.embedding + if first_result: + self.setup_plot() + else: + self.graph.update_coordinates() + self.graph.update_density() + + def on_done(self, result: Result): + assert isinstance(result.embedding, np.ndarray) + assert len(result.embedding) == len(self.data) + self.embedding = result.embedding + self.run_button.setText("Start") + self.commit() + + # OWDataProjectionWidget + def set_data(self, data: Table): + super().set_data(data) + if self._invalidated: + self._run() + + def get_embedding(self): + if self.embedding is None: + self.valid_data = None + return None + + self.valid_data = np.all(np.isfinite(self.embedding), 1) + return self.embedding + + def clear(self): + super().clear() + self.cancel() + self.embedding = None + + def onDeleteWidget(self): + self.shutdown() + super().onDeleteWidget() + + +if __name__ == "__main__": + table = Table("iris") + WidgetPreview(OWConcurrentWidget).run( + set_data=table, set_subset_data=table[::10]) diff --git a/Orange/widgets/utils/tests/test_concurrent_example.py b/Orange/widgets/utils/tests/test_concurrent_example.py new file mode 100644 index 00000000000..9850f012905 --- /dev/null +++ b/Orange/widgets/utils/tests/test_concurrent_example.py @@ -0,0 +1,60 @@ +# Test methods with long descriptive names can omit docstrings +# pylint: disable=missing-docstring +import unittest +from unittest.mock import Mock + +from Orange.data import Table +from Orange.widgets.tests.base import ( + WidgetTest, WidgetOutputsTestMixin, ProjectionWidgetTestMixin +) +from Orange.widgets.utils.tests.concurrent_example import ( + OWConcurrentWidget +) + + +class TestOWConcurrentWidget(WidgetTest, ProjectionWidgetTestMixin, + WidgetOutputsTestMixin): + @classmethod + def setUpClass(cls): + super().setUpClass() + WidgetOutputsTestMixin.init(cls) + + cls.signal_name = "Data" + cls.signal_data = cls.data + cls.same_input_output_domain = False + + def setUp(self): + self.widget = self.create_widget(OWConcurrentWidget) + + def test_button_no_data(self): + self.widget.run_button.click() + self.assertEqual(self.widget.run_button.text(), "Start") + + def test_button_with_data(self): + self.send_signal(self.widget.Inputs.data, self.data) + self.assertEqual(self.widget.run_button.text(), "Stop") + self.wait_until_stop_blocking() + self.assertEqual(self.widget.run_button.text(), "Start") + + def test_button_toggle(self): + self.send_signal(self.widget.Inputs.data, self.data) + self.widget.run_button.click() + self.assertEqual(self.widget.run_button.text(), "Resume") + + def test_plot_once(self): + table = Table("heart_disease") + self.widget.setup_plot = Mock() + self.widget.commit = Mock() + self.send_signal(self.widget.Inputs.data, table) + self.widget.setup_plot.assert_called_once() + self.widget.commit.assert_called_once() + self.wait_until_stop_blocking() + self.widget.setup_plot.reset_mock() + self.widget.commit.reset_mock() + self.send_signal(self.widget.Inputs.data_subset, table[::10]) + self.widget.setup_plot.assert_not_called() + self.widget.commit.assert_called_once() + + +if __name__ == "__main__": + unittest.main() From 8ac9f33dee406cf1870400d3e0e8608dcd0a999c Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Tue, 5 Mar 2019 08:15:22 +0100 Subject: [PATCH 3/3] MDS: Offload work to a separate thread --- Orange/widgets/unsupervised/owmds.py | 294 +++++++----------- .../widgets/unsupervised/tests/test_owmds.py | 40 ++- 2 files changed, 151 insertions(+), 183 deletions(-) diff --git a/Orange/widgets/unsupervised/owmds.py b/Orange/widgets/unsupervised/owmds.py index a6a04c23153..a9d67efb880 100644 --- a/Orange/widgets/unsupervised/owmds.py +++ b/Orange/widgets/unsupervised/owmds.py @@ -1,9 +1,10 @@ -import warnings +# pylint: disable=too-many-ancestors +from types import SimpleNamespace as namespace import numpy as np import scipy.spatial.distance -from AnyQt.QtCore import Qt, QTimer +from AnyQt.QtCore import Qt import pyqtgraph as pg @@ -14,6 +15,7 @@ from Orange.widgets import gui, settings from Orange.widgets.settings import SettingProvider +from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase from Orange.widgets.visualize.utils.widget import OWDataProjectionWidget @@ -29,6 +31,47 @@ def stress(X, distD): return delta_sq.sum(axis=0) / 2 +class Result(namespace): + embedding = None # type: np.ndarray + + +def run_mds(matrix: DistMatrix, max_iter: int, step_size: int, init_type: int, + embedding: np.ndarray, state: TaskState): + res = Result(embedding=embedding) + + iterations_done = 0 + init = embedding + state.set_status("Running...") + oldstress = np.finfo(np.float).max + + while True: + step_iter = min(max_iter - iterations_done, step_size) + mds = MDS( + dissimilarity="precomputed", n_components=2, + n_init=1, max_iter=step_iter, + init_type=init_type, init_data=init + ) + + mdsfit = mds(matrix) + iterations_done += step_iter + + embedding, stress = mdsfit.embedding_, mdsfit.stress_ + emb_norm = np.sqrt(np.sum(embedding ** 2, axis=1)).sum() + if emb_norm > 0: + stress /= emb_norm + + res.embedding = embedding + state.set_partial_result(res) + state.set_progress_value(100 * iterations_done / max_iter) + if iterations_done >= max_iter or stress == 0 or \ + (oldstress - stress) < mds.params["eps"]: + return res + init = embedding + oldstress = stress + if state.is_interruption_requested(): + return res + + #: Maximum number of displayed closest pairs. MAX_N_PAIRS = 10000 @@ -104,7 +147,7 @@ def update_pairs(self, reconnect): self.plot_widget.addItem(self.pairs_curve) -class OWMDS(OWDataProjectionWidget): +class OWMDS(OWDataProjectionWidget, ConcurrentWidgetMixin): name = "MDS" description = "Two-dimensional data projection by multidimensional " \ "scaling constructed from a distance matrix." @@ -129,9 +172,6 @@ class Inputs(OWDataProjectionWidget.Inputs): ("None", -1) ] - #: Runtime state - Running, Finished, Waiting = 1, 2, 3 - max_iter = settings.Setting(300) initialization = settings.Setting(PCA) refresh_rate = settings.Setting(3) @@ -150,7 +190,8 @@ class Error(OWDataProjectionWidget.Error): optimization_error = Msg("Error during optimization\n{}") def __init__(self): - super().__init__() + OWDataProjectionWidget.__init__(self) + ConcurrentWidgetMixin.__init__(self) #: Input dissimilarity matrix self.matrix = None # type: Optional[DistMatrix] #: Data table from the `self.matrix.row_items` (if present) @@ -158,15 +199,8 @@ def __init__(self): #: Input data table self.signal_data = None - self.embedding = None - self.effective_matrix = None - - self.__update_loop = None - # timer for scheduling updates - self.__timer = QTimer(self, singleShot=True, interval=0) - self.__timer.timeout.connect(self.__next_step) - self.__state = OWMDS.Waiting - self.__in_next_step = False + self.embedding = None # type: Optional[np.ndarray] + self.effective_matrix = None # type: Optional[DistMatrix] self.graph.pause_drawing_pairs() @@ -175,7 +209,6 @@ def __init__(self): self.gui.points_models[2].order[:1] \ + ("Stress", ) + \ self.gui.points_models[2].order[1:] - # self._initialize() def _add_controls(self): self._add_controls_optimization() @@ -188,17 +221,20 @@ def _add_controls(self): def _add_controls_optimization(self): box = gui.vBox(self.controlArea, box=True) - self.runbutton = gui.button(box, self, "Run optimization", - callback=self._toggle_run) + self.run_button = gui.button(box, self, "Start", self._toggle_run) gui.comboBox(box, self, "refresh_rate", label="Refresh: ", orientation=Qt.Horizontal, items=[t for t, _ in OWMDS.RefreshRate], - callback=self.__invalidate_refresh) + callback=self.__refresh_rate_combo_changed) hbox = gui.hBox(box, margin=0) gui.button(hbox, self, "PCA", callback=self.do_PCA) gui.button(hbox, self, "Randomize", callback=self.do_random) gui.button(hbox, self, "Jitter", callback=self.do_jitter) + def __refresh_rate_combo_changed(self): + if self.task is not None: + self._run() + def set_data(self, data): """Set the input dataset. @@ -234,10 +270,9 @@ def set_disimilarity(self, matrix): def clear(self): super().clear() + self.cancel() self.embedding = None self.graph.set_effective_matrix(None) - self.__set_update_loop(None) - self.__state = OWMDS.Waiting def _initialize(self): matrix_existed = self.effective_matrix is not None @@ -299,159 +334,70 @@ def init_attr_values(self): self.attr_label = self.data.domain["labels"] def _toggle_run(self): - if self.__state == OWMDS.Running: - self.stop() - self._invalidate_output() + if self.task is not None: + self.cancel() + self.run_button.setText("Resume") + self.commit() else: - self.start() + self._run() - def start(self): - if self.__state == OWMDS.Running: + def _run(self): + if self.effective_matrix is None: return - elif self.__state == OWMDS.Finished: - # Resume/continue from a previous run - self.__start() - elif self.__state == OWMDS.Waiting and \ - self.effective_matrix is not None: - self.__start() - - def stop(self): - if self.__state == OWMDS.Running: - self.__set_update_loop(None) - - def __start(self): self.graph.pause_drawing_pairs() - X = self.effective_matrix - init = self.embedding - - # number of iterations per single GUI update step + self.run_button.setText("Stop") _, step_size = OWMDS.RefreshRate[self.refresh_rate] if step_size == -1: step_size = self.max_iter - - def update_loop(X, max_iter, step, init): - """ - return an iterator over successive improved MDS point embeddings. - """ - # NOTE: this code MUST NOT call into QApplication.processEvents - done = False - iterations_done = 0 - oldstress = np.finfo(np.float).max - init_type = "PCA" if self.initialization == OWMDS.PCA else "random" - - while not done: - step_iter = min(max_iter - iterations_done, step) - mds = MDS( - dissimilarity="precomputed", n_components=2, - n_init=1, max_iter=step_iter, - init_type=init_type, init_data=init - ) - - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", ".*double_scalars.*", RuntimeWarning) - mdsfit = mds(X) - iterations_done += step_iter - - embedding, stress = mdsfit.embedding_, mdsfit.stress_ - emb_norm = np.sqrt(np.sum(embedding ** 2, axis=1)).sum() - if emb_norm > 0: - stress /= emb_norm - - if iterations_done >= max_iter \ - or (oldstress - stress) < mds.params["eps"] \ - or stress == 0: - done = True - init = embedding - oldstress = stress - - yield embedding, mdsfit.stress_, iterations_done / max_iter - - self.__set_update_loop(update_loop(X, self.max_iter, step_size, init)) - self.progressBarInit(processEvents=None) - - def __set_update_loop(self, loop): - """ - Set the update `loop` coroutine. - - The `loop` is a generator yielding `(embedding, stress, progress)` - tuples where `embedding` is a `(N, 2) ndarray` of current updated - MDS points, `stress` is the current stress and `progress` a float - ratio (0 <= progress <= 1) - - If an existing update coroutine loop is already in place it is - interrupted (i.e. closed). - - .. note:: - The `loop` must not explicitly yield control flow to the event - loop (i.e. call `QApplication.processEvents`) - - """ - if self.__update_loop is not None: - self.__update_loop.close() - self.__update_loop = None - self.progressBarFinished(processEvents=None) - - self.__update_loop = loop - - if loop is not None: - self.setBlocking(True) - self.progressBarInit(processEvents=None) - self.setStatusMessage("Running") - self.runbutton.setText("Stop") - self.__state = OWMDS.Running - self.__timer.start() + init_type = "PCA" if self.initialization == OWMDS.PCA else "random" + self.start(run_mds, self.effective_matrix, self.max_iter, + step_size, init_type, self.embedding) + + # ConcurrentWidgetMixin + def on_partial_result(self, result: Result): + assert isinstance(result.embedding, np.ndarray) + assert len(result.embedding) == len(self.effective_matrix) + first_result = self.embedding is None + new_embedding = result.embedding + need_update = new_embedding is not self.embedding + self.embedding = new_embedding + if first_result: + self.setup_plot() else: - self.setBlocking(False) - self.setStatusMessage("") - self.runbutton.setText("Start") - self.__state = OWMDS.Finished - self.__timer.stop() - - def __next_step(self): - if self.__update_loop is None: - return + if need_update: + self.graph.update_coordinates() + self.graph.update_density() + + def on_done(self, result: Result): + assert isinstance(result.embedding, np.ndarray) + assert len(result.embedding) == len(self.effective_matrix) + self.embedding = result.embedding + self.graph.resume_drawing_pairs() + self.run_button.setText("Start") + self.commit() - assert not self.__in_next_step - self.__in_next_step = True - - loop = self.__update_loop - self.Error.out_of_memory.clear() - try: - embedding, _, progress = next(self.__update_loop) - assert self.__update_loop is loop - except StopIteration: - self.__set_update_loop(None) - self.unconditional_commit() - self.graph.resume_drawing_pairs() - except MemoryError: + def on_exception(self, ex: Exception): + if isinstance(ex, MemoryError): self.Error.out_of_memory() - self.__set_update_loop(None) - self.graph.resume_drawing_pairs() - except Exception as exc: - self.Error.optimization_error(str(exc)) - self.__set_update_loop(None) - self.graph.resume_drawing_pairs() else: - self.progressBarSet(100.0 * progress, processEvents=None) - self.embedding = embedding - self.graph.update_coordinates() - # schedule next update - self.__timer.start() - - self.__in_next_step = False + self.Error.optimization_error(str(ex)) + self.graph.resume_drawing_pairs() + self.run_button.setText("Start") def do_PCA(self): - self.__invalidate_embedding(self.PCA) - self.setup_plot() + self.do_initialization(self.PCA) def do_random(self): - self.__invalidate_embedding(self.Random) - self.setup_plot() + self.do_initialization(self.Random) def do_jitter(self): - self.__invalidate_embedding(self.Jitter) + self.do_initialization(self.Jitter) + + def do_initialization(self, init_type: int): + self.run_button.setText("Start") + self.__invalidate_embedding(init_type) self.setup_plot() + self.commit() def __invalidate_embedding(self, initialization=PCA): def jitter_coord(part): @@ -460,10 +406,6 @@ def jitter_coord(part): # reset/invalidate the MDS embedding, to the default initialization # (Random or PCA), restarting the optimization if necessary. - state = self.__state - if self.__update_loop is not None: - self.__set_update_loop(None) - if self.effective_matrix is None: self.graph.reset_graph() return @@ -479,20 +421,8 @@ def jitter_coord(part): jitter_coord(self.embedding[:, 1]) # restart the optimization if it was interrupted. - if state == OWMDS.Running: - self.__start() - - def __invalidate_refresh(self): - state = self.__state - - if self.__update_loop is not None: - self.__set_update_loop(None) - - # restart the optimization if it was interrupted. - # TODO: decrease the max iteration count by the already - # completed iterations count. - if state == OWMDS.Running: - self.__start() + if self.task is not None: + self._run() def handleNewSignals(self): self._initialize() @@ -500,12 +430,10 @@ def handleNewSignals(self): self.graph.pause_drawing_pairs() self.__invalidate_embedding() self.enable_controls() - self.start() + if self.effective_matrix is not None: + self._run() super().handleNewSignals() - def _invalidate_output(self): - self.commit() - def _on_connected_changed(self): self.graph.set_effective_matrix(self.effective_matrix) self.graph.update_pairs(reconnect=True) @@ -536,6 +464,10 @@ def _get_projection_data(self): return Table(Domain(variables), self.embedding) return super()._get_projection_data() + def onDeleteWidget(self): + self.shutdown() + super().onDeleteWidget() + @classmethod def migrate_settings(cls, settings_, version): if version < 2: @@ -582,5 +514,5 @@ def migrate_context(cls, context, version): if __name__ == "__main__": # pragma: no cover - data = Table("iris") - WidgetPreview(OWMDS).run(set_data=data, set_subset_data=data[:30]) + table = Table("iris") + WidgetPreview(OWMDS).run(set_data=table, set_subset_data=table[:30]) diff --git a/Orange/widgets/unsupervised/tests/test_owmds.py b/Orange/widgets/unsupervised/tests/test_owmds.py index 39ad10f55f3..01a1a914172 100644 --- a/Orange/widgets/unsupervised/tests/test_owmds.py +++ b/Orange/widgets/unsupervised/tests/test_owmds.py @@ -10,12 +10,13 @@ from Orange.data import Table from Orange.distance import Euclidean from Orange.misc import DistMatrix +from Orange.projection.manifold import torgerson from Orange.widgets.settings import Context from Orange.widgets.tests.base import ( WidgetTest, WidgetOutputsTestMixin, datasets, ProjectionWidgetTestMixin ) from Orange.widgets.tests.utils import simulate -from Orange.widgets.unsupervised.owmds import OWMDS +from Orange.widgets.unsupervised.owmds import OWMDS, run_mds, Result class TestOWMDS(WidgetTest, ProjectionWidgetTestMixin, @@ -123,7 +124,7 @@ def test_small_data(self): def test_run(self): self.send_signal(self.widget.Inputs.data, self.data) - self.widget.runbutton.click() + self.widget.run_button.click() self.widget.initialization = 0 self.widget._OWMDS__invalidate_embedding() # pylint: disable=protected-access @@ -289,5 +290,40 @@ def test_matrix_columns_default_label(self): self.assertEqual(label_text, "labels") +class TestOWMDSRunner(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.data = Table("iris") + cls.distances = Euclidean(cls.data) + cls.init = torgerson(cls.distances) + cls.args = (cls.distances, 300, 25, 0, cls.init) + + def test_Result(self): + result = Result(embedding=self.init) + self.assertIsInstance(result.embedding, np.ndarray) + + def test_run_mds(self): + state = Mock() + state.is_interruption_requested.return_value = False + result = run_mds(*(self.args + (state,))) + array = np.array([[-2.69280967, 0.32544313], + [-2.72409383, -0.21287617], + [-2.9022707, -0.13465859], + [-2.75267253, -0.33899134], + [-2.74108069, 0.35393209]]) + np.testing.assert_almost_equal(array, result.embedding[:5]) + state.set_status.assert_called_once_with("Running...") + self.assertGreater(state.set_partial_result.call_count, 2) + self.assertGreater(state.set_progress_value.call_count, 2) + + def test_run_do_not_modify_model_inplace(self): + state = Mock() + state.is_interruption_requested.return_value = True + result = run_mds(*(self.args + (state,))) + state.set_partial_result.assert_called_once() + self.assertIsNot(self.init, result.embedding) + self.assertTrue((self.init != result.embedding).any()) + + if __name__ == "__main__": unittest.main()