JNI wrapper for the collective communicator (#8242)

This commit is contained in:
Rong Ou
2022-09-20 13:20:25 -07:00
committed by GitHub
parent fffb1fca52
commit 7d43e74e71
5 changed files with 570 additions and 0 deletions

View File

@@ -977,6 +977,89 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
xgboost::Json config{xgboost::Object{}};
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
assert(len % 2 == 0);
for (bst_ulong i = 0; i < len / 2; ++i) {
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
config[key_str] = xgboost::String(value_str);
}
std::string json_str;
xgboost::Json::Dump(config, &json_str);
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
(JNIEnv *jenv, jclass jcls) {
JVM_CHECK_CALL(XGCommunicatorFinalize());
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
JVM_CHECK_CALL(XGCommunicatorPrint(str.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
jint rank = XGCommunicatorGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
jint out = XGCommunicatorGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllreduce
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
JVM_CHECK_CALL(XGCommunicatorAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op));
return 0;
}
namespace xgboost {
namespace jni {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,