Cleanup Python GPU tests. (#9934)
* Cleanup Python GPU tests. - Remove the use of `gpu_hist` and `gpu_id` in cudf/cupy tests. - Move base margin test into the testing directory.
This commit is contained in:
@@ -3,7 +3,17 @@
|
||||
import os
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generator, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from urllib import request
|
||||
|
||||
import numpy as np
|
||||
@@ -603,3 +613,51 @@ def sort_ltr_samples(
|
||||
data = X, clicks, y, qid
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def run_base_margin_info(
|
||||
DType: Callable, DMatrixT: Type[xgboost.DMatrix], device: str
|
||||
) -> None:
|
||||
"""Run tests for base margin."""
|
||||
rng = np.random.default_rng()
|
||||
X = DType(rng.normal(0, 1.0, size=100).astype(np.float32).reshape(50, 2))
|
||||
if hasattr(X, "iloc"):
|
||||
y = X.iloc[:, 0]
|
||||
else:
|
||||
y = X[:, 0]
|
||||
base_margin = X
|
||||
# no error at set
|
||||
Xy = DMatrixT(X, y, base_margin=base_margin)
|
||||
# Error at train, caused by check in predictor.
|
||||
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
||||
xgboost.train({"tree_method": "hist", "device": device}, Xy)
|
||||
|
||||
if not hasattr(X, "iloc"):
|
||||
# column major matrix
|
||||
got = DType(Xy.get_base_margin().reshape(50, 2))
|
||||
assert (got == base_margin).all()
|
||||
|
||||
assert base_margin.T.flags.c_contiguous is False
|
||||
assert base_margin.T.flags.f_contiguous is True
|
||||
Xy.set_info(base_margin=base_margin.T)
|
||||
got = DType(Xy.get_base_margin().reshape(2, 50))
|
||||
assert (got == base_margin.T).all()
|
||||
|
||||
# Row vs col vec.
|
||||
base_margin = y
|
||||
Xy.set_base_margin(base_margin)
|
||||
bm_col = Xy.get_base_margin()
|
||||
Xy.set_base_margin(base_margin.reshape(1, base_margin.size))
|
||||
bm_row = Xy.get_base_margin()
|
||||
assert (bm_row == bm_col).all()
|
||||
|
||||
# type
|
||||
base_margin = base_margin.astype(np.float64)
|
||||
Xy.set_base_margin(base_margin)
|
||||
bm_f64 = Xy.get_base_margin()
|
||||
assert (bm_f64 == bm_col).all()
|
||||
|
||||
# too many dimensions
|
||||
base_margin = X.reshape(2, 5, 2, 5)
|
||||
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
||||
Xy.set_base_margin(base_margin)
|
||||
|
||||
Reference in New Issue
Block a user