Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1,20 +1,21 @@
/**
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
* Copyright 2014-2024, XGBoost Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./xgboost4j.h"
#include <rabit/c_api.h>
#include <xgboost/base.h>
#include <xgboost/c_api.h>
#include <xgboost/json.h>
@@ -23,7 +24,6 @@
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
@@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
jclass jcls,
jstring jargs) {
xgboost::Json config{xgboost::Object{}};
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
assert(len % 2 == 0);
for (bst_ulong i = 0; i < len / 2; ++i) {
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
config[key_str] = xgboost::String(value_str);
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
JVM_CHECK_CALL(XGCommunicatorInit(args));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
jlongArray jout) {
using namespace xgboost; // NOLINT
TrackerHandle handle;
Json config{Object{}};
std::string shost{jenv->GetStringUTFChars(host, nullptr),
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
if (!shost.empty()) {
config["host"] = shost;
}
std::string json_str;
xgboost::Json::Dump(config, &json_str);
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
config["port"] = Integer{static_cast<Integer::Int>(port)};
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
config["dmlc_communicator"] = String{"rabit"};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
setHandle(jenv, jout, handle);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
jlong jhandle,
jlong timeout) {
using namespace xgboost; // NOLINT
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
using namespace xgboost; // NOLINT
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
char const *args;
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
auto jargs = Json::Load(StringView{args});
jstring jret = jenv->NewStringUTF(args);
jenv->SetObjectArrayElement(jout, 0, jret);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerFree(handle));
return 0;
}
@@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
* Method: CommunicatorFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
(JNIEnv *jenv, jclass jcls) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
jclass) {
JVM_CHECK_CALL(XGCommunicatorFinalize());
return 0;
}

View File

@@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *, jclass, jobjectArray);
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorAllreduce