Disable dense optimization in hist for distributed training. (#9272)
This commit is contained in:
@@ -44,7 +44,7 @@ try:
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
from xgboost import dask as dxgb
|
||||
from xgboost.testing.dask import check_init_estimation
|
||||
from xgboost.testing.dask import check_init_estimation, check_uneven_nan
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -224,6 +224,12 @@ class TestDistributedGPU:
|
||||
def test_init_estimation(self, local_cuda_client: Client) -> None:
|
||||
check_init_estimation("gpu_hist", local_cuda_client)
|
||||
|
||||
def test_uneven_nan(self) -> None:
|
||||
n_workers = 2
|
||||
with LocalCUDACluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
check_uneven_nan(client, "gpu_hist", n_workers)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||
def test_dask_dataframe(self, local_cuda_client: Client) -> None:
|
||||
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
@@ -41,7 +40,7 @@ from distributed import Client, LocalCluster
|
||||
from toolz import sliding_window # dependency of dask
|
||||
|
||||
from xgboost.dask import DaskDMatrix
|
||||
from xgboost.testing.dask import check_init_estimation
|
||||
from xgboost.testing.dask import check_init_estimation, check_uneven_nan
|
||||
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": False})
|
||||
|
||||
@@ -2014,6 +2013,14 @@ def test_init_estimation(client: Client) -> None:
|
||||
check_init_estimation("hist", client)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_uneven_nan(tree_method) -> None:
|
||||
n_workers = 2
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
check_uneven_nan(client, tree_method, n_workers)
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client: "Client") -> None:
|
||||
|
||||
Reference in New Issue
Block a user