Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ONNX conversion for Spark OneHotEncoder model #552

Merged
merged 5 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 64 additions & 62 deletions onnxmltools/convert/sparkml/operator_converters/onehot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,76 @@
from ...common._registration import register_converter, register_shape_calculator
from ....proto import onnx_proto
from ...common.data_types import FloatTensorType


def convert_sparkml_one_hot_encoder(scope, operator, container):
op = operator.raw_operator
C = operator.inputs[0].type.shape[1]

# encoded_slot_sizes[i] is the number of output coordinates associated with the ith categorical feature

# Variable names produced by one-hot encoders. Each of them is the encoding result of a categorical feature.
final_variable_names = []
final_variable_lengths = []
for i in range(0, len(op.categorySizes)):
catSize = op.categorySizes[i]
cats = range(0, catSize)
# Put a feature index we want to encode to a tensor
index_variable_name = scope.get_unique_variable_name('target_index')
container.add_initializer(index_variable_name, onnx_proto.TensorProto.INT64, [1], [i])

# Extract the categorical feature from the original input tensor
extracted_feature_name = scope.get_unique_variable_name('extracted_feature_at_' + str(i))
extractor_type = 'ArrayFeatureExtractor'
extractor_attrs = {'name': scope.get_unique_operator_name(extractor_type)}
container.add_node(extractor_type, [operator.inputs[0].full_name, index_variable_name],
extracted_feature_name, op_domain='ai.onnx.ml', **extractor_attrs)

# Encode the extracted categorical feature as a one-hot vector
encoder_type = 'OneHotEncoder'
encoder_attrs = {'name': scope.get_unique_operator_name(encoder_type), 'cats_int64s': cats}
encoded_feature_name = scope.get_unique_variable_name('encoded_feature_at_' + str(i))
container.add_node(encoder_type, extracted_feature_name, encoded_feature_name, op_domain='ai.onnx.ml',
**encoder_attrs)

# Collect features produce by one-hot encoders
final_variable_names.append(encoded_feature_name)
# For each categorical value, the length of its encoded result is the number of all possible categorical values
final_variable_lengths.append(catSize)

# Combine encoded features and passed features
collector_type = 'FeatureVectorizer'
collector_attrs = {'name': scope.get_unique_operator_name(collector_type)}
collector_attrs['inputdimensions'] = final_variable_lengths
container.add_node(collector_type, final_variable_names, operator.outputs[0].full_name,
op_domain='ai.onnx.ml', **collector_attrs)
from ...common._topology import Operator, Scope
from pyspark.ml.feature import OneHotEncoderModel
from typing import List

def _get_categories(operator: Operator) -> List[List[int]]:
op: OneHotEncoderModel = operator.raw_operator
categorySizes: List[int] = op.categorySizes

# This is necessary to match SparkML's OneHotEncoder behavior.
# If handleInvalid is set to "keep", an extra "category" indicating invalid values is added as last category:
# - if dropLast is set to false, then an extra bit will be added for invalid input, which does not match ONNX's OneHotEncoder behavior.
# - if dropLast is set to true, then invalid values are encoded as all-zeros vector, which matches ONNX's OneHotEncoder behavior when "zeros" is set to 1.
# If handleInvalid is set to "error", then:
# - if dropLast is set to false, then nothing will be added or dropped, and matches ONNX's OneHotEncoder behavior when "zeros" is set to 0.
# - if dropLast is set to true, then the last bit will be dropped, which does not match ONNX's OneHotEncoder behavior.
if (op.getHandleInvalid() == "keep" and op.getDropLast() == False) or (
op.getHandleInvalid() == "error" and op.getDropLast() == True
):
raise RuntimeError(
f"The 'handleInvalid' and 'dropLast' parameters must be set to ('keep', True) or ('error', False), but got ('{op.getHandleInvalid()}', {op.getDropLast()}) instead."
)

return [list(range(0, size)) for size in categorySizes]

def convert_sparkml_one_hot_encoder(scope: Scope, operator: Operator, container):
categories = _get_categories(operator)
N = operator.inputs[0].type.shape[0] or -1

zeros = 1 if operator.raw_operator.getHandleInvalid() == "keep" else 0

for i in range(0, len(categories)):
encoder_type = "OneHotEncoder"

# Set zeros to 0 to match the "error" handleInvalid behavior of SparkML's OneHotEncoder.
encoder_attrs = {
"name": scope.get_unique_operator_name(encoder_type),
"cats_int64s": categories[i],
"zeros": zeros,
}

