Skip to content

Commit

Permalink
fix: ONNX conversion for Spark OneHotEncoder model
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed May 21, 2022
1 parent f87e306 commit f8587bc
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 76 deletions.
3 changes: 2 additions & 1 deletion onnxmltools/convert/common/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from onnxconverter_common.container import (
RawModelContainer,
CommonSklearnModelContainer
CommonSklearnModelContainer,
ModelComponentContainer,
)


Expand Down
121 changes: 62 additions & 59 deletions onnxmltools/convert/sparkml/operator_converters/onehot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,77 @@
from ...common._registration import register_converter, register_shape_calculator
from ....proto import onnx_proto
from ...common.data_types import FloatTensorType
from ...common._topology import Operator, Scope
from ...common._container import ModelComponentContainer
from pyspark.ml.feature import OneHotEncoderModel
from typing import List

def _get_categories(operator: Operator) -> List[List[int]]:
op: OneHotEncoderModel = operator.raw_operator
categorySizes: List[int] = op.categorySizes

# This is necessary to match SparkML's OneHotEncoder behavior.
# If handleInvalid is set to "keep", an extra "category" indicating invalid values is added as last category:
# - if dropLast is set to false, then an extra bit will be added for invalid input, which does not match ONNX's OneHotEncoder behavior.
# - if dropLast is set to true, then invalid values are encoded as all-zeros vector, which matches ONNX's OneHotEncoder behavior when "zeros" is set to 1.
# If handleInvalid is set to "error", then:
# - if dropLast is set to false, then nothing will be added or dropped, and matches ONNX's OneHotEncoder behavior when "zeros" is set to 0.
# - if dropLast is set to true, then the last bit will be dropped, which does not match ONNX's OneHotEncoder behavior.
if (op.getHandleInvalid() == "keep" and op.getDropLast() == False) or (
op.getHandleInvalid() == "error" and op.getDropLast() == True
):
raise RuntimeError(
f"The 'handleInvalid' and 'dropLast' parameters must be set to ('keep', True) or ('error', False), but got ('{op.getHandleInvalid()}', {op.getDropLast()}) instead."
)

return [list(range(0, size)) for size in categorySizes]

def convert_sparkml_one_hot_encoder(scope: Scope, operator: Operator, container: ModelComponentContainer):
categories = _get_categories(operator)
N = operator.inputs[0].type.shape[0]

zeros = 1 if operator.raw_operator.getHandleInvalid() == "keep" else 0

def convert_sparkml_one_hot_encoder(scope, operator, container):
op = operator.raw_operator
C = operator.inputs[0].type.shape[1]

# encoded_slot_sizes[i] is the number of output coordinates associated with the ith categorical feature

# Variable names produced by one-hot encoders. Each of them is the encoding result of a categorical feature.
final_variable_names = []
final_variable_lengths = []
for i in range(0, len(op.categorySizes)):
catSize = op.categorySizes[i]
cats = range(0, catSize)
# Put a feature index we want to encode to a tensor
index_variable_name = scope.get_unique_variable_name('target_index')
container.add_initializer(index_variable_name, onnx_proto.TensorProto.INT64, [1], [i])

# Extract the categorical feature from the original input tensor
extracted_feature_name = scope.get_unique_variable_name('extracted_feature_at_' + str(i))
extractor_type = 'ArrayFeatureExtractor'
extractor_attrs = {'name': scope.get_unique_operator_name(extractor_type)}
container.add_node(extractor_type, [operator.inputs[0].full_name, index_variable_name],
extracted_feature_name, op_domain='ai.onnx.ml', **extractor_attrs)

# Encode the extracted categorical feature as a one-hot vector
encoder_type = 'OneHotEncoder'
encoder_attrs = {'name': scope.get_unique_operator_name(encoder_type), 'cats_int64s': cats}
encoded_feature_name = scope.get_unique_variable_name('encoded_feature_at_' + str(i))
container.add_node(encoder_type, extracted_feature_name, encoded_feature_name, op_domain='ai.onnx.ml',
**encoder_attrs)

# Collect features produce by one-hot encoders
final_variable_names.append(encoded_feature_name)
# For each categorical value, the length of its encoded result is the number of all possible categorical values
final_variable_lengths.append(catSize)

