Skip to content

Commit

Permalink
update style
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
  • Loading branch information
xadupre committed Feb 2, 2024
1 parent d820053 commit 4e2da3b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
18 changes: 9 additions & 9 deletions onnxmltools/convert/coreml/operator_converters/TreeEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def convert_tree_ensemble_model(scope, operator, container):
op_type = "TreeEnsembleClassifier"
prefix = "class"
nodes = raw_model.treeEnsembleClassifier.treeEnsemble.nodes
attrs[
"base_values"
] = raw_model.treeEnsembleClassifier.treeEnsemble.basePredictionValue
attrs["base_values"] = (
raw_model.treeEnsembleClassifier.treeEnsemble.basePredictionValue
)
attrs["post_transform"] = get_onnx_tree_post_transform(
raw_model.treeEnsembleClassifier.postEvaluationTransform
)
Expand All @@ -72,12 +72,12 @@ def convert_tree_ensemble_model(scope, operator, container):
op_type = "TreeEnsembleRegressor"
prefix = "target"
nodes = raw_model.treeEnsembleRegressor.treeEnsemble.nodes
attrs[
"base_values"
] = raw_model.treeEnsembleRegressor.treeEnsemble.basePredictionValue
attrs[
"n_targets"
] = raw_model.treeEnsembleRegressor.treeEnsemble.numPredictionDimensions
attrs["base_values"] = (
raw_model.treeEnsembleRegressor.treeEnsemble.basePredictionValue
)
attrs["n_targets"] = (
raw_model.treeEnsembleRegressor.treeEnsemble.numPredictionDimensions
)
attrs["post_transform"] = get_onnx_tree_post_transform(
raw_model.treeEnsembleRegressor.postEvaluationTransform
)
Expand Down
48 changes: 30 additions & 18 deletions onnxmltools/convert/sparkml/ops_input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,16 @@ def build_io_name_map():
lambda model: [model.getOrDefault("predictionCol")],
),
"pyspark.ml.feature.ImputerModel": (
lambda model: model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")],
lambda model: model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")],
lambda model: (
model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")]
),
lambda model: (
model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")]
),
),
"pyspark.ml.feature.MaxAbsScalerModel": (
lambda model: [model.getOrDefault("inputCol")],
Expand Down Expand Up @@ -177,20 +181,28 @@ def build_io_name_map():
],
),
"pyspark.ml.feature.OneHotEncoderModel": (
lambda model: model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")],
lambda model: model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")],
lambda model: (
model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")]
),
lambda model: (
model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")]
),
),
"pyspark.ml.feature.StringIndexerModel": (
lambda model: model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")],
lambda model: model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")],
lambda model: (
model.getOrDefault("inputCols")
if model.isSet("inputCols")
else [model.getOrDefault("inputCol")]
),
lambda model: (
model.getOrDefault("outputCols")
if model.isSet("outputCols")
else [model.getOrDefault("outputCol")]
),
),
"pyspark.ml.feature.VectorAssembler": (
lambda model: model.getOrDefault("inputCols"),
Expand Down

0 comments on commit 4e2da3b

Please sign in to comment.