[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink
|
||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
@@ -46,7 +46,7 @@ object XGBoost {
|
||||
collector: Collector[XGBoostModel]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||
logger.info("start with env" + workerEnvs.toString)
|
||||
Rabit.init(workerEnvs)
|
||||
Communicator.init(workerEnvs)
|
||||
val mapper = (x: LabeledVector) => {
|
||||
val (index, value) = x.vector.toSeq.unzip
|
||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
@@ -59,7 +59,7 @@ object XGBoost {
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
collector.collect(new XGBoostModel(booster))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import java.util.ServiceLoader
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
@@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
Communicator.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||
@@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
override def next(): Row = {
|
||||
val ret = batchIterImpl.next()
|
||||
if (!batchIterImpl.hasNext) {
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
@@ -303,7 +303,7 @@ object XGBoost extends Serializable {
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
Communicator.init(rabitEnv)
|
||||
|
||||
watches = buildWatchesAndCheck(buildWatches)
|
||||
|
||||
@@ -342,7 +342,7 @@ object XGBoost extends Serializable {
|
||||
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
|
||||
throw xgbException
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
Communicator.shutdown()
|
||||
if (watches != null) watches.delete()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,277 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.scalatest.{FunSuite}
|
||||
|
||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
||||
val classifier = new XGBoostClassifier(paramMap)
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
|
||||
xgbParamsFactory.buildXGBRuntimeParams
|
||||
}
|
||||
|
||||
|
||||
test("Customize host ip and python exec for Rabit tracker") {
|
||||
val hostIp = "192.168.22.111"
|
||||
val pythonExec = "/usr/bin/python3"
|
||||
|
||||
val paramMap = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
|
||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
tracker match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.contains(hostIp))
|
||||
assert(cmd.startsWith("python"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap1 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
|
||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||
tracker1 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(!cmd.contains(hostIp))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap2 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
|
||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||
tracker2 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
}
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||
|
||||
val rawData = rdd.mapPartitions { iter =>
|
||||
Iterator(iter.toArray)
|
||||
}.collect()
|
||||
|
||||
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||
}
|
||||
|
||||
val allReduceResults = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val arr = iter.toArray
|
||||
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
|
||||
Rabit.shutdown()
|
||||
Iterator(results)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
allReduceResults.foreachPartition(() => _)
|
||||
val byPartitionResults = allReduceResults.collect()
|
||||
assert(byPartitionResults(0).length == vectorLength)
|
||||
collectedAllReduceResults.put(byPartitionResults(0))
|
||||
}
|
||||
}
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == 0)
|
||||
sparkThread.join()
|
||||
|
||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||
}
|
||||
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
|
||||
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
|
||||
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
|
||||
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
||||
that should be avoided.
|
||||
*/
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new PyRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
|
||||
last created worker. A cascading event chain will be triggered once the RuntimeException is
|
||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||
|
||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||
|
||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||
the entire Spark application to crash.
|
||||
|
||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||
the exception is thrown at last, ideally after all worker connections have been established.
|
||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||
process to ensure that pending worker connections are handled.
|
||||
*/
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(500)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||
if (index != 1) {
|
||||
Rabit.init(trackerEnvs)
|
||||
Thread.sleep(1000)
|
||||
Rabit.shutdown()
|
||||
}
|
||||
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
// should fail due to connection timeout
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("should allow the dataframe containing rabit calls to be partially evaluated for" +
|
||||
" multiple times (ISSUE-4406)") {
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
|
||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||
val prediction = model.transform(trainingDF)
|
||||
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
|
||||
// threads
|
||||
prediction.show()
|
||||
// a full evaluation here will re-run init and shutdown all rabit proxy
|
||||
// expecting no error
|
||||
prediction.collect()
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
@@ -25,7 +25,7 @@ import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
|
||||
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
|
||||
val predictionErrorMin = 0.00001f
|
||||
val maxFailure = 2;
|
||||
|
||||
@@ -47,8 +47,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
||||
.fit(training)
|
||||
|
||||
assert(Rabit.rabitEnvs.asScala.size > 3)
|
||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
|
||||
@@ -70,8 +70,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
|
||||
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
||||
).fit(training)
|
||||
assert(Rabit.rabitEnvs.asScala.size > 3)
|
||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
// check the equality of single instance prediction
|
||||
@@ -85,7 +85,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
test("test rabit timeout fail handle") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
// mock rank 0 failure during 8th allreduce synchronization
|
||||
Rabit.mockList = Array("0,8,0,0").toList.asJava
|
||||
Communicator.mockList = Array("0,8,0,0").toList.asJava
|
||||
|
||||
intercept[SparkException] {
|
||||
new XGBoostClassifier(Map(
|
||||
@@ -98,6 +98,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
"rabit_timeout" -> 0))
|
||||
.fit(training)
|
||||
}
|
||||
|
||||
Communicator.mockList = Array.empty.toList.asJava
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Rabit global class for synchronization.
|
||||
*/
|
||||
public class Rabit {
|
||||
|
||||
public enum OpType implements Serializable {
|
||||
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
|
||||
|
||||
private int op;
|
||||
|
||||
public int getOperand() {
|
||||
return this.op;
|
||||
}
|
||||
|
||||
OpType(int op) {
|
||||
this.op = op;
|
||||
}
|
||||
}
|
||||
|
||||
public enum DataType implements Serializable {
|
||||
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
|
||||
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
|
||||
LONGLONG(8, 8), ULONGLONG(9, 8);
|
||||
|
||||
private int enumOp;
|
||||
private int size;
|
||||
|
||||
public int getEnumOp() {
|
||||
return this.enumOp;
|
||||
}
|
||||
|
||||
public int getSize() {
|
||||
return this.size;
|
||||
}
|
||||
|
||||
DataType(int enumOp, int size) {
|
||||
this.enumOp = enumOp;
|
||||
this.size = size;
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
// used as way to test/debug passed rabit init parameters
|
||||
public static Map<String, String> rabitEnvs;
|
||||
public static List<String> mockList = new LinkedList<>();
|
||||
/**
|
||||
* Initialize the rabit library on current working thread.
|
||||
* @param envs The additional environment variables to pass to rabit.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
rabitEnvs = envs;
|
||||
String[] args = new String[envs.size() + mockList.size()];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||
}
|
||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
||||
for(String mock : mockList) {
|
||||
args[idx++] = "mock=" + mock;
|
||||
}
|
||||
checkCall(XGBoostJNI.RabitInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown the rabit engine in current working thread, equals to finalize.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void shutdown() throws XGBoostError {
|
||||
checkCall(XGBoostJNI.RabitFinalize());
|
||||
}
|
||||
|
||||
/**
|
||||
* Print the message on rabit tracker.
|
||||
* @param msg
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void trackerPrint(String msg) throws XGBoostError {
|
||||
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get version number of current stored model in the thread.
|
||||
* which means how many calls to CheckPoint we made so far.
|
||||
* @return version Number.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int versionNumber() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XGBoostJNI.RabitVersionNumber(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get rank of current thread.
|
||||
* @return the rank.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getRank() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XGBoostJNI.RabitGetRank(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get world size of current job.
|
||||
* @return the worldsize
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getWorldSize() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* perform Allreduce on distributed float vectors using operator op.
|
||||
* This implementation of allReduce does not support customized prepare function callback in the
|
||||
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
|
||||
*
|
||||
* @param elements local elements on distributed workers.
|
||||
* @param op operator used for Allreduce.
|
||||
* @return All-reduced float elements according to the given operator.
|
||||
*/
|
||||
public static float[] allReduce(float[] elements, OpType op) {
|
||||
DataType dataType = DataType.FLOAT;
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
|
||||
.order(ByteOrder.nativeOrder());
|
||||
|
||||
for (float el : elements) {
|
||||
buffer.putFloat(el);
|
||||
}
|
||||
buffer.flip();
|
||||
|
||||
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
|
||||
float[] results = new float[elements.length];
|
||||
buffer.asFloatBuffer().get(results);
|
||||
|
||||
return results;
|
||||
}
|
||||
}
|
||||
@@ -254,16 +254,16 @@ public class XGBoost {
|
||||
}
|
||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||
if (shouldPrint(params, iter)) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
Communicator.communicatorPrint(String.format(
|
||||
"early stopping after %d rounds away from the best iteration",
|
||||
earlyStoppingRounds
|
||||
));
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
|
||||
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
||||
if (shouldPrint(params, iter)){
|
||||
Rabit.trackerPrint(evalInfo + '\n');
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,19 +135,6 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||
|
||||
// rabit functions
|
||||
public final static native int RabitInit(String[] args);
|
||||
public final static native int RabitFinalize();
|
||||
public final static native int RabitTrackerPrint(String msg);
|
||||
public final static native int RabitGetRank(int[] out);
|
||||
public final static native int RabitGetWorldSize(int[] out);
|
||||
public final static native int RabitVersionNumber(int[] out);
|
||||
|
||||
// Perform Allreduce operation on data in sendrecvbuf.
|
||||
// This JNI function does not support the callback function for data preparation yet.
|
||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
|
||||
@@ -872,111 +872,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
std::vector<std::string> args;
|
||||
std::vector<char*> argv;
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
||||
const char *s = jenv->GetStringUTFChars(arg, 0);
|
||||
args.push_back(std::string(s, jenv->GetStringLength(arg)));
|
||||
if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
|
||||
if (args.back().length() == 0) args.pop_back();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
argv.push_back(&args[i][0]);
|
||||
}
|
||||
|
||||
if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
if (RabitFinalize()) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
||||
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
|
||||
std::string str(jenv->GetStringUTFChars(jmsg, 0),
|
||||
jenv->GetStringLength(jmsg));
|
||||
JVM_CHECK_CALL(RabitTrackerPrint(str.c_str()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
jint rank = RabitGetRank();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
jint out = RabitGetWorldSize();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
jint out = RabitVersionNumber();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitAllreduce
|
||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
|
||||
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
|
||||
JVM_CHECK_CALL(RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@@ -279,62 +279,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitAllreduce
|
||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@@ -300,7 +300,7 @@ public class DMatrixTest {
|
||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||
Map<String, String> rabitEnv = new HashMap<>();
|
||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||
Rabit.init(rabitEnv);
|
||||
Communicator.init(rabitEnv);
|
||||
DMatrix trainMat = null;
|
||||
BigDenseMatrix data0 = null;
|
||||
try {
|
||||
@@ -348,7 +348,7 @@ public class DMatrixTest {
|
||||
else if (data0 != null) {
|
||||
data0.dispose();
|
||||
}
|
||||
Rabit.shutdown();
|
||||
Communicator.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user