adjust the return values of RabitTracker.waitFor(), remove typesafe.Config
This commit is contained in:
parent
457ff82e33
commit
f768edfede
@ -161,10 +161,5 @@
|
|||||||
<version>2.2.6</version>
|
<version>2.2.6</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe</groupId>
|
|
||||||
<artifactId>config</artifactId>
|
|
||||||
<version>1.2.1</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@ -16,10 +16,10 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.immutable.HashMap
|
import scala.collection.mutable
|
||||||
|
|
||||||
import com.typesafe.config.Config
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.spark.{TaskContext, SparkContext}
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
@ -28,6 +28,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
|||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
|
|
||||||
|
var boosters: RDD[Booster] = null
|
||||||
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
||||||
new XGBoostModel(booster)
|
new XGBoostModel(booster)
|
||||||
}
|
}
|
||||||
@ -37,9 +40,6 @@ object XGBoost extends Serializable {
|
|||||||
xgBoostConfMap: Map[String, AnyRef],
|
xgBoostConfMap: Map[String, AnyRef],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val sc = trainingData.sparkContext
|
|
||||||
val tracker = new RabitTracker(numWorkers)
|
|
||||||
if (tracker.start()) {
|
|
||||||
trainingData.repartition(numWorkers).mapPartitions {
|
trainingData.repartition(numWorkers).mapPartitions {
|
||||||
trainingSamples =>
|
trainingSamples =>
|
||||||
Rabit.init(new java.util.HashMap[String, String]() {
|
Rabit.init(new java.util.HashMap[String, String]() {
|
||||||
@ -47,32 +47,26 @@ object XGBoost extends Serializable {
|
|||||||
})
|
})
|
||||||
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
||||||
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
|
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
|
||||||
watches = new HashMap[String, DMatrix], obj, eval)
|
watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
Iterator(booster)
|
Iterator(booster)
|
||||||
}.cache()
|
}.cache()
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
|
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
|
||||||
eval: EvalTrait = null): Option[XGBoostModel] = {
|
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
|
||||||
import DataUtils._
|
val numWorkers = trainingData.partitions.length
|
||||||
val numWorkers = config.getInt("numWorkers")
|
|
||||||
val round = config.getInt("round")
|
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val tracker = new RabitTracker(numWorkers)
|
val tracker = new RabitTracker(numWorkers)
|
||||||
if (tracker.start()) {
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
// TODO: build configuration map from config
|
boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
|
||||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
|
||||||
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
|
|
||||||
obj, eval)
|
|
||||||
// force the job
|
// force the job
|
||||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||||
tracker.waitFor()
|
val booster = boosters.first()
|
||||||
// TODO: how to choose best model
|
val returnVal = tracker.waitFor()
|
||||||
Some(boosters.first())
|
logger.info(s"Rabit returns with exit code $returnVal")
|
||||||
|
if (returnVal == 0) {
|
||||||
|
Some(booster)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|||||||
@ -134,15 +134,18 @@ public class RabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void waitFor() {
|
public int waitFor() {
|
||||||
try {
|
try {
|
||||||
trackerProcess.get().waitFor();
|
trackerProcess.get().waitFor();
|
||||||
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
|
int returnVal = trackerProcess.get().exitValue();
|
||||||
|
logger.info("Tracker Process ends with exit code " + returnVal);
|
||||||
stop();
|
stop();
|
||||||
|
return returnVal;
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
// we should not get here as RabitTracker is accessed in the main thread
|
// we should not get here as RabitTracker is accessed in the main thread
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user