[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink
|
||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
@@ -46,7 +46,7 @@ object XGBoost {
|
||||
collector: Collector[XGBoostModel]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||
logger.info("start with env" + workerEnvs.toString)
|
||||
Rabit.init(workerEnvs)
|
||||
Communicator.init(workerEnvs)
|
||||
val mapper = (x: LabeledVector) => {
|
||||
val (index, value) = x.vector.toSeq.unzip
|
||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
@@ -59,7 +59,7 @@ object XGBoost {
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
collector.collect(new XGBoostModel(booster))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user