diff --git a/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py b/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py index 3aeb8c42..384c8a5a 100644 --- a/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py +++ b/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py @@ -4,7 +4,7 @@ import os import time import numpy - +from pyspark.sql import SparkSession class SparkMLTree(dict): pass @@ -31,10 +31,11 @@ def sparkml_tree_dataset_to_sklearn(tree_df, is_classifier): return tree -def save_read_sparkml_model_data(spark, model): +def save_read_sparkml_model_data(spark: SparkSession, model): tdir = tempfile.tempdir if tdir is None: - tdir = spark.util.Utils.createTempDir().getAbsolutePath() + local_dir = spark._jvm.org.apache.spark.util.Utils.getLocalDir(spark._jsc.sc().conf()) + tdir = spark._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "onnx").getAbsolutePath() if tdir is None: raise FileNotFoundError( "Unable to create a temporary directory for model '{}'"