[fed] Fixes for the encrypted GRPC backend. (#10503)

This commit is contained in:
Jiaming Yuan 2024-07-02 15:15:12 +08:00 committed by GitHub
parent 5f0c1e902b
commit a39fef2c67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 192 additions and 109 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost contributors * Copyright 2023-2024, XGBoost contributors
*/ */
#include "federated_comm.h" #include "federated_comm.h"
@ -11,6 +11,7 @@
#include <string> // for string, stoi #include <string> // for string, stoi
#include "../../src/common/common.h" // for Split #include "../../src/common/common.h" // for Split
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg #include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json #include "xgboost/json.h" // for Json
#include "xgboost/logging.h" #include "xgboost/logging.h"
@ -46,9 +47,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
} else { } else {
stub_ = [&] { stub_ = [&] {
grpc::SslCredentialsOptions options; grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert; options.pem_root_certs = common::ReadAll(server_cert);
options.pem_private_key = client_key; options.pem_private_key = common::ReadAll(client_key);
options.pem_cert_chain = client_cert; options.pem_cert_chain = common::ReadAll(client_cert);
grpc::ChannelArguments args; grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port), auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),

View File

@ -39,9 +39,9 @@ class FederatedTracker(RabitTracker):
n_workers: int, n_workers: int,
port: int, port: int,
secure: bool, secure: bool,
server_key_path: str = "", server_key_path: Optional[str] = None,
server_cert_path: str = "", server_cert_path: Optional[str] = None,
client_cert_path: str = "", client_cert_path: Optional[str] = None,
timeout: int = 300, timeout: int = 300,
) -> None: ) -> None:
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
@ -84,7 +84,13 @@ def run_federated_server( # pylint: disable=too-many-arguments
for path in [server_key_path, server_cert_path, client_cert_path] for path in [server_key_path, server_cert_path, client_cert_path]
) )
tracker = FederatedTracker( tracker = FederatedTracker(
n_workers=n_workers, port=port, secure=secure, timeout=timeout n_workers=n_workers,
port=port,
secure=secure,
timeout=timeout,
server_key_path=server_key_path,
server_cert_path=server_cert_path,
client_cert_path=client_cert_path,
) )
tracker.start() tracker.start()

View File

@ -0,0 +1,153 @@
# pylint: disable=unbalanced-tuple-unpacking, too-many-locals
"""Tests for federated learning."""
import multiprocessing
import os
import subprocess
import tempfile
import time
from typing import List, cast
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split
import xgboost as xgb
import xgboost.federated
from xgboost import testing as tm
from xgboost.training import TrainingCallback
SERVER_KEY = "server-key.pem"
SERVER_CERT = "server-cert.pem"
CLIENT_KEY = "client-key.pem"
CLIENT_CERT = "client-cert.pem"
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
"""Run federated server for test."""
if with_ssl:
xgboost.federated.run_federated_server(
world_size,
port,
server_key_path=SERVER_KEY,
server_cert_path=SERVER_CERT,
client_cert_path=CLIENT_CERT,
)
else:
xgboost.federated.run_federated_server(world_size, port)
def run_worker(
port: int, world_size: int, rank: int, with_ssl: bool, device: str
) -> None:
"""Run federated client worker for test."""
communicator_env = {
"dmlc_communicator": "federated",
"federated_server_address": f"localhost:{port}",
"federated_world_size": world_size,
"federated_rank": rank,
}
if with_ssl:
communicator_env["federated_server_cert_path"] = SERVER_CERT
communicator_env["federated_client_key_path"] = CLIENT_KEY
communicator_env["federated_client_cert_path"] = CLIENT_CERT
cpu_count = os.cpu_count()
assert cpu_count is not None
n_threads = cpu_count // world_size
# Always call this before using distributed module
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
X, y = load_svmlight_file(f"agaricus.txt-{rank}.train")
dtrain = xgb.DMatrix(X, y)
X, y = load_svmlight_file(f"agaricus.txt-{rank}.test")
dtest = xgb.DMatrix(X, y)
# Specify parameters via map, definition are same as c++ version
param = {
"max_depth": 2,
"eta": 1,
"objective": "binary:logistic",
"nthread": n_threads,
"tree_method": "hist",
"device": device,
}
# Specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 20
# Run training, all the features in training API is available.
results: TrainingCallback.EvalsLog = {}
bst = xgb.train(
param,
dtrain,
num_round,
evals=watchlist,
early_stopping_rounds=2,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["train"]["logloss"]))
assert tm.non_increasing(cast(List[float], results["eval"]["logloss"]))
# save the model, only ask process 0 to save the model.
if xgb.collective.get_rank() == 0:
with tempfile.TemporaryDirectory() as tmpdir:
bst.save_model(os.path.join(tmpdir, "model.json"))
xgb.collective.communicator_print("Finished training\n")
def run_federated(world_size: int, with_ssl: bool, use_gpu: bool) -> None:
"""Launcher for clients and the server."""
port = 9091
server = multiprocessing.Process(
target=run_server, args=(port, world_size, with_ssl)
)
server.start()
time.sleep(1)
if not server.is_alive():
raise ValueError("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
device = f"cuda:{rank}" if use_gpu else "cpu"
worker = multiprocessing.Process(
target=run_worker, args=(port, world_size, rank, with_ssl, device)
)
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
server.terminate()
def run_federated_learning(with_ssl: bool, use_gpu: bool, test_path: str) -> None:
"""Run federated learning tests."""
n_workers = 2
if with_ssl:
command = "openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout {part}-key.pem -out {part}-cert.pem -subj /C=US/CN=localhost" # pylint: disable=line-too-long
server_key = command.format(part="server").split()
subprocess.check_call(server_key)
client_key = command.format(part="client").split()
subprocess.check_call(client_key)
train_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.train")
test_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.test")
X_train, y_train = load_svmlight_file(train_path)
X_test, y_test = load_svmlight_file(test_path)
X0, X1, y0, y1 = train_test_split(X_train, y_train, test_size=0.5)
X0_valid, X1_valid, y0_valid, y1_valid = train_test_split(
X_test, y_test, test_size=0.5
)
dump_svmlight_file(X0, y0, "agaricus.txt-0.train")
dump_svmlight_file(X0_valid, y0_valid, "agaricus.txt-0.test")
dump_svmlight_file(X1, y1, "agaricus.txt-1.train")
dump_svmlight_file(X1_valid, y1_valid, "agaricus.txt-1.test")
run_federated(world_size=n_workers, with_ssl=with_ssl, use_gpu=use_gpu)

View File

@ -191,8 +191,11 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
} }
if (device.IsCUDA()) { if (device.IsCUDA()) {
device = CUDAOrdinal(device, fail_on_invalid_gpu_id); device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
if (!device.IsCUDA()) {
// We allow loading a GPU-based pickle on a CPU-only machine.
LOG(WARNING) << "XGBoost is not compiled with CUDA support.";
}
} }
return device; return device;
} }
} // namespace } // namespace

