[jvm-packages] support spark 2.4 and compatibility test with previous xgboost version (#4377)
* bump spark version * keep float.nan * handle brokenly changed name/value * add test * add model files * add model files * update doc
This commit is contained in:
parent
711397d645
commit
65db8d0626
@ -61,9 +61,9 @@ and then refer to the snapshot dependency by adding:
|
|||||||
<version>next_version_num-SNAPSHOT</version>
|
<version>next_version_num-SNAPSHOT</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
.. note:: XGBoost4J-Spark requires Apache Spark 2.3+
|
.. note:: XGBoost4J-Spark requires Apache Spark 2.4+
|
||||||
|
|
||||||
XGBoost4J-Spark now requires **Apache Spark 2.3+**. Latest versions of XGBoost4J-Spark uses facilities of `org.apache.spark.ml.param.shared` extensively to provide for a tight integration with Spark MLLIB framework, and these facilities are not fully available on earlier versions of Spark.
|
XGBoost4J-Spark now requires **Apache Spark 2.4+**. Latest versions of XGBoost4J-Spark uses facilities of `org.apache.spark.ml.param.shared` extensively to provide for a tight integration with Spark MLLIB framework, and these facilities are not fully available on earlier versions of Spark.
|
||||||
|
|
||||||
Also, make sure to install Spark directly from `Apache website <https://spark.apache.org/>`_. **Upstream XGBoost is not guaranteed to work with third-party distributions of Spark, such as Cloudera Spark.** Consult appropriate third parties to obtain their distribution of XGBoost.
|
Also, make sure to install Spark directly from `Apache website <https://spark.apache.org/>`_. **Upstream XGBoost is not guaranteed to work with third-party distributions of Spark, such as Cloudera Spark.** Consult appropriate third parties to obtain their distribution of XGBoost.
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@
|
|||||||
<maven.compiler.source>1.7</maven.compiler.source>
|
<maven.compiler.source>1.7</maven.compiler.source>
|
||||||
<maven.compiler.target>1.7</maven.compiler.target>
|
<maven.compiler.target>1.7</maven.compiler.target>
|
||||||
<flink.version>1.5.0</flink.version>
|
<flink.version>1.5.0</flink.version>
|
||||||
<spark.version>2.3.3</spark.version>
|
<spark.version>2.4.1</spark.version>
|
||||||
<scala.version>2.11.12</scala.version>
|
<scala.version>2.11.12</scala.version>
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
<scala.binary.version>2.11</scala.binary.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|||||||
@ -22,12 +22,17 @@ import org.json4s.JsonAST.JObject
|
|||||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.ml.param.Params
|
import org.apache.spark.ml.param.{Param, Params}
|
||||||
import org.apache.spark.ml.util.MLReader
|
import org.apache.spark.ml.util.MLReader
|
||||||
|
|
||||||
// This originates from apache-spark DefaultPramsReader copy paste
|
// This originates from apache-spark DefaultPramsReader copy paste
|
||||||
private[spark] object DefaultXGBoostParamsReader {
|
private[spark] object DefaultXGBoostParamsReader {
|
||||||
|
|
||||||
|
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
||||||
|
|
||||||
|
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
||||||
|
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* All info from metadata file.
|
* All info from metadata file.
|
||||||
*
|
*
|
||||||
@ -103,6 +108,14 @@ private[spark] object DefaultXGBoostParamsReader {
|
|||||||
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
|
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
|
||||||
|
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
private def handleBrokenlyChangedName(paramName: String): String = {
|
||||||
|
paramNameCompatibilityMap.getOrElse(paramName, paramName)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extract Params from metadata, and set them in the instance.
|
* Extract Params from metadata, and set them in the instance.
|
||||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||||
@ -113,9 +126,9 @@ private[spark] object DefaultXGBoostParamsReader {
|
|||||||
metadata.params match {
|
metadata.params match {
|
||||||
case JObject(pairs) =>
|
case JObject(pairs) =>
|
||||||
pairs.foreach { case (paramName, jsonValue) =>
|
pairs.foreach { case (paramName, jsonValue) =>
|
||||||
val param = instance.getParam(paramName)
|
val param = instance.getParam(handleBrokenlyChangedName(paramName))
|
||||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||||
instance.set(param, value)
|
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||||
}
|
}
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
|||||||
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
{"class":"ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel","timestamp":1555350539033,"sparkVersion":"2.3.2-uber-109","uid":"xgbc_5e7bec215a4c","paramMap":{"useExternalMemory":false,"trainTestRatio":1.0,"alpha":0.0,"seed":0,"numWorkers":100,"skipDrop":0.0,"treeLimit":0,"silent":0,"trackerConf":{"workerConnectionTimeout":0,"trackerImpl":"python"},"missing":"NaN","colsampleBylevel":1.0,"probabilityCol":"probability","checkpointPath":"","lambda":1.0,"rawPredictionCol":"rawPrediction","eta":0.3,"numEarlyStoppingRounds":0,"growPolicy":"depthwise","gamma":0.0,"sampleType":"uniform","maxDepth":6,"rateDrop":0.0,"objective":"reg:linear","customObj":null,"lambdaBias":0.0,"baseScore":0.5,"labelCol":"label","minChildWeight":1.0,"customEval":null,"normalizeType":"tree","maxBin":16,"nthread":4,"numRound":20,"colsampleBytree":1.0,"predictionCol":"prediction","subsample":1.0,"timeoutRequestWorkers":1800000,"featuresCol":"features","evalMetric":"error","sketchEps":0.03,"scalePosWeight":1.0,"checkpointInterval":-1,"maxDeltaStep":0.0,"treeMethod":"approx"}}
|
||||||
@ -19,9 +19,11 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import java.io.{File, FileNotFoundException}
|
import java.io.{File, FileNotFoundException}
|
||||||
import java.util.Arrays
|
import java.util.Arrays
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import scala.io.Source
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.ml.feature._
|
import org.apache.spark.ml.feature._
|
||||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.network.util.JavaUtils
|
import org.apache.spark.network.util.JavaUtils
|
||||||
@ -162,5 +164,17 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
|||||||
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
|
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
|
||||||
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("cross-version model loading (0.82)") {
|
||||||
|
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||||
|
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||||
|
val r = new Random(0)
|
||||||
|
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||||
|
toDF("feature", "label")
|
||||||
|
val assembler = new VectorAssembler()
|
||||||
|
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||||
|
.setOutputCol("features")
|
||||||
|
model.transform(assembler.transform(df)).show()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -261,6 +261,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setInputCols(Array("col1", "col2", "col3"))
|
.setInputCols(Array("col1", "col2", "col3"))
|
||||||
.setOutputCol("features")
|
.setOutputCol("features")
|
||||||
|
.setHandleInvalid("keep")
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user