Skip to content

Commit

Permalink
[jvm-packages] Fix model compatibility (#7845)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Apr 27, 2022
1 parent 686caad commit a94e1b1
Showing 1 changed file with 16 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -16,18 +16,22 @@

package ml.dmlc.xgboost4j.scala.spark.params

import ml.dmlc.xgboost4j.scala.spark
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JValue}
import org.json4s.JsonAST.JObject
import org.json4s.jackson.JsonMethods.{compact, parse, render}

import org.apache.spark.SparkContext
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.MLReader

// This originates from apache-spark DefaultPramsReader copy paste
private[spark] object DefaultXGBoostParamsReader {

private val logger = LogFactory.getLog("XGBoostSpark")

private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")

private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
Expand Down Expand Up @@ -126,9 +130,16 @@ private[spark] object DefaultXGBoostParamsReader {
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(handleBrokenlyChangedName(paramName))
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, handleBrokenlyChangedValue(paramName, value))
val finalName = handleBrokenlyChangedName(paramName)
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
// So we need to check if the parameter exists instead of blindly setting it.
if (instance.hasParam(finalName)) {
val param = instance.getParam(finalName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, handleBrokenlyChangedValue(paramName, value))
} else {
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
}
}
case _ =>
throw new IllegalArgumentException(
Expand Down

0 comments on commit a94e1b1

Please sign in to comment.