code formatting in XGBoostModel
This commit is contained in:
parent
c9830cd8b1
commit
852c5a4b32
@ -26,12 +26,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
|
||||
|
||||
/**
|
||||
* Predict result with the given testset (represented as RDD)
|
||||
*/
|
||||
* Predict result with the given testset (represented as RDD)
|
||||
*/
|
||||
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
||||
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
@ -46,10 +45,10 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as to HDFS-compatible file system.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
* Save the model as to HDFS-compatible file system.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelToHadoop(modelPath: String): Unit = {
|
||||
val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath))
|
||||
booster.saveModel(outputStream)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user