Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -7,6 +7,9 @@ import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* Collective communicator global class for synchronization.
|
||||
*
|
||||
@@ -30,8 +33,9 @@ public class Communicator {
|
||||
}
|
||||
|
||||
public enum DataType implements Serializable {
|
||||
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
|
||||
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
|
||||
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
|
||||
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
|
||||
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
|
||||
|
||||
private final int enumOp;
|
||||
private final int size;
|
||||
@@ -56,30 +60,20 @@ public class Communicator {
|
||||
}
|
||||
}
|
||||
|
||||
// used as way to test/debug passed communicator init parameters
|
||||
public static Map<String, String> communicatorEnvs;
|
||||
public static List<String> mockList = new LinkedList<>();
|
||||
|
||||
/**
|
||||
* Initialize the collective communicator on current working thread.
|
||||
*
|
||||
* @param envs The additional environment variables to pass to the communicator.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
communicatorEnvs = envs;
|
||||
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey();
|
||||
args[idx++] = e.getValue();
|
||||
public static void init(Map<String, Object> envs) throws XGBoostError {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
try {
|
||||
String jconfig = mapper.writeValueAsString(envs);
|
||||
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
|
||||
}
|
||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
||||
for (String mock : mockList) {
|
||||
args[idx++] = "mock";
|
||||
args[idx++] = mock;
|
||||
}
|
||||
checkCall(XGBoostJNI.CommunicatorInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Interface for Rabit tracker implementations with three public methods:
|
||||
* Interface for a tracker implementations with three public methods:
|
||||
*
|
||||
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
|
||||
* timeout value (in milliseconds.)
|
||||
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
|
||||
* - start(timeout): Start the tracker awaiting for worker connections, with a given
|
||||
* timeout value (in seconds).
|
||||
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||
* milliseconds.
|
||||
*
|
||||
@@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
|
||||
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
|
||||
* brokers connections between workers.
|
||||
*/
|
||||
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
public interface ITracker extends Thread.UncaughtExceptionHandler {
|
||||
enum TrackerStatus {
|
||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||
|
||||
@@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, String> getWorkerEnvs();
|
||||
boolean start(long workerConnectionTimeout);
|
||||
void stop();
|
||||
// taskExecutionTimeout has no effect in current version of XGBoost.
|
||||
int waitFor(long taskExecutionTimeout);
|
||||
Map<String, Object> workerArgs() throws XGBoostError;
|
||||
|
||||
boolean start() throws XGBoostError;
|
||||
|
||||
void stop() throws XGBoostError;
|
||||
|
||||
void waitFor(long taskExecutionTimeout) throws XGBoostError;
|
||||
}
|
||||
@@ -1,101 +1,40 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
|
||||
* start() and waitFor() methods (i.e., the timeout is infinite.)
|
||||
*
|
||||
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
|
||||
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
|
||||
* provides timeout support.
|
||||
*
|
||||
* The tracker must be started on driver node before running distributed jobs.
|
||||
*/
|
||||
public class RabitTracker implements IRabitTracker {
|
||||
public class RabitTracker implements ITracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
// tracker python file.
|
||||
private static String tracker_py = null;
|
||||
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int numWorkers;
|
||||
private String hostIp = "";
|
||||
private String pythonExec = "";
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
private long handle = 0;
|
||||
private Thread tracker_daemon;
|
||||
|
||||
static {
|
||||
try {
|
||||
initTrackerPy();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load tracker library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
public RabitTracker(int numWorkers) throws XGBoostError {
|
||||
this(numWorkers, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracker logger that logs output from tracker.
|
||||
*/
|
||||
private class TrackerProcessLogger implements Runnable {
|
||||
public void run() {
|
||||
|
||||
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(
|
||||
trackerProcess.get().getErrorStream()));
|
||||
String line;
|
||||
try {
|
||||
while ((line = reader.readLine()) != null) {
|
||||
trackerProcessLogger.info(line);
|
||||
}
|
||||
trackerProcess.get().waitFor();
|
||||
int exitValue = trackerProcess.get().exitValue();
|
||||
if (exitValue != 0) {
|
||||
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
|
||||
} else {
|
||||
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
trackerProcessLogger.error(ex.toString());
|
||||
} catch (InterruptedException ie) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
ie.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void initTrackerPy() throws IOException {
|
||||
try {
|
||||
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
|
||||
} catch (IOException ioe) {
|
||||
logger.trace("cannot access tracker python script");
|
||||
throw ioe;
|
||||
}
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers)
|
||||
public RabitTracker(int numWorkers, String hostIp)
|
||||
throws XGBoostError {
|
||||
this(numWorkers, hostIp, 0, 300);
|
||||
}
|
||||
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
|
||||
if (numWorkers < 1) {
|
||||
throw new XGBoostError("numWorkers must be greater equal to one");
|
||||
}
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
|
||||
throws XGBoostError {
|
||||
this(numWorkers);
|
||||
this.hostIp = hostIp;
|
||||
this.pythonExec = pythonExec;
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
|
||||
this.handle = out[0];
|
||||
}
|
||||
|
||||
public void uncaughtException(Thread t, Throwable e) {
|
||||
@@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
|
||||
} catch (InterruptedException ex) {
|
||||
logger.error(ex);
|
||||
} finally {
|
||||
trackerProcess.get().destroy();
|
||||
this.tracker_daemon.interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
|
||||
* Get environments that can be used to pass to worker.
|
||||
* @return The environment settings.
|
||||
*/
|
||||
public Map<String, String> getWorkerEnvs() {
|
||||
return envs;
|
||||
public Map<String, Object> workerArgs() throws XGBoostError {
|
||||
// fixme: timeout
|
||||
String[] args = new String[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
|
||||
};
|
||||
Map<String, Object> config;
|
||||
try {
|
||||
config = mapper.readValue(args[0], typeRef);
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to get worker arguments.", ex);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private void loadEnvs(InputStream ins) throws IOException {
|
||||
try {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
|
||||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||
break;
|
||||
}
|
||||
String[] sep = line.split("=");
|
||||
if (sep.length == 2) {
|
||||
envs.put(sep[0], sep[1]);
|
||||
}
|
||||
public void stop() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
|
||||
}
|
||||
|
||||
public boolean start() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
|
||||
this.tracker_daemon = new Thread(() -> {
|
||||
try {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return; // exit the thread
|
||||
}
|
||||
reader.close();
|
||||
} catch (IOException ioe){
|
||||
logger.error("cannot get runtime configuration from tracker process");
|
||||
ioe.printStackTrace();
|
||||
throw ioe;
|
||||
}
|
||||
});
|
||||
this.tracker_daemon.setDaemon(true);
|
||||
this.tracker_daemon.start();
|
||||
|
||||
return this.tracker_daemon.isAlive();
|
||||
}
|
||||
|
||||
/** visible for testing */
|
||||
public String getRabitTrackerCommand() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (pythonExec == null || pythonExec.isEmpty()) {
|
||||
sb.append("python ");
|
||||
} else {
|
||||
sb.append(pythonExec + " ");
|
||||
}
|
||||
sb.append(" " + tracker_py + " ");
|
||||
sb.append(" --log-level=DEBUG" + " ");
|
||||
sb.append(" --num-workers=" + numWorkers + " ");
|
||||
|
||||
// we first check the property then check the parameter
|
||||
String hostIpFromProperties = trackerProperties.getHostIp();
|
||||
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
|
||||
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
|
||||
sb.append(" --host-ip=" + hostIpFromProperties + " ");
|
||||
} else if (hostIp != null & !hostIp.isEmpty()) {
|
||||
logger.debug("Using the parametr host-ip: " + hostIp);
|
||||
sb.append(" --host-ip=" + hostIp + " ");
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
String cmd = getRabitTrackerCommand();
|
||||
trackerProcess.set(Runtime.getRuntime().exec(cmd));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
if (trackerProcess.get() != null) {
|
||||
trackerProcess.get().destroy();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean start(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for all workers to connect indefinitely, unless " +
|
||||
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
if (startTrackerProcess()) {
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
System.out.println("Tracker started, with env=" + envs.toString());
|
||||
// also start a tracker logger
|
||||
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||
logger_thread.setDaemon(true);
|
||||
logger_thread.start();
|
||||
return true;
|
||||
} else {
|
||||
logger.error("FAULT: failed to start tracker process");
|
||||
stop();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public int waitFor(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for either all workers to finish tasks and send " +
|
||||
"shutdown signal, or manual interruptions. " +
|
||||
"Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
try {
|
||||
trackerProcess.get().waitFor();
|
||||
int returnVal = trackerProcess.get().exitValue();
|
||||
logger.info("Tracker Process ends with exit code " + returnVal);
|
||||
stop();
|
||||
return returnVal;
|
||||
} catch (InterruptedException e) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
e.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
return TrackerStatus.INTERRUPTED.getStatusCode();
|
||||
}
|
||||
public void waitFor(long timeout) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -54,7 +54,7 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
|
||||
float[] data, int shapeParam,
|
||||
@@ -146,12 +146,24 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorInit(String args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
public final static native int CommunicatorPrint(String msg);
|
||||
public final static native int CommunicatorGetRank(int[] out);
|
||||
public final static native int CommunicatorGetWorldSize(int[] out);
|
||||
|
||||
// Tracker functions
|
||||
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
|
||||
long[] out);
|
||||
|
||||
public final static native int TrackerRun(long handle);
|
||||
|
||||
public final static native int TrackerWaitFor(long handle, long timeout);
|
||||
|
||||
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
|
||||
|
||||
public final static native int TrackerFree(long handle);
|
||||
|
||||
// Perform Allreduce operation on data in sendrecvbuf.
|
||||
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
@@ -168,5 +180,4 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||
|
||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||
|
||||
}
|
||||
|
||||
@@ -42,5 +42,4 @@ public final class UtilUnsafe {
|
||||
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
/**
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "./xgboost4j.h"
|
||||
|
||||
#include <rabit/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/json.h>
|
||||
@@ -23,7 +24,6 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
@@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
|
||||
jclass jcls,
|
||||
jstring jargs) {
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
assert(len % 2 == 0);
|
||||
for (bst_ulong i = 0; i < len / 2; ++i) {
|
||||
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
|
||||
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
|
||||
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
|
||||
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
|
||||
config[key_str] = xgboost::String(value_str);
|
||||
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(args));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
|
||||
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
|
||||
jlongArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
TrackerHandle handle;
|
||||
Json config{Object{}};
|
||||
std::string shost{jenv->GetStringUTFChars(host, nullptr),
|
||||
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
|
||||
if (!shost.empty()) {
|
||||
config["host"] = shost;
|
||||
}
|
||||
std::string json_str;
|
||||
xgboost::Json::Dump(config, &json_str);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
|
||||
config["port"] = Integer{static_cast<Integer::Int>(port)};
|
||||
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
|
||||
config["dmlc_communicator"] = String{"rabit"};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
|
||||
setHandle(jenv, jout, handle);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
|
||||
jlong jhandle,
|
||||
jlong timeout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
char const *args;
|
||||
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
|
||||
auto jargs = Json::Load(StringView{args});
|
||||
|
||||
jstring jret = jenv->NewStringUTF(args);
|
||||
jenv->SetObjectArrayElement(jout, 0, jret);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerFree(handle));
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
* Method: CommunicatorFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
|
||||
jclass) {
|
||||
JVM_CHECK_CALL(XGCommunicatorFinalize());
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
|
||||
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
|
||||
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorAllreduce
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -298,7 +298,7 @@ public class DMatrixTest {
|
||||
|
||||
@Test
|
||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||
Map<String, String> rabitEnv = new HashMap<>();
|
||||
Map<String, Object> rabitEnv = new HashMap<>();
|
||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||
Communicator.init(rabitEnv);
|
||||
DMatrix trainMat = null;
|
||||
|
||||
Reference in New Issue
Block a user