# Combine encoded features and passed features
collector_type = 'FeatureVectorizer'
collector_attrs = {'name': scope.get_unique_operator_name(collector_type)}
collector_attrs['inputdimensions'] = final_variable_lengths
container.add_node(collector_type, final_variable_names, operator.outputs[0].full_name,
op_domain='ai.onnx.ml', **collector_attrs)
for i in range(0, len(categories)):
encoder_type = "OneHotEncoder"

# Set zeros to 0 to match the "error" handleInvalid behavior of SparkML's OneHotEncoder.
encoder_attrs = {
"name": scope.get_unique_operator_name(encoder_type),
"cats_int64s": categories[i],
"zeros": zeros,
}

register_converter('pyspark.ml.feature.OneHotEncoderModel', convert_sparkml_one_hot_encoder)
encoded_feature_name = scope.get_unique_variable_name("encoded_feature_at_" + str(i))
container.add_node(
op_type=encoder_type,
inputs=[operator.inputs[i].full_name],
outputs=[encoded_feature_name],
op_domain="ai.onnx.ml",
op_version=1,
**encoder_attrs,
)

shape_variable_name = scope.get_unique_variable_name("shape_at_" + str(i))
container.add_initializer(shape_variable_name, onnx_proto.TensorProto.INT64, [2], [N, len(categories[i])])

container.add_node(
op_type="Reshape",
inputs=[encoded_feature_name, shape_variable_name],
outputs=[operator.outputs[i].full_name],
op_domain="ai.onnx",
op_version=13,
)

def calculate_sparkml_one_hot_encoder_output_shapes(operator):
'''
Allowed input/output patterns are
1. [N, C] ---> [N, C']
2. [N, 'None'] ---> [N, 'None']
'''
op = operator.raw_operator

# encoded_slot_sizes[i] is the number of output coordinates associated with the ith categorical feature.
encoded_slot_sizes = op.categorySizes
register_converter('pyspark.ml.feature.OneHotEncoderModel', convert_sparkml_one_hot_encoder)


def calculate_sparkml_one_hot_encoder_output_shapes(operator: Operator):
categories = _get_categories(operator)
N = operator.inputs[0].type.shape[0]
# Calculate the output feature length by replacing the count of categorical
# features with their encoded widths
if operator.inputs[0].type.shape[1] != 'None':
C = operator.inputs[0].type.shape[1] - 1 + sum(encoded_slot_sizes)
else:
C = 'None'

operator.outputs[0].type = FloatTensorType([N, C])
for i, output in enumerate(operator.outputs):
output.type = FloatTensorType([N, len(categories[i])])


register_shape_calculator('pyspark.ml.feature.OneHotEncoderModel', calculate_sparkml_one_hot_encoder_output_shapes)
Expand Down
6 changes: 5 additions & 1 deletion onnxmltools/convert/sparkml/ops_input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,12 @@ def build_io_name_map():
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")]
),
"pyspark.ml.feature.OneHotEncoderModel": (
lambda model: model.getOrDefault("inputCols"),
lambda model: model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")],
lambda model: model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")],
),
"pyspark.ml.feature.StringIndexerModel": (
lambda model: [model.getOrDefault("inputCol")],
Expand Down
76 changes: 61 additions & 15 deletions tests/sparkml/test_onehot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,76 @@ class TestSparkmlOneHotEncoder(SparkMlTestCase):

@unittest.skipIf(sys.version_info < (3, 8),
reason="pickle fails on python 3.7")
def test_model_onehot_encoder(self):
encoder = OneHotEncoder(inputCols=['index'], outputCols=['indexVec'])
data = self.spark.createDataFrame(
[(0.0,), (1.0,), (2.0,), (2.0,), (0.0,), (2.0,)], ['index'])
def test_model_onehot_encoder_1(self):
"""
Testing ONNX conversion for Spark OneHotEncoder when handleInvalid is set to "error" and dropLast set to False.
"""
encoder = OneHotEncoder(inputCols=['index1', 'index2'], outputCols=['index1Vec', 'index2Vec'], handleInvalid="error", dropLast=False)
data = self.spark.createDataFrame([(0.0,5.0,), (1.0,4.0,), (2.0,3.0,), (2.0,2.0,), (0.0,1.0,), (2.0,0.0,)], ['index1','index2'])
model = encoder.fit(data)
model_onnx = convert_sparkml(
model, 'Sparkml OneHotEncoder', [('index', FloatTensorType([None, 1]))],
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([-1, 1])), ('index2', FloatTensorType([-1, 1]))],
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
# run the model
predicted = model.transform(data)
data_np = data.select("index").toPandas().values.astype(numpy.float32)
predicted_np = predicted.select("indexVec").toPandas().indexVec.apply(
lambda x: x.toArray().tolist()).values
expected = numpy.asarray(
[x + [0] if numpy.amax(x) == 1 else x + [1] for x in predicted_np])

