[fed] Fixes for the encrypted GRPC backend. (#10503)
This commit is contained in:
parent
5f0c1e902b
commit
a39fef2c67
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost contributors
|
* Copyright 2023-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "federated_comm.h"
|
#include "federated_comm.h"
|
||||||
|
|
||||||
@ -11,6 +11,7 @@
|
|||||||
#include <string> // for string, stoi
|
#include <string> // for string, stoi
|
||||||
|
|
||||||
#include "../../src/common/common.h" // for Split
|
#include "../../src/common/common.h" // for Split
|
||||||
|
#include "../../src/common/io.h" // for ReadAll
|
||||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||||
#include "xgboost/json.h" // for Json
|
#include "xgboost/json.h" // for Json
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
@ -46,9 +47,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
|||||||
} else {
|
} else {
|
||||||
stub_ = [&] {
|
stub_ = [&] {
|
||||||
grpc::SslCredentialsOptions options;
|
grpc::SslCredentialsOptions options;
|
||||||
options.pem_root_certs = server_cert;
|
options.pem_root_certs = common::ReadAll(server_cert);
|
||||||
options.pem_private_key = client_key;
|
options.pem_private_key = common::ReadAll(client_key);
|
||||||
options.pem_cert_chain = client_cert;
|
options.pem_cert_chain = common::ReadAll(client_cert);
|
||||||
grpc::ChannelArguments args;
|
grpc::ChannelArguments args;
|
||||||
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||||
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
|
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
|
||||||
|
|||||||
@ -39,9 +39,9 @@ class FederatedTracker(RabitTracker):
|
|||||||
n_workers: int,
|
n_workers: int,
|
||||||
port: int,
|
port: int,
|
||||||
secure: bool,
|
secure: bool,
|
||||||
server_key_path: str = "",
|
server_key_path: Optional[str] = None,
|
||||||
server_cert_path: str = "",
|
server_cert_path: Optional[str] = None,
|
||||||
client_cert_path: str = "",
|
client_cert_path: Optional[str] = None,
|
||||||
timeout: int = 300,
|
timeout: int = 300,
|
||||||
) -> None:
|
) -> None:
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
@ -84,7 +84,13 @@ def run_federated_server( # pylint: disable=too-many-arguments
|
|||||||
for path in [server_key_path, server_cert_path, client_cert_path]
|
for path in [server_key_path, server_cert_path, client_cert_path]
|
||||||
)
|
)
|
||||||
tracker = FederatedTracker(
|
tracker = FederatedTracker(
|
||||||
n_workers=n_workers, port=port, secure=secure, timeout=timeout
|
n_workers=n_workers,
|
||||||
|
port=port,
|
||||||
|
secure=secure,
|
||||||
|
timeout=timeout,
|
||||||
|
server_key_path=server_key_path,
|
||||||
|
server_cert_path=server_cert_path,
|
||||||
|
client_cert_path=client_cert_path,
|
||||||
)
|
)
|
||||||
tracker.start()
|
tracker.start()
|
||||||
|
|
||||||
|
|||||||
153
python-package/xgboost/testing/federated.py
Normal file
153
python-package/xgboost/testing/federated.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# pylint: disable=unbalanced-tuple-unpacking, too-many-locals
|
||||||
|
"""Tests for federated learning."""
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from typing import List, cast
|
||||||
|
|
||||||
|
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
import xgboost.federated
|
||||||
|
from xgboost import testing as tm
|
||||||
|
from xgboost.training import TrainingCallback
|
||||||
|
|
||||||
|
SERVER_KEY = "server-key.pem"
|
||||||
|
SERVER_CERT = "server-cert.pem"
|
||||||
|
CLIENT_KEY = "client-key.pem"
|
||||||
|
CLIENT_CERT = "client-cert.pem"
|
||||||
|
|
||||||
|
|
||||||
|
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
||||||
|
"""Run federated server for test."""
|
||||||
|
if with_ssl:
|
||||||
|
xgboost.federated.run_federated_server(
|
||||||
|
world_size,
|
||||||
|
port,
|
||||||
|
server_key_path=SERVER_KEY,
|
||||||
|
server_cert_path=SERVER_CERT,
|
||||||
|
client_cert_path=CLIENT_CERT,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
xgboost.federated.run_federated_server(world_size, port)
|
||||||
|
|
||||||
|
|
||||||
|
def run_worker(
|
||||||
|
port: int, world_size: int, rank: int, with_ssl: bool, device: str
|
||||||
|
) -> None:
|
||||||
|
"""Run federated client worker for test."""
|
||||||
|
communicator_env = {
|
||||||
|
"dmlc_communicator": "federated",
|
||||||
|
"federated_server_address": f"localhost:{port}",
|
||||||
|
"federated_world_size": world_size,
|
||||||
|
"federated_rank": rank,
|
||||||
|
}
|
||||||
|
if with_ssl:
|
||||||
|
communicator_env["federated_server_cert_path"] = SERVER_CERT
|
||||||
|
communicator_env["federated_client_key_path"] = CLIENT_KEY
|
||||||
|
communicator_env["federated_client_cert_path"] = CLIENT_CERT
|
||||||
|
|
||||||
|
cpu_count = os.cpu_count()
|
||||||
|
assert cpu_count is not None
|
||||||
|
n_threads = cpu_count // world_size
|
||||||
|
|
||||||
|
# Always call this before using distributed module
|
||||||
|
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||||
|
# Load file, file will not be sharded in federated mode.
|
||||||
|
X, y = load_svmlight_file(f"agaricus.txt-{rank}.train")
|
||||||
|
dtrain = xgb.DMatrix(X, y)
|
||||||
|
X, y = load_svmlight_file(f"agaricus.txt-{rank}.test")
|
||||||
|
dtest = xgb.DMatrix(X, y)
|
||||||
|
|
||||||
|
# Specify parameters via map, definition are same as c++ version
|
||||||
|
param = {
|
||||||
|
"max_depth": 2,
|
||||||
|
"eta": 1,
|
||||||
|
"objective": "binary:logistic",
|
||||||
|
"nthread": n_threads,
|
||||||
|
"tree_method": "hist",
|
||||||
|
"device": device,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Specify validations set to watch performance
|
||||||
|
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||||
|
num_round = 20
|
||||||
|
|
||||||
|
# Run training, all the features in training API is available.
|
||||||
|
results: TrainingCallback.EvalsLog = {}
|
||||||
|
bst = xgb.train(
|
||||||
|
param,
|
||||||
|
dtrain,
|
||||||
|
num_round,
|
||||||
|
evals=watchlist,
|
||||||
|
early_stopping_rounds=2,
|
||||||
|
evals_result=results,
|
||||||
|
)
|
||||||
|
assert tm.non_increasing(cast(List[float], results["train"]["logloss"]))
|
||||||
|
assert tm.non_increasing(cast(List[float], results["eval"]["logloss"]))
|
||||||
|
|
||||||
|
# save the model, only ask process 0 to save the model.
|
||||||
|
if xgb.collective.get_rank() == 0:
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
bst.save_model(os.path.join(tmpdir, "model.json"))
|
||||||
|
xgb.collective.communicator_print("Finished training\n")
|
||||||
|
|
||||||
|
|
||||||
|
def run_federated(world_size: int, with_ssl: bool, use_gpu: bool) -> None:
|
||||||
|
"""Launcher for clients and the server."""
|
||||||
|
port = 9091
|
||||||
|
|
||||||
|
server = multiprocessing.Process(
|
||||||
|
target=run_server, args=(port, world_size, with_ssl)
|
||||||
|
)
|
||||||
|
server.start()
|
||||||
|
time.sleep(1)
|
||||||
|
if not server.is_alive():
|
||||||
|
raise ValueError("Error starting Federated Learning server")
|
||||||
|
|
||||||
|
workers = []
|
||||||
|
for rank in range(world_size):
|
||||||
|
device = f"cuda:{rank}" if use_gpu else "cpu"
|
||||||
|
worker = multiprocessing.Process(
|
||||||
|
target=run_worker, args=(port, world_size, rank, with_ssl, device)
|
||||||
|
)
|
||||||
|
workers.append(worker)
|
||||||
|
worker.start()
|
||||||
|
for worker in workers:
|
||||||
|
worker.join()
|
||||||
|
server.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def run_federated_learning(with_ssl: bool, use_gpu: bool, test_path: str) -> None:
|
||||||
|
"""Run federated learning tests."""
|
||||||
|
n_workers = 2
|
||||||
|
|
||||||
|
if with_ssl:
|
||||||
|
command = "openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout {part}-key.pem -out {part}-cert.pem -subj /C=US/CN=localhost" # pylint: disable=line-too-long
|
||||||
|
server_key = command.format(part="server").split()
|
||||||
|
subprocess.check_call(server_key)
|
||||||
|
client_key = command.format(part="client").split()
|
||||||
|
subprocess.check_call(client_key)
|
||||||
|
|
||||||
|
train_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.train")
|
||||||
|
test_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.test")
|
||||||
|
|
||||||
|
X_train, y_train = load_svmlight_file(train_path)
|
||||||
|
X_test, y_test = load_svmlight_file(test_path)
|
||||||
|
|
||||||
|
X0, X1, y0, y1 = train_test_split(X_train, y_train, test_size=0.5)
|
||||||
|
X0_valid, X1_valid, y0_valid, y1_valid = train_test_split(
|
||||||
|
X_test, y_test, test_size=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
dump_svmlight_file(X0, y0, "agaricus.txt-0.train")
|
||||||
|
dump_svmlight_file(X0_valid, y0_valid, "agaricus.txt-0.test")
|
||||||
|
|
||||||
|
dump_svmlight_file(X1, y1, "agaricus.txt-1.train")
|
||||||
|
dump_svmlight_file(X1_valid, y1_valid, "agaricus.txt-1.test")
|
||||||
|
|
||||||
|
run_federated(world_size=n_workers, with_ssl=with_ssl, use_gpu=use_gpu)
|
||||||
@ -191,8 +191,11 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
|
|||||||
}
|
}
|
||||||
if (device.IsCUDA()) {
|
if (device.IsCUDA()) {
|
||||||
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
|
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
|
||||||
|
if (!device.IsCUDA()) {
|
||||||
|
// We allow loading a GPU-based pickle on a CPU-only machine.
|
||||||
|
LOG(WARNING) << "XGBoost is not compiled with CUDA support.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return device;
|
return device;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@ -34,6 +34,8 @@ class LintersPaths:
|
|||||||
"tests/python/test_with_pandas.py",
|
"tests/python/test_with_pandas.py",
|
||||||
"tests/python-gpu/",
|
"tests/python-gpu/",
|
||||||
"tests/python-sycl/",
|
"tests/python-sycl/",
|
||||||
|
"tests/test_distributed/test_federated/",
|
||||||
|
"tests/test_distributed/test_gpu_federated/",
|
||||||
"tests/test_distributed/test_with_dask/",
|
"tests/test_distributed/test_with_dask/",
|
||||||
"tests/test_distributed/test_gpu_with_dask/",
|
"tests/test_distributed/test_gpu_with_dask/",
|
||||||
"tests/test_distributed/test_with_spark/",
|
"tests/test_distributed/test_with_spark/",
|
||||||
@ -94,6 +96,8 @@ class LintersPaths:
|
|||||||
"tests/python-gpu/load_pickle.py",
|
"tests/python-gpu/load_pickle.py",
|
||||||
"tests/python-gpu/test_gpu_training_continuation.py",
|
"tests/python-gpu/test_gpu_training_continuation.py",
|
||||||
"tests/python/test_model_io.py",
|
"tests/python/test_model_io.py",
|
||||||
|
"tests/test_distributed/test_federated/",
|
||||||
|
"tests/test_distributed/test_gpu_federated/",
|
||||||
"tests/test_distributed/test_with_spark/test_data.py",
|
"tests/test_distributed/test_with_spark/test_data.py",
|
||||||
"tests/test_distributed/test_gpu_with_spark/test_data.py",
|
"tests/test_distributed/test_gpu_with_spark/test_data.py",
|
||||||
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
|
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
|
||||||
|
|||||||
@ -70,6 +70,7 @@ case "$suite" in
|
|||||||
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
|
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
|
||||||
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_dask
|
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_dask
|
||||||
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_spark
|
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_spark
|
||||||
|
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_federated
|
||||||
unset_pyspark_envs
|
unset_pyspark_envs
|
||||||
uninstall_xgboost
|
uninstall_xgboost
|
||||||
set +x
|
set +x
|
||||||
@ -84,6 +85,7 @@ case "$suite" in
|
|||||||
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
|
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
|
||||||
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_dask
|
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_dask
|
||||||
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_spark
|
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_spark
|
||||||
|
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_federated
|
||||||
unset_pyspark_envs
|
unset_pyspark_envs
|
||||||
uninstall_xgboost
|
uninstall_xgboost
|
||||||
set +x
|
set +x
|
||||||
|
|||||||
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
rm -f ./*.model* ./agaricus* ./*.pem
|
|
||||||
|
|
||||||
world_size=$(nvidia-smi -L | wc -l)
|
|
||||||
|
|
||||||
# Generate server and client certificates.
|
|
||||||
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
|
|
||||||
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
|
|
||||||
|
|
||||||
# Split train and test files manually to simulate a federated environment.
|
|
||||||
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.train agaricus.txt.train-
|
|
||||||
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.test agaricus.txt.test-
|
|
||||||
|
|
||||||
python test_federated.py "${world_size}"
|
|
||||||
@ -1,86 +1,8 @@
|
|||||||
#!/usr/bin/python
|
import pytest
|
||||||
import multiprocessing
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import xgboost as xgb
|
from xgboost.testing.federated import run_federated_learning
|
||||||
import xgboost.federated
|
|
||||||
|
|
||||||
SERVER_KEY = 'server-key.pem'
|
|
||||||
SERVER_CERT = 'server-cert.pem'
|
|
||||||
CLIENT_KEY = 'client-key.pem'
|
|
||||||
CLIENT_CERT = 'client-cert.pem'
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
@pytest.mark.parametrize("with_ssl", [True, False])
|
||||||
if with_ssl:
|
def test_federated_learning(with_ssl: bool) -> None:
|
||||||
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
|
run_federated_learning(with_ssl, False, __file__)
|
||||||
CLIENT_CERT)
|
|
||||||
else:
|
|
||||||
xgboost.federated.run_federated_server(port, world_size)
|
|
||||||
|
|
||||||
|
|
||||||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
|
||||||
communicator_env = {
|
|
||||||
'xgboost_communicator': 'federated',
|
|
||||||
'federated_server_address': f'localhost:{port}',
|
|
||||||
'federated_world_size': world_size,
|
|
||||||
'federated_rank': rank
|
|
||||||
}
|
|
||||||
if with_ssl:
|
|
||||||
communicator_env['federated_server_cert'] = SERVER_CERT
|
|
||||||
communicator_env['federated_client_key'] = CLIENT_KEY
|
|
||||||
communicator_env['federated_client_cert'] = CLIENT_CERT
|
|
||||||
|
|
||||||
# Always call this before using distributed module
|
|
||||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
|
||||||
# Load file, file will not be sharded in federated mode.
|
|
||||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank)
|
|
||||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank)
|
|
||||||
|
|
||||||
# Specify parameters via map, definition are same as c++ version
|
|
||||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
|
||||||
if with_gpu:
|
|
||||||
param['tree_method'] = 'hist'
|
|
||||||
param['device'] = f"cuda:{rank}"
|
|
||||||
|
|
||||||
# Specify validations set to watch performance
|
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
|
||||||
num_round = 20
|
|
||||||
|
|
||||||
# Run training, all the features in training API is available.
|
|
||||||
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
|
|
||||||
early_stopping_rounds=2)
|
|
||||||
|
|
||||||
# Save the model, only ask process 0 to save the model.
|
|
||||||
if xgb.collective.get_rank() == 0:
|
|
||||||
bst.save_model("test.model.json")
|
|
||||||
xgb.collective.communicator_print("Finished training\n")
|
|
||||||
|
|
||||||
|
|
||||||
def run_federated(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
|
||||||
port = 9091
|
|
||||||
world_size = int(sys.argv[1])
|
|
||||||
|
|
||||||
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
|
|
||||||
server.start()
|
|
||||||
time.sleep(1)
|
|
||||||
if not server.is_alive():
|
|
||||||
raise Exception("Error starting Federated Learning server")
|
|
||||||
|
|
||||||
workers = []
|
|
||||||
for rank in range(world_size):
|
|
||||||
worker = multiprocessing.Process(target=run_worker,
|
|
||||||
args=(port, world_size, rank, with_ssl, with_gpu))
|
|
||||||
workers.append(worker)
|
|
||||||
worker.start()
|
|
||||||
for worker in workers:
|
|
||||||
worker.join()
|
|
||||||
server.terminate()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run_federated(with_ssl=True, with_gpu=False)
|
|
||||||
run_federated(with_ssl=False, with_gpu=False)
|
|
||||||
run_federated(with_ssl=True, with_gpu=True)
|
|
||||||
run_federated(with_ssl=False, with_gpu=True)
|
|
||||||
|
|||||||
@ -0,0 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from xgboost.testing.federated import run_federated_learning
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("with_ssl", [True, False])
|
||||||
|
@pytest.mark.mgpu
|
||||||
|
def test_federated_learning(with_ssl: bool) -> None:
|
||||||
|
run_federated_learning(with_ssl, True, __file__)
|
||||||
Loading…
x
Reference in New Issue
Block a user