JNI wrapper for the collective communicator (#8242)

This commit is contained in:
Rong Ou
2022-09-20 13:20:25 -07:00
committed by GitHub
parent fffb1fca52
commit 7d43e74e71
5 changed files with 570 additions and 0 deletions

View File

@@ -0,0 +1,152 @@
package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
* Collective communicator global class for synchronization.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*/
public class Communicator {
public enum OpType implements Serializable {
MAX(0), MIN(1), SUM(2);
private int op;
public int getOperand() {
return this.op;
}
OpType(int op) {
this.op = op;
}
}
public enum DataType implements Serializable {
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
private final int enumOp;
private final int size;
public int getEnumOp() {
return this.enumOp;
}
public int getSize() {
return this.size;
}
DataType(int enumOp, int size) {
this.enumOp = enumOp;
this.size = size;
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
}
}
// used as way to test/debug passed communicator init parameters
public static Map<String, String> communicatorEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the collective communicator on current working thread.
*
* @param envs The additional environment variables to pass to the communicator.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
communicatorEnvs = envs;
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey();
args[idx++] = e.getValue();
}
// pass list of rabit mock strings eg mock=0,1,0,0
for (String mock : mockList) {
args[idx++] = "mock";
args[idx++] = mock;
}
checkCall(XGBoostJNI.CommunicatorInit(args));
}
/**
* Shutdown the communicator in current working thread, equals to finalize.
*
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XGBoostJNI.CommunicatorFinalize());
}
/**
* Print the message via the communicator.
*
* @param msg
* @throws XGBoostError
*/
public static void communicatorPrint(String msg) throws XGBoostError {
checkCall(XGBoostJNI.CommunicatorPrint(msg));
}
/**
* get rank of current thread.
*
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.CommunicatorGetRank(out));
return out[0];
}
/**
* get world size of current job.
*
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XGBoostJNI.CommunicatorGetWorldSize(out));
return out[0];
}
/**
* perform Allreduce on distributed float vectors using operator op.
*
* @param elements local elements on distributed workers.
* @param op operator used for Allreduce.
* @return All-reduced float elements according to the given operator.
*/
public static float[] allReduce(float[] elements, OpType op) {
DataType dataType = DataType.FLOAT32;
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
.order(ByteOrder.nativeOrder());
for (float el : elements) {
buffer.putFloat(el);
}
buffer.flip();
XGBoostJNI.CommunicatorAllreduce(buffer, elements.length, dataType.getEnumOp(),
op.getOperand());
float[] results = new float[elements.length];
buffer.asFloatBuffer().get(results);
return results;
}
}

View File

@@ -148,6 +148,17 @@ class XGBoostJNI {
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorFinalize();
public final static native int CommunicatorPrint(String msg);
public final static native int CommunicatorGetRank(int[] out);
public final static native int CommunicatorGetWorldSize(int[] out);
// Perform Allreduce operation on data in sendrecvbuf.
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);

View File

@@ -977,6 +977,89 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
xgboost::Json config{xgboost::Object{}};
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
assert(len % 2 == 0);
for (bst_ulong i = 0; i < len / 2; ++i) {
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
config[key_str] = xgboost::String(value_str);
}
std::string json_str;
xgboost::Json::Dump(config, &json_str);
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
(JNIEnv *jenv, jclass jcls) {
JVM_CHECK_CALL(XGCommunicatorFinalize());
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
JVM_CHECK_CALL(XGCommunicatorPrint(str.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
jint rank = XGCommunicatorGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
jint out = XGCommunicatorGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllreduce
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
JVM_CHECK_CALL(XGCommunicatorAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op));
return 0;
}
namespace xgboost {
namespace jni {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,

View File

@@ -335,6 +335,54 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *, jclass, jobject, jint, jint, jint);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *, jclass, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRank
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllreduce
(JNIEnv *, jclass, jobject, jint, jint, jint);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetInfoFromInterface