[jvm-packages] refine tracker (#10313)

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2024-05-23 12:46:21 +08:00
committed by GitHub
parent 966dc81788
commit 932d7201f9
8 changed files with 71 additions and 92 deletions

View File

@@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
private[spark] def buildRabitParams : Map[String, String] = Map(
"rabit_reduce_ring_mincount" ->
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
"rabit_debug" ->
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
"rabit_timeout" ->
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
"rabit_timeout_sec" -> {
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
overridedParams.get("rabit_timeout").toString
} else {
"1800"
}
},
"DMLC_WORKER_CONNECT_RETRY" ->
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
)
}
/**
@@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
}
}
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker: ITracker = new RabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
tracker
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker = getTracker(nWorkers, trackerConf)
// Executes the provided code block inside a tracker and then stops the tracker
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
require(tracker.start(), "FAULT: Failed to start tracker")
tracker
try {
block(tracker)
} finally {
tracker.stop()
}
}
/**
@@ -501,28 +481,27 @@ object XGBoost extends XGBoostStageLevel {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull
// Get the training data RDD and the cachedRDD
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
try {
// Train for every ${savingRound} rounds and save the partially completed booster
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
val (booster, metrics) = try {
tracker.workerArgs().putAll(xgbRabitParams)
val rabitEnv = tracker.workerArgs
val (booster, metrics) = withTracker(
runtimeParams.numWorkers,
runtimeParams.trackerConf
) { tracker =>
val rabitEnv = tracker.getWorkerArgs()
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
var optionWatches: Option[() => Watches] = None
// take the first Watches to train
@@ -530,26 +509,25 @@ object XGBoost extends XGBoostStageLevel {
optionWatches = Some(iter.next())
}
optionWatches.map { buildWatches => buildDistributedBooster(buildWatches,
xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
.getOrElse(throw new RuntimeException("No Watches to train"))
optionWatches.map { buildWatches =>
buildDistributedBooster(buildWatches,
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
}.getOrElse(throw new RuntimeException("No Watches to train"))
}
}}
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
boostersAndMetrics)
// The repartition step is to make training stage as ShuffleMapStage, so that when one
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
(booster, metrics)
} finally {
tracker.stop()
}
// we should delete the checkpoint directory after a successful training
xgbExecParams.checkpointParam.foreach {
runtimeParams.checkpointParam.foreach {
cpParam =>
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))

View File

@@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker. workerArgs
val trackerEnvs = tracker.getWorkerArgs
val workerCount: Int = numWorkers
/*
@@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.workerArgs
val trackerEnvs = tracker.getWorkerArgs
val workerCount: Int = numWorkers