code formatting in XGBoostModel

This commit is contained in:
CodingCat 2016-03-09 10:30:44 -05:00
parent c9830cd8b1
commit 852c5a4b32

View File

@ -26,12 +26,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable { class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
/** /**
* Predict result with the given testset (represented as RDD) * Predict result with the given testset (represented as RDD)
*/ */
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = { def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._ import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(booster) val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
testSet.mapPartitions { testSamples => testSet.mapPartitions { testSamples =>
val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predict(dMatrix)) Iterator(broadcastBooster.value.predict(dMatrix))
@ -46,10 +45,10 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
} }
/** /**
* Save the model as to HDFS-compatible file system. * Save the model as to HDFS-compatible file system.
* *
* @param modelPath The model path as in Hadoop path. * @param modelPath The model path as in Hadoop path.
*/ */
def saveModelToHadoop(modelPath: String): Unit = { def saveModelToHadoop(modelPath: String): Unit = {
val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath)) val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath))
booster.saveModel(outputStream) booster.saveModel(outputStream)