[jvm-packages] Fix model compatibility (#7845)
This commit is contained in:
parent
686caad40c
commit
a94e1b172e
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,18 +16,22 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
import org.json4s.{DefaultFormats, JValue}
|
import org.json4s.{DefaultFormats, JValue}
|
||||||
import org.json4s.JsonAST.JObject
|
import org.json4s.JsonAST.JObject
|
||||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.ml.param.{Param, Params}
|
import org.apache.spark.ml.param.Params
|
||||||
import org.apache.spark.ml.util.MLReader
|
import org.apache.spark.ml.util.MLReader
|
||||||
|
|
||||||
// This originates from apache-spark DefaultPramsReader copy paste
|
// This originates from apache-spark DefaultPramsReader copy paste
|
||||||
private[spark] object DefaultXGBoostParamsReader {
|
private[spark] object DefaultXGBoostParamsReader {
|
||||||
|
|
||||||
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
||||||
|
|
||||||
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
||||||
@ -126,9 +130,16 @@ private[spark] object DefaultXGBoostParamsReader {
|
|||||||
metadata.params match {
|
metadata.params match {
|
||||||
case JObject(pairs) =>
|
case JObject(pairs) =>
|
||||||
pairs.foreach { case (paramName, jsonValue) =>
|
pairs.foreach { case (paramName, jsonValue) =>
|
||||||
val param = instance.getParam(handleBrokenlyChangedName(paramName))
|
val finalName = handleBrokenlyChangedName(paramName)
|
||||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
|
||||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
// So we need to check if the parameter exists instead of blindly setting it.
|
||||||
|
if (instance.hasParam(finalName)) {
|
||||||
|
val param = instance.getParam(finalName)
|
||||||
|
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||||
|
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||||
|
} else {
|
||||||
|
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user