From 4e2da3b0aeba14eaa48383f08dbf814e7a346d97 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 2 Feb 2024 17:17:58 +0100 Subject: [PATCH] update style Signed-off-by: Xavier Dupre --- .../operator_converters/TreeEnsemble.py | 18 +++---- .../convert/sparkml/ops_input_output.py | 48 ++++++++++++------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/onnxmltools/convert/coreml/operator_converters/TreeEnsemble.py b/onnxmltools/convert/coreml/operator_converters/TreeEnsemble.py index db52cc01..6dda3fbf 100644 --- a/onnxmltools/convert/coreml/operator_converters/TreeEnsemble.py +++ b/onnxmltools/convert/coreml/operator_converters/TreeEnsemble.py @@ -45,9 +45,9 @@ def convert_tree_ensemble_model(scope, operator, container): op_type = "TreeEnsembleClassifier" prefix = "class" nodes = raw_model.treeEnsembleClassifier.treeEnsemble.nodes - attrs[ - "base_values" - ] = raw_model.treeEnsembleClassifier.treeEnsemble.basePredictionValue + attrs["base_values"] = ( + raw_model.treeEnsembleClassifier.treeEnsemble.basePredictionValue + ) attrs["post_transform"] = get_onnx_tree_post_transform( raw_model.treeEnsembleClassifier.postEvaluationTransform ) @@ -72,12 +72,12 @@ def convert_tree_ensemble_model(scope, operator, container): op_type = "TreeEnsembleRegressor" prefix = "target" nodes = raw_model.treeEnsembleRegressor.treeEnsemble.nodes - attrs[ - "base_values" - ] = raw_model.treeEnsembleRegressor.treeEnsemble.basePredictionValue - attrs[ - "n_targets" - ] = raw_model.treeEnsembleRegressor.treeEnsemble.numPredictionDimensions + attrs["base_values"] = ( + raw_model.treeEnsembleRegressor.treeEnsemble.basePredictionValue + ) + attrs["n_targets"] = ( + raw_model.treeEnsembleRegressor.treeEnsemble.numPredictionDimensions + ) attrs["post_transform"] = get_onnx_tree_post_transform( raw_model.treeEnsembleRegressor.postEvaluationTransform ) diff --git a/onnxmltools/convert/sparkml/ops_input_output.py b/onnxmltools/convert/sparkml/ops_input_output.py index ae241667..d74f95dd 100644 --- a/onnxmltools/convert/sparkml/ops_input_output.py +++ b/onnxmltools/convert/sparkml/ops_input_output.py @@ -134,12 +134,16 @@ def build_io_name_map(): lambda model: [model.getOrDefault("predictionCol")], ), "pyspark.ml.feature.ImputerModel": ( - lambda model: model.getOrDefault("inputCols") - if model.isSet("inputCols") - else [model.getOrDefault("inputCol")], - lambda model: model.getOrDefault("outputCols") - if model.isSet("outputCols") - else [model.getOrDefault("outputCol")], + lambda model: ( + model.getOrDefault("inputCols") + if model.isSet("inputCols") + else [model.getOrDefault("inputCol")] + ), + lambda model: ( + model.getOrDefault("outputCols") + if model.isSet("outputCols") + else [model.getOrDefault("outputCol")] + ), ), "pyspark.ml.feature.MaxAbsScalerModel": ( lambda model: [model.getOrDefault("inputCol")], @@ -177,20 +181,28 @@ def build_io_name_map(): ], ), "pyspark.ml.feature.OneHotEncoderModel": ( - lambda model: model.getOrDefault("inputCols") - if model.isSet("inputCols") - else [model.getOrDefault("inputCol")], - lambda model: model.getOrDefault("outputCols") - if model.isSet("outputCols") - else [model.getOrDefault("outputCol")], + lambda model: ( + model.getOrDefault("inputCols") + if model.isSet("inputCols") + else [model.getOrDefault("inputCol")] + ), + lambda model: ( + model.getOrDefault("outputCols") + if model.isSet("outputCols") + else [model.getOrDefault("outputCol")] + ), ), "pyspark.ml.feature.StringIndexerModel": ( - lambda model: model.getOrDefault("inputCols") - if model.isSet("inputCols") - else [model.getOrDefault("inputCol")], - lambda model: model.getOrDefault("outputCols") - if model.isSet("outputCols") - else [model.getOrDefault("outputCol")], + lambda model: ( + model.getOrDefault("inputCols") + if model.isSet("inputCols") + else [model.getOrDefault("inputCol")] + ), + lambda model: ( + model.getOrDefault("outputCols") + if model.isSet("outputCols") + else [model.getOrDefault("outputCol")] + ), ), "pyspark.ml.feature.VectorAssembler": ( lambda model: model.getOrDefault("inputCols"),