You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
SparkML OneHotEncoder conversion is implemented wrong. It just so happens to handle the conversion when there is one input column and one output column. Once you try multiple inputCols and outputCols, the generated ONNX graph is wrong.
Repro:
frompyspark.ml.featureimportOneHotEncoderfromonnxconverter_common.data_typesimportFloatTensorType, Int32TensorTypefromonnxmltools.convertimportconvert_sparkmlfromonnxruntimeimportInferenceSessiondf=spark.createDataFrame(
[(0, 2), (1, 4),],
["col1", "col2"]
)
encoder=OneHotEncoder(inputCols= ["col1", "col2"], outputCols= ["col1_enc", "col2_enc"], dropLast=False).fit(df)
initial_types= [
("col1", Int32TensorType([None, 1])),
("col2", Int32TensorType([None, 1])),
]
onnx_model=convert_sparkml(encoder, 'OneHotEncoder', initial_types)
# Writing the model out for inspectionwithopen("ohe_quicktest.onnx", "wb") asf:
f.write(onnx_model.SerializeToString())
# Run with onnx runtimeinput_data= {"col1": [(0, 1)], "col2": [(2, 4)]}
session=InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
output=session.run(None, input_data)
Expected result:
The model can correct encode the input features.
Actual result:
InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Failed to load model with error: /onnxruntime_src/onnxruntime/core/graph/graph.cc:1368 void onnxruntime::Graph::InitializeStateFromModelFileGraphProto() This is an invalid model. Graph output (col2_enc) does not exist in the graph.
The generated ONNX graph looks like this, which is obviously wrong:
The text was updated successfully, but these errors were encountered:
SparkML OneHotEncoder conversion is implemented wrong. It just so happens to handle the conversion when there is one input column and one output column. Once you try multiple inputCols and outputCols, the generated ONNX graph is wrong.
Repro:
Expected result:
The model can correct encode the input features.
Actual result:
InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Failed to load model with error: /onnxruntime_src/onnxruntime/core/graph/graph.cc:1368 void onnxruntime::Graph::InitializeStateFromModelFileGraphProto() This is an invalid model. Graph output (col2_enc) does not exist in the graph.
The generated ONNX graph looks like this, which is obviously wrong:
The text was updated successfully, but these errors were encountered: