code formatting in XGBoostModel

This commit is contained in:
CodingCat 2016-03-09 10:30:44 -05:00
parent c9830cd8b1
commit 852c5a4b32

View File

@ -26,12 +26,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
/**
* Predict result with the given testset (represented as RDD)
*/
* Predict result with the given testset (represented as RDD)
*/
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
testSet.mapPartitions { testSamples =>
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predict(dMatrix))
@ -46,10 +45,10 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
}
/**
* Save the model as to HDFS-compatible file system.
*
* @param modelPath The model path as in Hadoop path.
*/
* Save the model as to HDFS-compatible file system.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelToHadoop(modelPath: String): Unit = {
val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath))
booster.saveModel(outputStream)