From fee9871f047b1a05442f11b8662f2c943da4172c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jaka=20Kokos=CC=8Car?= Date: Mon, 4 Dec 2023 19:56:05 +0100 Subject: [PATCH] owassurvivaldata: properly handle widget input Widget now respects the information about survival variables stored in the domain. --- .../widgets/owassurvivaldata.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/orangecontrib/survival_analysis/widgets/owassurvivaldata.py b/orangecontrib/survival_analysis/widgets/owassurvivaldata.py index 20413b7..45cb9f4 100644 --- a/orangecontrib/survival_analysis/widgets/owassurvivaldata.py +++ b/orangecontrib/survival_analysis/widgets/owassurvivaldata.py @@ -17,6 +17,7 @@ TIME_VAR, EVENT_VAR, TIME_TO_EVENT_VAR, + get_survival_endpoints, ) @@ -36,9 +37,9 @@ class Outputs: data = Output('Data', Table) settingsHandler = DomainContextHandler() - time_var = ContextSetting(None, schema_only=True) - event_var = ContextSetting(None, schema_only=True) - auto_commit: bool = Setting(True, schema_only=True) + time_var = ContextSetting(None) + event_var = ContextSetting(None) + auto_commit: bool = Setting(True) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -82,12 +83,18 @@ def set_data(self, data: Table) -> None: self._data = data.transform(data.domain) self._data.attributes = data.attributes.copy() # look for survival data in meta and class vars only. - vars_ = [ + metas = [ var - for var in data.domain.metas + data.domain.class_vars + for var in data.domain.metas + if not isinstance(var, (TimeVariable, StringVariable)) + ] + class_vars = [ + var + for var in data.domain.class_vars if not isinstance(var, (TimeVariable, StringVariable)) ] - domain = Domain(vars_) + + domain = Domain([], metas=metas, class_vars=class_vars) self.controls.time_var.model().set_domain(domain) self.controls.event_var.model().set_domain(domain) @@ -95,9 +102,16 @@ def set_data(self, data: Table) -> None: time_var_model = self.controls.time_var.model() event_var_model = self.controls.event_var.model() - self.time_var = time_var_model[0] if len(time_var_model) else None - self.event_var = event_var_model[0] if len(event_var_model) else None + # If not found in the domain then default to the first var in model. + _time_var, _event_var = get_survival_endpoints(domain) + + if len(time_var_model): + self.time_var = time_var_model[0] if _time_var is None else _time_var + + if len(event_var_model): + self.event_var = event_var_model[0] if _event_var is None else _event_var + # Lastly, respect saved domain context if self.time_var is not None and self.event_var is not None: self.openContext(domain)