Initial support for federated learning (#7831)

Federated learning plugin for xgboost:
* A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers.
* A Rabit engine for the federated environment.
* Integration test to simulate federated learning.

Additional followups are needed to address GPU support, better security, and privacy, etc.
This commit is contained in:
Rong Ou
2022-05-05 06:49:22 -07:00
committed by GitHub
parent 46e0bce212
commit 14ef38b834
16 changed files with 1087 additions and 1 deletions

View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
rm -f ./*.model* ./agaricus* ./*.pem
world_size=3
# Generate server and client certificates.
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
# Split train and test files manually to simulate a federated environment.
split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train-
split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test-
python test_federated.py ${world_size}

View File

@@ -0,0 +1,78 @@
#!/usr/bin/python
import multiprocessing
import sys
import time
import xgboost as xgb
import xgboost.federated
SERVER_KEY = 'server-key.pem'
SERVER_CERT = 'server-cert.pem'
CLIENT_KEY = 'client-key.pem'
CLIENT_CERT = 'client-cert.pem'
def run_server(port: int, world_size: int) -> None:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
def run_worker(port: int, world_size: int, rank: int) -> None:
# Always call this before using distributed module
rabit_env = [
f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}',
f'federated_rank={rank}',
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]
xgb.rabit.init([e.encode() for e in rabit_env])
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model.json")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
def run_test() -> None:
port = 9091
world_size = int(sys.argv[1])
server = multiprocessing.Process(target=run_server, args=(port, world_size))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
server.terminate()
if __name__ == '__main__':
run_test()