Skip to content

Commit

Permalink
lowering div/add operator to version 13
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Wang <jasowang@microsoft.com>
  • Loading branch information
memoryz committed May 31, 2022
1 parent b6cff38 commit 95a2a4d
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions onnxmltools/convert/sparkml/operator_converters/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
# Euclidean Distance Squared = input_row_squared_sum - 2 * input * Transpose(Center) + Transpose(centers_row_squared_sum)
# [N x K]
container.add_node(
op_type = "Add",
inputs = [gemm_output_variable_name, centers_row_squared_sum_variable_name],
outputs = [distance_output_variable_name],
op_domain = "ai.onnx",
op_version = 14,
op_type="Add",
inputs=[gemm_output_variable_name, centers_row_squared_sum_variable_name],
outputs=[distance_output_variable_name],
op_domain="ai.onnx",
op_version=13,
)
elif op.getDistanceMeasure() == "cosine":
# centers_row_norm2: [1 x K]
Expand Down Expand Up @@ -127,19 +127,19 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
# Cosine similarity = gemm_output / input_row_norm2 / centers_row_norm2: [N x K]
div_output_variable_name = scope.get_unique_variable_name("div_output")
container.add_node(
op_type = "Div",
inputs = [gemm_output_variable_name, input_row_norm2_variable_name],
outputs = [div_output_variable_name],
op_domain = "ai.onnx",
op_version = 14,
op_type="Div",
inputs=[gemm_output_variable_name, input_row_norm2_variable_name],
outputs=[div_output_variable_name],
op_domain="ai.onnx",
op_version=13,
)
cosine_similarity_output_variable_name = scope.get_unique_variable_name("cosine_similarity_output")
container.add_node(
op_type = "Div",
inputs = [div_output_variable_name, centers_row_norm2_variable_name],
outputs = [cosine_similarity_output_variable_name],
op_domain = "ai.onnx",
op_version = 14,
op_type="Div",
inputs=[div_output_variable_name, centers_row_norm2_variable_name],
outputs=[cosine_similarity_output_variable_name],
op_domain="ai.onnx",
op_version=13,
)

# Cosine distance - 1 = -Cosine similarity: [N x K]
Expand Down

0 comments on commit 95a2a4d

Please sign in to comment.