encoded_feature_name = scope.get_unique_variable_name("encoded_feature_at_" + str(i))
container.add_node(
op_type=encoder_type,
inputs=[operator.inputs[i].full_name],
outputs=[encoded_feature_name],
op_domain="ai.onnx.ml",
op_version=1,
**encoder_attrs,
)

shape_variable_name = scope.get_unique_variable_name("shape_at_" + str(i))
container.add_initializer(shape_variable_name, onnx_proto.TensorProto.INT64, [2], [N, len(categories[i])])

container.add_node(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch to an output per category and not concatenate them into one single output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to match the behavior of the Spark OneHotEncoderModel. In Spark, if the OHEModel has multiple outputCols (one per category), the model does not concat the outputs into a vector. Instead you use a vector assembler to concat them, and the vector assembler will be translated into a concat ONNX operator.

op_type="Reshape",
inputs=[encoded_feature_name, shape_variable_name],
outputs=[operator.outputs[i].full_name],
op_domain="ai.onnx",
op_version=13,
)


register_converter('pyspark.ml.feature.OneHotEncoderModel', convert_sparkml_one_hot_encoder)


def calculate_sparkml_one_hot_encoder_output_shapes(operator):
'''
Allowed input/output patterns are
1. [N, C] ---> [N, C']
2. [N, 'None'] ---> [N, 'None']
'''
op = operator.raw_operator

# encoded_slot_sizes[i] is the number of output coordinates associated with the ith categorical feature.
encoded_slot_sizes = op.categorySizes

def calculate_sparkml_one_hot_encoder_output_shapes(operator: Operator):
categories = _get_categories(operator)
N = operator.inputs[0].type.shape[0]
# Calculate the output feature length by replacing the count of categorical
# features with their encoded widths
if operator.inputs[0].type.shape[1] != 'None':
C = operator.inputs[0].type.shape[1] - 1 + sum(encoded_slot_sizes)
else:
C = 'None'

operator.outputs[0].type = FloatTensorType([N, C])
for i, output in enumerate(operator.outputs):
output.type = FloatTensorType([N, len(categories[i])])


register_shape_calculator('pyspark.ml.feature.OneHotEncoderModel', calculate_sparkml_one_hot_encoder_output_shapes)
Expand Down
6 changes: 5 additions & 1 deletion onnxmltools/convert/sparkml/ops_input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,12 @@ def build_io_name_map():
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")]
),
"pyspark.ml.feature.OneHotEncoderModel": (
lambda model: model.getOrDefault("inputCols"),
lambda model: model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")],
lambda model: model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")],
),
"pyspark.ml.feature.StringIndexerModel": (
lambda model: [model.getOrDefault("inputCol")],
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ pytest-spark
scikit-learn
scipy
wheel
xgboost
xgboost==1.5.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this?

Copy link
Contributor Author

@memoryz memoryz May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest xgboost version breaks some unit tests in the xgboost unit tests, which has nothing to do with my changes, so I'm pinning it to the older version that works.

76 changes: 61 additions & 15 deletions tests/sparkml/test_onehot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,76 @@ class TestSparkmlOneHotEncoder(SparkMlTestCase):

@unittest.skipIf(sys.version_info < (3, 8),
reason="pickle fails on python 3.7")
def test_model_onehot_encoder(self):
encoder = OneHotEncoder(inputCols=['index'], outputCols=['indexVec'])
data = self.spark.createDataFrame(
[(0.0,), (1.0,), (2.0,), (2.0,), (0.0,), (2.0,)], ['index'])
def test_model_onehot_encoder_1(self):
"""
Testing ONNX conversion for Spark OneHotEncoder when handleInvalid is set to "error" and dropLast set to False.
"""
encoder = OneHotEncoder(inputCols=['index1', 'index2'], outputCols=['index1Vec', 'index2Vec'], handleInvalid="error", dropLast=False)
data = self.spark.createDataFrame([(0.0,5.0,), (1.0,4.0,), (2.0,3.0,), (2.0,2.0,), (0.0,1.0,), (2.0,0.0,)], ['index1','index2'])
model = encoder.fit(data)
model_onnx = convert_sparkml(
model, 'Sparkml OneHotEncoder', [('index', FloatTensorType([None, 1]))],
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([None, 1])), ('index2', FloatTensorType([-1, 1]))],
memoryz marked this conversation as resolved.
Show resolved Hide resolved
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
# run the model
predicted = model.transform(data)
data_np = data.select("index").toPandas().values.astype(numpy.float32)
predicted_np = predicted.select("indexVec").toPandas().indexVec.apply(
lambda x: x.toArray().tolist()).values
expected = numpy.asarray(
[x + [0] if numpy.amax(x) == 1 else x + [1] for x in predicted_np])

