use RabitContext intead of init/finalize (#7911)
This commit is contained in:
parent
4fcfd9c96e
commit
77d4a53c32
@ -230,7 +230,9 @@ def version_number() -> int:
|
|||||||
class RabitContext:
|
class RabitContext:
|
||||||
"""A context controlling rabit initialization and finalization."""
|
"""A context controlling rabit initialization and finalization."""
|
||||||
|
|
||||||
def __init__(self, args: List[bytes]) -> None:
|
def __init__(self, args: List[bytes] = None) -> None:
|
||||||
|
if args is None:
|
||||||
|
args = []
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
def run_test(name, params_fun):
|
def run_test(name, params_fun):
|
||||||
"""Runs a distributed GPU test."""
|
"""Runs a distributed GPU test."""
|
||||||
# Always call this before using distributed module
|
# Always call this before using distributed module
|
||||||
xgb.rabit.init()
|
with xgb.rabit.RabitContext():
|
||||||
rank = xgb.rabit.get_rank()
|
rank = xgb.rabit.get_rank()
|
||||||
world = xgb.rabit.get_world_size()
|
world = xgb.rabit.get_world_size()
|
||||||
|
|
||||||
@ -47,8 +47,6 @@ def run_test(name, params_fun):
|
|||||||
('Worker models diverged: test.model.%s.%d '
|
('Worker models diverged: test.model.%s.%d '
|
||||||
'differs from test.model.%s.%d') % (name, i, name, j))
|
'differs from test.model.%s.%d') % (name, i, name, j))
|
||||||
|
|
||||||
xgb.rabit.finalize()
|
|
||||||
|
|
||||||
|
|
||||||
base_params = {
|
base_params = {
|
||||||
'tree_method': 'gpu_hist',
|
'tree_method': 'gpu_hist',
|
||||||
|
|||||||
@ -2,28 +2,23 @@
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
# Always call this before using distributed module
|
# Always call this before using distributed module
|
||||||
xgb.rabit.init()
|
with xgb.rabit.RabitContext():
|
||||||
|
# Load file, file will be automatically sharded in distributed mode.
|
||||||
|
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||||
|
|
||||||
# Load file, file will be automatically sharded in distributed mode.
|
# Specify parameters via map, definition are same as c++ version
|
||||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
|
||||||
|
|
||||||
# Specify parameters via map, definition are same as c++ version
|
# Specify validations set to watch performance
|
||||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
num_round = 20
|
||||||
|
|
||||||
# Specify validations set to watch performance
|
# Run training, all the features in training API is available.
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
# Currently, this script only support calling train once for fault recovery purpose.
|
||||||
num_round = 20
|
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||||
|
|
||||||
# Run training, all the features in training API is available.
|
# Save the model, only ask process 0 to save the model.
|
||||||
# Currently, this script only support calling train once for fault recovery purpose.
|
if xgb.rabit.get_rank() == 0:
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
|
||||||
|
|
||||||
# Save the model, only ask process 0 to save the model.
|
|
||||||
if xgb.rabit.get_rank() == 0:
|
|
||||||
bst.save_model("test.model")
|
bst.save_model("test.model")
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
# Notify the tracker all training has been successful
|
|
||||||
# This is only needed in distributed training.
|
|
||||||
xgb.rabit.finalize()
|
|
||||||
|
|||||||
@ -27,8 +27,7 @@ def run_worker(port: int, world_size: int, rank: int) -> None:
|
|||||||
f'federated_client_key={CLIENT_KEY}',
|
f'federated_client_key={CLIENT_KEY}',
|
||||||
f'federated_client_cert={CLIENT_CERT}'
|
f'federated_client_cert={CLIENT_CERT}'
|
||||||
]
|
]
|
||||||
xgb.rabit.init([e.encode() for e in rabit_env])
|
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||||
|
|
||||||
# Load file, file will not be sharded in federated mode.
|
# Load file, file will not be sharded in federated mode.
|
||||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
||||||
@ -41,18 +40,14 @@ def run_worker(port: int, world_size: int, rank: int) -> None:
|
|||||||
num_round = 20
|
num_round = 20
|
||||||
|
|
||||||
# Run training, all the features in training API is available.
|
# Run training, all the features in training API is available.
|
||||||
# Currently, this script only support calling train once for fault recovery purpose.
|
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
|
||||||
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
|
early_stopping_rounds=2)
|
||||||
|
|
||||||
# Save the model, only ask process 0 to save the model.
|
# Save the model, only ask process 0 to save the model.
|
||||||
if xgb.rabit.get_rank() == 0:
|
if xgb.rabit.get_rank() == 0:
|
||||||
bst.save_model("test.model.json")
|
bst.save_model("test.model.json")
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
# Notify the tracker all training has been successful
|
|
||||||
# This is only needed in distributed training.
|
|
||||||
xgb.rabit.finalize()
|
|
||||||
|
|
||||||
|
|
||||||
def run_test() -> None:
|
def run_test() -> None:
|
||||||
port = 9091
|
port = 9091
|
||||||
|
|||||||
@ -2,9 +2,8 @@
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
xgb.rabit.init()
|
with xgb.rabit.RabitContext():
|
||||||
|
X = [
|
||||||
X = [
|
|
||||||
[15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05,
|
[15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05,
|
||||||
4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00,
|
4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00,
|
||||||
0.01,138.00,1.00,0.02,69.90,0.00,0.83,5.00,0.01,0.12,47.30,0.00,296.00,0.16,0.00,
|
0.01,138.00,1.00,0.02,69.90,0.00,0.83,5.00,0.01,0.12,47.30,0.00,296.00,0.16,0.00,
|
||||||
@ -60,20 +59,16 @@ X = [
|
|||||||
4415.50,22731.62,1.00,55.00,0.00,499.94,22.00,0.58,67.00,0.21,341.72,16.00,0.00,965.07,
|
4415.50,22731.62,1.00,55.00,0.00,499.94,22.00,0.58,67.00,0.21,341.72,16.00,0.00,965.07,
|
||||||
17.00,138.41,0.00,0.00,1.00,0.14,1.00,0.02,0.35,1.69,369.00,1300.00,25.00,0.00,0.01,
|
17.00,138.41,0.00,0.00,1.00,0.14,1.00,0.02,0.35,1.69,369.00,1300.00,25.00,0.00,0.01,
|
||||||
0.00,0.00,0.00,0.00,52.00,8.00]]
|
0.00,0.00,0.00,0.00,52.00,8.00]]
|
||||||
X = np.array(X)
|
X = np.array(X)
|
||||||
y = [1, 0]
|
y = [1, 0]
|
||||||
|
|
||||||
dtrain = xgb.DMatrix(X, label=y)
|
dtrain = xgb.DMatrix(X, label=y)
|
||||||
|
|
||||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic' }
|
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic' }
|
||||||
watchlist = [(dtrain,'train')]
|
watchlist = [(dtrain,'train')]
|
||||||
num_round = 2
|
num_round = 2
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
|
||||||
if xgb.rabit.get_rank() == 0:
|
if xgb.rabit.get_rank() == 0:
|
||||||
bst.save_model("test_issue3402.model")
|
bst.save_model("test_issue3402.model")
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
# Notify the tracker all training has been successful
|
|
||||||
# This is only needed in distributed training.
|
|
||||||
xgb.rabit.finalize()
|
|
||||||
|
|||||||
@ -16,10 +16,9 @@ def test_rabit_tracker():
|
|||||||
rabit_env = []
|
rabit_env = []
|
||||||
for k, v in worker_env.items():
|
for k, v in worker_env.items():
|
||||||
rabit_env.append(f"{k}={v}".encode())
|
rabit_env.append(f"{k}={v}".encode())
|
||||||
xgb.rabit.init(rabit_env)
|
with xgb.rabit.RabitContext(rabit_env):
|
||||||
ret = xgb.rabit.broadcast('test1234', 0)
|
ret = xgb.rabit.broadcast('test1234', 0)
|
||||||
assert str(ret) == 'test1234'
|
assert str(ret) == 'test1234'
|
||||||
xgb.rabit.finalize()
|
|
||||||
|
|
||||||
|
|
||||||
def run_rabit_ops(client, n_workers):
|
def run_rabit_ops(client, n_workers):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user