Skip to content

Commit

Permalink
Edit Domain: Add option to remove compute_value
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jun 19, 2020
1 parent 1ad65b9 commit 3629531
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 62 deletions.
139 changes: 96 additions & 43 deletions Orange/widgets/data/oweditdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Categorical(
("name", str),
("categories", Tuple[str, ...]),
("annotations", AnnotationsType),
("linked", bool)
])): pass


Expand All @@ -104,20 +105,23 @@ class Real(
# a precision (int, and a format specifier('f', 'g', or '')
("format", Tuple[int, str]),
("annotations", AnnotationsType),
("linked", bool)
])): pass


class String(
_DataType, NamedTuple("String", [
("name", str),
("annotations", AnnotationsType),
("linked", bool)
])): pass


class Time(
_DataType, NamedTuple("Time", [
("name", str),
("annotations", AnnotationsType),
("linked", bool)
])): pass


Expand Down Expand Up @@ -175,10 +179,14 @@ def __call__(self, var):
return var._replace(annotations=self.annotations)


Transform = Union[Rename, CategoriesMapping, Annotate]
TransformTypes = (Rename, CategoriesMapping, Annotate)
class Unlink(_DataType, namedtuple("Unlink", [])):
"""Unlink variable from its source, that is, remove compute_value"""

CategoricalTransformTypes = (CategoriesMapping, )

Transform = Union[Rename, CategoriesMapping, Annotate, Unlink]
TransformTypes = (Rename, CategoriesMapping, Annotate, Unlink)

CategoricalTransformTypes = (CategoriesMapping, Unlink)


# Reinterpret vector transformations.
Expand Down Expand Up @@ -221,7 +229,7 @@ def __call__(self, vector: DataVector) -> StringVector:
if isinstance(var, String):
return vector
return StringVector(
String(var.name, var.annotations),
String(var.name, var.annotations, False),
lambda: as_string(vector.data()),
)

Expand All @@ -241,19 +249,19 @@ def data() -> MArray:
a = categorical_to_string_vector(d, var.values)
return MArray(as_float_or_nan(a, where=a.mask), mask=a.mask)
return RealVector(
Real(var.name, (6, 'g'), var.annotations), data
Real(var.name, (6, 'g'), var.annotations, var.linked), data
)
elif isinstance(var, Time):
return RealVector(
Real(var.name, (6, 'g'), var.annotations),
Real(var.name, (6, 'g'), var.annotations, var.linked),
lambda: vector.data().astype(float)
)
elif isinstance(var, String):
def data():
s = vector.data()
return MArray(as_float_or_nan(s, where=s.mask), mask=s.mask)
return RealVector(
Real(var.name, (6, "g"), var.annotations), data
Real(var.name, (6, "g"), var.annotations, var.linked), data
)
raise AssertionError

Expand All @@ -266,22 +274,10 @@ def __call__(self, vector: DataVector) -> CategoricalVector:
var, _ = vector
if isinstance(var, Categorical):
return vector
if isinstance(var, Real):
data, values = categorical_from_vector(vector.data())
return CategoricalVector(
Categorical(var.name, values, var.annotations),
lambda: data
)
elif isinstance(var, Time):
if isinstance(var, (Real, Time, String)):
data, values = categorical_from_vector(vector.data())
return CategoricalVector(
Categorical(var.name, values, var.annotations),
lambda: data
)
elif isinstance(var, String):
data, values = categorical_from_vector(vector.data())
return CategoricalVector(
Categorical(var.name, values, var.annotations),
Categorical(var.name, values, var.annotations, var.linked),
lambda: data
)
raise AssertionError
Expand All @@ -295,7 +291,7 @@ def __call__(self, vector: DataVector) -> TimeVector:
return vector
elif isinstance(var, Real):
return TimeVector(
Time(var.name, var.annotations),
Time(var.name, var.annotations, var.linked),
lambda: vector.data().astype("M8[us]")
)
elif isinstance(var, Categorical):
Expand All @@ -305,15 +301,15 @@ def data():
dt = pd.to_datetime(s, errors="coerce").values.astype("M8[us]")
return MArray(dt, mask=d.mask)
return TimeVector(
Time(var.name, var.annotations), data
Time(var.name, var.annotations, var.linked), data
)
elif isinstance(var, String):
def data():
s = vector.data()
dt = pd.to_datetime(s, errors="coerce").values.astype("M8[us]")
return MArray(dt, mask=s.mask)
return TimeVector(
Time(var.name, var.annotations), data
Time(var.name, var.annotations, var.linked), data
)
raise AssertionError

