Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -1,20 +1,21 @@
|
||||
/**
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "./xgboost4j.h"
|
||||
|
||||
#include <rabit/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/json.h>
|
||||
@@ -23,7 +24,6 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
@@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
|
||||
jclass jcls,
|
||||
jstring jargs) {
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
assert(len % 2 == 0);
|
||||
for (bst_ulong i = 0; i < len / 2; ++i) {
|
||||
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
|
||||
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
|
||||
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
|
||||
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
|
||||
config[key_str] = xgboost::String(value_str);
|
||||
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(args));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
|
||||
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
|
||||
jlongArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
TrackerHandle handle;
|
||||
Json config{Object{}};
|
||||
std::string shost{jenv->GetStringUTFChars(host, nullptr),
|
||||
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
|
||||
if (!shost.empty()) {
|
||||
config["host"] = shost;
|
||||
}
|
||||
std::string json_str;
|
||||
xgboost::Json::Dump(config, &json_str);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
|
||||
config["port"] = Integer{static_cast<Integer::Int>(port)};
|
||||
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
|
||||
config["dmlc_communicator"] = String{"rabit"};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
|
||||
setHandle(jenv, jout, handle);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
|
||||
jlong jhandle,
|
||||
jlong timeout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
char const *args;
|
||||
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
|
||||
auto jargs = Json::Load(StringView{args});
|
||||
|
||||
jstring jret = jenv->NewStringUTF(args);
|
||||
jenv->SetObjectArrayElement(jout, 0, jret);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerFree(handle));
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
* Method: CommunicatorFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
|
||||
jclass) {
|
||||
JVM_CHECK_CALL(XGCommunicatorFinalize());
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
|
||||
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
|
||||
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorAllreduce
|
||||
|
||||
Reference in New Issue
Block a user