Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -7,6 +7,9 @@ import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* Collective communicator global class for synchronization.
|
||||
*
|
||||
@@ -30,8 +33,9 @@ public class Communicator {
|
||||
}
|
||||
|
||||
public enum DataType implements Serializable {
|
||||
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
|
||||
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
|
||||
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
|
||||
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
|
||||
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
|
||||
|
||||
private final int enumOp;
|
||||
private final int size;
|
||||
@@ -56,30 +60,20 @@ public class Communicator {
|
||||
}
|
||||
}
|
||||
|
||||
// used as way to test/debug passed communicator init parameters
|
||||
public static Map<String, String> communicatorEnvs;
|
||||
public static List<String> mockList = new LinkedList<>();
|
||||
|
||||
/**
|
||||
* Initialize the collective communicator on current working thread.
|
||||
*
|
||||
* @param envs The additional environment variables to pass to the communicator.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
communicatorEnvs = envs;
|
||||
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey();
|
||||
args[idx++] = e.getValue();
|
||||
public static void init(Map<String, Object> envs) throws XGBoostError {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
try {
|
||||
String jconfig = mapper.writeValueAsString(envs);
|
||||
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
|
||||
}
|
||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
||||
for (String mock : mockList) {
|
||||
args[idx++] = "mock";
|
||||
args[idx++] = mock;
|
||||
}
|
||||
checkCall(XGBoostJNI.CommunicatorInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Interface for Rabit tracker implementations with three public methods:
|
||||
* Interface for a tracker implementations with three public methods:
|
||||
*
|
||||
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
|
||||
* timeout value (in milliseconds.)
|
||||
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
|
||||
* - start(timeout): Start the tracker awaiting for worker connections, with a given
|
||||
* timeout value (in seconds).
|
||||
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||
* milliseconds.
|
||||
*
|
||||
@@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
|
||||
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
|
||||
* brokers connections between workers.
|
||||
*/
|
||||
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
public interface ITracker extends Thread.UncaughtExceptionHandler {
|
||||
enum TrackerStatus {
|
||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||
|
||||
@@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, String> getWorkerEnvs();
|
||||
boolean start(long workerConnectionTimeout);
|
||||
void stop();
|
||||
// taskExecutionTimeout has no effect in current version of XGBoost.
|
||||
int waitFor(long taskExecutionTimeout);
|
||||
Map<String, Object> workerArgs() throws XGBoostError;
|
||||
|
||||
boolean start() throws XGBoostError;
|
||||
|
||||
void stop() throws XGBoostError;
|
||||
|
||||
void waitFor(long taskExecutionTimeout) throws XGBoostError;
|
||||
}
|
||||
@@ -1,101 +1,40 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
|
||||
* start() and waitFor() methods (i.e., the timeout is infinite.)
|
||||
*
|
||||
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
|
||||
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
|
||||
* provides timeout support.
|
||||
*
|
||||
* The tracker must be started on driver node before running distributed jobs.
|
||||
*/
|
||||
public class RabitTracker implements IRabitTracker {
|
||||
public class RabitTracker implements ITracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
// tracker python file.
|
||||
private static String tracker_py = null;
|
||||
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int numWorkers;
|
||||
private String hostIp = "";
|
||||
private String pythonExec = "";
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
private long handle = 0;
|
||||
private Thread tracker_daemon;
|
||||
|
||||
static {
|
||||
try {
|
||||
initTrackerPy();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load tracker library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
public RabitTracker(int numWorkers) throws XGBoostError {
|
||||
this(numWorkers, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracker logger that logs output from tracker.
|
||||
*/
|
||||
private class TrackerProcessLogger implements Runnable {
|
||||
public void run() {
|
||||
|
||||
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(
|
||||
trackerProcess.get().getErrorStream()));
|
||||
String line;
|
||||
try {
|
||||
while ((line = reader.readLine()) != null) {
|
||||
trackerProcessLogger.info(line);
|
||||
}
|
||||
trackerProcess.get().waitFor();
|
||||
int exitValue = trackerProcess.get().exitValue();
|
||||
if (exitValue != 0) {
|
||||
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
|
||||
} else {
|
||||
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
trackerProcessLogger.error(ex.toString());
|
||||
} catch (InterruptedException ie) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
ie.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void initTrackerPy() throws IOException {
|
||||
try {
|
||||
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
|
||||
} catch (IOException ioe) {
|
||||
logger.trace("cannot access tracker python script");
|
||||
throw ioe;
|
||||
}
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers)
|
||||
public RabitTracker(int numWorkers, String hostIp)
|
||||
throws XGBoostError {
|
||||
this(numWorkers, hostIp, 0, 300);
|
||||
}
|
||||
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
|
||||
if (numWorkers < 1) {
|
||||
throw new XGBoostError("numWorkers must be greater equal to one");
|
||||
}
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
|
||||
throws XGBoostError {
|
||||
this(numWorkers);
|
||||
this.hostIp = hostIp;
|
||||
this.pythonExec = pythonExec;
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
|
||||
this.handle = out[0];
|
||||
}
|
||||
|
||||
public void uncaughtException(Thread t, Throwable e) {
|
||||
@@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
|
||||
} catch (InterruptedException ex) {
|
||||
logger.error(ex);
|
||||
} finally {
|
||||
trackerProcess.get().destroy();
|
||||
this.tracker_daemon.interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
|
||||
* Get environments that can be used to pass to worker.
|
||||
* @return The environment settings.
|
||||
*/
|
||||
public Map<String, String> getWorkerEnvs() {
|
||||
return envs;
|
||||
public Map<String, Object> workerArgs() throws XGBoostError {
|
||||
// fixme: timeout
|
||||
String[] args = new String[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
|
||||
};
|
||||
Map<String, Object> config;
|
||||
try {
|
||||
config = mapper.readValue(args[0], typeRef);
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to get worker arguments.", ex);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private void loadEnvs(InputStream ins) throws IOException {
|
||||
try {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
|
||||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||
break;
|
||||
}
|
||||
String[] sep = line.split("=");
|
||||
if (sep.length == 2) {
|
||||
envs.put(sep[0], sep[1]);
|
||||
}
|
||||
public void stop() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
|
||||
}
|
||||
|
||||
public boolean start() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
|
||||
this.tracker_daemon = new Thread(() -> {
|
||||
try {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return; // exit the thread
|
||||
}
|
||||
reader.close();
|
||||
} catch (IOException ioe){
|
||||
logger.error("cannot get runtime configuration from tracker process");
|
||||
ioe.printStackTrace();
|
||||
throw ioe;
|
||||
}
|
||||
});
|
||||
this.tracker_daemon.setDaemon(true);
|
||||
this.tracker_daemon.start();
|
||||
|
||||
return this.tracker_daemon.isAlive();
|
||||
}
|
||||
|
||||
/** visible for testing */
|
||||
public String getRabitTrackerCommand() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (pythonExec == null || pythonExec.isEmpty()) {
|
||||
sb.append("python ");
|
||||
} else {
|
||||
sb.append(pythonExec + " ");
|
||||
}
|
||||
sb.append(" " + tracker_py + " ");
|
||||
sb.append(" --log-level=DEBUG" + " ");
|
||||
sb.append(" --num-workers=" + numWorkers + " ");
|
||||
|
||||
// we first check the property then check the parameter
|
||||
String hostIpFromProperties = trackerProperties.getHostIp();
|
||||
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
|
||||
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
|
||||
sb.append(" --host-ip=" + hostIpFromProperties + " ");
|
||||
} else if (hostIp != null & !hostIp.isEmpty()) {
|
||||
logger.debug("Using the parametr host-ip: " + hostIp);
|
||||
sb.append(" --host-ip=" + hostIp + " ");
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
String cmd = getRabitTrackerCommand();
|
||||
trackerProcess.set(Runtime.getRuntime().exec(cmd));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
if (trackerProcess.get() != null) {
|
||||
trackerProcess.get().destroy();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean start(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for all workers to connect indefinitely, unless " +
|
||||
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
if (startTrackerProcess()) {
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
System.out.println("Tracker started, with env=" + envs.toString());
|
||||
// also start a tracker logger
|
||||
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||
logger_thread.setDaemon(true);
|
||||
logger_thread.start();
|
||||
return true;
|
||||
} else {
|
||||
logger.error("FAULT: failed to start tracker process");
|
||||
stop();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public int waitFor(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for either all workers to finish tasks and send " +
|
||||
"shutdown signal, or manual interruptions. " +
|
||||
"Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
try {
|
||||
trackerProcess.get().waitFor();
|
||||
int returnVal = trackerProcess.get().exitValue();
|
||||
logger.info("Tracker Process ends with exit code " + returnVal);
|
||||
stop();
|
||||
return returnVal;
|
||||
} catch (InterruptedException e) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
e.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
return TrackerStatus.INTERRUPTED.getStatusCode();
|
||||
}
|
||||
public void waitFor(long timeout) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -54,7 +54,7 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
|
||||
float[] data, int shapeParam,
|
||||
@@ -146,12 +146,24 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorInit(String args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
public final static native int CommunicatorPrint(String msg);
|
||||
public final static native int CommunicatorGetRank(int[] out);
|
||||
public final static native int CommunicatorGetWorldSize(int[] out);
|
||||
|
||||
// Tracker functions
|
||||
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
|
||||
long[] out);
|
||||
|
||||
public final static native int TrackerRun(long handle);
|
||||
|
||||
public final static native int TrackerWaitFor(long handle, long timeout);
|
||||
|
||||
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
|
||||
|
||||
public final static native int TrackerFree(long handle);
|
||||
|
||||
// Perform Allreduce operation on data in sendrecvbuf.
|
||||
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
@@ -168,5 +180,4 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||
|
||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||
|
||||
}
|
||||
|
||||
@@ -42,5 +42,4 @@ public final class UtilUnsafe {
|
||||
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user