Optimize cpu sketch allreduce for sparse data. (#6009)
* Bypass RABIT serialization reducer and use custom allgather based merging.
This commit is contained in:
@@ -501,17 +501,20 @@ class TestWithDask:
|
||||
num_boost_round=num_rounds,
|
||||
evals=[(m, 'train')])['history']
|
||||
note(history)
|
||||
assert tm.non_increasing(history['train'][dataset.metric])
|
||||
history = history['train'][dataset.metric]
|
||||
assert tm.non_increasing(history)
|
||||
# Make sure that it's decreasing
|
||||
assert history[-1] < history[0]
|
||||
|
||||
@given(params=hist_parameter_strategy,
|
||||
num_rounds=strategies.integers(10, 20),
|
||||
num_rounds=strategies.integers(20, 30),
|
||||
dataset=tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_hist(self, params, num_rounds, dataset, client):
|
||||
self.run_updater_test(client, params, num_rounds, dataset, 'hist')
|
||||
|
||||
@given(params=exact_parameter_strategy,
|
||||
num_rounds=strategies.integers(10, 20),
|
||||
num_rounds=strategies.integers(20, 30),
|
||||
dataset=tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_approx(self, client, params, num_rounds, dataset):
|
||||
@@ -524,8 +527,7 @@ class TestWithDask:
|
||||
exe = None
|
||||
for possible_path in {'./testxgboost', './build/testxgboost',
|
||||
'../build/testxgboost',
|
||||
'../cpu-build/testxgboost',
|
||||
'../gpu-build/testxgboost'}:
|
||||
'../cpu-build/testxgboost'}:
|
||||
if os.path.exists(possible_path):
|
||||
exe = possible_path
|
||||
if exe is None:
|
||||
@@ -542,7 +544,7 @@ class TestWithDask:
|
||||
port = port.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
|
||||
return subprocess.run([exe, test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
@@ -555,6 +557,7 @@ class TestWithDask:
|
||||
workers=workers,
|
||||
rabit_args=rabit_args)
|
||||
results = client.gather(futures)
|
||||
|
||||
for ret in results:
|
||||
msg = ret.stdout.decode('utf-8')
|
||||
assert msg.find('1 test from Quantile') != -1, msg
|
||||
@@ -563,4 +566,14 @@ class TestWithDask:
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_basic(self):
|
||||
self.run_quantile('DistributedBasic')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile(self):
|
||||
self.run_quantile('Distributed')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_same_on_all_workers(self):
|
||||
self.run_quantile('SameOnAllWorkers')
|
||||
|
||||
Reference in New Issue
Block a user