Fix federated learning demos and tests (#9488)
This commit is contained in:
@@ -11,7 +11,7 @@ openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out se
|
||||
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
|
||||
|
||||
# Split train and test files manually to simulate a federated environment.
|
||||
split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.train agaricus.txt.train-
|
||||
split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.test agaricus.txt.test-
|
||||
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.train agaricus.txt.train-
|
||||
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.test agaricus.txt.test-
|
||||
|
||||
python test_federated.py "${world_size}"
|
||||
|
||||
@@ -35,14 +35,14 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
|
||||
# Always call this before using distributed module
|
||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank)
|
||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank)
|
||||
|
||||
# Specify parameters via map, definition are same as c++ version
|
||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||
if with_gpu:
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
param['gpu_id'] = rank
|
||||
param['tree_method'] = 'hist'
|
||||
param['device'] = f"cuda:{rank}"
|
||||
|
||||
# Specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
|
||||
Reference in New Issue
Block a user