Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -7,6 +7,9 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* Collective communicator global class for synchronization.
*
@@ -30,8 +33,9 @@ public class Communicator {
}
public enum DataType implements Serializable {
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
private final int enumOp;
private final int size;
@@ -56,30 +60,20 @@ public class Communicator {
}
}
// used as way to test/debug passed communicator init parameters
public static Map<String, String> communicatorEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the collective communicator on current working thread.
*
* @param envs The additional environment variables to pass to the communicator.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
communicatorEnvs = envs;
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey();
args[idx++] = e.getValue();
public static void init(Map<String, Object> envs) throws XGBoostError {
ObjectMapper mapper = new ObjectMapper();
try {
String jconfig = mapper.writeValueAsString(envs);
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
}
// pass list of rabit mock strings eg mock=0,1,0,0
for (String mock : mockList) {
args[idx++] = "mock";
args[idx++] = mock;
}
checkCall(XGBoostJNI.CommunicatorInit(args));
}
/**

View File

@@ -1,14 +1,13 @@
package ml.dmlc.xgboost4j.java;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* Interface for Rabit tracker implementations with three public methods:
* Interface for a tracker implementations with three public methods:
*
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
* timeout value (in milliseconds.)
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
@@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
* brokers connections between workers.
*/
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
@@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
}
}
Map<String, String> getWorkerEnvs();
boolean start(long workerConnectionTimeout);
void stop();
// taskExecutionTimeout has no effect in current version of XGBoost.
int waitFor(long taskExecutionTimeout);
Map<String, Object> workerArgs() throws XGBoostError;
boolean start() throws XGBoostError;
void stop() throws XGBoostError;
void waitFor(long taskExecutionTimeout) throws XGBoostError;
}

View File

@@ -1,101 +1,40 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
* start() and waitFor() methods (i.e., the timeout is infinite.)
*
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
* provides timeout support.
*
* The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements IRabitTracker {
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
private static String tracker_py = null;
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int numWorkers;
private String hostIp = "";
private String pythonExec = "";
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
private long handle = 0;
private Thread tracker_daemon;
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
}
/**
* Tracker logger that logs output from tracker.
*/
private class TrackerProcessLogger implements Runnable {
public void run() {
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
BufferedReader reader = new BufferedReader(new InputStreamReader(
trackerProcess.get().getErrorStream()));
String line;
try {
while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line);
}
trackerProcess.get().waitFor();
int exitValue = trackerProcess.get().exitValue();
if (exitValue != 0) {
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
} else {
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
}
} catch (IOException ex) {
trackerProcessLogger.error(ex.toString());
} catch (InterruptedException ie) {
// we should not get here as RabitTracker is accessed in the main thread
ie.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}
private static void initTrackerPy() throws IOException {
try {
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
} catch (IOException ioe) {
logger.trace("cannot access tracker python script");
throw ioe;
}
}
public RabitTracker(int numWorkers)
public RabitTracker(int numWorkers, String hostIp)
throws XGBoostError {
this(numWorkers, hostIp, 0, 300);
}
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
if (numWorkers < 1) {
throw new XGBoostError("numWorkers must be greater equal to one");
}
this.numWorkers = numWorkers;
}
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
throws XGBoostError {
this(numWorkers);
this.hostIp = hostIp;
this.pythonExec = pythonExec;
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
this.handle = out[0];
}
public void uncaughtException(Thread t, Throwable e) {
@@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
trackerProcess.get().destroy();
this.tracker_daemon.interrupt();
}
}
@@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, String> getWorkerEnvs() {
return envs;
public Map<String, Object> workerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config;
try {
config = mapper.readValue(args[0], typeRef);
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex);
}
return config;
}
private void loadEnvs(InputStream ins) throws IOException {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break;
}
String[] sep = line.split("=");
if (sep.length == 2) {
envs.put(sep[0], sep[1]);
}
public void stop() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
}
public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> {
try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
} catch (XGBoostError ex) {
logger.error(ex);
return; // exit the thread
}
reader.close();
} catch (IOException ioe){
logger.error("cannot get runtime configuration from tracker process");
ioe.printStackTrace();
throw ioe;
}
});
this.tracker_daemon.setDaemon(true);
this.tracker_daemon.start();
return this.tracker_daemon.isAlive();
}
/** visible for testing */
public String getRabitTrackerCommand() {
StringBuilder sb = new StringBuilder();
if (pythonExec == null || pythonExec.isEmpty()) {
sb.append("python ");
} else {
sb.append(pythonExec + " ");
}
sb.append(" " + tracker_py + " ");
sb.append(" --log-level=DEBUG" + " ");
sb.append(" --num-workers=" + numWorkers + " ");
// we first check the property then check the parameter
String hostIpFromProperties = trackerProperties.getHostIp();
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
sb.append(" --host-ip=" + hostIpFromProperties + " ");
} else if (hostIp != null & !hostIp.isEmpty()) {
logger.debug("Using the parametr host-ip: " + hostIp);
sb.append(" --host-ip=" + hostIp + " ");
}
return sb.toString();
}
private boolean startTrackerProcess() {
try {
String cmd = getRabitTrackerCommand();
trackerProcess.set(Runtime.getRuntime().exec(cmd));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}
public void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
}
}
public boolean start(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for all workers to connect indefinitely, unless " +
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
}
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);
logger_thread.start();
return true;
} else {
logger.error("FAULT: failed to start tracker process");
stop();
return false;
}
}
public int waitFor(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for either all workers to finish tasks and send " +
"shutdown signal, or manual interruptions. " +
"Use the Scala RabitTracker for timeout support.");
}
try {
trackerProcess.get().waitFor();
int returnVal = trackerProcess.get().exitValue();
logger.info("Tracker Process ends with exit code " + returnVal);
stop();
return returnVal;
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
return TrackerStatus.INTERRUPTED.getStatusCode();
}
public void waitFor(long timeout) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -54,7 +54,7 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
float[] data, int shapeParam,
@@ -146,12 +146,24 @@ class XGBoostJNI {
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorInit(String args);
public final static native int CommunicatorFinalize();
public final static native int CommunicatorPrint(String msg);
public final static native int CommunicatorGetRank(int[] out);
public final static native int CommunicatorGetWorldSize(int[] out);
// Tracker functions
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
long[] out);
public final static native int TrackerRun(long handle);
public final static native int TrackerWaitFor(long handle, long timeout);
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
public final static native int TrackerFree(long handle);
// Perform Allreduce operation on data in sendrecvbuf.
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
@@ -168,5 +180,4 @@ class XGBoostJNI {
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
}

