Optimize cpu sketch allreduce for sparse data. (#6009)

* Bypass RABIT serialization reducer and use custom allgather based merging.
This commit is contained in:
Jiaming Yuan
2020-08-19 10:03:45 +08:00
committed by GitHub
parent 90355b4f00
commit 29b7fea572
10 changed files with 357 additions and 87 deletions

View File

@@ -501,17 +501,20 @@ class TestWithDask:
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])
history = history['train'][dataset.metric]
assert tm.non_increasing(history)
# Make sure that it's decreasing
assert history[-1] < history[0]
@given(params=hist_parameter_strategy,
num_rounds=strategies.integers(10, 20),
num_rounds=strategies.integers(20, 30),
dataset=tm.dataset_strategy)
@settings(deadline=None)
def test_hist(self, params, num_rounds, dataset, client):
self.run_updater_test(client, params, num_rounds, dataset, 'hist')
@given(params=exact_parameter_strategy,
num_rounds=strategies.integers(10, 20),
num_rounds=strategies.integers(20, 30),
dataset=tm.dataset_strategy)
@settings(deadline=None)
def test_approx(self, client, params, num_rounds, dataset):
@@ -524,8 +527,7 @@ class TestWithDask:
exe = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/testxgboost',
'../cpu-build/testxgboost',
'../gpu-build/testxgboost'}:
'../cpu-build/testxgboost'}:
if os.path.exists(possible_path):
exe = possible_path
if exe is None:
@@ -542,7 +544,7 @@ class TestWithDask:
port = port.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
return subprocess.run([exe, test], env=env, capture_output=True)
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
@@ -555,6 +557,7 @@ class TestWithDask:
workers=workers,
rabit_args=rabit_args)
results = client.gather(futures)
for ret in results:
msg = ret.stdout.decode('utf-8')
assert msg.find('1 test from Quantile') != -1, msg
@@ -563,4 +566,14 @@ class TestWithDask:
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_basic(self):
self.run_quantile('DistributedBasic')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile(self):
self.run_quantile('Distributed')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_same_on_all_workers(self):
self.run_quantile('SameOnAllWorkers')