Use in-memory communicator to test quantile (#8710)
This commit is contained in:
@@ -1490,62 +1490,6 @@ class TestWithDask:
|
||||
num_rounds = 10
|
||||
self.run_updater_test(client, params, num_rounds, dataset, 'approx')
|
||||
|
||||
def run_quantile(self, name: str) -> None:
|
||||
exe: Optional[str] = None
|
||||
for possible_path in {'./testxgboost', './build/testxgboost',
|
||||
'../build/cpubuild/testxgboost',
|
||||
'../cpu-build/testxgboost'}:
|
||||
if os.path.exists(possible_path):
|
||||
exe = possible_path
|
||||
if exe is None:
|
||||
return
|
||||
|
||||
test = "--gtest_filter=Quantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
# setup environment for running the c++ part.
|
||||
env = os.environ.copy()
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = tm.get_client_workers(client)
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, len(workers), None, client
|
||||
)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
pure=False,
|
||||
workers=workers,
|
||||
rabit_args=rabit_args)
|
||||
results = client.gather(futures)
|
||||
|
||||
for ret in results:
|
||||
msg = ret.stdout.decode('utf-8')
|
||||
assert msg.find('1 test from Quantile') != -1, msg
|
||||
assert ret.returncode == 0, msg
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_basic(self) -> None:
|
||||
self.run_quantile('DistributedBasic')
|
||||
self.run_quantile('SortedDistributedBasic')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile(self) -> None:
|
||||
self.run_quantile('Distributed')
|
||||
self.run_quantile('SortedDistributed')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_same_on_all_workers(self) -> None:
|
||||
self.run_quantile("SameOnAllWorkers")
|
||||
|
||||
def test_adaptive(self) -> None:
|
||||
def get_score(config: Dict) -> float:
|
||||
return float(config["learner"]["learner_model_param"]["base_score"])
|
||||
|
||||
Reference in New Issue
Block a user