[jvm-packages] support spark 2.4 and compatibility test with previous xgboost version (#4377)
* bump spark version * keep float.nan * handle brokenly changed name/value * add test * add model files * add model files * update doc
This commit is contained in:
@@ -22,12 +22,17 @@ import org.json4s.JsonAST.JObject
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.param.Params
|
||||
import org.apache.spark.ml.param.{Param, Params}
|
||||
import org.apache.spark.ml.util.MLReader
|
||||
|
||||
// This originates from apache-spark DefaultPramsReader copy paste
|
||||
private[spark] object DefaultXGBoostParamsReader {
|
||||
|
||||
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
||||
|
||||
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
||||
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
|
||||
|
||||
/**
|
||||
* All info from metadata file.
|
||||
*
|
||||
@@ -103,6 +108,14 @@ private[spark] object DefaultXGBoostParamsReader {
|
||||
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
|
||||
}
|
||||
|
||||
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
|
||||
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
|
||||
}
|
||||
|
||||
private def handleBrokenlyChangedName(paramName: String): String = {
|
||||
paramNameCompatibilityMap.getOrElse(paramName, paramName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract Params from metadata, and set them in the instance.
|
||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||
@@ -113,9 +126,9 @@ private[spark] object DefaultXGBoostParamsReader {
|
||||
metadata.params match {
|
||||
case JObject(pairs) =>
|
||||
pairs.foreach { case (paramName, jsonValue) =>
|
||||
val param = instance.getParam(paramName)
|
||||
val param = instance.getParam(handleBrokenlyChangedName(paramName))
|
||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||
instance.set(param, value)
|
||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
|
||||
Reference in New Issue
Block a user