Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -23,6 +23,7 @@ CONFIG = {
"USE_NCCL": "OFF",
"JVM_BINDINGS": "ON",
"LOG_CAPI_INVOCATION": "OFF",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
}
@@ -97,10 +98,6 @@ def native_build(args):
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
# if enviorment set rabit_mock
if os.getenv("RABIT_MOCK", None) is not None:
args.append("-DRABIT_MOCK:BOOL=ON")
# if enviorment set GPU_ARCH_FLAG
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
if gpu_arch_flag is not None:
@@ -162,12 +159,6 @@ def native_build(args):
maybe_makedirs(output_folder)
cp("../lib/" + library_name, output_folder)
print("copying pure-Python tracker")
cp(
"../python-package/xgboost/tracker.py",
"{}/src/main/resources".format(xgboost4j),
)
print("copying train/test files")
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
with cd("../demo/CLI/regression"):

View File

@@ -489,6 +489,11 @@
<artifactId>kryo</artifactId>
<version>5.6.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>

View File

@@ -54,9 +54,9 @@ public class XGBoost {
private final Map<String, Object> params;
private final int round;
private final Map<String, String> workerEnvs;
private final Map<String, Object> workerEnvs;
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
public MapFunction(Map<String, Object> params, int round, Map<String, Object> workerEnvs) {
this.params = params;
this.round = round;
this.workerEnvs = workerEnvs;
@@ -174,9 +174,9 @@ public class XGBoost {
int numBoostRound) throws Exception {
final RabitTracker tracker =
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start(0L)) {
if (tracker.start()) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
.reduce((x, y) -> x)
.collect()
.get(0);

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
@@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession
/**
* Rabit tracker configurations.
*
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* @param timeout The number of seconds before timeout waiting for workers to connect. and
* for the tracker to shutdown.
* @param hostIp The Rabit Tracker host IP address.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
* @param port The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L)
def apply(): TrackerConf = TrackerConf(0)
}
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
@@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel {
private def buildDistributedBooster(
buildWatches: () => Watches,
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
rabitEnv: java.util.Map[String, Object],
obj: ObjectiveTrait,
eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
@@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel {
val taskId = TaskContext.getPartitionId().toString
val attempt = TaskContext.get().attemptNumber.toString
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
@@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel {
}
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = new PyRabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
)
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker: ITracker = new RabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
tracker
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker = getTracker(nWorkers, trackerConf)
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
require(tracker.start(), "FAULT: Failed to start tracker")
tracker
}
@@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel {
// Train for every ${savingRound} rounds and save the partially completed booster
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
val (booster, metrics) = try {
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val rabitEnv = tracker.getWorkerEnvs
tracker.workerArgs().putAll(xgbRabitParams)
val rabitEnv = tracker.workerArgs
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
var optionWatches: Option[() => Watches] = None
@@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel {
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
if (trackerReturnVal != 0) {
throw new XGBoostError("XGBoostModel training failed.")
}
(booster, metrics)
} finally {
tracker.stop()

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params {
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
*
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
* trackerImpl: String)
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
*
* See below for detailed explanations.
*
* - trackerImpl: Select the implementation of Rabit tracker.
* default: "python"
*
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
* The "scala" version removes Python components, and fully supports timeout settings.
*
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
* default: 0 millisecond (no timeout)
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
* default: 0 (no timeout)
*
* Timeout for constructing the communication group and waiting for the tracker to
* shutdown when it's instructed to, doesn't apply to communication when tracking
* is running.
* The timeout value should take the time of data loading and pre-processing into account,
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
* due to potential lazy execution. Alternatively, you may force Spark to
* perform data transformation before calling XGBoost.train(), so that this timeout truly
* reflects the connection delay. Set a reasonable timeout value to prevent model
* training/testing from hanging indefinitely, possible due to network issues.
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
* Ignored if the tracker implementation is "python".
*
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
* cannot be automatically guessed.
*
* - port : The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
setDefault(trackerConf, TrackerConf())

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.funsuite.AnyFunSuite
@@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
xgbParamsFactory.buildXGBRuntimeParams
}
test("Customize host ip and python exec for Rabit tracker") {
val hostIp = "192.168.22.111"
val pythonExec = "/usr/bin/python3"
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.contains(hostIp))
assert(cmd.startsWith("python"))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(!cmd.contains(hostIp))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(cmd.contains(s" --host-ip=${hostIp}"))
case _ => assert(false, "expected python tracker implementation")
}
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
*/
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new PyRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker. workerArgs
val workerCount: Int = numWorkers
/*
@@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
The Java RabitTracker class reacts to exceptions by killing the spawned process running
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
it is killed, the resulted connection failure will trigger the Rabit worker to call
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
running in separate containers. However, as unit tests are run in Spark local mode, in which
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
ultimately kills the entire process, which also happens to host the Spark driver, causing
the entire Spark application to crash.
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
the exception is thrown at last, ideally after all worker connections have been established.
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
process to ensure that pending worker connections are handled.
*/
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
@@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0) != 0)
}
test("Communicator allreduce works.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.workerArgs
val workerCount: Int = numWorkers
rdd.mapPartitions { iter =>
val index = iter.next()
Communicator.init(trackerEnvs)
val a = Array(1.0f, 2.0f, 3.0f)
System.out.println(a.mkString(", "))
val b = Communicator.allReduce(a, Communicator.OpType.SUM)
for (i <- 0 to 2) {
assert(a(i) * workerCount == b(i))
}
val c = Communicator.allReduce(a, Communicator.OpType.MIN);
for (i <- 0 to 2) {
assert(a(i) == c(i))
}
Communicator.shutdown()
Iterator(index)
}.collect()
}
test("should allow the dataframe containing communicator calls to be partially evaluated for" +

View File

@@ -23,7 +23,6 @@ import org.apache.spark.SparkException
import org.apache.spark.ml.param.ParamMap
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
@@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
intercept[SparkException] {
xgb.fit(trainingDF)
}
}
test("fail training elegantly with unsupported eval metrics") {

View File

@@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
.fit(training)
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
@@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
).fit(training)
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
// check the equality of single instance prediction
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
@@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
assert(math.abs(p1 - p2) < predictionErrorMin)
}
}
test("test rabit timeout fail handle") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Communicator.mockList = Array("0,8,0,0").toList.asJava
intercept[SparkException] {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"rabit_timeout" -> 0))
.fit(training)
}
Communicator.mockList = Array.empty.toList.asJava
}
}

