From 3a951d0ab89ea243568c39d5b8ae2fdc001b535a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 14 Mar 2016 07:26:49 -0400 Subject: [PATCH] getter of XGBoostModel --- .../dmlc/xgboost4j/scala/spark/XGBoostModel.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 6dba097a6..14fa3c0f6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -23,14 +23,14 @@ import org.apache.spark.rdd.RDD import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} -class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable { +class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Serializable { /** * Predict result with the given testset (represented as RDD) */ def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = { import DataUtils._ - val broadcastBooster = testSet.sparkContext.broadcast(booster) + val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) Iterator(broadcastBooster.value.predict(dMatrix)) @@ -41,7 +41,7 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri * predict result given the test data (represented as DMatrix) */ def predict(testSet: DMatrix): Array[Array[Float]] = { - booster.predict(testSet, true, 0) + _booster.predict(testSet, true, 0) } /** @@ -52,7 +52,12 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri def saveModelAsHadoopFile(modelPath: String): Unit = { val path = new Path(modelPath) val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path) - booster.saveModel(outputStream) + _booster.saveModel(outputStream) outputStream.close() } + + /** + * get the booster instance of this model + */ + def booster: Booster = _booster }