Allow insecure gRPC connections for federated learning (#8181)

* Allow insecure gRPC connections for federated learning

* format
This commit is contained in:
Rong Ou
2022-08-18 21:16:14 -07:00
committed by GitHub
parent 53d2a733b0
commit ad3bc0edee
7 changed files with 75 additions and 37 deletions

View File

@@ -12,21 +12,28 @@ CLIENT_KEY = 'client-key.pem'
CLIENT_CERT = 'client-cert.pem'
def run_server(port: int, world_size: int) -> None:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
if with_ssl:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
else:
xgboost.federated.run_federated_server(port, world_size)
def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
# Always call this before using distributed module
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
rabit_env = [
f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}',
f'federated_rank={rank}',
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
f'federated_rank={rank}'
]
if with_ssl:
rabit_env = rabit_env + [
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]
# Always call this before using distributed module
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
@@ -52,11 +59,11 @@ def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
xgb.rabit.tracker_print("Finished training\n")
def run_test(with_gpu: bool = False) -> None:
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
port = 9091
world_size = int(sys.argv[1])
server = multiprocessing.Process(target=run_server, args=(port, world_size))
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
server.start()
time.sleep(1)
if not server.is_alive():
@@ -64,7 +71,8 @@ def run_test(with_gpu: bool = False) -> None:
workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank, with_gpu))
worker = multiprocessing.Process(target=run_worker,
args=(port, world_size, rank, with_ssl, with_gpu))
workers.append(worker)
worker.start()
for worker in workers:
@@ -73,5 +81,7 @@ def run_test(with_gpu: bool = False) -> None:
if __name__ == '__main__':
run_test()
run_test(with_gpu=True)
run_test(with_ssl=True, with_gpu=False)
run_test(with_ssl=False, with_gpu=False)
run_test(with_ssl=True, with_gpu=True)
run_test(with_ssl=False, with_gpu=True)