[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
e47b3a3da3
commit
668b8a0ea4
@ -52,15 +52,15 @@ class XGBoostTrainer(Executor):
|
||||
def _do_training(self, fl_ctx: FLContext):
|
||||
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
|
||||
rank = int(client_name.split('-')[1]) - 1
|
||||
rabit_env = [
|
||||
f'federated_server_address={self._server_address}',
|
||||
f'federated_world_size={self._world_size}',
|
||||
f'federated_rank={rank}',
|
||||
f'federated_server_cert={self._server_cert_path}',
|
||||
f'federated_client_key={self._client_key_path}',
|
||||
f'federated_client_cert={self._client_cert_path}'
|
||||
]
|
||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||
communicator_env = {
|
||||
'federated_server_address': self._server_address,
|
||||
'federated_world_size': self._world_size,
|
||||
'federated_rank': rank,
|
||||
'federated_server_cert': self._server_cert_path,
|
||||
'federated_client_key': self._client_key_path,
|
||||
'federated_client_cert': self._client_cert_path
|
||||
}
|
||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('agaricus.txt.test')
|
||||
@ -86,4 +86,4 @@ class XGBoostTrainer(Executor):
|
||||
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
|
||||
run_dir = workspace.get_run_dir(run_number)
|
||||
bst.save_model(os.path.join(run_dir, "test.model.json"))
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
xgb.collective.communicator_print("Finished training\n")
|
||||
|
||||
@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink
|
||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
@ -46,7 +46,7 @@ object XGBoost {
|
||||
collector: Collector[XGBoostModel]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||
logger.info("start with env" + workerEnvs.toString)
|
||||
Rabit.init(workerEnvs)
|
||||
Communicator.init(workerEnvs)
|
||||
val mapper = (x: LabeledVector) => {
|
||||
val (index, value) = x.vector.toSeq.unzip
|
||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
@ -59,7 +59,7 @@ object XGBoost {
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
collector.collect(new XGBoostModel(booster))
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ import java.util.ServiceLoader
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
Communicator.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||
@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
override def next(): Row = {
|
||||
val ret = batchIterImpl.next()
|
||||
if (!batchIterImpl.hasNext) {
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
@ -303,7 +303,7 @@ object XGBoost extends Serializable {
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
Communicator.init(rabitEnv)
|
||||
|
||||
watches = buildWatchesAndCheck(buildWatches)
|
||||
|
||||
@ -342,7 +342,7 @@ object XGBoost extends Serializable {
|
||||
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
|
||||
throw xgbException
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
if (watches != null) watches.delete()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,277 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.scalatest.{FunSuite}
|
||||
|
||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
||||
val classifier = new XGBoostClassifier(paramMap)
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
|
||||
xgbParamsFactory.buildXGBRuntimeParams
|
||||
}
|
||||
|
||||
|
||||
test("Customize host ip and python exec for Rabit tracker") {
|
||||
val hostIp = "192.168.22.111"
|
||||
val pythonExec = "/usr/bin/python3"
|
||||
|
||||
val paramMap = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
|
||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
tracker match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.contains(hostIp))
|
||||
assert(cmd.startsWith("python"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap1 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
|
||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||
tracker1 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(!cmd.contains(hostIp))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap2 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
|
||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||
tracker2 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
}
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||
|
||||
val rawData = rdd.mapPartitions { iter =>
|
||||
Iterator(iter.toArray)
|
||||
}.collect()
|
||||
|
||||
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||
}
|
||||
|
||||
val allReduceResults = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val arr = iter.toArray
|
||||
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
|
||||
Rabit.shutdown()
|
||||
Iterator(results)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
allReduceResults.foreachPartition(() => _)
|
||||
val byPartitionResults = allReduceResults.collect()
|
||||
assert(byPartitionResults(0).length == vectorLength)
|
||||
collectedAllReduceResults.put(byPartitionResults(0))
|
||||
}
|
||||
}
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == 0)
|
||||
sparkThread.join()
|
||||
|
||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||
}
|
||||
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
|
||||
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
|
||||
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
|
||||
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
||||
that should be avoided.
|
||||
*/
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new PyRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
|
||||
last created worker. A cascading event chain will be triggered once the RuntimeException is
|
||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||
|
||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||
|
||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||
the entire Spark application to crash.
|
||||
|
||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||
the exception is thrown at last, ideally after all worker connections have been established.
|
||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||
process to ensure that pending worker connections are handled.
|
||||
*/
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(500)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||
if (index != 1) {
|
||||
Rabit.init(trackerEnvs)
|
||||
Thread.sleep(1000)
|
||||
Rabit.shutdown()
|
||||
}
|
||||
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
// should fail due to connection timeout
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("should allow the dataframe containing rabit calls to be partially evaluated for" +
|
||||
" multiple times (ISSUE-4406)") {
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
|
||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||
val prediction = model.transform(trainingDF)
|
||||
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
|
||||
// threads
|
||||
prediction.show()
|
||||
// a full evaluation here will re-run init and shutdown all rabit proxy
|
||||
// expecting no error
|
||||
prediction.collect()
|
||||
}
|
||||
}
|
||||
@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
@ -25,7 +25,7 @@ import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
|
||||
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
|
||||
val predictionErrorMin = 0.00001f
|
||||
val maxFailure = 2;
|
||||
|
||||
@ -47,8 +47,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
||||
.fit(training)
|
||||
|
||||
assert(Rabit.rabitEnvs.asScala.size > 3)
|
||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
|
||||
@ -70,8 +70,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
|
||||
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
||||
).fit(training)
|
||||
assert(Rabit.rabitEnvs.asScala.size > 3)
|
||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
// check the equality of single instance prediction
|
||||
@ -85,7 +85,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
test("test rabit timeout fail handle") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
// mock rank 0 failure during 8th allreduce synchronization
|
||||
Rabit.mockList = Array("0,8,0,0").toList.asJava
|
||||
Communicator.mockList = Array("0,8,0,0").toList.asJava
|
||||
|
||||
intercept[SparkException] {
|
||||
new XGBoostClassifier(Map(
|
||||
@ -98,6 +98,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
"rabit_timeout" -> 0))
|
||||
.fit(training)
|
||||
}
|
||||
|
||||
Communicator.mockList = Array.empty.toList.asJava
|
||||
}
|
||||
|
||||
}
|
||||
@ -1,154 +0,0 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Rabit global class for synchronization.
|
||||
*/
|
||||
public class Rabit {
|
||||
|
||||
public enum OpType implements Serializable {
|
||||
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
|
||||
|
||||
private int op;
|
||||
|
||||
public int getOperand() {
|
||||
return this.op;
|
||||
}
|
||||
|
||||
OpType(int op) {
|
||||
this.op = op;
|
||||
}
|
||||
}
|
||||
|
||||
public enum DataType implements Serializable {
|
||||
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
|
||||
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
|
||||
LONGLONG(8, 8), ULONGLONG(9, 8);
|
||||
|
||||
private int enumOp;
|
||||
private int size;
|
||||
|
||||
public int getEnumOp() {
|
||||
return this.enumOp;
|
||||
}
|
||||
|
||||
public int getSize() {
|
||||
return this.size;
|
||||
}
|
||||
|
||||
DataType(int enumOp, int size) {
|
||||
this.enumOp = enumOp;
|
||||
this.size = size;
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
// used as way to test/debug passed rabit init parameters
|
||||
public static Map<String, String> rabitEnvs;
|
||||
public static List<String> mockList = new LinkedList<>();
|
||||
/**
|
||||
* Initialize the rabit library on current working thread.
|
||||
* @param envs The additional environment variables to pass to rabit.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
rabitEnvs = envs;
|
||||
String[] args = new String[envs.size() + mockList.size()];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||
}
|
||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
||||
for(String mock : mockList) {
|
||||
args[idx++] = "mock=" + mock;
|
||||
}
|
||||
checkCall(XGBoostJNI.RabitInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown the rabit engine in current working thread, equals to finalize.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void shutdown() throws XGBoostError {
|
||||
checkCall(XGBoostJNI.RabitFinalize());
|
||||
}
|
||||
|
||||
/**
|
||||
* Print the message on rabit tracker.
|
||||
* @param msg
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void trackerPrint(String msg) throws XGBoostError {
|
||||
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get version number of current stored model in the thread.
|
||||
* which means how many calls to CheckPoint we made so far.
|
||||
* @return version Number.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int versionNumber() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XGBoostJNI.RabitVersionNumber(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get rank of current thread.
|
||||
* @return the rank.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getRank() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XGBoostJNI.RabitGetRank(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get world size of current job.
|
||||
* @return the worldsize
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getWorldSize() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* perform Allreduce on distributed float vectors using operator op.
|
||||
* This implementation of allReduce does not support customized prepare function callback in the
|
||||
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
|
||||
*
|
||||
* @param elements local elements on distributed workers.
|
||||
* @param op operator used for Allreduce.
|
||||
* @return All-reduced float elements according to the given operator.
|
||||
*/
|
||||
public static float[] allReduce(float[] elements, OpType op) {
|
||||
DataType dataType = DataType.FLOAT;
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
|
||||
.order(ByteOrder.nativeOrder());
|
||||
|
||||
for (float el : elements) {
|
||||
buffer.putFloat(el);
|
||||
}
|
||||
buffer.flip();
|
||||
|
||||
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
|
||||
float[] results = new float[elements.length];
|
||||
buffer.asFloatBuffer().get(results);
|
||||
|
||||
return results;
|
||||
}
|
||||
}
|
||||
@ -254,16 +254,16 @@ public class XGBoost {
|
||||
}
|
||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||
if (shouldPrint(params, iter)) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
Communicator.communicatorPrint(String.format(
|
||||
"early stopping after %d rounds away from the best iteration",
|
||||
earlyStoppingRounds
|
||||
));
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
|
||||
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
||||
if (shouldPrint(params, iter)){
|
||||
Rabit.trackerPrint(evalInfo + '\n');
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,19 +135,6 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||
|
||||
// rabit functions
|
||||
public final static native int RabitInit(String[] args);
|
||||
public final static native int RabitFinalize();
|
||||
public final static native int RabitTrackerPrint(String msg);
|
||||
public final static native int RabitGetRank(int[] out);
|
||||
public final static native int RabitGetWorldSize(int[] out);
|
||||
public final static native int RabitVersionNumber(int[] out);
|
||||
|
||||
// Perform Allreduce operation on data in sendrecvbuf.
|
||||
// This JNI function does not support the callback function for data preparation yet.
|
||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
|
||||
@ -872,111 +872,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
std::vector<std::string> args;
|
||||
std::vector<char*> argv;
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
||||
const char *s = jenv->GetStringUTFChars(arg, 0);
|
||||
args.push_back(std::string(s, jenv->GetStringLength(arg)));
|
||||
if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
|
||||
if (args.back().length() == 0) args.pop_back();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
argv.push_back(&args[i][0]);
|
||||
}
|
||||
|
||||
if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
if (RabitFinalize()) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
||||
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
|
||||
std::string str(jenv->GetStringUTFChars(jmsg, 0),
|
||||
jenv->GetStringLength(jmsg));
|
||||
JVM_CHECK_CALL(RabitTrackerPrint(str.c_str()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
jint rank = RabitGetRank();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
jint out = RabitGetWorldSize();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
jint out = RabitVersionNumber();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitAllreduce
|
||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
|
||||
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
|
||||
JVM_CHECK_CALL(RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@ -279,62 +279,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitAllreduce
|
||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@ -300,7 +300,7 @@ public class DMatrixTest {
|
||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||
Map<String, String> rabitEnv = new HashMap<>();
|
||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||
Rabit.init(rabitEnv);
|
||||
Communicator.init(rabitEnv);
|
||||
DMatrix trainMat = null;
|
||||
BigDenseMatrix data0 = null;
|
||||
try {
|
||||
@ -348,7 +348,7 @@ public class DMatrixTest {
|
||||
else if (data0 != null) {
|
||||
data0.dispose();
|
||||
}
|
||||
Rabit.shutdown();
|
||||
Communicator.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -23,6 +23,6 @@ target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
|
||||
target_sources(objxgboost PRIVATE federated_server.cc)
|
||||
target_link_libraries(objxgboost PRIVATE federated_client)
|
||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||
|
||||
@ -1,197 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "federated_client.h"
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "rabit/internal/utils.h"
|
||||
|
||||
namespace MPI { // NOLINT
|
||||
// MPI data type to be compatible with existing MPI interface
|
||||
class Datatype {
|
||||
public:
|
||||
size_t type_size;
|
||||
explicit Datatype(size_t type_size) : type_size(type_size) {}
|
||||
};
|
||||
} // namespace MPI
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
|
||||
/*! \brief implementation of engine using federated learning */
|
||||
class FederatedEngine : public IEngine {
|
||||
public:
|
||||
void Init(int argc, char *argv[]) {
|
||||
// Parse environment variables first.
|
||||
for (auto const &env_var : env_vars_) {
|
||||
char const *value = getenv(env_var.c_str());
|
||||
if (value != nullptr) {
|
||||
SetParam(env_var, value);
|
||||
}
|
||||
}
|
||||
// Command line argument overrides.
|
||||
for (int i = 0; i < argc; ++i) {
|
||||
std::string const key_value = argv[i];
|
||||
auto const delimiter = key_value.find('=');
|
||||
if (delimiter != std::string::npos) {
|
||||
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
|
||||
}
|
||||
}
|
||||
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
|
||||
server_address_.c_str(), world_size_, rank_);
|
||||
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
|
||||
utils::Printf("Certificates not specified, turning off SSL.");
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
|
||||
} else {
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
|
||||
client_key_, client_cert_));
|
||||
}
|
||||
}
|
||||
|
||||
void Finalize() { client_.reset(); }
|
||||
|
||||
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
|
||||
size_t size_prev_slice) override {
|
||||
throw std::logic_error("FederatedEngine:: Allgather is not supported");
|
||||
}
|
||||
|
||||
std::string Allgather(void *sendbuf, size_t total_size) {
|
||||
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
|
||||
return client_->Allgather(send_buffer);
|
||||
}
|
||||
|
||||
void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun, void *prepare_arg) override {
|
||||
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||
}
|
||||
|
||||
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
|
||||
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
||||
std::string const send_buffer(buffer, size);
|
||||
auto const receive_buffer =
|
||||
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
|
||||
static_cast<xgboost::federated::ReduceOperation>(op));
|
||||
receive_buffer.copy(buffer, size);
|
||||
}
|
||||
|
||||
int GetRingPrevRank() const override {
|
||||
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
|
||||
}
|
||||
|
||||
void Broadcast(void *sendrecvbuf, size_t size, int root) override {
|
||||
if (world_size_ == 1) return;
|
||||
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
||||
std::string const send_buffer(buffer, size);
|
||||
auto const receive_buffer = client_->Broadcast(send_buffer, root);
|
||||
if (rank_ != root) {
|
||||
receive_buffer.copy(buffer, size);
|
||||
}
|
||||
}
|
||||
|
||||
int LoadCheckPoint() override { return 0; }
|
||||
|
||||
void CheckPoint() override { version_number_ += 1; }
|
||||
|
||||
int VersionNumber() const override { return version_number_; }
|
||||
|
||||
/*! \brief get rank of current node */
|
||||
int GetRank() const override { return rank_; }
|
||||
|
||||
/*! \brief get total number of */
|
||||
int GetWorldSize() const override { return world_size_; }
|
||||
|
||||
/*! \brief whether it is distributed */
|
||||
bool IsDistributed() const override { return true; }
|
||||
|
||||
/*! \brief get the host name of current node */
|
||||
std::string GetHost() const override { return "rank" + std::to_string(rank_); }
|
||||
|
||||
void TrackerPrint(const std::string &msg) override {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
void SetParam(std::string const &name, std::string const &val) {
|
||||
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
|
||||
server_address_ = val;
|
||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
|
||||
world_size_ = std::stoi(val);
|
||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
|
||||
rank_ = std::stoi(val);
|
||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
|
||||
server_cert_ = ReadFile(val);
|
||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
|
||||
client_key_ = ReadFile(val);
|
||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
|
||||
client_cert_ = ReadFile(val);
|
||||
}
|
||||
}
|
||||
|
||||
static std::string ReadFile(std::string const &path) {
|
||||
auto stream = std::ifstream(path.data());
|
||||
std::ostringstream out;
|
||||
out << stream.rdbuf();
|
||||
return out.str();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
std::vector<std::string> const env_vars_{
|
||||
"FEDERATED_SERVER_ADDRESS",
|
||||
"FEDERATED_WORLD_SIZE",
|
||||
"FEDERATED_RANK",
|
||||
"FEDERATED_SERVER_CERT",
|
||||
"FEDERATED_CLIENT_KEY",
|
||||
"FEDERATED_CLIENT_CERT" };
|
||||
// clang-format on
|
||||
std::string server_address_{"localhost:9091"};
|
||||
int world_size_{1};
|
||||
int rank_{0};
|
||||
std::string server_cert_{};
|
||||
std::string client_key_{};
|
||||
std::string client_cert_{};
|
||||
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
||||
int version_number_{0};
|
||||
};
|
||||
|
||||
// Singleton federated engine.
|
||||
FederatedEngine engine; // NOLINT(cert-err58-cpp)
|
||||
|
||||
/*! \brief initialize the synchronization module */
|
||||
bool Init(int argc, char *argv[]) {
|
||||
try {
|
||||
engine.Init(argc, argv);
|
||||
return true;
|
||||
} catch (std::exception const &e) {
|
||||
fprintf(stderr, " failed in federated Init %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief finalize synchronization module */
|
||||
bool Finalize() {
|
||||
try {
|
||||
engine.Finalize();
|
||||
return true;
|
||||
} catch (const std::exception &e) {
|
||||
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine() { return &engine; }
|
||||
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
if (engine.GetWorldSize() == 1) return;
|
||||
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
@ -11,6 +11,40 @@
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief A Federated Learning communicator class that handles collective communication.
|
||||
*/
|
||||
|
||||
@ -3,9 +3,8 @@
|
||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
"""
|
||||
|
||||
from . import rabit # noqa
|
||||
from . import tracker # noqa
|
||||
from . import dask
|
||||
from . import collective, dask
|
||||
from .core import (
|
||||
Booster,
|
||||
DataIter,
|
||||
@ -63,4 +62,6 @@ __all__ = [
|
||||
"XGBRFRegressor",
|
||||
# dask
|
||||
"dask",
|
||||
# collective
|
||||
"collective",
|
||||
]
|
||||
|
||||
@ -13,7 +13,7 @@ import pickle
|
||||
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast, Sequence, Any
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from . import collective
|
||||
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ def _allreduce_metric(score: _ART) -> _ART:
|
||||
as final result.
|
||||
|
||||
'''
|
||||
world = rabit.get_world_size()
|
||||
world = collective.get_world_size()
|
||||
assert world != 0
|
||||
if world == 1:
|
||||
return score
|
||||
@ -108,7 +108,7 @@ def _allreduce_metric(score: _ART) -> _ART:
|
||||
raise ValueError(
|
||||
'xgboost.cv function should not be used in distributed environment.')
|
||||
arr = numpy.array([score])
|
||||
arr = rabit.allreduce(arr, rabit.Op.SUM) / world
|
||||
arr = collective.allreduce(arr, collective.Op.SUM) / world
|
||||
return arr[0]
|
||||
|
||||
|
||||
@ -485,7 +485,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
return False
|
||||
|
||||
msg: str = f'[{epoch}]'
|
||||
if rabit.get_rank() == self.printer_rank:
|
||||
if collective.get_rank() == self.printer_rank:
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
stdv: Optional[float] = None
|
||||
@ -498,7 +498,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
msg += '\n'
|
||||
|
||||
if (epoch % self.period) == 0 or self.period == 1:
|
||||
rabit.tracker_print(msg)
|
||||
collective.communicator_print(msg)
|
||||
self._latest = None
|
||||
else:
|
||||
# There is skipped message
|
||||
@ -506,8 +506,8 @@ class EvaluationMonitor(TrainingCallback):
|
||||
return False
|
||||
|
||||
def after_training(self, model: _Model) -> _Model:
|
||||
if rabit.get_rank() == self.printer_rank and self._latest is not None:
|
||||
rabit.tracker_print(self._latest)
|
||||
if collective.get_rank() == self.printer_rank and self._latest is not None:
|
||||
collective.communicator_print(self._latest)
|
||||
return model
|
||||
|
||||
|
||||
@ -552,7 +552,7 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||
('.pkl' if self._as_pickle else '.json'))
|
||||
self._epoch = 0
|
||||
if rabit.get_rank() == 0:
|
||||
if collective.get_rank() == 0:
|
||||
if self._as_pickle:
|
||||
with open(path, 'wb') as fd:
|
||||
pickle.dump(model, fd)
|
||||
|
||||
@ -4,7 +4,7 @@ import json
|
||||
import logging
|
||||
import pickle
|
||||
from enum import IntEnum, unique
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -233,10 +233,11 @@ class CommunicatorContext:
|
||||
def __init__(self, **args: Any) -> None:
|
||||
self.args = args
|
||||
|
||||
def __enter__(self) -> None:
|
||||
def __enter__(self) -> Dict[str, Any]:
|
||||
init(**self.args)
|
||||
assert is_distributed()
|
||||
LOGGER.debug("-------------- communicator say hello ------------------")
|
||||
return self.args
|
||||
|
||||
def __exit__(self, *args: List) -> None:
|
||||
finalize()
|
||||
|
||||
@ -59,7 +59,7 @@ from typing import (
|
||||
|
||||
import numpy
|
||||
|
||||
from . import config, rabit
|
||||
from . import collective, config
|
||||
from ._typing import _T, FeatureNames, FeatureTypes
|
||||
from .callback import TrainingCallback
|
||||
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||
@ -112,7 +112,7 @@ TrainReturnT = TypedDict(
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RabitContext",
|
||||
"CommunicatorContext",
|
||||
"DaskDMatrix",
|
||||
"DaskDeviceQuantileDMatrix",
|
||||
"DaskXGBRegressor",
|
||||
@ -158,7 +158,7 @@ def _try_start_tracker(
|
||||
if isinstance(addrs[0], tuple):
|
||||
host_ip = addrs[0][0]
|
||||
port = addrs[0][1]
|
||||
rabit_context = RabitTracker(
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=get_host_ip(host_ip),
|
||||
n_workers=n_workers,
|
||||
port=port,
|
||||
@ -168,12 +168,12 @@ def _try_start_tracker(
|
||||
addr = addrs[0]
|
||||
assert isinstance(addr, str) or addr is None
|
||||
host_ip = get_host_ip(addr)
|
||||
rabit_context = RabitTracker(
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
||||
)
|
||||
env.update(rabit_context.worker_envs())
|
||||
rabit_context.start(n_workers)
|
||||
thread = Thread(target=rabit_context.join)
|
||||
env.update(rabit_tracker.worker_envs())
|
||||
rabit_tracker.start(n_workers)
|
||||
thread = Thread(target=rabit_tracker.join)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
except socket.error as e:
|
||||
@ -213,11 +213,11 @@ def _assert_dask_support() -> None:
|
||||
LOGGER.warning(msg)
|
||||
|
||||
|
||||
class RabitContext(rabit.RabitContext):
|
||||
"""A context controlling rabit initialization and finalization."""
|
||||
class CommunicatorContext(collective.CommunicatorContext):
|
||||
"""A context controlling collective communicator initialization and finalization."""
|
||||
|
||||
def __init__(self, args: List[bytes]) -> None:
|
||||
super().__init__(args)
|
||||
def __init__(self, **args: Any) -> None:
|
||||
super().__init__(**args)
|
||||
worker = distributed.get_worker()
|
||||
with distributed.worker_client() as client:
|
||||
info = client.scheduler_info()
|
||||
@ -227,9 +227,7 @@ class RabitContext(rabit.RabitContext):
|
||||
# not the same as task ID is string and "10" is sorted before "2") with dask
|
||||
# worker ID. This outsources the rank assignment to dask and prevents
|
||||
# non-deterministic issue.
|
||||
self.args.append(
|
||||
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
|
||||
)
|
||||
self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{wid}]:" + str(worker.address)
|
||||
|
||||
|
||||
def dconcat(value: Sequence[_T]) -> _T:
|
||||
@ -811,7 +809,7 @@ def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix:
|
||||
|
||||
async def _get_rabit_args(
|
||||
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
|
||||
) -> List[bytes]:
|
||||
) -> Dict[str, Union[str, int]]:
|
||||
"""Get rabit context arguments from data distribution in DaskDMatrix."""
|
||||
# There are 3 possible different addresses:
|
||||
# 1. Provided by user via dask.config
|
||||
@ -854,9 +852,7 @@ async def _get_rabit_args(
|
||||
env = await client.run_on_scheduler(
|
||||
_start_tracker, n_workers, sched_addr, user_addr
|
||||
)
|
||||
|
||||
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
|
||||
return rabit_args
|
||||
return env
|
||||
|
||||
|
||||
def _get_dask_config() -> Optional[Dict[str, Any]]:
|
||||
@ -911,7 +907,7 @@ async def _train_async(
|
||||
|
||||
def dispatched_train(
|
||||
parameters: Dict,
|
||||
rabit_args: List[bytes],
|
||||
rabit_args: Dict[str, Union[str, int]],
|
||||
train_id: int,
|
||||
evals_name: List[str],
|
||||
evals_id: List[int],
|
||||
@ -935,7 +931,7 @@ async def _train_async(
|
||||
n_threads = dwnt
|
||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||
local_history: TrainingCallback.EvalsLog = {}
|
||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
|
||||
@ -1,249 +0,0 @@
|
||||
"""Distributed XGBoost Rabit related API."""
|
||||
import ctypes
|
||||
from enum import IntEnum, unique
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .core import _LIB, c_str, _check_call
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.rabit]")
|
||||
|
||||
|
||||
def _init_rabit() -> None:
|
||||
"""internal library initializer."""
|
||||
if _LIB is not None:
|
||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
||||
_LIB.RabitIsDistributed.restype = ctypes.c_int
|
||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||
|
||||
|
||||
def init(args: Optional[List[bytes]] = None) -> None:
|
||||
"""Initialize the rabit library with arguments"""
|
||||
if args is None:
|
||||
args = []
|
||||
arr = (ctypes.c_char_p * len(args))()
|
||||
arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
|
||||
_LIB.RabitInit(len(arr), arr)
|
||||
|
||||
|
||||
def finalize() -> None:
|
||||
"""Finalize the process, notify tracker everything is done."""
|
||||
_LIB.RabitFinalize()
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get rank of current process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Rank of current process.
|
||||
"""
|
||||
ret = _LIB.RabitGetRank()
|
||||
return ret
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get total number workers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
n : int
|
||||
Total number of process.
|
||||
"""
|
||||
ret = _LIB.RabitGetWorldSize()
|
||||
return ret
|
||||
|
||||
|
||||
def is_distributed() -> int:
|
||||
'''If rabit is distributed.'''
|
||||
is_dist = _LIB.RabitIsDistributed()
|
||||
return is_dist
|
||||
|
||||
|
||||
def tracker_print(msg: Any) -> None:
|
||||
"""Print message to the tracker.
|
||||
|
||||
This function can be used to communicate the information of
|
||||
the progress to the tracker
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : str
|
||||
The message to be printed to tracker.
|
||||
"""
|
||||
if not isinstance(msg, str):
|
||||
msg = str(msg)
|
||||
is_dist = _LIB.RabitIsDistributed()
|
||||
if is_dist != 0:
|
||||
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
|
||||
else:
|
||||
print(msg.strip(), flush=True)
|
||||
|
||||
|
||||
def get_processor_name() -> bytes:
|
||||
"""Get the processor name.
|
||||
|
||||
Returns
|
||||
-------
|
||||
name : str
|
||||
the name of processor(host)
|
||||
"""
|
||||
mxlen = 256
|
||||
length = ctypes.c_ulong()
|
||||
buf = ctypes.create_string_buffer(mxlen)
|
||||
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
||||
return buf.value
|
||||
|
||||
|
||||
T = TypeVar("T") # pylint:disable=invalid-name
|
||||
|
||||
|
||||
def broadcast(data: T, root: int) -> T:
|
||||
"""Broadcast object from one node to all other nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : any type that can be pickled
|
||||
Input data, if current rank does not equal root, this can be None
|
||||
root : int
|
||||
Rank of the node to broadcast data from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object : int
|
||||
the result of broadcast.
|
||||
"""
|
||||
rank = get_rank()
|
||||
length = ctypes.c_ulong()
|
||||
if root == rank:
|
||||
assert data is not None, 'need to pass in data when broadcasting'
|
||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
length.value = len(s)
|
||||
# run first broadcast
|
||||
_check_call(_LIB.RabitBroadcast(ctypes.byref(length),
|
||||
ctypes.sizeof(ctypes.c_ulong), root))
|
||||
if root != rank:
|
||||
dptr = (ctypes.c_char * length.value)()
|
||||
# run second
|
||||
_check_call(_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
||||
length.value, root))
|
||||
data = pickle.loads(dptr.raw)
|
||||
del dptr
|
||||
else:
|
||||
_check_call(_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
||||
length.value, root))
|
||||
del s
|
||||
return data
|
||||
|
||||
|
||||
# enumeration of dtypes
|
||||
DTYPE_ENUM__ = {
|
||||
np.dtype('int8'): 0,
|
||||
np.dtype('uint8'): 1,
|
||||
np.dtype('int32'): 2,
|
||||
np.dtype('uint32'): 3,
|
||||
np.dtype('int64'): 4,
|
||||
np.dtype('uint64'): 5,
|
||||
np.dtype('float32'): 6,
|
||||
np.dtype('float64'): 7
|
||||
}
|
||||
|
||||
|
||||
@unique
|
||||
class Op(IntEnum):
|
||||
'''Supported operations for rabit.'''
|
||||
MAX = 0
|
||||
MIN = 1
|
||||
SUM = 2
|
||||
OR = 3
|
||||
|
||||
|
||||
def allreduce( # pylint:disable=invalid-name
|
||||
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
||||
) -> np.ndarray:
|
||||
"""Perform allreduce, return the result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data :
|
||||
Input data.
|
||||
op :
|
||||
Reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun :
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to initialize the data
|
||||
If the result of Allreduce can be recovered directly,
|
||||
then prepare_fun will NOT be called
|
||||
|
||||
Returns
|
||||
-------
|
||||
result :
|
||||
The result of allreduce, have same shape as data
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is not thread-safe.
|
||||
"""
|
||||
if not isinstance(data, np.ndarray):
|
||||
raise Exception('allreduce only takes in numpy.ndarray')
|
||||
buf = data.ravel()
|
||||
if buf.base is data.base:
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise Exception(f"data type {buf.dtype} not supported")
|
||||
if prepare_fun is None:
|
||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
int(op), None, None))
|
||||
else:
|
||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
|
||||
def pfunc(_: Any) -> None:
|
||||
"""prepare function."""
|
||||
fn = cast(Callable[[np.ndarray], None], prepare_fun)
|
||||
fn(data)
|
||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, func_ptr(pfunc), None))
|
||||
return buf
|
||||
|
||||
|
||||
def version_number() -> int:
|
||||
"""Returns version number of current stored model.
|
||||
|
||||
This means how many calls to CheckPoint we made so far.
|
||||
|
||||
Returns
|
||||
-------
|
||||
version : int
|
||||
Version number of currently stored model
|
||||
"""
|
||||
ret = _LIB.RabitVersionNumber()
|
||||
return ret
|
||||
|
||||
|
||||
class RabitContext:
|
||||
"""A context controlling rabit initialization and finalization."""
|
||||
|
||||
def __init__(self, args: List[bytes] = None) -> None:
|
||||
if args is None:
|
||||
args = []
|
||||
self.args = args
|
||||
|
||||
def __enter__(self) -> None:
|
||||
init(self.args)
|
||||
assert is_distributed()
|
||||
LOGGER.debug("-------------- rabit say hello ------------------")
|
||||
|
||||
def __exit__(self, *args: List) -> None:
|
||||
finalize()
|
||||
LOGGER.debug("--------------- rabit say bye ------------------")
|
||||
|
||||
|
||||
# initialization script
|
||||
_init_rabit()
|
||||
@ -2,6 +2,7 @@
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods, too-many-lines
|
||||
import json
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -57,7 +58,7 @@ from .params import (
|
||||
HasQueryIdCol,
|
||||
)
|
||||
from .utils import (
|
||||
RabitContext,
|
||||
CommunicatorContext,
|
||||
_get_args_from_message_list,
|
||||
_get_default_params_from_func,
|
||||
_get_gpu_id,
|
||||
@ -769,7 +770,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
_rabit_args = ""
|
||||
_rabit_args = {}
|
||||
if context.partitionId() == 0:
|
||||
get_logger("XGBoostPySpark").info(
|
||||
"booster params: %s\n"
|
||||
@ -780,12 +781,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
dmatrix_kwargs,
|
||||
)
|
||||
|
||||
_rabit_args = str(_get_rabit_args(context, num_workers))
|
||||
_rabit_args = _get_rabit_args(context, num_workers)
|
||||
|
||||
messages = context.allGather(message=str(_rabit_args))
|
||||
messages = context.allGather(message=json.dumps(_rabit_args))
|
||||
_rabit_args = _get_args_from_message_list(messages)
|
||||
evals_result = {}
|
||||
with RabitContext(_rabit_args, context):
|
||||
with CommunicatorContext(context, **_rabit_args):
|
||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||
pandas_df_iter,
|
||||
features_cols_names,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for helper functions."""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from threading import Thread
|
||||
@ -9,7 +10,7 @@ import pyspark
|
||||
from pyspark.sql.session import SparkSession
|
||||
from xgboost.tracker import RabitTracker
|
||||
|
||||
from xgboost import rabit
|
||||
from xgboost import collective
|
||||
|
||||
|
||||
def get_class_name(cls):
|
||||
@ -36,21 +37,21 @@ def _get_default_params_from_func(func, unsupported_set):
|
||||
return filtered_params_dict
|
||||
|
||||
|
||||
class RabitContext:
|
||||
class CommunicatorContext:
|
||||
"""
|
||||
A context controlling rabit initialization and finalization.
|
||||
A context controlling collective communicator initialization and finalization.
|
||||
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
|
||||
"""
|
||||
|
||||
def __init__(self, args, context):
|
||||
def __init__(self, context, **args):
|
||||
self.args = args
|
||||
self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode())
|
||||
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||
|
||||
def __enter__(self):
|
||||
rabit.init(self.args)
|
||||
collective.init(**self.args)
|
||||
|
||||
def __exit__(self, *args):
|
||||
rabit.finalize()
|
||||
collective.finalize()
|
||||
|
||||
|
||||
def _start_tracker(context, n_workers):
|
||||
@ -74,8 +75,7 @@ def _get_rabit_args(context, n_workers):
|
||||
"""
|
||||
# pylint: disable=consider-using-f-string
|
||||
env = _start_tracker(context, n_workers)
|
||||
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
|
||||
return rabit_args
|
||||
return env
|
||||
|
||||
|
||||
def _get_host_ip(context):
|
||||
@ -95,7 +95,7 @@ def _get_args_from_message_list(messages):
|
||||
if message != "":
|
||||
output = message
|
||||
break
|
||||
return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")]
|
||||
return json.loads(output)
|
||||
|
||||
|
||||
def _get_spark_session():
|
||||
|
||||
@ -6,9 +6,7 @@ set(RABIT_SOURCES
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
|
||||
|
||||
if (PLUGIN_FEDERATED)
|
||||
# Skip the engine if the Federated Learning plugin is enabled.
|
||||
elseif (RABIT_BUILD_MPI)
|
||||
if (RABIT_BUILD_MPI)
|
||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
|
||||
elseif (RABIT_MOCK)
|
||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
// Copyright (c) 2014-2022 by Contributors
|
||||
#include <rabit/rabit.h>
|
||||
#include <rabit/c_api.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
@ -22,12 +19,11 @@
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "../collective/communicator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/charconv.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_server.h"
|
||||
@ -215,7 +211,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
#else
|
||||
if (rabit::IsDistributed()) {
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||
<< "will split data among workers";
|
||||
load_row_split = true;
|
||||
@ -1560,44 +1556,42 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
|
||||
API_END();
|
||||
}
|
||||
|
||||
using xgboost::collective::Communicator;
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
Communicator::Init(config);
|
||||
collective::Init(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize() {
|
||||
API_BEGIN();
|
||||
Communicator::Finalize();
|
||||
collective::Finalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank() {
|
||||
return Communicator::Get()->GetRank();
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
return collective::GetRank();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize() {
|
||||
return Communicator::Get()->GetWorldSize();
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
||||
return collective::GetWorldSize();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed() {
|
||||
return Communicator::Get()->IsDistributed();
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
||||
return collective::IsDistributed();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
Communicator::Get()->Print(message);
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||
local.ret_str = Communicator::Get()->GetProcessorName();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
@ -1605,16 +1599,14 @@ XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
Communicator::Get()->AllReduce(
|
||||
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
|
||||
static_cast<xgboost::collective::Operation>(enum_op));
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include "collective/communicator-inl.h"
|
||||
#include "common/common.h"
|
||||
#include "common/config.h"
|
||||
#include "common/io.h"
|
||||
@ -156,7 +157,7 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
if (name_pred == "stdout") {
|
||||
save_period = 0;
|
||||
}
|
||||
if (dsplit == 0 && rabit::IsDistributed()) {
|
||||
if (dsplit == 0 && collective::IsDistributed()) {
|
||||
dsplit = 2;
|
||||
}
|
||||
}
|
||||
@ -186,26 +187,22 @@ class CLI {
|
||||
kHelp
|
||||
} print_info_ {kNone};
|
||||
|
||||
int ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
|
||||
void ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
|
||||
learner_.reset(Learner::Create(matrices));
|
||||
int version = rabit::LoadCheckPoint();
|
||||
if (version == 0) {
|
||||
if (param_.model_in != CLIParam::kNull) {
|
||||
this->LoadModel(param_.model_in, learner_.get());
|
||||
learner_->SetParams(param_.cfg);
|
||||
} else {
|
||||
learner_->SetParams(param_.cfg);
|
||||
}
|
||||
}
|
||||
learner_->Configure();
|
||||
return version;
|
||||
}
|
||||
|
||||
void CLITrain() {
|
||||
const double tstart_data_load = dmlc::GetTime();
|
||||
if (rabit::IsDistributed()) {
|
||||
std::string pname = rabit::GetProcessorName();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
|
||||
if (collective::IsDistributed()) {
|
||||
std::string pname = collective::GetProcessorName();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
|
||||
}
|
||||
// load in data.
|
||||
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
|
||||
@ -230,48 +227,45 @@ class CLI {
|
||||
eval_data_names.emplace_back("train");
|
||||
}
|
||||
// initialize the learner.
|
||||
int32_t version = this->ResetLearner(cache_mats);
|
||||
this->ResetLearner(cache_mats);
|
||||
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
|
||||
<< " sec";
|
||||
|
||||
// start training.
|
||||
const double start = dmlc::GetTime();
|
||||
int32_t version = 0;
|
||||
for (int i = version / 2; i < param_.num_round; ++i) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
if (version % 2 == 0) {
|
||||
LOG(INFO) << "boosting round " << i << ", " << elapsed
|
||||
<< " sec elapsed";
|
||||
learner_->UpdateOneIter(i, dtrain);
|
||||
rabit::CheckPoint();
|
||||
version += 1;
|
||||
}
|
||||
CHECK_EQ(version, rabit::VersionNumber());
|
||||
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
|
||||
if (rabit::IsDistributed()) {
|
||||
if (rabit::GetRank() == 0) {
|
||||
if (collective::IsDistributed()) {
|
||||
if (collective::GetRank() == 0) {
|
||||
LOG(TRACKER) << res;
|
||||
}
|
||||
} else {
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
|
||||
rabit::GetRank() == 0) {
|
||||
collective::GetRank() == 0) {
|
||||
std::ostringstream os;
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
<< i + 1 << ".model";
|
||||
this->SaveModel(os.str(), learner_.get());
|
||||
}
|
||||
|
||||
rabit::CheckPoint();
|
||||
version += 1;
|
||||
CHECK_EQ(version, rabit::VersionNumber());
|
||||
}
|
||||
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
|
||||
<< " sec";
|
||||
// always save final round
|
||||
if ((param_.save_period == 0 ||
|
||||
param_.num_round % param_.save_period != 0) &&
|
||||
rabit::GetRank() == 0) {
|
||||
collective::GetRank() == 0) {
|
||||
std::ostringstream os;
|
||||
if (param_.model_out == CLIParam::kNull) {
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
@ -467,7 +461,6 @@ class CLI {
|
||||
return;
|
||||
}
|
||||
|
||||
rabit::Init(argc, argv);
|
||||
std::string config_path = argv[1];
|
||||
|
||||
common::ConfigParser cp(config_path);
|
||||
@ -480,6 +473,13 @@ class CLI {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the collective communicator.
|
||||
Json json{JsonObject()};
|
||||
for (auto& kv : cfg) {
|
||||
json[kv.first] = String(kv.second);
|
||||
}
|
||||
collective::Init(json);
|
||||
|
||||
param_.Configure(cfg);
|
||||
}
|
||||
|
||||
@ -517,7 +517,7 @@ class CLI {
|
||||
}
|
||||
|
||||
~CLI() {
|
||||
rabit::Finalize();
|
||||
collective::Finalize();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
208
src/collective/communicator-inl.h
Normal file
208
src/collective/communicator-inl.h
Normal file
@ -0,0 +1,208 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/*!
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * mpi: Use MPI.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
* - rabit_debug: Enable debugging.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
*/
|
||||
inline void Init(Json const& config) {
|
||||
Communicator::Init(config);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
*/
|
||||
inline void Finalize() { Communicator::Finalize(); }
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
*/
|
||||
inline int GetRank() { return Communicator::Get()->GetRank(); }
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
*/
|
||||
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
*/
|
||||
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
*/
|
||||
inline void Print(char const *message) { Communicator::Get()->Print(message); }
|
||||
|
||||
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
*
|
||||
* \return Name of the processor.
|
||||
*/
|
||||
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
*/
|
||||
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
*/
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
|
||||
static_cast<Operation>(op));
|
||||
}
|
||||
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
// Specialization for size_t, which is implementation defined, so it might or might not
|
||||
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
template <Operation op, typename T,
|
||||
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
|
||||
inline void Allreduce(T *send_receive_buffer, size_t count) {
|
||||
static_assert(sizeof(T) == sizeof(uint64_t), "");
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(float *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
@ -12,14 +13,10 @@
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
if (communicator_) {
|
||||
LOG(FATAL) << "Communicator can only be initialized once.";
|
||||
}
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
@ -51,7 +48,7 @@ void Communicator::Init(Json const& config) {
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(nullptr);
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "device_communicator_adapter.cuh"
|
||||
#include "noop_communicator.h"
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl_device_communicator.cuh"
|
||||
#endif
|
||||
@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat
|
||||
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(nullptr);
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
device_ordinal_ = -1;
|
||||
device_communicator_.reset(nullptr);
|
||||
}
|
||||
|
||||
@ -23,40 +23,6 @@ enum class DataType {
|
||||
kDouble = 7
|
||||
};
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/** @brief Defines the reduction operation. */
|
||||
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
||||
|
||||
|
||||
@ -21,7 +21,28 @@ class DeviceCommunicator {
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
|
||||
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
|
||||
@ -23,17 +23,28 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(double);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
@ -66,6 +77,20 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
private:
|
||||
template <collective::DataType data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(T);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
|
||||
@ -24,6 +24,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
int32_t const rank = communicator_->GetRank();
|
||||
int32_t const world = communicator_->GetWorldSize();
|
||||
|
||||
if (world == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||
@ -52,8 +56,15 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
~NcclDeviceCommunicator() override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
if (cuda_stream_) {
|
||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||
ncclCommDestroy(nccl_comm_);
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
@ -61,16 +72,28 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
|
||||
ncclSum, nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(double);
|
||||
allreduce_calls_ += 1;
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
@ -95,6 +118,9 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
}
|
||||
@ -136,6 +162,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return id;
|
||||
}
|
||||
|
||||
template <ncclDataType_t data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||
nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(T);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
|
||||
30
src/collective/noop_communicator.h
Normal file
30
src/collective/noop_communicator.h
Normal file
@ -0,0 +1,30 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* A no-op communicator, used for non-distributed training.
|
||||
*/
|
||||
class NoOpCommunicator : public Communicator {
|
||||
public:
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {}
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
|
||||
std::string GetProcessorName() override { return ""; }
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -1,139 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
*
|
||||
* \brief Utilities for CUDA.
|
||||
*/
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include <nccl.h>
|
||||
#endif // #ifdef XGBOOST_USE_NCCL
|
||||
#include <sstream>
|
||||
|
||||
#include "device_helpers.cuh"
|
||||
|
||||
namespace dh {
|
||||
|
||||
constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
|
||||
void GetCudaUUID(int device_ord, xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
||||
cudaDeviceProp prob;
|
||||
safe_cuda(cudaGetDeviceProperties(&prob, device_ord));
|
||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
||||
std::stringstream ss;
|
||||
for (auto v : uuid) {
|
||||
ss << std::hex << v;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
void NcclAllReducer::DoInit(int _device_ordinal) {
|
||||
int32_t const rank = rabit::GetRank();
|
||||
int32_t const world = rabit::GetWorldSize();
|
||||
if (world == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(_device_ordinal, s_this_uuid);
|
||||
|
||||
// No allgather yet.
|
||||
rabit::Allreduce<rabit::op::Sum, uint64_t>(uuids.data(), uuids.size());
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);;
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] =
|
||||
xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
j++;
|
||||
}
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
|
||||
id_ = GetUniqueId();
|
||||
dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank));
|
||||
safe_cuda(cudaStreamCreate(&stream_));
|
||||
}
|
||||
|
||||
void NcclAllReducer::DoAllGather(void const *data, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *recvbuf) {
|
||||
int32_t world = rabit::GetWorldSize();
|
||||
segments->clear();
|
||||
segments->resize(world, 0);
|
||||
segments->at(rabit::GetRank()) = length_bytes;
|
||||
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0);
|
||||
recvbuf->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
safe_nccl(ncclGroupStart());
|
||||
for (int32_t i = 0; i < world; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
safe_nccl(
|
||||
ncclBroadcast(data, recvbuf->data().get() + offset,
|
||||
as_bytes, ncclChar, i, comm_, stream_));
|
||||
offset += as_bytes;
|
||||
}
|
||||
safe_nccl(ncclGroupEnd());
|
||||
}
|
||||
|
||||
NcclAllReducer::~NcclAllReducer() {
|
||||
if (initialised_) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream_));
|
||||
ncclCommDestroy(comm_);
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576;
|
||||
}
|
||||
}
|
||||
#else
|
||||
void RabitAllReducer::DoInit(int _device_ordinal) {
|
||||
#if !defined(XGBOOST_USE_FEDERATED)
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost is not compiled with NCCL, falling back to Rabit.";
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void RabitAllReducer::DoAllGather(void const *data, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *recvbuf) {
|
||||
size_t world = rabit::GetWorldSize();
|
||||
segments->clear();
|
||||
segments->resize(world, 0);
|
||||
segments->at(rabit::GetRank()) = length_bytes;
|
||||
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
recvbuf->resize(total_bytes);
|
||||
|
||||
sendrecvbuf_.reserve(total_bytes);
|
||||
auto rank = rabit::GetRank();
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank) {
|
||||
safe_cuda(
|
||||
cudaMemcpy(sendrecvbuf_.data() + offset, data, segments->at(rank), cudaMemcpyDefault));
|
||||
}
|
||||
rabit::Broadcast(sendrecvbuf_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
safe_cuda(cudaMemcpy(recvbuf->data().get(), sendrecvbuf_.data(), total_bytes, cudaMemcpyDefault));
|
||||
}
|
||||
#endif // XGBOOST_USE_NCCL
|
||||
|
||||
} // namespace dh
|
||||
@ -19,7 +19,6 @@
|
||||
#include <thrust/unique.h>
|
||||
#include <thrust/binary_search.h>
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_allocator.cuh>
|
||||
|
||||
@ -36,6 +35,7 @@
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/global_config.h"
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "common.h"
|
||||
#include "algorithm.cuh"
|
||||
|
||||
@ -404,7 +404,7 @@ inline detail::MemoryLogger &GlobalMemoryLogger() {
|
||||
// dh::DebugSyncDevice(__FILE__, __LINE__);
|
||||
inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
|
||||
if (file != "" && line != -1) {
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = xgboost::collective::GetRank();
|
||||
LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line;
|
||||
}
|
||||
safe_cuda(cudaDeviceSynchronize());
|
||||
@ -423,7 +423,7 @@ using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
|
||||
|
||||
inline void ThrowOOMError(std::string const& err, size_t bytes) {
|
||||
auto device = CurrentDevice();
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = xgboost::collective::GetRank();
|
||||
std::stringstream ss;
|
||||
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
|
||||
<< "- Free memory: " << AvailableMemory(device) << "\n"
|
||||
@ -737,512 +737,6 @@ using TypedDiscard =
|
||||
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
|
||||
detail::TypedDiscard<T>>;
|
||||
|
||||
/**
|
||||
* \class AllReducer
|
||||
*
|
||||
* \brief All reducer class that manages its own communication group and
|
||||
* streams. Must be initialised before use. If XGBoost is compiled without NCCL,
|
||||
* this falls back to use Rabit.
|
||||
*/
|
||||
template <typename AllReducer>
|
||||
class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
public:
|
||||
virtual ~AllReducerBase() = default;
|
||||
|
||||
/**
|
||||
* \brief Initialise with the desired device ordinal for this allreducer.
|
||||
*
|
||||
* \param device_ordinal The device ordinal.
|
||||
*/
|
||||
void Init(int _device_ordinal) {
|
||||
device_ordinal_ = _device_ordinal;
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
this->Underlying().DoInit(_device_ordinal);
|
||||
initialised_ = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
|
||||
* different size of data on different workers.
|
||||
*
|
||||
* \param data Buffer storing the input data.
|
||||
* \param length_bytes Size of input data in bytes.
|
||||
* \param segments Size of data on each worker.
|
||||
* \param recvbuf Buffer storing the result of data from all workers.
|
||||
*/
|
||||
void AllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *recvbuf) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allgather. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param data Buffer storing the input data.
|
||||
* \param length Size of input data in bytes.
|
||||
* \param recvbuf Buffer storing the result of data from all workers.
|
||||
*/
|
||||
void AllGather(uint32_t const *data, size_t length,
|
||||
dh::caching_device_vector<uint32_t> *recvbuf) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllGather(data, length, recvbuf);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
allreduce_bytes_ += count * sizeof(double);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
allreduce_bytes_ += count * sizeof(float);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
|
||||
*
|
||||
* \param count Number of.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of.
|
||||
*/
|
||||
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
allreduce_bytes_ += count * sizeof(int64_t);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
allreduce_bytes_ += count * sizeof(uint32_t);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
allreduce_bytes_ += count * sizeof(uint64_t);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* Specialization for size_t, which is implementation defined so it might or might not
|
||||
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
template <typename T = size_t,
|
||||
std::enable_if_t<std::is_same<size_t, T>::value &&
|
||||
!std::is_same<size_t, unsigned long long>::value> // NOLINT
|
||||
* = nullptr>
|
||||
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
allreduce_bytes_ += count * sizeof(T);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn void Synchronize()
|
||||
*
|
||||
* \brief Synchronizes the entire communication group.
|
||||
*/
|
||||
void Synchronize() {
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoSynchronize();
|
||||
}
|
||||
|
||||
protected:
|
||||
bool initialised_{false};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
|
||||
private:
|
||||
int device_ordinal_{-1};
|
||||
};
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
class NcclAllReducer : public AllReducerBase<NcclAllReducer> {
|
||||
public:
|
||||
friend class AllReducerBase<NcclAllReducer>;
|
||||
|
||||
~NcclAllReducer() override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* \brief Initialise with the desired device ordinal for this communication
|
||||
* group.
|
||||
*
|
||||
* \param device_ordinal The device ordinal.
|
||||
*/
|
||||
void DoInit(int _device_ordinal);
|
||||
|
||||
/**
|
||||
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
|
||||
* different size of data on different workers.
|
||||
*
|
||||
* \param data Buffer storing the input data.
|
||||
* \param length_bytes Size of input data in bytes.
|
||||
* \param segments Size of data on each worker.
|
||||
* \param recvbuf Buffer storing the result of data from all workers.
|
||||
*/
|
||||
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *recvbuf);
|
||||
|
||||
/**
|
||||
* \brief Allgather. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param data Buffer storing the input data.
|
||||
* \param length Size of input data in bytes.
|
||||
* \param recvbuf Buffer storing the result of data from all workers.
|
||||
*/
|
||||
void DoAllGather(uint32_t const *data, size_t length,
|
||||
dh::caching_device_vector<uint32_t> *recvbuf) {
|
||||
size_t world = rabit::GetWorldSize();
|
||||
recvbuf->resize(length * world);
|
||||
safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, comm_, stream_));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm_, stream_));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm_, stream_));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
|
||||
*
|
||||
* \param count Number of.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of.
|
||||
*/
|
||||
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm_, stream_));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* Specialization for size_t, which is implementation defined so it might or might not
|
||||
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
template <typename T = size_t,
|
||||
std::enable_if_t<std::is_same<size_t, T>::value &&
|
||||
!std::is_same<size_t, unsigned long long>::value> // NOLINT
|
||||
* = nullptr>
|
||||
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Synchronizes the entire communication group.
|
||||
*/
|
||||
void DoSynchronize() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
|
||||
|
||||
/**
|
||||
* \fn ncclUniqueId GetUniqueId()
|
||||
*
|
||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||
* communication
|
||||
*
|
||||
* \return the Unique ID
|
||||
*/
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rabit::GetRank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
rabit::Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
ncclComm_t comm_;
|
||||
cudaStream_t stream_;
|
||||
ncclUniqueId id_;
|
||||
};
|
||||
|
||||
using AllReducer = NcclAllReducer;
|
||||
#else
|
||||
class RabitAllReducer : public AllReducerBase<RabitAllReducer> {
|
||||
public:
|
||||
friend class AllReducerBase<RabitAllReducer>;
|
||||
|
||||
private:
|
||||
/**
|
||||
* \brief Initialise with the desired device ordinal for this allreducer.
|
||||
*
|
||||
* \param device_ordinal The device ordinal.
|
||||
*/
|
||||
static void DoInit(int _device_ordinal);
|
||||
|
||||
/**
|
||||
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
|
||||
* different size of data on different workers.
|
||||
*
|
||||
* \param data Buffer storing the input data.
|
||||
* \param length_bytes Size of input data in bytes.
|
||||
* \param segments Size of data on each worker.
|
||||
* \param recvbuf Buffer storing the result of data from all workers.
|
||||
*/
|
||||
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *recvbuf);
|
||||
|
||||
/**
|
||||
* \brief Allgather. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* \param data Buffer storing the input data.
|
||||
* \param length Size of input data in bytes.
|
||||
* \param recvbuf Buffer storing the result of data from all workers.
|
||||
*/
|
||||
void DoAllGather(uint32_t *data, size_t length, dh::caching_device_vector<uint32_t> *recvbuf) {
|
||||
size_t world = rabit::GetWorldSize();
|
||||
auto total_size = length * world;
|
||||
recvbuf->resize(total_size);
|
||||
sendrecvbuf_.reserve(total_size);
|
||||
auto rank = rabit::GetRank();
|
||||
safe_cuda(cudaMemcpy(sendrecvbuf_.data() + rank * length, data, length, cudaMemcpyDefault));
|
||||
rabit::Allgather(sendrecvbuf_.data(), total_size, rank * length, length, length);
|
||||
safe_cuda(cudaMemcpy(data, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
|
||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
|
||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* Specialization for size_t, which is implementation defined so it might or might not
|
||||
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
template <typename T = size_t,
|
||||
std::enable_if_t<std::is_same<size_t, T>::value &&
|
||||
!std::is_same<size_t, unsigned long long>::value> // NOLINT
|
||||
* = nullptr>
|
||||
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
|
||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Synchronizes the allreducer.
|
||||
*/
|
||||
void DoSynchronize() {}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
||||
*
|
||||
* Copy the device buffer to host, call rabit allreduce, then copy the buffer back
|
||||
* to device.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
template <typename T>
|
||||
void RabitAllReduceSum(const T *sendbuff, T *recvbuff, int count) {
|
||||
auto total_size = count * sizeof(T);
|
||||
sendrecvbuf_.reserve(total_size);
|
||||
safe_cuda(cudaMemcpy(sendrecvbuf_.data(), sendbuff, total_size, cudaMemcpyDefault));
|
||||
rabit::Allreduce<rabit::op::Sum>(reinterpret_cast<T*>(sendrecvbuf_.data()), count);
|
||||
safe_cuda(cudaMemcpy(recvbuff, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
/// Host buffer used to call rabit functions.
|
||||
std::vector<char> sendrecvbuf_{};
|
||||
};
|
||||
|
||||
using AllReducer = RabitAllReducer;
|
||||
#endif
|
||||
|
||||
template <typename VectorT, typename T = typename VectorT::value_type,
|
||||
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
||||
xgboost::common::Span<T> ToSpan(
|
||||
|
||||
@ -3,19 +3,14 @@
|
||||
* \file hist_util.cc
|
||||
*/
|
||||
#include <dmlc/timer.h>
|
||||
#include <dmlc/omp.h>
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "../common/common.h"
|
||||
#include "hist_util.h"
|
||||
#include "random.h"
|
||||
#include "column_matrix.h"
|
||||
#include "quantile.h"
|
||||
#include "../data/gradient_index.h"
|
||||
|
||||
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
||||
#include <xmmintrin.h>
|
||||
|
||||
@ -6,10 +6,10 @@
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "categorical.h"
|
||||
#include "hist_util.h"
|
||||
#include "rabit/rabit.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -144,8 +144,8 @@ struct QuantileAllreduce {
|
||||
void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_threads,
|
||||
std::vector<std::set<float>> *p_categories) {
|
||||
auto &categories = *p_categories;
|
||||
auto world_size = rabit::GetWorldSize();
|
||||
auto rank = rabit::GetRank();
|
||||
auto world_size = collective::GetWorldSize();
|
||||
auto rank = collective::GetRank();
|
||||
if (world_size == 1) {
|
||||
return;
|
||||
}
|
||||
@ -163,7 +163,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
||||
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
||||
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
|
||||
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
|
||||
rabit::Allreduce<rabit::op::Sum>(global_feat_ptrs.data(), global_feat_ptrs.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
|
||||
global_feat_ptrs.size());
|
||||
|
||||
// move all categories into a flatten vector to prepare for allreduce
|
||||
size_t total = feature_ptr.back();
|
||||
@ -176,7 +177,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
||||
// indptr for indexing workers
|
||||
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
|
||||
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
|
||||
rabit::Allreduce<rabit::op::Sum>(global_worker_ptr.data(), global_worker_ptr.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
|
||||
global_worker_ptr.size());
|
||||
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
|
||||
// total number of categories in all workers with all features
|
||||
auto gtotal = global_worker_ptr.back();
|
||||
@ -188,7 +190,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
||||
CHECK_EQ(rank_size, total);
|
||||
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
|
||||
// gather values from all workers.
|
||||
rabit::Allreduce<rabit::op::Sum>(global_categories.data(), global_categories.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
|
||||
global_categories.size());
|
||||
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
||||
categories.size()};
|
||||
ParallelFor(categories.size(), n_threads, [&](auto fidx) {
|
||||
@ -217,8 +220,8 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||
auto &worker_segments = *p_worker_segments;
|
||||
worker_segments.resize(1, 0);
|
||||
auto world = rabit::GetWorldSize();
|
||||
auto rank = rabit::GetRank();
|
||||
auto world = collective::GetWorldSize();
|
||||
auto rank = collective::GetRank();
|
||||
auto n_columns = sketches_.size();
|
||||
|
||||
// get the size of each feature.
|
||||
@ -237,7 +240,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
|
||||
|
||||
// Gather all column pointers
|
||||
rabit::Allreduce<rabit::op::Sum>(sketches_scan.data(), sketches_scan.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(sketches_scan.data(), sketches_scan.size());
|
||||
for (int32_t i = 0; i < world; ++i) {
|
||||
size_t back = (i + 1) * (n_columns + 1) - 1;
|
||||
auto n_entries = sketches_scan.at(back);
|
||||
@ -265,7 +268,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
|
||||
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
|
||||
"Unexpected size of sketch entry.");
|
||||
rabit::Allreduce<rabit::op::Sum>(
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<float *>(global_sketches.data()),
|
||||
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
|
||||
}
|
||||
@ -277,7 +280,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
monitor_.Start(__func__);
|
||||
|
||||
size_t n_columns = sketches_.size();
|
||||
rabit::Allreduce<rabit::op::Max>(&n_columns, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
|
||||
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
||||
|
||||
AllreduceCategories(feature_types_, n_threads_, &categories_);
|
||||
@ -291,7 +294,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
|
||||
// Prune the intermediate num cuts for synchronization.
|
||||
std::vector<bst_row_t> global_column_size(columns_size_);
|
||||
rabit::Allreduce<rabit::op::Sum>(global_column_size.data(), global_column_size.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(global_column_size.data(),
|
||||
global_column_size.size());
|
||||
|
||||
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
|
||||
int32_t intermediate_num_cuts = static_cast<int32_t>(
|
||||
@ -311,7 +315,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
num_cuts[i] = intermediate_num_cuts;
|
||||
});
|
||||
|
||||
auto world = rabit::GetWorldSize();
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world == 1) {
|
||||
monitor_.Stop(__func__);
|
||||
return;
|
||||
|
||||
@ -12,6 +12,8 @@
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator.h"
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "device_helpers.cuh"
|
||||
@ -501,47 +503,41 @@ void SketchContainer::FixError() {
|
||||
|
||||
void SketchContainer::AllReduce() {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
auto world = rabit::GetWorldSize();
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
timer_.Start(__func__);
|
||||
if (!reducer_) {
|
||||
reducer_ = std::make_shared<dh::AllReducer>();
|
||||
reducer_->Init(device_);
|
||||
}
|
||||
auto* communicator = collective::Communicator::GetDevice(device_);
|
||||
// Reduce the overhead on syncing.
|
||||
size_t global_sum_rows = num_rows_;
|
||||
rabit::Allreduce<rabit::op::Sum>(&global_sum_rows, 1);
|
||||
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
|
||||
size_t intermediate_num_cuts =
|
||||
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
|
||||
this->Prune(intermediate_num_cuts);
|
||||
|
||||
|
||||
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
|
||||
size_t n = d_columns_ptr.size();
|
||||
rabit::Allreduce<rabit::op::Max>(&n, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&n, 1);
|
||||
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
|
||||
|
||||
// Get the columns ptr from all workers
|
||||
dh::device_vector<SketchContainer::OffsetT> gathered_ptrs;
|
||||
gathered_ptrs.resize(d_columns_ptr.size() * world, 0);
|
||||
size_t rank = rabit::GetRank();
|
||||
size_t rank = collective::GetRank();
|
||||
auto offset = rank * d_columns_ptr.size();
|
||||
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
|
||||
gathered_ptrs.begin() + offset);
|
||||
reducer_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.data().get(),
|
||||
gathered_ptrs.size());
|
||||
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
|
||||
|
||||
// Get the data from all workers.
|
||||
std::vector<size_t> recv_lengths;
|
||||
dh::caching_device_vector<char> recvbuf;
|
||||
reducer_->AllGather(this->Current().data().get(),
|
||||
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths,
|
||||
&recvbuf);
|
||||
reducer_->Synchronize();
|
||||
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
|
||||
&recv_lengths, &recvbuf);
|
||||
communicator->Synchronize();
|
||||
|
||||
// Segment the received data.
|
||||
auto s_recvbuf = dh::ToSpan(recvbuf);
|
||||
|
||||
@ -37,7 +37,6 @@ class SketchContainer {
|
||||
|
||||
private:
|
||||
Monitor timer_;
|
||||
std::shared_ptr<dh::AllReducer> reducer_;
|
||||
HostDeviceVector<FeatureType> feature_types_;
|
||||
bst_row_t num_rows_;
|
||||
bst_feature_t num_columns_;
|
||||
@ -93,15 +92,12 @@ class SketchContainer {
|
||||
* \param num_columns Total number of columns in dataset.
|
||||
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
|
||||
* \param device GPU ID.
|
||||
* \param reducer Optional initialised reducer. Useful for speeding up testing.
|
||||
*/
|
||||
SketchContainer(HostDeviceVector<FeatureType> const &feature_types,
|
||||
int32_t max_bin, bst_feature_t num_columns,
|
||||
bst_row_t num_rows, int32_t device,
|
||||
std::shared_ptr<dh::AllReducer> reducer = nullptr)
|
||||
bst_row_t num_rows, int32_t device)
|
||||
: num_rows_{num_rows},
|
||||
num_columns_{num_columns}, num_bins_{max_bin}, device_{device},
|
||||
reducer_(std::move(reducer)) {
|
||||
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
CHECK_GE(device, 0);
|
||||
// Initialize Sketches for this dmatrix
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
|
||||
@ -7,20 +7,21 @@
|
||||
#ifndef XGBOOST_COMMON_RANDOM_H_
|
||||
#define XGBOOST_COMMON_RANDOM_H_
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "common.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -143,7 +144,7 @@ class ColumnSampler {
|
||||
*/
|
||||
ColumnSampler() {
|
||||
uint32_t seed = common::GlobalRandom()();
|
||||
rabit::Broadcast(&seed, sizeof(seed), 0);
|
||||
collective::Broadcast(&seed, sizeof(seed), 0);
|
||||
rng_.seed(seed);
|
||||
}
|
||||
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2019
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include "timer.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
|
||||
#if defined(XGBOOST_USE_NVTX)
|
||||
#include <nvToolsExt.h>
|
||||
#endif // defined(XGBOOST_USE_NVTX)
|
||||
@ -54,7 +53,7 @@ void Monitor::PrintStatistics(StatMap const& statistics) const {
|
||||
|
||||
void Monitor::Print() const {
|
||||
if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; }
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
StatMap stat_map;
|
||||
for (auto const &kv : statistics_map_) {
|
||||
stat_map[kv.first] = std::make_pair(
|
||||
|
||||
@ -2,36 +2,36 @@
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file data.cc
|
||||
*/
|
||||
#include "xgboost/data.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
|
||||
#include "dmlc/io.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/version_config.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/string_view.h"
|
||||
|
||||
#include "sparse_page_writer.h"
|
||||
#include "simple_dmatrix.h"
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/version.h"
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/version.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/iterative_dmatrix.h"
|
||||
#include "file_iterator.h"
|
||||
|
||||
#include "validation.h"
|
||||
#include "./sparse_page_source.h"
|
||||
#include "./sparse_page_dmatrix.h"
|
||||
#include "./sparse_page_source.h"
|
||||
#include "dmlc/io.h"
|
||||
#include "file_iterator.h"
|
||||
#include "simple_dmatrix.h"
|
||||
#include "sparse_page_writer.h"
|
||||
#include "validation.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h"
|
||||
#include "xgboost/version_config.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
|
||||
@ -793,12 +793,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||
size_t pos = cache_shards[i].rfind('.');
|
||||
if (pos == std::string::npos) {
|
||||
os << cache_shards[i]
|
||||
<< ".r" << rabit::GetRank()
|
||||
<< "-" << rabit::GetWorldSize();
|
||||
<< ".r" << collective::GetRank()
|
||||
<< "-" << collective::GetWorldSize();
|
||||
} else {
|
||||
os << cache_shards[i].substr(0, pos)
|
||||
<< ".r" << rabit::GetRank()
|
||||
<< "-" << rabit::GetWorldSize()
|
||||
<< ".r" << collective::GetRank()
|
||||
<< "-" << collective::GetWorldSize()
|
||||
<< cache_shards[i].substr(pos, cache_shards[i].length());
|
||||
}
|
||||
if (i + 1 != cache_shards.size()) {
|
||||
@ -821,8 +821,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||
|
||||
int partid = 0, npart = 1;
|
||||
if (load_row_split) {
|
||||
partid = rabit::GetRank();
|
||||
npart = rabit::GetWorldSize();
|
||||
partid = collective::GetRank();
|
||||
npart = collective::GetWorldSize();
|
||||
} else {
|
||||
// test option to load in part
|
||||
npart = 1;
|
||||
@ -877,7 +877,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||
/* sync up number of features after matrix loaded.
|
||||
* partitioned data will fail the train/val validation check
|
||||
* since partitioned data not knowing the real number of features. */
|
||||
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
||||
return dmat;
|
||||
}
|
||||
|
||||
|
||||
@ -3,13 +3,11 @@
|
||||
*/
|
||||
#include "iterative_dmatrix.h"
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <algorithm> // std::copy
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/categorical.h" // common::IsCat
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h" // common::HistogramCuts
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
@ -140,7 +138,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
// We use do while here as the first batch is fetched in ctor
|
||||
if (n_features == 0) {
|
||||
n_features = num_cols();
|
||||
rabit::Allreduce<rabit::op::Max>(&n_features, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_features, 1);
|
||||
column_sizes.resize(n_features);
|
||||
info_.num_col_ = n_features;
|
||||
} else {
|
||||
@ -157,7 +155,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
// From here on Info() has the correct data shape
|
||||
Info().num_row_ = accumulated_rows;
|
||||
Info().num_nonzero_ = nnz;
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
|
||||
return f > accumulated_rows;
|
||||
})) << "Something went wrong during iteration.";
|
||||
|
||||
@ -62,7 +62,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
rabit::Allreduce<rabit::op::Max>(&cols, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&cols, 1);
|
||||
this->info_.num_col_ = cols;
|
||||
} else {
|
||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||
@ -166,7 +166,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
|
||||
iter.Reset();
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
||||
|
||||
@ -189,7 +189,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
|
||||
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||
|
||||
if (adapter->NumRows() == kAdapterUnknownSize) {
|
||||
using IteratorAdapterT
|
||||
@ -322,7 +322,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
||||
}
|
||||
// Synchronise worker columns
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||
info_.num_row_ = total_batch_size;
|
||||
info_.num_nonzero_ = data_vec.size();
|
||||
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||
|
||||
@ -35,7 +35,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.num_row_ = adapter->NumRows();
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||
}
|
||||
|
||||
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||
|
||||
@ -5,6 +5,8 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include "./sparse_page_dmatrix.h"
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "gradient_index.h"
|
||||
|
||||
@ -46,8 +48,8 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
||||
cache_prefix_{std::move(cache_prefix)} {
|
||||
ctx_.nthread = nthreads;
|
||||
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
|
||||
if (rabit::IsDistributed()) {
|
||||
cache_prefix_ += ("-r" + std::to_string(rabit::GetRank()));
|
||||
if (collective::IsDistributed()) {
|
||||
cache_prefix_ += ("-r" + std::to_string(collective::GetRank()));
|
||||
}
|
||||
DMatrixProxy *proxy = MakeProxy(proxy_);
|
||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
||||
@ -94,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
||||
this->info_.num_col_ = n_features;
|
||||
this->info_.num_nonzero_ = nnz;
|
||||
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||
CHECK_NE(info_.num_col_, 0);
|
||||
}
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
|
||||
@ -135,7 +135,7 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (rabit::IsDistributed()) {
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(INFO) << "Tree method is automatically selected to be 'approx' "
|
||||
"for distributed training.";
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "collective/communicator-inl.h"
|
||||
#include "common/charconv.h"
|
||||
#include "common/common.h"
|
||||
#include "common/io.h"
|
||||
@ -478,7 +479,7 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
// add additional parameters
|
||||
// These are cosntraints that need to be satisfied.
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
@ -757,7 +758,7 @@ class LearnerConfiguration : public Learner {
|
||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||
}
|
||||
|
||||
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&num_feature, 1);
|
||||
if (num_feature > mparam_.num_feature) {
|
||||
mparam_.num_feature = num_feature;
|
||||
}
|
||||
@ -1083,7 +1084,7 @@ class LearnerIO : public LearnerConfiguration {
|
||||
cfg_.insert(n.cbegin(), n.cend());
|
||||
|
||||
// copy dsplit from config since it will not run again during restore
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
@ -1228,7 +1229,7 @@ class LearnerImpl : public LearnerIO {
|
||||
}
|
||||
// Configuration before data is known.
|
||||
void CheckDataSplitMode() {
|
||||
if (rabit::IsDistributed()) {
|
||||
if (collective::IsDistributed()) {
|
||||
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
|
||||
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
|
||||
if (tparam_.dsplit == DataSplitMode::kCol) {
|
||||
@ -1488,7 +1489,7 @@ class LearnerImpl : public LearnerIO {
|
||||
}
|
||||
|
||||
if (p_fmat->Info().num_row_ == 0) {
|
||||
LOG(WARNING) << "Empty dataset at worker: " << rabit::GetRank();
|
||||
LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -4,14 +4,12 @@
|
||||
* \brief Implementation of loggers.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#include "collective/communicator-inl.h"
|
||||
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
// Override logging mechanism for non-R interfaces
|
||||
@ -32,7 +30,7 @@ ConsoleLogger::~ConsoleLogger() {
|
||||
|
||||
TrackerLogger::~TrackerLogger() {
|
||||
log_stream_ << '\n';
|
||||
rabit::TrackerPrint(log_stream_.str());
|
||||
collective::Print(log_stream_.str());
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,27 +1,23 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#include "auc.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
#include "auc.h"
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
@ -117,7 +113,8 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
|
||||
|
||||
// we have 2 averages going in here, first is among workers, second is among
|
||||
// classes. allreduce sums up fp/tp auc for each class.
|
||||
rabit::Allreduce<rabit::op::Sum>(results.Values().data(), results.Values().size());
|
||||
collective::Allreduce<collective::Operation::kSum>(results.Values().data(),
|
||||
results.Values().size());
|
||||
double auc_sum{0};
|
||||
double tp_sum{0};
|
||||
for (size_t c = 0; c < n_classes; ++c) {
|
||||
@ -265,7 +262,7 @@ class EvalAUC : public Metric {
|
||||
}
|
||||
// We use the global size to handle empty dataset.
|
||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||
rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size());
|
||||
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
||||
if (meta[0] == 0) {
|
||||
// Empty across all workers, which is not supported.
|
||||
auc = std::numeric_limits<double>::quiet_NaN();
|
||||
@ -287,7 +284,7 @@ class EvalAUC : public Metric {
|
||||
}
|
||||
|
||||
std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
|
||||
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(results.data(), results.size());
|
||||
auc = results[0];
|
||||
valid_groups = static_cast<uint32_t>(results[1]);
|
||||
|
||||
@ -316,7 +313,7 @@ class EvalAUC : public Metric {
|
||||
}
|
||||
double local_area = fp * tp;
|
||||
std::array<double, 2> result{auc, local_area};
|
||||
rabit::Allreduce<rabit::op::Sum>(result.data(), result.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(result.data(), result.size());
|
||||
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
|
||||
if (local_area <= 0) {
|
||||
// the dataset across all workers have only positive or negative sample
|
||||
|
||||
@ -11,11 +11,10 @@
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "auc.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../common/ranking_utils.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@ -46,9 +45,8 @@ struct DeviceAUCCache {
|
||||
dh::device_vector<size_t> unique_idx;
|
||||
// p^T: transposed prediction matrix, used by MultiClassAUC
|
||||
dh::device_vector<float> predts_t;
|
||||
std::unique_ptr<dh::AllReducer> reducer;
|
||||
|
||||
void Init(common::Span<float const> predts, bool is_multi, int32_t device) {
|
||||
void Init(common::Span<float const> predts, bool is_multi) {
|
||||
if (sorted_idx.size() != predts.size()) {
|
||||
sorted_idx.resize(predts.size());
|
||||
fptp.resize(sorted_idx.size());
|
||||
@ -58,10 +56,6 @@ struct DeviceAUCCache {
|
||||
predts_t.resize(sorted_idx.size());
|
||||
}
|
||||
}
|
||||
if (is_multi && !reducer) {
|
||||
reducer.reset(new dh::AllReducer);
|
||||
reducer->Init(device);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -72,7 +66,7 @@ void InitCacheOnce(common::Span<float const> predts, int32_t device,
|
||||
if (!cache) {
|
||||
cache.reset(new DeviceAUCCache);
|
||||
}
|
||||
cache->Init(predts, is_multi, device);
|
||||
cache->Init(predts, is_multi);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -205,9 +199,11 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
||||
common::Span<double> tp, common::Span<double> auc,
|
||||
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
if (rabit::IsDistributed()) {
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
|
||||
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
auto* communicator = collective::Communicator::GetDevice(device);
|
||||
communicator->AllReduceSum(results.data(), results.size());
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
||||
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
||||
|
||||
@ -10,13 +10,13 @@
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/metric.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/metric.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
@ -101,7 +101,7 @@ XGBOOST_DEVICE inline double CalcDeltaPRAUC(double fp_prev, double fp,
|
||||
|
||||
inline void InvalidGroupAUC() {
|
||||
LOG(INFO) << "Invalid group with less than 3 samples is found on worker "
|
||||
<< rabit::GetRank() << ". Calculating AUC value requires at "
|
||||
<< collective::GetRank() << ". Calculating AUC value requires at "
|
||||
<< "least 2 pairs of samples.";
|
||||
}
|
||||
|
||||
|
||||
@ -7,11 +7,11 @@
|
||||
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
|
||||
*/
|
||||
#include <dmlc/registry.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/pseudo_huber.h"
|
||||
@ -196,8 +196,8 @@ class PseudoErrorLoss : public Metric {
|
||||
return std::make_tuple(v, wt);
|
||||
});
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
if (rabit::IsDistributed()) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
if (collective::IsDistributed()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
}
|
||||
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
@ -365,7 +365,7 @@ struct EvalEWiseBase : public Metric {
|
||||
});
|
||||
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
|
||||
@ -4,15 +4,14 @@
|
||||
* \brief evaluation metrics for multiclass classification.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
|
||||
#include "metric_common.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
@ -185,7 +184,7 @@ struct EvalMClassBase : public Metric {
|
||||
dat[0] = result.Residue();
|
||||
dat[1] = result.Weights();
|
||||
}
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
|
||||
@ -20,17 +20,17 @@
|
||||
// corresponding headers that brings in those function declaration can't be included with CUDA).
|
||||
// This precludes the CPU and GPU logic to coexist inside a .cu file
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <cmath>
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@ -103,7 +103,7 @@ struct EvalAMS : public Metric {
|
||||
}
|
||||
|
||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||
CHECK(!rabit::IsDistributed()) << "metric AMS do not support distributed evaluation";
|
||||
CHECK(!collective::IsDistributed()) << "metric AMS do not support distributed evaluation";
|
||||
using namespace std; // NOLINT(*)
|
||||
|
||||
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
|
||||
@ -216,10 +216,10 @@ struct EvalRank : public Metric, public EvalRankConfig {
|
||||
exc.Rethrow();
|
||||
}
|
||||
|
||||
if (rabit::IsDistributed()) {
|
||||
if (collective::IsDistributed()) {
|
||||
double dat[2]{sum_metric, static_cast<double>(ngroups)};
|
||||
// approximately estimate the metric using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return sum_metric / ngroups;
|
||||
@ -341,7 +341,7 @@ struct EvalCox : public Metric {
|
||||
public:
|
||||
EvalCox() = default;
|
||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||
CHECK(!rabit::IsDistributed()) << "Cox metric does not support distributed evaluation";
|
||||
CHECK(!collective::IsDistributed()) << "Cox metric does not support distributed evaluation";
|
||||
using namespace std; // NOLINT(*)
|
||||
|
||||
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
|
||||
|
||||
@ -4,15 +4,12 @@
|
||||
* \brief prediction rank based metrics.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "metric_common.h"
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
|
||||
*/
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <memory>
|
||||
@ -16,6 +15,7 @@
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "metric_common.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/survival_util.h"
|
||||
#include "../common/threading_utils.h"
|
||||
@ -214,7 +214,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
info.labels_upper_bound_, preds);
|
||||
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
|
||||
@ -7,8 +7,8 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
@ -39,7 +39,7 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
||||
auto const& h_node_idx = nidx;
|
||||
|
||||
size_t n_leaf{h_node_idx.size()};
|
||||
rabit::Allreduce<rabit::op::Max>(&n_leaf, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_leaf, 1);
|
||||
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
||||
if (quantiles.empty()) {
|
||||
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
||||
@ -49,12 +49,12 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
||||
std::vector<int32_t> n_valids(quantiles.size());
|
||||
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
|
||||
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
|
||||
rabit::Allreduce<rabit::op::Sum>(n_valids.data(), n_valids.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(n_valids.data(), n_valids.size());
|
||||
// convert to 0 for all reduce
|
||||
std::replace_if(
|
||||
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
|
||||
// use the mean value
|
||||
rabit::Allreduce<rabit::op::Sum>(quantiles.data(), quantiles.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(quantiles.data(), quantiles.size());
|
||||
for (size_t i = 0; i < n_leaf; ++i) {
|
||||
if (n_valids[i] > 0) {
|
||||
quantiles[i] /= static_cast<float>(n_valids[i]);
|
||||
|
||||
@ -724,8 +724,8 @@ class MeanAbsoluteError : public ObjFunction {
|
||||
}
|
||||
|
||||
// Weighted average base score across all workers
|
||||
rabit::Allreduce<rabit::op::Sum>(out.Values().data(), out.Values().size());
|
||||
rabit::Allreduce<rabit::op::Sum>(&w, 1);
|
||||
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
|
||||
collective::Allreduce<collective::Operation::kSum>(&w, 1);
|
||||
|
||||
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
||||
[w](float v) { return v / w; });
|
||||
|
||||
@ -84,11 +84,11 @@ GradientQuantizer::GradientQuantizer(common::Span<GradientPair const> gpair) {
|
||||
// Treat pair as array of 4 primitive types to allreduce
|
||||
using ReduceT = typename decltype(p.first)::ValueT;
|
||||
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
|
||||
rabit::Allreduce<rabit::op::Sum, ReduceT>(reinterpret_cast<ReduceT*>(&p), 4);
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<ReduceT*>(&p), 4);
|
||||
GradientPair positive_sum{p.first}, negative_sum{p.second};
|
||||
|
||||
std::size_t total_rows = gpair.size();
|
||||
rabit::Allreduce<rabit::op::Sum>(&total_rows, 1);
|
||||
collective::Allreduce<collective::Operation::kSum>(&total_rows, 1);
|
||||
|
||||
auto histogram_rounding = GradientSumT{
|
||||
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), total_rows),
|
||||
|
||||
@ -8,10 +8,10 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "../../collective/communicator-inl.h"
|
||||
#include "../../common/hist_util.h"
|
||||
#include "../../data/gradient_index.h"
|
||||
#include "expand_entry.h"
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -202,7 +202,8 @@ class HistogramBuilder {
|
||||
}
|
||||
});
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>(reinterpret_cast<double*>(this->hist_[starting_index].data()),
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double *>(this->hist_[starting_index].data()),
|
||||
builder_.GetNumBins() * sync_count * 2);
|
||||
|
||||
ParallelSubtractionHist(space, nodes_for_explicit_hist_build,
|
||||
|
||||
@ -74,7 +74,7 @@ class GloablApproxBuilder {
|
||||
}
|
||||
|
||||
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
||||
rabit::IsDistributed());
|
||||
collective::IsDistributed());
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ class GloablApproxBuilder {
|
||||
for (auto const &g : gpair) {
|
||||
root_sum.Add(g);
|
||||
}
|
||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
size_t i = 0;
|
||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||
|
||||
@ -4,8 +4,6 @@
|
||||
* \brief use columnwise update to construct a tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
@ -100,7 +98,7 @@ class ColMaker: public TreeUpdater {
|
||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
if (rabit::IsDistributed()) {
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
||||
"support distributed training.";
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../common/io.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
@ -528,12 +529,11 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
// num histograms is the number of contiguous histograms in memory to reduce over
|
||||
void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) {
|
||||
void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||
reducer->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
communicator->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
|
||||
monitor.Stop("AllReduce");
|
||||
@ -542,8 +542,8 @@ struct GPUHistMakerDevice {
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, dh::AllReducer* reducer,
|
||||
const RegTree& tree) {
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates,
|
||||
collective::DeviceCommunicator* communicator, const RegTree& tree) {
|
||||
if (candidates.empty()) return;
|
||||
// Some nodes we will manually compute histograms
|
||||
// others we will do by subtraction
|
||||
@ -574,7 +574,7 @@ struct GPUHistMakerDevice {
|
||||
// Reduce all in one go
|
||||
// This gives much better latency in a distributed setting
|
||||
// when processing a large batch
|
||||
this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size());
|
||||
this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size());
|
||||
|
||||
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
||||
auto build_hist_nidx = hist_nidx.at(i);
|
||||
@ -584,7 +584,7 @@ struct GPUHistMakerDevice {
|
||||
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
this->AllReduceHist(subtraction_trick_nidx, reducer, 1);
|
||||
this->AllReduceHist(subtraction_trick_nidx, communicator, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -593,7 +593,7 @@ struct GPUHistMakerDevice {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
// Sanity check - have we created a leaf with no training instances?
|
||||
if (!rabit::IsDistributed() && row_partitioner) {
|
||||
if (!collective::IsDistributed() && row_partitioner) {
|
||||
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
|
||||
<< "No training instances in this leaf!";
|
||||
}
|
||||
@ -642,7 +642,7 @@ struct GPUHistMakerDevice {
|
||||
parent.RightChild());
|
||||
}
|
||||
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
|
||||
@ -650,11 +650,11 @@ struct GPUHistMakerDevice {
|
||||
GradientPairPrecise root_sum =
|
||||
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(),
|
||||
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
|
||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2);
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double*>(&root_sum), 2);
|
||||
|
||||
hist.AllocateHistograms({kRootNIdx});
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, reducer, 1);
|
||||
this->AllReduceHist(kRootNIdx, communicator, 1);
|
||||
|
||||
// Remember root stats
|
||||
node_sum_gradients[kRootNIdx] = root_sum;
|
||||
@ -669,7 +669,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
|
||||
RegTree* p_tree, dh::AllReducer* reducer,
|
||||
RegTree* p_tree, collective::DeviceCommunicator* communicator,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
@ -680,7 +680,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("Reset");
|
||||
|
||||
monitor.Start("InitRoot");
|
||||
driver.Push({ this->InitRoot(p_tree, reducer) });
|
||||
driver.Push({ this->InitRoot(p_tree, communicator) });
|
||||
monitor.Stop("InitRoot");
|
||||
|
||||
// The set of leaves that can be expanded asynchronously
|
||||
@ -707,7 +707,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
this->BuildHistLeftRight(filtered_expand_set, reducer, tree);
|
||||
this->BuildHistLeftRight(filtered_expand_set, communicator, tree);
|
||||
monitor.Stop("BuildHist");
|
||||
|
||||
monitor.Start("EvaluateSplits");
|
||||
@ -789,11 +789,10 @@ class GPUHistMaker : public TreeUpdater {
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
|
||||
info_ = &dmat->Info();
|
||||
reducer_.Init({ctx_->gpu_id}); // NOLINT
|
||||
|
||||
// Synchronise the column sampling seed
|
||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
|
||||
BatchParam batch_param{
|
||||
ctx_->gpu_id,
|
||||
@ -823,12 +822,12 @@ class GPUHistMaker : public TreeUpdater {
|
||||
void CheckTreesSynchronized(RegTree* local_tree) const {
|
||||
std::string s_model;
|
||||
common::MemoryBufferStream fs(&s_model);
|
||||
int rank = rabit::GetRank();
|
||||
int rank = collective::GetRank();
|
||||
if (rank == 0) {
|
||||
local_tree->Save(&fs);
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
collective::Broadcast(&s_model, 0);
|
||||
RegTree reference_tree{}; // rank 0 tree
|
||||
reference_tree.Load(&fs);
|
||||
CHECK(*local_tree == reference_tree);
|
||||
@ -841,7 +840,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
monitor_.Stop("InitData");
|
||||
|
||||
gpair->SetDevice(ctx_->gpu_id);
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
|
||||
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id);
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
@ -867,8 +867,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
|
||||
dh::AllReducer reducer_;
|
||||
|
||||
DMatrix* p_last_fmat_{nullptr};
|
||||
RegTree const* p_last_tree_{nullptr};
|
||||
ObjInfo task_;
|
||||
|
||||
@ -4,16 +4,13 @@
|
||||
* \brief prune a tree given the statistics
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "./param.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/timer.h"
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
@ -6,19 +6,12 @@
|
||||
*/
|
||||
#include "./updater_quantile_hist.h"
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "constraints.h"
|
||||
#include "hist/evaluate_splits.h"
|
||||
#include "param.h"
|
||||
@ -103,7 +96,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot(
|
||||
for (auto const &grad : gpair_h) {
|
||||
grad_stat.Add(grad.GetGrad(), grad.GetHess());
|
||||
}
|
||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&grad_stat), 2);
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&grad_stat), 2);
|
||||
}
|
||||
|
||||
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||
@ -320,7 +313,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
++page_id;
|
||||
}
|
||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
rabit::IsDistributed());
|
||||
collective::IsDistributed());
|
||||
|
||||
if (param_.subsample < 1.0f) {
|
||||
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -4,17 +4,17 @@
|
||||
* \brief refresh the statistics and leaf value on the tree on the dataset
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "./param.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../predictor/predict_fn.h"
|
||||
#include "./param.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -100,8 +100,9 @@ class TreeRefresher : public TreeUpdater {
|
||||
}
|
||||
});
|
||||
};
|
||||
rabit::Allreduce<rabit::op::Sum>(&dmlc::BeginPtr(stemp[0])->sum_grad, stemp[0].size() * 2,
|
||||
lazy_get_stats);
|
||||
lazy_get_stats();
|
||||
collective::Allreduce<collective::Operation::kSum>(&dmlc::BeginPtr(stemp[0])->sum_grad,
|
||||
stemp[0].size() * 2);
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
|
||||
@ -4,12 +4,14 @@
|
||||
* \brief synchronize the tree in all distributed nodes
|
||||
*/
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/io.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -35,17 +37,17 @@ class TreeSyncher : public TreeUpdater {
|
||||
void Update(HostDeviceVector<GradientPair>*, DMatrix*,
|
||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
if (rabit::GetWorldSize() == 1) return;
|
||||
if (collective::GetWorldSize() == 1) return;
|
||||
std::string s_model;
|
||||
common::MemoryBufferStream fs(&s_model);
|
||||
int rank = rabit::GetRank();
|
||||
int rank = collective::GetRank();
|
||||
if (rank == 0) {
|
||||
for (auto tree : trees) {
|
||||
tree->Save(&fs);
|
||||
}
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
collective::Broadcast(&s_model, 0);
|
||||
for (auto tree : trees) {
|
||||
tree->Load(&fs);
|
||||
}
|
||||
|
||||
@ -46,8 +46,8 @@ template <bool use_column>
|
||||
void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitRabitContext(msg, kWorkers);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, kWorkers);
|
||||
} else {
|
||||
@ -65,7 +65,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto sparsity = 0.5f;
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
std::vector<FeatureType> ft(cols);
|
||||
for (size_t i = 0; i < ft.size(); ++i) {
|
||||
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
||||
@ -99,8 +99,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
sketch_distributed.MakeCuts(&distributed_cuts);
|
||||
|
||||
// Generate cuts for single node environment
|
||||
rabit::Finalize();
|
||||
CHECK_EQ(rabit::GetWorldSize(), 1);
|
||||
collective::Finalize();
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
m->Info().num_row_ = world * rows;
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
@ -184,8 +184,8 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
#if defined(__unix__)
|
||||
std::string msg{"Skipping Quantile AllreduceBasic test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitRabitContext(msg, kWorkers);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
CHECK_EQ(world, kWorkers);
|
||||
} else {
|
||||
@ -196,7 +196,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(
|
||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
HostDeviceVector<float> storage;
|
||||
std::vector<FeatureType> ft(kCols);
|
||||
for (size_t i = 0; i < ft.size(); ++i) {
|
||||
@ -217,12 +217,12 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
std::vector<float> cut_min_values(cuts.MinValues().size() * world, 0);
|
||||
|
||||
size_t value_size = cuts.Values().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&value_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&value_size, 1);
|
||||
size_t ptr_size = cuts.Ptrs().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&ptr_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&ptr_size, 1);
|
||||
CHECK_EQ(ptr_size, kCols + 1);
|
||||
size_t min_value_size = cuts.MinValues().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&min_value_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&min_value_size, 1);
|
||||
CHECK_EQ(min_value_size, kCols);
|
||||
|
||||
size_t value_offset = value_size * rank;
|
||||
@ -235,9 +235,9 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(),
|
||||
cut_min_values.begin() + min_values_offset);
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_values.data(), cut_values.size());
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_ptrs.data(), cut_ptrs.size());
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_min_values.data(), cut_min_values.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_values.data(), cut_values.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_ptrs.data(), cut_ptrs.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_min_values.data(), cut_min_values.size());
|
||||
|
||||
for (int32_t i = 0; i < world; i++) {
|
||||
for (size_t j = 0; j < value_size; ++j) {
|
||||
@ -256,7 +256,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
}
|
||||
}
|
||||
});
|
||||
rabit::Finalize();
|
||||
collective::Finalize();
|
||||
#endif // defined(__unix__)
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_quantile.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/collective/device_communicator.cuh"
|
||||
#include "../../../src/common/hist_util.cuh"
|
||||
#include "../../../src/common/quantile.cuh"
|
||||
|
||||
@ -341,17 +342,14 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
// This test is supposed to run by a python test that setups the environment.
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
auto n_gpus = AllVisibleGPUs();
|
||||
InitRabitContext(msg, n_gpus);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, n_gpus);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, n_gpus);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
auto reducer = std::make_shared<dh::AllReducer>();
|
||||
reducer->Init(0);
|
||||
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
||||
// Set up single node version;
|
||||
@ -385,8 +383,8 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
|
||||
// Set up distributed version. We rely on using rank as seed to generate
|
||||
// the exact same copy of data.
|
||||
auto rank = rabit::GetRank();
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
|
||||
auto rank = collective::GetRank();
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
|
||||
HostDeviceVector<float> storage;
|
||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||
.Device(0)
|
||||
@ -422,28 +420,26 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
|
||||
}
|
||||
});
|
||||
rabit::Finalize();
|
||||
collective::Finalize();
|
||||
}
|
||||
|
||||
TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
std::string msg {"Skipping SameOnAllWorkers test"};
|
||||
auto n_gpus = AllVisibleGPUs();
|
||||
InitRabitContext(msg, n_gpus);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, n_gpus);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, n_gpus);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
auto reducer = std::make_shared<dh::AllReducer>();
|
||||
reducer->Init(0);
|
||||
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
||||
MetaInfo const &info) {
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
|
||||
HostDeviceVector<float> storage;
|
||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||
.Device(0)
|
||||
@ -459,7 +455,7 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
|
||||
// Test for all workers having the same sketch.
|
||||
size_t n_data = sketch_distributed.Data().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&n_data, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_data, 1);
|
||||
ASSERT_EQ(n_data, sketch_distributed.Data().size());
|
||||
size_t size_as_float =
|
||||
sketch_distributed.Data().size_bytes() / sizeof(float);
|
||||
@ -472,9 +468,10 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
thrust::copy(thrust::device, local_data.data(),
|
||||
local_data.data() + local_data.size(),
|
||||
all_workers.begin() + local_data.size() * rank);
|
||||
reducer->AllReduceSum(all_workers.data().get(), all_workers.data().get(),
|
||||
all_workers.size());
|
||||
reducer->Synchronize();
|
||||
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(0);
|
||||
|
||||
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
|
||||
communicator->Synchronize();
|
||||
|
||||
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
|
||||
std::vector<float> h_base_line(base_line.size());
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
#ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
||||
#define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../src/collective/communicator-inl.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
inline void InitRabitContext(std::string msg, int32_t n_workers) {
|
||||
inline void InitCommunicatorContext(std::string msg, int32_t n_workers) {
|
||||
auto port = std::getenv("DMLC_TRACKER_PORT");
|
||||
std::string port_str;
|
||||
if (port) {
|
||||
@ -28,12 +28,11 @@ inline void InitRabitContext(std::string msg, int32_t n_workers) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> envs{
|
||||
"DMLC_TRACKER_PORT=" + port_str,
|
||||
"DMLC_TRACKER_URI=" + uri_str,
|
||||
"DMLC_NUM_WORKER=" + std::to_string(n_workers)};
|
||||
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
|
||||
rabit::Init(3, c_envs);
|
||||
Json config{JsonObject()};
|
||||
config["DMLC_TRACKER_PORT"] = port_str;
|
||||
config["DMLC_TRACKER_URI"] = uri_str;
|
||||
config["DMLC_NUM_WORKER"] = n_workers;
|
||||
collective::Init(config);
|
||||
}
|
||||
|
||||
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
|
||||
|
||||
@ -21,21 +21,19 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
||||
|
||||
|
||||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
||||
rabit_env = [
|
||||
'xgboost_communicator=federated',
|
||||
f'federated_server_address=localhost:{port}',
|
||||
f'federated_world_size={world_size}',
|
||||
f'federated_rank={rank}'
|
||||
]
|
||||
communicator_env = {
|
||||
'xgboost_communicator': 'federated',
|
||||
'federated_server_address': f'localhost:{port}',
|
||||
'federated_world_size': world_size,
|
||||
'federated_rank': rank
|
||||
}
|
||||
if with_ssl:
|
||||
rabit_env = rabit_env + [
|
||||
f'federated_server_cert={SERVER_CERT}',
|
||||
f'federated_client_key={CLIENT_KEY}',
|
||||
f'federated_client_cert={CLIENT_CERT}'
|
||||
]
|
||||
communicator_env['federated_server_cert'] = SERVER_CERT
|
||||
communicator_env['federated_client_key'] = CLIENT_KEY
|
||||
communicator_env['federated_client_cert'] = CLIENT_CERT
|
||||
|
||||
# Always call this before using distributed module
|
||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
||||
@ -55,9 +53,9 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
|
||||
early_stopping_rounds=2)
|
||||
|
||||
# Save the model, only ask process 0 to save the model.
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
if xgb.collective.get_rank() == 0:
|
||||
bst.save_model("test.model.json")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
xgb.collective.communicator_print("Finished training\n")
|
||||
|
||||
|
||||
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Copyright 2019-2022 XGBoost contributors"""
|
||||
import sys
|
||||
import os
|
||||
from typing import Type, TypeVar, Any, Dict, List
|
||||
from typing import Type, TypeVar, Any, Dict, List, Union
|
||||
import pytest
|
||||
import numpy as np
|
||||
import asyncio
|
||||
@ -425,7 +425,7 @@ class TestDistributedGPU:
|
||||
)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with dxgb.RabitContext(rabit_args):
|
||||
with dxgb.CommunicatorContext(**rabit_args):
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||
assert fw_rows == local_dtrain.num_col()
|
||||
@ -505,20 +505,13 @@ class TestDistributedGPU:
|
||||
test = "--gtest_filter=GPUQuantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: List[bytes]
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
port_env = ""
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_PORT"):
|
||||
port_env = arg.decode("utf-8")
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split("=")
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env[uri[0]] = uri[1]
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||
|
||||
workers = _get_client_workers(local_cuda_client)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import multiprocessing
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker
|
||||
from xgboost import collective
|
||||
from xgboost import RabitTracker, build_info, federated
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
|
||||
@ -37,3 +37,41 @@ def test_rabit_communicator():
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
def run_federated_worker(port, world_size, rank):
|
||||
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
|
||||
federated_server_address=f'localhost:{port}',
|
||||
federated_world_size=world_size,
|
||||
federated_rank=rank):
|
||||
assert xgb.collective.get_world_size() == world_size
|
||||
assert xgb.collective.is_distributed()
|
||||
assert xgb.collective.get_processor_name() == f'rank{rank}'
|
||||
ret = xgb.collective.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
|
||||
|
||||
def test_federated_communicator():
|
||||
if not build_info()["USE_FEDERATED"]:
|
||||
pytest.skip("XGBoost not built with federated learning enabled")
|
||||
|
||||
port = 9091
|
||||
world_size = 2
|
||||
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
|
||||
server.start()
|
||||
time.sleep(1)
|
||||
if not server.is_alive():
|
||||
raise Exception("Error starting Federated Learning server")
|
||||
|
||||
workers = []
|
||||
for rank in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_federated_worker,
|
||||
args=(port, world_size, rank))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
server.terminate()
|
||||
|
||||
@ -15,37 +15,33 @@ if sys.platform.startswith("win"):
|
||||
def test_rabit_tracker():
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
worker_env = tracker.worker_envs()
|
||||
rabit_env = []
|
||||
for k, v in worker_env.items():
|
||||
rabit_env.append(f"{k}={v}".encode())
|
||||
with xgb.rabit.RabitContext(rabit_env):
|
||||
ret = xgb.rabit.broadcast("test1234", 0)
|
||||
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
|
||||
ret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
from xgboost import rabit
|
||||
from xgboost import collective
|
||||
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
assert not rabit.is_distributed()
|
||||
assert not collective.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
def local_test(worker_id):
|
||||
with RabitContext(rabit_args):
|
||||
with CommunicatorContext(**rabit_args):
|
||||
a = 1
|
||||
assert rabit.is_distributed()
|
||||
assert collective.is_distributed()
|
||||
a = np.array([a])
|
||||
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
||||
reduced = collective.allreduce(a, collective.Op.SUM)
|
||||
assert reduced[0] == n_workers
|
||||
|
||||
worker_id = np.array([worker_id])
|
||||
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
||||
reduced = collective.allreduce(worker_id, collective.Op.MAX)
|
||||
assert reduced == n_workers - 1
|
||||
|
||||
return 1
|
||||
@ -83,14 +79,10 @@ def test_rank_assignment() -> None:
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
def local_test(worker_id):
|
||||
with xgb.dask.RabitContext(args):
|
||||
for val in args:
|
||||
sval = val.decode("utf-8")
|
||||
if sval.startswith("DMLC_TASK_ID"):
|
||||
task_id = sval
|
||||
break
|
||||
with xgb.dask.CommunicatorContext(**args) as ctx:
|
||||
task_id = ctx["DMLC_TASK_ID"]
|
||||
matched = re.search(".*-([0-9]).*", task_id)
|
||||
rank = xgb.rabit.get_rank()
|
||||
rank = xgb.collective.get_rank()
|
||||
# As long as the number of workers is lesser than 10, rank and worker id
|
||||
# should be the same
|
||||
assert rank == int(matched.group(1))
|
||||
|
||||
@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
|
||||
|
||||
class TestWithDask:
|
||||
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
X, y = tm.make_categorical(100, 4, 4, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy.save_binary(path)
|
||||
|
||||
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy = xgb.DMatrix(path)
|
||||
assert Xy.num_row() == 100
|
||||
@ -1488,20 +1488,13 @@ class TestWithDask:
|
||||
test = "--gtest_filter=Quantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: List[bytes]
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
port_env = ''
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env["DMLC_TRACKER_URI"] = uri[1]
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
|
||||
@ -1543,8 +1536,8 @@ class TestWithDask:
|
||||
def get_score(config: Dict) -> float:
|
||||
return float(config["learner"]["learner_model_param"]["base_score"])
|
||||
|
||||
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
if worker_id == 0:
|
||||
y = np.array([0.0, 0.0, 0.0])
|
||||
x = np.array([[0.0]] * 3)
|
||||
@ -1686,12 +1679,12 @@ class TestWithDask:
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||
**data_ref, nthread=7
|
||||
)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
futures = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user