Cleanup Python GPU tests. (#9934)

* Cleanup Python GPU tests.

- Remove the use of `gpu_hist` and `gpu_id` in cudf/cupy tests.
- Move base margin test into the testing directory.
This commit is contained in:
Jiaming Yuan
2024-01-04 13:15:18 +08:00
committed by GitHub
parent 3c004a4145
commit 9f73127a23
14 changed files with 282 additions and 240 deletions

View File

@@ -3,7 +3,17 @@
import os
import zipfile
from dataclasses import dataclass
from typing import Any, Generator, List, NamedTuple, Optional, Tuple, Union
from typing import (
Any,
Callable,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
from urllib import request
import numpy as np
@@ -603,3 +613,51 @@ def sort_ltr_samples(
data = X, clicks, y, qid
return data
def run_base_margin_info(
DType: Callable, DMatrixT: Type[xgboost.DMatrix], device: str
) -> None:
"""Run tests for base margin."""
rng = np.random.default_rng()
X = DType(rng.normal(0, 1.0, size=100).astype(np.float32).reshape(50, 2))
if hasattr(X, "iloc"):
y = X.iloc[:, 0]
else:
y = X[:, 0]
base_margin = X
# no error at set
Xy = DMatrixT(X, y, base_margin=base_margin)
# Error at train, caused by check in predictor.
with pytest.raises(ValueError, match=r".*base_margin.*"):
xgboost.train({"tree_method": "hist", "device": device}, Xy)
if not hasattr(X, "iloc"):
# column major matrix
got = DType(Xy.get_base_margin().reshape(50, 2))
assert (got == base_margin).all()
assert base_margin.T.flags.c_contiguous is False
assert base_margin.T.flags.f_contiguous is True
Xy.set_info(base_margin=base_margin.T)
got = DType(Xy.get_base_margin().reshape(2, 50))
assert (got == base_margin.T).all()
# Row vs col vec.
base_margin = y
Xy.set_base_margin(base_margin)
bm_col = Xy.get_base_margin()
Xy.set_base_margin(base_margin.reshape(1, base_margin.size))
bm_row = Xy.get_base_margin()
assert (bm_row == bm_col).all()
# type
base_margin = base_margin.astype(np.float64)
Xy.set_base_margin(base_margin)
bm_f64 = Xy.get_base_margin()
assert (bm_f64 == bm_col).all()
# too many dimensions
base_margin = X.reshape(2, 5, 2, 5)
with pytest.raises(ValueError, match=r".*base_margin.*"):
Xy.set_base_margin(base_margin)