Use in-memory communicator to test quantile (#8710)

This commit is contained in:
Rong Ou
2023-01-27 07:28:28 -08:00
committed by GitHub
parent 96e6b6beba
commit 8af98e30fc
8 changed files with 86 additions and 207 deletions

View File

@@ -1490,62 +1490,6 @@ class TestWithDask:
num_rounds = 10
self.run_updater_test(client, params, num_rounds, dataset, 'approx')
def run_quantile(self, name: str) -> None:
exe: Optional[str] = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/cpubuild/testxgboost',
'../cpu-build/testxgboost'}:
if os.path.exists(possible_path):
exe = possible_path
if exe is None:
return
test = "--gtest_filter=Quantile." + name
def runit(
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
# setup environment for running the c++ part.
env = os.environ.copy()
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, capture_output=True)
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
with Client(cluster) as client:
workers = tm.get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = client.map(runit,
workers,
pure=False,
workers=workers,
rabit_args=rabit_args)
results = client.gather(futures)
for ret in results:
msg = ret.stdout.decode('utf-8')
assert msg.find('1 test from Quantile') != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_basic(self) -> None:
self.run_quantile('DistributedBasic')
self.run_quantile('SortedDistributedBasic')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile(self) -> None:
self.run_quantile('Distributed')
self.run_quantile('SortedDistributed')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_same_on_all_workers(self) -> None:
self.run_quantile("SameOnAllWorkers")
def test_adaptive(self) -> None:
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])