Initial support for federated learning (#7831)

Federated learning plugin for xgboost:
* A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers.
* A Rabit engine for the federated environment.
* Integration test to simulate federated learning.

Additional followups are needed to address GPU support, better security, and privacy, etc.
This commit is contained in:
Rong Ou
2022-05-05 06:49:22 -07:00
committed by GitHub
parent 46e0bce212
commit 14ef38b834
16 changed files with 1087 additions and 1 deletions

View File

@@ -28,6 +28,10 @@
#include "../data/simple_dmatrix.h"
#include "../data/proxy_dmatrix.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
#endif
using namespace xgboost; // NOLINT(*);
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
@@ -95,6 +99,12 @@ XGB_DLL int XGBuildInfo(char const **out) {
info["DEBUG"] = Boolean{false};
#endif
#if defined(XGBOOST_USE_FEDERATED)
info["USE_FEDERATED"] = Boolean{true};
#else
info["USE_FEDERATED"] = Boolean{false};
#endif
XGBBuildInfoDevice(&info);
auto &out_str = GlobalConfigAPIThreadLocalStore::Get()->ret_str;
@@ -198,11 +208,15 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
DMatrixHandle *out) {
API_BEGIN();
bool load_row_split = false;
#if defined(XGBOOST_USE_FEDERATED)
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
#else
if (rabit::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
load_row_split = true;
}
#endif
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
API_END();
}
@@ -1342,5 +1356,14 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
API_END();
}
#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {
API_BEGIN();
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}
#endif
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();