View File

@ -34,6 +34,8 @@ class LintersPaths:
"tests/python/test_with_pandas.py", "tests/python/test_with_pandas.py",
"tests/python-gpu/", "tests/python-gpu/",
"tests/python-sycl/", "tests/python-sycl/",
"tests/test_distributed/test_federated/",
"tests/test_distributed/test_gpu_federated/",
"tests/test_distributed/test_with_dask/", "tests/test_distributed/test_with_dask/",
"tests/test_distributed/test_gpu_with_dask/", "tests/test_distributed/test_gpu_with_dask/",
"tests/test_distributed/test_with_spark/", "tests/test_distributed/test_with_spark/",
@ -94,6 +96,8 @@ class LintersPaths:
"tests/python-gpu/load_pickle.py", "tests/python-gpu/load_pickle.py",
"tests/python-gpu/test_gpu_training_continuation.py", "tests/python-gpu/test_gpu_training_continuation.py",
"tests/python/test_model_io.py", "tests/python/test_model_io.py",
"tests/test_distributed/test_federated/",
"tests/test_distributed/test_gpu_federated/",
"tests/test_distributed/test_with_spark/test_data.py", "tests/test_distributed/test_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_spark/test_data.py", "tests/test_distributed/test_gpu_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py", "tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",

View File

@ -70,6 +70,7 @@ case "$suite" in
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_dask pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_dask
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_spark pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_spark
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_federated
unset_pyspark_envs unset_pyspark_envs
uninstall_xgboost uninstall_xgboost
set +x set +x
@ -84,6 +85,7 @@ case "$suite" in
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_dask pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_dask
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_spark pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_spark
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_federated
unset_pyspark_envs unset_pyspark_envs
uninstall_xgboost uninstall_xgboost
set +x set +x

View File

@ -1,17 +0,0 @@
#!/bin/bash
set -e
rm -f ./*.model* ./agaricus* ./*.pem
world_size=$(nvidia-smi -L | wc -l)
# Generate server and client certificates.
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
# Split train and test files manually to simulate a federated environment.
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.train agaricus.txt.train-
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.test agaricus.txt.test-
python test_federated.py "${world_size}"

View File

@ -1,86 +1,8 @@
#!/usr/bin/python import pytest
import multiprocessing
import sys
import time
import xgboost as xgb from xgboost.testing.federated import run_federated_learning
import xgboost.federated
SERVER_KEY = 'server-key.pem'
SERVER_CERT = 'server-cert.pem'
CLIENT_KEY = 'client-key.pem'
CLIENT_CERT = 'client-cert.pem'
def run_server(port: int, world_size: int, with_ssl: bool) -> None: @pytest.mark.parametrize("with_ssl", [True, False])
if with_ssl: def test_federated_learning(with_ssl: bool) -> None:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, run_federated_learning(with_ssl, False, __file__)
CLIENT_CERT)
else:
xgboost.federated.run_federated_server(port, world_size)
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
communicator_env = {
'xgboost_communicator': 'federated',
'federated_server_address': f'localhost:{port}',
'federated_world_size': world_size,
'federated_rank': rank
}
if with_ssl:
communicator_env['federated_server_cert'] = SERVER_CERT
communicator_env['federated_client_key'] = CLIENT_KEY
communicator_env['federated_client_cert'] = CLIENT_CERT
# Always call this before using distributed module
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank)
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
if with_gpu:
param['tree_method'] = 'hist'
param['device'] = f"cuda:{rank}"
# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20
# Run training, all the features in training API is available.
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.collective.get_rank() == 0:
bst.save_model("test.model.json")
xgb.collective.communicator_print("Finished training\n")
def run_federated(with_ssl: bool = True, with_gpu: bool = False) -> None:
port = 9091
world_size = int(sys.argv[1])
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_worker,
args=(port, world_size, rank, with_ssl, with_gpu))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
server.terminate()
if __name__ == '__main__':
run_federated(with_ssl=True, with_gpu=False)
run_federated(with_ssl=False, with_gpu=False)
run_federated(with_ssl=True, with_gpu=True)
run_federated(with_ssl=False, with_gpu=True)

View File

@ -0,0 +1,9 @@
import pytest
from xgboost.testing.federated import run_federated_learning
@pytest.mark.parametrize("with_ssl", [True, False])
@pytest.mark.mgpu
def test_federated_learning(with_ssl: bool) -> None:
run_federated_learning(with_ssl, True, __file__)