Fix warnings in GPU dask tests. (#10358)

This commit is contained in:
Jiaming Yuan 2024-06-04 12:58:58 +08:00 committed by GitHub
parent 0808e50ae8
commit 979e392deb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 22 deletions

View File

@ -1,5 +1,7 @@
"""Tests for dask shared by different test modules.""" """Tests for dask shared by different test modules."""
from typing import Literal
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from dask import array as da from dask import array as da
@ -10,19 +12,26 @@ import xgboost as xgb
from xgboost.testing.updater import get_basescore from xgboost.testing.updater import get_basescore
def check_init_estimation_clf(tree_method: str, client: Client) -> None: def check_init_estimation_clf(
tree_method: str, device: Literal["cpu", "cuda"], client: Client
) -> None:
"""Test init estimation for classsifier.""" """Test init estimation for classsifier."""
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
X, y = make_classification(n_samples=4096 * 2, n_features=32, random_state=1994) X, y = make_classification(n_samples=4096 * 2, n_features=32, random_state=1994)
clf = xgb.XGBClassifier(n_estimators=1, max_depth=1, tree_method=tree_method) clf = xgb.XGBClassifier(
n_estimators=1, max_depth=1, tree_method=tree_method, device=device
)
clf.fit(X, y) clf.fit(X, y)
base_score = get_basescore(clf) base_score = get_basescore(clf)
dx = da.from_array(X).rechunk(chunks=(32, None)) dx = da.from_array(X).rechunk(chunks=(32, None))
dy = da.from_array(y).rechunk(chunks=(32,)) dy = da.from_array(y).rechunk(chunks=(32,))
dclf = xgb.dask.DaskXGBClassifier( dclf = xgb.dask.DaskXGBClassifier(
n_estimators=1, max_depth=1, tree_method=tree_method n_estimators=1,
max_depth=1,
tree_method=tree_method,
device=device,
) )
dclf.client = client dclf.client = client
dclf.fit(dx, dy) dclf.fit(dx, dy)
@ -30,20 +39,24 @@ def check_init_estimation_clf(tree_method: str, client: Client) -> None:
np.testing.assert_allclose(base_score, dbase_score) np.testing.assert_allclose(base_score, dbase_score)
def check_init_estimation_reg(tree_method: str, client: Client) -> None: def check_init_estimation_reg(
tree_method: str, device: Literal["cpu", "cuda"], client: Client
) -> None:
"""Test init estimation for regressor.""" """Test init estimation for regressor."""
from sklearn.datasets import make_regression from sklearn.datasets import make_regression
# pylint: disable=unbalanced-tuple-unpacking # pylint: disable=unbalanced-tuple-unpacking
X, y = make_regression(n_samples=4096 * 2, n_features=32, random_state=1994) X, y = make_regression(n_samples=4096 * 2, n_features=32, random_state=1994)
reg = xgb.XGBRegressor(n_estimators=1, max_depth=1, tree_method=tree_method) reg = xgb.XGBRegressor(
n_estimators=1, max_depth=1, tree_method=tree_method, device=device
)
reg.fit(X, y) reg.fit(X, y)
base_score = get_basescore(reg) base_score = get_basescore(reg)
dx = da.from_array(X).rechunk(chunks=(32, None)) dx = da.from_array(X).rechunk(chunks=(32, None))
dy = da.from_array(y).rechunk(chunks=(32,)) dy = da.from_array(y).rechunk(chunks=(32,))
dreg = xgb.dask.DaskXGBRegressor( dreg = xgb.dask.DaskXGBRegressor(
n_estimators=1, max_depth=1, tree_method=tree_method n_estimators=1, max_depth=1, tree_method=tree_method, device=device
) )
dreg.client = client dreg.client = client
dreg.fit(dx, dy) dreg.fit(dx, dy)
@ -51,22 +64,26 @@ def check_init_estimation_reg(tree_method: str, client: Client) -> None:
np.testing.assert_allclose(base_score, dbase_score) np.testing.assert_allclose(base_score, dbase_score)
def check_init_estimation(tree_method: str, client: Client) -> None: def check_init_estimation(
tree_method: str, device: Literal["cpu", "cuda"], client: Client
) -> None:
"""Test init estimation.""" """Test init estimation."""
check_init_estimation_reg(tree_method, client) check_init_estimation_reg(tree_method, device, client)
check_init_estimation_clf(tree_method, client) check_init_estimation_clf(tree_method, device, client)
def check_uneven_nan(client: Client, tree_method: str, n_workers: int) -> None: def check_uneven_nan(
client: Client, tree_method: str, device: Literal["cpu", "cuda"], n_workers: int
) -> None:
"""Issue #9271, not every worker has missing value.""" """Issue #9271, not every worker has missing value."""
assert n_workers >= 2 assert n_workers >= 2
with client.as_current(): with client.as_current():
clf = xgb.dask.DaskXGBClassifier(tree_method=tree_method) clf = xgb.dask.DaskXGBClassifier(tree_method=tree_method, device=device)
X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)}) X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)})
y = pd.Series([*[0] * 5000, *[1] * 5000]) y = pd.Series([*[0] * 5000, *[1] * 5000])
X["a"][:3000:1000] = np.nan X.loc[:3000:1000, "a"] = np.nan
client.wait_for_workers(n_workers=n_workers) client.wait_for_workers(n_workers=n_workers)

