You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After conversion of pyspark.ml RandomForestClassifier to ONNX, when infering with InferenceSession from onnxruntime, all the output probabilities (and, thus, predictions) were the same.
While comparing the output from the ONNX model with the sparkml model, it was noticed that the ONNX model assigned all BRANCH_LEQ nodes to BRANCH_EQ.
Tracing back this issue to onnxmltools/convert/sparkml/operator_converters/tree_helper.py, function rewrite_ids_and_process, it is possible to see that this is happening since the node values for these nodes are contained in an array (L282).
Going further back, this is due to the function sparkml_tree_dataset_to_sklearn in onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py. In here, issue appears to be in L31: threshold.append(item["leftCategoriesOrThreshold"])
which is retrieving an array.
Is there any reason to not using: threshold.append(item["leftCategoriesOrThreshold"][0] if len(item["leftCategoriesOrThreshold"]) >= 1 else -1.0)
just like it is done in L37 for the tuple case: threshold.append(item[1][0] if len(item[1]) >= 1 else -1.0)
?
Thank you :)
Best regards,
Tiago
The text was updated successfully, but these errors were encountered:
tiago-rib-goncalves
changed the title
RandomForestClassifier to ONNX: conversion succeeded but retrieves same probability on inference for all observations
RandomForestClassifier: ONNX output converts all BRANCH_LEQ to BRANCH_EQ
Apr 9, 2024
Hey,
After conversion of pyspark.ml RandomForestClassifier to ONNX, when infering with InferenceSession from onnxruntime, all the output probabilities (and, thus, predictions) were the same.
While comparing the output from the ONNX model with the sparkml model, it was noticed that the ONNX model assigned all BRANCH_LEQ nodes to BRANCH_EQ.
Tracing back this issue to onnxmltools/convert/sparkml/operator_converters/tree_helper.py, function rewrite_ids_and_process, it is possible to see that this is happening since the node values for these nodes are contained in an array (L282).
Going further back, this is due to the function sparkml_tree_dataset_to_sklearn in onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py. In here, issue appears to be in L31:
threshold.append(item["leftCategoriesOrThreshold"])
which is retrieving an array.
Is there any reason to not using:
threshold.append(item["leftCategoriesOrThreshold"][0] if len(item["leftCategoriesOrThreshold"]) >= 1 else -1.0)
just like it is done in L37 for the tuple case:
threshold.append(item[1][0] if len(item[1]) >= 1 else -1.0)
?
Thank you :)
Best regards,
Tiago
The text was updated successfully, but these errors were encountered: