getter of XGBoostModel
This commit is contained in:
parent
e3fa7753f5
commit
3a951d0ab8
@ -23,14 +23,14 @@ import org.apache.spark.rdd.RDD
|
|||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||||
|
|
||||||
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
|
class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict result with the given testset (represented as RDD)
|
* Predict result with the given testset (represented as RDD)
|
||||||
*/
|
*/
|
||||||
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
testSet.mapPartitions { testSamples =>
|
testSet.mapPartitions { testSamples =>
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||||
@ -41,7 +41,7 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
|
|||||||
* predict result given the test data (represented as DMatrix)
|
* predict result given the test data (represented as DMatrix)
|
||||||
*/
|
*/
|
||||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||||
booster.predict(testSet, true, 0)
|
_booster.predict(testSet, true, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -52,7 +52,12 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
|
|||||||
def saveModelAsHadoopFile(modelPath: String): Unit = {
|
def saveModelAsHadoopFile(modelPath: String): Unit = {
|
||||||
val path = new Path(modelPath)
|
val path = new Path(modelPath)
|
||||||
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
|
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
|
||||||
booster.saveModel(outputStream)
|
_booster.saveModel(outputStream)
|
||||||
outputStream.close()
|
outputStream.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the booster instance of this model
|
||||||
|
*/
|
||||||
|
def booster: Booster = _booster
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user