paths = save_data_models(data_np, expected, model, model_onnx,
basename="SparkmlOneHotEncoder")
data_np = {
"index1": data.select("index1").toPandas().values.astype(numpy.float32),
"index2": data.select("index2").toPandas().values.astype(numpy.float32)
}

predicted_np_1 = predicted.select("index1Vec").toPandas().index1Vec.apply(lambda x: x.toArray()).values
predicted_np_2 = predicted.select("index2Vec").toPandas().index2Vec.apply(lambda x: x.toArray()).values
expected = {"index1Vec": numpy.asarray(predicted_np_1.tolist()), "index2Vec": numpy.asarray(predicted_np_2.tolist())}

paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlOneHotEncoder")
onnx_model_path = paths[-1]

output_names = ['index1Vec', 'index2Vec']
output, output_shapes = run_onnx_model(output_names, data_np, onnx_model_path)
actual_output = dict(zip(output_names, output))

compare_results(expected["index1Vec"], actual_output["index1Vec"], decimal=5)
compare_results(expected["index2Vec"], actual_output["index2Vec"], decimal=5)


@unittest.skipIf(sys.version_info < (3, 8),
reason="pickle fails on python 3.7")
def test_model_onehot_encoder_2(self):
"""
Testing ONNX conversion for Spark OneHotEncoder when handleInvalid is set to "keep" and dropLast set to True.
"""
encoder = OneHotEncoder(inputCols=['index1', 'index2'], outputCols=['index1Vec', 'index2Vec'], handleInvalid="keep", dropLast=True)
data = self.spark.createDataFrame([(0.0,5.0,), (1.0,4.0,), (2.0,3.0,), (2.0,2.0,), (0.0,1.0,), (2.0,0.0,)], ['index1','index2'])
test = self.spark.createDataFrame([(3.0,7.0,)], ['index1','index2']) # invalid data

model = encoder.fit(data)
model_onnx = convert_sparkml(
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([None, 1])), ('index2', FloatTensorType([-1, 1]))],
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
# run the model
predicted = model.transform(test)
data_np = {
"index1": test.select("index1").toPandas().values.astype(numpy.float32),
"index2": test.select("index2").toPandas().values.astype(numpy.float32)
}

predicted_np_1 = predicted.select("index1Vec").toPandas().index1Vec.apply(lambda x: x.toArray()).values
predicted_np_2 = predicted.select("index2Vec").toPandas().index2Vec.apply(lambda x: x.toArray()).values
expected = {"index1Vec": numpy.asarray(predicted_np_1.tolist()), "index2Vec": numpy.asarray(predicted_np_2.tolist())}

paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlOneHotEncoder")
onnx_model_path = paths[-1]
output, output_shapes = run_onnx_model(['indexVec'], data_np, onnx_model_path)
compare_results(expected, output, decimal=5)

output_names = ['index1Vec', 'index2Vec']
output, output_shapes = run_onnx_model(output_names, data_np, onnx_model_path)
actual_output = dict(zip(output_names, output))

compare_results(expected["index1Vec"], actual_output["index1Vec"], decimal=5)
compare_results(expected["index2Vec"], actual_output["index2Vec"], decimal=5)

if __name__ == "__main__":
unittest.main()
13 changes: 3 additions & 10 deletions tests/sparkml/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_model_pipeline_2_stage(self):
stages = []
for col in cols:
stages.append(StringIndexer(inputCol=col, outputCol=col+'_index', handleInvalid='skip'))
stages.append(OneHotEncoder(inputCols=[col+'_index'], outputCols=[col+'_vec']))
stages.append(OneHotEncoder(inputCols=[col+'_index'], outputCols=[col+'_vec'], dropLast=False))

pipeline = Pipeline(stages=stages)

Expand All @@ -149,7 +149,8 @@ def test_model_pipeline_2_stage(self):
predicted.toPandas().education_vec.apply(lambda x: pandas.Series(x.toArray())).values,
predicted.toPandas().marital_status_vec.apply(lambda x: pandas.Series(x.toArray())).values
]
expected = [numpy.asarray([expand_one_hot_vec(x) for x in row]) for row in predicted_np]

expected = [numpy.asarray(row) for row in predicted_np]
paths = save_data_models(data_np, expected, model, model_onnx,
basename="SparkmlPipeline_2Stage")
onnx_model_path = paths[-1]
Expand All @@ -158,13 +159,5 @@ def test_model_pipeline_2_stage(self):
compare_results(expected, output, decimal=5)


def expand_one_hot_vec(v):
import numpy
if numpy.amax(v) == 1:
return v.tolist() + [0]
else:
return v.tolist() + [1]


if __name__ == "__main__":
unittest.main()