Cleanup Python GPU tests. (#9934)
* Cleanup Python GPU tests. - Remove the use of `gpu_hist` and `gpu_id` in cudf/cupy tests. - Move base margin test into the testing directory.
This commit is contained in:
parent
3c004a4145
commit
9f73127a23
@ -58,7 +58,7 @@ def individual_tree() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def model_slices() -> None:
|
def model_slices() -> None:
|
||||||
"""Inference with each individual using model slices."""
|
"""Inference with each individual tree using model slices."""
|
||||||
X_train, y_train = load_svmlight_file(train)
|
X_train, y_train = load_svmlight_file(train)
|
||||||
X_test, y_test = load_svmlight_file(test)
|
X_test, y_test = load_svmlight_file(test)
|
||||||
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
|
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
|
||||||
|
|||||||
@ -3,7 +3,17 @@
|
|||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Generator, List, NamedTuple, Optional, Tuple, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -603,3 +613,51 @@ def sort_ltr_samples(
|
|||||||
data = X, clicks, y, qid
|
data = X, clicks, y, qid
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def run_base_margin_info(
|
||||||
|
DType: Callable, DMatrixT: Type[xgboost.DMatrix], device: str
|
||||||
|
) -> None:
|
||||||
|
"""Run tests for base margin."""
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
X = DType(rng.normal(0, 1.0, size=100).astype(np.float32).reshape(50, 2))
|
||||||
|
if hasattr(X, "iloc"):
|
||||||
|
y = X.iloc[:, 0]
|
||||||
|
else:
|
||||||
|
y = X[:, 0]
|
||||||
|
base_margin = X
|
||||||
|
# no error at set
|
||||||
|
Xy = DMatrixT(X, y, base_margin=base_margin)
|
||||||
|
# Error at train, caused by check in predictor.
|
||||||
|
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
||||||
|
xgboost.train({"tree_method": "hist", "device": device}, Xy)
|
||||||
|
|
||||||
|
if not hasattr(X, "iloc"):
|
||||||
|
# column major matrix
|
||||||
|
got = DType(Xy.get_base_margin().reshape(50, 2))
|
||||||
|
assert (got == base_margin).all()
|
||||||
|
|
||||||
|
assert base_margin.T.flags.c_contiguous is False
|
||||||
|
assert base_margin.T.flags.f_contiguous is True
|
||||||
|
Xy.set_info(base_margin=base_margin.T)
|
||||||
|
got = DType(Xy.get_base_margin().reshape(2, 50))
|
||||||
|
assert (got == base_margin.T).all()
|
||||||
|
|
||||||
|
# Row vs col vec.
|
||||||
|
base_margin = y
|
||||||
|
Xy.set_base_margin(base_margin)
|
||||||
|
bm_col = Xy.get_base_margin()
|
||||||
|
Xy.set_base_margin(base_margin.reshape(1, base_margin.size))
|
||||||
|
bm_row = Xy.get_base_margin()
|
||||||
|
assert (bm_row == bm_col).all()
|
||||||
|
|
||||||
|
# type
|
||||||
|
base_margin = base_margin.astype(np.float64)
|
||||||
|
Xy.set_base_margin(base_margin)
|
||||||
|
bm_f64 = Xy.get_base_margin()
|
||||||
|
assert (bm_f64 == bm_col).all()
|
||||||
|
|
||||||
|
# too many dimensions
|
||||||
|
base_margin = X.reshape(2, 5, 2, 5)
|
||||||
|
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
||||||
|
Xy.set_base_margin(base_margin)
|
||||||
|
|||||||
@ -27,13 +27,8 @@ class LintersPaths:
|
|||||||
"tests/python/test_tree_regularization.py",
|
"tests/python/test_tree_regularization.py",
|
||||||
"tests/python/test_shap.py",
|
"tests/python/test_shap.py",
|
||||||
"tests/python/test_with_pandas.py",
|
"tests/python/test_with_pandas.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/",
|
||||||
"tests/python-gpu/test_gpu_prediction.py",
|
"tests/python-sycl/",
|
||||||
"tests/python-gpu/load_pickle.py",
|
|
||||||
"tests/python-gpu/test_gpu_pickling.py",
|
|
||||||
"tests/python-gpu/test_gpu_eval_metrics.py",
|
|
||||||
"tests/python-gpu/test_gpu_with_sklearn.py",
|
|
||||||
"tests/python-sycl/test_sycl_prediction.py",
|
|
||||||
"tests/test_distributed/test_with_spark/",
|
"tests/test_distributed/test_with_spark/",
|
||||||
"tests/test_distributed/test_gpu_with_spark/",
|
"tests/test_distributed/test_gpu_with_spark/",
|
||||||
# demo
|
# demo
|
||||||
|
|||||||
@ -203,9 +203,7 @@ class TestQuantileDMatrix:
|
|||||||
np.testing.assert_equal(h_ret.indptr, d_ret.indptr)
|
np.testing.assert_equal(h_ret.indptr, d_ret.indptr)
|
||||||
np.testing.assert_equal(h_ret.indices, d_ret.indices)
|
np.testing.assert_equal(h_ret.indices, d_ret.indices)
|
||||||
|
|
||||||
booster = xgb.train(
|
booster = xgb.train({"tree_method": "hist", "device": "cuda:0"}, dtrain=d_m)
|
||||||
{"tree_method": "hist", "device": "cuda:0"}, dtrain=d_m
|
|
||||||
)
|
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
booster.predict(d_m),
|
booster.predict(d_m),
|
||||||
@ -215,6 +213,7 @@ class TestQuantileDMatrix:
|
|||||||
|
|
||||||
def test_ltr(self) -> None:
|
def test_ltr(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
X, y, qid, w = tm.make_ltr(100, 3, 3, 5)
|
X, y, qid, w = tm.make_ltr(100, 3, 3, 5)
|
||||||
# make sure GPU is used to run sketching.
|
# make sure GPU is used to run sketching.
|
||||||
cpX = cp.array(X)
|
cpX = cp.array(X)
|
||||||
|
|||||||
@ -1,19 +1,17 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.data import run_base_margin_info
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
cudf = pytest.importorskip("cudf")
|
||||||
from test_dmatrix import set_base_margin_info
|
|
||||||
|
|
||||||
|
|
||||||
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
|
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
|
||||||
'''Test constructing DMatrix from cudf'''
|
"""Test constructing DMatrix from cudf"""
|
||||||
import cudf
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
kRows = 80
|
kRows = 80
|
||||||
@ -25,9 +23,7 @@ def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
|
|||||||
na[5, 0] = missing
|
na[5, 0] = missing
|
||||||
na[3, 1] = missing
|
na[3, 1] = missing
|
||||||
|
|
||||||
pa = pd.DataFrame({'0': na[:, 0],
|
pa = pd.DataFrame({"0": na[:, 0], "1": na[:, 1], "2": na[:, 2].astype(np.int32)})
|
||||||
'1': na[:, 1],
|
|
||||||
'2': na[:, 2].astype(np.int32)})
|
|
||||||
|
|
||||||
np_label = np.random.randn(kRows).astype(input_type)
|
np_label = np.random.randn(kRows).astype(input_type)
|
||||||
pa_label = pd.DataFrame(np_label)
|
pa_label = pd.DataFrame(np_label)
|
||||||
@ -41,8 +37,7 @@ def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
|
|||||||
|
|
||||||
|
|
||||||
def _test_from_cudf(DMatrixT):
|
def _test_from_cudf(DMatrixT):
|
||||||
'''Test constructing DMatrix from cudf'''
|
"""Test constructing DMatrix from cudf"""
|
||||||
import cudf
|
|
||||||
dmatrix_from_cudf(np.float32, DMatrixT, np.NAN)
|
dmatrix_from_cudf(np.float32, DMatrixT, np.NAN)
|
||||||
dmatrix_from_cudf(np.float64, DMatrixT, np.NAN)
|
dmatrix_from_cudf(np.float64, DMatrixT, np.NAN)
|
||||||
|
|
||||||
@ -50,37 +45,38 @@ def _test_from_cudf(DMatrixT):
|
|||||||
dmatrix_from_cudf(np.int32, DMatrixT, -2)
|
dmatrix_from_cudf(np.int32, DMatrixT, -2)
|
||||||
dmatrix_from_cudf(np.int64, DMatrixT, -3)
|
dmatrix_from_cudf(np.int64, DMatrixT, -3)
|
||||||
|
|
||||||
cd = cudf.DataFrame({'x': [1, 2, 3], 'y': [0.1, 0.2, 0.3]})
|
cd = cudf.DataFrame({"x": [1, 2, 3], "y": [0.1, 0.2, 0.3]})
|
||||||
dtrain = DMatrixT(cd)
|
dtrain = DMatrixT(cd)
|
||||||
|
|
||||||
assert dtrain.feature_names == ['x', 'y']
|
assert dtrain.feature_names == ["x", "y"]
|
||||||
assert dtrain.feature_types == ['int', 'float']
|
assert dtrain.feature_types == ["int", "float"]
|
||||||
|
|
||||||
series = cudf.DataFrame({'x': [1, 2, 3]}).iloc[:, 0]
|
series = cudf.DataFrame({"x": [1, 2, 3]}).iloc[:, 0]
|
||||||
assert isinstance(series, cudf.Series)
|
assert isinstance(series, cudf.Series)
|
||||||
dtrain = DMatrixT(series)
|
dtrain = DMatrixT(series)
|
||||||
|
|
||||||
assert dtrain.feature_names == ['x']
|
assert dtrain.feature_names == ["x"]
|
||||||
assert dtrain.feature_types == ['int']
|
assert dtrain.feature_types == ["int"]
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*multi.*"):
|
with pytest.raises(ValueError, match=r".*multi.*"):
|
||||||
dtrain = DMatrixT(cd, label=cd)
|
dtrain = DMatrixT(cd, label=cd)
|
||||||
xgb.train({"tree_method": "gpu_hist", "objective": "multi:softprob"}, dtrain)
|
xgb.train(
|
||||||
|
{"tree_method": "hist", "device": "cuda", "objective": "multi:softprob"},
|
||||||
|
dtrain,
|
||||||
|
)
|
||||||
|
|
||||||
# Test when number of elements is less than 8
|
# Test when number of elements is less than 8
|
||||||
X = cudf.DataFrame({'x': cudf.Series([0, 1, 2, np.NAN, 4],
|
X = cudf.DataFrame({"x": cudf.Series([0, 1, 2, np.NAN, 4], dtype=np.int32)})
|
||||||
dtype=np.int32)})
|
|
||||||
dtrain = DMatrixT(X)
|
dtrain = DMatrixT(X)
|
||||||
assert dtrain.num_col() == 1
|
assert dtrain.num_col() == 1
|
||||||
assert dtrain.num_row() == 5
|
assert dtrain.num_row() == 5
|
||||||
|
|
||||||
# Boolean is not supported.
|
# Boolean is not supported.
|
||||||
X_boolean = cudf.DataFrame({'x': cudf.Series([True, False])})
|
X_boolean = cudf.DataFrame({"x": cudf.Series([True, False])})
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
dtrain = DMatrixT(X_boolean)
|
dtrain = DMatrixT(X_boolean)
|
||||||
|
|
||||||
y_boolean = cudf.DataFrame({
|
y_boolean = cudf.DataFrame({"x": cudf.Series([True, False, True, True, True])})
|
||||||
'x': cudf.Series([True, False, True, True, True])})
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
dtrain = DMatrixT(X_boolean, label=y_boolean)
|
dtrain = DMatrixT(X_boolean, label=y_boolean)
|
||||||
|
|
||||||
@ -88,6 +84,7 @@ def _test_from_cudf(DMatrixT):
|
|||||||
def _test_cudf_training(DMatrixT):
|
def _test_cudf_training(DMatrixT):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from cudf import DataFrame as df
|
from cudf import DataFrame as df
|
||||||
|
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
X = pd.DataFrame(np.random.randn(50, 10))
|
X = pd.DataFrame(np.random.randn(50, 10))
|
||||||
y = pd.DataFrame(np.random.randn(50))
|
y = pd.DataFrame(np.random.randn(50))
|
||||||
@ -97,21 +94,33 @@ def _test_cudf_training(DMatrixT):
|
|||||||
cudf_base_margin = df.from_pandas(pd.DataFrame(base_margin))
|
cudf_base_margin = df.from_pandas(pd.DataFrame(base_margin))
|
||||||
|
|
||||||
evals_result_cudf = {}
|
evals_result_cudf = {}
|
||||||
dtrain_cudf = DMatrixT(df.from_pandas(X), df.from_pandas(y), weight=cudf_weights,
|
dtrain_cudf = DMatrixT(
|
||||||
base_margin=cudf_base_margin)
|
df.from_pandas(X),
|
||||||
params = {'gpu_id': 0, 'tree_method': 'gpu_hist'}
|
df.from_pandas(y),
|
||||||
xgb.train(params, dtrain_cudf, evals=[(dtrain_cudf, "train")],
|
weight=cudf_weights,
|
||||||
evals_result=evals_result_cudf)
|
base_margin=cudf_base_margin,
|
||||||
|
)
|
||||||
|
params = {"device": "cuda", "tree_method": "hist"}
|
||||||
|
xgb.train(
|
||||||
|
params,
|
||||||
|
dtrain_cudf,
|
||||||
|
evals=[(dtrain_cudf, "train")],
|
||||||
|
evals_result=evals_result_cudf,
|
||||||
|
)
|
||||||
evals_result_np = {}
|
evals_result_np = {}
|
||||||
dtrain_np = xgb.DMatrix(X, y, weight=weights, base_margin=base_margin)
|
dtrain_np = xgb.DMatrix(X, y, weight=weights, base_margin=base_margin)
|
||||||
xgb.train(params, dtrain_np, evals=[(dtrain_np, "train")],
|
xgb.train(
|
||||||
evals_result=evals_result_np)
|
params, dtrain_np, evals=[(dtrain_np, "train")], evals_result=evals_result_np
|
||||||
assert np.array_equal(evals_result_cudf["train"]["rmse"], evals_result_np["train"]["rmse"])
|
)
|
||||||
|
assert np.array_equal(
|
||||||
|
evals_result_cudf["train"]["rmse"], evals_result_np["train"]["rmse"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_cudf_metainfo(DMatrixT):
|
def _test_cudf_metainfo(DMatrixT):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from cudf import DataFrame as df
|
from cudf import DataFrame as df
|
||||||
|
|
||||||
n = 100
|
n = 100
|
||||||
X = np.random.random((n, 2))
|
X = np.random.random((n, 2))
|
||||||
dmat_cudf = DMatrixT(df.from_pandas(pd.DataFrame(X)))
|
dmat_cudf = DMatrixT(df.from_pandas(pd.DataFrame(X)))
|
||||||
@ -120,39 +129,53 @@ def _test_cudf_metainfo(DMatrixT):
|
|||||||
uints = np.array([4, 2, 8]).astype("uint32")
|
uints = np.array([4, 2, 8]).astype("uint32")
|
||||||
cudf_floats = df.from_pandas(pd.DataFrame(floats))
|
cudf_floats = df.from_pandas(pd.DataFrame(floats))
|
||||||
cudf_uints = df.from_pandas(pd.DataFrame(uints))
|
cudf_uints = df.from_pandas(pd.DataFrame(uints))
|
||||||
dmat.set_float_info('weight', floats)
|
dmat.set_float_info("weight", floats)
|
||||||
dmat.set_float_info('label', floats)
|
dmat.set_float_info("label", floats)
|
||||||
dmat.set_float_info('base_margin', floats)
|
dmat.set_float_info("base_margin", floats)
|
||||||
dmat.set_uint_info('group', uints)
|
dmat.set_uint_info("group", uints)
|
||||||
dmat_cudf.set_info(weight=cudf_floats)
|
dmat_cudf.set_info(weight=cudf_floats)
|
||||||
dmat_cudf.set_info(label=cudf_floats)
|
dmat_cudf.set_info(label=cudf_floats)
|
||||||
dmat_cudf.set_info(base_margin=cudf_floats)
|
dmat_cudf.set_info(base_margin=cudf_floats)
|
||||||
dmat_cudf.set_info(group=cudf_uints)
|
dmat_cudf.set_info(group=cudf_uints)
|
||||||
|
|
||||||
# Test setting info with cudf DataFrame
|
# Test setting info with cudf DataFrame
|
||||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
assert np.array_equal(
|
||||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
dmat.get_float_info("weight"), dmat_cudf.get_float_info("weight")
|
||||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
)
|
||||||
dmat_cudf.get_float_info('base_margin'))
|
assert np.array_equal(
|
||||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
dmat.get_float_info("label"), dmat_cudf.get_float_info("label")
|
||||||
|
)
|
||||||
|
assert np.array_equal(
|
||||||
|
dmat.get_float_info("base_margin"), dmat_cudf.get_float_info("base_margin")
|
||||||
|
)
|
||||||
|
assert np.array_equal(
|
||||||
|
dmat.get_uint_info("group_ptr"), dmat_cudf.get_uint_info("group_ptr")
|
||||||
|
)
|
||||||
|
|
||||||
# Test setting info with cudf Series
|
# Test setting info with cudf Series
|
||||||
dmat_cudf.set_info(weight=cudf_floats[cudf_floats.columns[0]])
|
dmat_cudf.set_info(weight=cudf_floats[cudf_floats.columns[0]])
|
||||||
dmat_cudf.set_info(label=cudf_floats[cudf_floats.columns[0]])
|
dmat_cudf.set_info(label=cudf_floats[cudf_floats.columns[0]])
|
||||||
dmat_cudf.set_info(base_margin=cudf_floats[cudf_floats.columns[0]])
|
dmat_cudf.set_info(base_margin=cudf_floats[cudf_floats.columns[0]])
|
||||||
dmat_cudf.set_info(group=cudf_uints[cudf_uints.columns[0]])
|
dmat_cudf.set_info(group=cudf_uints[cudf_uints.columns[0]])
|
||||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
assert np.array_equal(
|
||||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
dmat.get_float_info("weight"), dmat_cudf.get_float_info("weight")
|
||||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
)
|
||||||
dmat_cudf.get_float_info('base_margin'))
|
assert np.array_equal(
|
||||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
dmat.get_float_info("label"), dmat_cudf.get_float_info("label")
|
||||||
|
)
|
||||||
|
assert np.array_equal(
|
||||||
|
dmat.get_float_info("base_margin"), dmat_cudf.get_float_info("base_margin")
|
||||||
|
)
|
||||||
|
assert np.array_equal(
|
||||||
|
dmat.get_uint_info("group_ptr"), dmat_cudf.get_uint_info("group_ptr")
|
||||||
|
)
|
||||||
|
|
||||||
set_base_margin_info(df, DMatrixT, "gpu_hist")
|
run_base_margin_info(df, DMatrixT, "cuda")
|
||||||
|
|
||||||
|
|
||||||
class TestFromColumnar:
|
class TestFromColumnar:
|
||||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
"""Tests for constructing DMatrix from data structure conforming Apache
|
||||||
Arrow specification.'''
|
Arrow specification."""
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_simple_dmatrix_from_cudf(self):
|
def test_simple_dmatrix_from_cudf(self):
|
||||||
@ -180,7 +203,6 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_cudf_categorical(self) -> None:
|
def test_cudf_categorical(self) -> None:
|
||||||
import cudf
|
|
||||||
n_features = 30
|
n_features = 30
|
||||||
_X, _y = tm.make_categorical(100, n_features, 17, False)
|
_X, _y = tm.make_categorical(100, n_features, 17, False)
|
||||||
X = cudf.from_pandas(_X)
|
X = cudf.from_pandas(_X)
|
||||||
@ -251,6 +273,7 @@ def test_cudf_training_with_sklearn():
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from cudf import DataFrame as df
|
from cudf import DataFrame as df
|
||||||
from cudf import Series as ss
|
from cudf import Series as ss
|
||||||
|
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
X = pd.DataFrame(np.random.randn(50, 10))
|
X = pd.DataFrame(np.random.randn(50, 10))
|
||||||
y = pd.DataFrame((np.random.randn(50) > 0).astype(np.int8))
|
y = pd.DataFrame((np.random.randn(50) > 0).astype(np.int8))
|
||||||
@ -264,29 +287,34 @@ def test_cudf_training_with_sklearn():
|
|||||||
y_cudf_series = ss(data=y.iloc[:, 0])
|
y_cudf_series = ss(data=y.iloc[:, 0])
|
||||||
|
|
||||||
for y_obj in [y_cudf, y_cudf_series]:
|
for y_obj in [y_cudf, y_cudf_series]:
|
||||||
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist')
|
clf = xgb.XGBClassifier(tree_method="hist", device="cuda:0")
|
||||||
clf.fit(X_cudf, y_obj, sample_weight=cudf_weights, base_margin=cudf_base_margin,
|
clf.fit(
|
||||||
eval_set=[(X_cudf, y_obj)])
|
X_cudf,
|
||||||
|
y_obj,
|
||||||
|
sample_weight=cudf_weights,
|
||||||
|
base_margin=cudf_base_margin,
|
||||||
|
eval_set=[(X_cudf, y_obj)],
|
||||||
|
)
|
||||||
pred = clf.predict(X_cudf)
|
pred = clf.predict(X_cudf)
|
||||||
assert np.array_equal(np.unique(pred), np.array([0, 1]))
|
assert np.array_equal(np.unique(pred), np.array([0, 1]))
|
||||||
|
|
||||||
|
|
||||||
class IterForDMatrixTest(xgb.core.DataIter):
|
class IterForDMatrixTest(xgb.core.DataIter):
|
||||||
'''A data iterator for XGBoost DMatrix.
|
"""A data iterator for XGBoost DMatrix.
|
||||||
|
|
||||||
`reset` and `next` are required for any data iterator, other functions here
|
`reset` and `next` are required for any data iterator, other functions here
|
||||||
are utilites for demonstration's purpose.
|
are utilites for demonstration's purpose.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
ROWS_PER_BATCH = 100 # data is splited by rows
|
ROWS_PER_BATCH = 100 # data is splited by rows
|
||||||
BATCHES = 16
|
BATCHES = 16
|
||||||
|
|
||||||
def __init__(self, categorical):
|
def __init__(self, categorical):
|
||||||
'''Generate some random data for demostration.
|
"""Generate some random data for demostration.
|
||||||
|
|
||||||
Actual data can be anything that is currently supported by XGBoost.
|
Actual data can be anything that is currently supported by XGBoost.
|
||||||
'''
|
"""
|
||||||
import cudf
|
|
||||||
self.rows = self.ROWS_PER_BATCH
|
self.rows = self.ROWS_PER_BATCH
|
||||||
|
|
||||||
if categorical:
|
if categorical:
|
||||||
@ -300,34 +328,37 @@ class IterForDMatrixTest(xgb.core.DataIter):
|
|||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
self._data = [
|
self._data = [
|
||||||
cudf.DataFrame(
|
cudf.DataFrame(
|
||||||
{'a': rng.randn(self.ROWS_PER_BATCH),
|
{
|
||||||
'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES
|
"a": rng.randn(self.ROWS_PER_BATCH),
|
||||||
|
"b": rng.randn(self.ROWS_PER_BATCH),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
] * self.BATCHES
|
||||||
self._labels = [rng.randn(self.rows)] * self.BATCHES
|
self._labels = [rng.randn(self.rows)] * self.BATCHES
|
||||||
|
|
||||||
self.it = 0 # set iterator to 0
|
self.it = 0 # set iterator to 0
|
||||||
super().__init__(cache_prefix=None)
|
super().__init__(cache_prefix=None)
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
import cudf
|
|
||||||
return cudf.concat(self._data)
|
return cudf.concat(self._data)
|
||||||
|
|
||||||
def as_array_labels(self):
|
def as_array_labels(self):
|
||||||
return np.concatenate(self._labels)
|
return np.concatenate(self._labels)
|
||||||
|
|
||||||
def data(self):
|
def data(self):
|
||||||
'''Utility function for obtaining current batch of data.'''
|
"""Utility function for obtaining current batch of data."""
|
||||||
return self._data[self.it]
|
return self._data[self.it]
|
||||||
|
|
||||||
def labels(self):
|
def labels(self):
|
||||||
'''Utility function for obtaining current batch of label.'''
|
"""Utility function for obtaining current batch of label."""
|
||||||
return self._labels[self.it]
|
return self._labels[self.it]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
'''Reset the iterator'''
|
"""Reset the iterator"""
|
||||||
self.it = 0
|
self.it = 0
|
||||||
|
|
||||||
def next(self, input_data):
|
def next(self, input_data):
|
||||||
'''Yield next batch of data'''
|
"""Yield next batch of data"""
|
||||||
if self.it == len(self._data):
|
if self.it == len(self._data):
|
||||||
# Return 0 when there's no more batch.
|
# Return 0 when there's no more batch.
|
||||||
return 0
|
return 0
|
||||||
@ -341,7 +372,7 @@ class IterForDMatrixTest(xgb.core.DataIter):
|
|||||||
def test_from_cudf_iter(enable_categorical):
|
def test_from_cudf_iter(enable_categorical):
|
||||||
rounds = 100
|
rounds = 100
|
||||||
it = IterForDMatrixTest(enable_categorical)
|
it = IterForDMatrixTest(enable_categorical)
|
||||||
params = {"tree_method": "gpu_hist"}
|
params = {"tree_method": "hist", "device": "cuda"}
|
||||||
|
|
||||||
# Use iterator
|
# Use iterator
|
||||||
m_it = xgb.QuantileDMatrix(it, enable_categorical=enable_categorical)
|
m_it = xgb.QuantileDMatrix(it, enable_categorical=enable_categorical)
|
||||||
|
|||||||
@ -1,31 +1,25 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
|
||||||
from test_dmatrix import set_base_margin_info
|
|
||||||
|
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.data import run_base_margin_info
|
||||||
|
|
||||||
cupy = pytest.importorskip("cupy")
|
cp = pytest.importorskip("cupy")
|
||||||
|
|
||||||
|
|
||||||
def test_array_interface() -> None:
|
def test_array_interface() -> None:
|
||||||
arr = cupy.array([[1, 2, 3, 4], [1, 2, 3, 4]])
|
arr = cp.array([[1, 2, 3, 4], [1, 2, 3, 4]])
|
||||||
i_arr = arr.__cuda_array_interface__
|
i_arr = arr.__cuda_array_interface__
|
||||||
i_arr = json.loads(json.dumps(i_arr))
|
i_arr = json.loads(json.dumps(i_arr))
|
||||||
ret = xgb.core.from_array_interface(i_arr)
|
ret = xgb.core.from_array_interface(i_arr)
|
||||||
np.testing.assert_equal(cupy.asnumpy(arr), cupy.asnumpy(ret))
|
np.testing.assert_equal(cp.asnumpy(arr), cp.asnumpy(ret))
|
||||||
|
|
||||||
|
|
||||||
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
||||||
'''Test constructing DMatrix from cupy'''
|
"""Test constructing DMatrix from cupy"""
|
||||||
import cupy as cp
|
|
||||||
|
|
||||||
kRows = 80
|
kRows = 80
|
||||||
kCols = 3
|
kCols = 3
|
||||||
|
|
||||||
@ -51,9 +45,7 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
|||||||
|
|
||||||
|
|
||||||
def _test_from_cupy(DMatrixT):
|
def _test_from_cupy(DMatrixT):
|
||||||
'''Test constructing DMatrix from cupy'''
|
"""Test constructing DMatrix from cupy"""
|
||||||
import cupy as cp
|
|
||||||
|
|
||||||
dmatrix_from_cupy(np.float16, DMatrixT, np.NAN)
|
dmatrix_from_cupy(np.float16, DMatrixT, np.NAN)
|
||||||
dmatrix_from_cupy(np.float32, DMatrixT, np.NAN)
|
dmatrix_from_cupy(np.float32, DMatrixT, np.NAN)
|
||||||
dmatrix_from_cupy(np.float64, DMatrixT, np.NAN)
|
dmatrix_from_cupy(np.float64, DMatrixT, np.NAN)
|
||||||
@ -73,7 +65,6 @@ def _test_from_cupy(DMatrixT):
|
|||||||
|
|
||||||
|
|
||||||
def _test_cupy_training(DMatrixT):
|
def _test_cupy_training(DMatrixT):
|
||||||
import cupy as cp
|
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
cp.random.seed(1)
|
cp.random.seed(1)
|
||||||
X = cp.random.randn(50, 10, dtype="float32")
|
X = cp.random.randn(50, 10, dtype="float32")
|
||||||
@ -85,19 +76,23 @@ def _test_cupy_training(DMatrixT):
|
|||||||
|
|
||||||
evals_result_cupy = {}
|
evals_result_cupy = {}
|
||||||
dtrain_cp = DMatrixT(X, y, weight=cupy_weights, base_margin=cupy_base_margin)
|
dtrain_cp = DMatrixT(X, y, weight=cupy_weights, base_margin=cupy_base_margin)
|
||||||
params = {'gpu_id': 0, 'nthread': 1, 'tree_method': 'gpu_hist'}
|
params = {"tree_method": "hist", "device": "cuda:0"}
|
||||||
xgb.train(params, dtrain_cp, evals=[(dtrain_cp, "train")],
|
xgb.train(
|
||||||
evals_result=evals_result_cupy)
|
params, dtrain_cp, evals=[(dtrain_cp, "train")], evals_result=evals_result_cupy
|
||||||
|
)
|
||||||
evals_result_np = {}
|
evals_result_np = {}
|
||||||
dtrain_np = xgb.DMatrix(cp.asnumpy(X), cp.asnumpy(y), weight=weights,
|
dtrain_np = xgb.DMatrix(
|
||||||
base_margin=base_margin)
|
cp.asnumpy(X), cp.asnumpy(y), weight=weights, base_margin=base_margin
|
||||||
xgb.train(params, dtrain_np, evals=[(dtrain_np, "train")],
|
)
|
||||||
evals_result=evals_result_np)
|
xgb.train(
|
||||||
assert np.array_equal(evals_result_cupy["train"]["rmse"], evals_result_np["train"]["rmse"])
|
params, dtrain_np, evals=[(dtrain_np, "train")], evals_result=evals_result_np
|
||||||
|
)
|
||||||
|
assert np.array_equal(
|
||||||
|
evals_result_cupy["train"]["rmse"], evals_result_np["train"]["rmse"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_cupy_metainfo(DMatrixT):
|
def _test_cupy_metainfo(DMatrixT):
|
||||||
import cupy as cp
|
|
||||||
n = 100
|
n = 100
|
||||||
X = np.random.random((n, 2))
|
X = np.random.random((n, 2))
|
||||||
dmat_cupy = DMatrixT(cp.array(X))
|
dmat_cupy = DMatrixT(cp.array(X))
|
||||||
@ -106,33 +101,35 @@ def _test_cupy_metainfo(DMatrixT):
|
|||||||
uints = np.array([4, 2, 8]).astype("uint32")
|
uints = np.array([4, 2, 8]).astype("uint32")
|
||||||
cupy_floats = cp.array(floats)
|
cupy_floats = cp.array(floats)
|
||||||
cupy_uints = cp.array(uints)
|
cupy_uints = cp.array(uints)
|
||||||
dmat.set_float_info('weight', floats)
|
dmat.set_float_info("weight", floats)
|
||||||
dmat.set_float_info('label', floats)
|
dmat.set_float_info("label", floats)
|
||||||
dmat.set_float_info('base_margin', floats)
|
dmat.set_float_info("base_margin", floats)
|
||||||
dmat.set_uint_info('group', uints)
|
dmat.set_uint_info("group", uints)
|
||||||
dmat_cupy.set_info(weight=cupy_floats)
|
dmat_cupy.set_info(weight=cupy_floats)
|
||||||
dmat_cupy.set_info(label=cupy_floats)
|
dmat_cupy.set_info(label=cupy_floats)
|
||||||
dmat_cupy.set_info(base_margin=cupy_floats)
|
dmat_cupy.set_info(base_margin=cupy_floats)
|
||||||
dmat_cupy.set_info(group=cupy_uints)
|
dmat_cupy.set_info(group=cupy_uints)
|
||||||
|
|
||||||
# Test setting info with cupy
|
# Test setting info with cupy
|
||||||
assert np.array_equal(dmat.get_float_info('weight'),
|
assert np.array_equal(
|
||||||
dmat_cupy.get_float_info('weight'))
|
dmat.get_float_info("weight"), dmat_cupy.get_float_info("weight")
|
||||||
assert np.array_equal(dmat.get_float_info('label'),
|
)
|
||||||
dmat_cupy.get_float_info('label'))
|
assert np.array_equal(
|
||||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
dmat.get_float_info("label"), dmat_cupy.get_float_info("label")
|
||||||
dmat_cupy.get_float_info('base_margin'))
|
)
|
||||||
assert np.array_equal(dmat.get_uint_info('group_ptr'),
|
assert np.array_equal(
|
||||||
dmat_cupy.get_uint_info('group_ptr'))
|
dmat.get_float_info("base_margin"), dmat_cupy.get_float_info("base_margin")
|
||||||
|
)
|
||||||
|
assert np.array_equal(
|
||||||
|
dmat.get_uint_info("group_ptr"), dmat_cupy.get_uint_info("group_ptr")
|
||||||
|
)
|
||||||
|
|
||||||
set_base_margin_info(cp.asarray, DMatrixT, "gpu_hist")
|
run_base_margin_info(cp.asarray, DMatrixT, "cuda")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_cupy_training_with_sklearn():
|
def test_cupy_training_with_sklearn():
|
||||||
import cupy as cp
|
|
||||||
|
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
cp.random.seed(1)
|
cp.random.seed(1)
|
||||||
X = cp.random.randn(50, 10, dtype="float32")
|
X = cp.random.randn(50, 10, dtype="float32")
|
||||||
@ -142,7 +139,7 @@ def test_cupy_training_with_sklearn():
|
|||||||
base_margin = np.random.random(50)
|
base_margin = np.random.random(50)
|
||||||
cupy_base_margin = cp.array(base_margin)
|
cupy_base_margin = cp.array(base_margin)
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist")
|
clf = xgb.XGBClassifier(tree_method="hist", device="cuda:0")
|
||||||
clf.fit(
|
clf.fit(
|
||||||
X,
|
X,
|
||||||
y,
|
y,
|
||||||
@ -155,8 +152,8 @@ def test_cupy_training_with_sklearn():
|
|||||||
|
|
||||||
|
|
||||||
class TestFromCupy:
|
class TestFromCupy:
|
||||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
"""Tests for constructing DMatrix from data structure conforming Apache
|
||||||
Arrow specification.'''
|
Arrow specification."""
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_simple_dmat_from_cupy(self):
|
def test_simple_dmat_from_cupy(self):
|
||||||
@ -184,19 +181,17 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dlpack_simple_dmat(self):
|
def test_dlpack_simple_dmat(self):
|
||||||
import cupy as cp
|
|
||||||
n = 100
|
n = 100
|
||||||
X = cp.random.random((n, 2))
|
X = cp.random.random((n, 2))
|
||||||
xgb.DMatrix(X.toDlpack())
|
xgb.DMatrix(X.toDlpack())
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_cupy_categorical(self):
|
def test_cupy_categorical(self):
|
||||||
import cupy as cp
|
|
||||||
n_features = 10
|
n_features = 10
|
||||||
X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False)
|
X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False)
|
||||||
X = cp.asarray(X.values.astype(cp.float32))
|
X = cp.asarray(X.values.astype(cp.float32))
|
||||||
y = cp.array(y)
|
y = cp.array(y)
|
||||||
feature_types = ['c'] * n_features
|
feature_types = ["c"] * n_features
|
||||||
|
|
||||||
assert isinstance(X, cp.ndarray)
|
assert isinstance(X, cp.ndarray)
|
||||||
Xy = xgb.DMatrix(X, y, feature_types=feature_types)
|
Xy = xgb.DMatrix(X, y, feature_types=feature_types)
|
||||||
@ -204,7 +199,6 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dlpack_device_dmat(self):
|
def test_dlpack_device_dmat(self):
|
||||||
import cupy as cp
|
|
||||||
n = 100
|
n = 100
|
||||||
X = cp.random.random((n, 2))
|
X = cp.random.random((n, 2))
|
||||||
m = xgb.QuantileDMatrix(X.toDlpack())
|
m = xgb.QuantileDMatrix(X.toDlpack())
|
||||||
@ -213,7 +207,6 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_qid(self):
|
def test_qid(self):
|
||||||
import cupy as cp
|
|
||||||
rng = cp.random.RandomState(1994)
|
rng = cp.random.RandomState(1994)
|
||||||
rows = 100
|
rows = 100
|
||||||
cols = 10
|
cols = 10
|
||||||
@ -223,19 +216,16 @@ Arrow specification.'''
|
|||||||
|
|
||||||
Xy = xgb.DMatrix(X, y)
|
Xy = xgb.DMatrix(X, y)
|
||||||
Xy.set_info(qid=qid)
|
Xy.set_info(qid=qid)
|
||||||
group_ptr = Xy.get_uint_info('group_ptr')
|
group_ptr = Xy.get_uint_info("group_ptr")
|
||||||
assert group_ptr[0] == 0
|
assert group_ptr[0] == 0
|
||||||
assert group_ptr[-1] == rows
|
assert group_ptr[-1] == rows
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
@pytest.mark.mgpu
|
@pytest.mark.mgpu
|
||||||
def test_specified_device(self):
|
def test_specified_device(self):
|
||||||
import cupy as cp
|
|
||||||
cp.cuda.runtime.setDevice(0)
|
cp.cuda.runtime.setDevice(0)
|
||||||
dtrain = dmatrix_from_cupy(np.float32, xgb.QuantileDMatrix, np.nan)
|
dtrain = dmatrix_from_cupy(np.float32, xgb.QuantileDMatrix, np.nan)
|
||||||
with pytest.raises(
|
with pytest.raises(xgb.core.XGBoostError, match="Invalid device ordinal"):
|
||||||
xgb.core.XGBoostError, match="Invalid device ordinal"
|
|
||||||
):
|
|
||||||
xgb.train(
|
xgb.train(
|
||||||
{'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10
|
{"tree_method": "hist", "device": "cuda:1"}, dtrain, num_boost_round=10
|
||||||
)
|
)
|
||||||
|
|||||||
@ -21,21 +21,21 @@ class TestGPUBasicModels:
|
|||||||
cpu_test_bm = test_bm.TestModels()
|
cpu_test_bm = test_bm.TestModels()
|
||||||
|
|
||||||
def run_cls(self, X, y):
|
def run_cls(self, X, y):
|
||||||
cls = xgb.XGBClassifier(tree_method='gpu_hist')
|
cls = xgb.XGBClassifier(tree_method="hist", device="cuda")
|
||||||
cls.fit(X, y)
|
cls.fit(X, y)
|
||||||
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
|
cls.get_booster().save_model("test_deterministic_gpu_hist-0.json")
|
||||||
|
|
||||||
cls = xgb.XGBClassifier(tree_method='gpu_hist')
|
cls = xgb.XGBClassifier(tree_method="hist", device="cuda")
|
||||||
cls.fit(X, y)
|
cls.fit(X, y)
|
||||||
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
|
cls.get_booster().save_model("test_deterministic_gpu_hist-1.json")
|
||||||
|
|
||||||
with open('test_deterministic_gpu_hist-0.json', 'r') as fd:
|
with open("test_deterministic_gpu_hist-0.json", "r") as fd:
|
||||||
model_0 = fd.read()
|
model_0 = fd.read()
|
||||||
with open('test_deterministic_gpu_hist-1.json', 'r') as fd:
|
with open("test_deterministic_gpu_hist-1.json", "r") as fd:
|
||||||
model_1 = fd.read()
|
model_1 = fd.read()
|
||||||
|
|
||||||
os.remove('test_deterministic_gpu_hist-0.json')
|
os.remove("test_deterministic_gpu_hist-0.json")
|
||||||
os.remove('test_deterministic_gpu_hist-1.json')
|
os.remove("test_deterministic_gpu_hist-1.json")
|
||||||
|
|
||||||
return hash(model_0), hash(model_1)
|
return hash(model_0), hash(model_1)
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ class TestGPUBasicModels:
|
|||||||
self.cpu_test_bm.run_custom_objective("gpu_hist")
|
self.cpu_test_bm.run_custom_objective("gpu_hist")
|
||||||
|
|
||||||
def test_eta_decay(self):
|
def test_eta_decay(self):
|
||||||
self.cpu_test_cb.run_eta_decay('gpu_hist')
|
self.cpu_test_cb.run_eta_decay("gpu_hist")
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"objective", ["binary:logistic", "reg:absoluteerror", "reg:quantileerror"]
|
"objective", ["binary:logistic", "reg:absoluteerror", "reg:quantileerror"]
|
||||||
|
|||||||
@ -12,18 +12,18 @@ import test_demos as td # noqa
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_data_iterator():
|
def test_data_iterator():
|
||||||
script = os.path.join(td.PYTHON_DEMO_DIR, 'quantile_data_iterator.py')
|
script = os.path.join(td.PYTHON_DEMO_DIR, "quantile_data_iterator.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_update_process_demo():
|
def test_update_process_demo():
|
||||||
script = os.path.join(td.PYTHON_DEMO_DIR, 'update_process.py')
|
script = os.path.join(td.PYTHON_DEMO_DIR, "update_process.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_categorical_demo():
|
def test_categorical_demo():
|
||||||
script = os.path.join(td.PYTHON_DEMO_DIR, 'categorical.py')
|
script = os.path.join(td.PYTHON_DEMO_DIR, "categorical.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|||||||
@ -6,22 +6,29 @@ from xgboost import testing as tm
|
|||||||
|
|
||||||
pytestmark = tm.timeout(10)
|
pytestmark = tm.timeout(10)
|
||||||
|
|
||||||
parameter_strategy = strategies.fixed_dictionaries({
|
parameter_strategy = strategies.fixed_dictionaries(
|
||||||
'booster': strategies.just('gblinear'),
|
{
|
||||||
'eta': strategies.floats(0.01, 0.25),
|
"booster": strategies.just("gblinear"),
|
||||||
'tolerance': strategies.floats(1e-5, 1e-2),
|
"eta": strategies.floats(0.01, 0.25),
|
||||||
'nthread': strategies.integers(1, 4),
|
"tolerance": strategies.floats(1e-5, 1e-2),
|
||||||
'feature_selector': strategies.sampled_from(['cyclic', 'shuffle',
|
"nthread": strategies.integers(1, 4),
|
||||||
'greedy', 'thrifty']),
|
"feature_selector": strategies.sampled_from(
|
||||||
'top_k': strategies.integers(1, 10),
|
["cyclic", "shuffle", "greedy", "thrifty"]
|
||||||
})
|
),
|
||||||
|
"top_k": strategies.integers(1, 10),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_result(param, dmat, num_rounds):
|
def train_result(param, dmat, num_rounds):
|
||||||
result = {}
|
result = {}
|
||||||
booster = xgb.train(
|
booster = xgb.train(
|
||||||
param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
|
param,
|
||||||
evals_result=result
|
dmat,
|
||||||
|
num_rounds,
|
||||||
|
[(dmat, "train")],
|
||||||
|
verbose_eval=False,
|
||||||
|
evals_result=result,
|
||||||
)
|
)
|
||||||
assert booster.num_boosted_rounds() == num_rounds
|
assert booster.num_boosted_rounds() == num_rounds
|
||||||
return result
|
return result
|
||||||
@ -32,9 +39,11 @@ class TestGPULinear:
|
|||||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||||
def test_gpu_coordinate(self, param, num_rounds, dataset):
|
def test_gpu_coordinate(self, param, num_rounds, dataset):
|
||||||
assume(len(dataset.y) > 0)
|
assume(len(dataset.y) > 0)
|
||||||
param['updater'] = 'gpu_coord_descent'
|
param["updater"] = "gpu_coord_descent"
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)['train'][dataset.metric]
|
result = train_result(param, dataset.get_dmat(), num_rounds)["train"][
|
||||||
|
dataset.metric
|
||||||
|
]
|
||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing(result)
|
assert tm.non_increasing(result)
|
||||||
|
|
||||||
@ -46,16 +55,18 @@ class TestGPULinear:
|
|||||||
strategies.integers(10, 50),
|
strategies.integers(10, 50),
|
||||||
tm.make_dataset_strategy(),
|
tm.make_dataset_strategy(),
|
||||||
strategies.floats(1e-5, 0.8),
|
strategies.floats(1e-5, 0.8),
|
||||||
strategies.floats(1e-5, 0.8)
|
strategies.floats(1e-5, 0.8),
|
||||||
)
|
)
|
||||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||||
def test_gpu_coordinate_regularised(self, param, num_rounds, dataset, alpha, lambd):
|
def test_gpu_coordinate_regularised(self, param, num_rounds, dataset, alpha, lambd):
|
||||||
assume(len(dataset.y) > 0)
|
assume(len(dataset.y) > 0)
|
||||||
param['updater'] = 'gpu_coord_descent'
|
param["updater"] = "gpu_coord_descent"
|
||||||
param['alpha'] = alpha
|
param["alpha"] = alpha
|
||||||
param['lambda'] = lambd
|
param["lambda"] = lambd
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)['train'][dataset.metric]
|
result = train_result(param, dataset.get_dmat(), num_rounds)["train"][
|
||||||
|
dataset.metric
|
||||||
|
]
|
||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing([result[0], result[-1]])
|
assert tm.non_increasing([result[0], result[-1]])
|
||||||
|
|
||||||
@ -64,8 +75,12 @@ class TestGPULinear:
|
|||||||
# Training linear model is quite expensive, so we don't include it in
|
# Training linear model is quite expensive, so we don't include it in
|
||||||
# test_from_cupy.py
|
# test_from_cupy.py
|
||||||
import cupy
|
import cupy
|
||||||
params = {'booster': 'gblinear', 'updater': 'gpu_coord_descent',
|
|
||||||
'n_estimators': 100}
|
params = {
|
||||||
|
"booster": "gblinear",
|
||||||
|
"updater": "gpu_coord_descent",
|
||||||
|
"n_estimators": 100,
|
||||||
|
}
|
||||||
X, y = tm.get_california_housing()
|
X, y = tm.get_california_housing()
|
||||||
cpu_model = xgb.XGBRegressor(**params)
|
cpu_model = xgb.XGBRegressor(**params)
|
||||||
cpu_model.fit(X, y)
|
cpu_model.fit(X, y)
|
||||||
|
|||||||
@ -14,14 +14,18 @@ class TestGPUTrainingContinuation:
|
|||||||
X = np.random.randn(kRows, kCols)
|
X = np.random.randn(kRows, kCols)
|
||||||
y = np.random.randn(kRows)
|
y = np.random.randn(kRows)
|
||||||
dtrain = xgb.DMatrix(X, y)
|
dtrain = xgb.DMatrix(X, y)
|
||||||
params = {'tree_method': 'gpu_hist', 'max_depth': '2',
|
params = {
|
||||||
'gamma': '0.1', 'alpha': '0.01'}
|
"tree_method": "gpu_hist",
|
||||||
|
"max_depth": "2",
|
||||||
|
"gamma": "0.1",
|
||||||
|
"alpha": "0.01",
|
||||||
|
}
|
||||||
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
|
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
|
||||||
dump_0 = bst_0.get_dump(dump_format='json')
|
dump_0 = bst_0.get_dump(dump_format="json")
|
||||||
|
|
||||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
|
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
|
||||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
|
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
|
||||||
dump_1 = bst_1.get_dump(dump_format='json')
|
dump_1 = bst_1.get_dump(dump_format="json")
|
||||||
|
|
||||||
def recursive_compare(obj_0, obj_1):
|
def recursive_compare(obj_0, obj_1):
|
||||||
if isinstance(obj_0, float):
|
if isinstance(obj_0, float):
|
||||||
@ -37,9 +41,8 @@ class TestGPUTrainingContinuation:
|
|||||||
values_1 = list(obj_1.values())
|
values_1 = list(obj_1.values())
|
||||||
for i in range(len(obj_0.items())):
|
for i in range(len(obj_0.items())):
|
||||||
assert keys_0[i] == keys_1[i]
|
assert keys_0[i] == keys_1[i]
|
||||||
if list(obj_0.keys())[i] != 'missing':
|
if list(obj_0.keys())[i] != "missing":
|
||||||
recursive_compare(values_0[i],
|
recursive_compare(values_0[i], values_1[i])
|
||||||
values_1[i])
|
|
||||||
else:
|
else:
|
||||||
for i in range(len(obj_0)):
|
for i in range(len(obj_0)):
|
||||||
recursive_compare(obj_0[i], obj_1[i])
|
recursive_compare(obj_0[i], obj_1[i])
|
||||||
|
|||||||
@ -22,12 +22,13 @@ def non_increasing(L):
|
|||||||
|
|
||||||
def assert_constraint(constraint, tree_method):
|
def assert_constraint(constraint, tree_method):
|
||||||
from sklearn.datasets import make_regression
|
from sklearn.datasets import make_regression
|
||||||
|
|
||||||
n = 1000
|
n = 1000
|
||||||
X, y = make_regression(n, random_state=rng, n_features=1, n_informative=1)
|
X, y = make_regression(n, random_state=rng, n_features=1, n_informative=1)
|
||||||
dtrain = xgb.DMatrix(X, y)
|
dtrain = xgb.DMatrix(X, y)
|
||||||
param = {}
|
param = {}
|
||||||
param['tree_method'] = tree_method
|
param["tree_method"] = tree_method
|
||||||
param['monotone_constraints'] = "(" + str(constraint) + ")"
|
param["monotone_constraints"] = "(" + str(constraint) + ")"
|
||||||
bst = xgb.train(param, dtrain)
|
bst = xgb.train(param, dtrain)
|
||||||
dpredict = xgb.DMatrix(X[X[:, 0].argsort()])
|
dpredict = xgb.DMatrix(X[X[:, 0].argsort()])
|
||||||
pred = bst.predict(dpredict)
|
pred = bst.predict(dpredict)
|
||||||
@ -40,15 +41,15 @@ def assert_constraint(constraint, tree_method):
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_gpu_hist_basic():
|
def test_gpu_hist_basic():
|
||||||
assert_constraint(1, 'gpu_hist')
|
assert_constraint(1, "gpu_hist")
|
||||||
assert_constraint(-1, 'gpu_hist')
|
assert_constraint(-1, "gpu_hist")
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_hist_depthwise():
|
def test_gpu_hist_depthwise():
|
||||||
params = {
|
params = {
|
||||||
'tree_method': 'gpu_hist',
|
"tree_method": "gpu_hist",
|
||||||
'grow_policy': 'depthwise',
|
"grow_policy": "depthwise",
|
||||||
'monotone_constraints': '(1, -1)'
|
"monotone_constraints": "(1, -1)",
|
||||||
}
|
}
|
||||||
model = xgb.train(params, tmc.training_dset)
|
model = xgb.train(params, tmc.training_dset)
|
||||||
tmc.is_correctly_constrained(model)
|
tmc.is_correctly_constrained(model)
|
||||||
@ -56,9 +57,9 @@ def test_gpu_hist_depthwise():
|
|||||||
|
|
||||||
def test_gpu_hist_lossguide():
|
def test_gpu_hist_lossguide():
|
||||||
params = {
|
params = {
|
||||||
'tree_method': 'gpu_hist',
|
"tree_method": "gpu_hist",
|
||||||
'grow_policy': 'lossguide',
|
"grow_policy": "lossguide",
|
||||||
'monotone_constraints': '(1, -1)'
|
"monotone_constraints": "(1, -1)",
|
||||||
}
|
}
|
||||||
model = xgb.train(params, tmc.training_dset)
|
model = xgb.train(params, tmc.training_dset)
|
||||||
tmc.is_correctly_constrained(model)
|
tmc.is_correctly_constrained(model)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,59 +11,12 @@ from scipy.sparse import csr_matrix, rand
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.core import DataSplitMode
|
from xgboost.core import DataSplitMode
|
||||||
from xgboost.testing.data import np_dtypes
|
from xgboost.testing.data import np_dtypes, run_base_margin_info
|
||||||
|
|
||||||
rng = np.random.RandomState(1)
|
|
||||||
|
|
||||||
dpath = "demo/data/"
|
dpath = "demo/data/"
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
def set_base_margin_info(DType, DMatrixT, tm: str):
|
|
||||||
rng = np.random.default_rng()
|
|
||||||
X = DType(rng.normal(0, 1.0, size=100).astype(np.float32).reshape(50, 2))
|
|
||||||
if hasattr(X, "iloc"):
|
|
||||||
y = X.iloc[:, 0]
|
|
||||||
else:
|
|
||||||
y = X[:, 0]
|
|
||||||
base_margin = X
|
|
||||||
# no error at set
|
|
||||||
Xy = DMatrixT(X, y, base_margin=base_margin)
|
|
||||||
# Error at train, caused by check in predictor.
|
|
||||||
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
|
||||||
xgb.train({"tree_method": tm}, Xy)
|
|
||||||
|
|
||||||
if not hasattr(X, "iloc"):
|
|
||||||
# column major matrix
|
|
||||||
got = DType(Xy.get_base_margin().reshape(50, 2))
|
|
||||||
assert (got == base_margin).all()
|
|
||||||
|
|
||||||
assert base_margin.T.flags.c_contiguous is False
|
|
||||||
assert base_margin.T.flags.f_contiguous is True
|
|
||||||
Xy.set_info(base_margin=base_margin.T)
|
|
||||||
got = DType(Xy.get_base_margin().reshape(2, 50))
|
|
||||||
assert (got == base_margin.T).all()
|
|
||||||
|
|
||||||
# Row vs col vec.
|
|
||||||
base_margin = y
|
|
||||||
Xy.set_base_margin(base_margin)
|
|
||||||
bm_col = Xy.get_base_margin()
|
|
||||||
Xy.set_base_margin(base_margin.reshape(1, base_margin.size))
|
|
||||||
bm_row = Xy.get_base_margin()
|
|
||||||
assert (bm_row == bm_col).all()
|
|
||||||
|
|
||||||
# type
|
|
||||||
base_margin = base_margin.astype(np.float64)
|
|
||||||
Xy.set_base_margin(base_margin)
|
|
||||||
bm_f64 = Xy.get_base_margin()
|
|
||||||
assert (bm_f64 == bm_col).all()
|
|
||||||
|
|
||||||
# too many dimensions
|
|
||||||
base_margin = X.reshape(2, 5, 2, 5)
|
|
||||||
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
|
||||||
Xy.set_base_margin(base_margin)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDMatrix:
|
class TestDMatrix:
|
||||||
def test_warn_missing(self):
|
def test_warn_missing(self):
|
||||||
from xgboost import data
|
from xgboost import data
|
||||||
@ -417,8 +369,8 @@ class TestDMatrix:
|
|||||||
)
|
)
|
||||||
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||||
|
|
||||||
def test_base_margin(self):
|
def test_base_margin(self) -> None:
|
||||||
set_base_margin_info(np.asarray, xgb.DMatrix, "hist")
|
run_base_margin_info(np.asarray, xgb.DMatrix, "cpu")
|
||||||
|
|
||||||
@given(
|
@given(
|
||||||
strategies.integers(0, 1000),
|
strategies.integers(0, 1000),
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from test_dmatrix import set_base_margin_info
|
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.data import run_base_margin_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import modin.pandas as md
|
import modin.pandas as md
|
||||||
@ -145,4 +145,4 @@ class TestModin:
|
|||||||
np.testing.assert_array_equal(data.get_weight(), w)
|
np.testing.assert_array_equal(data.get_weight(), w)
|
||||||
|
|
||||||
def test_base_margin(self):
|
def test_base_margin(self):
|
||||||
set_base_margin_info(md.DataFrame, xgb.DMatrix, "hist")
|
run_base_margin_info(md.DataFrame, xgb.DMatrix, "cpu")
|
||||||
|
|||||||
@ -1,14 +1,12 @@
|
|||||||
import sys
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from test_dmatrix import set_base_margin_info
|
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.core import DataSplitMode
|
from xgboost.core import DataSplitMode
|
||||||
from xgboost.testing.data import pd_arrow_dtypes, pd_dtypes
|
from xgboost.testing.data import pd_arrow_dtypes, pd_dtypes, run_base_margin_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -336,7 +334,7 @@ class TestPandas:
|
|||||||
np.testing.assert_array_equal(data.get_weight(), w)
|
np.testing.assert_array_equal(data.get_weight(), w)
|
||||||
|
|
||||||
def test_base_margin(self):
|
def test_base_margin(self):
|
||||||
set_base_margin_info(pd.DataFrame, xgb.DMatrix, "hist")
|
run_base_margin_info(pd.DataFrame, xgb.DMatrix, "cpu")
|
||||||
|
|
||||||
def test_cv_as_pandas(self):
|
def test_cv_as_pandas(self):
|
||||||
dm, _ = tm.load_agaricus(__file__)
|
dm, _ = tm.load_agaricus(__file__)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user