Expand Down Expand Up @@ -532,6 +528,17 @@ def __init__(self, parent=None, **kwargs):
)
form.addRow("Name:", self.name_edit)

self.unlink_var_cb = QCheckBox(
"Unlink variable from its source variable", self,
toolTip="Make Orange forget that the variable is derived from "
"another.\n"
"Use this for instance when you want to consider variables "
"with the same name but from different sources as the same "
"variable."
)
self.unlink_var_cb.toggled.connect(self._set_unlink)
form.addRow("", self.unlink_var_cb)

vlayout = QVBoxLayout(margin=0, spacing=1)
self.labels_edit = view = QTreeView(
objectName="annotation-pairs-edit",
Expand Down Expand Up @@ -616,17 +623,23 @@ def set_data(self, var, transform=()):
if var is not None:
name = var.name
annotations = var.annotations
unlink = False
for tr in transform:
if isinstance(tr, Rename):
name = tr.name
elif isinstance(tr, Annotate):
annotations = tr.annotations
elif isinstance(tr, Unlink):
unlink = True
self.name_edit.setText(name)
self.labels_model.set_dict(dict(annotations))
self.add_label_action.actionGroup().setEnabled(True)
self.unlink_var_cb.setChecked(unlink)
else:
self.add_label_action.actionGroup().setEnabled(False)

self.unlink_var_cb.setDisabled(var is None or not var.linked)

def get_data(self):
"""Retrieve the modified variable.
"""
Expand All @@ -639,6 +652,8 @@ def get_data(self):
tr.append(Rename(name))
if self.var.annotations != labels:
tr.append(Annotate(labels))
if self.var.linked and self.unlink_var_cb.isChecked():
tr.append(Unlink())
return self.var, tr

def clear(self):
Expand All @@ -647,6 +662,7 @@ def clear(self):
self.var = None
self.name_edit.setText("")
self.labels_model.setRowCount(0)
self.unlink_var_cb.setChecked(False)

@Slot()
def on_name_changed(self):
Expand All @@ -661,6 +677,10 @@ def on_label_selection_changed(self):
selected = self.labels_edit.selectionModel().selectedRows()
self.remove_label_action.setEnabled(bool(len(selected)))

def _set_unlink(self, unlink):
self.unlink_var_cb.setChecked(unlink)
self.variable_changed.emit()


class GroupItemsDialog(QDialog):
"""
Expand Down Expand Up @@ -1157,7 +1177,7 @@ def __init__(self, *args, **kwargs):
hlayout.addStretch(10)
vlayout.addLayout(hlayout)

form.insertRow(1, "Values:", vlayout)
form.insertRow(2, "Values:", vlayout)

QWidget.setTabOrder(self.name_edit, self.values_edit)
QWidget.setTabOrder(self.values_edit, button1)
Expand Down Expand Up @@ -2030,23 +2050,32 @@ def state(i):
model.data(midx, TransformRole))

state = [state(i) for i in range(model.rowCount())]
if all(tr is None or not tr for _, tr in state) \
and self.output_table_name in ("", data.name):
input_vars = data.domain.variables + data.domain.metas
if self.output_table_name in ("", data.name) \
and not any(requires_transform(var, trs)
for var, (_, trs) in zip(input_vars, state)):
self.Outputs.data.send(data)
self.info.set_output_summary(len(data),
format_summary_details(data))
return

output_vars = []
input_vars = data.domain.variables + data.domain.metas
assert all(v_.vtype.name == v.name
for v, (v_, _) in zip(input_vars, state))
output_vars = []
unlinked_vars = []
unlink_domain = False
for (_, tr), v in zip(state, input_vars):
if tr:
var = apply_transform(v, data, tr)
if requires_unlink(v, tr):
unlinked_var = var.copy(compute_value=None)
unlink_domain = True
else:
unlinked_var = var
else:
var = v
unlinked_var = var = v
output_vars.append(var)
unlinked_vars.append(unlinked_var)

if len(output_vars) != len({v.name for v in output_vars}):
self.Error.duplicate_var_name()
Expand All @@ -2058,15 +2087,23 @@ def state(i):
nx = len(domain.attributes)
ny = len(domain.class_vars)

Xs = output_vars[:nx]
Ys = output_vars[nx: nx + ny]
Ms = output_vars[nx + ny:]
# Move non primitive Xs, Ys to metas (if they were changed)
Ms += [v for v in Xs + Ys if not v.is_primitive()]
Xs = [v for v in Xs if v.is_primitive()]
Ys = [v for v in Ys if v.is_primitive()]
domain = Orange.data.Domain(Xs, Ys, Ms)
def construct_domain(vars_list):
# Move non primitive Xs, Ys to metas (if they were changed)
Xs = [v for v in vars_list[:nx] if v.is_primitive()]
Ys = [v for v in vars_list[nx: nx + ny] if v.is_primitive()]
Ms = vars_list[nx + ny:] + \
[v for v in vars_list[:nx + ny] if not v.is_primitive()]
return Orange.data.Domain(Xs, Ys, Ms)

domain = construct_domain(output_vars)
new_data = data.transform(domain)
if unlink_domain:
unlinked_domain = construct_domain(unlinked_vars)
new_data = new_data.from_numpy(
unlinked_domain,
new_data.X, new_data.Y, new_data.metas, new_data.W,
new_data.attributes, new_data.ids
)
if self.output_table_name:
new_data.name = self.output_table_name
self.Outputs.data.send(new_data)
Expand Down Expand Up @@ -2236,7 +2273,7 @@ def i(text):
def text(text):
return "<span>{}</span>".format(escape(text))
assert trs
rename = annotate = catmap = None
rename = annotate = catmap = unlink = None
reinterpret = None

for tr in trs:
Expand All @@ -2246,6 +2283,8 @@ def text(text):
annotate = tr
elif isinstance(tr, CategoriesMapping):
catmap = tr
elif isinstance(tr, Unlink):
unlink = tr
elif isinstance(tr, ReinterpretTransformTypes):
reinterpret = tr

Expand All @@ -2258,6 +2297,8 @@ def text(text):
header = "{} → {}".format(var.name, rename.name)
else:
header = var.name
if unlink is not None:
header += "(unlinked from source)"

values_section = None
if catmap is not None:
Expand Down Expand Up @@ -2323,14 +2364,15 @@ def abstract(var):
(key, str(value))
for key, value in var.attributes.items()
))
linked = var.compute_value is not None
if isinstance(var, Orange.data.DiscreteVariable):
return Categorical(var.name, tuple(var.values), annotations)
return Categorical(var.name, tuple(var.values), annotations, linked)
elif isinstance(var, Orange.data.TimeVariable):
return Time(var.name, annotations)
return Time(var.name, annotations, linked)
elif isinstance(var, Orange.data.ContinuousVariable):
return Real(var.name, (var.number_of_decimals, 'f'), annotations)
return Real(var.name, (var.number_of_decimals, 'f'), annotations, linked)
elif isinstance(var, Orange.data.StringVariable):
return String(var.name, annotations)
return String(var.name, annotations, linked)
else:
raise TypeError

Expand Down Expand Up @@ -2359,6 +2401,17 @@ def apply_transform(var, table, trs):
return var


def requires_unlink(var: Orange.data.Variable, trs: List[Transform]) -> bool:
return trs is not None \
and any(isinstance(tr, Unlink) for tr in trs) \
and (var.compute_value is not None or len(trs) > 1)


def requires_transform(var: Orange.data.Variable, trs: List[Transform]) -> bool:
return trs and not all (isinstance(tr, Unlink) for tr in trs) \
or requires_unlink(var, trs)


@singledispatch
def apply_transform_var(var, trs):
# type: (Orange.data.Variable, List[Transform]) -> Orange.data.Variable
Expand Down
Loading

0 comments on commit 3629531

Please sign in to comment.