Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
@@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession
/**
* Rabit tracker configurations.
*
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* @param timeout The number of seconds before timeout waiting for workers to connect. and
* for the tracker to shutdown.
* @param hostIp The Rabit Tracker host IP address.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
* @param port The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L)
def apply(): TrackerConf = TrackerConf(0)
}
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
@@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel {
private def buildDistributedBooster(
buildWatches: () => Watches,
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
rabitEnv: java.util.Map[String, Object],
obj: ObjectiveTrait,
eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
@@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel {
val taskId = TaskContext.getPartitionId().toString
val attempt = TaskContext.get().attemptNumber.toString
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
@@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel {
}
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = new PyRabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
)
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker: ITracker = new RabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
tracker
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker = getTracker(nWorkers, trackerConf)
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
require(tracker.start(), "FAULT: Failed to start tracker")
tracker
}
@@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel {
// Train for every ${savingRound} rounds and save the partially completed booster
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
val (booster, metrics) = try {
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val rabitEnv = tracker.getWorkerEnvs
tracker.workerArgs().putAll(xgbRabitParams)
val rabitEnv = tracker.workerArgs
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
var optionWatches: Option[() => Watches] = None
@@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel {
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
if (trackerReturnVal != 0) {
throw new XGBoostError("XGBoostModel training failed.")
}
(booster, metrics)
} finally {
tracker.stop()

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params {
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
*
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
* trackerImpl: String)
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
*
* See below for detailed explanations.
*
* - trackerImpl: Select the implementation of Rabit tracker.
* default: "python"
*
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
* The "scala" version removes Python components, and fully supports timeout settings.
*
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
* default: 0 millisecond (no timeout)
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
* default: 0 (no timeout)
*
* Timeout for constructing the communication group and waiting for the tracker to
* shutdown when it's instructed to, doesn't apply to communication when tracking
* is running.
* The timeout value should take the time of data loading and pre-processing into account,
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
* due to potential lazy execution. Alternatively, you may force Spark to
* perform data transformation before calling XGBoost.train(), so that this timeout truly
* reflects the connection delay. Set a reasonable timeout value to prevent model
* training/testing from hanging indefinitely, possible due to network issues.
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
* Ignored if the tracker implementation is "python".
*
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
* cannot be automatically guessed.
*
* - port : The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
setDefault(trackerConf, TrackerConf())