Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Decision Tree Visualization #386

Merged
merged 12 commits into from
Aug 1, 2023
10 changes: 9 additions & 1 deletion skops/io/_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,15 @@ def walk_tree(
"here: https://github.com/skops-dev/skops/issues"
)

if node_name == "constructor":
if isinstance(node, type):
yield NodeInfo(
level=level,
key=node_name,
val=type(node).__name__,
is_self_safe=False,
is_safe=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we should defer to the information in the Node instead of assuming it to be unsafe all the time. the Tree constructor for instance, is safe.

Copy link
Contributor Author

@reidjohnson reidjohnson Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I agree. However the challenge I'm running into is that it's not a valid Node object, so it does not have the necessary attributes:

AttributeError: type object 'sklearn.tree._tree.Tree' has no attribute 'is_self_safe'

If I understand correctly, it's because the constructor class is directly set as a child node here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so probably the parent when doing the walk_tree should send this info to the child.

is_last=is_last,
)
return

if isinstance(node, dict):
Expand Down
58 changes: 56 additions & 2 deletions skops/io/tests/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import pytest
import sklearn
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.preprocessing import (
FunctionTransformer,
MinMaxScaler,
PolynomialFeatures,
StandardScaler,
)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

import skops.io as sio

Expand Down Expand Up @@ -39,7 +40,7 @@ def unsafe_function(x):
("scale", MinMaxScaler()),
])),
])),
("clf", RandomForestRegressor(random_state=0)),
("clf", LogisticRegression(random_state=0, solver="lbfgs")),
]).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2])
# fmt: on
return pipeline
Expand Down Expand Up @@ -282,3 +283,56 @@ def test_long_bytes(self, capsys):
]
stdout, _ = capsys.readouterr()
assert stdout.strip() == "\n".join(expected)

@pytest.mark.parametrize("cls", [DecisionTreeClassifier, DecisionTreeRegressor])
def test_decision_tree(self, cls, capsys):
model = cls(random_state=0).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2])
dumped = sio.dumps(model)
sio.visualize(dumped)

classes = []
if isinstance(model, DecisionTreeClassifier):
criterion = "gini"
classes = [
" ├── classes_: numpy.ndarray",
" ├── n_classes_: numpy.int64",
]
elif isinstance(model, DecisionTreeRegressor):
criterion = "squared_error"

expected = [
"root: sklearn.tree._classes.{}".format(cls.__name__),
"└── attrs: builtins.dict",
' ├── criterion: json-type("{}")'.format(criterion),
' ├── splitter: json-type("best")',
" ├── max_depth: json-type(null)",
" ├── min_samples_split: json-type(2)",
" ├── min_samples_leaf: json-type(1)",
" ├── min_weight_fraction_leaf: json-type(0.0)",
" ├── max_features: json-type(null)",
" ├── max_leaf_nodes: json-type(null)",
" ├── random_state: json-type(0)",
" ├── min_impurity_decrease: json-type(0.0)",
" ├── class_weight: json-type(null)",
" ├── ccp_alpha: json-type(0.0)",
" ├── n_features_in_: json-type(2)",
" ├── n_outputs_: json-type(1)",
]
expected += classes
expected += [
" ├── max_features_: json-type(2)",
" ├── tree_: sklearn.tree._tree.Tree",
" │ ├── attrs: builtins.dict",
" │ │ ├── max_depth: json-type(2)",
" │ │ ├── node_count: json-type(5)",
" │ │ ├── nodes: numpy.ndarray",
" │ │ └── values: numpy.ndarray",
" │ ├── args: builtins.tuple",
" │ │ ├── content: json-type(2)",
" │ │ ├── content: numpy.ndarray",
" │ │ └── content: json-type(1)",
" │ └── constructor: type [UNSAFE]",
' └── _sklearn_version: json-type("{}")'.format(sklearn.__version__),
]
stdout, _ = capsys.readouterr()
assert stdout.strip() == "\n".join(expected)