Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -7,6 +7,9 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* Collective communicator global class for synchronization.
*
@@ -30,8 +33,9 @@ public class Communicator {
}
public enum DataType implements Serializable {
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
private final int enumOp;
private final int size;
@@ -56,30 +60,20 @@ public class Communicator {
}
}
// used as way to test/debug passed communicator init parameters
public static Map<String, String> communicatorEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the collective communicator on current working thread.
*
* @param envs The additional environment variables to pass to the communicator.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
communicatorEnvs = envs;
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey();
args[idx++] = e.getValue();
public static void init(Map<String, Object> envs) throws XGBoostError {
ObjectMapper mapper = new ObjectMapper();
try {
String jconfig = mapper.writeValueAsString(envs);
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
}
// pass list of rabit mock strings eg mock=0,1,0,0
for (String mock : mockList) {
args[idx++] = "mock";
args[idx++] = mock;
}
checkCall(XGBoostJNI.CommunicatorInit(args));
}
/**

View File

@@ -1,14 +1,13 @@
package ml.dmlc.xgboost4j.java;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* Interface for Rabit tracker implementations with three public methods:
* Interface for a tracker implementations with three public methods:
*
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
* timeout value (in milliseconds.)
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
@@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
* brokers connections between workers.
*/
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
@@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
}
}
Map<String, String> getWorkerEnvs();
boolean start(long workerConnectionTimeout);
void stop();
// taskExecutionTimeout has no effect in current version of XGBoost.
int waitFor(long taskExecutionTimeout);
Map<String, Object> workerArgs() throws XGBoostError;
boolean start() throws XGBoostError;
void stop() throws XGBoostError;
void waitFor(long taskExecutionTimeout) throws XGBoostError;
}

View File

@@ -1,101 +1,40 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
* start() and waitFor() methods (i.e., the timeout is infinite.)
*
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
* provides timeout support.
*
* The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements IRabitTracker {
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
private static String tracker_py = null;
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int numWorkers;
private String hostIp = "";
private String pythonExec = "";
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
private long handle = 0;
private Thread tracker_daemon;
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
}
/**
* Tracker logger that logs output from tracker.
*/
private class TrackerProcessLogger implements Runnable {
public void run() {
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
BufferedReader reader = new BufferedReader(new InputStreamReader(
trackerProcess.get().getErrorStream()));
String line;
try {
while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line);
}
trackerProcess.get().waitFor();
int exitValue = trackerProcess.get().exitValue();
if (exitValue != 0) {
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
} else {
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
}
} catch (IOException ex) {
trackerProcessLogger.error(ex.toString());
} catch (InterruptedException ie) {
// we should not get here as RabitTracker is accessed in the main thread
ie.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}
private static void initTrackerPy() throws IOException {
try {
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
} catch (IOException ioe) {
logger.trace("cannot access tracker python script");
throw ioe;
}
}
public RabitTracker(int numWorkers)
public RabitTracker(int numWorkers, String hostIp)
throws XGBoostError {
this(numWorkers, hostIp, 0, 300);
}
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
if (numWorkers < 1) {
throw new XGBoostError("numWorkers must be greater equal to one");
}
this.numWorkers = numWorkers;
}
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
throws XGBoostError {
this(numWorkers);
this.hostIp = hostIp;
this.pythonExec = pythonExec;
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
this.handle = out[0];
}
public void uncaughtException(Thread t, Throwable e) {
@@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
trackerProcess.get().destroy();
this.tracker_daemon.interrupt();
}
}
@@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, String> getWorkerEnvs() {
return envs;
public Map<String, Object> workerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config;
try {
config = mapper.readValue(args[0], typeRef);
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex);
}
return config;
}
private void loadEnvs(InputStream ins) throws IOException {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break;
}
String[] sep = line.split("=");
if (sep.length == 2) {
envs.put(sep[0], sep[1]);
}
public void stop() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
}
public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> {
try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
} catch (XGBoostError ex) {
logger.error(ex);
return; // exit the thread
}
reader.close();
} catch (IOException ioe){
logger.error("cannot get runtime configuration from tracker process");
ioe.printStackTrace();
throw ioe;
}
});
this.tracker_daemon.setDaemon(true);
this.tracker_daemon.start();
return this.tracker_daemon.isAlive();
}
/** visible for testing */
public String getRabitTrackerCommand() {
StringBuilder sb = new StringBuilder();
if (pythonExec == null || pythonExec.isEmpty()) {
sb.append("python ");
} else {
sb.append(pythonExec + " ");
}
sb.append(" " + tracker_py + " ");
sb.append(" --log-level=DEBUG" + " ");
sb.append(" --num-workers=" + numWorkers + " ");
// we first check the property then check the parameter
String hostIpFromProperties = trackerProperties.getHostIp();
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
sb.append(" --host-ip=" + hostIpFromProperties + " ");
} else if (hostIp != null & !hostIp.isEmpty()) {
logger.debug("Using the parametr host-ip: " + hostIp);
sb.append(" --host-ip=" + hostIp + " ");
}
return sb.toString();
}
private boolean startTrackerProcess() {
try {
String cmd = getRabitTrackerCommand();
trackerProcess.set(Runtime.getRuntime().exec(cmd));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}
public void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
}
}
public boolean start(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for all workers to connect indefinitely, unless " +
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
}
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);
logger_thread.start();
return true;
} else {
logger.error("FAULT: failed to start tracker process");
stop();
return false;
}
}
public int waitFor(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for either all workers to finish tasks and send " +
"shutdown signal, or manual interruptions. " +
"Use the Scala RabitTracker for timeout support.");
}
try {
trackerProcess.get().waitFor();
int returnVal = trackerProcess.get().exitValue();
logger.info("Tracker Process ends with exit code " + returnVal);
stop();
return returnVal;
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
return TrackerStatus.INTERRUPTED.getStatusCode();
}
public void waitFor(long timeout) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -54,7 +54,7 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
float[] data, int shapeParam,
@@ -146,12 +146,24 @@ class XGBoostJNI {
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorInit(String args);
public final static native int CommunicatorFinalize();
public final static native int CommunicatorPrint(String msg);
public final static native int CommunicatorGetRank(int[] out);
public final static native int CommunicatorGetWorldSize(int[] out);
// Tracker functions
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
long[] out);
public final static native int TrackerRun(long handle);
public final static native int TrackerWaitFor(long handle, long timeout);
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
public final static native int TrackerFree(long handle);
// Perform Allreduce operation on data in sendrecvbuf.
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
@@ -168,5 +180,4 @@ class XGBoostJNI {
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
}

View File

@@ -42,5 +42,4 @@ public final class UtilUnsafe {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
}
}
}