JNI wrapper for the collective communicator (#8242)
This commit is contained in:
parent
fffb1fca52
commit
7d43e74e71
@ -0,0 +1,276 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.util.concurrent.LinkedBlockingDeque
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
|
||||||
|
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
|
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
||||||
|
val classifier = new XGBoostClassifier(paramMap)
|
||||||
|
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
|
||||||
|
xgbParamsFactory.buildXGBRuntimeParams
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Customize host ip and python exec for Rabit tracker") {
|
||||||
|
val hostIp = "192.168.22.111"
|
||||||
|
val pythonExec = "/usr/bin/python3"
|
||||||
|
|
||||||
|
val paramMap = Map(
|
||||||
|
"num_workers" -> numWorkers,
|
||||||
|
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
|
||||||
|
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||||
|
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||||
|
tracker match {
|
||||||
|
case pyTracker: PyRabitTracker =>
|
||||||
|
val cmd = pyTracker.getRabitTrackerCommand
|
||||||
|
assert(cmd.contains(hostIp))
|
||||||
|
assert(cmd.startsWith("python"))
|
||||||
|
case _ => assert(false, "expected python tracker implementation")
|
||||||
|
}
|
||||||
|
|
||||||
|
val paramMap1 = Map(
|
||||||
|
"num_workers" -> numWorkers,
|
||||||
|
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
|
||||||
|
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||||
|
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||||
|
tracker1 match {
|
||||||
|
case pyTracker: PyRabitTracker =>
|
||||||
|
val cmd = pyTracker.getRabitTrackerCommand
|
||||||
|
assert(cmd.startsWith(pythonExec))
|
||||||
|
assert(!cmd.contains(hostIp))
|
||||||
|
case _ => assert(false, "expected python tracker implementation")
|
||||||
|
}
|
||||||
|
|
||||||
|
val paramMap2 = Map(
|
||||||
|
"num_workers" -> numWorkers,
|
||||||
|
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
|
||||||
|
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||||
|
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||||
|
tracker2 match {
|
||||||
|
case pyTracker: PyRabitTracker =>
|
||||||
|
val cmd = pyTracker.getRabitTrackerCommand
|
||||||
|
assert(cmd.startsWith(pythonExec))
|
||||||
|
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
||||||
|
case _ => assert(false, "expected python tracker implementation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("training with Scala-implemented Rabit tracker") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
||||||
|
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||||
|
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||||
|
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
|
||||||
|
val vectorLength = 100
|
||||||
|
val rdd = sc.parallelize(
|
||||||
|
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new ScalaRabitTracker(numWorkers)
|
||||||
|
tracker.start(0)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||||
|
|
||||||
|
val rawData = rdd.mapPartitions { iter =>
|
||||||
|
Iterator(iter.toArray)
|
||||||
|
}.collect()
|
||||||
|
|
||||||
|
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||||
|
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||||
|
}
|
||||||
|
|
||||||
|
val allReduceResults = rdd.mapPartitions { iter =>
|
||||||
|
Communicator.init(trackerEnvs)
|
||||||
|
val arr = iter.toArray
|
||||||
|
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
|
||||||
|
Communicator.shutdown()
|
||||||
|
Iterator(results)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
allReduceResults.foreachPartition(() => _)
|
||||||
|
val byPartitionResults = allReduceResults.collect()
|
||||||
|
assert(byPartitionResults(0).length == vectorLength)
|
||||||
|
collectedAllReduceResults.put(byPartitionResults(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkThread.start()
|
||||||
|
assert(tracker.waitFor(0L) == 0)
|
||||||
|
sparkThread.join()
|
||||||
|
|
||||||
|
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||||
|
/*
|
||||||
|
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||||
|
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
|
||||||
|
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
|
||||||
|
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
|
||||||
|
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
||||||
|
that should be avoided.
|
||||||
|
*/
|
||||||
|
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new PyRabitTracker(numWorkers)
|
||||||
|
tracker.start(0)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
|
||||||
|
val workerCount: Int = numWorkers
|
||||||
|
/*
|
||||||
|
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
|
||||||
|
last created worker. A cascading event chain will be triggered once the RuntimeException is
|
||||||
|
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||||
|
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||||
|
|
||||||
|
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||||
|
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||||
|
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||||
|
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||||
|
|
||||||
|
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||||
|
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||||
|
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||||
|
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||||
|
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||||
|
the entire Spark application to crash.
|
||||||
|
|
||||||
|
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||||
|
the exception is thrown at last, ideally after all worker connections have been established.
|
||||||
|
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||||
|
process to ensure that pending worker connections are handled.
|
||||||
|
*/
|
||||||
|
val dummyTasks = rdd.mapPartitions { iter =>
|
||||||
|
Communicator.init(trackerEnvs)
|
||||||
|
val index = iter.next()
|
||||||
|
Thread.sleep(100 + index * 10)
|
||||||
|
if (index == workerCount) {
|
||||||
|
// kill the worker by throwing an exception
|
||||||
|
throw new RuntimeException("Worker exception.")
|
||||||
|
}
|
||||||
|
Communicator.shutdown()
|
||||||
|
Iterator(index)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
// forces a Spark job.
|
||||||
|
dummyTasks.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkThread.start()
|
||||||
|
assert(tracker.waitFor(0) != 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||||
|
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new ScalaRabitTracker(numWorkers)
|
||||||
|
tracker.start(0)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
|
||||||
|
val workerCount: Int = numWorkers
|
||||||
|
val dummyTasks = rdd.mapPartitions { iter =>
|
||||||
|
Communicator.init(trackerEnvs)
|
||||||
|
val index = iter.next()
|
||||||
|
Thread.sleep(100 + index * 10)
|
||||||
|
if (index == workerCount) {
|
||||||
|
// kill the worker by throwing an exception
|
||||||
|
throw new RuntimeException("Worker exception.")
|
||||||
|
}
|
||||||
|
Communicator.shutdown()
|
||||||
|
Iterator(index)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
// forces a Spark job.
|
||||||
|
dummyTasks.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkThread.start()
|
||||||
|
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||||
|
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new ScalaRabitTracker(numWorkers)
|
||||||
|
tracker.start(500)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
|
||||||
|
val dummyTasks = rdd.mapPartitions { iter =>
|
||||||
|
val index = iter.next()
|
||||||
|
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||||
|
if (index != 1) {
|
||||||
|
Communicator.init(trackerEnvs)
|
||||||
|
Thread.sleep(1000)
|
||||||
|
Communicator.shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
Iterator(index)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
// forces a Spark job.
|
||||||
|
dummyTasks.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkThread.start()
|
||||||
|
// should fail due to connection timeout
|
||||||
|
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
|
||||||
|
" multiple times (ISSUE-4406)") {
|
||||||
|
val paramMap = Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
|
val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
|
||||||
|
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||||
|
val prediction = model.transform(trainingDF)
|
||||||
|
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
|
||||||
|
// threads
|
||||||
|
prediction.show()
|
||||||
|
// a full evaluation here will re-run init and shutdown all rabit proxy
|
||||||
|
// expecting no error
|
||||||
|
prediction.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,152 @@
|
|||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collective communicator global class for synchronization.
|
||||||
|
*
|
||||||
|
* Currently the communicator API is experimental, function signatures may change in the future
|
||||||
|
* without notice.
|
||||||
|
*/
|
||||||
|
public class Communicator {
|
||||||
|
|
||||||
|
public enum OpType implements Serializable {
|
||||||
|
MAX(0), MIN(1), SUM(2);
|
||||||
|
|
||||||
|
private int op;
|
||||||
|
|
||||||
|
public int getOperand() {
|
||||||
|
return this.op;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpType(int op) {
|
||||||
|
this.op = op;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum DataType implements Serializable {
|
||||||
|
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
|
||||||
|
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
|
||||||
|
|
||||||
|
private final int enumOp;
|
||||||
|
private final int size;
|
||||||
|
|
||||||
|
public int getEnumOp() {
|
||||||
|
return this.enumOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getSize() {
|
||||||
|
return this.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType(int enumOp, int size) {
|
||||||
|
this.enumOp = enumOp;
|
||||||
|
this.size = size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void checkCall(int ret) throws XGBoostError {
|
||||||
|
if (ret != 0) {
|
||||||
|
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// used as way to test/debug passed communicator init parameters
|
||||||
|
public static Map<String, String> communicatorEnvs;
|
||||||
|
public static List<String> mockList = new LinkedList<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the collective communicator on current working thread.
|
||||||
|
*
|
||||||
|
* @param envs The additional environment variables to pass to the communicator.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||||
|
communicatorEnvs = envs;
|
||||||
|
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
|
||||||
|
int idx = 0;
|
||||||
|
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||||
|
args[idx++] = e.getKey();
|
||||||
|
args[idx++] = e.getValue();
|
||||||
|
}
|
||||||
|
// pass list of rabit mock strings eg mock=0,1,0,0
|
||||||
|
for (String mock : mockList) {
|
||||||
|
args[idx++] = "mock";
|
||||||
|
args[idx++] = mock;
|
||||||
|
}
|
||||||
|
checkCall(XGBoostJNI.CommunicatorInit(args));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shutdown the communicator in current working thread, equals to finalize.
|
||||||
|
*
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static void shutdown() throws XGBoostError {
|
||||||
|
checkCall(XGBoostJNI.CommunicatorFinalize());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Print the message via the communicator.
|
||||||
|
*
|
||||||
|
* @param msg
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static void communicatorPrint(String msg) throws XGBoostError {
|
||||||
|
checkCall(XGBoostJNI.CommunicatorPrint(msg));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get rank of current thread.
|
||||||
|
*
|
||||||
|
* @return the rank.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static int getRank() throws XGBoostError {
|
||||||
|
int[] out = new int[1];
|
||||||
|
checkCall(XGBoostJNI.CommunicatorGetRank(out));
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get world size of current job.
|
||||||
|
*
|
||||||
|
* @return the worldsize
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static int getWorldSize() throws XGBoostError {
|
||||||
|
int[] out = new int[1];
|
||||||
|
checkCall(XGBoostJNI.CommunicatorGetWorldSize(out));
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* perform Allreduce on distributed float vectors using operator op.
|
||||||
|
*
|
||||||
|
* @param elements local elements on distributed workers.
|
||||||
|
* @param op operator used for Allreduce.
|
||||||
|
* @return All-reduced float elements according to the given operator.
|
||||||
|
*/
|
||||||
|
public static float[] allReduce(float[] elements, OpType op) {
|
||||||
|
DataType dataType = DataType.FLOAT32;
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
|
||||||
|
.order(ByteOrder.nativeOrder());
|
||||||
|
|
||||||
|
for (float el : elements) {
|
||||||
|
buffer.putFloat(el);
|
||||||
|
}
|
||||||
|
buffer.flip();
|
||||||
|
|
||||||
|
XGBoostJNI.CommunicatorAllreduce(buffer, elements.length, dataType.getEnumOp(),
|
||||||
|
op.getOperand());
|
||||||
|
float[] results = new float[elements.length];
|
||||||
|
buffer.asFloatBuffer().get(results);
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -148,6 +148,17 @@ class XGBoostJNI {
|
|||||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||||
int enum_dtype, int enum_op);
|
int enum_dtype, int enum_op);
|
||||||
|
|
||||||
|
// communicator functions
|
||||||
|
public final static native int CommunicatorInit(String[] args);
|
||||||
|
public final static native int CommunicatorFinalize();
|
||||||
|
public final static native int CommunicatorPrint(String msg);
|
||||||
|
public final static native int CommunicatorGetRank(int[] out);
|
||||||
|
public final static native int CommunicatorGetWorldSize(int[] out);
|
||||||
|
|
||||||
|
// Perform Allreduce operation on data in sendrecvbuf.
|
||||||
|
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||||
|
int enum_dtype, int enum_op);
|
||||||
|
|
||||||
public final static native int XGDMatrixSetInfoFromInterface(
|
public final static native int XGDMatrixSetInfoFromInterface(
|
||||||
long handle, String field, String json);
|
long handle, String field, String json);
|
||||||
|
|
||||||
|
|||||||
@ -977,6 +977,89 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorInit
|
||||||
|
* Signature: ([Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||||
|
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||||
|
xgboost::Json config{xgboost::Object{}};
|
||||||
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||||
|
assert(len % 2 == 0);
|
||||||
|
for (bst_ulong i = 0; i < len / 2; ++i) {
|
||||||
|
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
|
||||||
|
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
|
||||||
|
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
|
||||||
|
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
|
||||||
|
config[key_str] = xgboost::String(value_str);
|
||||||
|
}
|
||||||
|
std::string json_str;
|
||||||
|
xgboost::Json::Dump(config, &json_str);
|
||||||
|
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorFinalize
|
||||||
|
* Signature: ()I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
|
||||||
|
(JNIEnv *jenv, jclass jcls) {
|
||||||
|
JVM_CHECK_CALL(XGCommunicatorFinalize());
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorPrint
|
||||||
|
* Signature: (Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint
|
||||||
|
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
|
||||||
|
std::string str(jenv->GetStringUTFChars(jmsg, 0),
|
||||||
|
jenv->GetStringLength(jmsg));
|
||||||
|
JVM_CHECK_CALL(XGCommunicatorPrint(str.c_str()));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorGetRank
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRank
|
||||||
|
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||||
|
jint rank = XGCommunicatorGetRank();
|
||||||
|
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorGetWorldSize
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
|
||||||
|
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||||
|
jint out = XGCommunicatorGetWorldSize();
|
||||||
|
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorAllreduce
|
||||||
|
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllreduce
|
||||||
|
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
|
||||||
|
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
|
||||||
|
JVM_CHECK_CALL(XGCommunicatorAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace jni {
|
namespace jni {
|
||||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||||
|
|||||||
@ -335,6 +335,54 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||||
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorInit
|
||||||
|
* Signature: ([Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||||
|
(JNIEnv *, jclass, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorFinalize
|
||||||
|
* Signature: ()I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
|
||||||
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorPrint
|
||||||
|
* Signature: (Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint
|
||||||
|
(JNIEnv *, jclass, jstring);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorGetRank
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRank
|
||||||
|
(JNIEnv *, jclass, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorGetWorldSize
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
|
||||||
|
(JNIEnv *, jclass, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: CommunicatorAllreduce
|
||||||
|
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllreduce
|
||||||
|
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixSetInfoFromInterface
|
* Method: XGDMatrixSetInfoFromInterface
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user