paths = save_data_models(data_np, expected, model, model_onnx,
basename="SparkmlOneHotEncoder")
data_np = {
"index1": data.select("index1").toPandas().values.astype(numpy.float32),
"index2": data.select("index2").toPandas().values.astype(numpy.float32)
}

predicted_np_1 = predicted.select("index1Vec").toPandas().index1Vec.apply(lambda x: x.toArray()).values
predicted_np_2 = predicted.select("index2Vec").toPandas().index2Vec.apply(lambda x: x.toArray()).values
expected = {"index1Vec": numpy.asarray(predicted_np_1.tolist()), "index2Vec": numpy.asarray(predicted_np_2.tolist())}

paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlOneHotEncoder")
onnx_model_path = paths[-1]

output_names = ['index1Vec', 'index2Vec']
output, output_shapes = run_onnx_model(output_names, data_np, onnx_model_path)
actual_output = dict(zip(output_names, output))

compare_results(expected["index1Vec"], actual_output["index1Vec"], decimal=5)
compare_results(expected["index2Vec"], actual_output["index2Vec"], decimal=5)


@unittest.skipIf(sys.version_info < (3, 8),
reason="pickle fails on python 3.7")
def test_model_onehot_encoder_2(self):
"""
Testing ONNX conversion for Spark OneHotEncoder when handleInvalid is set to "keep" and dropLast set to True.
"""
encoder = OneHotEncoder(inputCols=['index1', 'index2'], outputCols=['index1Vec', 'index2Vec'], handleInvalid="keep", dropLast=True)
data = self.spark.createDataFrame([(0.0,5.0,), (1.0,4.0,), (2.0,3.0,), (2.0,2.0,), (0.0,1.0,), (2.0,0.0,)], ['index1','index2'])
test = self.spark.createDataFrame([(3.0,7.0,)], ['index1','index2']) # invalid data

model = encoder.fit(data)
model_onnx = convert_sparkml(
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([-1, 1])), ('index2', FloatTensorType([-1, 1]))],
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
# run the model
predicted = model.transform(test)
data_np = {
"index1": test.select("index1").toPandas().values.astype(numpy.float32),
"index2": test.select("index2").toPandas().values.astype(numpy.float32)
}

predicted_np_1 = predicted.select("index1Vec").toPandas().index1Vec.apply(lambda x: x.toArray()).values
predicted_np_2 = predicted.select("index2Vec").toPandas().index2Vec.apply(lambda x: x.toArray()).values
expected = {"index1Vec": numpy.asarray(predicted_np_1.tolist()), "index2Vec": numpy.asarray(predicted_np_2.tolist())}

paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlOneHotEncoder")
onnx_model_path = paths[-1]
output, output_shapes = run_onnx_model(['indexVec'], data_np, onnx_model_path)
compare_results(expected, output, decimal=5)

output_names = ['index1Vec', 'index2Vec']
output, output_shapes = run_onnx_model(output_names, data_np, onnx_model_path)
actual_output = dict(zip(output_names, output))

compare_results(expected["index1Vec"], actual_output["index1Vec"], decimal=5)
compare_results(expected["index2Vec"], actual_output["index2Vec"], decimal=5)

if __name__ == "__main__":
unittest.main()

0 comments on commit f8587bc

Please sign in to comment.