distributed in RDD
This commit is contained in:
@@ -28,27 +28,35 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
|
||||
object XGBoost {
|
||||
|
||||
private var _sc: Option[SparkContext] = None
|
||||
|
||||
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
||||
new XGBoostModel(booster)
|
||||
}
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingData: RDD[LabeledPoint],
|
||||
xgBoostConfMap: Map[String, AnyRef],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
val sc = trainingData.sparkContext
|
||||
val dataUtilsBroadcast = sc.broadcast(DataUtils)
|
||||
trainingData.repartition(numWorkers).mapPartitions {
|
||||
trainingSamples =>
|
||||
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
||||
Iterator(SXGBoost.train(xgBoostConfMap, dMatrix, round,
|
||||
watches = new HashMap[String, DMatrix], obj, eval))
|
||||
}.cache()
|
||||
}
|
||||
|
||||
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): XGBoostModel = {
|
||||
import DataUtils._
|
||||
val sc = trainingData.sparkContext
|
||||
val dataUtilsBroadcast = sc.broadcast(DataUtils)
|
||||
val filePath = config.getString("inputPath") // configuration entry name to be fixed
|
||||
val numWorkers = config.getInt("numWorkers")
|
||||
val round = config.getInt("round")
|
||||
val sc = trainingData.sparkContext
|
||||
// TODO: build configuration map from config
|
||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||
val boosters = trainingData.repartition(numWorkers).mapPartitions {
|
||||
trainingSamples =>
|
||||
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
||||
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
|
||||
}.cache()
|
||||
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
|
||||
obj, eval)
|
||||
// force the job
|
||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||
// TODO: how to choose best model
|
||||
|
||||
Reference in New Issue
Block a user