getter of XGBoostModel
This commit is contained in:
parent
e3fa7753f5
commit
3a951d0ab8
@ -23,14 +23,14 @@ import org.apache.spark.rdd.RDD
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
|
||||
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
|
||||
class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Serializable {
|
||||
|
||||
/**
|
||||
* Predict result with the given testset (represented as RDD)
|
||||
*/
|
||||
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
@ -41,7 +41,7 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
|
||||
* predict result given the test data (represented as DMatrix)
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet, true, 0)
|
||||
_booster.predict(testSet, true, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -52,7 +52,12 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
|
||||
def saveModelAsHadoopFile(modelPath: String): Unit = {
|
||||
val path = new Path(modelPath)
|
||||
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
|
||||
booster.saveModel(outputStream)
|
||||
_booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
|
||||
/**
|
||||
* get the booster instance of this model
|
||||
*/
|
||||
def booster: Booster = _booster
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user