Skip to content

Commit

Permalink
Merge pull request #2430 from jerneju/sparse-tree
Browse files Browse the repository at this point in the history
[FIX] Tree: Sparse Support
  • Loading branch information
janezd authored Jul 21, 2017
2 parents 11de443 + 2545c92 commit 6865a45
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
15 changes: 10 additions & 5 deletions Orange/classification/tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tree inducers: SKL and Orange's own inducer"""
import numpy as np
import scipy.sparse as sp
import sklearn.tree as skl_tree

from Orange.base import TreeModel as TreeModelInterface
Expand Down Expand Up @@ -76,6 +77,7 @@ def _select_attr(self, data):
"""
# Prevent false warnings by pylint
attr = attr_no = None
col_x = None
REJECT_ATTRIBUTE = 0, None, None, 0

def _score_disc():
Expand All @@ -89,8 +91,7 @@ def _score_disc():
if n_values < 2:
return REJECT_ATTRIBUTE

x = data.X[:, attr_no].flatten()
cont = _tree_scorers.contingency(x, len(data.domain.attributes[attr_no].values),
cont = _tree_scorers.contingency(col_x, len(data.domain.attributes[attr_no].values),
data.Y, len(data.domain.class_var.values))
attr_distr = np.sum(cont, axis=0)
null_nodes = attr_distr <= self.min_samples_leaf
Expand All @@ -111,7 +112,7 @@ def _score_disc():
cont_entr = np.sum(cont * np.log(cont))
score = (class_entr - attr_entr + cont_entr) / n / np.log(2)
score *= n / len(data) # punishment for missing values
branches = x
branches = col_x
branches[np.isnan(branches)] = -1
if score == 0:
return REJECT_ATTRIBUTE
Expand All @@ -135,13 +136,12 @@ def _score_disc_bin():
return REJECT_ATTRIBUTE
best_score *= 1 - np.sum(cont.unknowns) / len(data)
mapping, branches = MappedDiscreteNode.branches_from_mapping(
data.X[:, attr_no], best_mapping, n_values)
col_x, best_mapping, n_values)
node = MappedDiscreteNode(attr, attr_no, mapping, None)
return best_score, node, branches, 2

def _score_cont():
"""Scoring for numeric attributes"""
col_x = data.X[:, attr_no]
nans = np.sum(np.isnan(col_x))
non_nans = len(col_x) - nans
arginds = np.argsort(col_x)[:non_nans]
Expand All @@ -159,12 +159,17 @@ def _score_cont():

#######################################
# The real _select_attr starts here
is_sparse = sp.issparse(data.X)
domain = data.domain
class_var = domain.class_var
best_score, *best_res = REJECT_ATTRIBUTE
best_res = [Node(None, None, None)] + best_res[1:]
disc_scorer = _score_disc_bin if self.binarize else _score_disc
for attr_no, attr in enumerate(domain.attributes):
col_x = data.X[:, attr_no]
if is_sparse:
col_x = col_x.toarray()
col_x = col_x.flatten()
sc, *res = disc_scorer() if attr.is_discrete else _score_cont()
if res[0] is not None and sc > best_score:
best_score, best_res = sc, res
Expand Down
20 changes: 20 additions & 0 deletions Orange/widgets/model/tests/test_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# pylint: disable=protected-access
import numpy as np
import scipy.sparse as sp

from Orange.base import Model
from Orange.data import Table
from Orange.widgets.model.owtree import OWTreeLearner
from Orange.widgets.tests.base import (
DefaultParameterMapping,
Expand Down Expand Up @@ -35,3 +40,18 @@ def test_parameters_unchecked(self):
self.parameters = [DefaultParameterMapping(par.name, val)
for par, val in zip(self.parameters, (None, 2, 1))]
self.test_parameters()

def test_sparse_data(self):
"""
Tree can handle sparse data.
GH-2430
"""
table1 = Table("iris")
self.send_signal("Data", table1)
model_dense = self.get_output("Model")
table2 = Table("iris")
table2.X = sp.csr_matrix(table2.X)
model_sparse = self.get_output("Model")
self.assertTrue(np.array_equal(model_dense._code, model_sparse._code))
self.assertTrue(np.array_equal(model_dense._thresholds, model_sparse._thresholds))
self.assertTrue(np.array_equal(model_dense._values, model_sparse._values))

0 comments on commit 6865a45

Please sign in to comment.