revise current API
This commit is contained in:
@@ -19,23 +19,21 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import scala.collection.mutable
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError, Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
var boosters: RDD[Booster] = null
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
||||
private implicit def convertBoosterToXGBoostModel(booster: Booster)
|
||||
(implicit sc: SparkContext): XGBoostModel = {
|
||||
new XGBoostModel(booster)
|
||||
}
|
||||
|
||||
@@ -57,27 +55,36 @@ object XGBoost extends Serializable {
|
||||
}.cache()
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param trainingData the trainingset represented as RDD
|
||||
* @param configMap Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param obj the user-defined objective function, null by default
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
val numWorkers = trainingData.partitions.length
|
||||
val sc = trainingData.sparkContext
|
||||
implicit val sc = trainingData.sparkContext
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||
@volatile var booster: Booster = null
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boosters.foreachPartition(_ => ())
|
||||
boosters.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkJobThread.start()
|
||||
val returnVal = tracker.waitFor()
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
if (returnVal == 0) {
|
||||
booster = boosters.first()
|
||||
Some(booster).get
|
||||
boosters.first()
|
||||
} else {
|
||||
try {
|
||||
if (sparkJobThread.isAlive) {
|
||||
@@ -87,21 +94,20 @@ object XGBoost extends Serializable {
|
||||
case ie: InterruptedException =>
|
||||
logger.info("spark job thread is interrupted")
|
||||
}
|
||||
null
|
||||
throw new XGBoostError("XGBoostModel training failed")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
SXGBoost.loadModel(
|
||||
FileSystem
|
||||
.get(new Configuration)
|
||||
.open(new Path(modelPath))))
|
||||
* Load XGBoost model from path in HDFS-compatible file system
|
||||
*
|
||||
* @param modelPath The path of the file representing the model
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoop(modelPath: String)(implicit sparkContext: SparkContext): XGBoostModel = {
|
||||
val dataInStream = FileSystem.get(sparkContext.hadoopConfiguration).open(new Path(modelPath))
|
||||
val xgBoostModel = new XGBoostModel(SXGBoost.loadModel(dataInStream))
|
||||
dataInStream.close()
|
||||
xgBoostModel
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,18 +16,17 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.rdd.RDD
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
|
||||
class XGBoostModel(booster: Booster) extends Serializable {
|
||||
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
|
||||
|
||||
/**
|
||||
* Predict result given testRDD
|
||||
* @param testSet the testSet of Data vectors
|
||||
* @return The predicted RDD
|
||||
* Predict result with the given testset (represented as RDD)
|
||||
*/
|
||||
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||
import DataUtils._
|
||||
@@ -39,18 +38,21 @@ class XGBoostModel(booster: Booster) extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* predict result given the test data (represented as DMatrix)
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet)
|
||||
booster.predict(testSet, true, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* Save the model as to HDFS-compatible file system.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelToHadoop(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath))
|
||||
booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user