Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -23,6 +23,7 @@ CONFIG = {
|
||||
"USE_NCCL": "OFF",
|
||||
"JVM_BINDINGS": "ON",
|
||||
"LOG_CAPI_INVOCATION": "OFF",
|
||||
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
|
||||
}
|
||||
|
||||
|
||||
@@ -97,10 +98,6 @@ def native_build(args):
|
||||
|
||||
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
|
||||
|
||||
# if enviorment set rabit_mock
|
||||
if os.getenv("RABIT_MOCK", None) is not None:
|
||||
args.append("-DRABIT_MOCK:BOOL=ON")
|
||||
|
||||
# if enviorment set GPU_ARCH_FLAG
|
||||
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
|
||||
if gpu_arch_flag is not None:
|
||||
@@ -162,12 +159,6 @@ def native_build(args):
|
||||
maybe_makedirs(output_folder)
|
||||
cp("../lib/" + library_name, output_folder)
|
||||
|
||||
print("copying pure-Python tracker")
|
||||
cp(
|
||||
"../python-package/xgboost/tracker.py",
|
||||
"{}/src/main/resources".format(xgboost4j),
|
||||
)
|
||||
|
||||
print("copying train/test files")
|
||||
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
|
||||
with cd("../demo/CLI/regression"):
|
||||
|
||||
@@ -489,6 +489,11 @@
|
||||
<artifactId>kryo</artifactId>
|
||||
<version>5.6.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.14.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-logging</groupId>
|
||||
<artifactId>commons-logging</artifactId>
|
||||
|
||||
@@ -54,9 +54,9 @@ public class XGBoost {
|
||||
|
||||
private final Map<String, Object> params;
|
||||
private final int round;
|
||||
private final Map<String, String> workerEnvs;
|
||||
private final Map<String, Object> workerEnvs;
|
||||
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, Object> workerEnvs) {
|
||||
this.params = params;
|
||||
this.round = round;
|
||||
this.workerEnvs = workerEnvs;
|
||||
@@ -174,9 +174,9 @@ public class XGBoost {
|
||||
int numBoostRound) throws Exception {
|
||||
final RabitTracker tracker =
|
||||
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||
if (tracker.start(0L)) {
|
||||
if (tracker.start()) {
|
||||
return dtrain
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
|
||||
.reduce((x, y) -> x)
|
||||
.collect()
|
||||
.get(0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -22,7 +22,7 @@ import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
*
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
|
||||
* @param timeout The number of seconds before timeout waiting for workers to connect. and
|
||||
* for the tracker to shutdown.
|
||||
* @param hostIp The Rabit Tracker host IP address.
|
||||
* This is only needed if the host IP cannot be automatically guessed.
|
||||
* @param pythonExec The python executed path for Rabit Tracker,
|
||||
* which is only used for python implementation.
|
||||
* @param port The port number for the tracker to listen to. Use a system allocated one by
|
||||
* default.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long,
|
||||
hostIp: String = "", pythonExec: String = "")
|
||||
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
def apply(): TrackerConf = TrackerConf(0)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
@@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel {
|
||||
private def buildDistributedBooster(
|
||||
buildWatches: () => Watches,
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
rabitEnv: java.util.Map[String, Object],
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
@@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel {
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
val attempt = TaskContext.get().attemptNumber.toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||
val numRounds = xgbExecutionParam.numRounds
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
|
||||
@@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel {
|
||||
}
|
||||
|
||||
/** visiable for testing */
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = new PyRabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
|
||||
)
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker: ITracker = new RabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
|
||||
tracker
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker = getTracker(nWorkers, trackerConf)
|
||||
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
}
|
||||
|
||||
@@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
val (booster, metrics) = try {
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
tracker.workerArgs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.workerArgs
|
||||
|
||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
|
||||
var optionWatches: Option[() => Watches] = None
|
||||
@@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel {
|
||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||
// it fails.
|
||||
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
if (trackerReturnVal != 0) {
|
||||
throw new XGBoostError("XGBoostModel training failed.")
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params {
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
* TrackerConf class, which has the following definition:
|
||||
*
|
||||
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
|
||||
* trackerImpl: String)
|
||||
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
|
||||
*
|
||||
* See below for detailed explanations.
|
||||
*
|
||||
* - trackerImpl: Select the implementation of Rabit tracker.
|
||||
* default: "python"
|
||||
*
|
||||
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
|
||||
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
|
||||
* The "scala" version removes Python components, and fully supports timeout settings.
|
||||
*
|
||||
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
|
||||
* default: 0 millisecond (no timeout)
|
||||
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
|
||||
* default: 0 (no timeout)
|
||||
*
|
||||
* Timeout for constructing the communication group and waiting for the tracker to
|
||||
* shutdown when it's instructed to, doesn't apply to communication when tracking
|
||||
* is running.
|
||||
* The timeout value should take the time of data loading and pre-processing into account,
|
||||
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
|
||||
* due to potential lazy execution. Alternatively, you may force Spark to
|
||||
* perform data transformation before calling XGBoost.train(), so that this timeout truly
|
||||
* reflects the connection delay. Set a reasonable timeout value to prevent model
|
||||
* training/testing from hanging indefinitely, possible due to network issues.
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*
|
||||
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
|
||||
* cannot be automatically guessed.
|
||||
*
|
||||
* - port : The port number for the tracker to listen to. Use a system allocated one by
|
||||
* default.
|
||||
*/
|
||||
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
||||
setDefault(trackerConf, TrackerConf())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
@@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
xgbParamsFactory.buildXGBRuntimeParams
|
||||
}
|
||||
|
||||
test("Customize host ip and python exec for Rabit tracker") {
|
||||
val hostIp = "192.168.22.111"
|
||||
val pythonExec = "/usr/bin/python3"
|
||||
|
||||
val paramMap = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp))
|
||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
tracker match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.contains(hostIp))
|
||||
assert(cmd.startsWith("python"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap1 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
|
||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||
tracker1 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(!cmd.contains(hostIp))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap2 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
|
||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||
tracker2 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
}
|
||||
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
@@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
*/
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new PyRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker. workerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
@@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||
|
||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||
|
||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||
the entire Spark application to crash.
|
||||
|
||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||
the exception is thrown at last, ideally after all worker connections have been established.
|
||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||
process to ensure that pending worker connections are handled.
|
||||
*/
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Communicator.init(trackerEnvs)
|
||||
@@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
}
|
||||
|
||||
test("Communicator allreduce works.") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker.workerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
|
||||
rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
Communicator.init(trackerEnvs)
|
||||
val a = Array(1.0f, 2.0f, 3.0f)
|
||||
System.out.println(a.mkString(", "))
|
||||
val b = Communicator.allReduce(a, Communicator.OpType.SUM)
|
||||
for (i <- 0 to 2) {
|
||||
assert(a(i) * workerCount == b(i))
|
||||
}
|
||||
val c = Communicator.allReduce(a, Communicator.OpType.MIN);
|
||||
for (i <- 0 to 2) {
|
||||
assert(a(i) == c(i))
|
||||
}
|
||||
Communicator.shutdown()
|
||||
Iterator(index)
|
||||
}.collect()
|
||||
}
|
||||
|
||||
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
|
||||
|
||||
@@ -23,7 +23,6 @@ import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
"objective_type" -> "classification")
|
||||
@@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
intercept[SparkException] {
|
||||
xgb.fit(trainingDF)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
test("fail training elegantly with unsupported eval metrics") {
|
||||
|
||||
@@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
||||
.fit(training)
|
||||
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
@@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
||||
).fit(training)
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
// check the equality of single instance prediction
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
@@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
assert(math.abs(p1 - p2) < predictionErrorMin)
|
||||
}
|
||||
}
|
||||
|
||||
test("test rabit timeout fail handle") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
// mock rank 0 failure during 8th allreduce synchronization
|
||||
Communicator.mockList = Array("0,8,0,0").toList.asJava
|
||||
|
||||
intercept[SparkException] {
|
||||
new XGBoostClassifier(Map(
|
||||
"eta" -> "0.1",
|
||||
"max_depth" -> "10",
|
||||
"verbosity" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5,
|
||||
"num_workers" -> numWorkers,
|
||||
"rabit_timeout" -> 0))
|
||||
.fit(training)
|
||||
}
|
||||
|
||||
Communicator.mockList = Array.empty.toList.asJava
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -51,6 +51,11 @@ pom_template = """
|
||||
<artifactId>commons-logging</artifactId>
|
||||
<version>1.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.14.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${{scala.binary.version}}</artifactId>
|
||||
|
||||
@@ -7,6 +7,9 @@ import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* Collective communicator global class for synchronization.
|
||||
*
|
||||
@@ -30,8 +33,9 @@ public class Communicator {
|
||||
}
|
||||
|
||||
public enum DataType implements Serializable {
|
||||
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
|
||||
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
|
||||
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
|
||||
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
|
||||
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
|
||||
|
||||
private final int enumOp;
|
||||
private final int size;
|
||||
@@ -56,30 +60,20 @@ public class Communicator {
|
||||
}
|
||||
}
|
||||
|
||||
// used as way to test/debug passed communicator init parameters
|
||||
public static Map<String, String> communicatorEnvs;
|
||||
public static List<String> mockList = new LinkedList<>();
|
||||
|
||||
/**
|
||||
* Initialize the collective communicator on current working thread.
|
||||
*
|
||||
* @param envs The additional environment variables to pass to the communicator.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
communicatorEnvs = envs;
|
||||
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey();
|
||||
args[idx++] = e.getValue();
|
||||
public static void init(Map<String, Object> envs) throws XGBoostError {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
try {
|
||||
String jconfig = mapper.writeValueAsString(envs);
|
||||
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
|
||||
}
|
||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
||||
for (String mock : mockList) {
|
||||
args[idx++] = "mock";
|
||||
args[idx++] = mock;
|
||||
}
|
||||
checkCall(XGBoostJNI.CommunicatorInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Interface for Rabit tracker implementations with three public methods:
|
||||
* Interface for a tracker implementations with three public methods:
|
||||
*
|
||||
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
|
||||
* timeout value (in milliseconds.)
|
||||
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
|
||||
* - start(timeout): Start the tracker awaiting for worker connections, with a given
|
||||
* timeout value (in seconds).
|
||||
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||
* milliseconds.
|
||||
*
|
||||
@@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
|
||||
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
|
||||
* brokers connections between workers.
|
||||
*/
|
||||
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
public interface ITracker extends Thread.UncaughtExceptionHandler {
|
||||
enum TrackerStatus {
|
||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||
|
||||
@@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, String> getWorkerEnvs();
|
||||
boolean start(long workerConnectionTimeout);
|
||||
void stop();
|
||||
// taskExecutionTimeout has no effect in current version of XGBoost.
|
||||
int waitFor(long taskExecutionTimeout);
|
||||
Map<String, Object> workerArgs() throws XGBoostError;
|
||||
|
||||
boolean start() throws XGBoostError;
|
||||
|
||||
void stop() throws XGBoostError;
|
||||
|
||||
void waitFor(long taskExecutionTimeout) throws XGBoostError;
|
||||
}
|
||||
@@ -1,101 +1,40 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
|
||||
* start() and waitFor() methods (i.e., the timeout is infinite.)
|
||||
*
|
||||
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
|
||||
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
|
||||
* provides timeout support.
|
||||
*
|
||||
* The tracker must be started on driver node before running distributed jobs.
|
||||
*/
|
||||
public class RabitTracker implements IRabitTracker {
|
||||
public class RabitTracker implements ITracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
// tracker python file.
|
||||
private static String tracker_py = null;
|
||||
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int numWorkers;
|
||||
private String hostIp = "";
|
||||
private String pythonExec = "";
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
private long handle = 0;
|
||||
private Thread tracker_daemon;
|
||||
|
||||
static {
|
||||
try {
|
||||
initTrackerPy();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load tracker library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
public RabitTracker(int numWorkers) throws XGBoostError {
|
||||
this(numWorkers, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracker logger that logs output from tracker.
|
||||
*/
|
||||
private class TrackerProcessLogger implements Runnable {
|
||||
public void run() {
|
||||
|
||||
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(
|
||||
trackerProcess.get().getErrorStream()));
|
||||
String line;
|
||||
try {
|
||||
while ((line = reader.readLine()) != null) {
|
||||
trackerProcessLogger.info(line);
|
||||
}
|
||||
trackerProcess.get().waitFor();
|
||||
int exitValue = trackerProcess.get().exitValue();
|
||||
if (exitValue != 0) {
|
||||
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
|
||||
} else {
|
||||
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
trackerProcessLogger.error(ex.toString());
|
||||
} catch (InterruptedException ie) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
ie.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void initTrackerPy() throws IOException {
|
||||
try {
|
||||
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
|
||||
} catch (IOException ioe) {
|
||||
logger.trace("cannot access tracker python script");
|
||||
throw ioe;
|
||||
}
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers)
|
||||
public RabitTracker(int numWorkers, String hostIp)
|
||||
throws XGBoostError {
|
||||
this(numWorkers, hostIp, 0, 300);
|
||||
}
|
||||
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
|
||||
if (numWorkers < 1) {
|
||||
throw new XGBoostError("numWorkers must be greater equal to one");
|
||||
}
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
|
||||
throws XGBoostError {
|
||||
this(numWorkers);
|
||||
this.hostIp = hostIp;
|
||||
this.pythonExec = pythonExec;
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
|
||||
this.handle = out[0];
|
||||
}
|
||||
|
||||
public void uncaughtException(Thread t, Throwable e) {
|
||||
@@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
|
||||
} catch (InterruptedException ex) {
|
||||
logger.error(ex);
|
||||
} finally {
|
||||
trackerProcess.get().destroy();
|
||||
this.tracker_daemon.interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
|
||||
* Get environments that can be used to pass to worker.
|
||||
* @return The environment settings.
|
||||
*/
|
||||
public Map<String, String> getWorkerEnvs() {
|
||||
return envs;
|
||||
public Map<String, Object> workerArgs() throws XGBoostError {
|
||||
// fixme: timeout
|
||||
String[] args = new String[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
|
||||
};
|
||||
Map<String, Object> config;
|
||||
try {
|
||||
config = mapper.readValue(args[0], typeRef);
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to get worker arguments.", ex);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private void loadEnvs(InputStream ins) throws IOException {
|
||||
try {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
|
||||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||
break;
|
||||
}
|
||||
String[] sep = line.split("=");
|
||||
if (sep.length == 2) {
|
||||
envs.put(sep[0], sep[1]);
|
||||
}
|
||||
public void stop() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
|
||||
}
|
||||
|
||||
public boolean start() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
|
||||
this.tracker_daemon = new Thread(() -> {
|
||||
try {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return; // exit the thread
|
||||
}
|
||||
reader.close();
|
||||
} catch (IOException ioe){
|
||||
logger.error("cannot get runtime configuration from tracker process");
|
||||
ioe.printStackTrace();
|
||||
throw ioe;
|
||||
}
|
||||
});
|
||||
this.tracker_daemon.setDaemon(true);
|
||||
this.tracker_daemon.start();
|
||||
|
||||
return this.tracker_daemon.isAlive();
|
||||
}
|
||||
|
||||
/** visible for testing */
|
||||
public String getRabitTrackerCommand() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (pythonExec == null || pythonExec.isEmpty()) {
|
||||
sb.append("python ");
|
||||
} else {
|
||||
sb.append(pythonExec + " ");
|
||||
}
|
||||
sb.append(" " + tracker_py + " ");
|
||||
sb.append(" --log-level=DEBUG" + " ");
|
||||
sb.append(" --num-workers=" + numWorkers + " ");
|
||||
|
||||
// we first check the property then check the parameter
|
||||
String hostIpFromProperties = trackerProperties.getHostIp();
|
||||
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
|
||||
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
|
||||
sb.append(" --host-ip=" + hostIpFromProperties + " ");
|
||||
} else if (hostIp != null & !hostIp.isEmpty()) {
|
||||
logger.debug("Using the parametr host-ip: " + hostIp);
|
||||
sb.append(" --host-ip=" + hostIp + " ");
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
String cmd = getRabitTrackerCommand();
|
||||
trackerProcess.set(Runtime.getRuntime().exec(cmd));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
if (trackerProcess.get() != null) {
|
||||
trackerProcess.get().destroy();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean start(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for all workers to connect indefinitely, unless " +
|
||||
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
if (startTrackerProcess()) {
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
System.out.println("Tracker started, with env=" + envs.toString());
|
||||
// also start a tracker logger
|
||||
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||
logger_thread.setDaemon(true);
|
||||
logger_thread.start();
|
||||
return true;
|
||||
} else {
|
||||
logger.error("FAULT: failed to start tracker process");
|
||||
stop();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public int waitFor(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for either all workers to finish tasks and send " +
|
||||
"shutdown signal, or manual interruptions. " +
|
||||
"Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
try {
|
||||
trackerProcess.get().waitFor();
|
||||
int returnVal = trackerProcess.get().exitValue();
|
||||
logger.info("Tracker Process ends with exit code " + returnVal);
|
||||
stop();
|
||||
return returnVal;
|
||||
} catch (InterruptedException e) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
e.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
return TrackerStatus.INTERRUPTED.getStatusCode();
|
||||
}
|
||||
public void waitFor(long timeout) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -54,7 +54,7 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
|
||||
float[] data, int shapeParam,
|
||||
@@ -146,12 +146,24 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorInit(String args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
public final static native int CommunicatorPrint(String msg);
|
||||
public final static native int CommunicatorGetRank(int[] out);
|
||||
public final static native int CommunicatorGetWorldSize(int[] out);
|
||||
|
||||
// Tracker functions
|
||||
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
|
||||
long[] out);
|
||||
|
||||
public final static native int TrackerRun(long handle);
|
||||
|
||||
public final static native int TrackerWaitFor(long handle, long timeout);
|
||||
|
||||
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
|
||||
|
||||
public final static native int TrackerFree(long handle);
|
||||
|
||||
// Perform Allreduce operation on data in sendrecvbuf.
|
||||
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
@@ -168,5 +180,4 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||
|
||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||
|
||||
}
|
||||
|
||||
@@ -42,5 +42,4 @@ public final class UtilUnsafe {
|
||||
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
/**
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "./xgboost4j.h"
|
||||
|
||||
#include <rabit/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/json.h>
|
||||
@@ -23,7 +24,6 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
@@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
|
||||
jclass jcls,
|
||||
jstring jargs) {
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
assert(len % 2 == 0);
|
||||
for (bst_ulong i = 0; i < len / 2; ++i) {
|
||||
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
|
||||
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
|
||||
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
|
||||
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
|
||||
config[key_str] = xgboost::String(value_str);
|
||||
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(args));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
|
||||
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
|
||||
jlongArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
TrackerHandle handle;
|
||||
Json config{Object{}};
|
||||
std::string shost{jenv->GetStringUTFChars(host, nullptr),
|
||||
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
|
||||
if (!shost.empty()) {
|
||||
config["host"] = shost;
|
||||
}
|
||||
std::string json_str;
|
||||
xgboost::Json::Dump(config, &json_str);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
|
||||
config["port"] = Integer{static_cast<Integer::Int>(port)};
|
||||
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
|
||||
config["dmlc_communicator"] = String{"rabit"};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
|
||||
setHandle(jenv, jout, handle);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
|
||||
jlong jhandle,
|
||||
jlong timeout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
char const *args;
|
||||
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
|
||||
auto jargs = Json::Load(StringView{args});
|
||||
|
||||
jstring jret = jenv->NewStringUTF(args);
|
||||
jenv->SetObjectArrayElement(jout, 0, jret);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerFree(handle));
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
* Method: CommunicatorFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
|
||||
jclass) {
|
||||
JVM_CHECK_CALL(XGCommunicatorFinalize());
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
|
||||
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
|
||||
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorAllreduce
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -298,7 +298,7 @@ public class DMatrixTest {
|
||||
|
||||
@Test
|
||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||
Map<String, String> rabitEnv = new HashMap<>();
|
||||
Map<String, Object> rabitEnv = new HashMap<>();
|
||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||
Communicator.init(rabitEnv);
|
||||
DMatrix trainMat = null;
|
||||
|
||||
Reference in New Issue
Block a user