Disable dense optimization in hist for distributed training. (#9272)
This commit is contained in:
parent
8c1065f645
commit
ea0deeca68
@ -1,6 +1,8 @@
|
||||
"""Tests for dask shared by different test modules."""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dask import array as da
|
||||
from dask import dataframe as dd
|
||||
from distributed import Client
|
||||
|
||||
import xgboost as xgb
|
||||
@ -52,3 +54,22 @@ def check_init_estimation(tree_method: str, client: Client) -> None:
|
||||
"""Test init estimation."""
|
||||
check_init_estimation_reg(tree_method, client)
|
||||
check_init_estimation_clf(tree_method, client)
|
||||
|
||||
|
||||
def check_uneven_nan(client: Client, tree_method: str, n_workers: int) -> None:
|
||||
"""Issue #9271, not every worker has missing value."""
|
||||
assert n_workers >= 2
|
||||
|
||||
with client.as_current():
|
||||
clf = xgb.dask.DaskXGBClassifier(tree_method=tree_method)
|
||||
X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)})
|
||||
y = pd.Series([*[0] * 5000, *[1] * 5000])
|
||||
|
||||
X["a"][:3000:1000] = np.NaN
|
||||
|
||||
client.wait_for_workers(n_workers=n_workers)
|
||||
|
||||
clf.fit(
|
||||
dd.from_pandas(X, npartitions=n_workers),
|
||||
dd.from_pandas(y, npartitions=n_workers),
|
||||
)
|
||||
|
||||
@ -285,7 +285,7 @@ struct GPUHistMakerDevice {
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
matrix.is_dense
|
||||
matrix.is_dense && !collective::IsDistributed()
|
||||
};
|
||||
auto split = this->evaluator_.EvaluateSingleSplit(inputs, shared_inputs);
|
||||
return split;
|
||||
@ -299,11 +299,11 @@ struct GPUHistMakerDevice {
|
||||
std::vector<bst_node_t> nidx(2 * candidates.size());
|
||||
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
|
||||
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
GPUTrainingParam{param}, *quantiser, feature_types, matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map, matrix.min_fvalue,
|
||||
matrix.is_dense
|
||||
};
|
||||
EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param}, *quantiser, feature_types,
|
||||
matrix.feature_segments, matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
// is_dense represents the local data
|
||||
matrix.is_dense && !collective::IsDistributed()};
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
||||
// Store the feature set ptrs so they dont go out of scope before the kernel is called
|
||||
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
|
||||
|
||||
@ -435,7 +435,7 @@ class HistBuilder {
|
||||
|
||||
{
|
||||
GradientPairPrecise grad_stat;
|
||||
if (p_fmat->IsDense()) {
|
||||
if (p_fmat->IsDense() && !collective::IsDistributed()) {
|
||||
/**
|
||||
* Specialized code for dense data: For dense data (with no missing value), the sum
|
||||
* of gradient histogram is equal to snode[nid]
|
||||
|
||||
@ -44,7 +44,7 @@ try:
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
from xgboost import dask as dxgb
|
||||
from xgboost.testing.dask import check_init_estimation
|
||||
from xgboost.testing.dask import check_init_estimation, check_uneven_nan
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -224,6 +224,12 @@ class TestDistributedGPU:
|
||||
def test_init_estimation(self, local_cuda_client: Client) -> None:
|
||||
check_init_estimation("gpu_hist", local_cuda_client)
|
||||
|
||||
def test_uneven_nan(self) -> None:
|
||||
n_workers = 2
|
||||
with LocalCUDACluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
check_uneven_nan(client, "gpu_hist", n_workers)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||
def test_dask_dataframe(self, local_cuda_client: Client) -> None:
|
||||
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
|
||||
|
||||
@ -4,7 +4,6 @@ import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
@ -41,7 +40,7 @@ from distributed import Client, LocalCluster
|
||||
from toolz import sliding_window # dependency of dask
|
||||
|
||||
from xgboost.dask import DaskDMatrix
|
||||
from xgboost.testing.dask import check_init_estimation
|
||||
from xgboost.testing.dask import check_init_estimation, check_uneven_nan
|
||||
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": False})
|
||||
|
||||
@ -2014,6 +2013,14 @@ def test_init_estimation(client: Client) -> None:
|
||||
check_init_estimation("hist", client)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_uneven_nan(tree_method) -> None:
|
||||
n_workers = 2
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
check_uneven_nan(client, tree_method, n_workers)
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client: "Client") -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user