getter of XGBoostModel

This commit is contained in:
CodingCat 2016-03-14 07:26:49 -04:00
parent e3fa7753f5
commit 3a951d0ab8

View File

@ -23,14 +23,14 @@ import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Serializable {
/**
* Predict result with the given testset (represented as RDD)
*/
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(booster)
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predict(dMatrix))
@ -41,7 +41,7 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
* predict result given the test data (represented as DMatrix)
*/
def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet, true, 0)
_booster.predict(testSet, true, 0)
}
/**
@ -52,7 +52,12 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
def saveModelAsHadoopFile(modelPath: String): Unit = {
val path = new Path(modelPath)
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
booster.saveModel(outputStream)
_booster.saveModel(outputStream)
outputStream.close()
}
/**
* get the booster instance of this model
*/
def booster: Booster = _booster
}