Skip to content

Commit

Permalink
Merge pull request #3673 from VesnaT/transform_enh
Browse files Browse the repository at this point in the history
[FIX] Transform: Replace 'Preprocess' input with 'Template Data' input
  • Loading branch information
janezd authored Mar 15, 2019
2 parents 021604b + 7ae3285 commit b5d6f5a
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 197 deletions.
12 changes: 0 additions & 12 deletions Orange/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,18 +536,6 @@ def transform(var):
return data.transform(domain)


class ApplyDomain(Preprocess):
def __init__(self, domain, name):
self._domain = domain
self._name = name

def __call__(self, data):
return data.transform(self._domain)

def __str__(self):
return self._name


class PreprocessorList(Preprocess):
"""
Store a list of preprocessors and on call apply them to the dataset.
Expand Down
129 changes: 64 additions & 65 deletions Orange/widgets/data/owtransform.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,77 @@
from typing import Optional

import numpy as np

from Orange.data import Table, Domain
from Orange.preprocess.preprocess import Preprocess, Discretize
from Orange.widgets import gui
from Orange.widgets.report.report import describe_data
from Orange.widgets.settings import Setting
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.widget import OWWidget, Input, Output, Msg


class OWTransform(OWWidget):
name = "Transform"
description = "Transform data table."
name = "Apply Domain"
description = "Applies template domain on data table."
icon = "icons/Transform.svg"
priority = 2110
keywords = []
keywords = ["transform"]

retain_all_data = Setting(False)

class Inputs:
data = Input("Data", Table, default=True)
preprocessor = Input("Preprocessor", Preprocess)
template_data = Input("Template Data", Table)

class Outputs:
transformed_data = Output("Transformed Data", Table)

class Error(OWWidget.Error):
pp_error = Msg("An error occurred while transforming data.\n{}")
error = Msg("An error occurred while transforming data.\n{}")

resizing_enabled = False
want_main_area = False

def __init__(self):
super().__init__()
self.data = None
self.preprocessor = None
self.transformed_data = None
self.data = None # type: Optional[Table]
self.template_domain = None # type: Optional[Domain]
self.transformed_info = describe_data(None) # type: OrderedDict

info_box = gui.widgetBox(self.controlArea, "Info")
self.input_label = gui.widgetLabel(info_box, "")
self.preprocessor_label = gui.widgetLabel(info_box, "")
self.template_label = gui.widgetLabel(info_box, "")
self.output_label = gui.widgetLabel(info_box, "")
self.set_input_label_text()
self.set_preprocessor_label_text()
self.set_template_label_text()

self.retain_all_data_cb = gui.checkBox(
self.controlArea, self, "retain_all_data", label="Retain all data",
callback=self.apply
)
box = gui.widgetBox(self.controlArea, "Output")
gui.checkBox(box, self, "retain_all_data", "Retain all data",
callback=self.apply)

def set_input_label_text(self):
text = "No data on input."
if self.data is not None:
if self.data:
text = "Input data with {:,} instances and {:,} features.".format(
len(self.data),
len(self.data.domain.attributes))
self.input_label.setText(text)

def set_preprocessor_label_text(self):
text = "No preprocessor on input."
if self.transformed_data is not None:
text = "Preprocessor {} applied.".format(self.preprocessor)
elif self.preprocessor is not None:
text = "Preprocessor {} on input.".format(self.preprocessor)
self.preprocessor_label.setText(text)
def set_template_label_text(self):
text = "No template data on input."
if self.data and self.template_domain is not None:
text = "Template domain applied."
elif self.template_domain is not None:
text = "Template data includes {:,} features.".format(
len(self.template_domain.attributes))
self.template_label.setText(text)

def set_output_label_text(self):
def set_output_label_text(self, data):
text = ""
if self.transformed_data:
if data:
text = "Output data includes {:,} features.".format(
len(self.transformed_data.domain.attributes))
len(data.domain.attributes))
self.output_label.setText(text)

@Inputs.data
Expand All @@ -78,56 +80,53 @@ def set_data(self, data):
self.data = data
self.set_input_label_text()

@Inputs.preprocessor
def set_preprocessor(self, preprocessor):
self.preprocessor = preprocessor
@Inputs.template_data
@check_sql_input
def set_template_data(self, data):
self.template_domain = data and data.domain

def handleNewSignals(self):
self.apply()

def apply(self):
self.clear_messages()
self.transformed_data = None
if self.data is not None and self.preprocessor is not None:
transformed_data = None
if self.data and self.template_domain is not None:
try:
self.transformed_data = self.preprocessor(self.data)
except Exception as ex: # pylint: disable=broad-except
self.Error.pp_error(ex)

if self.retain_all_data:
self.Outputs.transformed_data.send(self.merge_data())
else:
self.Outputs.transformed_data.send(self.transformed_data)

self.set_preprocessor_label_text()
self.set_output_label_text()

def merge_data(self):
attributes = getattr(self.data.domain, 'attributes')
cls_vars = getattr(self.data.domain, 'class_vars')
metas_v = getattr(self.data.domain, 'metas')\
+ getattr(self.transformed_data.domain, 'attributes')
domain = Domain(attributes, cls_vars, metas_v)
X = self.data.X
Y = self.data.Y
metas = np.hstack((self.data.metas, self.transformed_data.X))
table = Table.from_numpy(domain, X, Y, metas)
table.name = getattr(self.data, 'name', '')
table.attributes = getattr(self.data, 'attributes', {})
table.ids = self.data.ids
return table
transformed_data = self.data.transform(self.template_domain)
except Exception as ex: # pylint: disable=broad-except
self.Error.error(ex)

data = transformed_data
if data and self.retain_all_data:
data = self.merged_data(data)
self.transformed_info = describe_data(data)
self.Outputs.transformed_data.send(data)
self.set_template_label_text()
self.set_output_label_text(data)

def merged_data(self, t_data):
domain = self.data.domain
t_domain = t_data.domain
metas = domain.metas + t_domain.attributes + t_domain.metas
domain = Domain(domain.attributes, domain.class_vars, metas)
data = self.data.transform(domain)
metas = np.hstack((t_data.X, t_data.metas))
data.metas[:, -metas.shape[1]:] = metas
return data

def send_report(self):
if self.preprocessor is not None:
self.report_items("Settings",
(("Preprocessor", self.preprocessor),))
if self.data is not None:
if self.data:
self.report_data("Data", self.data)
if self.transformed_data is not None:
self.report_data("Transformed data", self.transformed_data)
if self.template_domain is not None:
self.report_domain("Template data", self.template_domain)
if self.transformed_info:
self.report_items("Transformed data", self.transformed_info)


if __name__ == "__main__": # pragma: no cover
from Orange.preprocess import Discretize

table = Table("iris")
WidgetPreview(OWTransform).run(
set_data=Table("iris"),
set_preprocessor=Discretize())
set_data=table, set_template_data=Discretize()(table))
102 changes: 60 additions & 42 deletions Orange/widgets/data/tests/test_owtransform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
from unittest.mock import Mock

from numpy import testing as npt

from Orange.data import Table
from Orange.preprocess import Discretize
from Orange.preprocess.preprocess import Preprocess
from Orange.preprocess import Discretize, Continuize
from Orange.widgets.data.owtransform import OWTransform
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.unsupervised.owpca import OWPCA
Expand All @@ -12,38 +15,39 @@ class TestOWTransform(WidgetTest):
def setUp(self):
self.widget = self.create_widget(OWTransform)
self.data = Table("iris")
self.preprocessor = Discretize()
self.disc_data = Discretize()(self.data)

def test_output(self):
# send data and preprocessor
self.send_signal(self.widget.Inputs.data, self.data)
self.send_signal(self.widget.Inputs.preprocessor, self.preprocessor)
# send data and template data
self.send_signal(self.widget.Inputs.data, self.data[::15])
self.send_signal(self.widget.Inputs.template_data, self.disc_data)
output = self.get_output(self.widget.Outputs.transformed_data)
self.assertIsInstance(output, Table)
self.assertEqual("Input data with 150 instances and 4 features.",
self.assertTableEqual(output, self.disc_data[::15])
self.assertEqual("Input data with 10 instances and 4 features.",
self.widget.input_label.text())
self.assertEqual("Preprocessor Discretize() applied.",
self.widget.preprocessor_label.text())
self.assertEqual("Template domain applied.",
self.widget.template_label.text())
self.assertEqual("Output data includes 4 features.",
self.widget.output_label.text())

# remove preprocessor
self.send_signal(self.widget.Inputs.preprocessor, None)
# remove template data
self.send_signal(self.widget.Inputs.template_data, None)
output = self.get_output(self.widget.Outputs.transformed_data)
self.assertIsNone(output)
self.assertEqual("Input data with 150 instances and 4 features.",
self.assertEqual("Input data with 10 instances and 4 features.",
self.widget.input_label.text())
self.assertEqual("No preprocessor on input.", self.widget.preprocessor_label.text())
self.assertEqual("No template data on input.",
self.widget.template_label.text())
self.assertEqual("", self.widget.output_label.text())

# send preprocessor
self.send_signal(self.widget.Inputs.preprocessor, self.preprocessor)
# send template data
self.send_signal(self.widget.Inputs.template_data, self.disc_data)
output = self.get_output(self.widget.Outputs.transformed_data)
self.assertIsInstance(output, Table)
self.assertEqual("Input data with 150 instances and 4 features.",
self.assertTableEqual(output, self.disc_data[::15])
self.assertEqual("Input data with 10 instances and 4 features.",
self.widget.input_label.text())
self.assertEqual("Preprocessor Discretize() applied.",
self.widget.preprocessor_label.text())
self.assertEqual("Template domain applied.",
self.widget.template_label.text())
self.assertEqual("Output data includes 4 features.",
self.widget.output_label.text())

Expand All @@ -52,49 +56,63 @@ def test_output(self):
output = self.get_output(self.widget.Outputs.transformed_data)
self.assertIsNone(output)
self.assertEqual("No data on input.", self.widget.input_label.text())
self.assertEqual("Preprocessor Discretize() on input.",
self.widget.preprocessor_label.text())
self.assertEqual("Template data includes 4 features.",
self.widget.template_label.text())
self.assertEqual("", self.widget.output_label.text())

# remove preprocessor
self.send_signal(self.widget.Inputs.preprocessor, None)
# remove template data
self.send_signal(self.widget.Inputs.template_data, None)
self.assertEqual("No data on input.", self.widget.input_label.text())
self.assertEqual("No preprocessor on input.",
self.widget.preprocessor_label.text())
self.assertEqual("No template data on input.",
self.widget.template_label.text())
self.assertEqual("", self.widget.output_label.text())

def test_input_pca_preprocessor(self):
def assertTableEqual(self, table1, table2):
self.assertIs(table1.domain, table2.domain)
npt.assert_array_equal(table1.X, table2.X)
npt.assert_array_equal(table1.Y, table2.Y)
npt.assert_array_equal(table1.metas, table2.metas)

def test_input_pca_output(self):
owpca = self.create_widget(OWPCA)
self.send_signal(owpca.Inputs.data, self.data, widget=owpca)
owpca.components_spin.setValue(2)
pp = self.get_output(owpca.Outputs.preprocessor, widget=owpca)
self.assertIsNotNone(pp, Preprocess)
pca_out = self.get_output(owpca.Outputs.transformed_data, widget=owpca)

self.send_signal(self.widget.Inputs.data, self.data)
self.send_signal(self.widget.Inputs.preprocessor, pp)
self.send_signal(self.widget.Inputs.data, self.data[::10])
self.send_signal(self.widget.Inputs.template_data, pca_out)
output = self.get_output(self.widget.Outputs.transformed_data)
self.assertIsInstance(output, Table)
self.assertEqual(output.X.shape, (len(self.data), 2))
npt.assert_array_equal(pca_out.X[::10], output.X)

# test retain data functionality
self.widget.retain_all_data = True
self.widget.apply()
def test_retain_all_data(self):
data = Table("zoo")
cont_data = Continuize()(data)
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.template_data, cont_data)
self.widget.controls.retain_all_data.click()
output = self.get_output(self.widget.Outputs.transformed_data)
self.assertIsInstance(output, Table)
self.assertEqual(output.X.shape, (len(self.data), 4))
self.assertEqual(output.metas.shape, (len(self.data), 2))
self.assertEqual(output.X.shape, (len(data), 16))
self.assertEqual(output.metas.shape, (len(data), 38))

def test_error_transforming(self):
self.send_signal(self.widget.Inputs.data, self.data)
self.send_signal(self.widget.Inputs.preprocessor, Preprocess())
self.assertTrue(self.widget.Error.pp_error.is_shown())
data = self.data[::10]
data.transform = Mock(side_effect=Exception())
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.template_data, self.disc_data)
self.assertTrue(self.widget.Error.error.is_shown())
output = self.get_output(self.widget.Outputs.transformed_data)
self.assertIsNone(output)
self.send_signal(self.widget.Inputs.data, None)
self.assertFalse(self.widget.Error.pp_error.is_shown())
self.assertFalse(self.widget.Error.error.is_shown())

def test_send_report(self):
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.report_button.click()
self.send_signal(self.widget.Inputs.data, None)
self.widget.report_button.click()


if __name__ == "__main__":
import unittest
unittest.main()
Loading

0 comments on commit b5d6f5a

Please sign in to comment.