[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -1,154 +0,0 @@
package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
public enum OpType implements Serializable {
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
private int op;
public int getOperand() {
return this.op;
}
OpType(int op) {
this.op = op;
}
}
public enum DataType implements Serializable {
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
LONGLONG(8, 8), ULONGLONG(9, 8);
private int enumOp;
private int size;
public int getEnumOp() {
return this.enumOp;
}
public int getSize() {
return this.size;
}
DataType(int enumOp, int size) {
this.enumOp = enumOp;
this.size = size;
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
}
}
// used as way to test/debug passed rabit init parameters
public static Map<String, String> rabitEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
rabitEnvs = envs;
String[] args = new String[envs.size() + mockList.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
// pass list of rabit mock strings eg mock=0,1,0,0
for(String mock : mockList) {
args[idx++] = "mock=" + mock;
}
checkCall(XGBoostJNI.RabitInit(args));
}
/**
* Shutdown the rabit engine in current working thread, equals to finalize.
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XGBoostJNI.RabitFinalize());
}
/**
* Print the message on rabit tracker.
* @param msg
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
}
/**
* Get version number of current stored model in the thread.
* which means how many calls to CheckPoint we made so far.
* @return version Number.
* @throws XGBoostError
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.RabitVersionNumber(out));
return out[0];
}
/**
* get rank of current thread.
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.RabitGetRank(out));
return out[0];
}
/**
* get world size of current job.
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.RabitGetWorldSize(out));
return out[0];
}
/**
* perform Allreduce on distributed float vectors using operator op.
* This implementation of allReduce does not support customized prepare function callback in the
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
*
* @param elements local elements on distributed workers.
* @param op operator used for Allreduce.
* @return All-reduced float elements according to the given operator.
*/
public static float[] allReduce(float[] elements, OpType op) {
DataType dataType = DataType.FLOAT;
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
.order(ByteOrder.nativeOrder());
for (float el : elements) {
buffer.putFloat(el);
}
buffer.flip();
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
float[] results = new float[elements.length];
buffer.asFloatBuffer().get(results);
return results;
}
}

View File

@@ -254,16 +254,16 @@ public class XGBoost {
}
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
if (shouldPrint(params, iter)) {
Rabit.trackerPrint(String.format(
Communicator.communicatorPrint(String.format(
"early stopping after %d rounds away from the best iteration",
earlyStoppingRounds
));
}
break;
}
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
if (shouldPrint(params, iter)){
Rabit.trackerPrint(evalInfo + '\n');
Communicator.communicatorPrint(evalInfo + '\n');
}
}
}

View File

@@ -135,19 +135,6 @@ class XGBoostJNI {
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
// rabit functions
public final static native int RabitInit(String[] args);
public final static native int RabitFinalize();
public final static native int RabitTrackerPrint(String msg);
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
// Perform Allreduce operation on data in sendrecvbuf.
// This JNI function does not support the callback function for data preparation yet.
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorFinalize();