[dask] Extend tree stats tests. (#7128)
* Add tests to GPU. * Assert cover in children sums up to the parent.
This commit is contained in:
@@ -28,6 +28,7 @@ from test_with_dask import _get_client_workers # noqa
|
||||
from test_with_dask import generate_array # noqa
|
||||
from test_with_dask import kCols as random_cols # noqa
|
||||
from test_with_dask import suppress # noqa
|
||||
from test_with_dask import run_tree_stats # noqa
|
||||
import testing as tm # noqa
|
||||
|
||||
|
||||
@@ -493,6 +494,17 @@ class TestDistributedGPU:
|
||||
for rn, drn in zip(ranker_names, dranker_names):
|
||||
assert rn == drn
|
||||
|
||||
def test_tree_stats(self) -> None:
|
||||
with LocalCUDACluster(n_workers=1) as cluster:
|
||||
with Client(cluster) as client:
|
||||
local = run_tree_stats(client, "gpu_hist")
|
||||
|
||||
with LocalCUDACluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
distributed = run_tree_stats(client, "gpu_hist")
|
||||
|
||||
assert local == distributed
|
||||
|
||||
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows")
|
||||
|
||||
Reference in New Issue
Block a user