Skip to content

Commit

Permalink
requiring target opset to be >= 7 in order to use broadcast in add an…
Browse files Browse the repository at this point in the history
…d div
  • Loading branch information
memoryz committed Jun 2, 2022
1 parent 81e003a commit 273515a
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions onnxmltools/convert/sparkml/operator_converters/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import numpy as np

def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: ModelComponentContainer):
if container.target_opset < 7:
raise NotImplementedError("Converting to ONNX for KMeansModel is not supported in opset < 7")

op: KMeansModel = operator.raw_operator
centers: np.ndarray = np.vstack(op.clusterCenters())

Expand Down Expand Up @@ -65,6 +68,7 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
op_type="Gemm",
inputs=[operator.inputs[0].full_name, centers_variable_name, input_row_squared_sum_variable_name],
outputs=[gemm_output_variable_name],
op_version=7,
**gemm_attrs
)

Expand All @@ -74,6 +78,7 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
op_type="Add",
inputs=[gemm_output_variable_name, centers_row_squared_sum_variable_name],
outputs=[distance_output_variable_name],
op_version=7,
)
elif op.getDistanceMeasure() == "cosine":
# centers_row_norm2: [1 x K]
Expand Down Expand Up @@ -118,6 +123,7 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
op_type="Gemm",
inputs=[operator.inputs[0].full_name, centers_variable_name, zeros_variable_name],
outputs=[gemm_output_variable_name],
op_version=7,
**gemm_attrs
)

Expand All @@ -127,14 +133,14 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
op_type="Div",
inputs=[gemm_output_variable_name, input_row_norm2_variable_name],
outputs=[div_output_variable_name],
op_version=7, # Setting to version 7 for broadcasting support
op_version=7,
)
cosine_similarity_output_variable_name = scope.get_unique_variable_name("cosine_similarity_output")
container.add_node(
op_type="Div",
inputs=[div_output_variable_name, centers_row_norm2_variable_name],
outputs=[cosine_similarity_output_variable_name],
op_version=7, # Setting to version 7 for broadcasting support
op_version=7,
)

# Cosine distance - 1 = -Cosine similarity: [N x K]
Expand Down

0 comments on commit 273515a

Please sign in to comment.