Support multi-class with base margin. (#7381)
This is already partially supported but never properly tested. So the only possible way to use it is calling `numpy.ndarray.flatten` with `base_margin` before passing it into XGBoost. This PR adds proper support for most of the data types along with tests.
This commit is contained in:
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from test_dmatrix import set_base_margin_info
|
||||
|
||||
|
||||
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
|
||||
@@ -142,6 +143,8 @@ def _test_cudf_metainfo(DMatrixT):
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
|
||||
set_base_margin_info(df, DMatrixT, "gpu_hist")
|
||||
|
||||
|
||||
class TestFromColumnar:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from test_dmatrix import set_base_margin_info
|
||||
|
||||
|
||||
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
||||
@@ -107,6 +108,8 @@ def _test_cupy_metainfo(DMatrixT):
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'),
|
||||
dmat_cupy.get_uint_info('group_ptr'))
|
||||
|
||||
set_base_margin_info(cp.asarray, DMatrixT, "gpu_hist")
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
|
||||
@@ -22,6 +22,7 @@ from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||
from test_with_dask import run_empty_dmatrix_auc # noqa
|
||||
from test_with_dask import run_auc # noqa
|
||||
from test_with_dask import run_boost_from_prediction # noqa
|
||||
from test_with_dask import run_boost_from_prediction_multi_clasas # noqa
|
||||
from test_with_dask import run_dask_classifier # noqa
|
||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||
from test_with_dask import _get_client_workers # noqa
|
||||
@@ -297,13 +298,18 @@ def run_gpu_hist(
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
import cudf
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
with Client(local_cuda_cluster) as client:
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction(X, y, "gpu_hist", client)
|
||||
|
||||
X_, y_ = load_digits(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction_multi_clasas(X, y, "gpu_hist", client)
|
||||
|
||||
|
||||
class TestDistributedGPU:
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
@@ -35,8 +35,25 @@ def test_gpu_binary_classification():
|
||||
assert err < 0.1
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_boost_from_prediction_gpu_hist():
|
||||
twskl.run_boost_from_prediction('gpu_hist')
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
import cupy as cp
|
||||
import cudf
|
||||
|
||||
tree_method = "gpu_hist"
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = cp.array(X), cp.array(y)
|
||||
|
||||
twskl.run_boost_from_prediction_binary(tree_method, X, y, None)
|
||||
twskl.run_boost_from_prediction_binary(tree_method, X, y, cudf.DataFrame)
|
||||
|
||||
X, y = load_digits(return_X_y=True)
|
||||
X, y = cp.array(X), cp.array(y)
|
||||
|
||||
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, None)
|
||||
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, cudf.DataFrame)
|
||||
|
||||
|
||||
def test_num_parallel_tree():
|
||||
|
||||
Reference in New Issue
Block a user