View File

@ -230,13 +230,13 @@ class TestDistributedGPU:
run_boost_from_prediction_multi_class(X, y, "hist", "cuda", local_cuda_client) run_boost_from_prediction_multi_class(X, y, "hist", "cuda", local_cuda_client)
def test_init_estimation(self, local_cuda_client: Client) -> None: def test_init_estimation(self, local_cuda_client: Client) -> None:
check_init_estimation("gpu_hist", local_cuda_client) check_init_estimation("hist", "cuda", local_cuda_client)
def test_uneven_nan(self) -> None: def test_uneven_nan(self) -> None:
n_workers = 2 n_workers = 2
with LocalCUDACluster(n_workers=n_workers) as cluster: with LocalCUDACluster(n_workers=n_workers) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
check_uneven_nan(client, "gpu_hist", n_workers) check_uneven_nan(client, "hist", "cuda", n_workers)
@pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf())
def test_dask_dataframe(self, local_cuda_client: Client) -> None: def test_dask_dataframe(self, local_cuda_client: Client) -> None:
@ -386,7 +386,7 @@ class TestDistributedGPU:
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_)) X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_)) y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_)) w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
run_dask_classifier(X, y, w, model, "gpu_hist", local_cuda_client, 10) run_dask_classifier(X, y, w, model, "hist", "cuda", local_cuda_client, 10)
def test_empty_dmatrix(self, local_cuda_client: Client) -> None: def test_empty_dmatrix(self, local_cuda_client: Client) -> None:
parameters = { parameters = {

View File

@ -12,7 +12,7 @@ from itertools import starmap
from math import ceil from math import ceil
from operator import attrgetter, getitem from operator import attrgetter, getitem
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple, Type, TypeVar, Union from typing import Any, Dict, Generator, Literal, Optional, Tuple, Type, TypeVar, Union
import hypothesis import hypothesis
import numpy as np import numpy as np
@ -700,6 +700,7 @@ def run_dask_classifier(
w: xgb.dask._DaskCollection, w: xgb.dask._DaskCollection,
model: str, model: str,
tree_method: Optional[str], tree_method: Optional[str],
device: Literal["cpu", "cuda"],
client: "Client", client: "Client",
n_classes, n_classes,
) -> None: ) -> None:
@ -707,11 +708,19 @@ def run_dask_classifier(
if model == "boosting": if model == "boosting":
classifier = xgb.dask.DaskXGBClassifier( classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method verbosity=1,
n_estimators=2,
eval_metric=metric,
tree_method=tree_method,
device=device,
) )
else: else:
classifier = xgb.dask.DaskXGBRFClassifier( classifier = xgb.dask.DaskXGBRFClassifier(
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method verbosity=1,
n_estimators=2,
eval_metric=metric,
tree_method=tree_method,
device=device,
) )
assert classifier._estimator_type == "classifier" assert classifier._estimator_type == "classifier"
@ -785,12 +794,12 @@ def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True) X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32) y = (y * 10).astype(np.int32)
assert w is not None assert w is not None
run_dask_classifier(X, y, w, model, None, client, 10) run_dask_classifier(X, y, w, model, None, "cpu", client, 10)
y_bin = y.copy() y_bin = y.copy()
y_bin[y > 5] = 1.0 y_bin[y > 5] = 1.0
y_bin[y <= 5] = 0.0 y_bin[y <= 5] = 0.0
run_dask_classifier(X, y_bin, w, model, None, client, 2) run_dask_classifier(X, y_bin, w, model, None, "cpu", client, 2)
def test_empty_dmatrix_training_continuation(client: "Client") -> None: def test_empty_dmatrix_training_continuation(client: "Client") -> None:
@ -2136,7 +2145,7 @@ def test_parallel_submit_multi_clients() -> None:
def test_init_estimation(client: Client) -> None: def test_init_estimation(client: Client) -> None:
check_init_estimation("hist", client) check_init_estimation("hist", "cpu", client)
@pytest.mark.parametrize("tree_method", ["hist", "approx"]) @pytest.mark.parametrize("tree_method", ["hist", "approx"])
@ -2144,7 +2153,7 @@ def test_uneven_nan(tree_method) -> None:
n_workers = 2 n_workers = 2
with LocalCluster(n_workers=n_workers) as cluster: with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
check_uneven_nan(client, tree_method, n_workers) check_uneven_nan(client, tree_method, "cpu", n_workers)
class TestDaskCallbacks: class TestDaskCallbacks: