Initial support for federated learning (#7831)
Federated learning plugin for xgboost: * A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers. * A Rabit engine for the federated environment. * Integration test to simulate federated learning. Additional followups are needed to address GPU support, better security, and privacy, etc.
This commit is contained in:
@@ -28,6 +28,10 @@
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_server.h"
|
||||
#endif
|
||||
|
||||
using namespace xgboost; // NOLINT(*);
|
||||
|
||||
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
||||
@@ -95,6 +99,12 @@ XGB_DLL int XGBuildInfo(char const **out) {
|
||||
info["DEBUG"] = Boolean{false};
|
||||
#endif
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
info["USE_FEDERATED"] = Boolean{true};
|
||||
#else
|
||||
info["USE_FEDERATED"] = Boolean{false};
|
||||
#endif
|
||||
|
||||
XGBBuildInfoDevice(&info);
|
||||
|
||||
auto &out_str = GlobalConfigAPIThreadLocalStore::Get()->ret_str;
|
||||
@@ -198,11 +208,15 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
bool load_row_split = false;
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
#else
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||
<< "will split data among workers";
|
||||
load_row_split = true;
|
||||
}
|
||||
#endif
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
|
||||
API_END();
|
||||
}
|
||||
@@ -1342,5 +1356,14 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
|
||||
char const *server_cert_path, char const *client_cert_path) {
|
||||
API_BEGIN();
|
||||
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
||||
API_END();
|
||||
}
|
||||
#endif
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
Reference in New Issue
Block a user