[jvm-packages] Fix model compatibility (#7845)
This commit is contained in:
parent
686caad40c
commit
a94e1b172e
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -16,18 +16,22 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s.{DefaultFormats, JValue}
|
||||
import org.json4s.JsonAST.JObject
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.param.{Param, Params}
|
||||
import org.apache.spark.ml.param.Params
|
||||
import org.apache.spark.ml.util.MLReader
|
||||
|
||||
// This originates from apache-spark DefaultPramsReader copy paste
|
||||
private[spark] object DefaultXGBoostParamsReader {
|
||||
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
||||
|
||||
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
||||
@ -126,9 +130,16 @@ private[spark] object DefaultXGBoostParamsReader {
|
||||
metadata.params match {
|
||||
case JObject(pairs) =>
|
||||
pairs.foreach { case (paramName, jsonValue) =>
|
||||
val param = instance.getParam(handleBrokenlyChangedName(paramName))
|
||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||
val finalName = handleBrokenlyChangedName(paramName)
|
||||
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
|
||||
// So we need to check if the parameter exists instead of blindly setting it.
|
||||
if (instance.hasParam(finalName)) {
|
||||
val param = instance.getParam(finalName)
|
||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||
} else {
|
||||
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user