[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -22,7 +22,7 @@ import java.util.ServiceLoader
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
@@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
Communicator.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||
@@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
override def next(): Row = {
|
||||
val ret = batchIterImpl.next()
|
||||
if (!batchIterImpl.hasNext) {
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
@@ -303,7 +303,7 @@ object XGBoost extends Serializable {
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
Communicator.init(rabitEnv)
|
||||
|
||||
watches = buildWatchesAndCheck(buildWatches)
|
||||
|
||||
@@ -342,7 +342,7 @@ object XGBoost extends Serializable {
|
||||
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
|
||||
throw xgbException
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
if (watches != null) watches.delete()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user