Allow insecure gRPC connections for federated learning (#8181)
* Allow insecure gRPC connections for federated learning * format
This commit is contained in:
@@ -12,21 +12,28 @@ CLIENT_KEY = 'client-key.pem'
|
||||
CLIENT_CERT = 'client-cert.pem'
|
||||
|
||||
|
||||
def run_server(port: int, world_size: int) -> None:
|
||||
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
|
||||
CLIENT_CERT)
|
||||
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
||||
if with_ssl:
|
||||
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
|
||||
CLIENT_CERT)
|
||||
else:
|
||||
xgboost.federated.run_federated_server(port, world_size)
|
||||
|
||||
|
||||
def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
|
||||
# Always call this before using distributed module
|
||||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
||||
rabit_env = [
|
||||
f'federated_server_address=localhost:{port}',
|
||||
f'federated_world_size={world_size}',
|
||||
f'federated_rank={rank}',
|
||||
f'federated_server_cert={SERVER_CERT}',
|
||||
f'federated_client_key={CLIENT_KEY}',
|
||||
f'federated_client_cert={CLIENT_CERT}'
|
||||
f'federated_rank={rank}'
|
||||
]
|
||||
if with_ssl:
|
||||
rabit_env = rabit_env + [
|
||||
f'federated_server_cert={SERVER_CERT}',
|
||||
f'federated_client_key={CLIENT_KEY}',
|
||||
f'federated_client_cert={CLIENT_CERT}'
|
||||
]
|
||||
|
||||
# Always call this before using distributed module
|
||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||
@@ -52,11 +59,11 @@ def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
|
||||
def run_test(with_gpu: bool = False) -> None:
|
||||
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
||||
port = 9091
|
||||
world_size = int(sys.argv[1])
|
||||
|
||||
server = multiprocessing.Process(target=run_server, args=(port, world_size))
|
||||
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
|
||||
server.start()
|
||||
time.sleep(1)
|
||||
if not server.is_alive():
|
||||
@@ -64,7 +71,8 @@ def run_test(with_gpu: bool = False) -> None:
|
||||
|
||||
workers = []
|
||||
for rank in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank, with_gpu))
|
||||
worker = multiprocessing.Process(target=run_worker,
|
||||
args=(port, world_size, rank, with_ssl, with_gpu))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
@@ -73,5 +81,7 @@ def run_test(with_gpu: bool = False) -> None:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_test()
|
||||
run_test(with_gpu=True)
|
||||
run_test(with_ssl=True, with_gpu=False)
|
||||
run_test(with_ssl=False, with_gpu=False)
|
||||
run_test(with_ssl=True, with_gpu=True)
|
||||
run_test(with_ssl=False, with_gpu=True)
|
||||
|
||||
Reference in New Issue
Block a user