View File

@@ -42,5 +42,4 @@ public final class UtilUnsafe {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
}
}
}

View File

@@ -1,20 +1,21 @@
/**
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
* Copyright 2014-2024, XGBoost Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./xgboost4j.h"
#include <rabit/c_api.h>
#include <xgboost/base.h>
#include <xgboost/c_api.h>
#include <xgboost/json.h>
@@ -23,7 +24,6 @@
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
@@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
jclass jcls,
jstring jargs) {
xgboost::Json config{xgboost::Object{}};
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
assert(len % 2 == 0);
for (bst_ulong i = 0; i < len / 2; ++i) {
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
config[key_str] = xgboost::String(value_str);
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
JVM_CHECK_CALL(XGCommunicatorInit(args));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
jlongArray jout) {
using namespace xgboost; // NOLINT
TrackerHandle handle;
Json config{Object{}};
std::string shost{jenv->GetStringUTFChars(host, nullptr),
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
if (!shost.empty()) {
config["host"] = shost;
}
std::string json_str;
xgboost::Json::Dump(config, &json_str);
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
config["port"] = Integer{static_cast<Integer::Int>(port)};
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
config["dmlc_communicator"] = String{"rabit"};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
setHandle(jenv, jout, handle);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
jlong jhandle,
jlong timeout) {
using namespace xgboost; // NOLINT
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
using namespace xgboost; // NOLINT
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
char const *args;
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
auto jargs = Json::Load(StringView{args});
jstring jret = jenv->NewStringUTF(args);
jenv->SetObjectArrayElement(jout, 0, jret);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerFree(handle));
return 0;
}
@@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
* Method: CommunicatorFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
(JNIEnv *jenv, jclass jcls) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
jclass) {
JVM_CHECK_CALL(XGCommunicatorFinalize());
return 0;
}

View File

@@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *, jclass, jobjectArray);
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorAllreduce

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -298,7 +298,7 @@ public class DMatrixTest {
@Test
public void testTrainWithDenseMatrixRef() throws XGBoostError {
Map<String, String> rabitEnv = new HashMap<>();
Map<String, Object> rabitEnv = new HashMap<>();
rabitEnv.put("DMLC_TASK_ID", "0");
Communicator.init(rabitEnv);
DMatrix trainMat = null;