Disable dense optimization in hist for distributed training. (#9272)

This commit is contained in:
Jiaming Yuan 2023-06-10 02:31:34 +08:00 committed by GitHub
parent 8c1065f645
commit ea0deeca68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 10 deletions

View File

@ -1,6 +1,8 @@
"""Tests for dask shared by different test modules.""" """Tests for dask shared by different test modules."""
import numpy as np import numpy as np
import pandas as pd
from dask import array as da from dask import array as da
from dask import dataframe as dd
from distributed import Client from distributed import Client
import xgboost as xgb import xgboost as xgb
@ -52,3 +54,22 @@ def check_init_estimation(tree_method: str, client: Client) -> None:
"""Test init estimation.""" """Test init estimation."""
check_init_estimation_reg(tree_method, client) check_init_estimation_reg(tree_method, client)
check_init_estimation_clf(tree_method, client) check_init_estimation_clf(tree_method, client)
def check_uneven_nan(client: Client, tree_method: str, n_workers: int) -> None:
"""Issue #9271, not every worker has missing value."""
assert n_workers >= 2
with client.as_current():
clf = xgb.dask.DaskXGBClassifier(tree_method=tree_method)
X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)})
y = pd.Series([*[0] * 5000, *[1] * 5000])
X["a"][:3000:1000] = np.NaN
client.wait_for_workers(n_workers=n_workers)
clf.fit(
dd.from_pandas(X, npartitions=n_workers),
dd.from_pandas(y, npartitions=n_workers),
)

View File

@ -285,7 +285,7 @@ struct GPUHistMakerDevice {
matrix.feature_segments, matrix.feature_segments,
matrix.gidx_fvalue_map, matrix.gidx_fvalue_map,
matrix.min_fvalue, matrix.min_fvalue,
matrix.is_dense matrix.is_dense && !collective::IsDistributed()
}; };
auto split = this->evaluator_.EvaluateSingleSplit(inputs, shared_inputs); auto split = this->evaluator_.EvaluateSingleSplit(inputs, shared_inputs);
return split; return split;
@ -299,11 +299,11 @@ struct GPUHistMakerDevice {
std::vector<bst_node_t> nidx(2 * candidates.size()); std::vector<bst_node_t> nidx(2 * candidates.size());
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size()); auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
EvaluateSplitSharedInputs shared_inputs{ EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param}, *quantiser, feature_types,
GPUTrainingParam{param}, *quantiser, feature_types, matrix.feature_segments, matrix.feature_segments, matrix.gidx_fvalue_map,
matrix.gidx_fvalue_map, matrix.min_fvalue, matrix.min_fvalue,
matrix.is_dense // is_dense represents the local data
}; matrix.is_dense && !collective::IsDistributed()};
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size()); dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
// Store the feature set ptrs so they dont go out of scope before the kernel is called // Store the feature set ptrs so they dont go out of scope before the kernel is called
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets; std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;

View File

@ -435,7 +435,7 @@ class HistBuilder {
{ {
GradientPairPrecise grad_stat; GradientPairPrecise grad_stat;
if (p_fmat->IsDense()) { if (p_fmat->IsDense() && !collective::IsDistributed()) {
/** /**
* Specialized code for dense data: For dense data (with no missing value), the sum * Specialized code for dense data: For dense data (with no missing value), the sum
* of gradient histogram is equal to snode[nid] * of gradient histogram is equal to snode[nid]

View File

@ -44,7 +44,7 @@ try:
from dask_cuda import LocalCUDACluster from dask_cuda import LocalCUDACluster
from xgboost import dask as dxgb from xgboost import dask as dxgb
from xgboost.testing.dask import check_init_estimation from xgboost.testing.dask import check_init_estimation, check_uneven_nan
except ImportError: except ImportError:
pass pass
@ -224,6 +224,12 @@ class TestDistributedGPU:
def test_init_estimation(self, local_cuda_client: Client) -> None: def test_init_estimation(self, local_cuda_client: Client) -> None:
check_init_estimation("gpu_hist", local_cuda_client) check_init_estimation("gpu_hist", local_cuda_client)
def test_uneven_nan(self) -> None:
n_workers = 2
with LocalCUDACluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
check_uneven_nan(client, "gpu_hist", n_workers)
@pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf())
def test_dask_dataframe(self, local_cuda_client: Client) -> None: def test_dask_dataframe(self, local_cuda_client: Client) -> None:
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client) run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)

View File

@ -4,7 +4,6 @@ import json
import os import os
import pickle import pickle
import socket import socket
import subprocess
import tempfile import tempfile
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
@ -41,7 +40,7 @@ from distributed import Client, LocalCluster
from toolz import sliding_window # dependency of dask from toolz import sliding_window # dependency of dask
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
from xgboost.testing.dask import check_init_estimation from xgboost.testing.dask import check_init_estimation, check_uneven_nan
dask.config.set({"distributed.scheduler.allowed-failures": False}) dask.config.set({"distributed.scheduler.allowed-failures": False})
@ -2014,6 +2013,14 @@ def test_init_estimation(client: Client) -> None:
check_init_estimation("hist", client) check_init_estimation("hist", client)
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_uneven_nan(tree_method) -> None:
n_workers = 2
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
check_uneven_nan(client, tree_method, n_workers)
class TestDaskCallbacks: class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping(self, client: "Client") -> None: def test_early_stopping(self, client: "Client") -> None: