Skip to content

Commit

Permalink
add node to container instead of changing output node of container (#478
Browse files Browse the repository at this point in the history
)

* add node to container instead of changing output node of container
* remove unused imports

Signed-off-by: Jan-Benedikt Jagusch <jan.jagusch@gmail.com>
  • Loading branch information
janjagusch authored Jul 14, 2021
1 parent 93e95db commit f295db6
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import copy
import numbers
import numpy as np
import onnx
from collections import Counter
from ...common._apply_operation import (
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
from ...common._registration import register_converter
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
from ....proto import onnx_proto
from onnxconverter_common.container import ModelComponentContainer


def _translate_split_criterion(criterion):
Expand Down Expand Up @@ -399,26 +397,18 @@ def convert_lightgbm(scope, operator, container):

apply_div(scope, [output_name, denominator_name],
operator.output_full_names, container, broadcast=1)
elif post_transform:
container.add_node(
post_transform,
output_name,
operator.output_full_names,
name=scope.get_unique_operator_name(
post_transform),
)
else:
container.add_node('Identity', output_name,
operator.output_full_names,
name=scope.get_unique_operator_name('Identity'))
if post_transform:
_add_post_transform_node(container, post_transform)


def _add_post_transform_node(container: ModelComponentContainer, op_type: str):
"""
Add a post transform node to a ModelComponentContainer.
Useful for post transform functions that are not supported by the ONNX spec yet (e.g. 'Exp').
"""
assert len(container.outputs) == 1, "Adding a post transform node is only possible for models with 1 output."
original_output_name = container.outputs[0].name
new_output_name = f"{op_type.lower()}_{original_output_name}"
post_transform_node = onnx.helper.make_node(op_type, inputs=[original_output_name], outputs=[new_output_name])
container.nodes.append(post_transform_node)
container.outputs[0].name = new_output_name


def modify_tree_for_rule_in_set(gbm, use_float=False):
Expand Down

0 comments on commit f295db6

Please sign in to comment.