Skip to content

Commit

Permalink
code review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed May 25, 2022
1 parent e297f6e commit 5cc6a82
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/sparkml/test_onehot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_model_onehot_encoder_1(self):
data = self.spark.createDataFrame([(0.0,5.0,), (1.0,4.0,), (2.0,3.0,), (2.0,2.0,), (0.0,1.0,), (2.0,0.0,)], ['index1','index2'])
model = encoder.fit(data)
model_onnx = convert_sparkml(
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([None, 1])), ('index2', FloatTensorType([-1, 1]))],
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([None, 1])), ('index2', FloatTensorType([None, 1]))],
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_model_onehot_encoder_2(self):

model = encoder.fit(data)
model_onnx = convert_sparkml(
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([None, 1])), ('index2', FloatTensorType([-1, 1]))],
model, 'Sparkml OneHotEncoder', [('index1', FloatTensorType([None, 1])), ('index2', FloatTensorType([None, 1]))],
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
Expand Down

0 comments on commit 5cc6a82

Please sign in to comment.