Fix warnings in GPU dask tests. (#10358)

This commit is contained in:
Jiaming Yuan
2024-06-04 12:58:58 +08:00
committed by GitHub
parent 0808e50ae8
commit 979e392deb
3 changed files with 48 additions and 22 deletions

View File

@@ -12,7 +12,7 @@ from itertools import starmap
from math import ceil
from operator import attrgetter, getitem
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Generator, Literal, Optional, Tuple, Type, TypeVar, Union
import hypothesis
import numpy as np
@@ -700,6 +700,7 @@ def run_dask_classifier(
w: xgb.dask._DaskCollection,
model: str,
tree_method: Optional[str],
device: Literal["cpu", "cuda"],
client: "Client",
n_classes,
) -> None:
@@ -707,11 +708,19 @@ def run_dask_classifier(
if model == "boosting":
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
verbosity=1,
n_estimators=2,
eval_metric=metric,
tree_method=tree_method,
device=device,
)
else:
classifier = xgb.dask.DaskXGBRFClassifier(
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
verbosity=1,
n_estimators=2,
eval_metric=metric,
tree_method=tree_method,
device=device,
)
assert classifier._estimator_type == "classifier"
@@ -785,12 +794,12 @@ def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
assert w is not None
run_dask_classifier(X, y, w, model, None, client, 10)
run_dask_classifier(X, y, w, model, None, "cpu", client, 10)
y_bin = y.copy()
y_bin[y > 5] = 1.0
y_bin[y <= 5] = 0.0
run_dask_classifier(X, y_bin, w, model, None, client, 2)
run_dask_classifier(X, y_bin, w, model, None, "cpu", client, 2)
def test_empty_dmatrix_training_continuation(client: "Client") -> None:
@@ -2136,7 +2145,7 @@ def test_parallel_submit_multi_clients() -> None:
def test_init_estimation(client: Client) -> None:
check_init_estimation("hist", client)
check_init_estimation("hist", "cpu", client)
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
@@ -2144,7 +2153,7 @@ def test_uneven_nan(tree_method) -> None:
n_workers = 2
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
check_uneven_nan(client, tree_method, n_workers)
check_uneven_nan(client, tree_method, "cpu", n_workers)
class TestDaskCallbacks: