[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou 2022-10-05 15:39:01 -07:00 committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 805 additions and 2212 deletions

View File

@ -52,15 +52,15 @@ class XGBoostTrainer(Executor):
def _do_training(self, fl_ctx: FLContext):
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
rank = int(client_name.split('-')[1]) - 1
rabit_env = [
f'federated_server_address={self._server_address}',
f'federated_world_size={self._world_size}',
f'federated_rank={rank}',
f'federated_server_cert={self._server_cert_path}',
f'federated_client_key={self._client_key_path}',
f'federated_client_cert={self._client_cert_path}'
]
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
communicator_env = {
'federated_server_address': self._server_address,
'federated_world_size': self._world_size,
'federated_rank': rank,
'federated_server_cert': self._server_cert_path,
'federated_client_key': self._client_key_path,
'federated_client_cert': self._client_cert_path
}
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train')
dtest = xgb.DMatrix('agaricus.txt.test')
@ -86,4 +86,4 @@ class XGBoostTrainer(Executor):
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
run_dir = workspace.get_run_dir(run_number)
bst.save_model(os.path.join(run_dir, "test.model.json"))
xgb.rabit.tracker_print("Finished training\n")
xgb.collective.communicator_print("Finished training\n")

View File

@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
import org.apache.commons.logging.LogFactory
@ -46,7 +46,7 @@ object XGBoost {
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
Communicator.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
@ -59,7 +59,7 @@ object XGBoost {
.map(_.toString.toInt).getOrElse(0)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
earlyStoppingRound = numEarlyStoppingRounds)
Rabit.shutdown()
Communicator.shutdown()
collector.collect(new XGBoostModel(booster))
}
}

View File

@ -22,7 +22,7 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider {
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
Communicator.init(rabitEnv.asJava)
}
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider {
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Rabit.shutdown()
Communicator.shutdown()
}
ret
}

View File

@ -22,7 +22,7 @@ import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
@ -303,7 +303,7 @@ object XGBoost extends Serializable {
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try {
Rabit.init(rabitEnv)
Communicator.init(rabitEnv)
watches = buildWatchesAndCheck(buildWatches)
@ -342,7 +342,7 @@ object XGBoost extends Serializable {
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
throw xgbException
} finally {
Rabit.shutdown()
Communicator.shutdown()
if (watches != null) watches.delete()
}
}

View File

@ -1,277 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.{FunSuite}
class RabitRobustnessSuite extends FunSuite with PerTest {
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
val classifier = new XGBoostClassifier(paramMap)
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
xgbParamsFactory.buildXGBRuntimeParams
}
test("Customize host ip and python exec for Rabit tracker") {
val hostIp = "192.168.22.111"
val pythonExec = "/usr/bin/python3"
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.contains(hostIp))
assert(cmd.startsWith("python"))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(!cmd.contains(hostIp))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(cmd.contains(s" --host-ip=${hostIp}"))
case _ => assert(false, "expected python tracker implementation")
}
}
test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
val rawData = rdd.mapPartitions { iter =>
Iterator(iter.toArray)
}.collect()
val maxVec = (0 until vectorLength).toArray.map { j =>
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
}
val allReduceResults = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val arr = iter.toArray
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
Rabit.shutdown()
Iterator(results)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
allReduceResults.foreachPartition(() => _)
val byPartitionResults = allReduceResults.collect()
assert(byPartitionResults(0).length == vectorLength)
collectedAllReduceResults.put(byPartitionResults(0))
}
}
sparkThread.start()
assert(tracker.waitFor(0L) == 0)
sparkThread.join()
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
that should be avoided.
*/
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new PyRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
/*
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
last created worker. A cascading event chain will be triggered once the RuntimeException is
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
The Java RabitTracker class reacts to exceptions by killing the spawned process running
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
it is killed, the resulted connection failure will trigger the Rabit worker to call
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
running in separate containers. However, as unit tests are run in Spark local mode, in which
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
ultimately kills the entire process, which also happens to host the Spark driver, causing
the entire Spark application to crash.
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
the exception is thrown at last, ideally after all worker connections have been established.
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
process to ensure that pending worker connections are handled.
*/
val dummyTasks = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Rabit.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0) != 0)
}
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
val dummyTasks = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Rabit.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("test Scala RabitTracker's workerConnectionTimeout") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(500)
val trackerEnvs = tracker.getWorkerEnvs
val dummyTasks = rdd.mapPartitions { iter =>
val index = iter.next()
// simulate that the first worker cannot connect to tracker due to network issues.
if (index != 1) {
Rabit.init(trackerEnvs)
Thread.sleep(1000)
Rabit.shutdown()
}
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
// should fail due to connection timeout
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("should allow the dataframe containing rabit calls to be partially evaluated for" +
" multiple times (ISSUE-4406)") {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic")
val trainingDF = buildDataFrame(Classification.train)
val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction = model.transform(trainingDF)
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
// threads
prediction.show()
// a full evaluation here will re-run init and shutdown all rabit proxy
// expecting no error
prediction.collect()
}
}

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.Booster
import scala.collection.JavaConverters._
@ -25,7 +25,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
val predictionErrorMin = 0.00001f
val maxFailure = 2;
@ -47,8 +47,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
.fit(training)
assert(Rabit.rabitEnvs.asScala.size > 3)
Rabit.rabitEnvs.asScala.foreach( item => {
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
@ -70,8 +70,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
).fit(training)
assert(Rabit.rabitEnvs.asScala.size > 3)
Rabit.rabitEnvs.asScala.foreach( item => {
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
// check the equality of single instance prediction
@ -85,7 +85,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
test("test rabit timeout fail handle") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Rabit.mockList = Array("0,8,0,0").toList.asJava
Communicator.mockList = Array("0,8,0,0").toList.asJava
intercept[SparkException] {
new XGBoostClassifier(Map(
@ -98,6 +98,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
"rabit_timeout" -> 0))
.fit(training)
}
Communicator.mockList = Array.empty.toList.asJava
}
}

View File

@ -1,154 +0,0 @@
package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
public enum OpType implements Serializable {
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
private int op;
public int getOperand() {
return this.op;
}
OpType(int op) {
this.op = op;
}
}
public enum DataType implements Serializable {
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
LONGLONG(8, 8), ULONGLONG(9, 8);
private int enumOp;
private int size;
public int getEnumOp() {
return this.enumOp;
}
public int getSize() {
return this.size;
}
DataType(int enumOp, int size) {
this.enumOp = enumOp;
this.size = size;
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
}
}
// used as way to test/debug passed rabit init parameters
public static Map<String, String> rabitEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
rabitEnvs = envs;
String[] args = new String[envs.size() + mockList.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
// pass list of rabit mock strings eg mock=0,1,0,0
for(String mock : mockList) {
args[idx++] = "mock=" + mock;
}
checkCall(XGBoostJNI.RabitInit(args));
}
/**
* Shutdown the rabit engine in current working thread, equals to finalize.
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XGBoostJNI.RabitFinalize());
}
/**
* Print the message on rabit tracker.
* @param msg
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
}
/**
* Get version number of current stored model in the thread.
* which means how many calls to CheckPoint we made so far.
* @return version Number.
* @throws XGBoostError
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.RabitVersionNumber(out));
return out[0];
}
/**
* get rank of current thread.
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.RabitGetRank(out));
return out[0];
}
/**
* get world size of current job.
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.RabitGetWorldSize(out));
return out[0];
}
/**
* perform Allreduce on distributed float vectors using operator op.
* This implementation of allReduce does not support customized prepare function callback in the
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
*
* @param elements local elements on distributed workers.
* @param op operator used for Allreduce.
* @return All-reduced float elements according to the given operator.
*/
public static float[] allReduce(float[] elements, OpType op) {
DataType dataType = DataType.FLOAT;
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
.order(ByteOrder.nativeOrder());
for (float el : elements) {
buffer.putFloat(el);
}
buffer.flip();
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
float[] results = new float[elements.length];
buffer.asFloatBuffer().get(results);
return results;
}
}

View File

@ -254,16 +254,16 @@ public class XGBoost {
}
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
if (shouldPrint(params, iter)) {
Rabit.trackerPrint(String.format(
Communicator.communicatorPrint(String.format(
"early stopping after %d rounds away from the best iteration",
earlyStoppingRounds
));
}
break;
}
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
if (shouldPrint(params, iter)){
Rabit.trackerPrint(evalInfo + '\n');
Communicator.communicatorPrint(evalInfo + '\n');
}
}
}

View File

@ -135,19 +135,6 @@ class XGBoostJNI {
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
// rabit functions
public final static native int RabitInit(String[] args);
public final static native int RabitFinalize();
public final static native int RabitTrackerPrint(String msg);
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
// Perform Allreduce operation on data in sendrecvbuf.
// This JNI function does not support the callback function for data preparation yet.
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorFinalize();

View File

@ -872,111 +872,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
std::vector<std::string> args;
std::vector<char*> argv;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
for (bst_ulong i = 0; i < len; ++i) {
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
const char *s = jenv->GetStringUTFChars(arg, 0);
args.push_back(std::string(s, jenv->GetStringLength(arg)));
if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
if (args.back().length() == 0) args.pop_back();
}
for (size_t i = 0; i < args.size(); ++i) {
argv.push_back(&args[i][0]);
}
if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
return 0;
} else {
return 1;
}
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) {
if (RabitFinalize()) {
return 0;
} else {
return 1;
}
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
JVM_CHECK_CALL(RabitTrackerPrint(str.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
jint rank = RabitGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
jint out = RabitGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
(JNIEnv *jenv, jclass jcls, jintArray jout) {
jint out = RabitVersionNumber();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
JVM_CHECK_CALL(RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit

View File

@ -279,62 +279,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
(JNIEnv *, jclass, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *, jclass, jobject, jint, jint, jint);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit

View File

@ -300,7 +300,7 @@ public class DMatrixTest {
public void testTrainWithDenseMatrixRef() throws XGBoostError {
Map<String, String> rabitEnv = new HashMap<>();
rabitEnv.put("DMLC_TASK_ID", "0");
Rabit.init(rabitEnv);
Communicator.init(rabitEnv);
DMatrix trainMat = null;
BigDenseMatrix data0 = null;
try {
@ -348,7 +348,7 @@ public class DMatrixTest {
else if (data0 != null) {
data0.dispose();
}
Rabit.shutdown();
Communicator.shutdown();
}
}

View File

@ -23,6 +23,6 @@ target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
target_sources(objxgboost PRIVATE federated_server.cc)
target_link_libraries(objxgboost PRIVATE federated_client)
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@ -1,197 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <cstdio>
#include <fstream>
#include <sstream>
#include "federated_client.h"
#include "rabit/internal/engine.h"
#include "rabit/internal/utils.h"
namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
} // namespace MPI
namespace rabit {
namespace engine {
/*! \brief implementation of engine using federated learning */
class FederatedEngine : public IEngine {
public:
void Init(int argc, char *argv[]) {
// Parse environment variables first.
for (auto const &env_var : env_vars_) {
char const *value = getenv(env_var.c_str());
if (value != nullptr) {
SetParam(env_var, value);
}
}
// Command line argument overrides.
for (int i = 0; i < argc; ++i) {
std::string const key_value = argv[i];
auto const delimiter = key_value.find('=');
if (delimiter != std::string::npos) {
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
}
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
utils::Printf("Certificates not specified, turning off SSL.");
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
} else {
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}
}
void Finalize() { client_.reset(); }
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
size_t size_prev_slice) override {
throw std::logic_error("FederatedEngine:: Allgather is not supported");
}
std::string Allgather(void *sendbuf, size_t total_size) {
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
return client_->Allgather(send_buffer);
}
void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
PreprocFunction prepare_fun, void *prepare_arg) override {
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
}
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
static_cast<xgboost::federated::ReduceOperation>(op));
receive_buffer.copy(buffer, size);
}
int GetRingPrevRank() const override {
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
}
void Broadcast(void *sendrecvbuf, size_t size, int root) override {
if (world_size_ == 1) return;
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Broadcast(send_buffer, root);
if (rank_ != root) {
receive_buffer.copy(buffer, size);
}
}
int LoadCheckPoint() override { return 0; }
void CheckPoint() override { version_number_ += 1; }
int VersionNumber() const override { return version_number_; }
/*! \brief get rank of current node */
int GetRank() const override { return rank_; }
/*! \brief get total number of */
int GetWorldSize() const override { return world_size_; }
/*! \brief whether it is distributed */
bool IsDistributed() const override { return true; }
/*! \brief get the host name of current node */
std::string GetHost() const override { return "rank" + std::to_string(rank_); }
void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
world_size_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
rank_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
server_cert_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
client_key_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
client_cert_ = ReadFile(val);
}
}
static std::string ReadFile(std::string const &path) {
auto stream = std::ifstream(path.data());
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}
// clang-format off
std::vector<std::string> const env_vars_{
"FEDERATED_SERVER_ADDRESS",
"FEDERATED_WORLD_SIZE",
"FEDERATED_RANK",
"FEDERATED_SERVER_CERT",
"FEDERATED_CLIENT_KEY",
"FEDERATED_CLIENT_CERT" };
// clang-format on
std::string server_address_{"localhost:9091"};
int world_size_{1};
int rank_{0};
std::string server_cert_{};
std::string client_key_{};
std::string client_cert_{};
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
int version_number_{0};
};
// Singleton federated engine.
FederatedEngine engine; // NOLINT(cert-err58-cpp)
/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
engine.Init(argc, argv);
return true;
} catch (std::exception const &e) {
fprintf(stderr, " failed in federated Init %s\n", e.what());
return false;
}
}
/*! \brief finalize synchronization module */
bool Finalize() {
try {
engine.Finalize();
return true;
} catch (const std::exception &e) {
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
return false;
}
}
/*! \brief singleton method to get engine */
IEngine *GetEngine() { return &engine; }
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
}
} // namespace engine
} // namespace rabit

