Fix warnings in GPU dask tests. (#10358)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Tests for dask shared by different test modules."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dask import array as da
|
||||
@@ -10,19 +12,26 @@ import xgboost as xgb
|
||||
from xgboost.testing.updater import get_basescore
|
||||
|
||||
|
||||
def check_init_estimation_clf(tree_method: str, client: Client) -> None:
|
||||
def check_init_estimation_clf(
|
||||
tree_method: str, device: Literal["cpu", "cuda"], client: Client
|
||||
) -> None:
|
||||
"""Test init estimation for classsifier."""
|
||||
from sklearn.datasets import make_classification
|
||||
|
||||
X, y = make_classification(n_samples=4096 * 2, n_features=32, random_state=1994)
|
||||
clf = xgb.XGBClassifier(n_estimators=1, max_depth=1, tree_method=tree_method)
|
||||
clf = xgb.XGBClassifier(
|
||||
n_estimators=1, max_depth=1, tree_method=tree_method, device=device
|
||||
)
|
||||
clf.fit(X, y)
|
||||
base_score = get_basescore(clf)
|
||||
|
||||
dx = da.from_array(X).rechunk(chunks=(32, None))
|
||||
dy = da.from_array(y).rechunk(chunks=(32,))
|
||||
dclf = xgb.dask.DaskXGBClassifier(
|
||||
n_estimators=1, max_depth=1, tree_method=tree_method
|
||||
n_estimators=1,
|
||||
max_depth=1,
|
||||
tree_method=tree_method,
|
||||
device=device,
|
||||
)
|
||||
dclf.client = client
|
||||
dclf.fit(dx, dy)
|
||||
@@ -30,20 +39,24 @@ def check_init_estimation_clf(tree_method: str, client: Client) -> None:
|
||||
np.testing.assert_allclose(base_score, dbase_score)
|
||||
|
||||
|
||||
def check_init_estimation_reg(tree_method: str, client: Client) -> None:
|
||||
def check_init_estimation_reg(
|
||||
tree_method: str, device: Literal["cpu", "cuda"], client: Client
|
||||
) -> None:
|
||||
"""Test init estimation for regressor."""
|
||||
from sklearn.datasets import make_regression
|
||||
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
X, y = make_regression(n_samples=4096 * 2, n_features=32, random_state=1994)
|
||||
reg = xgb.XGBRegressor(n_estimators=1, max_depth=1, tree_method=tree_method)
|
||||
reg = xgb.XGBRegressor(
|
||||
n_estimators=1, max_depth=1, tree_method=tree_method, device=device
|
||||
)
|
||||
reg.fit(X, y)
|
||||
base_score = get_basescore(reg)
|
||||
|
||||
dx = da.from_array(X).rechunk(chunks=(32, None))
|
||||
dy = da.from_array(y).rechunk(chunks=(32,))
|
||||
dreg = xgb.dask.DaskXGBRegressor(
|
||||
n_estimators=1, max_depth=1, tree_method=tree_method
|
||||
n_estimators=1, max_depth=1, tree_method=tree_method, device=device
|
||||
)
|
||||
dreg.client = client
|
||||
dreg.fit(dx, dy)
|
||||
@@ -51,22 +64,26 @@ def check_init_estimation_reg(tree_method: str, client: Client) -> None:
|
||||
np.testing.assert_allclose(base_score, dbase_score)
|
||||
|
||||
|
||||
def check_init_estimation(tree_method: str, client: Client) -> None:
|
||||
def check_init_estimation(
|
||||
tree_method: str, device: Literal["cpu", "cuda"], client: Client
|
||||
) -> None:
|
||||
"""Test init estimation."""
|
||||
check_init_estimation_reg(tree_method, client)
|
||||
check_init_estimation_clf(tree_method, client)
|
||||
check_init_estimation_reg(tree_method, device, client)
|
||||
check_init_estimation_clf(tree_method, device, client)
|
||||
|
||||
|
||||
def check_uneven_nan(client: Client, tree_method: str, n_workers: int) -> None:
|
||||
def check_uneven_nan(
|
||||
client: Client, tree_method: str, device: Literal["cpu", "cuda"], n_workers: int
|
||||
) -> None:
|
||||
"""Issue #9271, not every worker has missing value."""
|
||||
assert n_workers >= 2
|
||||
|
||||
with client.as_current():
|
||||
clf = xgb.dask.DaskXGBClassifier(tree_method=tree_method)
|
||||
clf = xgb.dask.DaskXGBClassifier(tree_method=tree_method, device=device)
|
||||
X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)})
|
||||
y = pd.Series([*[0] * 5000, *[1] * 5000])
|
||||
|
||||
X["a"][:3000:1000] = np.nan
|
||||
X.loc[:3000:1000, "a"] = np.nan
|
||||
|
||||
client.wait_for_workers(n_workers=n_workers)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user