Fix federated learning demos and tests (#9488)

This commit is contained in:
Sean Yang
2023-08-16 00:25:05 -07:00
committed by GitHub
parent b2e93d2742
commit 12fe2fc06c
10 changed files with 60 additions and 11 deletions

View File

@@ -11,7 +11,7 @@ openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out se
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
# Split train and test files manually to simulate a federated environment.
split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.train agaricus.txt.train-
split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.test agaricus.txt.test-
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.train agaricus.txt.train-
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.test agaricus.txt.test-
python test_federated.py "${world_size}"

View File

@@ -35,14 +35,14 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
# Always call this before using distributed module
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank)
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
if with_gpu:
param['tree_method'] = 'gpu_hist'
param['gpu_id'] = rank
param['tree_method'] = 'hist'
param['device'] = f"cuda:{rank}"
# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]