View File

@ -11,6 +11,40 @@
namespace xgboost {
namespace collective {
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/**
* @brief A Federated Learning communicator class that handles collective communication.
*/

View File

@ -3,9 +3,8 @@
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
"""
from . import rabit # noqa
from . import tracker # noqa
from . import dask
from . import collective, dask
from .core import (
Booster,
DataIter,
@ -63,4 +62,6 @@ __all__ = [
"XGBRFRegressor",
# dask
"dask",
# collective
"collective",
]

View File

@ -13,7 +13,7 @@ import pickle
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast, Sequence, Any
import numpy
from . import rabit
from . import collective
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
@ -100,7 +100,7 @@ def _allreduce_metric(score: _ART) -> _ART:
as final result.
'''
world = rabit.get_world_size()
world = collective.get_world_size()
assert world != 0
if world == 1:
return score
@ -108,7 +108,7 @@ def _allreduce_metric(score: _ART) -> _ART:
raise ValueError(
'xgboost.cv function should not be used in distributed environment.')
arr = numpy.array([score])
arr = rabit.allreduce(arr, rabit.Op.SUM) / world
arr = collective.allreduce(arr, collective.Op.SUM) / world
return arr[0]
@ -485,7 +485,7 @@ class EvaluationMonitor(TrainingCallback):
return False
msg: str = f'[{epoch}]'
if rabit.get_rank() == self.printer_rank:
if collective.get_rank() == self.printer_rank:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
stdv: Optional[float] = None
@ -498,7 +498,7 @@ class EvaluationMonitor(TrainingCallback):
msg += '\n'
if (epoch % self.period) == 0 or self.period == 1:
rabit.tracker_print(msg)
collective.communicator_print(msg)
self._latest = None
else:
# There is skipped message
@ -506,8 +506,8 @@ class EvaluationMonitor(TrainingCallback):
return False
def after_training(self, model: _Model) -> _Model:
if rabit.get_rank() == self.printer_rank and self._latest is not None:
rabit.tracker_print(self._latest)
if collective.get_rank() == self.printer_rank and self._latest is not None:
collective.communicator_print(self._latest)
return model
@ -552,7 +552,7 @@ class TrainingCheckPoint(TrainingCallback):
path = os.path.join(self._path, self._name + '_' + str(epoch) +
('.pkl' if self._as_pickle else '.json'))
self._epoch = 0
if rabit.get_rank() == 0:
if collective.get_rank() == 0:
if self._as_pickle:
with open(path, 'wb') as fd:
pickle.dump(model, fd)

View File

@ -4,7 +4,7 @@ import json
import logging
import pickle
from enum import IntEnum, unique
from typing import Any, List
from typing import Any, List, Dict
import numpy as np
@ -233,10 +233,11 @@ class CommunicatorContext:
def __init__(self, **args: Any) -> None:
self.args = args
def __enter__(self) -> None:
def __enter__(self) -> Dict[str, Any]:
init(**self.args)
assert is_distributed()
LOGGER.debug("-------------- communicator say hello ------------------")
return self.args
def __exit__(self, *args: List) -> None:
finalize()

View File

@ -59,7 +59,7 @@ from typing import (
import numpy
from . import config, rabit
from . import collective, config
from ._typing import _T, FeatureNames, FeatureTypes
from .callback import TrainingCallback
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
@ -112,7 +112,7 @@ TrainReturnT = TypedDict(
)
__all__ = [
"RabitContext",
"CommunicatorContext",
"DaskDMatrix",
"DaskDeviceQuantileDMatrix",
"DaskXGBRegressor",
@ -158,7 +158,7 @@ def _try_start_tracker(
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
rabit_tracker = RabitTracker(
host_ip=get_host_ip(host_ip),
n_workers=n_workers,
port=port,
@ -168,12 +168,12 @@ def _try_start_tracker(
addr = addrs[0]
assert isinstance(addr, str) or addr is None
host_ip = get_host_ip(addr)
rabit_context = RabitTracker(
rabit_tracker = RabitTracker(
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
env.update(rabit_tracker.worker_envs())
rabit_tracker.start(n_workers)
thread = Thread(target=rabit_tracker.join)
thread.daemon = True
thread.start()
except socket.error as e:
@ -213,11 +213,11 @@ def _assert_dask_support() -> None:
LOGGER.warning(msg)
class RabitContext(rabit.RabitContext):
"""A context controlling rabit initialization and finalization."""
class CommunicatorContext(collective.CommunicatorContext):
"""A context controlling collective communicator initialization and finalization."""
def __init__(self, args: List[bytes]) -> None:
super().__init__(args)
def __init__(self, **args: Any) -> None:
super().__init__(**args)
worker = distributed.get_worker()
with distributed.worker_client() as client:
info = client.scheduler_info()
@ -227,9 +227,7 @@ class RabitContext(rabit.RabitContext):
# not the same as task ID is string and "10" is sorted before "2") with dask
# worker ID. This outsources the rank assignment to dask and prevents
# non-deterministic issue.
self.args.append(
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
)
self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{wid}]:" + str(worker.address)
def dconcat(value: Sequence[_T]) -> _T:
@ -811,7 +809,7 @@ def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix:
async def _get_rabit_args(
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
) -> List[bytes]:
) -> Dict[str, Union[str, int]]:
"""Get rabit context arguments from data distribution in DaskDMatrix."""
# There are 3 possible different addresses:
# 1. Provided by user via dask.config
@ -854,9 +852,7 @@ async def _get_rabit_args(
env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr
)
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args
return env
def _get_dask_config() -> Optional[Dict[str, Any]]:
@ -911,7 +907,7 @@ async def _train_async(
def dispatched_train(
parameters: Dict,
rabit_args: List[bytes],
rabit_args: Dict[str, Union[str, int]],
train_id: int,
evals_name: List[str],
evals_id: List[int],
@ -935,7 +931,7 @@ async def _train_async(
n_threads = dwnt
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
local_history: TrainingCallback.EvalsLog = {}
with RabitContext(rabit_args), config.config_context(**global_config):
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
evals: List[Tuple[DMatrix, str]] = []
for i, ref in enumerate(refs):

View File

@ -1,249 +0,0 @@
"""Distributed XGBoost Rabit related API."""
import ctypes
from enum import IntEnum, unique
import logging
import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
import numpy as np
from .core import _LIB, c_str, _check_call
LOGGER = logging.getLogger("[xgboost.rabit]")
def _init_rabit() -> None:
"""internal library initializer."""
if _LIB is not None:
_LIB.RabitGetRank.restype = ctypes.c_int
_LIB.RabitGetWorldSize.restype = ctypes.c_int
_LIB.RabitIsDistributed.restype = ctypes.c_int
_LIB.RabitVersionNumber.restype = ctypes.c_int
def init(args: Optional[List[bytes]] = None) -> None:
"""Initialize the rabit library with arguments"""
if args is None:
args = []
arr = (ctypes.c_char_p * len(args))()
arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
_LIB.RabitInit(len(arr), arr)
def finalize() -> None:
"""Finalize the process, notify tracker everything is done."""
_LIB.RabitFinalize()
def get_rank() -> int:
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
ret = _LIB.RabitGetRank()
return ret
def get_world_size() -> int:
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
ret = _LIB.RabitGetWorldSize()
return ret
def is_distributed() -> int:
'''If rabit is distributed.'''
is_dist = _LIB.RabitIsDistributed()
return is_dist
def tracker_print(msg: Any) -> None:
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
if not isinstance(msg, str):
msg = str(msg)
is_dist = _LIB.RabitIsDistributed()
if is_dist != 0:
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
else:
print(msg.strip(), flush=True)
def get_processor_name() -> bytes:
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
mxlen = 256
length = ctypes.c_ulong()
buf = ctypes.create_string_buffer(mxlen)
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
return buf.value
T = TypeVar("T") # pylint:disable=invalid-name
def broadcast(data: T, root: int) -> T:
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_check_call(_LIB.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root))
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_check_call(_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root))
data = pickle.loads(dptr.raw)
del dptr
else:
_check_call(_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root))
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8'): 0,
np.dtype('uint8'): 1,
np.dtype('int32'): 2,
np.dtype('uint32'): 3,
np.dtype('int64'): 4,
np.dtype('uint64'): 5,
np.dtype('float32'): 6,
np.dtype('float64'): 7
}
@unique
class Op(IntEnum):
'''Supported operations for rabit.'''
MAX = 0
MIN = 1
SUM = 2
OR = 3
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data :
Input data.
op :
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun :
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result :
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise Exception('allreduce only takes in numpy.ndarray')
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception(f"data type {buf.dtype} not supported")
if prepare_fun is None:
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
int(op), None, None))
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(_: Any) -> None:
"""prepare function."""
fn = cast(Callable[[np.ndarray], None], prepare_fun)
fn(data)
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, func_ptr(pfunc), None))
return buf
def version_number() -> int:
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
ret = _LIB.RabitVersionNumber()
return ret
class RabitContext:
"""A context controlling rabit initialization and finalization."""
def __init__(self, args: List[bytes] = None) -> None:
if args is None:
args = []
self.args = args
def __enter__(self) -> None:
init(self.args)
assert is_distributed()
LOGGER.debug("-------------- rabit say hello ------------------")
def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- rabit say bye ------------------")
# initialization script
_init_rabit()

View File

@ -2,6 +2,7 @@
"""Xgboost pyspark integration submodule for core code."""
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines
import json
from typing import Iterator, Optional, Tuple
import numpy as np
@ -57,7 +58,7 @@ from .params import (
HasQueryIdCol,
)
from .utils import (
RabitContext,
CommunicatorContext,
_get_args_from_message_list,
_get_default_params_from_func,
_get_gpu_id,
@ -769,7 +770,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
_rabit_args = ""
_rabit_args = {}
if context.partitionId() == 0:
get_logger("XGBoostPySpark").info(
"booster params: %s\n"
@ -780,12 +781,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dmatrix_kwargs,
)
_rabit_args = str(_get_rabit_args(context, num_workers))
_rabit_args = _get_rabit_args(context, num_workers)
messages = context.allGather(message=str(_rabit_args))
messages = context.allGather(message=json.dumps(_rabit_args))
_rabit_args = _get_args_from_message_list(messages)
evals_result = {}
with RabitContext(_rabit_args, context):
with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter,
features_cols_names,

View File

@ -1,6 +1,7 @@
# type: ignore
"""Xgboost pyspark integration submodule for helper functions."""
import inspect
import json
import logging
import sys
from threading import Thread
@ -9,7 +10,7 @@ import pyspark
from pyspark.sql.session import SparkSession
from xgboost.tracker import RabitTracker
from xgboost import rabit
from xgboost import collective
def get_class_name(cls):
@ -36,21 +37,21 @@ def _get_default_params_from_func(func, unsupported_set):
return filtered_params_dict
class RabitContext:
class CommunicatorContext:
"""
A context controlling rabit initialization and finalization.
A context controlling collective communicator initialization and finalization.
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
"""
def __init__(self, args, context):
def __init__(self, context, **args):
self.args = args
self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode())
self.args["DMLC_TASK_ID"] = str(context.partitionId())
def __enter__(self):
rabit.init(self.args)
collective.init(**self.args)
def __exit__(self, *args):
rabit.finalize()
collective.finalize()
def _start_tracker(context, n_workers):
@ -74,8 +75,7 @@ def _get_rabit_args(context, n_workers):
"""
# pylint: disable=consider-using-f-string
env = _start_tracker(context, n_workers)
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
return rabit_args
return env
def _get_host_ip(context):
@ -95,7 +95,7 @@ def _get_args_from_message_list(messages):
if message != "":
output = message
break
return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")]
return json.loads(output)
def _get_spark_session():

View File

@ -6,9 +6,7 @@ set(RABIT_SOURCES
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
if (PLUGIN_FEDERATED)
# Skip the engine if the Federated Learning plugin is enabled.
elseif (RABIT_BUILD_MPI)
if (RABIT_BUILD_MPI)
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
elseif (RABIT_MOCK)
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)

View File

@ -1,11 +1,8 @@
// Copyright (c) 2014-2022 by Contributors
#include <rabit/rabit.h>
#include <rabit/c_api.h>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <algorithm>
#include <vector>
#include <string>
#include <memory>
@ -22,12 +19,11 @@
#include "c_api_error.h"
#include "c_api_utils.h"
#include "../collective/communicator.h"
#include "../collective/communicator-inl.h"
#include "../common/io.h"
#include "../common/charconv.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "../data/proxy_dmatrix.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
@ -215,7 +211,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
#if defined(XGBOOST_USE_FEDERATED)
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
#else
if (rabit::IsDistributed()) {
if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
load_row_split = true;
@ -1560,44 +1556,42 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
API_END();
}
using xgboost::collective::Communicator;
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config { Json::Load(StringView{json_config}) };
Communicator::Init(config);
Json config{Json::Load(StringView{json_config})};
collective::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize() {
API_BEGIN();
Communicator::Finalize();
collective::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank() {
return Communicator::Get()->GetRank();
XGB_DLL int XGCommunicatorGetRank(void) {
return collective::GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize() {
return Communicator::Get()->GetWorldSize();
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return collective::GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed() {
return Communicator::Get()->IsDistributed();
XGB_DLL int XGCommunicatorIsDistributed(void) {
return collective::IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
Communicator::Get()->Print(message);
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = Communicator::Get()->GetProcessorName();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
@ -1605,16 +1599,14 @@ XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
Communicator::Get()->AllReduce(
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
static_cast<xgboost::collective::Operation>(enum_op));
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}

View File

@ -25,6 +25,7 @@
#include <cstdio>
#include <cstring>
#include <vector>
#include "collective/communicator-inl.h"
#include "common/common.h"
#include "common/config.h"
#include "common/io.h"
@ -156,7 +157,7 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
if (name_pred == "stdout") {
save_period = 0;
}
if (dsplit == 0 && rabit::IsDistributed()) {
if (dsplit == 0 && collective::IsDistributed()) {
dsplit = 2;
}
}
@ -186,26 +187,22 @@ class CLI {
kHelp
} print_info_ {kNone};
int ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
void ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
learner_.reset(Learner::Create(matrices));
int version = rabit::LoadCheckPoint();
if (version == 0) {
if (param_.model_in != CLIParam::kNull) {
this->LoadModel(param_.model_in, learner_.get());
learner_->SetParams(param_.cfg);
} else {
learner_->SetParams(param_.cfg);
}
}
learner_->Configure();
return version;
}
void CLITrain() {
const double tstart_data_load = dmlc::GetTime();
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
if (collective::IsDistributed()) {
std::string pname = collective::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
@ -230,48 +227,45 @@ class CLI {
eval_data_names.emplace_back("train");
}
// initialize the learner.
int32_t version = this->ResetLearner(cache_mats);
this->ResetLearner(cache_mats);
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
<< " sec";
// start training.
const double start = dmlc::GetTime();
int32_t version = 0;
for (int i = version / 2; i < param_.num_round; ++i) {
double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed
<< " sec elapsed";
learner_->UpdateOneIter(i, dtrain);
rabit::CheckPoint();
version += 1;
}
CHECK_EQ(version, rabit::VersionNumber());
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
if (rabit::IsDistributed()) {
if (rabit::GetRank() == 0) {
if (collective::IsDistributed()) {
if (collective::GetRank() == 0) {
LOG(TRACKER) << res;
}
} else {
LOG(CONSOLE) << res;
}
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
rabit::GetRank() == 0) {
collective::GetRank() == 0) {
std::ostringstream os;
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< i + 1 << ".model";
this->SaveModel(os.str(), learner_.get());
}
rabit::CheckPoint();
version += 1;
CHECK_EQ(version, rabit::VersionNumber());
}
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
<< " sec";
// always save final round
if ((param_.save_period == 0 ||
param_.num_round % param_.save_period != 0) &&
rabit::GetRank() == 0) {
collective::GetRank() == 0) {
std::ostringstream os;
if (param_.model_out == CLIParam::kNull) {
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
@ -467,7 +461,6 @@ class CLI {
return;
}
rabit::Init(argc, argv);
std::string config_path = argv[1];
common::ConfigParser cp(config_path);
@ -480,6 +473,13 @@ class CLI {
}
}
// Initialize the collective communicator.
Json json{JsonObject()};
for (auto& kv : cfg) {
json[kv.first] = String(kv.second);
}
collective::Init(json);
param_.Configure(cfg);
}
@ -517,7 +517,7 @@ class CLI {
}
~CLI() {
rabit::Finalize();
collective::Finalize();
}
};
} // namespace xgboost

View File

@ -0,0 +1,208 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
/*!
* \brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
inline void Init(Json const& config) {
Communicator::Init(config);
}
/*!
* \brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*/
inline void Finalize() { Communicator::Finalize(); }
/*!
* \brief Get rank of current process.
*
* \return Rank of the worker.
*/
inline int GetRank() { return Communicator::Get()->GetRank(); }
/*!
* \brief Get total number of processes.
*
* \return Total world size.
*/
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
/*!
* \brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
*/
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
/*!
* \brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
*/
inline void Print(char const *message) { Communicator::Get()->Print(message); }
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
/*!
* \brief Get the name of the processor.
*
* \return Name of the processor.
*/
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*/
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
static_cast<Operation>(op));
}
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
}
template <Operation op>
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
// Specialization for size_t, which is implementation defined, so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <Operation op, typename T,
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
inline void Allreduce(T *send_receive_buffer, size_t count) {
static_assert(sizeof(T) == sizeof(uint64_t), "");
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void Allreduce(float *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void Allreduce(double *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
} // namespace collective
} // namespace xgboost

View File

@ -3,6 +3,7 @@
*/
#include "communicator.h"
#include "noop_communicator.h"
#include "rabit_communicator.h"
#if defined(XGBOOST_USE_FEDERATED)
@ -12,14 +13,10 @@
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
void Communicator::Init(Json const& config) {
if (communicator_) {
LOG(FATAL) << "Communicator can only be initialized once.";
}
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {
@ -51,7 +48,7 @@ void Communicator::Init(Json const& config) {
#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
communicator_.reset(new NoOpCommunicator());
}
#endif

View File

@ -4,6 +4,7 @@
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#include "noop_communicator.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
communicator_.reset(new NoOpCommunicator());
device_ordinal_ = -1;
device_communicator_.reset(nullptr);
}

View File

@ -23,40 +23,6 @@ enum class DataType {
kDouble = 7
};
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/** @brief Defines the reduction operation. */
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };

View File

@ -21,7 +21,28 @@ class DeviceCommunicator {
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
/**
* @brief Gather variable-length values from all processes.

View File

@ -23,17 +23,28 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
~DeviceCommunicatorAdapter() override = default;
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(double);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
}
void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
}
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
}
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
@ -66,6 +77,20 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
}
private:
template <collective::DataType data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(T);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
int const device_ordinal_;
Communicator *communicator_;
/// Host buffer used to call communicator functions.

View File

@ -24,6 +24,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
if (world == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
@ -52,8 +56,15 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
~NcclDeviceCommunicator() override {
if (communicator_->GetWorldSize() == 1) {
return;
}
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
ncclCommDestroy(nccl_comm_);
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
@ -61,16 +72,28 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
}
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
ncclSum, nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
}
void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
}
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
}
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
@ -95,6 +118,9 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
void Synchronize() override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
@ -136,6 +162,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return id;
}
template <ncclDataType_t data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(T);
allreduce_calls_ += 1;
}
int const device_ordinal_;
Communicator *communicator_;
ncclComm_t nccl_comm_{};

View File

@ -0,0 +1,30 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
/**
* A no-op communicator, used for non-distributed training.
*/
class NoOpCommunicator : public Communicator {
public:
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
std::string GetProcessorName() override { return ""; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
};
} // namespace collective
} // namespace xgboost

View File

@ -1,139 +0,0 @@
/*!
* Copyright 2017-2019 XGBoost contributors
*
* \brief Utilities for CUDA.
*/
#ifdef XGBOOST_USE_NCCL
#include <nccl.h>
#endif // #ifdef XGBOOST_USE_NCCL
#include <sstream>
#include "device_helpers.cuh"
namespace dh {
constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
void GetCudaUUID(int device_ord, xgboost::common::Span<uint64_t, kUuidLength> uuid) {
cudaDeviceProp prob;
safe_cuda(cudaGetDeviceProperties(&prob, device_ord));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
}
std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
#ifdef XGBOOST_USE_NCCL
void NcclAllReducer::DoInit(int _device_ordinal) {
int32_t const rank = rabit::GetRank();
int32_t const world = rabit::GetWorldSize();
if (world == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
GetCudaUUID(_device_ordinal, s_this_uuid);
// No allgather yet.
rabit::Allreduce<rabit::op::Sum, uint64_t>(uuids.data(), uuids.size());
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);;
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] =
xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
id_ = GetUniqueId();
dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank));
safe_cuda(cudaStreamCreate(&stream_));
}
void NcclAllReducer::DoAllGather(void const *data, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf) {
int32_t world = rabit::GetWorldSize();
segments->clear();
segments->resize(world, 0);
segments->at(rabit::GetRank()) = length_bytes;
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0);
recvbuf->resize(total_bytes);
size_t offset = 0;
safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world; ++i) {
size_t as_bytes = segments->at(i);
safe_nccl(
ncclBroadcast(data, recvbuf->data().get() + offset,
as_bytes, ncclChar, i, comm_, stream_));
offset += as_bytes;
}
safe_nccl(ncclGroupEnd());
}
NcclAllReducer::~NcclAllReducer() {
if (initialised_) {
dh::safe_cuda(cudaStreamDestroy(stream_));
ncclCommDestroy(comm_);
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576;
}
}
#else
void RabitAllReducer::DoInit(int _device_ordinal) {
#if !defined(XGBOOST_USE_FEDERATED)
if (rabit::IsDistributed()) {
LOG(CONSOLE) << "XGBoost is not compiled with NCCL, falling back to Rabit.";
}
#endif
}
void RabitAllReducer::DoAllGather(void const *data, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf) {
size_t world = rabit::GetWorldSize();
segments->clear();
segments->resize(world, 0);
segments->at(rabit::GetRank()) = length_bytes;
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
recvbuf->resize(total_bytes);
sendrecvbuf_.reserve(total_bytes);
auto rank = rabit::GetRank();
size_t offset = 0;
for (int32_t i = 0; i < world; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank) {
safe_cuda(
cudaMemcpy(sendrecvbuf_.data() + offset, data, segments->at(rank), cudaMemcpyDefault));
}
rabit::Broadcast(sendrecvbuf_.data() + offset, as_bytes, i);
offset += as_bytes;
}
safe_cuda(cudaMemcpy(recvbuf->data().get(), sendrecvbuf_.data(), total_bytes, cudaMemcpyDefault));
}
#endif // XGBOOST_USE_NCCL
} // namespace dh

