Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/onnx/onnxmltools into o116
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Apr 2, 2024
2 parents 92d6df3 + b1e1068 commit 1b3dcc3
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions onnxmltools/convert/xgboost/shape_calculators/Classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def calculate_xgboost_classifier_output_shapes(operator):
objective = params["objective"]
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
num_class = params.get("num_class", None)

if num_class is not None:

if objective == "binary:logistic":
ncl = 2
elif num_class is not None:
ncl = num_class
n_estimators = ntrees // ncl
elif objective == "binary:logistic":
ncl = 2
else:
ncl = ntrees // n_estimators
if objective == "reg:logistic" and ncl == 1:
Expand All @@ -46,7 +46,7 @@ def calculate_xgboost_classifier_output_shapes(operator):
operator.outputs[0].type = Int64TensorType(shape=[N])
else:
operator.outputs[0].type = StringTensorType(shape=[N])
operator.outputs[1].type = operator.outputs[1].type = FloatTensorType([N, ncl])
operator.outputs[1].type = FloatTensorType([N, ncl])


register_shape_calculator("XGBClassifier", calculate_xgboost_classifier_output_shapes)
Expand Down

0 comments on commit 1b3dcc3

Please sign in to comment.