[jvm-packages] fix the persistence of XGBoostEstimator (#2265)
* add back train method but mark as deprecated * fix scalastyle error * fix the persistence of XGBoostEstimator * test persistence of a complete pipeline * fix compilation issue * do not allow persist custom_eval and custom_obj * fix the failed tesl
This commit is contained in:
@@ -18,17 +18,21 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler}
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||
import org.apache.spark.sql._
|
||||
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
|
||||
@@ -110,7 +110,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
|
||||
@@ -18,67 +18,84 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import scala.concurrent.duration._
|
||||
|
||||
case class Foobar(TARGET: Int, bar: Double, baz: Double)
|
||||
|
||||
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
super.afterAll()
|
||||
delete(new File("./testxgbPipe"))
|
||||
delete(new File("./test2xgbPipe"))
|
||||
delete(new File("./testxgbEst"))
|
||||
delete(new File("./testxgbModel"))
|
||||
delete(new File("./test2xgbModel"))
|
||||
}
|
||||
|
||||
private def delete(f: File) {
|
||||
if (f.isDirectory()) {
|
||||
for (c <- f.listFiles()) {
|
||||
delete(c)
|
||||
if (f.exists()) {
|
||||
if (f.isDirectory()) {
|
||||
for (c <- f.listFiles()) {
|
||||
delete(c)
|
||||
}
|
||||
}
|
||||
if (!f.delete()) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
if (!f.delete()) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
|
||||
test("test sparks pipeline persistence of dataframe-based model") {
|
||||
// maybe move to shared context, but requires session to import implicits.
|
||||
// what about introducing https://github.com/holdenk/spark-testing-base ?
|
||||
val conf: SparkConf = new SparkConf()
|
||||
.setAppName("foo")
|
||||
.setMaster("local[*]")
|
||||
test("test persistence of XGBoostEstimator") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
xgbEstimator.write.overwrite().save("./testxgbEst")
|
||||
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
|
||||
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
val spark: SparkSession = SparkSession
|
||||
.builder()
|
||||
.config(conf)
|
||||
.getOrCreate()
|
||||
test("test persistence of a complete pipeline") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val r = new Random(0)
|
||||
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
|
||||
pipeline.write.overwrite().save("testxgbPipe")
|
||||
val loadedPipeline = Pipeline.read.load("testxgbPipe")
|
||||
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
|
||||
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
import spark.implicits._
|
||||
test("test persistence of XGBoostModel") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val r = new Random(0)
|
||||
// maybe move to shared context, but requires session to import implicits
|
||||
|
||||
val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8),
|
||||
Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04))
|
||||
.toDS
|
||||
|
||||
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(df.columns
|
||||
.filter(!_.contains("TARGET")))
|
||||
.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
|
||||
val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
|
||||
"tracker_conf" -> TrackerConf(1 minute, "scala")
|
||||
))
|
||||
.setFeaturesCol("features")
|
||||
.setLabelCol("TARGET")
|
||||
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
|
||||
)).setFeaturesCol("features").setLabelCol("label")
|
||||
// separate
|
||||
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
|
||||
predModel.write.overwrite.save("test2xgbPipe")
|
||||
val same2Model = XGBoostModel.load("test2xgbPipe")
|
||||
predModel.write.overwrite.save("test2xgbModel")
|
||||
val same2Model = XGBoostModel.load("test2xgbModel")
|
||||
|
||||
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
|
||||
val predParamMap = predModel.extractParamMap()
|
||||
@@ -93,8 +110,8 @@ class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
||||
|
||||
// chained
|
||||
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
|
||||
predictionModel.write.overwrite.save("testxgbPipe")
|
||||
val sameModel = PipelineModel.load("testxgbPipe")
|
||||
predictionModel.write.overwrite.save("testxgbModel")
|
||||
val sameModel = PipelineModel.load("testxgbModel")
|
||||
|
||||
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
|
||||
Reference in New Issue
Block a user