[jvm-packages] refine tracker (#10313)
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
}
|
||||
|
||||
private[spark] def buildRabitParams : Map[String, String] = Map(
|
||||
"rabit_reduce_ring_mincount" ->
|
||||
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
|
||||
"rabit_debug" ->
|
||||
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
|
||||
"rabit_timeout" ->
|
||||
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
|
||||
"rabit_timeout_sec" -> {
|
||||
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
|
||||
overridedParams.get("rabit_timeout").toString
|
||||
} else {
|
||||
"1800"
|
||||
}
|
||||
},
|
||||
"DMLC_WORKER_CONNECT_RETRY" ->
|
||||
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
|
||||
}
|
||||
}
|
||||
|
||||
/** visiable for testing */
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker: ITracker = new RabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
|
||||
tracker
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker = getTracker(nWorkers, trackerConf)
|
||||
// Executes the provided code block inside a tracker and then stops the tracker
|
||||
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
|
||||
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
try {
|
||||
block(tracker)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -501,28 +481,27 @@ object XGBoost extends XGBoostStageLevel {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
||||
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
|
||||
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
|
||||
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
checkpointParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
|
||||
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
|
||||
checkpointManager.loadCheckpointAsScalaBooster()
|
||||
}.orNull
|
||||
|
||||
// Get the training data RDD and the cachedRDD
|
||||
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
|
||||
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
|
||||
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
val (booster, metrics) = try {
|
||||
tracker.workerArgs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.workerArgs
|
||||
val (booster, metrics) = withTracker(
|
||||
runtimeParams.numWorkers,
|
||||
runtimeParams.trackerConf
|
||||
) { tracker =>
|
||||
val rabitEnv = tracker.getWorkerArgs()
|
||||
|
||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
|
||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
|
||||
var optionWatches: Option[() => Watches] = None
|
||||
|
||||
// take the first Watches to train
|
||||
@@ -530,26 +509,25 @@ object XGBoost extends XGBoostStageLevel {
|
||||
optionWatches = Some(iter.next())
|
||||
}
|
||||
|
||||
optionWatches.map { buildWatches => buildDistributedBooster(buildWatches,
|
||||
xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
|
||||
.getOrElse(throw new RuntimeException("No Watches to train"))
|
||||
optionWatches.map { buildWatches =>
|
||||
buildDistributedBooster(buildWatches,
|
||||
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
|
||||
}.getOrElse(throw new RuntimeException("No Watches to train"))
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
|
||||
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
|
||||
boostersAndMetrics)
|
||||
// The repartition step is to make training stage as ShuffleMapStage, so that when one
|
||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||
// it fails.
|
||||
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
xgbExecParams.checkpointParam.foreach {
|
||||
runtimeParams.checkpointParam.foreach {
|
||||
cpParam =>
|
||||
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
cpParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
|
||||
@@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker. workerArgs
|
||||
val trackerEnvs = tracker.getWorkerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
@@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker.workerArgs
|
||||
val trackerEnvs = tracker.getWorkerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
|
||||
|
||||
Reference in New Issue
Block a user