use RabitContext intead of init/finalize (#7911)

This commit is contained in:
Rong Ou 2022-05-16 21:15:41 -07:00 committed by GitHub
parent 4fcfd9c96e
commit 77d4a53c32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 137 additions and 153 deletions

View File

@ -230,7 +230,9 @@ def version_number() -> int:
class RabitContext:
"""A context controlling rabit initialization and finalization."""
def __init__(self, args: List[bytes]) -> None:
def __init__(self, args: List[bytes] = None) -> None:
if args is None:
args = []
self.args = args
def __enter__(self) -> None:

View File

@ -8,7 +8,7 @@ import numpy as np
def run_test(name, params_fun):
"""Runs a distributed GPU test."""
# Always call this before using distributed module
xgb.rabit.init()
with xgb.rabit.RabitContext():
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()
@ -47,8 +47,6 @@ def run_test(name, params_fun):
('Worker models diverged: test.model.%s.%d '
'differs from test.model.%s.%d') % (name, i, name, j))
xgb.rabit.finalize()
base_params = {
'tree_method': 'gpu_hist',

View File

@ -2,8 +2,7 @@
import xgboost as xgb
# Always call this before using distributed module
xgb.rabit.init()
with xgb.rabit.RabitContext():
# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
@ -23,7 +22,3 @@ bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()

View File

@ -27,8 +27,7 @@ def run_worker(port: int, world_size: int, rank: int) -> None:
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]
xgb.rabit.init([e.encode() for e in rabit_env])
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
@ -41,18 +40,14 @@ def run_worker(port: int, world_size: int, rank: int) -> None:
num_round = 20
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model.json")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
def run_test() -> None:
port = 9091

View File

@ -2,8 +2,7 @@
import xgboost as xgb
import numpy as np
xgb.rabit.init()
with xgb.rabit.RabitContext():
X = [
[15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05,
4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00,
@ -73,7 +72,3 @@ bst = xgb.train(param, dtrain, num_round, watchlist)
if xgb.rabit.get_rank() == 0:
bst.save_model("test_issue3402.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()

View File

@ -16,10 +16,9 @@ def test_rabit_tracker():
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
xgb.rabit.init(rabit_env)
with xgb.rabit.RabitContext(rabit_env):
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
xgb.rabit.finalize()
def run_rabit_ops(client, n_workers):