Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -22,7 +22,7 @@ import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
*
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
|
||||
* @param timeout The number of seconds before timeout waiting for workers to connect. and
|
||||
* for the tracker to shutdown.
|
||||
* @param hostIp The Rabit Tracker host IP address.
|
||||
* This is only needed if the host IP cannot be automatically guessed.
|
||||
* @param pythonExec The python executed path for Rabit Tracker,
|
||||
* which is only used for python implementation.
|
||||
* @param port The port number for the tracker to listen to. Use a system allocated one by
|
||||
* default.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long,
|
||||
hostIp: String = "", pythonExec: String = "")
|
||||
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
def apply(): TrackerConf = TrackerConf(0)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
@@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel {
|
||||
private def buildDistributedBooster(
|
||||
buildWatches: () => Watches,
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
rabitEnv: java.util.Map[String, Object],
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
@@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel {
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
val attempt = TaskContext.get().attemptNumber.toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||
val numRounds = xgbExecutionParam.numRounds
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
|
||||
@@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel {
|
||||
}
|
||||
|
||||
/** visiable for testing */
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = new PyRabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
|
||||
)
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker: ITracker = new RabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
|
||||
tracker
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker = getTracker(nWorkers, trackerConf)
|
||||
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
}
|
||||
|
||||
@@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
val (booster, metrics) = try {
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
tracker.workerArgs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.workerArgs
|
||||
|
||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
|
||||
var optionWatches: Option[() => Watches] = None
|
||||
@@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel {
|
||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||
// it fails.
|
||||
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
if (trackerReturnVal != 0) {
|
||||
throw new XGBoostError("XGBoostModel training failed.")
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params {
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
* TrackerConf class, which has the following definition:
|
||||
*
|
||||
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
|
||||
* trackerImpl: String)
|
||||
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
|
||||
*
|
||||
* See below for detailed explanations.
|
||||
*
|
||||
* - trackerImpl: Select the implementation of Rabit tracker.
|
||||
* default: "python"
|
||||
*
|
||||
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
|
||||
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
|
||||
* The "scala" version removes Python components, and fully supports timeout settings.
|
||||
*
|
||||
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
|
||||
* default: 0 millisecond (no timeout)
|
||||
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
|
||||
* default: 0 (no timeout)
|
||||
*
|
||||
* Timeout for constructing the communication group and waiting for the tracker to
|
||||
* shutdown when it's instructed to, doesn't apply to communication when tracking
|
||||
* is running.
|
||||
* The timeout value should take the time of data loading and pre-processing into account,
|
||||
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
|
||||
* due to potential lazy execution. Alternatively, you may force Spark to
|
||||
* perform data transformation before calling XGBoost.train(), so that this timeout truly
|
||||
* reflects the connection delay. Set a reasonable timeout value to prevent model
|
||||
* training/testing from hanging indefinitely, possible due to network issues.
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*
|
||||
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
|
||||
* cannot be automatically guessed.
|
||||
*
|
||||
* - port : The port number for the tracker to listen to. Use a system allocated one by
|
||||
* default.
|
||||
*/
|
||||
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
||||
setDefault(trackerConf, TrackerConf())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
@@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
xgbParamsFactory.buildXGBRuntimeParams
|
||||
}
|
||||
|
||||
test("Customize host ip and python exec for Rabit tracker") {
|
||||
val hostIp = "192.168.22.111"
|
||||
val pythonExec = "/usr/bin/python3"
|
||||
|
||||
val paramMap = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp))
|
||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
tracker match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.contains(hostIp))
|
||||
assert(cmd.startsWith("python"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap1 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
|
||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||
tracker1 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(!cmd.contains(hostIp))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap2 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
|
||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||
tracker2 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
}
|
||||
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
@@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
*/
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new PyRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker. workerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
@@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||
|
||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||
|
||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||
the entire Spark application to crash.
|
||||
|
||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||
the exception is thrown at last, ideally after all worker connections have been established.
|
||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||
process to ensure that pending worker connections are handled.
|
||||
*/
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Communicator.init(trackerEnvs)
|
||||
@@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
}
|
||||
|
||||
test("Communicator allreduce works.") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker.workerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
|
||||
rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
Communicator.init(trackerEnvs)
|
||||
val a = Array(1.0f, 2.0f, 3.0f)
|
||||
System.out.println(a.mkString(", "))
|
||||
val b = Communicator.allReduce(a, Communicator.OpType.SUM)
|
||||
for (i <- 0 to 2) {
|
||||
assert(a(i) * workerCount == b(i))
|
||||
}
|
||||
val c = Communicator.allReduce(a, Communicator.OpType.MIN);
|
||||
for (i <- 0 to 2) {
|
||||
assert(a(i) == c(i))
|
||||
}
|
||||
Communicator.shutdown()
|
||||
Iterator(index)
|
||||
}.collect()
|
||||
}
|
||||
|
||||
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
|
||||
|
||||
@@ -23,7 +23,6 @@ import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
"objective_type" -> "classification")
|
||||
@@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
intercept[SparkException] {
|
||||
xgb.fit(trainingDF)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
test("fail training elegantly with unsupported eval metrics") {
|
||||
|
||||
@@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
||||
.fit(training)
|
||||
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
@@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
||||
).fit(training)
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
// check the equality of single instance prediction
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
@@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
assert(math.abs(p1 - p2) < predictionErrorMin)
|
||||
}
|
||||
}
|
||||
|
||||
test("test rabit timeout fail handle") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
// mock rank 0 failure during 8th allreduce synchronization
|
||||
Communicator.mockList = Array("0,8,0,0").toList.asJava
|
||||
|
||||
intercept[SparkException] {
|
||||
new XGBoostClassifier(Map(
|
||||
"eta" -> "0.1",
|
||||
"max_depth" -> "10",
|
||||
"verbosity" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5,
|
||||
"num_workers" -> numWorkers,
|
||||
"rabit_timeout" -> 0))
|
||||
.fit(training)
|
||||
}
|
||||
|
||||
Communicator.mockList = Array.empty.toList.asJava
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user