View File

@@ -51,6 +51,11 @@ pom_template = """
<artifactId>commons-logging</artifactId>
<version>1.2</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${{scala.binary.version}}</artifactId>

View File

@@ -7,6 +7,9 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* Collective communicator global class for synchronization.
*
@@ -30,8 +33,9 @@ public class Communicator {
}
public enum DataType implements Serializable {
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
private final int enumOp;
private final int size;
@@ -56,30 +60,20 @@ public class Communicator {
}
}
// used as way to test/debug passed communicator init parameters
public static Map<String, String> communicatorEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the collective communicator on current working thread.
*
* @param envs The additional environment variables to pass to the communicator.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
communicatorEnvs = envs;
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey();
args[idx++] = e.getValue();
public static void init(Map<String, Object> envs) throws XGBoostError {
ObjectMapper mapper = new ObjectMapper();
try {
String jconfig = mapper.writeValueAsString(envs);
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
}
// pass list of rabit mock strings eg mock=0,1,0,0
for (String mock : mockList) {
args[idx++] = "mock";
args[idx++] = mock;
}
checkCall(XGBoostJNI.CommunicatorInit(args));
}
/**

View File

@@ -1,14 +1,13 @@
package ml.dmlc.xgboost4j.java;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* Interface for Rabit tracker implementations with three public methods:
* Interface for a tracker implementations with three public methods:
*
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
* timeout value (in milliseconds.)
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
@@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
* brokers connections between workers.
*/
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
@@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
}
}
Map<String, String> getWorkerEnvs();
boolean start(long workerConnectionTimeout);
void stop();
// taskExecutionTimeout has no effect in current version of XGBoost.
int waitFor(long taskExecutionTimeout);
Map<String, Object> workerArgs() throws XGBoostError;
boolean start() throws XGBoostError;
void stop() throws XGBoostError;
void waitFor(long taskExecutionTimeout) throws XGBoostError;
}

View File

@@ -1,101 +1,40 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
* start() and waitFor() methods (i.e., the timeout is infinite.)
*
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
* provides timeout support.
*
* The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements IRabitTracker {
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
private static String tracker_py = null;
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int numWorkers;
private String hostIp = "";
private String pythonExec = "";
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
private long handle = 0;
private Thread tracker_daemon;
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
}
/**
* Tracker logger that logs output from tracker.
*/
private class TrackerProcessLogger implements Runnable {
public void run() {
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
BufferedReader reader = new BufferedReader(new InputStreamReader(
trackerProcess.get().getErrorStream()));
String line;
try {
while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line);
}
trackerProcess.get().waitFor();
int exitValue = trackerProcess.get().exitValue();
if (exitValue != 0) {
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
} else {
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
}
} catch (IOException ex) {
trackerProcessLogger.error(ex.toString());
} catch (InterruptedException ie) {
// we should not get here as RabitTracker is accessed in the main thread
ie.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}
private static void initTrackerPy() throws IOException {
try {
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
} catch (IOException ioe) {
logger.trace("cannot access tracker python script");
throw ioe;
}
}
public RabitTracker(int numWorkers)
public RabitTracker(int numWorkers, String hostIp)
throws XGBoostError {
this(numWorkers, hostIp, 0, 300);
}
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
if (numWorkers < 1) {
throw new XGBoostError("numWorkers must be greater equal to one");
}
this.numWorkers = numWorkers;
}
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
throws XGBoostError {
this(numWorkers);
this.hostIp = hostIp;
this.pythonExec = pythonExec;
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
this.handle = out[0];
}
public void uncaughtException(Thread t, Throwable e) {
@@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
trackerProcess.get().destroy();
this.tracker_daemon.interrupt();
}
}
@@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, String> getWorkerEnvs() {
return envs;
public Map<String, Object> workerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config;
try {
config = mapper.readValue(args[0], typeRef);
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex);
}
return config;
}
private void loadEnvs(InputStream ins) throws IOException {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break;
}
String[] sep = line.split("=");
if (sep.length == 2) {
envs.put(sep[0], sep[1]);
}
public void stop() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
}
public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> {
try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
} catch (XGBoostError ex) {
logger.error(ex);
return; // exit the thread
}
reader.close();
} catch (IOException ioe){
logger.error("cannot get runtime configuration from tracker process");
ioe.printStackTrace();
throw ioe;
}
});
this.tracker_daemon.setDaemon(true);
this.tracker_daemon.start();
return this.tracker_daemon.isAlive();
}
/** visible for testing */
public String getRabitTrackerCommand() {
StringBuilder sb = new StringBuilder();
if (pythonExec == null || pythonExec.isEmpty()) {
sb.append("python ");
} else {
sb.append(pythonExec + " ");
}
sb.append(" " + tracker_py + " ");
sb.append(" --log-level=DEBUG" + " ");
sb.append(" --num-workers=" + numWorkers + " ");
// we first check the property then check the parameter
String hostIpFromProperties = trackerProperties.getHostIp();
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
sb.append(" --host-ip=" + hostIpFromProperties + " ");
} else if (hostIp != null & !hostIp.isEmpty()) {
logger.debug("Using the parametr host-ip: " + hostIp);
sb.append(" --host-ip=" + hostIp + " ");
}
return sb.toString();
}
private boolean startTrackerProcess() {
try {
String cmd = getRabitTrackerCommand();
trackerProcess.set(Runtime.getRuntime().exec(cmd));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}
public void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
}
}
public boolean start(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for all workers to connect indefinitely, unless " +
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
}
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);
logger_thread.start();
return true;
} else {
logger.error("FAULT: failed to start tracker process");
stop();
return false;
}
}
public int waitFor(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for either all workers to finish tasks and send " +
"shutdown signal, or manual interruptions. " +
"Use the Scala RabitTracker for timeout support.");
}
try {
trackerProcess.get().waitFor();
int returnVal = trackerProcess.get().exitValue();
logger.info("Tracker Process ends with exit code " + returnVal);
stop();
return returnVal;
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
return TrackerStatus.INTERRUPTED.getStatusCode();
}
public void waitFor(long timeout) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -54,7 +54,7 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
float[] data, int shapeParam,
@@ -146,12 +146,24 @@ class XGBoostJNI {
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorInit(String args);
public final static native int CommunicatorFinalize();
public final static native int CommunicatorPrint(String msg);
public final static native int CommunicatorGetRank(int[] out);
public final static native int CommunicatorGetWorldSize(int[] out);
// Tracker functions
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
long[] out);
public final static native int TrackerRun(long handle);
public final static native int TrackerWaitFor(long handle, long timeout);
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
public final static native int TrackerFree(long handle);
// Perform Allreduce operation on data in sendrecvbuf.
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
@@ -168,5 +180,4 @@ class XGBoostJNI {
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
}

View File

@@ -42,5 +42,4 @@ public final class UtilUnsafe {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
}
}
}

View File

@@ -1,20 +1,21 @@
/**
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
* Copyright 2014-2024, XGBoost Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./xgboost4j.h"
#include <rabit/c_api.h>
#include <xgboost/base.h>
#include <xgboost/c_api.h>
#include <xgboost/json.h>
@@ -23,7 +24,6 @@
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
@@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
jclass jcls,
jstring jargs) {
xgboost::Json config{xgboost::Object{}};
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
assert(len % 2 == 0);
for (bst_ulong i = 0; i < len / 2; ++i) {
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
config[key_str] = xgboost::String(value_str);
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
JVM_CHECK_CALL(XGCommunicatorInit(args));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
jlongArray jout) {
using namespace xgboost; // NOLINT
TrackerHandle handle;
Json config{Object{}};
std::string shost{jenv->GetStringUTFChars(host, nullptr),
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
if (!shost.empty()) {
config["host"] = shost;
}
std::string json_str;
xgboost::Json::Dump(config, &json_str);
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
config["port"] = Integer{static_cast<Integer::Int>(port)};
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
config["dmlc_communicator"] = String{"rabit"};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
setHandle(jenv, jout, handle);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
jlong jhandle,
jlong timeout) {
using namespace xgboost; // NOLINT
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
using namespace xgboost; // NOLINT
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
char const *args;
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
auto jargs = Json::Load(StringView{args});
jstring jret = jenv->NewStringUTF(args);
jenv->SetObjectArrayElement(jout, 0, jret);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerFree(handle));
return 0;
}
@@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
* Method: CommunicatorFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
(JNIEnv *jenv, jclass jcls) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
jclass) {
JVM_CHECK_CALL(XGCommunicatorFinalize());
return 0;
}

View File

@@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *, jclass, jobjectArray);
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorAllreduce

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -298,7 +298,7 @@ public class DMatrixTest {
@Test
public void testTrainWithDenseMatrixRef() throws XGBoostError {
Map<String, String> rabitEnv = new HashMap<>();
Map<String, Object> rabitEnv = new HashMap<>();
rabitEnv.put("DMLC_TASK_ID", "0");
Communicator.init(rabitEnv);
DMatrix trainMat = null;