support kryo serialization
This commit is contained in:
@@ -29,8 +29,8 @@ import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||
|
||||
class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
|
||||
@@ -171,4 +171,23 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
}
|
||||
customSparkContext.stop()
|
||||
}
|
||||
|
||||
test("kryoSerializer test") {
|
||||
sc.stop()
|
||||
sc = null
|
||||
val eval = new EvalError()
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
||||
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
|
||||
customSparkContext.stop()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user