View File

@ -19,7 +19,6 @@
#include <thrust/unique.h>
#include <thrust/binary_search.h>
#include <rabit/rabit.h>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh>
@ -36,6 +35,7 @@
#include "xgboost/span.h"
#include "xgboost/global_config.h"
#include "../collective/communicator-inl.h"
#include "common.h"
#include "algorithm.cuh"
@ -404,7 +404,7 @@ inline detail::MemoryLogger &GlobalMemoryLogger() {
// dh::DebugSyncDevice(__FILE__, __LINE__);
inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
if (file != "" && line != -1) {
auto rank = rabit::GetRank();
auto rank = xgboost::collective::GetRank();
LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line;
}
safe_cuda(cudaDeviceSynchronize());
@ -423,7 +423,7 @@ using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
inline void ThrowOOMError(std::string const& err, size_t bytes) {
auto device = CurrentDevice();
auto rank = rabit::GetRank();
auto rank = xgboost::collective::GetRank();
std::stringstream ss;
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
<< "- Free memory: " << AvailableMemory(device) << "\n"
@ -737,512 +737,6 @@ using TypedDiscard =
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;
/**
* \class AllReducer
*
* \brief All reducer class that manages its own communication group and
* streams. Must be initialised before use. If XGBoost is compiled without NCCL,
* this falls back to use Rabit.
*/
template <typename AllReducer>
class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
public:
virtual ~AllReducerBase() = default;
/**
* \brief Initialise with the desired device ordinal for this allreducer.
*
* \param device_ordinal The device ordinal.
*/
void Init(int _device_ordinal) {
device_ordinal_ = _device_ordinal;
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (rabit::GetWorldSize() == 1) {
return;
}
this->Underlying().DoInit(_device_ordinal);
initialised_ = true;
}
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
*
* \param data Buffer storing the input data.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void AllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf);
}
/**
* \brief Allgather. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param data Buffer storing the input data.
* \param length Size of input data in bytes.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void AllGather(uint32_t const *data, size_t length,
dh::caching_device_vector<uint32_t> *recvbuf) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllGather(data, length, recvbuf);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(float);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
*
* \param count Number of.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of.
*/
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(int64_t);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(uint32_t);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(uint64_t);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* Specialization for size_t, which is implementation defined so it might or might not
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(T);
allreduce_calls_ += 1;
}
/**
* \fn void Synchronize()
*
* \brief Synchronizes the entire communication group.
*/
void Synchronize() {
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoSynchronize();
}
protected:
bool initialised_{false};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
private:
int device_ordinal_{-1};
};
#ifdef XGBOOST_USE_NCCL
class NcclAllReducer : public AllReducerBase<NcclAllReducer> {
public:
friend class AllReducerBase<NcclAllReducer>;
~NcclAllReducer() override;
private:
/**
* \brief Initialise with the desired device ordinal for this communication
* group.
*
* \param device_ordinal The device ordinal.
*/
void DoInit(int _device_ordinal);
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
*
* \param data Buffer storing the input data.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf);
/**
* \brief Allgather. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param data Buffer storing the input data.
* \param length Size of input data in bytes.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(uint32_t const *data, size_t length,
dh::caching_device_vector<uint32_t> *recvbuf) {
size_t world = rabit::GetWorldSize();
recvbuf->resize(length * world);
safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
*
* \param count Number of.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of.
*/
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* Specialization for size_t, which is implementation defined so it might or might not
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
}
/**
* \brief Synchronizes the entire communication group.
*/
void DoSynchronize() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (rabit::GetRank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
rabit::Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}
ncclComm_t comm_;
cudaStream_t stream_;
ncclUniqueId id_;
};
using AllReducer = NcclAllReducer;
#else
class RabitAllReducer : public AllReducerBase<RabitAllReducer> {
public:
friend class AllReducerBase<RabitAllReducer>;
private:
/**
* \brief Initialise with the desired device ordinal for this allreducer.
*
* \param device_ordinal The device ordinal.
*/
static void DoInit(int _device_ordinal);
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
*
* \param data Buffer storing the input data.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf);
/**
* \brief Allgather. Use in exactly the same way as NCCL.
*
* \param data Buffer storing the input data.
* \param length Size of input data in bytes.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(uint32_t *data, size_t length, dh::caching_device_vector<uint32_t> *recvbuf) {
size_t world = rabit::GetWorldSize();
auto total_size = length * world;
recvbuf->resize(total_size);
sendrecvbuf_.reserve(total_size);
auto rank = rabit::GetRank();
safe_cuda(cudaMemcpy(sendrecvbuf_.data() + rank * length, data, length, cudaMemcpyDefault));
rabit::Allgather(sendrecvbuf_.data(), total_size, rank * length, length, length);
safe_cuda(cudaMemcpy(data, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* Specialization for size_t, which is implementation defined so it might or might not
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Synchronizes the allreducer.
*/
void DoSynchronize() {}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* Copy the device buffer to host, call rabit allreduce, then copy the buffer back
* to device.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T>
void RabitAllReduceSum(const T *sendbuff, T *recvbuff, int count) {
auto total_size = count * sizeof(T);
sendrecvbuf_.reserve(total_size);
safe_cuda(cudaMemcpy(sendrecvbuf_.data(), sendbuff, total_size, cudaMemcpyDefault));
rabit::Allreduce<rabit::op::Sum>(reinterpret_cast<T*>(sendrecvbuf_.data()), count);
safe_cuda(cudaMemcpy(recvbuff, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
}
/// Host buffer used to call rabit functions.
std::vector<char> sendrecvbuf_{};
};
using AllReducer = RabitAllReducer;
#endif
template <typename VectorT, typename T = typename VectorT::value_type,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(

View File

@ -3,19 +3,14 @@
* \file hist_util.cc
*/
#include <dmlc/timer.h>
#include <dmlc/omp.h>
#include <rabit/rabit.h>
#include <numeric>
#include <vector>
#include "xgboost/base.h"
#include "../common/common.h"
#include "hist_util.h"
#include "random.h"
#include "column_matrix.h"
#include "quantile.h"
#include "../data/gradient_index.h"
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
#include <xmmintrin.h>

View File

@ -6,10 +6,10 @@
#include <limits>
#include <utility>
#include "../collective/communicator-inl.h"
#include "../data/adapter.h"
#include "categorical.h"
#include "hist_util.h"
#include "rabit/rabit.h"
namespace xgboost {
namespace common {
@ -144,8 +144,8 @@ struct QuantileAllreduce {
void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_threads,
std::vector<std::set<float>> *p_categories) {
auto &categories = *p_categories;
auto world_size = rabit::GetWorldSize();
auto rank = rabit::GetRank();
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1) {
return;
}
@ -163,7 +163,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
rabit::Allreduce<rabit::op::Sum>(global_feat_ptrs.data(), global_feat_ptrs.size());
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
global_feat_ptrs.size());
// move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back();
@ -176,7 +177,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
// indptr for indexing workers
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
rabit::Allreduce<rabit::op::Sum>(global_worker_ptr.data(), global_worker_ptr.size());
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
global_worker_ptr.size());
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
// total number of categories in all workers with all features
auto gtotal = global_worker_ptr.back();
@ -188,7 +190,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers.
rabit::Allreduce<rabit::op::Sum>(global_categories.data(), global_categories.size());
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
global_categories.size());
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories.size()};
ParallelFor(categories.size(), n_threads, [&](auto fidx) {
@ -217,8 +220,8 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
std::vector<typename WQSketch::Entry> *p_global_sketches) {
auto &worker_segments = *p_worker_segments;
worker_segments.resize(1, 0);
auto world = rabit::GetWorldSize();
auto rank = rabit::GetRank();
auto world = collective::GetWorldSize();
auto rank = collective::GetRank();
auto n_columns = sketches_.size();
// get the size of each feature.
@ -237,7 +240,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
// Gather all column pointers
rabit::Allreduce<rabit::op::Sum>(sketches_scan.data(), sketches_scan.size());
collective::Allreduce<collective::Operation::kSum>(sketches_scan.data(), sketches_scan.size());
for (int32_t i = 0; i < world; ++i) {
size_t back = (i + 1) * (n_columns + 1) - 1;
auto n_entries = sketches_scan.at(back);
@ -265,7 +268,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
"Unexpected size of sketch entry.");
rabit::Allreduce<rabit::op::Sum>(
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
}
@ -277,7 +280,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
monitor_.Start(__func__);
size_t n_columns = sketches_.size();
rabit::Allreduce<rabit::op::Max>(&n_columns, 1);
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
AllreduceCategories(feature_types_, n_threads_, &categories_);
@ -291,7 +294,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
// Prune the intermediate num cuts for synchronization.
std::vector<bst_row_t> global_column_size(columns_size_);
rabit::Allreduce<rabit::op::Sum>(global_column_size.data(), global_column_size.size());
collective::Allreduce<collective::Operation::kSum>(global_column_size.data(),
global_column_size.size());
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(
@ -311,7 +315,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
num_cuts[i] = intermediate_num_cuts;
});
auto world = rabit::GetWorldSize();
auto world = collective::GetWorldSize();
if (world == 1) {
monitor_.Stop(__func__);
return;

View File

@ -12,6 +12,8 @@
#include <memory>
#include <utility>
#include "../collective/communicator.h"
#include "../collective/device_communicator.cuh"
#include "categorical.h"
#include "common.h"
#include "device_helpers.cuh"
@ -501,47 +503,41 @@ void SketchContainer::FixError() {
void SketchContainer::AllReduce() {
dh::safe_cuda(cudaSetDevice(device_));
auto world = rabit::GetWorldSize();
auto world = collective::GetWorldSize();
if (world == 1) {
return;
}
timer_.Start(__func__);
if (!reducer_) {
reducer_ = std::make_shared<dh::AllReducer>();
reducer_->Init(device_);
}
auto* communicator = collective::Communicator::GetDevice(device_);
// Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_;
rabit::Allreduce<rabit::op::Sum>(&global_sum_rows, 1);
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
size_t intermediate_num_cuts =
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
this->Prune(intermediate_num_cuts);
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
size_t n = d_columns_ptr.size();
rabit::Allreduce<rabit::op::Max>(&n, 1);
collective::Allreduce<collective::Operation::kMax>(&n, 1);
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
// Get the columns ptr from all workers
dh::device_vector<SketchContainer::OffsetT> gathered_ptrs;
gathered_ptrs.resize(d_columns_ptr.size() * world, 0);
size_t rank = rabit::GetRank();
size_t rank = collective::GetRank();
auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset);
reducer_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.data().get(),
gathered_ptrs.size());
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
// Get the data from all workers.
std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf;
reducer_->AllGather(this->Current().data().get(),
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths,
&recvbuf);
reducer_->Synchronize();
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
&recv_lengths, &recvbuf);
communicator->Synchronize();
// Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf);

View File

@ -37,7 +37,6 @@ class SketchContainer {
private:
Monitor timer_;
std::shared_ptr<dh::AllReducer> reducer_;
HostDeviceVector<FeatureType> feature_types_;
bst_row_t num_rows_;
bst_feature_t num_columns_;
@ -93,15 +92,12 @@ class SketchContainer {
* \param num_columns Total number of columns in dataset.
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
* \param device GPU ID.
* \param reducer Optional initialised reducer. Useful for speeding up testing.
*/
SketchContainer(HostDeviceVector<FeatureType> const &feature_types,
int32_t max_bin, bst_feature_t num_columns,
bst_row_t num_rows, int32_t device,
std::shared_ptr<dh::AllReducer> reducer = nullptr)
bst_row_t num_rows, int32_t device)
: num_rows_{num_rows},
num_columns_{num_columns}, num_bins_{max_bin}, device_{device},
reducer_(std::move(reducer)) {
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
CHECK_GE(device, 0);
// Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_);

View File

@ -7,20 +7,21 @@
#ifndef XGBOOST_COMMON_RANDOM_H_
#define XGBOOST_COMMON_RANDOM_H_
#include <rabit/rabit.h>
#include <xgboost/logging.h>
#include <algorithm>
#include <functional>
#include <vector>
#include <limits>
#include <map>
#include <memory>
#include <numeric>
#include <random>
#include <utility>
#include <vector>
#include "xgboost/host_device_vector.h"
#include "../collective/communicator-inl.h"
#include "common.h"
#include "xgboost/host_device_vector.h"
namespace xgboost {
namespace common {
@ -143,7 +144,7 @@ class ColumnSampler {
*/
ColumnSampler() {
uint32_t seed = common::GlobalRandom()();
rabit::Broadcast(&seed, sizeof(seed), 0);
collective::Broadcast(&seed, sizeof(seed), 0);
rng_.seed(seed);
}

View File

@ -1,14 +1,13 @@
/*!
* Copyright by Contributors 2019
*/
#include <rabit/rabit.h>
#include <algorithm>
#include <type_traits>
#include <utility>
#include <vector>
#include <sstream>
#include "timer.h"
#include <sstream>
#include <utility>
#include "../collective/communicator-inl.h"
#if defined(XGBOOST_USE_NVTX)
#include <nvToolsExt.h>
#endif // defined(XGBOOST_USE_NVTX)
@ -54,7 +53,7 @@ void Monitor::PrintStatistics(StatMap const& statistics) const {
void Monitor::Print() const {
if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; }
auto rank = rabit::GetRank();
auto rank = collective::GetRank();
StatMap stat_map;
for (auto const &kv : statistics_map_) {
stat_map[kv.first] = std::make_pair(

View File

@ -2,36 +2,36 @@
* Copyright 2015-2022 by XGBoost Contributors
* \file data.cc
*/
#include "xgboost/data.h"
#include <dmlc/registry.h>
#include <array>
#include <cstring>
#include "dmlc/io.h"
#include "xgboost/data.h"
#include "xgboost/c_api.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/logging.h"
#include "xgboost/version_config.h"
#include "xgboost/learner.h"
#include "xgboost/string_view.h"
#include "sparse_page_writer.h"
#include "simple_dmatrix.h"
#include "../collective/communicator-inl.h"
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
#include "../common/math.h"
#include "../common/numeric.h"
#include "../common/version.h"
#include "../common/group_data.h"
#include "../common/threading_utils.h"
#include "../common/version.h"
#include "../data/adapter.h"
#include "../data/iterative_dmatrix.h"
#include "file_iterator.h"
#include "validation.h"
#include "./sparse_page_source.h"
#include "./sparse_page_dmatrix.h"
#include "./sparse_page_source.h"
#include "dmlc/io.h"
#include "file_iterator.h"
#include "simple_dmatrix.h"
#include "sparse_page_writer.h"
#include "validation.h"
#include "xgboost/c_api.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "xgboost/logging.h"
#include "xgboost/string_view.h"
#include "xgboost/version_config.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
@ -793,12 +793,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
size_t pos = cache_shards[i].rfind('.');
if (pos == std::string::npos) {
os << cache_shards[i]
<< ".r" << rabit::GetRank()
<< "-" << rabit::GetWorldSize();
<< ".r" << collective::GetRank()
<< "-" << collective::GetWorldSize();
} else {
os << cache_shards[i].substr(0, pos)
<< ".r" << rabit::GetRank()
<< "-" << rabit::GetWorldSize()
<< ".r" << collective::GetRank()
<< "-" << collective::GetWorldSize()
<< cache_shards[i].substr(pos, cache_shards[i].length());
}
if (i + 1 != cache_shards.size()) {
@ -821,8 +821,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
int partid = 0, npart = 1;
if (load_row_split) {
partid = rabit::GetRank();
npart = rabit::GetWorldSize();
partid = collective::GetRank();
npart = collective::GetWorldSize();
} else {
// test option to load in part
npart = 1;
@ -877,7 +877,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
/* sync up number of features after matrix loaded.
* partitioned data will fail the train/val validation check
* since partitioned data not knowing the real number of features. */
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
return dmat;
}

View File

@ -3,13 +3,11 @@
*/
#include "iterative_dmatrix.h"
#include <rabit/rabit.h>
#include <algorithm> // std::copy
#include "../collective/communicator-inl.h"
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h"
#include "../common/hist_util.h" // common::HistogramCuts
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "gradient_index.h"
#include "proxy_dmatrix.h"
@ -140,7 +138,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
// We use do while here as the first batch is fetched in ctor
if (n_features == 0) {
n_features = num_cols();
rabit::Allreduce<rabit::op::Max>(&n_features, 1);
collective::Allreduce<collective::Operation::kMax>(&n_features, 1);
column_sizes.resize(n_features);
info_.num_col_ = n_features;
} else {
@ -157,7 +155,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";

View File

@ -62,7 +62,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
dh::safe_cuda(cudaSetDevice(get_device()));
if (cols == 0) {
cols = num_cols();
rabit::Allreduce<rabit::op::Max>(&cols, 1);
collective::Allreduce<collective::Operation::kMax>(&cols, 1);
this->info_.num_col_ = cols;
} else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
@ -166,7 +166,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
iter.Reset();
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
}
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {

View File

@ -189,7 +189,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
if (adapter->NumRows() == kAdapterUnknownSize) {
using IteratorAdapterT
@ -322,7 +322,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
}
// Synchronise worker columns
info_.num_col_ = adapter->NumColumns();
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.num_row_ = total_batch_size;
info_.num_nonzero_ = data_vec.size();
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);

View File

@ -35,7 +35,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
}
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,

View File

@ -5,6 +5,8 @@
* \author Tianqi Chen
*/
#include "./sparse_page_dmatrix.h"
#include "../collective/communicator-inl.h"
#include "./simple_batch_iterator.h"
#include "gradient_index.h"
@ -46,8 +48,8 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
cache_prefix_{std::move(cache_prefix)} {
ctx_.nthread = nthreads;
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
if (rabit::IsDistributed()) {
cache_prefix_ += ("-r" + std::to_string(rabit::GetRank()));
if (collective::IsDistributed()) {
cache_prefix_ += ("-r" + std::to_string(collective::GetRank()));
}
DMatrixProxy *proxy = MakeProxy(proxy_);
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
@ -94,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
this->info_.num_col_ = n_features;
this->info_.num_nonzero_ = nnz;
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
CHECK_NE(info_.num_col_, 0);
}

View File

@ -14,7 +14,6 @@
#include <map>
#include <memory>
#include "rabit/rabit.h"
#include "xgboost/base.h"
#include "xgboost/data.h"

View File

@ -135,7 +135,7 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
return;
}
if (rabit::IsDistributed()) {
if (collective::IsDistributed()) {
LOG(INFO) << "Tree method is automatically selected to be 'approx' "
"for distributed training.";
tparam_.tree_method = TreeMethod::kApprox;

View File

@ -23,6 +23,7 @@
#include <utility>
#include <vector>
#include "collective/communicator-inl.h"
#include "common/charconv.h"
#include "common/common.h"
#include "common/io.h"
@ -478,7 +479,7 @@ class LearnerConfiguration : public Learner {
// add additional parameters
// These are cosntraints that need to be satisfied.
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow;
}
@ -757,7 +758,7 @@ class LearnerConfiguration : public Learner {
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
}
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
collective::Allreduce<collective::Operation::kMax>(&num_feature, 1);
if (num_feature > mparam_.num_feature) {
mparam_.num_feature = num_feature;
}
@ -1083,7 +1084,7 @@ class LearnerIO : public LearnerConfiguration {
cfg_.insert(n.cbegin(), n.cend());
// copy dsplit from config since it will not run again during restore
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow;
}
@ -1228,7 +1229,7 @@ class LearnerImpl : public LearnerIO {
}
// Configuration before data is known.
void CheckDataSplitMode() {
if (rabit::IsDistributed()) {
if (collective::IsDistributed()) {
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
if (tparam_.dsplit == DataSplitMode::kCol) {
@ -1488,7 +1489,7 @@ class LearnerImpl : public LearnerIO {
}
if (p_fmat->Info().num_row_ == 0) {
LOG(WARNING) << "Empty dataset at worker: " << rabit::GetRank();
LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank();
}
}

View File

@ -4,14 +4,12 @@
* \brief Implementation of loggers.
* \author Tianqi Chen
*/
#include <rabit/rabit.h>
#include <iostream>
#include <map>
#include "xgboost/parameter.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "collective/communicator-inl.h"
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// Override logging mechanism for non-R interfaces
@ -32,7 +30,7 @@ ConsoleLogger::~ConsoleLogger() {
TrackerLogger::~TrackerLogger() {
log_stream_ << '\n';
rabit::TrackerPrint(log_stream_.str());
collective::Print(log_stream_.str());
}
} // namespace xgboost

View File

@ -1,27 +1,23 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#include "auc.h"
#include <algorithm>
#include <array>
#include <atomic>
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <numeric>
#include <utility>
#include <tuple>
#include <utility>
#include <vector>
#include "rabit/rabit.h"
#include "xgboost/linalg.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/metric.h"
#include "auc.h"
#include "../common/common.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
#include "xgboost/metric.h"
namespace xgboost {
namespace metric {
@ -117,7 +113,8 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
// we have 2 averages going in here, first is among workers, second is among
// classes. allreduce sums up fp/tp auc for each class.
rabit::Allreduce<rabit::op::Sum>(results.Values().data(), results.Values().size());
collective::Allreduce<collective::Operation::kSum>(results.Values().data(),
results.Values().size());
double auc_sum{0};
double tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) {
@ -265,7 +262,7 @@ class EvalAUC : public Metric {
}
// We use the global size to handle empty dataset.
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size());
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
if (meta[0] == 0) {
// Empty across all workers, which is not supported.
auc = std::numeric_limits<double>::quiet_NaN();
@ -287,7 +284,7 @@ class EvalAUC : public Metric {
}
std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
collective::Allreduce<collective::Operation::kSum>(results.data(), results.size());
auc = results[0];
valid_groups = static_cast<uint32_t>(results[1]);
@ -316,7 +313,7 @@ class EvalAUC : public Metric {
}
double local_area = fp * tp;
std::array<double, 2> result{auc, local_area};
rabit::Allreduce<rabit::op::Sum>(result.data(), result.size());
collective::Allreduce<collective::Operation::kSum>(result.data(), result.size());
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
if (local_area <= 0) {
// the dataset across all workers have only positive or negative sample

View File

@ -11,11 +11,10 @@
#include <utility>
#include <tuple>
#include "rabit/rabit.h"
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "auc.h"
#include "../common/device_helpers.cuh"
#include "../collective/device_communicator.cuh"
#include "../common/ranking_utils.cuh"
namespace xgboost {
@ -46,9 +45,8 @@ struct DeviceAUCCache {
dh::device_vector<size_t> unique_idx;
// p^T: transposed prediction matrix, used by MultiClassAUC
dh::device_vector<float> predts_t;
std::unique_ptr<dh::AllReducer> reducer;
void Init(common::Span<float const> predts, bool is_multi, int32_t device) {
void Init(common::Span<float const> predts, bool is_multi) {
if (sorted_idx.size() != predts.size()) {
sorted_idx.resize(predts.size());
fptp.resize(sorted_idx.size());
@ -58,10 +56,6 @@ struct DeviceAUCCache {
predts_t.resize(sorted_idx.size());
}
}
if (is_multi && !reducer) {
reducer.reset(new dh::AllReducer);
reducer->Init(device);
}
}
};
@ -72,7 +66,7 @@ void InitCacheOnce(common::Span<float const> predts, int32_t device,
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, is_multi, device);
cache->Init(predts, is_multi);
}
/**
@ -205,9 +199,11 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
common::Span<double> tp, common::Span<double> auc,
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
dh::XGBDeviceAllocator<char> alloc;
if (rabit::IsDistributed()) {
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
auto* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(results.data(), results.size());
}
auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {

View File

@ -10,13 +10,13 @@
#include <tuple>
#include <utility>
#include "rabit/rabit.h"
#include "xgboost/base.h"
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "xgboost/metric.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/metric.h"
#include "xgboost/span.h"
namespace xgboost {
namespace metric {
@ -101,7 +101,7 @@ XGBOOST_DEVICE inline double CalcDeltaPRAUC(double fp_prev, double fp,
inline void InvalidGroupAUC() {
LOG(INFO) << "Invalid group with less than 3 samples is found on worker "
<< rabit::GetRank() << ". Calculating AUC value requires at "
<< collective::GetRank() << ". Calculating AUC value requires at "
<< "least 2 pairs of samples.";
}

View File

@ -7,11 +7,11 @@
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
*/
#include <dmlc/registry.h>
#include <rabit/rabit.h>
#include <xgboost/metric.h>
#include <cmath>
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "../common/math.h"
#include "../common/pseudo_huber.h"
@ -196,8 +196,8 @@ class PseudoErrorLoss : public Metric {
return std::make_tuple(v, wt);
});
double dat[2]{result.Residue(), result.Weights()};
if (rabit::IsDistributed()) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
if (collective::IsDistributed()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
}
@ -365,7 +365,7 @@ struct EvalEWiseBase : public Metric {
});
double dat[2]{result.Residue(), result.Weights()};
rabit::Allreduce<rabit::op::Sum>(dat, 2);
collective::Allreduce<collective::Operation::kSum>(dat, 2);
return Policy::GetFinal(dat[0], dat[1]);
}

View File

@ -4,15 +4,14 @@
* \brief evaluation metrics for multiclass classification.
* \author Kailong Chen, Tianqi Chen
*/
#include <rabit/rabit.h>
#include <xgboost/metric.h>
#include <atomic>
#include <cmath>
#include "metric_common.h"
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#if defined(XGBOOST_USE_CUDA)
@ -185,7 +184,7 @@ struct EvalMClassBase : public Metric {
dat[0] = result.Residue();
dat[1] = result.Weights();
}
rabit::Allreduce<rabit::op::Sum>(dat, 2);
collective::Allreduce<collective::Operation::kSum>(dat, 2);
return Derived::GetFinal(dat[0], dat[1]);
}
/*!

View File

@ -20,17 +20,17 @@
// corresponding headers that brings in those function declaration can't be included with CUDA).
// This precludes the CPU and GPU logic to coexist inside a .cu file
#include <rabit/rabit.h>
#include <xgboost/metric.h>
#include <dmlc/registry.h>
#include <cmath>
#include <xgboost/metric.h>
#include <cmath>
#include <vector>
#include "xgboost/host_device_vector.h"
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "metric_common.h"
#include "xgboost/host_device_vector.h"
namespace {
@ -103,7 +103,7 @@ struct EvalAMS : public Metric {
}
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
CHECK(!rabit::IsDistributed()) << "metric AMS do not support distributed evaluation";
CHECK(!collective::IsDistributed()) << "metric AMS do not support distributed evaluation";
using namespace std; // NOLINT(*)
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
@ -216,10 +216,10 @@ struct EvalRank : public Metric, public EvalRankConfig {
exc.Rethrow();
}
if (rabit::IsDistributed()) {
if (collective::IsDistributed()) {
double dat[2]{sum_metric, static_cast<double>(ngroups)};
// approximately estimate the metric using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2);
collective::Allreduce<collective::Operation::kSum>(dat, 2);
return dat[0] / dat[1];
} else {
return sum_metric / ngroups;
@ -341,7 +341,7 @@ struct EvalCox : public Metric {
public:
EvalCox() = default;
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
CHECK(!rabit::IsDistributed()) << "Cox metric does not support distributed evaluation";
CHECK(!collective::IsDistributed()) << "Cox metric does not support distributed evaluation";
using namespace std; // NOLINT(*)
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());

View File

@ -4,15 +4,12 @@
* \brief prediction rank based metrics.
* \author Kailong Chen, Tianqi Chen
*/
#include <rabit/rabit.h>
#include <dmlc/registry.h>
#include <xgboost/metric.h>
#include <xgboost/host_device_vector.h>
#include <thrust/iterator/discard_iterator.h>
#include <cmath>
#include <array>
#include <vector>
#include "metric_common.h"

View File

@ -5,7 +5,6 @@
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
*/
#include <rabit/rabit.h>
#include <dmlc/registry.h>
#include <memory>
@ -16,6 +15,7 @@
#include "xgboost/host_device_vector.h"
#include "metric_common.h"
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/survival_util.h"
#include "../common/threading_utils.h"
@ -214,7 +214,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
info.labels_upper_bound_, preds);
double dat[2]{result.Residue(), result.Weights()};
rabit::Allreduce<rabit::op::Sum>(dat, 2);
collective::Allreduce<collective::Operation::kSum>(dat, 2);
return Policy::GetFinal(dat[0], dat[1]);
}

View File

@ -7,8 +7,8 @@
#include <limits>
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "rabit/rabit.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/tree_model.h"
@ -39,7 +39,7 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
auto const& h_node_idx = nidx;
size_t n_leaf{h_node_idx.size()};
rabit::Allreduce<rabit::op::Max>(&n_leaf, 1);
collective::Allreduce<collective::Operation::kMax>(&n_leaf, 1);
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
if (quantiles.empty()) {
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
@ -49,12 +49,12 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
std::vector<int32_t> n_valids(quantiles.size());
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
rabit::Allreduce<rabit::op::Sum>(n_valids.data(), n_valids.size());
collective::Allreduce<collective::Operation::kSum>(n_valids.data(), n_valids.size());
// convert to 0 for all reduce
std::replace_if(
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
// use the mean value
rabit::Allreduce<rabit::op::Sum>(quantiles.data(), quantiles.size());
collective::Allreduce<collective::Operation::kSum>(quantiles.data(), quantiles.size());
for (size_t i = 0; i < n_leaf; ++i) {
if (n_valids[i] > 0) {
quantiles[i] /= static_cast<float>(n_valids[i]);

View File

@ -724,8 +724,8 @@ class MeanAbsoluteError : public ObjFunction {
}
// Weighted average base score across all workers
rabit::Allreduce<rabit::op::Sum>(out.Values().data(), out.Values().size());
rabit::Allreduce<rabit::op::Sum>(&w, 1);
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
collective::Allreduce<collective::Operation::kSum>(&w, 1);
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
[w](float v) { return v / w; });

View File

@ -84,11 +84,11 @@ GradientQuantizer::GradientQuantizer(common::Span<GradientPair const> gpair) {
// Treat pair as array of 4 primitive types to allreduce
using ReduceT = typename decltype(p.first)::ValueT;
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
rabit::Allreduce<rabit::op::Sum, ReduceT>(reinterpret_cast<ReduceT*>(&p), 4);
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<ReduceT*>(&p), 4);
GradientPair positive_sum{p.first}, negative_sum{p.second};
std::size_t total_rows = gpair.size();
rabit::Allreduce<rabit::op::Sum>(&total_rows, 1);
collective::Allreduce<collective::Operation::kSum>(&total_rows, 1);
auto histogram_rounding = GradientSumT{
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), total_rows),

View File

@ -8,10 +8,10 @@
#include <limits>
#include <vector>
#include "../../collective/communicator-inl.h"
#include "../../common/hist_util.h"
#include "../../data/gradient_index.h"
#include "expand_entry.h"
#include "rabit/rabit.h"
#include "xgboost/tree_model.h"
namespace xgboost {
@ -202,7 +202,8 @@ class HistogramBuilder {
}
});
rabit::Allreduce<rabit::op::Sum>(reinterpret_cast<double*>(this->hist_[starting_index].data()),
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()),
builder_.GetNumBins() * sync_count * 2);
ParallelSubtractionHist(space, nodes_for_explicit_hist_build,

View File

@ -74,7 +74,7 @@ class GloablApproxBuilder {
}
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
rabit::IsDistributed());
collective::IsDistributed());
monitor_->Stop(__func__);
}
@ -88,7 +88,7 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
std::vector<CPUExpandEntry> nodes{best};
size_t i = 0;
auto space = ConstructHistSpace(partitioner_, nodes);

View File

@ -4,8 +4,6 @@
* \brief use columnwise update to construct a tree
* \author Tianqi Chen
*/
#include <rabit/rabit.h>
#include <memory>
#include <vector>
#include <cmath>
#include <algorithm>
@ -100,7 +98,7 @@ class ColMaker: public TreeUpdater {
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree *> &trees) override {
if (rabit::IsDistributed()) {
if (collective::IsDistributed()) {
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
"support distributed training.";
}

View File

@ -19,6 +19,7 @@
#include "xgboost/span.h"
#include "xgboost/json.h"
#include "../collective/device_communicator.cuh"
#include "../common/io.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
@ -528,12 +529,11 @@ struct GPUHistMakerDevice {
}
// num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) {
void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) {
monitor.Start("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
reducer->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
reinterpret_cast<ReduceT*>(d_node_hist),
communicator->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms);
monitor.Stop("AllReduce");
@ -542,8 +542,8 @@ struct GPUHistMakerDevice {
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, dh::AllReducer* reducer,
const RegTree& tree) {
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates,
collective::DeviceCommunicator* communicator, const RegTree& tree) {
if (candidates.empty()) return;
// Some nodes we will manually compute histograms
// others we will do by subtraction
@ -574,7 +574,7 @@ struct GPUHistMakerDevice {
// Reduce all in one go
// This gives much better latency in a distributed setting
// when processing a large batch
this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size());
this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size());
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = hist_nidx.at(i);
@ -584,7 +584,7 @@ struct GPUHistMakerDevice {
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer, 1);
this->AllReduceHist(subtraction_trick_nidx, communicator, 1);
}
}
}
@ -593,7 +593,7 @@ struct GPUHistMakerDevice {
RegTree& tree = *p_tree;
// Sanity check - have we created a leaf with no training instances?
if (!rabit::IsDistributed() && row_partitioner) {
if (!collective::IsDistributed() && row_partitioner) {
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
<< "No training instances in this leaf!";
}
@ -642,7 +642,7 @@ struct GPUHistMakerDevice {
parent.RightChild());
}
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
@ -650,11 +650,11 @@ struct GPUHistMakerDevice {
GradientPairPrecise root_sum =
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(),
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2);
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double*>(&root_sum), 2);
hist.AllocateHistograms({kRootNIdx});
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, reducer, 1);
this->AllReduceHist(kRootNIdx, communicator, 1);
// Remember root stats
node_sum_gradients[kRootNIdx] = root_sum;
@ -669,7 +669,7 @@ struct GPUHistMakerDevice {
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
RegTree* p_tree, dh::AllReducer* reducer,
RegTree* p_tree, collective::DeviceCommunicator* communicator,
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree;
// Process maximum 32 nodes at a time
@ -680,7 +680,7 @@ struct GPUHistMakerDevice {
monitor.Stop("Reset");
monitor.Start("InitRoot");
driver.Push({ this->InitRoot(p_tree, reducer) });
driver.Push({ this->InitRoot(p_tree, communicator) });
monitor.Stop("InitRoot");
// The set of leaves that can be expanded asynchronously
@ -707,7 +707,7 @@ struct GPUHistMakerDevice {
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(filtered_expand_set, reducer, tree);
this->BuildHistLeftRight(filtered_expand_set, communicator, tree);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
@ -789,11 +789,10 @@ class GPUHistMaker : public TreeUpdater {
void InitDataOnce(DMatrix* dmat) {
CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
info_ = &dmat->Info();
reducer_.Init({ctx_->gpu_id}); // NOLINT
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
BatchParam batch_param{
ctx_->gpu_id,
@ -823,12 +822,12 @@ class GPUHistMaker : public TreeUpdater {
void CheckTreesSynchronized(RegTree* local_tree) const {
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = rabit::GetRank();
int rank = collective::GetRank();
if (rank == 0) {
local_tree->Save(&fs);
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
collective::Broadcast(&s_model, 0);
RegTree reference_tree{}; // rank 0 tree
reference_tree.Load(&fs);
CHECK(*local_tree == reference_tree);
@ -841,7 +840,8 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Stop("InitData");
gpair->SetDevice(ctx_->gpu_id);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
}
bool UpdatePredictionCache(const DMatrix* data,
@ -867,8 +867,6 @@ class GPUHistMaker : public TreeUpdater {
GPUHistMakerTrainParam hist_maker_param_;
dh::AllReducer reducer_;
DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
ObjInfo task_;

View File

@ -4,16 +4,13 @@
* \brief prune a tree given the statistics
* \author Tianqi Chen
*/
#include <rabit/rabit.h>
#include <xgboost/tree_updater.h>
#include <string>
#include <memory>
#include "xgboost/base.h"
#include "xgboost/json.h"
#include "./param.h"
#include "../common/io.h"
#include "../common/timer.h"
namespace xgboost {
namespace tree {

View File

@ -6,19 +6,12 @@
*/
#include "./updater_quantile_hist.h"
#include <rabit/rabit.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "constraints.h"
#include "hist/evaluate_splits.h"
#include "param.h"
@ -103,7 +96,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot(
for (auto const &grad : gpair_h) {
grad_stat.Add(grad.GetGrad(), grad.GetHess());
}
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&grad_stat), 2);
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&grad_stat), 2);
}
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
@ -320,7 +313,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
++page_id;
}
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
rabit::IsDistributed());
collective::IsDistributed());
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)

View File

@ -7,7 +7,6 @@
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
#include <rabit/rabit.h>
#include <xgboost/tree_updater.h>
#include <algorithm>

View File

@ -4,17 +4,17 @@
* \brief refresh the statistics and leaf value on the tree on the dataset
* \author Tianqi Chen
*/
#include <rabit/rabit.h>
#include <xgboost/tree_updater.h>
#include <vector>
#include <limits>
#include <vector>
#include "xgboost/json.h"
#include "./param.h"
#include "../collective/communicator-inl.h"
#include "../common/io.h"
#include "../common/threading_utils.h"
#include "../predictor/predict_fn.h"
#include "./param.h"
#include "xgboost/json.h"
namespace xgboost {
namespace tree {
@ -100,8 +100,9 @@ class TreeRefresher : public TreeUpdater {
}
});
};
rabit::Allreduce<rabit::op::Sum>(&dmlc::BeginPtr(stemp[0])->sum_grad, stemp[0].size() * 2,
lazy_get_stats);
lazy_get_stats();
collective::Allreduce<collective::Operation::kSum>(&dmlc::BeginPtr(stemp[0])->sum_grad,
stemp[0].size() * 2);
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();

View File

@ -4,12 +4,14 @@
* \brief synchronize the tree in all distributed nodes
*/
#include <xgboost/tree_updater.h>
#include <vector>
#include <string>
#include <limits>
#include "xgboost/json.h"
#include <limits>
#include <string>
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/io.h"
#include "xgboost/json.h"
namespace xgboost {
namespace tree {
@ -35,17 +37,17 @@ class TreeSyncher : public TreeUpdater {
void Update(HostDeviceVector<GradientPair>*, DMatrix*,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree*>& trees) override {
if (rabit::GetWorldSize() == 1) return;
if (collective::GetWorldSize() == 1) return;
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = rabit::GetRank();
int rank = collective::GetRank();
if (rank == 0) {
for (auto tree : trees) {
tree->Save(&fs);
}
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
collective::Broadcast(&s_model, 0);
for (auto tree : trees) {
tree->Load(&fs);
}

View File

@ -46,8 +46,8 @@ template <bool use_column>
void TestDistributedQuantile(size_t rows, size_t cols) {
std::string msg {"Skipping AllReduce test"};
int32_t constexpr kWorkers = 4;
InitRabitContext(msg, kWorkers);
auto world = rabit::GetWorldSize();
InitCommunicatorContext(msg, kWorkers);
auto world = collective::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, kWorkers);
} else {
@ -65,7 +65,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
// Generate cuts for distributed environment.
auto sparsity = 0.5f;
auto rank = rabit::GetRank();
auto rank = collective::GetRank();
std::vector<FeatureType> ft(cols);
for (size_t i = 0; i < ft.size(); ++i) {
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
@ -99,8 +99,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
sketch_distributed.MakeCuts(&distributed_cuts);
// Generate cuts for single node environment
rabit::Finalize();
CHECK_EQ(rabit::GetWorldSize(), 1);
collective::Finalize();
CHECK_EQ(collective::GetWorldSize(), 1);
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
m->Info().num_row_ = world * rows;
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
@ -184,8 +184,8 @@ TEST(Quantile, SameOnAllWorkers) {
#if defined(__unix__)
std::string msg{"Skipping Quantile AllreduceBasic test"};
int32_t constexpr kWorkers = 4;
InitRabitContext(msg, kWorkers);
auto world = rabit::GetWorldSize();
InitCommunicatorContext(msg, kWorkers);
auto world = collective::GetWorldSize();
if (world != 1) {
CHECK_EQ(world, kWorkers);
} else {
@ -196,7 +196,7 @@ TEST(Quantile, SameOnAllWorkers) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
auto rank = rabit::GetRank();
auto rank = collective::GetRank();
HostDeviceVector<float> storage;
std::vector<FeatureType> ft(kCols);
for (size_t i = 0; i < ft.size(); ++i) {
@ -217,12 +217,12 @@ TEST(Quantile, SameOnAllWorkers) {
std::vector<float> cut_min_values(cuts.MinValues().size() * world, 0);
size_t value_size = cuts.Values().size();
rabit::Allreduce<rabit::op::Max>(&value_size, 1);
collective::Allreduce<collective::Operation::kMax>(&value_size, 1);
size_t ptr_size = cuts.Ptrs().size();
rabit::Allreduce<rabit::op::Max>(&ptr_size, 1);
collective::Allreduce<collective::Operation::kMax>(&ptr_size, 1);
CHECK_EQ(ptr_size, kCols + 1);
size_t min_value_size = cuts.MinValues().size();
rabit::Allreduce<rabit::op::Max>(&min_value_size, 1);
collective::Allreduce<collective::Operation::kMax>(&min_value_size, 1);
CHECK_EQ(min_value_size, kCols);
size_t value_offset = value_size * rank;
@ -235,9 +235,9 @@ TEST(Quantile, SameOnAllWorkers) {
std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(),
cut_min_values.begin() + min_values_offset);
rabit::Allreduce<rabit::op::Sum>(cut_values.data(), cut_values.size());
rabit::Allreduce<rabit::op::Sum>(cut_ptrs.data(), cut_ptrs.size());
rabit::Allreduce<rabit::op::Sum>(cut_min_values.data(), cut_min_values.size());
collective::Allreduce<collective::Operation::kSum>(cut_values.data(), cut_values.size());
collective::Allreduce<collective::Operation::kSum>(cut_ptrs.data(), cut_ptrs.size());
collective::Allreduce<collective::Operation::kSum>(cut_min_values.data(), cut_min_values.size());
for (int32_t i = 0; i < world; i++) {
for (size_t j = 0; j < value_size; ++j) {
@ -256,7 +256,7 @@ TEST(Quantile, SameOnAllWorkers) {
}
}
});
rabit::Finalize();
collective::Finalize();
#endif // defined(__unix__)
}
} // namespace common

View File

@ -1,6 +1,7 @@
#include <gtest/gtest.h>
#include "test_quantile.h"
#include "../helpers.h"
#include "../../../src/collective/device_communicator.cuh"
#include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh"
@ -341,17 +342,14 @@ TEST(GPUQuantile, AllReduceBasic) {
// This test is supposed to run by a python test that setups the environment.
std::string msg {"Skipping AllReduce test"};
auto n_gpus = AllVisibleGPUs();
InitRabitContext(msg, n_gpus);
auto world = rabit::GetWorldSize();
InitCommunicatorContext(msg, n_gpus);
auto world = collective::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
} else {
return;
}
auto reducer = std::make_shared<dh::AllReducer>();
reducer->Init(0);
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
// Set up single node version;
@ -385,8 +383,8 @@ TEST(GPUQuantile, AllReduceBasic) {
// Set up distributed version. We rely on using rank as seed to generate
// the exact same copy of data.
auto rank = rabit::GetRank();
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
auto rank = collective::GetRank();
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
@ -422,28 +420,26 @@ TEST(GPUQuantile, AllReduceBasic) {
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
}
});
rabit::Finalize();
collective::Finalize();
}
TEST(GPUQuantile, SameOnAllWorkers) {
std::string msg {"Skipping SameOnAllWorkers test"};
auto n_gpus = AllVisibleGPUs();
InitRabitContext(msg, n_gpus);
auto world = rabit::GetWorldSize();
InitCommunicatorContext(msg, n_gpus);
auto world = collective::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
} else {
return;
}
auto reducer = std::make_shared<dh::AllReducer>();
reducer->Init(0);
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
MetaInfo const &info) {
auto rank = rabit::GetRank();
auto rank = collective::GetRank();
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
@ -459,7 +455,7 @@ TEST(GPUQuantile, SameOnAllWorkers) {
// Test for all workers having the same sketch.
size_t n_data = sketch_distributed.Data().size();
rabit::Allreduce<rabit::op::Max>(&n_data, 1);
collective::Allreduce<collective::Operation::kMax>(&n_data, 1);
ASSERT_EQ(n_data, sketch_distributed.Data().size());
size_t size_as_float =
sketch_distributed.Data().size_bytes() / sizeof(float);
@ -472,9 +468,10 @@ TEST(GPUQuantile, SameOnAllWorkers) {
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
reducer->AllReduceSum(all_workers.data().get(), all_workers.data().get(),
all_workers.size());
reducer->Synchronize();
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(0);
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
communicator->Synchronize();
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size());

View File

@ -1,16 +1,16 @@
#ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
#define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
#include <rabit/rabit.h>
#include <algorithm>
#include <string>
#include <vector>
#include "../helpers.h"
#include "../../src/collective/communicator-inl.h"
namespace xgboost {
namespace common {
inline void InitRabitContext(std::string msg, int32_t n_workers) {
inline void InitCommunicatorContext(std::string msg, int32_t n_workers) {
auto port = std::getenv("DMLC_TRACKER_PORT");
std::string port_str;
if (port) {
@ -28,12 +28,11 @@ inline void InitRabitContext(std::string msg, int32_t n_workers) {
return;
}
std::vector<std::string> envs{
"DMLC_TRACKER_PORT=" + port_str,
"DMLC_TRACKER_URI=" + uri_str,
"DMLC_NUM_WORKER=" + std::to_string(n_workers)};
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
rabit::Init(3, c_envs);
Json config{JsonObject()};
config["DMLC_TRACKER_PORT"] = port_str;
config["DMLC_TRACKER_URI"] = uri_str;
config["DMLC_NUM_WORKER"] = n_workers;
collective::Init(config);
}
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {

View File

@ -21,21 +21,19 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
rabit_env = [
'xgboost_communicator=federated',
f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}',
f'federated_rank={rank}'
]
communicator_env = {
'xgboost_communicator': 'federated',
'federated_server_address': f'localhost:{port}',
'federated_world_size': world_size,
'federated_rank': rank
}
if with_ssl:
rabit_env = rabit_env + [
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]
communicator_env['federated_server_cert'] = SERVER_CERT
communicator_env['federated_client_key'] = CLIENT_KEY
communicator_env['federated_client_cert'] = CLIENT_CERT
# Always call this before using distributed module
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
@ -55,9 +53,9 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
if xgb.collective.get_rank() == 0:
bst.save_model("test.model.json")
xgb.rabit.tracker_print("Finished training\n")
xgb.collective.communicator_print("Finished training\n")
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:

View File

@ -1,7 +1,7 @@
"""Copyright 2019-2022 XGBoost contributors"""
import sys
import os
from typing import Type, TypeVar, Any, Dict, List
from typing import Type, TypeVar, Any, Dict, List, Union
import pytest
import numpy as np
import asyncio
@ -425,7 +425,7 @@ class TestDistributedGPU:
)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
with dxgb.CommunicatorContext(**rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()
@ -505,20 +505,13 @@ class TestDistributedGPU:
test = "--gtest_filter=GPUQuantile." + name
def runit(
worker_addr: str, rabit_args: List[bytes]
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
port_env = ""
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode("utf-8").startswith("DMLC_TRACKER_PORT"):
port_env = arg.decode("utf-8")
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split("=")
env = os.environ.copy()
env[port[0]] = port[1]
uri = uri_env.split("=")
env[uri[0]] = uri[1]
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
workers = _get_client_workers(local_cuda_client)

View File

@ -1,13 +1,13 @@
import multiprocessing
import socket
import sys
import time
import numpy as np
import pytest
import xgboost as xgb
from xgboost import RabitTracker
from xgboost import collective
from xgboost import RabitTracker, build_info, federated
if sys.platform.startswith("win"):
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
@ -37,3 +37,41 @@ def test_rabit_communicator():
for worker in workers:
worker.join()
assert worker.exitcode == 0
def run_federated_worker(port, world_size, rank):
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
federated_server_address=f'localhost:{port}',
federated_world_size=world_size,
federated_rank=rank):
assert xgb.collective.get_world_size() == world_size
assert xgb.collective.is_distributed()
assert xgb.collective.get_processor_name() == f'rank{rank}'
ret = xgb.collective.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))
def test_federated_communicator():
if not build_info()["USE_FEDERATED"]:
pytest.skip("XGBoost not built with federated learning enabled")
port = 9091
world_size = 2
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_federated_worker,
args=(port, world_size, rank))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert worker.exitcode == 0
server.terminate()

View File

@ -15,37 +15,33 @@ if sys.platform.startswith("win"):
def test_rabit_tracker():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
worker_env = tracker.worker_envs()
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
with xgb.rabit.RabitContext(rabit_env):
ret = xgb.rabit.broadcast("test1234", 0)
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
ret = xgb.collective.broadcast("test1234", 0)
assert str(ret) == "test1234"
def run_rabit_ops(client, n_workers):
from test_with_dask import _get_client_workers
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
from xgboost import rabit
from xgboost import collective
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not rabit.is_distributed()
assert not collective.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
def local_test(worker_id):
with RabitContext(rabit_args):
with CommunicatorContext(**rabit_args):
a = 1
assert rabit.is_distributed()
assert collective.is_distributed()
a = np.array([a])
reduced = rabit.allreduce(a, rabit.Op.SUM)
reduced = collective.allreduce(a, collective.Op.SUM)
assert reduced[0] == n_workers
worker_id = np.array([worker_id])
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
reduced = collective.allreduce(worker_id, collective.Op.MAX)
assert reduced == n_workers - 1
return 1
@ -83,14 +79,10 @@ def test_rank_assignment() -> None:
from test_with_dask import _get_client_workers
def local_test(worker_id):
with xgb.dask.RabitContext(args):
for val in args:
sval = val.decode("utf-8")
if sval.startswith("DMLC_TASK_ID"):
task_id = sval
break
with xgb.dask.CommunicatorContext(**args) as ctx:
task_id = ctx["DMLC_TASK_ID"]
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.rabit.get_rank()
rank = xgb.collective.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))

View File

@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
class TestWithDask:
def test_dmatrix_binary(self, client: "Client") -> None:
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
with xgb.dask.CommunicatorContext(**rabit_args):
rank = xgb.collective.get_rank()
X, y = tm.make_categorical(100, 4, 4, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
path = os.path.join(tmpdir, f"{rank}.bin")
Xy.save_binary(path)
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
with xgb.dask.CommunicatorContext(**rabit_args):
rank = xgb.collective.get_rank()
path = os.path.join(tmpdir, f"{rank}.bin")
Xy = xgb.DMatrix(path)
assert Xy.num_row() == 100
@ -1488,20 +1488,13 @@ class TestWithDask:
test = "--gtest_filter=Quantile." + name
def runit(
worker_addr: str, rabit_args: List[bytes]
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
port_env = ''
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
uri = uri_env.split("=")
env["DMLC_TRACKER_URI"] = uri[1]
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, capture_output=True)
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
@ -1543,8 +1536,8 @@ class TestWithDask:
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
with xgb.dask.RabitContext(rabit_args):
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
with xgb.dask.CommunicatorContext(**rabit_args):
if worker_id == 0:
y = np.array([0.0, 0.0, 0.0])
x = np.array([[0.0]] * 3)
@ -1686,12 +1679,12 @@ class TestWithDask:
n_workers = len(workers)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with xgb.dask.RabitContext(rabit_args):
with xgb.dask.CommunicatorContext(**rabit_args):
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
**data_ref, nthread=7
)
total = np.array([local_dtrain.num_row()])
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
assert total[0] == kRows
futures = []