Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparkML OneHotEncoder conversion is wrong #551

Closed
memoryz opened this issue May 21, 2022 · 1 comment
Closed

SparkML OneHotEncoder conversion is wrong #551

memoryz opened this issue May 21, 2022 · 1 comment

Comments

@memoryz
Copy link
Contributor

memoryz commented May 21, 2022

SparkML OneHotEncoder conversion is implemented wrong. It just so happens to handle the conversion when there is one input column and one output column. Once you try multiple inputCols and outputCols, the generated ONNX graph is wrong.

Repro:

from pyspark.ml.feature import OneHotEncoder
from onnxconverter_common.data_types import FloatTensorType, Int32TensorType
from onnxmltools.convert import convert_sparkml
from onnxruntime import InferenceSession

df = spark.createDataFrame(
  [(0, 2), (1, 4),],
  ["col1", "col2"]
)

encoder = OneHotEncoder(inputCols = ["col1", "col2"], outputCols = ["col1_enc", "col2_enc"], dropLast=False).fit(df)

initial_types = [
  ("col1", Int32TensorType([None, 1])),
  ("col2", Int32TensorType([None, 1])),
]

onnx_model = convert_sparkml(encoder, 'OneHotEncoder', initial_types)

# Writing the model out for inspection
with open("ohe_quicktest.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

# Run with onnx runtime
input_data = {"col1": [(0, 1)], "col2": [(2, 4)]}
session = InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
output = session.run(None, input_data)

Expected result:
The model can correct encode the input features.

Actual result:
InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Failed to load model with error: /onnxruntime_src/onnxruntime/core/graph/graph.cc:1368 void onnxruntime::Graph::InitializeStateFromModelFileGraphProto() This is an invalid model. Graph output (col2_enc) does not exist in the graph.

The generated ONNX graph looks like this, which is obviously wrong:
image

@memoryz
Copy link
Contributor Author

memoryz commented May 27, 2022

PR merged.

@memoryz memoryz closed this as completed May 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant