Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.funsuite.AnyFunSuite
@@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
xgbParamsFactory.buildXGBRuntimeParams
}
test("Customize host ip and python exec for Rabit tracker") {
val hostIp = "192.168.22.111"
val pythonExec = "/usr/bin/python3"
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.contains(hostIp))
assert(cmd.startsWith("python"))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(!cmd.contains(hostIp))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(cmd.contains(s" --host-ip=${hostIp}"))
case _ => assert(false, "expected python tracker implementation")
}
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
*/
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new PyRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker. workerArgs
val workerCount: Int = numWorkers
/*
@@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
The Java RabitTracker class reacts to exceptions by killing the spawned process running
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
it is killed, the resulted connection failure will trigger the Rabit worker to call
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
running in separate containers. However, as unit tests are run in Spark local mode, in which
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
ultimately kills the entire process, which also happens to host the Spark driver, causing
the entire Spark application to crash.
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
the exception is thrown at last, ideally after all worker connections have been established.
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
process to ensure that pending worker connections are handled.
*/
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
@@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0) != 0)
}
test("Communicator allreduce works.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.workerArgs
val workerCount: Int = numWorkers
rdd.mapPartitions { iter =>
val index = iter.next()
Communicator.init(trackerEnvs)
val a = Array(1.0f, 2.0f, 3.0f)
System.out.println(a.mkString(", "))
val b = Communicator.allReduce(a, Communicator.OpType.SUM)
for (i <- 0 to 2) {
assert(a(i) * workerCount == b(i))
}
val c = Communicator.allReduce(a, Communicator.OpType.MIN);
for (i <- 0 to 2) {
assert(a(i) == c(i))
}
Communicator.shutdown()
Iterator(index)
}.collect()
}
test("should allow the dataframe containing communicator calls to be partially evaluated for" +

View File

@@ -23,7 +23,6 @@ import org.apache.spark.SparkException
import org.apache.spark.ml.param.ParamMap
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
@@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
intercept[SparkException] {
xgb.fit(trainingDF)
}
}
test("fail training elegantly with unsupported eval metrics") {

View File

@@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
.fit(training)
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
@@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
).fit(training)
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
// check the equality of single instance prediction
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
@@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
assert(math.abs(p1 - p2) < predictionErrorMin)
}
}
test("test rabit timeout fail handle") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Communicator.mockList = Array("0,8,0,0").toList.asJava
intercept[SparkException] {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"rabit_timeout" -> 0))
.fit(training)
}
Communicator.mockList = Array.empty.toList.asJava
}
}