support kryo serialization

This commit is contained in:
CodingCat
2016-03-13 11:55:14 -04:00
parent 9011acf52b
commit f2ef958ebb
6 changed files with 88 additions and 10 deletions

View File

@@ -29,8 +29,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
class XGBoostSuite extends FunSuite with BeforeAndAfter {
@@ -171,4 +171,23 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
}
customSparkContext.stop()
}
test("kryoSerializer test") {
sc.stop()
sc = null
val eval = new EvalError()
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val customSparkContext = new SparkContext(sparkConf)
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
customSparkContext.stop()
}
}