[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
e47b3a3da3
commit
668b8a0ea4
@ -52,15 +52,15 @@ class XGBoostTrainer(Executor):
|
|||||||
def _do_training(self, fl_ctx: FLContext):
|
def _do_training(self, fl_ctx: FLContext):
|
||||||
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
|
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
|
||||||
rank = int(client_name.split('-')[1]) - 1
|
rank = int(client_name.split('-')[1]) - 1
|
||||||
rabit_env = [
|
communicator_env = {
|
||||||
f'federated_server_address={self._server_address}',
|
'federated_server_address': self._server_address,
|
||||||
f'federated_world_size={self._world_size}',
|
'federated_world_size': self._world_size,
|
||||||
f'federated_rank={rank}',
|
'federated_rank': rank,
|
||||||
f'federated_server_cert={self._server_cert_path}',
|
'federated_server_cert': self._server_cert_path,
|
||||||
f'federated_client_key={self._client_key_path}',
|
'federated_client_key': self._client_key_path,
|
||||||
f'federated_client_cert={self._client_cert_path}'
|
'federated_client_cert': self._client_cert_path
|
||||||
]
|
}
|
||||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||||
# Load file, file will not be sharded in federated mode.
|
# Load file, file will not be sharded in federated mode.
|
||||||
dtrain = xgb.DMatrix('agaricus.txt.train')
|
dtrain = xgb.DMatrix('agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix('agaricus.txt.test')
|
dtest = xgb.DMatrix('agaricus.txt.test')
|
||||||
@ -86,4 +86,4 @@ class XGBoostTrainer(Executor):
|
|||||||
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
|
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
|
||||||
run_dir = workspace.get_run_dir(run_number)
|
run_dir = workspace.get_run_dir(run_number)
|
||||||
bst.save_model(os.path.join(run_dir, "test.model.json"))
|
bst.save_model(os.path.join(run_dir, "test.model.json"))
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.collective.communicator_print("Finished training\n")
|
||||||
|
|||||||
@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink
|
|||||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
import scala.collection.JavaConverters.asScalaIteratorConverter
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
|
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
||||||
|
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
@ -46,7 +46,7 @@ object XGBoost {
|
|||||||
collector: Collector[XGBoostModel]): Unit = {
|
collector: Collector[XGBoostModel]): Unit = {
|
||||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||||
logger.info("start with env" + workerEnvs.toString)
|
logger.info("start with env" + workerEnvs.toString)
|
||||||
Rabit.init(workerEnvs)
|
Communicator.init(workerEnvs)
|
||||||
val mapper = (x: LabeledVector) => {
|
val mapper = (x: LabeledVector) => {
|
||||||
val (index, value) = x.vector.toSeq.unzip
|
val (index, value) = x.vector.toSeq.unzip
|
||||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||||
@ -59,7 +59,7 @@ object XGBoost {
|
|||||||
.map(_.toString.toInt).getOrElse(0)
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds)
|
earlyStoppingRound = numEarlyStoppingRounds)
|
||||||
Rabit.shutdown()
|
Communicator.shutdown()
|
||||||
collector.collect(new XGBoostModel(booster))
|
collector.collect(new XGBoostModel(booster))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import java.util.ServiceLoader
|
|||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Communicator
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||||
@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
if (batchCnt == 0) {
|
if (batchCnt == 0) {
|
||||||
val rabitEnv = Array(
|
val rabitEnv = Array(
|
||||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
Rabit.init(rabitEnv.asJava)
|
Communicator.init(rabitEnv.asJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||||
@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
override def next(): Row = {
|
override def next(): Row = {
|
||||||
val ret = batchIterImpl.next()
|
val ret = batchIterImpl.next()
|
||||||
if (!batchIterImpl.hasNext) {
|
if (!batchIterImpl.hasNext) {
|
||||||
Rabit.shutdown()
|
Communicator.shutdown()
|
||||||
}
|
}
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import scala.collection.mutable
|
|||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||||
@ -303,7 +303,7 @@ object XGBoost extends Serializable {
|
|||||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Rabit.init(rabitEnv)
|
Communicator.init(rabitEnv)
|
||||||
|
|
||||||
watches = buildWatchesAndCheck(buildWatches)
|
watches = buildWatchesAndCheck(buildWatches)
|
||||||
|
|
||||||
@ -342,7 +342,7 @@ object XGBoost extends Serializable {
|
|||||||
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
|
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
|
||||||
throw xgbException
|
throw xgbException
|
||||||
} finally {
|
} finally {
|
||||||
Rabit.shutdown()
|
Communicator.shutdown()
|
||||||
if (watches != null) watches.delete()
|
if (watches != null) watches.delete()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,277 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import java.util.concurrent.LinkedBlockingDeque
|
|
||||||
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
|
||||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
|
||||||
import org.scalatest.{FunSuite}
|
|
||||||
|
|
||||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
|
||||||
|
|
||||||
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
|
||||||
val classifier = new XGBoostClassifier(paramMap)
|
|
||||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
|
|
||||||
xgbParamsFactory.buildXGBRuntimeParams
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
test("Customize host ip and python exec for Rabit tracker") {
|
|
||||||
val hostIp = "192.168.22.111"
|
|
||||||
val pythonExec = "/usr/bin/python3"
|
|
||||||
|
|
||||||
val paramMap = Map(
|
|
||||||
"num_workers" -> numWorkers,
|
|
||||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
|
|
||||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
|
||||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
|
||||||
tracker match {
|
|
||||||
case pyTracker: PyRabitTracker =>
|
|
||||||
val cmd = pyTracker.getRabitTrackerCommand
|
|
||||||
assert(cmd.contains(hostIp))
|
|
||||||
assert(cmd.startsWith("python"))
|
|
||||||
case _ => assert(false, "expected python tracker implementation")
|
|
||||||
}
|
|
||||||
|
|
||||||
val paramMap1 = Map(
|
|
||||||
"num_workers" -> numWorkers,
|
|
||||||
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
|
|
||||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
|
||||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
|
||||||
tracker1 match {
|
|
||||||
case pyTracker: PyRabitTracker =>
|
|
||||||
val cmd = pyTracker.getRabitTrackerCommand
|
|
||||||
assert(cmd.startsWith(pythonExec))
|
|
||||||
assert(!cmd.contains(hostIp))
|
|
||||||
case _ => assert(false, "expected python tracker implementation")
|
|
||||||
}
|
|
||||||
|
|
||||||
val paramMap2 = Map(
|
|
||||||
"num_workers" -> numWorkers,
|
|
||||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
|
|
||||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
|
||||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
|
||||||
tracker2 match {
|
|
||||||
case pyTracker: PyRabitTracker =>
|
|
||||||
val cmd = pyTracker.getRabitTrackerCommand
|
|
||||||
assert(cmd.startsWith(pythonExec))
|
|
||||||
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
|
||||||
case _ => assert(false, "expected python tracker implementation")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("training with Scala-implemented Rabit tracker") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
|
||||||
val vectorLength = 100
|
|
||||||
val rdd = sc.parallelize(
|
|
||||||
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
|
||||||
|
|
||||||
val tracker = new ScalaRabitTracker(numWorkers)
|
|
||||||
tracker.start(0)
|
|
||||||
val trackerEnvs = tracker.getWorkerEnvs
|
|
||||||
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
|
||||||
|
|
||||||
val rawData = rdd.mapPartitions { iter =>
|
|
||||||
Iterator(iter.toArray)
|
|
||||||
}.collect()
|
|
||||||
|
|
||||||
val maxVec = (0 until vectorLength).toArray.map { j =>
|
|
||||||
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
|
||||||
}
|
|
||||||
|
|
||||||
val allReduceResults = rdd.mapPartitions { iter =>
|
|
||||||
Rabit.init(trackerEnvs)
|
|
||||||
val arr = iter.toArray
|
|
||||||
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
|
|
||||||
Rabit.shutdown()
|
|
||||||
Iterator(results)
|
|
||||||
}.cache()
|
|
||||||
|
|
||||||
val sparkThread = new Thread() {
|
|
||||||
override def run(): Unit = {
|
|
||||||
allReduceResults.foreachPartition(() => _)
|
|
||||||
val byPartitionResults = allReduceResults.collect()
|
|
||||||
assert(byPartitionResults(0).length == vectorLength)
|
|
||||||
collectedAllReduceResults.put(byPartitionResults(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sparkThread.start()
|
|
||||||
assert(tracker.waitFor(0L) == 0)
|
|
||||||
sparkThread.join()
|
|
||||||
|
|
||||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
|
||||||
/*
|
|
||||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
|
||||||
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
|
|
||||||
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
|
|
||||||
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
|
|
||||||
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
|
||||||
that should be avoided.
|
|
||||||
*/
|
|
||||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
|
||||||
|
|
||||||
val tracker = new PyRabitTracker(numWorkers)
|
|
||||||
tracker.start(0)
|
|
||||||
val trackerEnvs = tracker.getWorkerEnvs
|
|
||||||
|
|
||||||
val workerCount: Int = numWorkers
|
|
||||||
/*
|
|
||||||
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
|
|
||||||
last created worker. A cascading event chain will be triggered once the RuntimeException is
|
|
||||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
|
||||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
|
||||||
|
|
||||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
|
||||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
|
||||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
|
||||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
|
||||||
|
|
||||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
|
||||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
|
||||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
|
||||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
|
||||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
|
||||||
the entire Spark application to crash.
|
|
||||||
|
|
||||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
|
||||||
the exception is thrown at last, ideally after all worker connections have been established.
|
|
||||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
|
||||||
process to ensure that pending worker connections are handled.
|
|
||||||
*/
|
|
||||||
val dummyTasks = rdd.mapPartitions { iter =>
|
|
||||||
Rabit.init(trackerEnvs)
|
|
||||||
val index = iter.next()
|
|
||||||
Thread.sleep(100 + index * 10)
|
|
||||||
if (index == workerCount) {
|
|
||||||
// kill the worker by throwing an exception
|
|
||||||
throw new RuntimeException("Worker exception.")
|
|
||||||
}
|
|
||||||
Rabit.shutdown()
|
|
||||||
Iterator(index)
|
|
||||||
}.cache()
|
|
||||||
|
|
||||||
val sparkThread = new Thread() {
|
|
||||||
override def run(): Unit = {
|
|
||||||
// forces a Spark job.
|
|
||||||
dummyTasks.foreachPartition(() => _)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
|
||||||
sparkThread.start()
|
|
||||||
assert(tracker.waitFor(0) != 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
|
||||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
|
||||||
|
|
||||||
val tracker = new ScalaRabitTracker(numWorkers)
|
|
||||||
tracker.start(0)
|
|
||||||
val trackerEnvs = tracker.getWorkerEnvs
|
|
||||||
|
|
||||||
val workerCount: Int = numWorkers
|
|
||||||
val dummyTasks = rdd.mapPartitions { iter =>
|
|
||||||
Rabit.init(trackerEnvs)
|
|
||||||
val index = iter.next()
|
|
||||||
Thread.sleep(100 + index * 10)
|
|
||||||
if (index == workerCount) {
|
|
||||||
// kill the worker by throwing an exception
|
|
||||||
throw new RuntimeException("Worker exception.")
|
|
||||||
}
|
|
||||||
Rabit.shutdown()
|
|
||||||
Iterator(index)
|
|
||||||
}.cache()
|
|
||||||
|
|
||||||
val sparkThread = new Thread() {
|
|
||||||
override def run(): Unit = {
|
|
||||||
// forces a Spark job.
|
|
||||||
dummyTasks.foreachPartition(() => _)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
|
||||||
sparkThread.start()
|
|
||||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
|
||||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
|
||||||
|
|
||||||
val tracker = new ScalaRabitTracker(numWorkers)
|
|
||||||
tracker.start(500)
|
|
||||||
val trackerEnvs = tracker.getWorkerEnvs
|
|
||||||
|
|
||||||
val dummyTasks = rdd.mapPartitions { iter =>
|
|
||||||
val index = iter.next()
|
|
||||||
// simulate that the first worker cannot connect to tracker due to network issues.
|
|
||||||
if (index != 1) {
|
|
||||||
Rabit.init(trackerEnvs)
|
|
||||||
Thread.sleep(1000)
|
|
||||||
Rabit.shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
Iterator(index)
|
|
||||||
}.cache()
|
|
||||||
|
|
||||||
val sparkThread = new Thread() {
|
|
||||||
override def run(): Unit = {
|
|
||||||
// forces a Spark job.
|
|
||||||
dummyTasks.foreachPartition(() => _)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
|
||||||
sparkThread.start()
|
|
||||||
// should fail due to connection timeout
|
|
||||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("should allow the dataframe containing rabit calls to be partially evaluated for" +
|
|
||||||
" multiple times (ISSUE-4406)") {
|
|
||||||
val paramMap = Map(
|
|
||||||
"eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic")
|
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
|
||||||
val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
|
|
||||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
|
||||||
val prediction = model.transform(trainingDF)
|
|
||||||
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
|
|
||||||
// threads
|
|
||||||
prediction.show()
|
|
||||||
// a full evaluation here will re-run init and shutdown all rabit proxy
|
|
||||||
// expecting no error
|
|
||||||
prediction.collect()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Communicator
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ import org.scalatest.FunSuite
|
|||||||
|
|
||||||
import org.apache.spark.SparkException
|
import org.apache.spark.SparkException
|
||||||
|
|
||||||
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
|
||||||
val predictionErrorMin = 0.00001f
|
val predictionErrorMin = 0.00001f
|
||||||
val maxFailure = 2;
|
val maxFailure = 2;
|
||||||
|
|
||||||
@ -47,8 +47,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
|||||||
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
||||||
.fit(training)
|
.fit(training)
|
||||||
|
|
||||||
assert(Rabit.rabitEnvs.asScala.size > 3)
|
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -70,8 +70,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
||||||
).fit(training)
|
).fit(training)
|
||||||
assert(Rabit.rabitEnvs.asScala.size > 3)
|
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||||
})
|
})
|
||||||
// check the equality of single instance prediction
|
// check the equality of single instance prediction
|
||||||
@ -85,7 +85,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
|||||||
test("test rabit timeout fail handle") {
|
test("test rabit timeout fail handle") {
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
// mock rank 0 failure during 8th allreduce synchronization
|
// mock rank 0 failure during 8th allreduce synchronization
|
||||||
Rabit.mockList = Array("0,8,0,0").toList.asJava
|
Communicator.mockList = Array("0,8,0,0").toList.asJava
|
||||||
|
|
||||||
intercept[SparkException] {
|
intercept[SparkException] {
|
||||||
new XGBoostClassifier(Map(
|
new XGBoostClassifier(Map(
|
||||||
@ -98,6 +98,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
|||||||
"rabit_timeout" -> 0))
|
"rabit_timeout" -> 0))
|
||||||
.fit(training)
|
.fit(training)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Communicator.mockList = Array.empty.toList.asJava
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -1,154 +0,0 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.ByteOrder;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Rabit global class for synchronization.
|
|
||||||
*/
|
|
||||||
public class Rabit {
|
|
||||||
|
|
||||||
public enum OpType implements Serializable {
|
|
||||||
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
|
|
||||||
|
|
||||||
private int op;
|
|
||||||
|
|
||||||
public int getOperand() {
|
|
||||||
return this.op;
|
|
||||||
}
|
|
||||||
|
|
||||||
OpType(int op) {
|
|
||||||
this.op = op;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public enum DataType implements Serializable {
|
|
||||||
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
|
|
||||||
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
|
|
||||||
LONGLONG(8, 8), ULONGLONG(9, 8);
|
|
||||||
|
|
||||||
private int enumOp;
|
|
||||||
private int size;
|
|
||||||
|
|
||||||
public int getEnumOp() {
|
|
||||||
return this.enumOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getSize() {
|
|
||||||
return this.size;
|
|
||||||
}
|
|
||||||
|
|
||||||
DataType(int enumOp, int size) {
|
|
||||||
this.enumOp = enumOp;
|
|
||||||
this.size = size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void checkCall(int ret) throws XGBoostError {
|
|
||||||
if (ret != 0) {
|
|
||||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// used as way to test/debug passed rabit init parameters
|
|
||||||
public static Map<String, String> rabitEnvs;
|
|
||||||
public static List<String> mockList = new LinkedList<>();
|
|
||||||
/**
|
|
||||||
* Initialize the rabit library on current working thread.
|
|
||||||
* @param envs The additional environment variables to pass to rabit.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
|
||||||
rabitEnvs = envs;
|
|
||||||
String[] args = new String[envs.size() + mockList.size()];
|
|
||||||
int idx = 0;
|
|
||||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
|
||||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
|
||||||
}
|
|
||||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
|
||||||
for(String mock : mockList) {
|
|
||||||
args[idx++] = "mock=" + mock;
|
|
||||||
}
|
|
||||||
checkCall(XGBoostJNI.RabitInit(args));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Shutdown the rabit engine in current working thread, equals to finalize.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public static void shutdown() throws XGBoostError {
|
|
||||||
checkCall(XGBoostJNI.RabitFinalize());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Print the message on rabit tracker.
|
|
||||||
* @param msg
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public static void trackerPrint(String msg) throws XGBoostError {
|
|
||||||
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get version number of current stored model in the thread.
|
|
||||||
* which means how many calls to CheckPoint we made so far.
|
|
||||||
* @return version Number.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public static int versionNumber() throws XGBoostError {
|
|
||||||
int[] out = new int[1];
|
|
||||||
checkCall(XGBoostJNI.RabitVersionNumber(out));
|
|
||||||
return out[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get rank of current thread.
|
|
||||||
* @return the rank.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public static int getRank() throws XGBoostError {
|
|
||||||
int[] out = new int[1];
|
|
||||||
checkCall(XGBoostJNI.RabitGetRank(out));
|
|
||||||
return out[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get world size of current job.
|
|
||||||
* @return the worldsize
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public static int getWorldSize() throws XGBoostError {
|
|
||||||
int[] out = new int[1];
|
|
||||||
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
|
||||||
return out[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* perform Allreduce on distributed float vectors using operator op.
|
|
||||||
* This implementation of allReduce does not support customized prepare function callback in the
|
|
||||||
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
|
|
||||||
*
|
|
||||||
* @param elements local elements on distributed workers.
|
|
||||||
* @param op operator used for Allreduce.
|
|
||||||
* @return All-reduced float elements according to the given operator.
|
|
||||||
*/
|
|
||||||
public static float[] allReduce(float[] elements, OpType op) {
|
|
||||||
DataType dataType = DataType.FLOAT;
|
|
||||||
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
|
|
||||||
.order(ByteOrder.nativeOrder());
|
|
||||||
|
|
||||||
for (float el : elements) {
|
|
||||||
buffer.putFloat(el);
|
|
||||||
}
|
|
||||||
buffer.flip();
|
|
||||||
|
|
||||||
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
|
|
||||||
float[] results = new float[elements.length];
|
|
||||||
buffer.asFloatBuffer().get(results);
|
|
||||||
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -254,16 +254,16 @@ public class XGBoost {
|
|||||||
}
|
}
|
||||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||||
if (shouldPrint(params, iter)) {
|
if (shouldPrint(params, iter)) {
|
||||||
Rabit.trackerPrint(String.format(
|
Communicator.communicatorPrint(String.format(
|
||||||
"early stopping after %d rounds away from the best iteration",
|
"early stopping after %d rounds away from the best iteration",
|
||||||
earlyStoppingRounds
|
earlyStoppingRounds
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
|
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
||||||
if (shouldPrint(params, iter)){
|
if (shouldPrint(params, iter)){
|
||||||
Rabit.trackerPrint(evalInfo + '\n');
|
Communicator.communicatorPrint(evalInfo + '\n');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -135,19 +135,6 @@ class XGBoostJNI {
|
|||||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||||
|
|
||||||
// rabit functions
|
|
||||||
public final static native int RabitInit(String[] args);
|
|
||||||
public final static native int RabitFinalize();
|
|
||||||
public final static native int RabitTrackerPrint(String msg);
|
|
||||||
public final static native int RabitGetRank(int[] out);
|
|
||||||
public final static native int RabitGetWorldSize(int[] out);
|
|
||||||
public final static native int RabitVersionNumber(int[] out);
|
|
||||||
|
|
||||||
// Perform Allreduce operation on data in sendrecvbuf.
|
|
||||||
// This JNI function does not support the callback function for data preparation yet.
|
|
||||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
|
||||||
int enum_dtype, int enum_op);
|
|
||||||
|
|
||||||
// communicator functions
|
// communicator functions
|
||||||
public final static native int CommunicatorInit(String[] args);
|
public final static native int CommunicatorInit(String[] args);
|
||||||
public final static native int CommunicatorFinalize();
|
public final static native int CommunicatorFinalize();
|
||||||
|
|||||||
@ -872,111 +872,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitInit
|
|
||||||
* Signature: ([Ljava/lang/String;)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
|
||||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
|
||||||
std::vector<std::string> args;
|
|
||||||
std::vector<char*> argv;
|
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
|
||||||
for (bst_ulong i = 0; i < len; ++i) {
|
|
||||||
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
|
||||||
const char *s = jenv->GetStringUTFChars(arg, 0);
|
|
||||||
args.push_back(std::string(s, jenv->GetStringLength(arg)));
|
|
||||||
if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
|
|
||||||
if (args.back().length() == 0) args.pop_back();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
|
||||||
argv.push_back(&args[i][0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
|
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitFinalize
|
|
||||||
* Signature: ()I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
|
||||||
(JNIEnv *jenv, jclass jcls) {
|
|
||||||
if (RabitFinalize()) {
|
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitTrackerPrint
|
|
||||||
* Signature: (Ljava/lang/String;)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
|
||||||
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
|
|
||||||
std::string str(jenv->GetStringUTFChars(jmsg, 0),
|
|
||||||
jenv->GetStringLength(jmsg));
|
|
||||||
JVM_CHECK_CALL(RabitTrackerPrint(str.c_str()));
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitGetRank
|
|
||||||
* Signature: ([I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
|
||||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
|
||||||
jint rank = RabitGetRank();
|
|
||||||
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitGetWorldSize
|
|
||||||
* Signature: ([I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
|
||||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
|
||||||
jint out = RabitGetWorldSize();
|
|
||||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitVersionNumber
|
|
||||||
* Signature: ([I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
|
||||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
|
||||||
jint out = RabitVersionNumber();
|
|
||||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitAllreduce
|
|
||||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
|
||||||
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
|
|
||||||
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
|
|
||||||
JVM_CHECK_CALL(RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL));
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: CommunicatorInit
|
* Method: CommunicatorInit
|
||||||
|
|||||||
@ -279,62 +279,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||||
(JNIEnv *, jclass, jlong, jlongArray);
|
(JNIEnv *, jclass, jlong, jlongArray);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitInit
|
|
||||||
* Signature: ([Ljava/lang/String;)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
|
||||||
(JNIEnv *, jclass, jobjectArray);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitFinalize
|
|
||||||
* Signature: ()I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
|
||||||
(JNIEnv *, jclass);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitTrackerPrint
|
|
||||||
* Signature: (Ljava/lang/String;)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
|
||||||
(JNIEnv *, jclass, jstring);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitGetRank
|
|
||||||
* Signature: ([I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
|
||||||
(JNIEnv *, jclass, jintArray);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitGetWorldSize
|
|
||||||
* Signature: ([I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
|
||||||
(JNIEnv *, jclass, jintArray);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitVersionNumber
|
|
||||||
* Signature: ([I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
|
||||||
(JNIEnv *, jclass, jintArray);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: RabitAllreduce
|
|
||||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
|
||||||
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: CommunicatorInit
|
* Method: CommunicatorInit
|
||||||
|
|||||||
@ -300,7 +300,7 @@ public class DMatrixTest {
|
|||||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||||
Map<String, String> rabitEnv = new HashMap<>();
|
Map<String, String> rabitEnv = new HashMap<>();
|
||||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||||
Rabit.init(rabitEnv);
|
Communicator.init(rabitEnv);
|
||||||
DMatrix trainMat = null;
|
DMatrix trainMat = null;
|
||||||
BigDenseMatrix data0 = null;
|
BigDenseMatrix data0 = null;
|
||||||
try {
|
try {
|
||||||
@ -348,7 +348,7 @@ public class DMatrixTest {
|
|||||||
else if (data0 != null) {
|
else if (data0 != null) {
|
||||||
data0.dispose();
|
data0.dispose();
|
||||||
}
|
}
|
||||||
Rabit.shutdown();
|
Communicator.shutdown();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,6 @@ target_sources(federated_client INTERFACE federated_client.h)
|
|||||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||||
|
|
||||||
# Rabit engine for Federated Learning.
|
# Rabit engine for Federated Learning.
|
||||||
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
|
target_sources(objxgboost PRIVATE federated_server.cc)
|
||||||
target_link_libraries(objxgboost PRIVATE federated_client)
|
target_link_libraries(objxgboost PRIVATE federated_client)
|
||||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||||
|
|||||||
@ -1,197 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2022 XGBoost contributors
|
|
||||||
*/
|
|
||||||
#include <cstdio>
|
|
||||||
#include <fstream>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "federated_client.h"
|
|
||||||
#include "rabit/internal/engine.h"
|
|
||||||
#include "rabit/internal/utils.h"
|
|
||||||
|
|
||||||
namespace MPI { // NOLINT
|
|
||||||
// MPI data type to be compatible with existing MPI interface
|
|
||||||
class Datatype {
|
|
||||||
public:
|
|
||||||
size_t type_size;
|
|
||||||
explicit Datatype(size_t type_size) : type_size(type_size) {}
|
|
||||||
};
|
|
||||||
} // namespace MPI
|
|
||||||
|
|
||||||
namespace rabit {
|
|
||||||
namespace engine {
|
|
||||||
|
|
||||||
/*! \brief implementation of engine using federated learning */
|
|
||||||
class FederatedEngine : public IEngine {
|
|
||||||
public:
|
|
||||||
void Init(int argc, char *argv[]) {
|
|
||||||
// Parse environment variables first.
|
|
||||||
for (auto const &env_var : env_vars_) {
|
|
||||||
char const *value = getenv(env_var.c_str());
|
|
||||||
if (value != nullptr) {
|
|
||||||
SetParam(env_var, value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Command line argument overrides.
|
|
||||||
for (int i = 0; i < argc; ++i) {
|
|
||||||
std::string const key_value = argv[i];
|
|
||||||
auto const delimiter = key_value.find('=');
|
|
||||||
if (delimiter != std::string::npos) {
|
|
||||||
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
|
|
||||||
server_address_.c_str(), world_size_, rank_);
|
|
||||||
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
|
|
||||||
utils::Printf("Certificates not specified, turning off SSL.");
|
|
||||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
|
|
||||||
} else {
|
|
||||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
|
|
||||||
client_key_, client_cert_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Finalize() { client_.reset(); }
|
|
||||||
|
|
||||||
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
|
|
||||||
size_t size_prev_slice) override {
|
|
||||||
throw std::logic_error("FederatedEngine:: Allgather is not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Allgather(void *sendbuf, size_t total_size) {
|
|
||||||
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
|
|
||||||
return client_->Allgather(send_buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
|
|
||||||
PreprocFunction prepare_fun, void *prepare_arg) override {
|
|
||||||
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
|
|
||||||
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
|
||||||
std::string const send_buffer(buffer, size);
|
|
||||||
auto const receive_buffer =
|
|
||||||
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
|
|
||||||
static_cast<xgboost::federated::ReduceOperation>(op));
|
|
||||||
receive_buffer.copy(buffer, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetRingPrevRank() const override {
|
|
||||||
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Broadcast(void *sendrecvbuf, size_t size, int root) override {
|
|
||||||
if (world_size_ == 1) return;
|
|
||||||
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
|
||||||
std::string const send_buffer(buffer, size);
|
|
||||||
auto const receive_buffer = client_->Broadcast(send_buffer, root);
|
|
||||||
if (rank_ != root) {
|
|
||||||
receive_buffer.copy(buffer, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int LoadCheckPoint() override { return 0; }
|
|
||||||
|
|
||||||
void CheckPoint() override { version_number_ += 1; }
|
|
||||||
|
|
||||||
int VersionNumber() const override { return version_number_; }
|
|
||||||
|
|
||||||
/*! \brief get rank of current node */
|
|
||||||
int GetRank() const override { return rank_; }
|
|
||||||
|
|
||||||
/*! \brief get total number of */
|
|
||||||
int GetWorldSize() const override { return world_size_; }
|
|
||||||
|
|
||||||
/*! \brief whether it is distributed */
|
|
||||||
bool IsDistributed() const override { return true; }
|
|
||||||
|
|
||||||
/*! \brief get the host name of current node */
|
|
||||||
std::string GetHost() const override { return "rank" + std::to_string(rank_); }
|
|
||||||
|
|
||||||
void TrackerPrint(const std::string &msg) override {
|
|
||||||
// simply print information into the tracker
|
|
||||||
utils::Printf("%s", msg.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void SetParam(std::string const &name, std::string const &val) {
|
|
||||||
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
|
|
||||||
server_address_ = val;
|
|
||||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
|
|
||||||
world_size_ = std::stoi(val);
|
|
||||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
|
|
||||||
rank_ = std::stoi(val);
|
|
||||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
|
|
||||||
server_cert_ = ReadFile(val);
|
|
||||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
|
|
||||||
client_key_ = ReadFile(val);
|
|
||||||
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
|
|
||||||
client_cert_ = ReadFile(val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string ReadFile(std::string const &path) {
|
|
||||||
auto stream = std::ifstream(path.data());
|
|
||||||
std::ostringstream out;
|
|
||||||
out << stream.rdbuf();
|
|
||||||
return out.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
std::vector<std::string> const env_vars_{
|
|
||||||
"FEDERATED_SERVER_ADDRESS",
|
|
||||||
"FEDERATED_WORLD_SIZE",
|
|
||||||
"FEDERATED_RANK",
|
|
||||||
"FEDERATED_SERVER_CERT",
|
|
||||||
"FEDERATED_CLIENT_KEY",
|
|
||||||
"FEDERATED_CLIENT_CERT" };
|
|
||||||
// clang-format on
|
|
||||||
std::string server_address_{"localhost:9091"};
|
|
||||||
int world_size_{1};
|
|
||||||
int rank_{0};
|
|
||||||
std::string server_cert_{};
|
|
||||||
std::string client_key_{};
|
|
||||||
std::string client_cert_{};
|
|
||||||
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
|
||||||
int version_number_{0};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Singleton federated engine.
|
|
||||||
FederatedEngine engine; // NOLINT(cert-err58-cpp)
|
|
||||||
|
|
||||||
/*! \brief initialize the synchronization module */
|
|
||||||
bool Init(int argc, char *argv[]) {
|
|
||||||
try {
|
|
||||||
engine.Init(argc, argv);
|
|
||||||
return true;
|
|
||||||
} catch (std::exception const &e) {
|
|
||||||
fprintf(stderr, " failed in federated Init %s\n", e.what());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief finalize synchronization module */
|
|
||||||
bool Finalize() {
|
|
||||||
try {
|
|
||||||
engine.Finalize();
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception &e) {
|
|
||||||
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief singleton method to get engine */
|
|
||||||
IEngine *GetEngine() { return &engine; }
|
|
||||||
|
|
||||||
// perform in-place allreduce, on sendrecvbuf
|
|
||||||
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
|
|
||||||
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
|
|
||||||
void *prepare_arg) {
|
|
||||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
|
||||||
if (engine.GetWorldSize() == 1) return;
|
|
||||||
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
|
|
||||||
}
|
|
||||||
} // namespace engine
|
|
||||||
} // namespace rabit
|
|
||||||
@ -11,6 +11,40 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace collective {
|
namespace collective {
|
||||||
|
|
||||||
|
/** @brief Get the size of the data type. */
|
||||||
|
inline std::size_t GetTypeSize(DataType data_type) {
|
||||||
|
std::size_t size{0};
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::kInt8:
|
||||||
|
size = sizeof(std::int8_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt8:
|
||||||
|
size = sizeof(std::uint8_t);
|
||||||
|
break;
|
||||||
|
case DataType::kInt32:
|
||||||
|
size = sizeof(std::int32_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt32:
|
||||||
|
size = sizeof(std::uint32_t);
|
||||||
|
break;
|
||||||
|
case DataType::kInt64:
|
||||||
|
size = sizeof(std::int64_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt64:
|
||||||
|
size = sizeof(std::uint64_t);
|
||||||
|
break;
|
||||||
|
case DataType::kFloat:
|
||||||
|
size = sizeof(float);
|
||||||
|
break;
|
||||||
|
case DataType::kDouble:
|
||||||
|
size = sizeof(double);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown data type.";
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A Federated Learning communicator class that handles collective communication.
|
* @brief A Federated Learning communicator class that handles collective communication.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -3,9 +3,8 @@
|
|||||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import rabit # noqa
|
|
||||||
from . import tracker # noqa
|
from . import tracker # noqa
|
||||||
from . import dask
|
from . import collective, dask
|
||||||
from .core import (
|
from .core import (
|
||||||
Booster,
|
Booster,
|
||||||
DataIter,
|
DataIter,
|
||||||
@ -63,4 +62,6 @@ __all__ = [
|
|||||||
"XGBRFRegressor",
|
"XGBRFRegressor",
|
||||||
# dask
|
# dask
|
||||||
"dask",
|
"dask",
|
||||||
|
# collective
|
||||||
|
"collective",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import pickle
|
|||||||
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast, Sequence, Any
|
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast, Sequence, Any
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import rabit
|
from . import collective
|
||||||
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ def _allreduce_metric(score: _ART) -> _ART:
|
|||||||
as final result.
|
as final result.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
world = rabit.get_world_size()
|
world = collective.get_world_size()
|
||||||
assert world != 0
|
assert world != 0
|
||||||
if world == 1:
|
if world == 1:
|
||||||
return score
|
return score
|
||||||
@ -108,7 +108,7 @@ def _allreduce_metric(score: _ART) -> _ART:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'xgboost.cv function should not be used in distributed environment.')
|
'xgboost.cv function should not be used in distributed environment.')
|
||||||
arr = numpy.array([score])
|
arr = numpy.array([score])
|
||||||
arr = rabit.allreduce(arr, rabit.Op.SUM) / world
|
arr = collective.allreduce(arr, collective.Op.SUM) / world
|
||||||
return arr[0]
|
return arr[0]
|
||||||
|
|
||||||
|
|
||||||
@ -485,7 +485,7 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
msg: str = f'[{epoch}]'
|
msg: str = f'[{epoch}]'
|
||||||
if rabit.get_rank() == self.printer_rank:
|
if collective.get_rank() == self.printer_rank:
|
||||||
for data, metric in evals_log.items():
|
for data, metric in evals_log.items():
|
||||||
for metric_name, log in metric.items():
|
for metric_name, log in metric.items():
|
||||||
stdv: Optional[float] = None
|
stdv: Optional[float] = None
|
||||||
@ -498,7 +498,7 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
msg += '\n'
|
msg += '\n'
|
||||||
|
|
||||||
if (epoch % self.period) == 0 or self.period == 1:
|
if (epoch % self.period) == 0 or self.period == 1:
|
||||||
rabit.tracker_print(msg)
|
collective.communicator_print(msg)
|
||||||
self._latest = None
|
self._latest = None
|
||||||
else:
|
else:
|
||||||
# There is skipped message
|
# There is skipped message
|
||||||
@ -506,8 +506,8 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def after_training(self, model: _Model) -> _Model:
|
def after_training(self, model: _Model) -> _Model:
|
||||||
if rabit.get_rank() == self.printer_rank and self._latest is not None:
|
if collective.get_rank() == self.printer_rank and self._latest is not None:
|
||||||
rabit.tracker_print(self._latest)
|
collective.communicator_print(self._latest)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -552,7 +552,7 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||||
('.pkl' if self._as_pickle else '.json'))
|
('.pkl' if self._as_pickle else '.json'))
|
||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
if rabit.get_rank() == 0:
|
if collective.get_rank() == 0:
|
||||||
if self._as_pickle:
|
if self._as_pickle:
|
||||||
with open(path, 'wb') as fd:
|
with open(path, 'wb') as fd:
|
||||||
pickle.dump(model, fd)
|
pickle.dump(model, fd)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
from enum import IntEnum, unique
|
from enum import IntEnum, unique
|
||||||
from typing import Any, List
|
from typing import Any, List, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -233,10 +233,11 @@ class CommunicatorContext:
|
|||||||
def __init__(self, **args: Any) -> None:
|
def __init__(self, **args: Any) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> Dict[str, Any]:
|
||||||
init(**self.args)
|
init(**self.args)
|
||||||
assert is_distributed()
|
assert is_distributed()
|
||||||
LOGGER.debug("-------------- communicator say hello ------------------")
|
LOGGER.debug("-------------- communicator say hello ------------------")
|
||||||
|
return self.args
|
||||||
|
|
||||||
def __exit__(self, *args: List) -> None:
|
def __exit__(self, *args: List) -> None:
|
||||||
finalize()
|
finalize()
|
||||||
|
|||||||
@ -59,7 +59,7 @@ from typing import (
|
|||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import config, rabit
|
from . import collective, config
|
||||||
from ._typing import _T, FeatureNames, FeatureTypes
|
from ._typing import _T, FeatureNames, FeatureTypes
|
||||||
from .callback import TrainingCallback
|
from .callback import TrainingCallback
|
||||||
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||||
@ -112,7 +112,7 @@ TrainReturnT = TypedDict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RabitContext",
|
"CommunicatorContext",
|
||||||
"DaskDMatrix",
|
"DaskDMatrix",
|
||||||
"DaskDeviceQuantileDMatrix",
|
"DaskDeviceQuantileDMatrix",
|
||||||
"DaskXGBRegressor",
|
"DaskXGBRegressor",
|
||||||
@ -158,7 +158,7 @@ def _try_start_tracker(
|
|||||||
if isinstance(addrs[0], tuple):
|
if isinstance(addrs[0], tuple):
|
||||||
host_ip = addrs[0][0]
|
host_ip = addrs[0][0]
|
||||||
port = addrs[0][1]
|
port = addrs[0][1]
|
||||||
rabit_context = RabitTracker(
|
rabit_tracker = RabitTracker(
|
||||||
host_ip=get_host_ip(host_ip),
|
host_ip=get_host_ip(host_ip),
|
||||||
n_workers=n_workers,
|
n_workers=n_workers,
|
||||||
port=port,
|
port=port,
|
||||||
@ -168,12 +168,12 @@ def _try_start_tracker(
|
|||||||
addr = addrs[0]
|
addr = addrs[0]
|
||||||
assert isinstance(addr, str) or addr is None
|
assert isinstance(addr, str) or addr is None
|
||||||
host_ip = get_host_ip(addr)
|
host_ip = get_host_ip(addr)
|
||||||
rabit_context = RabitTracker(
|
rabit_tracker = RabitTracker(
|
||||||
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
||||||
)
|
)
|
||||||
env.update(rabit_context.worker_envs())
|
env.update(rabit_tracker.worker_envs())
|
||||||
rabit_context.start(n_workers)
|
rabit_tracker.start(n_workers)
|
||||||
thread = Thread(target=rabit_context.join)
|
thread = Thread(target=rabit_tracker.join)
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
except socket.error as e:
|
except socket.error as e:
|
||||||
@ -213,11 +213,11 @@ def _assert_dask_support() -> None:
|
|||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
class RabitContext(rabit.RabitContext):
|
class CommunicatorContext(collective.CommunicatorContext):
|
||||||
"""A context controlling rabit initialization and finalization."""
|
"""A context controlling collective communicator initialization and finalization."""
|
||||||
|
|
||||||
def __init__(self, args: List[bytes]) -> None:
|
def __init__(self, **args: Any) -> None:
|
||||||
super().__init__(args)
|
super().__init__(**args)
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
with distributed.worker_client() as client:
|
with distributed.worker_client() as client:
|
||||||
info = client.scheduler_info()
|
info = client.scheduler_info()
|
||||||
@ -227,9 +227,7 @@ class RabitContext(rabit.RabitContext):
|
|||||||
# not the same as task ID is string and "10" is sorted before "2") with dask
|
# not the same as task ID is string and "10" is sorted before "2") with dask
|
||||||
# worker ID. This outsources the rank assignment to dask and prevents
|
# worker ID. This outsources the rank assignment to dask and prevents
|
||||||
# non-deterministic issue.
|
# non-deterministic issue.
|
||||||
self.args.append(
|
self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{wid}]:" + str(worker.address)
|
||||||
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dconcat(value: Sequence[_T]) -> _T:
|
def dconcat(value: Sequence[_T]) -> _T:
|
||||||
@ -811,7 +809,7 @@ def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix:
|
|||||||
|
|
||||||
async def _get_rabit_args(
|
async def _get_rabit_args(
|
||||||
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
|
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
|
||||||
) -> List[bytes]:
|
) -> Dict[str, Union[str, int]]:
|
||||||
"""Get rabit context arguments from data distribution in DaskDMatrix."""
|
"""Get rabit context arguments from data distribution in DaskDMatrix."""
|
||||||
# There are 3 possible different addresses:
|
# There are 3 possible different addresses:
|
||||||
# 1. Provided by user via dask.config
|
# 1. Provided by user via dask.config
|
||||||
@ -854,9 +852,7 @@ async def _get_rabit_args(
|
|||||||
env = await client.run_on_scheduler(
|
env = await client.run_on_scheduler(
|
||||||
_start_tracker, n_workers, sched_addr, user_addr
|
_start_tracker, n_workers, sched_addr, user_addr
|
||||||
)
|
)
|
||||||
|
return env
|
||||||
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
|
|
||||||
return rabit_args
|
|
||||||
|
|
||||||
|
|
||||||
def _get_dask_config() -> Optional[Dict[str, Any]]:
|
def _get_dask_config() -> Optional[Dict[str, Any]]:
|
||||||
@ -911,7 +907,7 @@ async def _train_async(
|
|||||||
|
|
||||||
def dispatched_train(
|
def dispatched_train(
|
||||||
parameters: Dict,
|
parameters: Dict,
|
||||||
rabit_args: List[bytes],
|
rabit_args: Dict[str, Union[str, int]],
|
||||||
train_id: int,
|
train_id: int,
|
||||||
evals_name: List[str],
|
evals_name: List[str],
|
||||||
evals_id: List[int],
|
evals_id: List[int],
|
||||||
@ -935,7 +931,7 @@ async def _train_async(
|
|||||||
n_threads = dwnt
|
n_threads = dwnt
|
||||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||||
local_history: TrainingCallback.EvalsLog = {}
|
local_history: TrainingCallback.EvalsLog = {}
|
||||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
||||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||||
evals: List[Tuple[DMatrix, str]] = []
|
evals: List[Tuple[DMatrix, str]] = []
|
||||||
for i, ref in enumerate(refs):
|
for i, ref in enumerate(refs):
|
||||||
|
|||||||
@ -1,249 +0,0 @@
|
|||||||
"""Distributed XGBoost Rabit related API."""
|
|
||||||
import ctypes
|
|
||||||
from enum import IntEnum, unique
|
|
||||||
import logging
|
|
||||||
import pickle
|
|
||||||
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .core import _LIB, c_str, _check_call
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger("[xgboost.rabit]")
|
|
||||||
|
|
||||||
|
|
||||||
def _init_rabit() -> None:
|
|
||||||
"""internal library initializer."""
|
|
||||||
if _LIB is not None:
|
|
||||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
|
||||||
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
|
||||||
_LIB.RabitIsDistributed.restype = ctypes.c_int
|
|
||||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
|
||||||
|
|
||||||
|
|
||||||
def init(args: Optional[List[bytes]] = None) -> None:
|
|
||||||
"""Initialize the rabit library with arguments"""
|
|
||||||
if args is None:
|
|
||||||
args = []
|
|
||||||
arr = (ctypes.c_char_p * len(args))()
|
|
||||||
arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
|
|
||||||
_LIB.RabitInit(len(arr), arr)
|
|
||||||
|
|
||||||
|
|
||||||
def finalize() -> None:
|
|
||||||
"""Finalize the process, notify tracker everything is done."""
|
|
||||||
_LIB.RabitFinalize()
|
|
||||||
|
|
||||||
|
|
||||||
def get_rank() -> int:
|
|
||||||
"""Get rank of current process.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
rank : int
|
|
||||||
Rank of current process.
|
|
||||||
"""
|
|
||||||
ret = _LIB.RabitGetRank()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
|
||||||
"""Get total number workers.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
n : int
|
|
||||||
Total number of process.
|
|
||||||
"""
|
|
||||||
ret = _LIB.RabitGetWorldSize()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def is_distributed() -> int:
|
|
||||||
'''If rabit is distributed.'''
|
|
||||||
is_dist = _LIB.RabitIsDistributed()
|
|
||||||
return is_dist
|
|
||||||
|
|
||||||
|
|
||||||
def tracker_print(msg: Any) -> None:
|
|
||||||
"""Print message to the tracker.
|
|
||||||
|
|
||||||
This function can be used to communicate the information of
|
|
||||||
the progress to the tracker
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
msg : str
|
|
||||||
The message to be printed to tracker.
|
|
||||||
"""
|
|
||||||
if not isinstance(msg, str):
|
|
||||||
msg = str(msg)
|
|
||||||
is_dist = _LIB.RabitIsDistributed()
|
|
||||||
if is_dist != 0:
|
|
||||||
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
|
|
||||||
else:
|
|
||||||
print(msg.strip(), flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_processor_name() -> bytes:
|
|
||||||
"""Get the processor name.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
name : str
|
|
||||||
the name of processor(host)
|
|
||||||
"""
|
|
||||||
mxlen = 256
|
|
||||||
length = ctypes.c_ulong()
|
|
||||||
buf = ctypes.create_string_buffer(mxlen)
|
|
||||||
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
|
||||||
return buf.value
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T") # pylint:disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast(data: T, root: int) -> T:
|
|
||||||
"""Broadcast object from one node to all other nodes.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
data : any type that can be pickled
|
|
||||||
Input data, if current rank does not equal root, this can be None
|
|
||||||
root : int
|
|
||||||
Rank of the node to broadcast data from.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
object : int
|
|
||||||
the result of broadcast.
|
|
||||||
"""
|
|
||||||
rank = get_rank()
|
|
||||||
length = ctypes.c_ulong()
|
|
||||||
if root == rank:
|
|
||||||
assert data is not None, 'need to pass in data when broadcasting'
|
|
||||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
length.value = len(s)
|
|
||||||
# run first broadcast
|
|
||||||
_check_call(_LIB.RabitBroadcast(ctypes.byref(length),
|
|
||||||
ctypes.sizeof(ctypes.c_ulong), root))
|
|
||||||
if root != rank:
|
|
||||||
dptr = (ctypes.c_char * length.value)()
|
|
||||||
# run second
|
|
||||||
_check_call(_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
|
||||||
length.value, root))
|
|
||||||
data = pickle.loads(dptr.raw)
|
|
||||||
del dptr
|
|
||||||
else:
|
|
||||||
_check_call(_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
|
||||||
length.value, root))
|
|
||||||
del s
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# enumeration of dtypes
|
|
||||||
DTYPE_ENUM__ = {
|
|
||||||
np.dtype('int8'): 0,
|
|
||||||
np.dtype('uint8'): 1,
|
|
||||||
np.dtype('int32'): 2,
|
|
||||||
np.dtype('uint32'): 3,
|
|
||||||
np.dtype('int64'): 4,
|
|
||||||
np.dtype('uint64'): 5,
|
|
||||||
np.dtype('float32'): 6,
|
|
||||||
np.dtype('float64'): 7
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@unique
|
|
||||||
class Op(IntEnum):
|
|
||||||
'''Supported operations for rabit.'''
|
|
||||||
MAX = 0
|
|
||||||
MIN = 1
|
|
||||||
SUM = 2
|
|
||||||
OR = 3
|
|
||||||
|
|
||||||
|
|
||||||
def allreduce( # pylint:disable=invalid-name
|
|
||||||
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Perform allreduce, return the result.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
data :
|
|
||||||
Input data.
|
|
||||||
op :
|
|
||||||
Reduction operators, can be MIN, MAX, SUM, BITOR
|
|
||||||
prepare_fun :
|
|
||||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
|
||||||
will be called by the function before performing allreduce, to initialize the data
|
|
||||||
If the result of Allreduce can be recovered directly,
|
|
||||||
then prepare_fun will NOT be called
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
result :
|
|
||||||
The result of allreduce, have same shape as data
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function is not thread-safe.
|
|
||||||
"""
|
|
||||||
if not isinstance(data, np.ndarray):
|
|
||||||
raise Exception('allreduce only takes in numpy.ndarray')
|
|
||||||
buf = data.ravel()
|
|
||||||
if buf.base is data.base:
|
|
||||||
buf = buf.copy()
|
|
||||||
if buf.dtype not in DTYPE_ENUM__:
|
|
||||||
raise Exception(f"data type {buf.dtype} not supported")
|
|
||||||
if prepare_fun is None:
|
|
||||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
|
||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
|
||||||
int(op), None, None))
|
|
||||||
else:
|
|
||||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
|
||||||
|
|
||||||
def pfunc(_: Any) -> None:
|
|
||||||
"""prepare function."""
|
|
||||||
fn = cast(Callable[[np.ndarray], None], prepare_fun)
|
|
||||||
fn(data)
|
|
||||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
|
||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
|
||||||
op, func_ptr(pfunc), None))
|
|
||||||
return buf
|
|
||||||
|
|
||||||
|
|
||||||
def version_number() -> int:
|
|
||||||
"""Returns version number of current stored model.
|
|
||||||
|
|
||||||
This means how many calls to CheckPoint we made so far.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
version : int
|
|
||||||
Version number of currently stored model
|
|
||||||
"""
|
|
||||||
ret = _LIB.RabitVersionNumber()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class RabitContext:
|
|
||||||
"""A context controlling rabit initialization and finalization."""
|
|
||||||
|
|
||||||
def __init__(self, args: List[bytes] = None) -> None:
|
|
||||||
if args is None:
|
|
||||||
args = []
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
|
||||||
init(self.args)
|
|
||||||
assert is_distributed()
|
|
||||||
LOGGER.debug("-------------- rabit say hello ------------------")
|
|
||||||
|
|
||||||
def __exit__(self, *args: List) -> None:
|
|
||||||
finalize()
|
|
||||||
LOGGER.debug("--------------- rabit say bye ------------------")
|
|
||||||
|
|
||||||
|
|
||||||
# initialization script
|
|
||||||
_init_rabit()
|
|
||||||
@ -2,6 +2,7 @@
|
|||||||
"""Xgboost pyspark integration submodule for core code."""
|
"""Xgboost pyspark integration submodule for core code."""
|
||||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||||
# pylint: disable=too-few-public-methods, too-many-lines
|
# pylint: disable=too-few-public-methods, too-many-lines
|
||||||
|
import json
|
||||||
from typing import Iterator, Optional, Tuple
|
from typing import Iterator, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -57,7 +58,7 @@ from .params import (
|
|||||||
HasQueryIdCol,
|
HasQueryIdCol,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
RabitContext,
|
CommunicatorContext,
|
||||||
_get_args_from_message_list,
|
_get_args_from_message_list,
|
||||||
_get_default_params_from_func,
|
_get_default_params_from_func,
|
||||||
_get_gpu_id,
|
_get_gpu_id,
|
||||||
@ -769,7 +770,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
):
|
):
|
||||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||||
|
|
||||||
_rabit_args = ""
|
_rabit_args = {}
|
||||||
if context.partitionId() == 0:
|
if context.partitionId() == 0:
|
||||||
get_logger("XGBoostPySpark").info(
|
get_logger("XGBoostPySpark").info(
|
||||||
"booster params: %s\n"
|
"booster params: %s\n"
|
||||||
@ -780,12 +781,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
dmatrix_kwargs,
|
dmatrix_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
_rabit_args = str(_get_rabit_args(context, num_workers))
|
_rabit_args = _get_rabit_args(context, num_workers)
|
||||||
|
|
||||||
messages = context.allGather(message=str(_rabit_args))
|
messages = context.allGather(message=json.dumps(_rabit_args))
|
||||||
_rabit_args = _get_args_from_message_list(messages)
|
_rabit_args = _get_args_from_message_list(messages)
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
with RabitContext(_rabit_args, context):
|
with CommunicatorContext(context, **_rabit_args):
|
||||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||||
pandas_df_iter,
|
pandas_df_iter,
|
||||||
features_cols_names,
|
features_cols_names,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
"""Xgboost pyspark integration submodule for helper functions."""
|
"""Xgboost pyspark integration submodule for helper functions."""
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -9,7 +10,7 @@ import pyspark
|
|||||||
from pyspark.sql.session import SparkSession
|
from pyspark.sql.session import SparkSession
|
||||||
from xgboost.tracker import RabitTracker
|
from xgboost.tracker import RabitTracker
|
||||||
|
|
||||||
from xgboost import rabit
|
from xgboost import collective
|
||||||
|
|
||||||
|
|
||||||
def get_class_name(cls):
|
def get_class_name(cls):
|
||||||
@ -36,21 +37,21 @@ def _get_default_params_from_func(func, unsupported_set):
|
|||||||
return filtered_params_dict
|
return filtered_params_dict
|
||||||
|
|
||||||
|
|
||||||
class RabitContext:
|
class CommunicatorContext:
|
||||||
"""
|
"""
|
||||||
A context controlling rabit initialization and finalization.
|
A context controlling collective communicator initialization and finalization.
|
||||||
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
|
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, context):
|
def __init__(self, context, **args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode())
|
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
rabit.init(self.args)
|
collective.init(**self.args)
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
rabit.finalize()
|
collective.finalize()
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(context, n_workers):
|
def _start_tracker(context, n_workers):
|
||||||
@ -74,8 +75,7 @@ def _get_rabit_args(context, n_workers):
|
|||||||
"""
|
"""
|
||||||
# pylint: disable=consider-using-f-string
|
# pylint: disable=consider-using-f-string
|
||||||
env = _start_tracker(context, n_workers)
|
env = _start_tracker(context, n_workers)
|
||||||
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
|
return env
|
||||||
return rabit_args
|
|
||||||
|
|
||||||
|
|
||||||
def _get_host_ip(context):
|
def _get_host_ip(context):
|
||||||
@ -95,7 +95,7 @@ def _get_args_from_message_list(messages):
|
|||||||
if message != "":
|
if message != "":
|
||||||
output = message
|
output = message
|
||||||
break
|
break
|
||||||
return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")]
|
return json.loads(output)
|
||||||
|
|
||||||
|
|
||||||
def _get_spark_session():
|
def _get_spark_session():
|
||||||
|
|||||||
@ -6,9 +6,7 @@ set(RABIT_SOURCES
|
|||||||
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
|
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
|
||||||
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
|
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
|
||||||
|
|
||||||
if (PLUGIN_FEDERATED)
|
if (RABIT_BUILD_MPI)
|
||||||
# Skip the engine if the Federated Learning plugin is enabled.
|
|
||||||
elseif (RABIT_BUILD_MPI)
|
|
||||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
|
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
|
||||||
elseif (RABIT_MOCK)
|
elseif (RABIT_MOCK)
|
||||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
|
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
|
||||||
|
|||||||
@ -1,11 +1,8 @@
|
|||||||
// Copyright (c) 2014-2022 by Contributors
|
// Copyright (c) 2014-2022 by Contributors
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <rabit/c_api.h>
|
#include <rabit/c_api.h>
|
||||||
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <algorithm>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -22,12 +19,11 @@
|
|||||||
|
|
||||||
#include "c_api_error.h"
|
#include "c_api_error.h"
|
||||||
#include "c_api_utils.h"
|
#include "c_api_utils.h"
|
||||||
#include "../collective/communicator.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/charconv.h"
|
#include "../common/charconv.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "../data/simple_dmatrix.h"
|
#include "../data/simple_dmatrix.h"
|
||||||
#include "../data/proxy_dmatrix.h"
|
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
#include "../../plugin/federated/federated_server.h"
|
#include "../../plugin/federated/federated_server.h"
|
||||||
@ -215,7 +211,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
|
|||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||||
#else
|
#else
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||||
<< "will split data among workers";
|
<< "will split data among workers";
|
||||||
load_row_split = true;
|
load_row_split = true;
|
||||||
@ -1560,44 +1556,42 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
using xgboost::collective::Communicator;
|
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||||
Json config { Json::Load(StringView{json_config}) };
|
Json config{Json::Load(StringView{json_config})};
|
||||||
Communicator::Init(config);
|
collective::Init(config);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorFinalize() {
|
XGB_DLL int XGCommunicatorFinalize() {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
Communicator::Finalize();
|
collective::Finalize();
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorGetRank() {
|
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||||
return Communicator::Get()->GetRank();
|
return collective::GetRank();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorGetWorldSize() {
|
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
||||||
return Communicator::Get()->GetWorldSize();
|
return collective::GetWorldSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorIsDistributed() {
|
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
||||||
return Communicator::Get()->IsDistributed();
|
return collective::IsDistributed();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
Communicator::Get()->Print(message);
|
collective::Print(message);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||||
local.ret_str = Communicator::Get()->GetProcessorName();
|
local.ret_str = collective::GetProcessorName();
|
||||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||||
*name_str = local.ret_str.c_str();
|
*name_str = local.ret_str.c_str();
|
||||||
API_END();
|
API_END();
|
||||||
@ -1605,16 +1599,14 @@ XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
|||||||
|
|
||||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
collective::Broadcast(send_receive_buffer, size, root);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||||
int enum_op) {
|
int enum_op) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
Communicator::Get()->AllReduce(
|
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||||
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
|
|
||||||
static_cast<xgboost::collective::Operation>(enum_op));
|
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "collective/communicator-inl.h"
|
||||||
#include "common/common.h"
|
#include "common/common.h"
|
||||||
#include "common/config.h"
|
#include "common/config.h"
|
||||||
#include "common/io.h"
|
#include "common/io.h"
|
||||||
@ -156,7 +157,7 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
|||||||
if (name_pred == "stdout") {
|
if (name_pred == "stdout") {
|
||||||
save_period = 0;
|
save_period = 0;
|
||||||
}
|
}
|
||||||
if (dsplit == 0 && rabit::IsDistributed()) {
|
if (dsplit == 0 && collective::IsDistributed()) {
|
||||||
dsplit = 2;
|
dsplit = 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -186,26 +187,22 @@ class CLI {
|
|||||||
kHelp
|
kHelp
|
||||||
} print_info_ {kNone};
|
} print_info_ {kNone};
|
||||||
|
|
||||||
int ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
|
void ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
|
||||||
learner_.reset(Learner::Create(matrices));
|
learner_.reset(Learner::Create(matrices));
|
||||||
int version = rabit::LoadCheckPoint();
|
if (param_.model_in != CLIParam::kNull) {
|
||||||
if (version == 0) {
|
this->LoadModel(param_.model_in, learner_.get());
|
||||||
if (param_.model_in != CLIParam::kNull) {
|
learner_->SetParams(param_.cfg);
|
||||||
this->LoadModel(param_.model_in, learner_.get());
|
} else {
|
||||||
learner_->SetParams(param_.cfg);
|
learner_->SetParams(param_.cfg);
|
||||||
} else {
|
|
||||||
learner_->SetParams(param_.cfg);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
learner_->Configure();
|
learner_->Configure();
|
||||||
return version;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CLITrain() {
|
void CLITrain() {
|
||||||
const double tstart_data_load = dmlc::GetTime();
|
const double tstart_data_load = dmlc::GetTime();
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
std::string pname = rabit::GetProcessorName();
|
std::string pname = collective::GetProcessorName();
|
||||||
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
|
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
|
||||||
}
|
}
|
||||||
// load in data.
|
// load in data.
|
||||||
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
|
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
|
||||||
@ -230,48 +227,45 @@ class CLI {
|
|||||||
eval_data_names.emplace_back("train");
|
eval_data_names.emplace_back("train");
|
||||||
}
|
}
|
||||||
// initialize the learner.
|
// initialize the learner.
|
||||||
int32_t version = this->ResetLearner(cache_mats);
|
this->ResetLearner(cache_mats);
|
||||||
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
|
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
|
||||||
<< " sec";
|
<< " sec";
|
||||||
|
|
||||||
// start training.
|
// start training.
|
||||||
const double start = dmlc::GetTime();
|
const double start = dmlc::GetTime();
|
||||||
|
int32_t version = 0;
|
||||||
for (int i = version / 2; i < param_.num_round; ++i) {
|
for (int i = version / 2; i < param_.num_round; ++i) {
|
||||||
double elapsed = dmlc::GetTime() - start;
|
double elapsed = dmlc::GetTime() - start;
|
||||||
if (version % 2 == 0) {
|
if (version % 2 == 0) {
|
||||||
LOG(INFO) << "boosting round " << i << ", " << elapsed
|
LOG(INFO) << "boosting round " << i << ", " << elapsed
|
||||||
<< " sec elapsed";
|
<< " sec elapsed";
|
||||||
learner_->UpdateOneIter(i, dtrain);
|
learner_->UpdateOneIter(i, dtrain);
|
||||||
rabit::CheckPoint();
|
|
||||||
version += 1;
|
version += 1;
|
||||||
}
|
}
|
||||||
CHECK_EQ(version, rabit::VersionNumber());
|
|
||||||
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
|
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
if (rabit::GetRank() == 0) {
|
if (collective::GetRank() == 0) {
|
||||||
LOG(TRACKER) << res;
|
LOG(TRACKER) << res;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(CONSOLE) << res;
|
LOG(CONSOLE) << res;
|
||||||
}
|
}
|
||||||
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
|
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
|
||||||
rabit::GetRank() == 0) {
|
collective::GetRank() == 0) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||||
<< i + 1 << ".model";
|
<< i + 1 << ".model";
|
||||||
this->SaveModel(os.str(), learner_.get());
|
this->SaveModel(os.str(), learner_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
rabit::CheckPoint();
|
|
||||||
version += 1;
|
version += 1;
|
||||||
CHECK_EQ(version, rabit::VersionNumber());
|
|
||||||
}
|
}
|
||||||
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
|
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
|
||||||
<< " sec";
|
<< " sec";
|
||||||
// always save final round
|
// always save final round
|
||||||
if ((param_.save_period == 0 ||
|
if ((param_.save_period == 0 ||
|
||||||
param_.num_round % param_.save_period != 0) &&
|
param_.num_round % param_.save_period != 0) &&
|
||||||
rabit::GetRank() == 0) {
|
collective::GetRank() == 0) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
if (param_.model_out == CLIParam::kNull) {
|
if (param_.model_out == CLIParam::kNull) {
|
||||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||||
@ -467,7 +461,6 @@ class CLI {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
rabit::Init(argc, argv);
|
|
||||||
std::string config_path = argv[1];
|
std::string config_path = argv[1];
|
||||||
|
|
||||||
common::ConfigParser cp(config_path);
|
common::ConfigParser cp(config_path);
|
||||||
@ -480,6 +473,13 @@ class CLI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize the collective communicator.
|
||||||
|
Json json{JsonObject()};
|
||||||
|
for (auto& kv : cfg) {
|
||||||
|
json[kv.first] = String(kv.second);
|
||||||
|
}
|
||||||
|
collective::Init(json);
|
||||||
|
|
||||||
param_.Configure(cfg);
|
param_.Configure(cfg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -517,7 +517,7 @@ class CLI {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~CLI() {
|
~CLI() {
|
||||||
rabit::Finalize();
|
collective::Finalize();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
208
src/collective/communicator-inl.h
Normal file
208
src/collective/communicator-inl.h
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "communicator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Initialize the collective communicator.
|
||||||
|
*
|
||||||
|
* Currently the communicator API is experimental, function signatures may change in the future
|
||||||
|
* without notice.
|
||||||
|
*
|
||||||
|
* Call this once before using anything.
|
||||||
|
*
|
||||||
|
* The additional configuration is not required. Usually the communicator will detect settings
|
||||||
|
* from environment variables.
|
||||||
|
*
|
||||||
|
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||||
|
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||||
|
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||||
|
* * mpi: Use MPI.
|
||||||
|
* * federated: Use the gRPC interface for Federated Learning.
|
||||||
|
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||||
|
* - rabit_tracker_uri: Hostname of the tracker.
|
||||||
|
* - rabit_tracker_port: Port number of the tracker.
|
||||||
|
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||||
|
* - rabit_world_size: Total number of workers.
|
||||||
|
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||||
|
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||||
|
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||||
|
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||||
|
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||||
|
* - rabit_debug: Enable debugging.
|
||||||
|
* - rabit_timeout: Enable timeout.
|
||||||
|
* - rabit_timeout_sec: Timeout in seconds.
|
||||||
|
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||||
|
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||||
|
* environment variables):
|
||||||
|
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||||
|
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||||
|
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||||
|
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||||
|
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||||
|
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||||
|
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||||
|
* lower case for runtime configuration):
|
||||||
|
* - federated_server_address: Address of the federated server.
|
||||||
|
* - federated_world_size: Number of federated workers.
|
||||||
|
* - federated_rank: Rank of the current worker.
|
||||||
|
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||||
|
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||||
|
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||||
|
*/
|
||||||
|
inline void Init(Json const& config) {
|
||||||
|
Communicator::Init(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Finalize the collective communicator.
|
||||||
|
*
|
||||||
|
* Call this function after you finished all jobs.
|
||||||
|
*/
|
||||||
|
inline void Finalize() { Communicator::Finalize(); }
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get rank of current process.
|
||||||
|
*
|
||||||
|
* \return Rank of the worker.
|
||||||
|
*/
|
||||||
|
inline int GetRank() { return Communicator::Get()->GetRank(); }
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get total number of processes.
|
||||||
|
*
|
||||||
|
* \return Total world size.
|
||||||
|
*/
|
||||||
|
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get if the communicator is distributed.
|
||||||
|
*
|
||||||
|
* \return True if the communicator is distributed.
|
||||||
|
*/
|
||||||
|
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Print the message to the communicator.
|
||||||
|
*
|
||||||
|
* This function can be used to communicate the information of the progress to the user who monitors
|
||||||
|
* the communicator.
|
||||||
|
*
|
||||||
|
* \param message The message to be printed.
|
||||||
|
*/
|
||||||
|
inline void Print(char const *message) { Communicator::Get()->Print(message); }
|
||||||
|
|
||||||
|
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get the name of the processor.
|
||||||
|
*
|
||||||
|
* \return Name of the processor.
|
||||||
|
*/
|
||||||
|
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int a = 1;
|
||||||
|
* Broadcast(&a, sizeof(a), root);
|
||||||
|
*
|
||||||
|
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||||
|
* \param size Size of the data.
|
||||||
|
* \param root The process rank to broadcast from.
|
||||||
|
*/
|
||||||
|
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
|
||||||
|
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||||
|
size_t size = sendrecv_data->length();
|
||||||
|
Broadcast(&size, sizeof(size), root);
|
||||||
|
if (sendrecv_data->length() != size) {
|
||||||
|
sendrecv_data->resize(size);
|
||||||
|
}
|
||||||
|
if (size != 0) {
|
||||||
|
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||||
|
*
|
||||||
|
* Example Usage: the following code gives sum of the result
|
||||||
|
* vector<int> data(10);
|
||||||
|
* ...
|
||||||
|
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||||
|
* ...
|
||||||
|
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||||
|
* \param count Number of elements to be reduced.
|
||||||
|
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||||
|
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||||
|
*/
|
||||||
|
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
|
||||||
|
static_cast<Operation>(op));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specialization for size_t, which is implementation defined, so it might or might not
|
||||||
|
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||||
|
template <Operation op, typename T,
|
||||||
|
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
|
||||||
|
inline void Allreduce(T *send_receive_buffer, size_t count) {
|
||||||
|
static_assert(sizeof(T) == sizeof(uint64_t), "");
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(float *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Operation op>
|
||||||
|
inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||||
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -3,6 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
#include "communicator.h"
|
#include "communicator.h"
|
||||||
|
|
||||||
|
#include "noop_communicator.h"
|
||||||
#include "rabit_communicator.h"
|
#include "rabit_communicator.h"
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
@ -12,14 +13,10 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace collective {
|
namespace collective {
|
||||||
|
|
||||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
|
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||||
thread_local CommunicatorType Communicator::type_{};
|
thread_local CommunicatorType Communicator::type_{};
|
||||||
|
|
||||||
void Communicator::Init(Json const& config) {
|
void Communicator::Init(Json const& config) {
|
||||||
if (communicator_) {
|
|
||||||
LOG(FATAL) << "Communicator can only be initialized once.";
|
|
||||||
}
|
|
||||||
|
|
||||||
auto type = GetTypeFromEnv();
|
auto type = GetTypeFromEnv();
|
||||||
auto const arg = GetTypeFromConfig(config);
|
auto const arg = GetTypeFromConfig(config);
|
||||||
if (arg != CommunicatorType::kUnknown) {
|
if (arg != CommunicatorType::kUnknown) {
|
||||||
@ -51,7 +48,7 @@ void Communicator::Init(Json const& config) {
|
|||||||
#ifndef XGBOOST_USE_CUDA
|
#ifndef XGBOOST_USE_CUDA
|
||||||
void Communicator::Finalize() {
|
void Communicator::Finalize() {
|
||||||
communicator_->Shutdown();
|
communicator_->Shutdown();
|
||||||
communicator_.reset(nullptr);
|
communicator_.reset(new NoOpCommunicator());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#include "communicator.h"
|
#include "communicator.h"
|
||||||
#include "device_communicator.cuh"
|
#include "device_communicator.cuh"
|
||||||
#include "device_communicator_adapter.cuh"
|
#include "device_communicator_adapter.cuh"
|
||||||
|
#include "noop_communicator.h"
|
||||||
#ifdef XGBOOST_USE_NCCL
|
#ifdef XGBOOST_USE_NCCL
|
||||||
#include "nccl_device_communicator.cuh"
|
#include "nccl_device_communicator.cuh"
|
||||||
#endif
|
#endif
|
||||||
@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat
|
|||||||
|
|
||||||
void Communicator::Finalize() {
|
void Communicator::Finalize() {
|
||||||
communicator_->Shutdown();
|
communicator_->Shutdown();
|
||||||
communicator_.reset(nullptr);
|
communicator_.reset(new NoOpCommunicator());
|
||||||
device_ordinal_ = -1;
|
device_ordinal_ = -1;
|
||||||
device_communicator_.reset(nullptr);
|
device_communicator_.reset(nullptr);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,40 +23,6 @@ enum class DataType {
|
|||||||
kDouble = 7
|
kDouble = 7
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Get the size of the data type. */
|
|
||||||
inline std::size_t GetTypeSize(DataType data_type) {
|
|
||||||
std::size_t size{0};
|
|
||||||
switch (data_type) {
|
|
||||||
case DataType::kInt8:
|
|
||||||
size = sizeof(std::int8_t);
|
|
||||||
break;
|
|
||||||
case DataType::kUInt8:
|
|
||||||
size = sizeof(std::uint8_t);
|
|
||||||
break;
|
|
||||||
case DataType::kInt32:
|
|
||||||
size = sizeof(std::int32_t);
|
|
||||||
break;
|
|
||||||
case DataType::kUInt32:
|
|
||||||
size = sizeof(std::uint32_t);
|
|
||||||
break;
|
|
||||||
case DataType::kInt64:
|
|
||||||
size = sizeof(std::int64_t);
|
|
||||||
break;
|
|
||||||
case DataType::kUInt64:
|
|
||||||
size = sizeof(std::uint64_t);
|
|
||||||
break;
|
|
||||||
case DataType::kFloat:
|
|
||||||
size = sizeof(float);
|
|
||||||
break;
|
|
||||||
case DataType::kDouble:
|
|
||||||
size = sizeof(double);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
LOG(FATAL) << "Unknown data type.";
|
|
||||||
}
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @brief Defines the reduction operation. */
|
/** @brief Defines the reduction operation. */
|
||||||
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,28 @@ class DeviceCommunicator {
|
|||||||
* @param send_receive_buffer Buffer storing the data.
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
* @param count Number of elements in the buffer.
|
* @param count Number of elements in the buffer.
|
||||||
*/
|
*/
|
||||||
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
|
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||||
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
|
* @param count Number of elements in the buffer.
|
||||||
|
*/
|
||||||
|
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||||
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
|
* @param count Number of elements in the buffer.
|
||||||
|
*/
|
||||||
|
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||||
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
|
* @param count Number of elements in the buffer.
|
||||||
|
*/
|
||||||
|
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Gather variable-length values from all processes.
|
* @brief Gather variable-length values from all processes.
|
||||||
|
|||||||
@ -23,17 +23,28 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
|||||||
|
|
||||||
~DeviceCommunicatorAdapter() override = default;
|
~DeviceCommunicatorAdapter() override = default;
|
||||||
|
|
||||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
|
||||||
auto size = count * sizeof(double);
|
}
|
||||||
host_buffer_.reserve(size);
|
|
||||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||||
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
|
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
|
||||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
}
|
||||||
|
|
||||||
|
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||||
|
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||||
|
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||||
dh::caching_device_vector<char> *receive_buffer) override {
|
dh::caching_device_vector<char> *receive_buffer) override {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
int const world_size = communicator_->GetWorldSize();
|
int const world_size = communicator_->GetWorldSize();
|
||||||
int const rank = communicator_->GetRank();
|
int const rank = communicator_->GetRank();
|
||||||
@ -66,6 +77,20 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
template <collective::DataType data_type, typename T>
|
||||||
|
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
auto size = count * sizeof(T);
|
||||||
|
host_buffer_.reserve(size);
|
||||||
|
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||||
|
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||||
|
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||||
|
}
|
||||||
|
|
||||||
int const device_ordinal_;
|
int const device_ordinal_;
|
||||||
Communicator *communicator_;
|
Communicator *communicator_;
|
||||||
/// Host buffer used to call communicator functions.
|
/// Host buffer used to call communicator functions.
|
||||||
|
|||||||
@ -24,6 +24,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
int32_t const rank = communicator_->GetRank();
|
int32_t const rank = communicator_->GetRank();
|
||||||
int32_t const world = communicator_->GetWorldSize();
|
int32_t const world = communicator_->GetWorldSize();
|
||||||
|
|
||||||
|
if (world == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||||
@ -52,8 +56,15 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~NcclDeviceCommunicator() override {
|
~NcclDeviceCommunicator() override {
|
||||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
ncclCommDestroy(nccl_comm_);
|
return;
|
||||||
|
}
|
||||||
|
if (cuda_stream_) {
|
||||||
|
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||||
|
}
|
||||||
|
if (nccl_comm_) {
|
||||||
|
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||||
|
}
|
||||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||||
@ -61,16 +72,28 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
|
||||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
|
}
|
||||||
ncclSum, nccl_comm_, cuda_stream_));
|
|
||||||
allreduce_bytes_ += count * sizeof(double);
|
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||||
allreduce_calls_ += 1;
|
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||||
|
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||||
|
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||||
dh::caching_device_vector<char> *receive_buffer) override {
|
dh::caching_device_vector<char> *receive_buffer) override {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
int const world_size = communicator_->GetWorldSize();
|
int const world_size = communicator_->GetWorldSize();
|
||||||
int const rank = communicator_->GetRank();
|
int const rank = communicator_->GetRank();
|
||||||
@ -95,6 +118,9 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Synchronize() override {
|
void Synchronize() override {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||||
}
|
}
|
||||||
@ -136,6 +162,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <ncclDataType_t data_type, typename T>
|
||||||
|
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||||
|
nccl_comm_, cuda_stream_));
|
||||||
|
allreduce_bytes_ += count * sizeof(T);
|
||||||
|
allreduce_calls_ += 1;
|
||||||
|
}
|
||||||
|
|
||||||
int const device_ordinal_;
|
int const device_ordinal_;
|
||||||
Communicator *communicator_;
|
Communicator *communicator_;
|
||||||
ncclComm_t nccl_comm_{};
|
ncclComm_t nccl_comm_{};
|
||||||
|
|||||||
30
src/collective/noop_communicator.h
Normal file
30
src/collective/noop_communicator.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "communicator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A no-op communicator, used for non-distributed training.
|
||||||
|
*/
|
||||||
|
class NoOpCommunicator : public Communicator {
|
||||||
|
public:
|
||||||
|
NoOpCommunicator() : Communicator(1, 0) {}
|
||||||
|
bool IsDistributed() const override { return false; }
|
||||||
|
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
|
Operation op) override {}
|
||||||
|
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
|
||||||
|
std::string GetProcessorName() override { return ""; }
|
||||||
|
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Shutdown() override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,139 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2017-2019 XGBoost contributors
|
|
||||||
*
|
|
||||||
* \brief Utilities for CUDA.
|
|
||||||
*/
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
|
||||||
#include <nccl.h>
|
|
||||||
#endif // #ifdef XGBOOST_USE_NCCL
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "device_helpers.cuh"
|
|
||||||
|
|
||||||
namespace dh {
|
|
||||||
|
|
||||||
constexpr std::size_t kUuidLength =
|
|
||||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
|
||||||
|
|
||||||
void GetCudaUUID(int device_ord, xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
|
||||||
cudaDeviceProp prob;
|
|
||||||
safe_cuda(cudaGetDeviceProperties(&prob, device_ord));
|
|
||||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
|
||||||
std::stringstream ss;
|
|
||||||
for (auto v : uuid) {
|
|
||||||
ss << std::hex << v;
|
|
||||||
}
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
|
||||||
void NcclAllReducer::DoInit(int _device_ordinal) {
|
|
||||||
int32_t const rank = rabit::GetRank();
|
|
||||||
int32_t const world = rabit::GetWorldSize();
|
|
||||||
if (world == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
|
||||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
|
||||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
|
||||||
GetCudaUUID(_device_ordinal, s_this_uuid);
|
|
||||||
|
|
||||||
// No allgather yet.
|
|
||||||
rabit::Allreduce<rabit::op::Sum, uint64_t>(uuids.data(), uuids.size());
|
|
||||||
|
|
||||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);;
|
|
||||||
size_t j = 0;
|
|
||||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
|
||||||
converted[j] =
|
|
||||||
xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
|
||||||
j++;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto iter = std::unique(converted.begin(), converted.end());
|
|
||||||
auto n_uniques = std::distance(converted.begin(), iter);
|
|
||||||
|
|
||||||
CHECK_EQ(n_uniques, world)
|
|
||||||
<< "Multiple processes within communication group running on same CUDA "
|
|
||||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
|
||||||
|
|
||||||
|
|
||||||
id_ = GetUniqueId();
|
|
||||||
dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank));
|
|
||||||
safe_cuda(cudaStreamCreate(&stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
void NcclAllReducer::DoAllGather(void const *data, size_t length_bytes,
|
|
||||||
std::vector<size_t> *segments,
|
|
||||||
dh::caching_device_vector<char> *recvbuf) {
|
|
||||||
int32_t world = rabit::GetWorldSize();
|
|
||||||
segments->clear();
|
|
||||||
segments->resize(world, 0);
|
|
||||||
segments->at(rabit::GetRank()) = length_bytes;
|
|
||||||
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
|
|
||||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0);
|
|
||||||
recvbuf->resize(total_bytes);
|
|
||||||
|
|
||||||
size_t offset = 0;
|
|
||||||
safe_nccl(ncclGroupStart());
|
|
||||||
for (int32_t i = 0; i < world; ++i) {
|
|
||||||
size_t as_bytes = segments->at(i);
|
|
||||||
safe_nccl(
|
|
||||||
ncclBroadcast(data, recvbuf->data().get() + offset,
|
|
||||||
as_bytes, ncclChar, i, comm_, stream_));
|
|
||||||
offset += as_bytes;
|
|
||||||
}
|
|
||||||
safe_nccl(ncclGroupEnd());
|
|
||||||
}
|
|
||||||
|
|
||||||
NcclAllReducer::~NcclAllReducer() {
|
|
||||||
if (initialised_) {
|
|
||||||
dh::safe_cuda(cudaStreamDestroy(stream_));
|
|
||||||
ncclCommDestroy(comm_);
|
|
||||||
}
|
|
||||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
|
||||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
|
||||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
|
||||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
void RabitAllReducer::DoInit(int _device_ordinal) {
|
|
||||||
#if !defined(XGBOOST_USE_FEDERATED)
|
|
||||||
if (rabit::IsDistributed()) {
|
|
||||||
LOG(CONSOLE) << "XGBoost is not compiled with NCCL, falling back to Rabit.";
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void RabitAllReducer::DoAllGather(void const *data, size_t length_bytes,
|
|
||||||
std::vector<size_t> *segments,
|
|
||||||
dh::caching_device_vector<char> *recvbuf) {
|
|
||||||
size_t world = rabit::GetWorldSize();
|
|
||||||
segments->clear();
|
|
||||||
segments->resize(world, 0);
|
|
||||||
segments->at(rabit::GetRank()) = length_bytes;
|
|
||||||
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
|
|
||||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
|
||||||
recvbuf->resize(total_bytes);
|
|
||||||
|
|
||||||
sendrecvbuf_.reserve(total_bytes);
|
|
||||||
auto rank = rabit::GetRank();
|
|
||||||
size_t offset = 0;
|
|
||||||
for (int32_t i = 0; i < world; ++i) {
|
|
||||||
size_t as_bytes = segments->at(i);
|
|
||||||
if (i == rank) {
|
|
||||||
safe_cuda(
|
|
||||||
cudaMemcpy(sendrecvbuf_.data() + offset, data, segments->at(rank), cudaMemcpyDefault));
|
|
||||||
}
|
|
||||||
rabit::Broadcast(sendrecvbuf_.data() + offset, as_bytes, i);
|
|
||||||
offset += as_bytes;
|
|
||||||
}
|
|
||||||
safe_cuda(cudaMemcpy(recvbuf->data().get(), sendrecvbuf_.data(), total_bytes, cudaMemcpyDefault));
|
|
||||||
}
|
|
||||||
#endif // XGBOOST_USE_NCCL
|
|
||||||
|
|
||||||
} // namespace dh
|
|
||||||
@ -19,7 +19,6 @@
|
|||||||
#include <thrust/unique.h>
|
#include <thrust/unique.h>
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#include <cub/util_allocator.cuh>
|
#include <cub/util_allocator.cuh>
|
||||||
|
|
||||||
@ -36,6 +35,7 @@
|
|||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "xgboost/global_config.h"
|
#include "xgboost/global_config.h"
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "algorithm.cuh"
|
#include "algorithm.cuh"
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ inline detail::MemoryLogger &GlobalMemoryLogger() {
|
|||||||
// dh::DebugSyncDevice(__FILE__, __LINE__);
|
// dh::DebugSyncDevice(__FILE__, __LINE__);
|
||||||
inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
|
inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
|
||||||
if (file != "" && line != -1) {
|
if (file != "" && line != -1) {
|
||||||
auto rank = rabit::GetRank();
|
auto rank = xgboost::collective::GetRank();
|
||||||
LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line;
|
LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line;
|
||||||
}
|
}
|
||||||
safe_cuda(cudaDeviceSynchronize());
|
safe_cuda(cudaDeviceSynchronize());
|
||||||
@ -423,7 +423,7 @@ using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
|
|||||||
|
|
||||||
inline void ThrowOOMError(std::string const& err, size_t bytes) {
|
inline void ThrowOOMError(std::string const& err, size_t bytes) {
|
||||||
auto device = CurrentDevice();
|
auto device = CurrentDevice();
|
||||||
auto rank = rabit::GetRank();
|
auto rank = xgboost::collective::GetRank();
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
|
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
|
||||||
<< "- Free memory: " << AvailableMemory(device) << "\n"
|
<< "- Free memory: " << AvailableMemory(device) << "\n"
|
||||||
@ -737,512 +737,6 @@ using TypedDiscard =
|
|||||||
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
|
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
|
||||||
detail::TypedDiscard<T>>;
|
detail::TypedDiscard<T>>;
|
||||||
|
|
||||||
/**
|
|
||||||
* \class AllReducer
|
|
||||||
*
|
|
||||||
* \brief All reducer class that manages its own communication group and
|
|
||||||
* streams. Must be initialised before use. If XGBoost is compiled without NCCL,
|
|
||||||
* this falls back to use Rabit.
|
|
||||||
*/
|
|
||||||
template <typename AllReducer>
|
|
||||||
class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
|
||||||
public:
|
|
||||||
virtual ~AllReducerBase() = default;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Initialise with the desired device ordinal for this allreducer.
|
|
||||||
*
|
|
||||||
* \param device_ordinal The device ordinal.
|
|
||||||
*/
|
|
||||||
void Init(int _device_ordinal) {
|
|
||||||
device_ordinal_ = _device_ordinal;
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
this->Underlying().DoInit(_device_ordinal);
|
|
||||||
initialised_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
|
|
||||||
* different size of data on different workers.
|
|
||||||
*
|
|
||||||
* \param data Buffer storing the input data.
|
|
||||||
* \param length_bytes Size of input data in bytes.
|
|
||||||
* \param segments Size of data on each worker.
|
|
||||||
* \param recvbuf Buffer storing the result of data from all workers.
|
|
||||||
*/
|
|
||||||
void AllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
|
|
||||||
dh::caching_device_vector<char> *recvbuf) {
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allgather. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param data Buffer storing the input data.
|
|
||||||
* \param length Size of input data in bytes.
|
|
||||||
* \param recvbuf Buffer storing the result of data from all workers.
|
|
||||||
*/
|
|
||||||
void AllGather(uint32_t const *data, size_t length,
|
|
||||||
dh::caching_device_vector<uint32_t> *recvbuf) {
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoAllGather(data, length, recvbuf);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
allreduce_bytes_ += count * sizeof(double);
|
|
||||||
allreduce_calls_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
allreduce_bytes_ += count * sizeof(float);
|
|
||||||
allreduce_calls_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
|
|
||||||
*
|
|
||||||
* \param count Number of.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of.
|
|
||||||
*/
|
|
||||||
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
allreduce_bytes_ += count * sizeof(int64_t);
|
|
||||||
allreduce_calls_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
allreduce_bytes_ += count * sizeof(uint32_t);
|
|
||||||
allreduce_calls_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
allreduce_bytes_ += count * sizeof(uint64_t);
|
|
||||||
allreduce_calls_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* Specialization for size_t, which is implementation defined so it might or might not
|
|
||||||
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
template <typename T = size_t,
|
|
||||||
std::enable_if_t<std::is_same<size_t, T>::value &&
|
|
||||||
!std::is_same<size_t, unsigned long long>::value> // NOLINT
|
|
||||||
* = nullptr>
|
|
||||||
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
|
|
||||||
if (rabit::GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
|
|
||||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
allreduce_bytes_ += count * sizeof(T);
|
|
||||||
allreduce_calls_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \fn void Synchronize()
|
|
||||||
*
|
|
||||||
* \brief Synchronizes the entire communication group.
|
|
||||||
*/
|
|
||||||
void Synchronize() {
|
|
||||||
CHECK(initialised_);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
this->Underlying().DoSynchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool initialised_{false};
|
|
||||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
|
||||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
|
||||||
|
|
||||||
private:
|
|
||||||
int device_ordinal_{-1};
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
|
||||||
class NcclAllReducer : public AllReducerBase<NcclAllReducer> {
|
|
||||||
public:
|
|
||||||
friend class AllReducerBase<NcclAllReducer>;
|
|
||||||
|
|
||||||
~NcclAllReducer() override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
/**
|
|
||||||
* \brief Initialise with the desired device ordinal for this communication
|
|
||||||
* group.
|
|
||||||
*
|
|
||||||
* \param device_ordinal The device ordinal.
|
|
||||||
*/
|
|
||||||
void DoInit(int _device_ordinal);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
|
|
||||||
* different size of data on different workers.
|
|
||||||
*
|
|
||||||
* \param data Buffer storing the input data.
|
|
||||||
* \param length_bytes Size of input data in bytes.
|
|
||||||
* \param segments Size of data on each worker.
|
|
||||||
* \param recvbuf Buffer storing the result of data from all workers.
|
|
||||||
*/
|
|
||||||
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
|
|
||||||
dh::caching_device_vector<char> *recvbuf);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allgather. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param data Buffer storing the input data.
|
|
||||||
* \param length Size of input data in bytes.
|
|
||||||
* \param recvbuf Buffer storing the result of data from all workers.
|
|
||||||
*/
|
|
||||||
void DoAllGather(uint32_t const *data, size_t length,
|
|
||||||
dh::caching_device_vector<uint32_t> *recvbuf) {
|
|
||||||
size_t world = rabit::GetWorldSize();
|
|
||||||
recvbuf->resize(length * world);
|
|
||||||
safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, comm_, stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
|
||||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm_, stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
|
||||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm_, stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
|
|
||||||
*
|
|
||||||
* \param count Number of.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
|
||||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm_, stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
|
|
||||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
|
|
||||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
|
||||||
* streams or comms.
|
|
||||||
*
|
|
||||||
* Specialization for size_t, which is implementation defined so it might or might not
|
|
||||||
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
template <typename T = size_t,
|
|
||||||
std::enable_if_t<std::is_same<size_t, T>::value &&
|
|
||||||
!std::is_same<size_t, unsigned long long>::value> // NOLINT
|
|
||||||
* = nullptr>
|
|
||||||
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
|
|
||||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Synchronizes the entire communication group.
|
|
||||||
*/
|
|
||||||
void DoSynchronize() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \fn ncclUniqueId GetUniqueId()
|
|
||||||
*
|
|
||||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
|
||||||
* communication
|
|
||||||
*
|
|
||||||
* \return the Unique ID
|
|
||||||
*/
|
|
||||||
ncclUniqueId GetUniqueId() {
|
|
||||||
static const int kRootRank = 0;
|
|
||||||
ncclUniqueId id;
|
|
||||||
if (rabit::GetRank() == kRootRank) {
|
|
||||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
|
||||||
}
|
|
||||||
rabit::Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
ncclComm_t comm_;
|
|
||||||
cudaStream_t stream_;
|
|
||||||
ncclUniqueId id_;
|
|
||||||
};
|
|
||||||
|
|
||||||
using AllReducer = NcclAllReducer;
|
|
||||||
#else
|
|
||||||
class RabitAllReducer : public AllReducerBase<RabitAllReducer> {
|
|
||||||
public:
|
|
||||||
friend class AllReducerBase<RabitAllReducer>;
|
|
||||||
|
|
||||||
private:
|
|
||||||
/**
|
|
||||||
* \brief Initialise with the desired device ordinal for this allreducer.
|
|
||||||
*
|
|
||||||
* \param device_ordinal The device ordinal.
|
|
||||||
*/
|
|
||||||
static void DoInit(int _device_ordinal);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
|
|
||||||
* different size of data on different workers.
|
|
||||||
*
|
|
||||||
* \param data Buffer storing the input data.
|
|
||||||
* \param length_bytes Size of input data in bytes.
|
|
||||||
* \param segments Size of data on each worker.
|
|
||||||
* \param recvbuf Buffer storing the result of data from all workers.
|
|
||||||
*/
|
|
||||||
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
|
|
||||||
dh::caching_device_vector<char> *recvbuf);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allgather. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* \param data Buffer storing the input data.
|
|
||||||
* \param length Size of input data in bytes.
|
|
||||||
* \param recvbuf Buffer storing the result of data from all workers.
|
|
||||||
*/
|
|
||||||
void DoAllGather(uint32_t *data, size_t length, dh::caching_device_vector<uint32_t> *recvbuf) {
|
|
||||||
size_t world = rabit::GetWorldSize();
|
|
||||||
auto total_size = length * world;
|
|
||||||
recvbuf->resize(total_size);
|
|
||||||
sendrecvbuf_.reserve(total_size);
|
|
||||||
auto rank = rabit::GetRank();
|
|
||||||
safe_cuda(cudaMemcpy(sendrecvbuf_.data() + rank * length, data, length, cudaMemcpyDefault));
|
|
||||||
rabit::Allgather(sendrecvbuf_.data(), total_size, rank * length, length, length);
|
|
||||||
safe_cuda(cudaMemcpy(data, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
|
||||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
|
||||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
|
||||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
|
|
||||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
|
|
||||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* Specialization for size_t, which is implementation defined so it might or might not
|
|
||||||
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
template <typename T = size_t,
|
|
||||||
std::enable_if_t<std::is_same<size_t, T>::value &&
|
|
||||||
!std::is_same<size_t, unsigned long long>::value> // NOLINT
|
|
||||||
* = nullptr>
|
|
||||||
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
|
|
||||||
RabitAllReduceSum(sendbuff, recvbuff, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Synchronizes the allreducer.
|
|
||||||
*/
|
|
||||||
void DoSynchronize() {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL.
|
|
||||||
*
|
|
||||||
* Copy the device buffer to host, call rabit allreduce, then copy the buffer back
|
|
||||||
* to device.
|
|
||||||
*
|
|
||||||
* \param sendbuff The sendbuff.
|
|
||||||
* \param recvbuff The recvbuff.
|
|
||||||
* \param count Number of elements.
|
|
||||||
*/
|
|
||||||
template <typename T>
|
|
||||||
void RabitAllReduceSum(const T *sendbuff, T *recvbuff, int count) {
|
|
||||||
auto total_size = count * sizeof(T);
|
|
||||||
sendrecvbuf_.reserve(total_size);
|
|
||||||
safe_cuda(cudaMemcpy(sendrecvbuf_.data(), sendbuff, total_size, cudaMemcpyDefault));
|
|
||||||
rabit::Allreduce<rabit::op::Sum>(reinterpret_cast<T*>(sendrecvbuf_.data()), count);
|
|
||||||
safe_cuda(cudaMemcpy(recvbuff, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Host buffer used to call rabit functions.
|
|
||||||
std::vector<char> sendrecvbuf_{};
|
|
||||||
};
|
|
||||||
|
|
||||||
using AllReducer = RabitAllReducer;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename VectorT, typename T = typename VectorT::value_type,
|
template <typename VectorT, typename T = typename VectorT::value_type,
|
||||||
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
||||||
xgboost::common::Span<T> ToSpan(
|
xgboost::common::Span<T> ToSpan(
|
||||||
|
|||||||
@ -3,19 +3,14 @@
|
|||||||
* \file hist_util.cc
|
* \file hist_util.cc
|
||||||
*/
|
*/
|
||||||
#include <dmlc/timer.h>
|
#include <dmlc/timer.h>
|
||||||
#include <dmlc/omp.h>
|
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <numeric>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
#include "random.h"
|
|
||||||
#include "column_matrix.h"
|
#include "column_matrix.h"
|
||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
#include "../data/gradient_index.h"
|
|
||||||
|
|
||||||
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
||||||
#include <xmmintrin.h>
|
#include <xmmintrin.h>
|
||||||
|
|||||||
@ -6,10 +6,10 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
#include "rabit/rabit.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -144,8 +144,8 @@ struct QuantileAllreduce {
|
|||||||
void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_threads,
|
void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_threads,
|
||||||
std::vector<std::set<float>> *p_categories) {
|
std::vector<std::set<float>> *p_categories) {
|
||||||
auto &categories = *p_categories;
|
auto &categories = *p_categories;
|
||||||
auto world_size = rabit::GetWorldSize();
|
auto world_size = collective::GetWorldSize();
|
||||||
auto rank = rabit::GetRank();
|
auto rank = collective::GetRank();
|
||||||
if (world_size == 1) {
|
if (world_size == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -163,7 +163,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
|||||||
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
||||||
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
|
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
|
||||||
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
|
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
|
||||||
rabit::Allreduce<rabit::op::Sum>(global_feat_ptrs.data(), global_feat_ptrs.size());
|
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
|
||||||
|
global_feat_ptrs.size());
|
||||||
|
|
||||||
// move all categories into a flatten vector to prepare for allreduce
|
// move all categories into a flatten vector to prepare for allreduce
|
||||||
size_t total = feature_ptr.back();
|
size_t total = feature_ptr.back();
|
||||||
@ -176,7 +177,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
|||||||
// indptr for indexing workers
|
// indptr for indexing workers
|
||||||
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
|
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
|
||||||
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
|
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
|
||||||
rabit::Allreduce<rabit::op::Sum>(global_worker_ptr.data(), global_worker_ptr.size());
|
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
|
||||||
|
global_worker_ptr.size());
|
||||||
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
|
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
|
||||||
// total number of categories in all workers with all features
|
// total number of categories in all workers with all features
|
||||||
auto gtotal = global_worker_ptr.back();
|
auto gtotal = global_worker_ptr.back();
|
||||||
@ -188,7 +190,8 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
|||||||
CHECK_EQ(rank_size, total);
|
CHECK_EQ(rank_size, total);
|
||||||
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
|
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
|
||||||
// gather values from all workers.
|
// gather values from all workers.
|
||||||
rabit::Allreduce<rabit::op::Sum>(global_categories.data(), global_categories.size());
|
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
|
||||||
|
global_categories.size());
|
||||||
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
||||||
categories.size()};
|
categories.size()};
|
||||||
ParallelFor(categories.size(), n_threads, [&](auto fidx) {
|
ParallelFor(categories.size(), n_threads, [&](auto fidx) {
|
||||||
@ -217,8 +220,8 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
|||||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||||
auto &worker_segments = *p_worker_segments;
|
auto &worker_segments = *p_worker_segments;
|
||||||
worker_segments.resize(1, 0);
|
worker_segments.resize(1, 0);
|
||||||
auto world = rabit::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
auto rank = rabit::GetRank();
|
auto rank = collective::GetRank();
|
||||||
auto n_columns = sketches_.size();
|
auto n_columns = sketches_.size();
|
||||||
|
|
||||||
// get the size of each feature.
|
// get the size of each feature.
|
||||||
@ -237,7 +240,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
|||||||
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
|
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
|
||||||
|
|
||||||
// Gather all column pointers
|
// Gather all column pointers
|
||||||
rabit::Allreduce<rabit::op::Sum>(sketches_scan.data(), sketches_scan.size());
|
collective::Allreduce<collective::Operation::kSum>(sketches_scan.data(), sketches_scan.size());
|
||||||
for (int32_t i = 0; i < world; ++i) {
|
for (int32_t i = 0; i < world; ++i) {
|
||||||
size_t back = (i + 1) * (n_columns + 1) - 1;
|
size_t back = (i + 1) * (n_columns + 1) - 1;
|
||||||
auto n_entries = sketches_scan.at(back);
|
auto n_entries = sketches_scan.at(back);
|
||||||
@ -265,7 +268,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
|||||||
|
|
||||||
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
|
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
|
||||||
"Unexpected size of sketch entry.");
|
"Unexpected size of sketch entry.");
|
||||||
rabit::Allreduce<rabit::op::Sum>(
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
reinterpret_cast<float *>(global_sketches.data()),
|
reinterpret_cast<float *>(global_sketches.data()),
|
||||||
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
|
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
|
||||||
}
|
}
|
||||||
@ -277,7 +280,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
|||||||
monitor_.Start(__func__);
|
monitor_.Start(__func__);
|
||||||
|
|
||||||
size_t n_columns = sketches_.size();
|
size_t n_columns = sketches_.size();
|
||||||
rabit::Allreduce<rabit::op::Max>(&n_columns, 1);
|
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
|
||||||
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
||||||
|
|
||||||
AllreduceCategories(feature_types_, n_threads_, &categories_);
|
AllreduceCategories(feature_types_, n_threads_, &categories_);
|
||||||
@ -291,7 +294,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
|||||||
|
|
||||||
// Prune the intermediate num cuts for synchronization.
|
// Prune the intermediate num cuts for synchronization.
|
||||||
std::vector<bst_row_t> global_column_size(columns_size_);
|
std::vector<bst_row_t> global_column_size(columns_size_);
|
||||||
rabit::Allreduce<rabit::op::Sum>(global_column_size.data(), global_column_size.size());
|
collective::Allreduce<collective::Operation::kSum>(global_column_size.data(),
|
||||||
|
global_column_size.size());
|
||||||
|
|
||||||
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
|
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
|
||||||
int32_t intermediate_num_cuts = static_cast<int32_t>(
|
int32_t intermediate_num_cuts = static_cast<int32_t>(
|
||||||
@ -311,7 +315,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
|||||||
num_cuts[i] = intermediate_num_cuts;
|
num_cuts[i] = intermediate_num_cuts;
|
||||||
});
|
});
|
||||||
|
|
||||||
auto world = rabit::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world == 1) {
|
if (world == 1) {
|
||||||
monitor_.Stop(__func__);
|
monitor_.Stop(__func__);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -12,6 +12,8 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../collective/communicator.h"
|
||||||
|
#include "../collective/device_communicator.cuh"
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
@ -501,47 +503,41 @@ void SketchContainer::FixError() {
|
|||||||
|
|
||||||
void SketchContainer::AllReduce() {
|
void SketchContainer::AllReduce() {
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
auto world = rabit::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world == 1) {
|
if (world == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
if (!reducer_) {
|
auto* communicator = collective::Communicator::GetDevice(device_);
|
||||||
reducer_ = std::make_shared<dh::AllReducer>();
|
|
||||||
reducer_->Init(device_);
|
|
||||||
}
|
|
||||||
// Reduce the overhead on syncing.
|
// Reduce the overhead on syncing.
|
||||||
size_t global_sum_rows = num_rows_;
|
size_t global_sum_rows = num_rows_;
|
||||||
rabit::Allreduce<rabit::op::Sum>(&global_sum_rows, 1);
|
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
|
||||||
size_t intermediate_num_cuts =
|
size_t intermediate_num_cuts =
|
||||||
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
|
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
|
||||||
this->Prune(intermediate_num_cuts);
|
this->Prune(intermediate_num_cuts);
|
||||||
|
|
||||||
|
|
||||||
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
||||||
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
|
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
|
||||||
size_t n = d_columns_ptr.size();
|
size_t n = d_columns_ptr.size();
|
||||||
rabit::Allreduce<rabit::op::Max>(&n, 1);
|
collective::Allreduce<collective::Operation::kMax>(&n, 1);
|
||||||
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
|
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
|
||||||
|
|
||||||
// Get the columns ptr from all workers
|
// Get the columns ptr from all workers
|
||||||
dh::device_vector<SketchContainer::OffsetT> gathered_ptrs;
|
dh::device_vector<SketchContainer::OffsetT> gathered_ptrs;
|
||||||
gathered_ptrs.resize(d_columns_ptr.size() * world, 0);
|
gathered_ptrs.resize(d_columns_ptr.size() * world, 0);
|
||||||
size_t rank = rabit::GetRank();
|
size_t rank = collective::GetRank();
|
||||||
auto offset = rank * d_columns_ptr.size();
|
auto offset = rank * d_columns_ptr.size();
|
||||||
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
|
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
|
||||||
gathered_ptrs.begin() + offset);
|
gathered_ptrs.begin() + offset);
|
||||||
reducer_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.data().get(),
|
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
|
||||||
gathered_ptrs.size());
|
|
||||||
|
|
||||||
// Get the data from all workers.
|
// Get the data from all workers.
|
||||||
std::vector<size_t> recv_lengths;
|
std::vector<size_t> recv_lengths;
|
||||||
dh::caching_device_vector<char> recvbuf;
|
dh::caching_device_vector<char> recvbuf;
|
||||||
reducer_->AllGather(this->Current().data().get(),
|
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
|
||||||
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths,
|
&recv_lengths, &recvbuf);
|
||||||
&recvbuf);
|
communicator->Synchronize();
|
||||||
reducer_->Synchronize();
|
|
||||||
|
|
||||||
// Segment the received data.
|
// Segment the received data.
|
||||||
auto s_recvbuf = dh::ToSpan(recvbuf);
|
auto s_recvbuf = dh::ToSpan(recvbuf);
|
||||||
|
|||||||
@ -37,7 +37,6 @@ class SketchContainer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Monitor timer_;
|
Monitor timer_;
|
||||||
std::shared_ptr<dh::AllReducer> reducer_;
|
|
||||||
HostDeviceVector<FeatureType> feature_types_;
|
HostDeviceVector<FeatureType> feature_types_;
|
||||||
bst_row_t num_rows_;
|
bst_row_t num_rows_;
|
||||||
bst_feature_t num_columns_;
|
bst_feature_t num_columns_;
|
||||||
@ -93,15 +92,12 @@ class SketchContainer {
|
|||||||
* \param num_columns Total number of columns in dataset.
|
* \param num_columns Total number of columns in dataset.
|
||||||
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
|
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
|
||||||
* \param device GPU ID.
|
* \param device GPU ID.
|
||||||
* \param reducer Optional initialised reducer. Useful for speeding up testing.
|
|
||||||
*/
|
*/
|
||||||
SketchContainer(HostDeviceVector<FeatureType> const &feature_types,
|
SketchContainer(HostDeviceVector<FeatureType> const &feature_types,
|
||||||
int32_t max_bin, bst_feature_t num_columns,
|
int32_t max_bin, bst_feature_t num_columns,
|
||||||
bst_row_t num_rows, int32_t device,
|
bst_row_t num_rows, int32_t device)
|
||||||
std::shared_ptr<dh::AllReducer> reducer = nullptr)
|
|
||||||
: num_rows_{num_rows},
|
: num_rows_{num_rows},
|
||||||
num_columns_{num_columns}, num_bins_{max_bin}, device_{device},
|
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||||
reducer_(std::move(reducer)) {
|
|
||||||
CHECK_GE(device, 0);
|
CHECK_GE(device, 0);
|
||||||
// Initialize Sketches for this dmatrix
|
// Initialize Sketches for this dmatrix
|
||||||
this->columns_ptr_.SetDevice(device_);
|
this->columns_ptr_.SetDevice(device_);
|
||||||
|
|||||||
@ -7,20 +7,21 @@
|
|||||||
#ifndef XGBOOST_COMMON_RANDOM_H_
|
#ifndef XGBOOST_COMMON_RANDOM_H_
|
||||||
#define XGBOOST_COMMON_RANDOM_H_
|
#define XGBOOST_COMMON_RANDOM_H_
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -143,7 +144,7 @@ class ColumnSampler {
|
|||||||
*/
|
*/
|
||||||
ColumnSampler() {
|
ColumnSampler() {
|
||||||
uint32_t seed = common::GlobalRandom()();
|
uint32_t seed = common::GlobalRandom()();
|
||||||
rabit::Broadcast(&seed, sizeof(seed), 0);
|
collective::Broadcast(&seed, sizeof(seed), 0);
|
||||||
rng_.seed(seed);
|
rng_.seed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,13 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright by Contributors 2019
|
* Copyright by Contributors 2019
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
#include <sstream>
|
|
||||||
#include "timer.h"
|
#include "timer.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_NVTX)
|
#if defined(XGBOOST_USE_NVTX)
|
||||||
#include <nvToolsExt.h>
|
#include <nvToolsExt.h>
|
||||||
#endif // defined(XGBOOST_USE_NVTX)
|
#endif // defined(XGBOOST_USE_NVTX)
|
||||||
@ -54,7 +53,7 @@ void Monitor::PrintStatistics(StatMap const& statistics) const {
|
|||||||
|
|
||||||
void Monitor::Print() const {
|
void Monitor::Print() const {
|
||||||
if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; }
|
if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; }
|
||||||
auto rank = rabit::GetRank();
|
auto rank = collective::GetRank();
|
||||||
StatMap stat_map;
|
StatMap stat_map;
|
||||||
for (auto const &kv : statistics_map_) {
|
for (auto const &kv : statistics_map_) {
|
||||||
stat_map[kv.first] = std::make_pair(
|
stat_map[kv.first] = std::make_pair(
|
||||||
|
|||||||
@ -2,36 +2,36 @@
|
|||||||
* Copyright 2015-2022 by XGBoost Contributors
|
* Copyright 2015-2022 by XGBoost Contributors
|
||||||
* \file data.cc
|
* \file data.cc
|
||||||
*/
|
*/
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
#include "dmlc/io.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "xgboost/data.h"
|
#include "../common/group_data.h"
|
||||||
#include "xgboost/c_api.h"
|
|
||||||
#include "xgboost/host_device_vector.h"
|
|
||||||
#include "xgboost/logging.h"
|
|
||||||
#include "xgboost/version_config.h"
|
|
||||||
#include "xgboost/learner.h"
|
|
||||||
#include "xgboost/string_view.h"
|
|
||||||
|
|
||||||
#include "sparse_page_writer.h"
|
|
||||||
#include "simple_dmatrix.h"
|
|
||||||
|
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/linalg_op.h"
|
#include "../common/linalg_op.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/numeric.h"
|
#include "../common/numeric.h"
|
||||||
#include "../common/version.h"
|
|
||||||
#include "../common/group_data.h"
|
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
#include "../common/version.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "../data/iterative_dmatrix.h"
|
#include "../data/iterative_dmatrix.h"
|
||||||
#include "file_iterator.h"
|
|
||||||
|
|
||||||
#include "validation.h"
|
|
||||||
#include "./sparse_page_source.h"
|
|
||||||
#include "./sparse_page_dmatrix.h"
|
#include "./sparse_page_dmatrix.h"
|
||||||
|
#include "./sparse_page_source.h"
|
||||||
|
#include "dmlc/io.h"
|
||||||
|
#include "file_iterator.h"
|
||||||
|
#include "simple_dmatrix.h"
|
||||||
|
#include "sparse_page_writer.h"
|
||||||
|
#include "validation.h"
|
||||||
|
#include "xgboost/c_api.h"
|
||||||
|
#include "xgboost/host_device_vector.h"
|
||||||
|
#include "xgboost/learner.h"
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/string_view.h"
|
||||||
|
#include "xgboost/version_config.h"
|
||||||
|
|
||||||
namespace dmlc {
|
namespace dmlc {
|
||||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
|
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
|
||||||
@ -793,12 +793,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
|||||||
size_t pos = cache_shards[i].rfind('.');
|
size_t pos = cache_shards[i].rfind('.');
|
||||||
if (pos == std::string::npos) {
|
if (pos == std::string::npos) {
|
||||||
os << cache_shards[i]
|
os << cache_shards[i]
|
||||||
<< ".r" << rabit::GetRank()
|
<< ".r" << collective::GetRank()
|
||||||
<< "-" << rabit::GetWorldSize();
|
<< "-" << collective::GetWorldSize();
|
||||||
} else {
|
} else {
|
||||||
os << cache_shards[i].substr(0, pos)
|
os << cache_shards[i].substr(0, pos)
|
||||||
<< ".r" << rabit::GetRank()
|
<< ".r" << collective::GetRank()
|
||||||
<< "-" << rabit::GetWorldSize()
|
<< "-" << collective::GetWorldSize()
|
||||||
<< cache_shards[i].substr(pos, cache_shards[i].length());
|
<< cache_shards[i].substr(pos, cache_shards[i].length());
|
||||||
}
|
}
|
||||||
if (i + 1 != cache_shards.size()) {
|
if (i + 1 != cache_shards.size()) {
|
||||||
@ -821,8 +821,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
|||||||
|
|
||||||
int partid = 0, npart = 1;
|
int partid = 0, npart = 1;
|
||||||
if (load_row_split) {
|
if (load_row_split) {
|
||||||
partid = rabit::GetRank();
|
partid = collective::GetRank();
|
||||||
npart = rabit::GetWorldSize();
|
npart = collective::GetWorldSize();
|
||||||
} else {
|
} else {
|
||||||
// test option to load in part
|
// test option to load in part
|
||||||
npart = 1;
|
npart = 1;
|
||||||
@ -877,7 +877,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
|||||||
/* sync up number of features after matrix loaded.
|
/* sync up number of features after matrix loaded.
|
||||||
* partitioned data will fail the train/val validation check
|
* partitioned data will fail the train/val validation check
|
||||||
* since partitioned data not knowing the real number of features. */
|
* since partitioned data not knowing the real number of features. */
|
||||||
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
||||||
return dmat;
|
return dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,13 +3,11 @@
|
|||||||
*/
|
*/
|
||||||
#include "iterative_dmatrix.h"
|
#include "iterative_dmatrix.h"
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
|
|
||||||
#include <algorithm> // std::copy
|
#include <algorithm> // std::copy
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/categorical.h" // common::IsCat
|
#include "../common/categorical.h" // common::IsCat
|
||||||
#include "../common/column_matrix.h"
|
#include "../common/column_matrix.h"
|
||||||
#include "../common/hist_util.h" // common::HistogramCuts
|
|
||||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
#include "proxy_dmatrix.h"
|
#include "proxy_dmatrix.h"
|
||||||
@ -140,7 +138,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
// We use do while here as the first batch is fetched in ctor
|
// We use do while here as the first batch is fetched in ctor
|
||||||
if (n_features == 0) {
|
if (n_features == 0) {
|
||||||
n_features = num_cols();
|
n_features = num_cols();
|
||||||
rabit::Allreduce<rabit::op::Max>(&n_features, 1);
|
collective::Allreduce<collective::Operation::kMax>(&n_features, 1);
|
||||||
column_sizes.resize(n_features);
|
column_sizes.resize(n_features);
|
||||||
info_.num_col_ = n_features;
|
info_.num_col_ = n_features;
|
||||||
} else {
|
} else {
|
||||||
@ -157,7 +155,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
// From here on Info() has the correct data shape
|
// From here on Info() has the correct data shape
|
||||||
Info().num_row_ = accumulated_rows;
|
Info().num_row_ = accumulated_rows;
|
||||||
Info().num_nonzero_ = nnz;
|
Info().num_nonzero_ = nnz;
|
||||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||||
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
|
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
|
||||||
return f > accumulated_rows;
|
return f > accumulated_rows;
|
||||||
})) << "Something went wrong during iteration.";
|
})) << "Something went wrong during iteration.";
|
||||||
|
|||||||
@ -62,7 +62,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||||
if (cols == 0) {
|
if (cols == 0) {
|
||||||
cols = num_cols();
|
cols = num_cols();
|
||||||
rabit::Allreduce<rabit::op::Max>(&cols, 1);
|
collective::Allreduce<collective::Operation::kMax>(&cols, 1);
|
||||||
this->info_.num_col_ = cols;
|
this->info_.num_col_ = cols;
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||||
@ -166,7 +166,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
|
|
||||||
iter.Reset();
|
iter.Reset();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
||||||
|
|||||||
@ -189,7 +189,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
|||||||
|
|
||||||
|
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||||
|
|
||||||
if (adapter->NumRows() == kAdapterUnknownSize) {
|
if (adapter->NumRows() == kAdapterUnknownSize) {
|
||||||
using IteratorAdapterT
|
using IteratorAdapterT
|
||||||
@ -322,7 +322,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
|||||||
}
|
}
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||||
info_.num_row_ = total_batch_size;
|
info_.num_row_ = total_batch_size;
|
||||||
info_.num_nonzero_ = data_vec.size();
|
info_.num_nonzero_ = data_vec.size();
|
||||||
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||||
|
|||||||
@ -35,7 +35,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
|||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||||
|
|||||||
@ -5,6 +5,8 @@
|
|||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include "./sparse_page_dmatrix.h"
|
#include "./sparse_page_dmatrix.h"
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "./simple_batch_iterator.h"
|
#include "./simple_batch_iterator.h"
|
||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
|
|
||||||
@ -46,8 +48,8 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
cache_prefix_{std::move(cache_prefix)} {
|
cache_prefix_{std::move(cache_prefix)} {
|
||||||
ctx_.nthread = nthreads;
|
ctx_.nthread = nthreads;
|
||||||
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
|
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
cache_prefix_ += ("-r" + std::to_string(rabit::GetRank()));
|
cache_prefix_ += ("-r" + std::to_string(collective::GetRank()));
|
||||||
}
|
}
|
||||||
DMatrixProxy *proxy = MakeProxy(proxy_);
|
DMatrixProxy *proxy = MakeProxy(proxy_);
|
||||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
||||||
@ -94,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
this->info_.num_col_ = n_features;
|
this->info_.num_col_ = n_features;
|
||||||
this->info_.num_nonzero_ = nnz;
|
this->info_.num_nonzero_ = nnz;
|
||||||
|
|
||||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
||||||
CHECK_NE(info_.num_col_, 0);
|
CHECK_NE(info_.num_col_, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "rabit/rabit.h"
|
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
|
|||||||
@ -135,7 +135,7 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
LOG(INFO) << "Tree method is automatically selected to be 'approx' "
|
LOG(INFO) << "Tree method is automatically selected to be 'approx' "
|
||||||
"for distributed training.";
|
"for distributed training.";
|
||||||
tparam_.tree_method = TreeMethod::kApprox;
|
tparam_.tree_method = TreeMethod::kApprox;
|
||||||
|
|||||||
@ -23,6 +23,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "collective/communicator-inl.h"
|
||||||
#include "common/charconv.h"
|
#include "common/charconv.h"
|
||||||
#include "common/common.h"
|
#include "common/common.h"
|
||||||
#include "common/io.h"
|
#include "common/io.h"
|
||||||
@ -478,7 +479,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
|
|
||||||
// add additional parameters
|
// add additional parameters
|
||||||
// These are cosntraints that need to be satisfied.
|
// These are cosntraints that need to be satisfied.
|
||||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
|
||||||
tparam_.dsplit = DataSplitMode::kRow;
|
tparam_.dsplit = DataSplitMode::kRow;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -757,7 +758,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||||
}
|
}
|
||||||
|
|
||||||
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
|
collective::Allreduce<collective::Operation::kMax>(&num_feature, 1);
|
||||||
if (num_feature > mparam_.num_feature) {
|
if (num_feature > mparam_.num_feature) {
|
||||||
mparam_.num_feature = num_feature;
|
mparam_.num_feature = num_feature;
|
||||||
}
|
}
|
||||||
@ -1083,7 +1084,7 @@ class LearnerIO : public LearnerConfiguration {
|
|||||||
cfg_.insert(n.cbegin(), n.cend());
|
cfg_.insert(n.cbegin(), n.cend());
|
||||||
|
|
||||||
// copy dsplit from config since it will not run again during restore
|
// copy dsplit from config since it will not run again during restore
|
||||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
|
||||||
tparam_.dsplit = DataSplitMode::kRow;
|
tparam_.dsplit = DataSplitMode::kRow;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1228,7 +1229,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
}
|
}
|
||||||
// Configuration before data is known.
|
// Configuration before data is known.
|
||||||
void CheckDataSplitMode() {
|
void CheckDataSplitMode() {
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
|
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
|
||||||
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
|
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
|
||||||
if (tparam_.dsplit == DataSplitMode::kCol) {
|
if (tparam_.dsplit == DataSplitMode::kCol) {
|
||||||
@ -1488,7 +1489,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (p_fmat->Info().num_row_ == 0) {
|
if (p_fmat->Info().num_row_ == 0) {
|
||||||
LOG(WARNING) << "Empty dataset at worker: " << rabit::GetRank();
|
LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,14 +4,12 @@
|
|||||||
* \brief Implementation of loggers.
|
* \brief Implementation of loggers.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "xgboost/parameter.h"
|
#include "xgboost/parameter.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/json.h"
|
|
||||||
|
#include "collective/communicator-inl.h"
|
||||||
|
|
||||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||||
// Override logging mechanism for non-R interfaces
|
// Override logging mechanism for non-R interfaces
|
||||||
@ -32,7 +30,7 @@ ConsoleLogger::~ConsoleLogger() {
|
|||||||
|
|
||||||
TrackerLogger::~TrackerLogger() {
|
TrackerLogger::~TrackerLogger() {
|
||||||
log_stream_ << '\n';
|
log_stream_ << '\n';
|
||||||
rabit::TrackerPrint(log_stream_.str());
|
collective::Print(log_stream_.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,27 +1,23 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2021 by XGBoost Contributors
|
* Copyright 2021 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include "auc.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <algorithm>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <utility>
|
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "rabit/rabit.h"
|
|
||||||
#include "xgboost/linalg.h"
|
|
||||||
#include "xgboost/host_device_vector.h"
|
|
||||||
#include "xgboost/metric.h"
|
|
||||||
|
|
||||||
#include "auc.h"
|
|
||||||
|
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
|
#include "xgboost/metric.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace metric {
|
namespace metric {
|
||||||
@ -117,7 +113,8 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
|
|||||||
|
|
||||||
// we have 2 averages going in here, first is among workers, second is among
|
// we have 2 averages going in here, first is among workers, second is among
|
||||||
// classes. allreduce sums up fp/tp auc for each class.
|
// classes. allreduce sums up fp/tp auc for each class.
|
||||||
rabit::Allreduce<rabit::op::Sum>(results.Values().data(), results.Values().size());
|
collective::Allreduce<collective::Operation::kSum>(results.Values().data(),
|
||||||
|
results.Values().size());
|
||||||
double auc_sum{0};
|
double auc_sum{0};
|
||||||
double tp_sum{0};
|
double tp_sum{0};
|
||||||
for (size_t c = 0; c < n_classes; ++c) {
|
for (size_t c = 0; c < n_classes; ++c) {
|
||||||
@ -265,7 +262,7 @@ class EvalAUC : public Metric {
|
|||||||
}
|
}
|
||||||
// We use the global size to handle empty dataset.
|
// We use the global size to handle empty dataset.
|
||||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||||
rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size());
|
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
||||||
if (meta[0] == 0) {
|
if (meta[0] == 0) {
|
||||||
// Empty across all workers, which is not supported.
|
// Empty across all workers, which is not supported.
|
||||||
auc = std::numeric_limits<double>::quiet_NaN();
|
auc = std::numeric_limits<double>::quiet_NaN();
|
||||||
@ -287,7 +284,7 @@ class EvalAUC : public Metric {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
|
std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
|
||||||
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
|
collective::Allreduce<collective::Operation::kSum>(results.data(), results.size());
|
||||||
auc = results[0];
|
auc = results[0];
|
||||||
valid_groups = static_cast<uint32_t>(results[1]);
|
valid_groups = static_cast<uint32_t>(results[1]);
|
||||||
|
|
||||||
@ -316,7 +313,7 @@ class EvalAUC : public Metric {
|
|||||||
}
|
}
|
||||||
double local_area = fp * tp;
|
double local_area = fp * tp;
|
||||||
std::array<double, 2> result{auc, local_area};
|
std::array<double, 2> result{auc, local_area};
|
||||||
rabit::Allreduce<rabit::op::Sum>(result.data(), result.size());
|
collective::Allreduce<collective::Operation::kSum>(result.data(), result.size());
|
||||||
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
|
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
|
||||||
if (local_area <= 0) {
|
if (local_area <= 0) {
|
||||||
// the dataset across all workers have only positive or negative sample
|
// the dataset across all workers have only positive or negative sample
|
||||||
|
|||||||
@ -11,11 +11,10 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#include "rabit/rabit.h"
|
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "auc.h"
|
#include "auc.h"
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../collective/device_communicator.cuh"
|
||||||
#include "../common/ranking_utils.cuh"
|
#include "../common/ranking_utils.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -46,9 +45,8 @@ struct DeviceAUCCache {
|
|||||||
dh::device_vector<size_t> unique_idx;
|
dh::device_vector<size_t> unique_idx;
|
||||||
// p^T: transposed prediction matrix, used by MultiClassAUC
|
// p^T: transposed prediction matrix, used by MultiClassAUC
|
||||||
dh::device_vector<float> predts_t;
|
dh::device_vector<float> predts_t;
|
||||||
std::unique_ptr<dh::AllReducer> reducer;
|
|
||||||
|
|
||||||
void Init(common::Span<float const> predts, bool is_multi, int32_t device) {
|
void Init(common::Span<float const> predts, bool is_multi) {
|
||||||
if (sorted_idx.size() != predts.size()) {
|
if (sorted_idx.size() != predts.size()) {
|
||||||
sorted_idx.resize(predts.size());
|
sorted_idx.resize(predts.size());
|
||||||
fptp.resize(sorted_idx.size());
|
fptp.resize(sorted_idx.size());
|
||||||
@ -58,10 +56,6 @@ struct DeviceAUCCache {
|
|||||||
predts_t.resize(sorted_idx.size());
|
predts_t.resize(sorted_idx.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (is_multi && !reducer) {
|
|
||||||
reducer.reset(new dh::AllReducer);
|
|
||||||
reducer->Init(device);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -72,7 +66,7 @@ void InitCacheOnce(common::Span<float const> predts, int32_t device,
|
|||||||
if (!cache) {
|
if (!cache) {
|
||||||
cache.reset(new DeviceAUCCache);
|
cache.reset(new DeviceAUCCache);
|
||||||
}
|
}
|
||||||
cache->Init(predts, is_multi, device);
|
cache->Init(predts, is_multi);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -205,9 +199,11 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
|||||||
common::Span<double> tp, common::Span<double> auc,
|
common::Span<double> tp, common::Span<double> auc,
|
||||||
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
|
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
|
||||||
dh::XGBDeviceAllocator<char> alloc;
|
dh::XGBDeviceAllocator<char> alloc;
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
|
int32_t device = dh::CurrentDevice();
|
||||||
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
|
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||||
|
auto* communicator = collective::Communicator::GetDevice(device);
|
||||||
|
communicator->AllReduceSum(results.data(), results.size());
|
||||||
}
|
}
|
||||||
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
||||||
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
||||||
|
|||||||
@ -10,13 +10,13 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "rabit/rabit.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "xgboost/base.h"
|
|
||||||
#include "xgboost/span.h"
|
|
||||||
#include "xgboost/data.h"
|
|
||||||
#include "xgboost/metric.h"
|
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/metric.h"
|
||||||
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace metric {
|
namespace metric {
|
||||||
@ -101,7 +101,7 @@ XGBOOST_DEVICE inline double CalcDeltaPRAUC(double fp_prev, double fp,
|
|||||||
|
|
||||||
inline void InvalidGroupAUC() {
|
inline void InvalidGroupAUC() {
|
||||||
LOG(INFO) << "Invalid group with less than 3 samples is found on worker "
|
LOG(INFO) << "Invalid group with less than 3 samples is found on worker "
|
||||||
<< rabit::GetRank() << ". Calculating AUC value requires at "
|
<< collective::GetRank() << ". Calculating AUC value requires at "
|
||||||
<< "least 2 pairs of samples.";
|
<< "least 2 pairs of samples.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,11 @@
|
|||||||
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
|
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
|
||||||
*/
|
*/
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/pseudo_huber.h"
|
#include "../common/pseudo_huber.h"
|
||||||
@ -196,8 +196,8 @@ class PseudoErrorLoss : public Metric {
|
|||||||
return std::make_tuple(v, wt);
|
return std::make_tuple(v, wt);
|
||||||
});
|
});
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
double dat[2]{result.Residue(), result.Weights()};
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||||
}
|
}
|
||||||
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
|
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
@ -365,7 +365,7 @@ struct EvalEWiseBase : public Metric {
|
|||||||
});
|
});
|
||||||
|
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
double dat[2]{result.Residue(), result.Weights()};
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||||
return Policy::GetFinal(dat[0], dat[1]);
|
return Policy::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,15 +4,14 @@
|
|||||||
* \brief evaluation metrics for multiclass classification.
|
* \brief evaluation metrics for multiclass classification.
|
||||||
* \author Kailong Chen, Tianqi Chen
|
* \author Kailong Chen, Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
@ -185,7 +184,7 @@ struct EvalMClassBase : public Metric {
|
|||||||
dat[0] = result.Residue();
|
dat[0] = result.Residue();
|
||||||
dat[1] = result.Weights();
|
dat[1] = result.Weights();
|
||||||
}
|
}
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||||
return Derived::GetFinal(dat[0], dat[1]);
|
return Derived::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -20,17 +20,17 @@
|
|||||||
// corresponding headers that brings in those function declaration can't be included with CUDA).
|
// corresponding headers that brings in those function declaration can't be included with CUDA).
|
||||||
// This precludes the CPU and GPU logic to coexist inside a .cu file
|
// This precludes the CPU and GPU logic to coexist inside a .cu file
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <xgboost/metric.h>
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
#include <cmath>
|
#include <xgboost/metric.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ struct EvalAMS : public Metric {
|
|||||||
}
|
}
|
||||||
|
|
||||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||||
CHECK(!rabit::IsDistributed()) << "metric AMS do not support distributed evaluation";
|
CHECK(!collective::IsDistributed()) << "metric AMS do not support distributed evaluation";
|
||||||
using namespace std; // NOLINT(*)
|
using namespace std; // NOLINT(*)
|
||||||
|
|
||||||
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
|
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
|
||||||
@ -216,10 +216,10 @@ struct EvalRank : public Metric, public EvalRankConfig {
|
|||||||
exc.Rethrow();
|
exc.Rethrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
double dat[2]{sum_metric, static_cast<double>(ngroups)};
|
double dat[2]{sum_metric, static_cast<double>(ngroups)};
|
||||||
// approximately estimate the metric using mean
|
// approximately estimate the metric using mean
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||||
return dat[0] / dat[1];
|
return dat[0] / dat[1];
|
||||||
} else {
|
} else {
|
||||||
return sum_metric / ngroups;
|
return sum_metric / ngroups;
|
||||||
@ -341,7 +341,7 @@ struct EvalCox : public Metric {
|
|||||||
public:
|
public:
|
||||||
EvalCox() = default;
|
EvalCox() = default;
|
||||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||||
CHECK(!rabit::IsDistributed()) << "Cox metric does not support distributed evaluation";
|
CHECK(!collective::IsDistributed()) << "Cox metric does not support distributed evaluation";
|
||||||
using namespace std; // NOLINT(*)
|
using namespace std; // NOLINT(*)
|
||||||
|
|
||||||
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
|
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
|
||||||
|
|||||||
@ -4,15 +4,12 @@
|
|||||||
* \brief prediction rank based metrics.
|
* \brief prediction rank based metrics.
|
||||||
* \author Kailong Chen, Tianqi Chen
|
* \author Kailong Chen, Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <thrust/iterator/discard_iterator.h>
|
#include <thrust/iterator/discard_iterator.h>
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <array>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
|
|||||||
@ -5,7 +5,6 @@
|
|||||||
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
|
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -16,6 +15,7 @@
|
|||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|
||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/survival_util.h"
|
#include "../common/survival_util.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
@ -214,7 +214,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
|||||||
info.labels_upper_bound_, preds);
|
info.labels_upper_bound_, preds);
|
||||||
|
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
double dat[2]{result.Residue(), result.Weights()};
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||||
return Policy::GetFinal(dat[0], dat[1]);
|
return Policy::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,8 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "rabit/rabit.h"
|
|
||||||
#include "xgboost/generic_parameters.h"
|
#include "xgboost/generic_parameters.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
@ -39,7 +39,7 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
|||||||
auto const& h_node_idx = nidx;
|
auto const& h_node_idx = nidx;
|
||||||
|
|
||||||
size_t n_leaf{h_node_idx.size()};
|
size_t n_leaf{h_node_idx.size()};
|
||||||
rabit::Allreduce<rabit::op::Max>(&n_leaf, 1);
|
collective::Allreduce<collective::Operation::kMax>(&n_leaf, 1);
|
||||||
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
||||||
if (quantiles.empty()) {
|
if (quantiles.empty()) {
|
||||||
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
||||||
@ -49,12 +49,12 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
|||||||
std::vector<int32_t> n_valids(quantiles.size());
|
std::vector<int32_t> n_valids(quantiles.size());
|
||||||
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
|
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
|
||||||
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
|
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
|
||||||
rabit::Allreduce<rabit::op::Sum>(n_valids.data(), n_valids.size());
|
collective::Allreduce<collective::Operation::kSum>(n_valids.data(), n_valids.size());
|
||||||
// convert to 0 for all reduce
|
// convert to 0 for all reduce
|
||||||
std::replace_if(
|
std::replace_if(
|
||||||
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
|
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
|
||||||
// use the mean value
|
// use the mean value
|
||||||
rabit::Allreduce<rabit::op::Sum>(quantiles.data(), quantiles.size());
|
collective::Allreduce<collective::Operation::kSum>(quantiles.data(), quantiles.size());
|
||||||
for (size_t i = 0; i < n_leaf; ++i) {
|
for (size_t i = 0; i < n_leaf; ++i) {
|
||||||
if (n_valids[i] > 0) {
|
if (n_valids[i] > 0) {
|
||||||
quantiles[i] /= static_cast<float>(n_valids[i]);
|
quantiles[i] /= static_cast<float>(n_valids[i]);
|
||||||
|
|||||||
@ -724,8 +724,8 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Weighted average base score across all workers
|
// Weighted average base score across all workers
|
||||||
rabit::Allreduce<rabit::op::Sum>(out.Values().data(), out.Values().size());
|
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
|
||||||
rabit::Allreduce<rabit::op::Sum>(&w, 1);
|
collective::Allreduce<collective::Operation::kSum>(&w, 1);
|
||||||
|
|
||||||
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
||||||
[w](float v) { return v / w; });
|
[w](float v) { return v / w; });
|
||||||
|
|||||||
@ -84,11 +84,11 @@ GradientQuantizer::GradientQuantizer(common::Span<GradientPair const> gpair) {
|
|||||||
// Treat pair as array of 4 primitive types to allreduce
|
// Treat pair as array of 4 primitive types to allreduce
|
||||||
using ReduceT = typename decltype(p.first)::ValueT;
|
using ReduceT = typename decltype(p.first)::ValueT;
|
||||||
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
|
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
|
||||||
rabit::Allreduce<rabit::op::Sum, ReduceT>(reinterpret_cast<ReduceT*>(&p), 4);
|
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<ReduceT*>(&p), 4);
|
||||||
GradientPair positive_sum{p.first}, negative_sum{p.second};
|
GradientPair positive_sum{p.first}, negative_sum{p.second};
|
||||||
|
|
||||||
std::size_t total_rows = gpair.size();
|
std::size_t total_rows = gpair.size();
|
||||||
rabit::Allreduce<rabit::op::Sum>(&total_rows, 1);
|
collective::Allreduce<collective::Operation::kSum>(&total_rows, 1);
|
||||||
|
|
||||||
auto histogram_rounding = GradientSumT{
|
auto histogram_rounding = GradientSumT{
|
||||||
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), total_rows),
|
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), total_rows),
|
||||||
|
|||||||
@ -8,10 +8,10 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../collective/communicator-inl.h"
|
||||||
#include "../../common/hist_util.h"
|
#include "../../common/hist_util.h"
|
||||||
#include "../../data/gradient_index.h"
|
#include "../../data/gradient_index.h"
|
||||||
#include "expand_entry.h"
|
#include "expand_entry.h"
|
||||||
#include "rabit/rabit.h"
|
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -202,8 +202,9 @@ class HistogramBuilder {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
rabit::Allreduce<rabit::op::Sum>(reinterpret_cast<double*>(this->hist_[starting_index].data()),
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
builder_.GetNumBins() * sync_count * 2);
|
reinterpret_cast<double *>(this->hist_[starting_index].data()),
|
||||||
|
builder_.GetNumBins() * sync_count * 2);
|
||||||
|
|
||||||
ParallelSubtractionHist(space, nodes_for_explicit_hist_build,
|
ParallelSubtractionHist(space, nodes_for_explicit_hist_build,
|
||||||
nodes_for_subtraction_trick, p_tree);
|
nodes_for_subtraction_trick, p_tree);
|
||||||
|
|||||||
@ -74,7 +74,7 @@ class GloablApproxBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
||||||
rabit::IsDistributed());
|
collective::IsDistributed());
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ class GloablApproxBuilder {
|
|||||||
for (auto const &g : gpair) {
|
for (auto const &g : gpair) {
|
||||||
root_sum.Add(g);
|
root_sum.Add(g);
|
||||||
}
|
}
|
||||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
|
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||||
std::vector<CPUExpandEntry> nodes{best};
|
std::vector<CPUExpandEntry> nodes{best};
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||||
|
|||||||
@ -4,8 +4,6 @@
|
|||||||
* \brief use columnwise update to construct a tree
|
* \brief use columnwise update to construct a tree
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@ -100,7 +98,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
|
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||||
const std::vector<RegTree *> &trees) override {
|
const std::vector<RegTree *> &trees) override {
|
||||||
if (rabit::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
||||||
"support distributed training.";
|
"support distributed training.";
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,7 @@
|
|||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
|
#include "../collective/device_communicator.cuh"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
@ -528,13 +529,12 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// num histograms is the number of contiguous histograms in memory to reduce over
|
// num histograms is the number of contiguous histograms in memory to reduce over
|
||||||
void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) {
|
void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) {
|
||||||
monitor.Start("AllReduce");
|
monitor.Start("AllReduce");
|
||||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||||
reducer->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
|
communicator->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||||
reinterpret_cast<ReduceT*>(d_node_hist),
|
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
|
||||||
|
|
||||||
monitor.Stop("AllReduce");
|
monitor.Stop("AllReduce");
|
||||||
}
|
}
|
||||||
@ -542,8 +542,8 @@ struct GPUHistMakerDevice {
|
|||||||
/**
|
/**
|
||||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||||
*/
|
*/
|
||||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, dh::AllReducer* reducer,
|
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates,
|
||||||
const RegTree& tree) {
|
collective::DeviceCommunicator* communicator, const RegTree& tree) {
|
||||||
if (candidates.empty()) return;
|
if (candidates.empty()) return;
|
||||||
// Some nodes we will manually compute histograms
|
// Some nodes we will manually compute histograms
|
||||||
// others we will do by subtraction
|
// others we will do by subtraction
|
||||||
@ -574,7 +574,7 @@ struct GPUHistMakerDevice {
|
|||||||
// Reduce all in one go
|
// Reduce all in one go
|
||||||
// This gives much better latency in a distributed setting
|
// This gives much better latency in a distributed setting
|
||||||
// when processing a large batch
|
// when processing a large batch
|
||||||
this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size());
|
this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size());
|
||||||
|
|
||||||
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
||||||
auto build_hist_nidx = hist_nidx.at(i);
|
auto build_hist_nidx = hist_nidx.at(i);
|
||||||
@ -584,7 +584,7 @@ struct GPUHistMakerDevice {
|
|||||||
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
||||||
// Calculate other histogram manually
|
// Calculate other histogram manually
|
||||||
this->BuildHist(subtraction_trick_nidx);
|
this->BuildHist(subtraction_trick_nidx);
|
||||||
this->AllReduceHist(subtraction_trick_nidx, reducer, 1);
|
this->AllReduceHist(subtraction_trick_nidx, communicator, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -593,7 +593,7 @@ struct GPUHistMakerDevice {
|
|||||||
RegTree& tree = *p_tree;
|
RegTree& tree = *p_tree;
|
||||||
|
|
||||||
// Sanity check - have we created a leaf with no training instances?
|
// Sanity check - have we created a leaf with no training instances?
|
||||||
if (!rabit::IsDistributed() && row_partitioner) {
|
if (!collective::IsDistributed() && row_partitioner) {
|
||||||
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
|
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
|
||||||
<< "No training instances in this leaf!";
|
<< "No training instances in this leaf!";
|
||||||
}
|
}
|
||||||
@ -642,7 +642,7 @@ struct GPUHistMakerDevice {
|
|||||||
parent.RightChild());
|
parent.RightChild());
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) {
|
||||||
constexpr bst_node_t kRootNIdx = 0;
|
constexpr bst_node_t kRootNIdx = 0;
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
|
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
|
||||||
@ -650,11 +650,11 @@ struct GPUHistMakerDevice {
|
|||||||
GradientPairPrecise root_sum =
|
GradientPairPrecise root_sum =
|
||||||
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(),
|
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(),
|
||||||
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
|
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
|
||||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2);
|
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double*>(&root_sum), 2);
|
||||||
|
|
||||||
hist.AllocateHistograms({kRootNIdx});
|
hist.AllocateHistograms({kRootNIdx});
|
||||||
this->BuildHist(kRootNIdx);
|
this->BuildHist(kRootNIdx);
|
||||||
this->AllReduceHist(kRootNIdx, reducer, 1);
|
this->AllReduceHist(kRootNIdx, communicator, 1);
|
||||||
|
|
||||||
// Remember root stats
|
// Remember root stats
|
||||||
node_sum_gradients[kRootNIdx] = root_sum;
|
node_sum_gradients[kRootNIdx] = root_sum;
|
||||||
@ -669,7 +669,7 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
|
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
|
||||||
RegTree* p_tree, dh::AllReducer* reducer,
|
RegTree* p_tree, collective::DeviceCommunicator* communicator,
|
||||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
// Process maximum 32 nodes at a time
|
// Process maximum 32 nodes at a time
|
||||||
@ -680,7 +680,7 @@ struct GPUHistMakerDevice {
|
|||||||
monitor.Stop("Reset");
|
monitor.Stop("Reset");
|
||||||
|
|
||||||
monitor.Start("InitRoot");
|
monitor.Start("InitRoot");
|
||||||
driver.Push({ this->InitRoot(p_tree, reducer) });
|
driver.Push({ this->InitRoot(p_tree, communicator) });
|
||||||
monitor.Stop("InitRoot");
|
monitor.Stop("InitRoot");
|
||||||
|
|
||||||
// The set of leaves that can be expanded asynchronously
|
// The set of leaves that can be expanded asynchronously
|
||||||
@ -707,7 +707,7 @@ struct GPUHistMakerDevice {
|
|||||||
monitor.Stop("UpdatePosition");
|
monitor.Stop("UpdatePosition");
|
||||||
|
|
||||||
monitor.Start("BuildHist");
|
monitor.Start("BuildHist");
|
||||||
this->BuildHistLeftRight(filtered_expand_set, reducer, tree);
|
this->BuildHistLeftRight(filtered_expand_set, communicator, tree);
|
||||||
monitor.Stop("BuildHist");
|
monitor.Stop("BuildHist");
|
||||||
|
|
||||||
monitor.Start("EvaluateSplits");
|
monitor.Start("EvaluateSplits");
|
||||||
@ -789,11 +789,10 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
void InitDataOnce(DMatrix* dmat) {
|
void InitDataOnce(DMatrix* dmat) {
|
||||||
CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
|
CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
|
||||||
info_ = &dmat->Info();
|
info_ = &dmat->Info();
|
||||||
reducer_.Init({ctx_->gpu_id}); // NOLINT
|
|
||||||
|
|
||||||
// Synchronise the column sampling seed
|
// Synchronise the column sampling seed
|
||||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||||
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||||
|
|
||||||
BatchParam batch_param{
|
BatchParam batch_param{
|
||||||
ctx_->gpu_id,
|
ctx_->gpu_id,
|
||||||
@ -823,12 +822,12 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
void CheckTreesSynchronized(RegTree* local_tree) const {
|
void CheckTreesSynchronized(RegTree* local_tree) const {
|
||||||
std::string s_model;
|
std::string s_model;
|
||||||
common::MemoryBufferStream fs(&s_model);
|
common::MemoryBufferStream fs(&s_model);
|
||||||
int rank = rabit::GetRank();
|
int rank = collective::GetRank();
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
local_tree->Save(&fs);
|
local_tree->Save(&fs);
|
||||||
}
|
}
|
||||||
fs.Seek(0);
|
fs.Seek(0);
|
||||||
rabit::Broadcast(&s_model, 0);
|
collective::Broadcast(&s_model, 0);
|
||||||
RegTree reference_tree{}; // rank 0 tree
|
RegTree reference_tree{}; // rank 0 tree
|
||||||
reference_tree.Load(&fs);
|
reference_tree.Load(&fs);
|
||||||
CHECK(*local_tree == reference_tree);
|
CHECK(*local_tree == reference_tree);
|
||||||
@ -841,7 +840,8 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
monitor_.Stop("InitData");
|
monitor_.Stop("InitData");
|
||||||
|
|
||||||
gpair->SetDevice(ctx_->gpu_id);
|
gpair->SetDevice(ctx_->gpu_id);
|
||||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
|
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id);
|
||||||
|
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix* data,
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
@ -867,8 +867,6 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
GPUHistMakerTrainParam hist_maker_param_;
|
GPUHistMakerTrainParam hist_maker_param_;
|
||||||
|
|
||||||
dh::AllReducer reducer_;
|
|
||||||
|
|
||||||
DMatrix* p_last_fmat_{nullptr};
|
DMatrix* p_last_fmat_{nullptr};
|
||||||
RegTree const* p_last_tree_{nullptr};
|
RegTree const* p_last_tree_{nullptr};
|
||||||
ObjInfo task_;
|
ObjInfo task_;
|
||||||
|
|||||||
@ -4,16 +4,13 @@
|
|||||||
* \brief prune a tree given the statistics
|
* \brief prune a tree given the statistics
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "../common/io.h"
|
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|||||||
@ -6,19 +6,12 @@
|
|||||||
*/
|
*/
|
||||||
#include "./updater_quantile_hist.h"
|
#include "./updater_quantile_hist.h"
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <numeric>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../common/column_matrix.h"
|
|
||||||
#include "../common/hist_util.h"
|
|
||||||
#include "../common/random.h"
|
|
||||||
#include "../common/threading_utils.h"
|
|
||||||
#include "constraints.h"
|
#include "constraints.h"
|
||||||
#include "hist/evaluate_splits.h"
|
#include "hist/evaluate_splits.h"
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
@ -103,7 +96,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot(
|
|||||||
for (auto const &grad : gpair_h) {
|
for (auto const &grad : gpair_h) {
|
||||||
grad_stat.Add(grad.GetGrad(), grad.GetHess());
|
grad_stat.Add(grad.GetGrad(), grad.GetHess());
|
||||||
}
|
}
|
||||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&grad_stat), 2);
|
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&grad_stat), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||||
@ -320,7 +313,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
|||||||
++page_id;
|
++page_id;
|
||||||
}
|
}
|
||||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
rabit::IsDistributed());
|
collective::IsDistributed());
|
||||||
|
|
||||||
if (param_.subsample < 1.0f) {
|
if (param_.subsample < 1.0f) {
|
||||||
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|||||||
@ -4,17 +4,17 @@
|
|||||||
* \brief refresh the statistics and leaf value on the tree on the dataset
|
* \brief refresh the statistics and leaf value on the tree on the dataset
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/json.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "./param.h"
|
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../predictor/predict_fn.h"
|
#include "../predictor/predict_fn.h"
|
||||||
|
#include "./param.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -100,8 +100,9 @@ class TreeRefresher : public TreeUpdater {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
rabit::Allreduce<rabit::op::Sum>(&dmlc::BeginPtr(stemp[0])->sum_grad, stemp[0].size() * 2,
|
lazy_get_stats();
|
||||||
lazy_get_stats);
|
collective::Allreduce<collective::Operation::kSum>(&dmlc::BeginPtr(stemp[0])->sum_grad,
|
||||||
|
stemp[0].size() * 2);
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
float lr = param_.learning_rate;
|
||||||
param_.learning_rate = lr / trees.size();
|
param_.learning_rate = lr / trees.size();
|
||||||
|
|||||||
@ -4,12 +4,14 @@
|
|||||||
* \brief synchronize the tree in all distributed nodes
|
* \brief synchronize the tree in all distributed nodes
|
||||||
*/
|
*/
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include "xgboost/json.h"
|
#include <limits>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -35,17 +37,17 @@ class TreeSyncher : public TreeUpdater {
|
|||||||
void Update(HostDeviceVector<GradientPair>*, DMatrix*,
|
void Update(HostDeviceVector<GradientPair>*, DMatrix*,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||||
const std::vector<RegTree*>& trees) override {
|
const std::vector<RegTree*>& trees) override {
|
||||||
if (rabit::GetWorldSize() == 1) return;
|
if (collective::GetWorldSize() == 1) return;
|
||||||
std::string s_model;
|
std::string s_model;
|
||||||
common::MemoryBufferStream fs(&s_model);
|
common::MemoryBufferStream fs(&s_model);
|
||||||
int rank = rabit::GetRank();
|
int rank = collective::GetRank();
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
tree->Save(&fs);
|
tree->Save(&fs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fs.Seek(0);
|
fs.Seek(0);
|
||||||
rabit::Broadcast(&s_model, 0);
|
collective::Broadcast(&s_model, 0);
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
tree->Load(&fs);
|
tree->Load(&fs);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -46,8 +46,8 @@ template <bool use_column>
|
|||||||
void TestDistributedQuantile(size_t rows, size_t cols) {
|
void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||||
std::string msg {"Skipping AllReduce test"};
|
std::string msg {"Skipping AllReduce test"};
|
||||||
int32_t constexpr kWorkers = 4;
|
int32_t constexpr kWorkers = 4;
|
||||||
InitRabitContext(msg, kWorkers);
|
InitCommunicatorContext(msg, kWorkers);
|
||||||
auto world = rabit::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world != 1) {
|
if (world != 1) {
|
||||||
ASSERT_EQ(world, kWorkers);
|
ASSERT_EQ(world, kWorkers);
|
||||||
} else {
|
} else {
|
||||||
@ -65,7 +65,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
|
|
||||||
// Generate cuts for distributed environment.
|
// Generate cuts for distributed environment.
|
||||||
auto sparsity = 0.5f;
|
auto sparsity = 0.5f;
|
||||||
auto rank = rabit::GetRank();
|
auto rank = collective::GetRank();
|
||||||
std::vector<FeatureType> ft(cols);
|
std::vector<FeatureType> ft(cols);
|
||||||
for (size_t i = 0; i < ft.size(); ++i) {
|
for (size_t i = 0; i < ft.size(); ++i) {
|
||||||
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
||||||
@ -99,8 +99,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
sketch_distributed.MakeCuts(&distributed_cuts);
|
sketch_distributed.MakeCuts(&distributed_cuts);
|
||||||
|
|
||||||
// Generate cuts for single node environment
|
// Generate cuts for single node environment
|
||||||
rabit::Finalize();
|
collective::Finalize();
|
||||||
CHECK_EQ(rabit::GetWorldSize(), 1);
|
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||||
m->Info().num_row_ = world * rows;
|
m->Info().num_row_ = world * rows;
|
||||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||||
@ -184,8 +184,8 @@ TEST(Quantile, SameOnAllWorkers) {
|
|||||||
#if defined(__unix__)
|
#if defined(__unix__)
|
||||||
std::string msg{"Skipping Quantile AllreduceBasic test"};
|
std::string msg{"Skipping Quantile AllreduceBasic test"};
|
||||||
int32_t constexpr kWorkers = 4;
|
int32_t constexpr kWorkers = 4;
|
||||||
InitRabitContext(msg, kWorkers);
|
InitCommunicatorContext(msg, kWorkers);
|
||||||
auto world = rabit::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world != 1) {
|
if (world != 1) {
|
||||||
CHECK_EQ(world, kWorkers);
|
CHECK_EQ(world, kWorkers);
|
||||||
} else {
|
} else {
|
||||||
@ -196,7 +196,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
|||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(
|
RunWithSeedsAndBins(
|
||||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||||
auto rank = rabit::GetRank();
|
auto rank = collective::GetRank();
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
std::vector<FeatureType> ft(kCols);
|
std::vector<FeatureType> ft(kCols);
|
||||||
for (size_t i = 0; i < ft.size(); ++i) {
|
for (size_t i = 0; i < ft.size(); ++i) {
|
||||||
@ -217,12 +217,12 @@ TEST(Quantile, SameOnAllWorkers) {
|
|||||||
std::vector<float> cut_min_values(cuts.MinValues().size() * world, 0);
|
std::vector<float> cut_min_values(cuts.MinValues().size() * world, 0);
|
||||||
|
|
||||||
size_t value_size = cuts.Values().size();
|
size_t value_size = cuts.Values().size();
|
||||||
rabit::Allreduce<rabit::op::Max>(&value_size, 1);
|
collective::Allreduce<collective::Operation::kMax>(&value_size, 1);
|
||||||
size_t ptr_size = cuts.Ptrs().size();
|
size_t ptr_size = cuts.Ptrs().size();
|
||||||
rabit::Allreduce<rabit::op::Max>(&ptr_size, 1);
|
collective::Allreduce<collective::Operation::kMax>(&ptr_size, 1);
|
||||||
CHECK_EQ(ptr_size, kCols + 1);
|
CHECK_EQ(ptr_size, kCols + 1);
|
||||||
size_t min_value_size = cuts.MinValues().size();
|
size_t min_value_size = cuts.MinValues().size();
|
||||||
rabit::Allreduce<rabit::op::Max>(&min_value_size, 1);
|
collective::Allreduce<collective::Operation::kMax>(&min_value_size, 1);
|
||||||
CHECK_EQ(min_value_size, kCols);
|
CHECK_EQ(min_value_size, kCols);
|
||||||
|
|
||||||
size_t value_offset = value_size * rank;
|
size_t value_offset = value_size * rank;
|
||||||
@ -235,9 +235,9 @@ TEST(Quantile, SameOnAllWorkers) {
|
|||||||
std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(),
|
std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(),
|
||||||
cut_min_values.begin() + min_values_offset);
|
cut_min_values.begin() + min_values_offset);
|
||||||
|
|
||||||
rabit::Allreduce<rabit::op::Sum>(cut_values.data(), cut_values.size());
|
collective::Allreduce<collective::Operation::kSum>(cut_values.data(), cut_values.size());
|
||||||
rabit::Allreduce<rabit::op::Sum>(cut_ptrs.data(), cut_ptrs.size());
|
collective::Allreduce<collective::Operation::kSum>(cut_ptrs.data(), cut_ptrs.size());
|
||||||
rabit::Allreduce<rabit::op::Sum>(cut_min_values.data(), cut_min_values.size());
|
collective::Allreduce<collective::Operation::kSum>(cut_min_values.data(), cut_min_values.size());
|
||||||
|
|
||||||
for (int32_t i = 0; i < world; i++) {
|
for (int32_t i = 0; i < world; i++) {
|
||||||
for (size_t j = 0; j < value_size; ++j) {
|
for (size_t j = 0; j < value_size; ++j) {
|
||||||
@ -256,7 +256,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
rabit::Finalize();
|
collective::Finalize();
|
||||||
#endif // defined(__unix__)
|
#endif // defined(__unix__)
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "test_quantile.h"
|
#include "test_quantile.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
#include "../../../src/collective/device_communicator.cuh"
|
||||||
#include "../../../src/common/hist_util.cuh"
|
#include "../../../src/common/hist_util.cuh"
|
||||||
#include "../../../src/common/quantile.cuh"
|
#include "../../../src/common/quantile.cuh"
|
||||||
|
|
||||||
@ -341,17 +342,14 @@ TEST(GPUQuantile, AllReduceBasic) {
|
|||||||
// This test is supposed to run by a python test that setups the environment.
|
// This test is supposed to run by a python test that setups the environment.
|
||||||
std::string msg {"Skipping AllReduce test"};
|
std::string msg {"Skipping AllReduce test"};
|
||||||
auto n_gpus = AllVisibleGPUs();
|
auto n_gpus = AllVisibleGPUs();
|
||||||
InitRabitContext(msg, n_gpus);
|
InitCommunicatorContext(msg, n_gpus);
|
||||||
auto world = rabit::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world != 1) {
|
if (world != 1) {
|
||||||
ASSERT_EQ(world, n_gpus);
|
ASSERT_EQ(world, n_gpus);
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto reducer = std::make_shared<dh::AllReducer>();
|
|
||||||
reducer->Init(0);
|
|
||||||
|
|
||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
||||||
// Set up single node version;
|
// Set up single node version;
|
||||||
@ -385,8 +383,8 @@ TEST(GPUQuantile, AllReduceBasic) {
|
|||||||
|
|
||||||
// Set up distributed version. We rely on using rank as seed to generate
|
// Set up distributed version. We rely on using rank as seed to generate
|
||||||
// the exact same copy of data.
|
// the exact same copy of data.
|
||||||
auto rank = rabit::GetRank();
|
auto rank = collective::GetRank();
|
||||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
|
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||||
.Device(0)
|
.Device(0)
|
||||||
@ -422,28 +420,26 @@ TEST(GPUQuantile, AllReduceBasic) {
|
|||||||
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
|
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
rabit::Finalize();
|
collective::Finalize();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GPUQuantile, SameOnAllWorkers) {
|
TEST(GPUQuantile, SameOnAllWorkers) {
|
||||||
std::string msg {"Skipping SameOnAllWorkers test"};
|
std::string msg {"Skipping SameOnAllWorkers test"};
|
||||||
auto n_gpus = AllVisibleGPUs();
|
auto n_gpus = AllVisibleGPUs();
|
||||||
InitRabitContext(msg, n_gpus);
|
InitCommunicatorContext(msg, n_gpus);
|
||||||
auto world = rabit::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world != 1) {
|
if (world != 1) {
|
||||||
ASSERT_EQ(world, n_gpus);
|
ASSERT_EQ(world, n_gpus);
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto reducer = std::make_shared<dh::AllReducer>();
|
|
||||||
reducer->Init(0);
|
|
||||||
|
|
||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
||||||
MetaInfo const &info) {
|
MetaInfo const &info) {
|
||||||
auto rank = rabit::GetRank();
|
auto rank = collective::GetRank();
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
|
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||||
.Device(0)
|
.Device(0)
|
||||||
@ -459,7 +455,7 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
|||||||
|
|
||||||
// Test for all workers having the same sketch.
|
// Test for all workers having the same sketch.
|
||||||
size_t n_data = sketch_distributed.Data().size();
|
size_t n_data = sketch_distributed.Data().size();
|
||||||
rabit::Allreduce<rabit::op::Max>(&n_data, 1);
|
collective::Allreduce<collective::Operation::kMax>(&n_data, 1);
|
||||||
ASSERT_EQ(n_data, sketch_distributed.Data().size());
|
ASSERT_EQ(n_data, sketch_distributed.Data().size());
|
||||||
size_t size_as_float =
|
size_t size_as_float =
|
||||||
sketch_distributed.Data().size_bytes() / sizeof(float);
|
sketch_distributed.Data().size_bytes() / sizeof(float);
|
||||||
@ -472,9 +468,10 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
|||||||
thrust::copy(thrust::device, local_data.data(),
|
thrust::copy(thrust::device, local_data.data(),
|
||||||
local_data.data() + local_data.size(),
|
local_data.data() + local_data.size(),
|
||||||
all_workers.begin() + local_data.size() * rank);
|
all_workers.begin() + local_data.size() * rank);
|
||||||
reducer->AllReduceSum(all_workers.data().get(), all_workers.data().get(),
|
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(0);
|
||||||
all_workers.size());
|
|
||||||
reducer->Synchronize();
|
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
|
||||||
|
communicator->Synchronize();
|
||||||
|
|
||||||
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
|
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
|
||||||
std::vector<float> h_base_line(base_line.size());
|
std::vector<float> h_base_line(base_line.size());
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
#ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
#ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
||||||
#define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
#define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
||||||
|
|
||||||
#include <rabit/rabit.h>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
#include "../../src/collective/communicator-inl.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
inline void InitRabitContext(std::string msg, int32_t n_workers) {
|
inline void InitCommunicatorContext(std::string msg, int32_t n_workers) {
|
||||||
auto port = std::getenv("DMLC_TRACKER_PORT");
|
auto port = std::getenv("DMLC_TRACKER_PORT");
|
||||||
std::string port_str;
|
std::string port_str;
|
||||||
if (port) {
|
if (port) {
|
||||||
@ -28,12 +28,11 @@ inline void InitRabitContext(std::string msg, int32_t n_workers) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> envs{
|
Json config{JsonObject()};
|
||||||
"DMLC_TRACKER_PORT=" + port_str,
|
config["DMLC_TRACKER_PORT"] = port_str;
|
||||||
"DMLC_TRACKER_URI=" + uri_str,
|
config["DMLC_TRACKER_URI"] = uri_str;
|
||||||
"DMLC_NUM_WORKER=" + std::to_string(n_workers)};
|
config["DMLC_NUM_WORKER"] = n_workers;
|
||||||
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
|
collective::Init(config);
|
||||||
rabit::Init(3, c_envs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
|
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
|
||||||
|
|||||||
@ -21,21 +21,19 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
||||||
rabit_env = [
|
communicator_env = {
|
||||||
'xgboost_communicator=federated',
|
'xgboost_communicator': 'federated',
|
||||||
f'federated_server_address=localhost:{port}',
|
'federated_server_address': f'localhost:{port}',
|
||||||
f'federated_world_size={world_size}',
|
'federated_world_size': world_size,
|
||||||
f'federated_rank={rank}'
|
'federated_rank': rank
|
||||||
]
|
}
|
||||||
if with_ssl:
|
if with_ssl:
|
||||||
rabit_env = rabit_env + [
|
communicator_env['federated_server_cert'] = SERVER_CERT
|
||||||
f'federated_server_cert={SERVER_CERT}',
|
communicator_env['federated_client_key'] = CLIENT_KEY
|
||||||
f'federated_client_key={CLIENT_KEY}',
|
communicator_env['federated_client_cert'] = CLIENT_CERT
|
||||||
f'federated_client_cert={CLIENT_CERT}'
|
|
||||||
]
|
|
||||||
|
|
||||||
# Always call this before using distributed module
|
# Always call this before using distributed module
|
||||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||||
# Load file, file will not be sharded in federated mode.
|
# Load file, file will not be sharded in federated mode.
|
||||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
||||||
@ -55,9 +53,9 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
|
|||||||
early_stopping_rounds=2)
|
early_stopping_rounds=2)
|
||||||
|
|
||||||
# Save the model, only ask process 0 to save the model.
|
# Save the model, only ask process 0 to save the model.
|
||||||
if xgb.rabit.get_rank() == 0:
|
if xgb.collective.get_rank() == 0:
|
||||||
bst.save_model("test.model.json")
|
bst.save_model("test.model.json")
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.collective.communicator_print("Finished training\n")
|
||||||
|
|
||||||
|
|
||||||
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Copyright 2019-2022 XGBoost contributors"""
|
"""Copyright 2019-2022 XGBoost contributors"""
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
from typing import Type, TypeVar, Any, Dict, List
|
from typing import Type, TypeVar, Any, Dict, List, Union
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -425,7 +425,7 @@ class TestDistributedGPU:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
with dxgb.RabitContext(rabit_args):
|
with dxgb.CommunicatorContext(**rabit_args):
|
||||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
||||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||||
assert fw_rows == local_dtrain.num_col()
|
assert fw_rows == local_dtrain.num_col()
|
||||||
@ -505,20 +505,13 @@ class TestDistributedGPU:
|
|||||||
test = "--gtest_filter=GPUQuantile." + name
|
test = "--gtest_filter=GPUQuantile." + name
|
||||||
|
|
||||||
def runit(
|
def runit(
|
||||||
worker_addr: str, rabit_args: List[bytes]
|
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||||
) -> subprocess.CompletedProcess:
|
) -> subprocess.CompletedProcess:
|
||||||
port_env = ""
|
port_env = ""
|
||||||
# setup environment for running the c++ part.
|
# setup environment for running the c++ part.
|
||||||
for arg in rabit_args:
|
|
||||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_PORT"):
|
|
||||||
port_env = arg.decode("utf-8")
|
|
||||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
|
||||||
uri_env = arg.decode("utf-8")
|
|
||||||
port = port_env.split("=")
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env[port[0]] = port[1]
|
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||||
uri = uri_env.split("=")
|
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||||
env[uri[0]] = uri[1]
|
|
||||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||||
|
|
||||||
workers = _get_client_workers(local_cuda_client)
|
workers = _get_client_workers(local_cuda_client)
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import RabitTracker
|
from xgboost import RabitTracker, build_info, federated
|
||||||
from xgboost import collective
|
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
|
||||||
@ -37,3 +37,41 @@ def test_rabit_communicator():
|
|||||||
for worker in workers:
|
for worker in workers:
|
||||||
worker.join()
|
worker.join()
|
||||||
assert worker.exitcode == 0
|
assert worker.exitcode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def run_federated_worker(port, world_size, rank):
|
||||||
|
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
|
||||||
|
federated_server_address=f'localhost:{port}',
|
||||||
|
federated_world_size=world_size,
|
||||||
|
federated_rank=rank):
|
||||||
|
assert xgb.collective.get_world_size() == world_size
|
||||||
|
assert xgb.collective.is_distributed()
|
||||||
|
assert xgb.collective.get_processor_name() == f'rank{rank}'
|
||||||
|
ret = xgb.collective.broadcast('test1234', 0)
|
||||||
|
assert str(ret) == 'test1234'
|
||||||
|
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||||
|
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_federated_communicator():
|
||||||
|
if not build_info()["USE_FEDERATED"]:
|
||||||
|
pytest.skip("XGBoost not built with federated learning enabled")
|
||||||
|
|
||||||
|
port = 9091
|
||||||
|
world_size = 2
|
||||||
|
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
|
||||||
|
server.start()
|
||||||
|
time.sleep(1)
|
||||||
|
if not server.is_alive():
|
||||||
|
raise Exception("Error starting Federated Learning server")
|
||||||
|
|
||||||
|
workers = []
|
||||||
|
for rank in range(world_size):
|
||||||
|
worker = multiprocessing.Process(target=run_federated_worker,
|
||||||
|
args=(port, world_size, rank))
|
||||||
|
workers.append(worker)
|
||||||
|
worker.start()
|
||||||
|
for worker in workers:
|
||||||
|
worker.join()
|
||||||
|
assert worker.exitcode == 0
|
||||||
|
server.terminate()
|
||||||
|
|||||||
@ -15,37 +15,33 @@ if sys.platform.startswith("win"):
|
|||||||
def test_rabit_tracker():
|
def test_rabit_tracker():
|
||||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||||
tracker.start(1)
|
tracker.start(1)
|
||||||
worker_env = tracker.worker_envs()
|
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
|
||||||
rabit_env = []
|
ret = xgb.collective.broadcast("test1234", 0)
|
||||||
for k, v in worker_env.items():
|
|
||||||
rabit_env.append(f"{k}={v}".encode())
|
|
||||||
with xgb.rabit.RabitContext(rabit_env):
|
|
||||||
ret = xgb.rabit.broadcast("test1234", 0)
|
|
||||||
assert str(ret) == "test1234"
|
assert str(ret) == "test1234"
|
||||||
|
|
||||||
|
|
||||||
def run_rabit_ops(client, n_workers):
|
def run_rabit_ops(client, n_workers):
|
||||||
from test_with_dask import _get_client_workers
|
from test_with_dask import _get_client_workers
|
||||||
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
|
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||||
|
|
||||||
from xgboost import rabit
|
from xgboost import collective
|
||||||
|
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||||
assert not rabit.is_distributed()
|
assert not collective.is_distributed()
|
||||||
n_workers_from_dask = len(workers)
|
n_workers_from_dask = len(workers)
|
||||||
assert n_workers == n_workers_from_dask
|
assert n_workers == n_workers_from_dask
|
||||||
|
|
||||||
def local_test(worker_id):
|
def local_test(worker_id):
|
||||||
with RabitContext(rabit_args):
|
with CommunicatorContext(**rabit_args):
|
||||||
a = 1
|
a = 1
|
||||||
assert rabit.is_distributed()
|
assert collective.is_distributed()
|
||||||
a = np.array([a])
|
a = np.array([a])
|
||||||
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
reduced = collective.allreduce(a, collective.Op.SUM)
|
||||||
assert reduced[0] == n_workers
|
assert reduced[0] == n_workers
|
||||||
|
|
||||||
worker_id = np.array([worker_id])
|
worker_id = np.array([worker_id])
|
||||||
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
reduced = collective.allreduce(worker_id, collective.Op.MAX)
|
||||||
assert reduced == n_workers - 1
|
assert reduced == n_workers - 1
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
@ -83,14 +79,10 @@ def test_rank_assignment() -> None:
|
|||||||
from test_with_dask import _get_client_workers
|
from test_with_dask import _get_client_workers
|
||||||
|
|
||||||
def local_test(worker_id):
|
def local_test(worker_id):
|
||||||
with xgb.dask.RabitContext(args):
|
with xgb.dask.CommunicatorContext(**args) as ctx:
|
||||||
for val in args:
|
task_id = ctx["DMLC_TASK_ID"]
|
||||||
sval = val.decode("utf-8")
|
|
||||||
if sval.startswith("DMLC_TASK_ID"):
|
|
||||||
task_id = sval
|
|
||||||
break
|
|
||||||
matched = re.search(".*-([0-9]).*", task_id)
|
matched = re.search(".*-([0-9]).*", task_id)
|
||||||
rank = xgb.rabit.get_rank()
|
rank = xgb.collective.get_rank()
|
||||||
# As long as the number of workers is lesser than 10, rank and worker id
|
# As long as the number of workers is lesser than 10, rank and worker id
|
||||||
# should be the same
|
# should be the same
|
||||||
assert rank == int(matched.group(1))
|
assert rank == int(matched.group(1))
|
||||||
|
|||||||
@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
|
|||||||
|
|
||||||
class TestWithDask:
|
class TestWithDask:
|
||||||
def test_dmatrix_binary(self, client: "Client") -> None:
|
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||||
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
|
||||||
with xgb.dask.RabitContext(rabit_args):
|
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||||
rank = xgb.rabit.get_rank()
|
rank = xgb.collective.get_rank()
|
||||||
X, y = tm.make_categorical(100, 4, 4, False)
|
X, y = tm.make_categorical(100, 4, 4, False)
|
||||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||||
Xy.save_binary(path)
|
Xy.save_binary(path)
|
||||||
|
|
||||||
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
|
||||||
with xgb.dask.RabitContext(rabit_args):
|
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||||
rank = xgb.rabit.get_rank()
|
rank = xgb.collective.get_rank()
|
||||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||||
Xy = xgb.DMatrix(path)
|
Xy = xgb.DMatrix(path)
|
||||||
assert Xy.num_row() == 100
|
assert Xy.num_row() == 100
|
||||||
@ -1488,20 +1488,13 @@ class TestWithDask:
|
|||||||
test = "--gtest_filter=Quantile." + name
|
test = "--gtest_filter=Quantile." + name
|
||||||
|
|
||||||
def runit(
|
def runit(
|
||||||
worker_addr: str, rabit_args: List[bytes]
|
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||||
) -> subprocess.CompletedProcess:
|
) -> subprocess.CompletedProcess:
|
||||||
port_env = ''
|
port_env = ''
|
||||||
# setup environment for running the c++ part.
|
# setup environment for running the c++ part.
|
||||||
for arg in rabit_args:
|
|
||||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
|
||||||
port_env = arg.decode('utf-8')
|
|
||||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
|
||||||
uri_env = arg.decode("utf-8")
|
|
||||||
port = port_env.split('=')
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env[port[0]] = port[1]
|
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||||
uri = uri_env.split("=")
|
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||||
env["DMLC_TRACKER_URI"] = uri[1]
|
|
||||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||||
|
|
||||||
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
|
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
|
||||||
@ -1543,8 +1536,8 @@ class TestWithDask:
|
|||||||
def get_score(config: Dict) -> float:
|
def get_score(config: Dict) -> float:
|
||||||
return float(config["learner"]["learner_model_param"]["base_score"])
|
return float(config["learner"]["learner_model_param"]["base_score"])
|
||||||
|
|
||||||
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
|
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
|
||||||
with xgb.dask.RabitContext(rabit_args):
|
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||||
if worker_id == 0:
|
if worker_id == 0:
|
||||||
y = np.array([0.0, 0.0, 0.0])
|
y = np.array([0.0, 0.0, 0.0])
|
||||||
x = np.array([[0.0]] * 3)
|
x = np.array([[0.0]] * 3)
|
||||||
@ -1686,12 +1679,12 @@ class TestWithDask:
|
|||||||
n_workers = len(workers)
|
n_workers = len(workers)
|
||||||
|
|
||||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
with xgb.dask.RabitContext(rabit_args):
|
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||||
**data_ref, nthread=7
|
**data_ref, nthread=7
|
||||||
)
|
)
|
||||||
total = np.array([local_dtrain.num_row()])
|
total = np.array([local_dtrain.num_row()])
|
||||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
|
||||||
assert total[0] == kRows
|
assert total[0] == kRows
|
||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user