Support multi-class with base margin. (#7381)

This is already partially supported but never properly tested. So the only possible way to use it is calling `numpy.ndarray.flatten` with `base_margin` before passing it into XGBoost. This PR adds proper support
for most of the data types along with tests.
This commit is contained in:
Jiaming Yuan 2021-11-02 13:38:00 +08:00 committed by GitHub
parent 6295dc3b67
commit a13321148a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 274 additions and 92 deletions

View File

@ -577,7 +577,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
# force into void_p, mac need to pass things in as void_p
if data is None:
self.handle = None
self.handle: Optional[ctypes.c_void_p] = None
return
from .data import dispatch_data_backend, _is_iter

View File

@ -1432,9 +1432,7 @@ def inplace_predict( # pylint: disable=unused-argument
Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan.
base_margin:
See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well
supported with base_margin as it requires the size of base margin to be `n_classes
* n_samples`.
See :py:obj:`xgboost.DMatrix` for details.
.. versionadded:: 1.4.0

View File

@ -18,6 +18,11 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
CAT_T = "c"
# meta info that can be a matrix instead of vector.
# For now it's base_margin for multi-class, but it can be extended to label once we have
# multi-output.
_matrix_meta = {"base_margin"}
def _warn_unused_missing(data, missing):
if (missing is not None) and (not np.isnan(missing)):
@ -217,7 +222,7 @@ _pandas_dtype_mapper = {
}
def _invalid_dataframe_dtype(data) -> None:
def _invalid_dataframe_dtype(data: Any) -> None:
# pandas series has `dtypes` but it's just a single object
# cudf series doesn't have `dtypes`.
if hasattr(data, "dtypes") and hasattr(data.dtypes, "__iter__"):
@ -291,7 +296,7 @@ def _transform_pandas_df(
else:
transformed = data
if meta and len(data.columns) > 1:
if meta and len(data.columns) > 1 and meta not in _matrix_meta:
raise ValueError(f"DataFrame for {meta} cannot have multiple columns")
dtype = meta_type if meta_type else np.float32
@ -323,6 +328,18 @@ def _is_pandas_series(data):
return isinstance(data, pd.Series)
def _meta_from_pandas_series(
data, name: str, dtype: Optional[str], handle: ctypes.c_void_p
) -> None:
"""Help transform pandas series for meta data like labels"""
data = data.values.astype('float')
from pandas.api.types import is_sparse
if is_sparse(data):
data = data.to_dense()
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
_meta_from_numpy(data, name, dtype, handle)
def _is_modin_series(data):
try:
import modin.pandas as pd
@ -374,9 +391,9 @@ def _transform_dt_df(
):
"""Validate feature names and types if data table"""
if meta and data.shape[1] > 1:
raise ValueError(
'DataTable for label or weight cannot have multiple columns')
raise ValueError('DataTable for meta info cannot have multiple columns')
if meta:
meta_type = "float" if meta_type is None else meta_type
# below requires new dt version
# extract first column
data = data.to_numpy()[:, 0].astype(meta_type)
@ -820,19 +837,27 @@ def _to_data_type(dtype: str, name: str):
return dtype_map[dtype]
def _validate_meta_shape(data, name: str) -> None:
def _validate_meta_shape(data: Any, name: str) -> None:
if hasattr(data, "shape"):
msg = f"Invalid shape: {data.shape} for {name}"
if name in _matrix_meta:
if len(data.shape) > 2:
raise ValueError(msg)
return
if len(data.shape) > 2 or (
len(data.shape) == 2 and (data.shape[1] != 0 and data.shape[1] != 1)
):
raise ValueError(f"Invalid shape: {data.shape} for {name}")
def _meta_from_numpy(data, field, dtype, handle):
def _meta_from_numpy(
data: np.ndarray, field: str, dtype, handle: ctypes.c_void_p
) -> None:
data = _maybe_np_slice(data, dtype)
interface = data.__array_interface__
assert interface.get('mask', None) is None, 'Masked array is not supported'
size = data.shape[0]
size = data.size
c_type = _to_data_type(str(data.dtype), field)
ptr = interface['data'][0]
@ -855,17 +880,13 @@ def _meta_from_tuple(data, field, dtype, handle):
return _meta_from_list(data, field, dtype, handle)
def _meta_from_cudf_df(data, field, handle):
if len(data.columns) != 1:
raise ValueError(
'Expecting meta-info to contain a single column')
data = data[data.columns[0]]
interface = bytes(json.dumps([data.__cuda_array_interface__],
indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
c_str(field),
interface))
def _meta_from_cudf_df(data, field: str, handle: ctypes.c_void_p) -> None:
if field not in _matrix_meta:
_meta_from_cudf_series(data.iloc[:, 0], field, handle)
else:
data = data.values
interface = _cuda_array_interface(data)
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
def _meta_from_cudf_series(data, field, handle):
@ -885,14 +906,15 @@ def _meta_from_cupy_array(data, field, handle):
interface))
def _meta_from_dt(data, field, dtype, handle):
data, _, _ = _transform_dt_df(data, None, None)
def _meta_from_dt(data, field: str, dtype, handle: ctypes.c_void_p):
data, _, _ = _transform_dt_df(data, None, None, field, dtype)
_meta_from_numpy(data, field, dtype, handle)
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
'''Dispatch for meta info.'''
handle = matrix.handle
assert handle is not None
_validate_meta_shape(data, name)
if data is None:
return
@ -911,9 +933,7 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
_meta_from_numpy(data, name, dtype, handle)
return
if _is_pandas_series(data):
data = data.values.astype('float')
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
_meta_from_numpy(data, name, dtype, handle)
_meta_from_pandas_series(data, name, dtype, handle)
return
if _is_dlpack(data):
data = _transform_dlpack(data)

View File

@ -210,27 +210,28 @@ class ArrayInterfaceHandler {
}
static void ExtractStride(std::map<std::string, Json> const &column,
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
size_t *stride_r, size_t *stride_c, size_t rows,
size_t cols, size_t itemsize) {
auto strides_it = column.find("strides");
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
// default strides
strides[0] = cols;
strides[1] = 1;
*stride_r = cols;
*stride_c = 1;
} else {
// strides specified by the array interface
auto const &j_strides = get<Array const>(strides_it->second);
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
*stride_r = get<Integer const>(j_strides[0]) / itemsize;
size_t n = 1;
if (j_strides.size() == 2) {
n = get<Integer const>(j_strides[1]) / itemsize;
}
strides[1] = n;
*stride_c = n;
}
auto valid = rows * strides[0] + cols * strides[1] >= (rows * cols);
auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols);
CHECK(valid) << "Invalid strides in array."
<< " strides: (" << strides[0] << "," << strides[1]
<< " strides: (" << (*stride_r) << "," << (*stride_c)
<< "), shape: (" << rows << ", " << cols << ")";
}
@ -281,8 +282,8 @@ class ArrayInterface {
<< "Masked array is not yet supported.";
}
ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
typestr[2] - '0');
ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col,
num_rows, num_cols, typestr[2] - '0');
auto stream_it = array.find("stream");
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
@ -323,8 +324,8 @@ class ArrayInterface {
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
num_cols = 1;
strides[0] = std::max(strides[0], strides[1]);
strides[1] = 1;
stride_row = std::max(stride_row, stride_col);
stride_col = 1;
}
void AssignType(StringView typestr) {
@ -406,13 +407,14 @@ class ArrayInterface {
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
return this->DispatchCall(
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
[=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; });
}
RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
size_t strides[2]{0, 0};
size_t stride_row{0};
size_t stride_col{0};
void* data;
Type type;
};

View File

@ -30,12 +30,16 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
return;
}
out->SetDevice(ptr_device);
out->Resize(column.num_rows);
size_t size = column.num_rows * column.num_cols;
CHECK_NE(size, 0);
out->Resize(size);
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
dh::LaunchN(column.num_rows, [=] __device__(size_t idx) {
p_dst[idx] = column.GetElement(idx, 0);
dh::LaunchN(size, [=] __device__(size_t idx) {
size_t ridx = idx / column.num_cols;
size_t cidx = idx - (ridx * column.num_cols);
p_dst[idx] = column.GetElement(ridx, cidx);
});
}
@ -126,16 +130,8 @@ void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_);
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface);
CHECK_EQ(j_arr.size(), 1)
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
ArrayInterface array_interface(interface_str);
std::string key{c_key};
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
// Not an empty column, transform it.
array_interface.AsColumnVector();
}
CHECK(!array_interface.valid.Data())
<< "Meta info " << key << " should be dense, found validity mask";
@ -143,6 +139,18 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
return;
}
if (key == "base_margin") {
CopyInfoImpl(array_interface, &base_margin_);
return;
}
CHECK(array_interface.num_cols == 1 || array_interface.num_rows == 1)
<< "MetaInfo: " << c_key << " has invalid shape";
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
// Not an empty column, transform it.
array_interface.AsColumnVector();
}
if (key == "label") {
CopyInfoImpl(array_interface, &labels_);
auto ptr = labels_.ConstDevicePointer();
@ -155,8 +163,6 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(),
WeightsCheck{});
CHECK(valid) << "Weights must be positive values.";
} else if (key == "base_margin") {
CopyInfoImpl(array_interface, &base_margin_);
} else if (key == "group") {
CopyGroupInfoImpl(array_interface, &group_ptr_);
ValidateQueryGroup(group_ptr_);

View File

@ -290,27 +290,16 @@ class CPUPredictor : public Predictor {
const auto& base_margin = info.base_margin_.HostVector();
out_preds->Resize(n);
std::vector<bst_float>& out_preds_h = out_preds->HostVector();
if (base_margin.size() == n) {
CHECK_EQ(out_preds->Size(), n);
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
} else {
if (!base_margin.empty()) {
std::ostringstream oss;
oss << "Ignoring the base margin, since it has incorrect length. "
<< "The base margin must be an array of length ";
if (model.learner_model_param->num_output_group > 1) {
oss << "[num_class] * [number of data points], i.e. "
<< model.learner_model_param->num_output_group << " * " << info.num_row_
<< " = " << n << ". ";
} else {
oss << "[number of data points], i.e. " << info.num_row_ << ". ";
}
oss << "Instead, all data points will use "
<< "base_score = " << model.learner_model_param->base_score;
LOG(WARNING) << oss.str();
}
if (base_margin.empty()) {
std::fill(out_preds_h.begin(), out_preds_h.end(),
model.learner_model_param->base_score);
} else {
std::string expected{
"(" + std::to_string(info.num_row_) + ", " +
std::to_string(model.learner_model_param->num_output_group) + ")"};
CHECK_EQ(base_margin.size(), n)
<< "Invalid shape of base_margin. Expected:" << expected;
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
}
}

View File

@ -938,7 +938,11 @@ class GPUPredictor : public xgboost::Predictor {
out_preds->SetDevice(generic_param_->gpu_id);
out_preds->Resize(n);
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Size(), n);
std::string expected{
"(" + std::to_string(info.num_row_) + ", " +
std::to_string(model.learner_model_param->num_output_group) + ")"};
CHECK_EQ(base_margin.Size(), n)
<< "Invalid shape of base_margin. Expected:" << expected;
out_preds->Copy(base_margin);
} else {
out_preds->Fill(model.learner_model_param->base_score);

View File

@ -252,6 +252,8 @@ TEST(MetaInfo, Validate) {
EXPECT_THROW(info.Validate(1), dmlc::Error);
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
d_groups.SetDevice(0);
d_groups.DevicePointer(); // pull to device
auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1);
std::string arr_interface_str;
xgboost::Json::Dump(arr_interface, &arr_interface_str);

View File

@ -5,6 +5,7 @@ import pytest
sys.path.append("tests/python")
import testing as tm
from test_dmatrix import set_base_margin_info
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
@ -142,6 +143,8 @@ def _test_cudf_metainfo(DMatrixT):
dmat_cudf.get_float_info('base_margin'))
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
set_base_margin_info(df, DMatrixT, "gpu_hist")
class TestFromColumnar:
'''Tests for constructing DMatrix from data structure conforming Apache

View File

@ -5,6 +5,7 @@ import pytest
sys.path.append("tests/python")
import testing as tm
from test_dmatrix import set_base_margin_info
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
@ -107,6 +108,8 @@ def _test_cupy_metainfo(DMatrixT):
assert np.array_equal(dmat.get_uint_info('group_ptr'),
dmat_cupy.get_uint_info('group_ptr'))
set_base_margin_info(cp.asarray, DMatrixT, "gpu_hist")
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_sklearn())

View File

@ -22,6 +22,7 @@ from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_auc # noqa
from test_with_dask import run_auc # noqa
from test_with_dask import run_boost_from_prediction # noqa
from test_with_dask import run_boost_from_prediction_multi_clasas # noqa
from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
@ -297,13 +298,18 @@ def run_gpu_hist(
@pytest.mark.skipif(**tm.no_cudf())
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
import cudf
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_breast_cancer, load_digits
with Client(local_cuda_cluster) as client:
X_, y_ = load_breast_cancer(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction(X, y, "gpu_hist", client)
X_, y_ = load_digits(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction_multi_clasas(X, y, "gpu_hist", client)
class TestDistributedGPU:
@pytest.mark.skipif(**tm.no_dask())

View File

@ -35,8 +35,25 @@ def test_gpu_binary_classification():
assert err < 0.1
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_cudf())
def test_boost_from_prediction_gpu_hist():
twskl.run_boost_from_prediction('gpu_hist')
from sklearn.datasets import load_breast_cancer, load_digits
import cupy as cp
import cudf
tree_method = "gpu_hist"
X, y = load_breast_cancer(return_X_y=True)
X, y = cp.array(X), cp.array(y)
twskl.run_boost_from_prediction_binary(tree_method, X, y, None)
twskl.run_boost_from_prediction_binary(tree_method, X, y, cudf.DataFrame)
X, y = load_digits(return_X_y=True)
X, y = cp.array(X), cp.array(y)
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, None)
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, cudf.DataFrame)
def test_num_parallel_tree():

View File

@ -15,6 +15,24 @@ dpath = 'demo/data/'
rng = np.random.RandomState(1994)
def set_base_margin_info(DType, DMatrixT, tm: str):
rng = np.random.default_rng()
X = DType(rng.normal(0, 1.0, size=100).reshape(50, 2))
if hasattr(X, "iloc"):
y = X.iloc[:, 0]
else:
y = X[:, 0]
base_margin = X
# no error at set
Xy = DMatrixT(X, y, base_margin=base_margin)
# Error at train, caused by check in predictor.
with pytest.raises(ValueError, match=r".*base_margin.*"):
xgb.train({"tree_method": tm}, Xy)
# FIXME(jiamingy): Currently the metainfo has no concept of shape. If you pass a
# base_margin with shape (n_classes, n_samples) to XGBoost the result is undefined.
class TestDMatrix:
def test_warn_missing(self):
from xgboost import data
@ -122,7 +140,7 @@ class TestDMatrix:
# base margin is per-class in multi-class classifier
base_margin = rng.randn(100, 3).astype(np.float32)
d.set_base_margin(base_margin.flatten())
d.set_base_margin(base_margin)
ridxs = [1, 2, 3, 4, 5, 6]
sliced = d.slice(ridxs)
@ -380,3 +398,6 @@ class TestDMatrix:
feature_types = ["q"] * 5 + ["c"] + ["q"] * 120
Xy = xgb.DMatrix(path + "?indexing_mode=1", feature_types=feature_types)
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
def test_base_margin(self):
set_base_margin_info(np.asarray, xgb.DMatrix, "hist")

View File

@ -7,7 +7,7 @@ import sys
import numpy as np
import scipy
import json
from typing import List, Tuple, Dict, Optional, Type, Any
from typing import List, Tuple, Dict, Optional, Type, Any, Callable
import asyncio
from functools import partial
from concurrent.futures import ThreadPoolExecutor
@ -182,6 +182,50 @@ def test_dask_predict_shape_infer(client: "Client") -> None:
assert prediction.shape[1] == 3
def run_boost_from_prediction_multi_clasas(
X: xgb.dask._DaskCollection,
y: xgb.dask._DaskCollection,
tree_method: str,
client: "Client"
) -> None:
model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
)
model_0.fit(X=X, y=y)
margin = xgb.dask.inplace_predict(
client, model_0.get_booster(), X, predict_type="margin"
)
model_1 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
)
model_1.fit(X=X, y=y, base_margin=margin)
predictions_1 = xgb.dask.predict(
client,
model_1.get_booster(),
xgb.dask.DaskDMatrix(client, X, base_margin=margin),
output_margin=True
)
model_2 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method
)
model_2.fit(X=X, y=y)
predictions_2 = xgb.dask.inplace_predict(
client, model_2.get_booster(), X, predict_type="margin"
)
a = predictions_1.compute()
b = predictions_2.compute()
# cupy/cudf
if hasattr(a, "get"):
a = a.get()
if hasattr(b, "values"):
b = b.values
if hasattr(b, "get"):
b = b.get()
np.testing.assert_allclose(a, b, atol=1e-5)
def run_boost_from_prediction(
X: xgb.dask._DaskCollection, y: xgb.dask._DaskCollection, tree_method: str, client: "Client"
) -> None:
@ -227,11 +271,15 @@ def run_boost_from_prediction(
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_breast_cancer, load_digits
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
run_boost_from_prediction(X, y, tree_method, client)
X_, y_ = load_digits(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
run_boost_from_prediction_multi_clasas(X, y, tree_method, client)
def test_inplace_predict(client: "Client") -> None:
from sklearn.datasets import load_boston

View File

@ -3,6 +3,7 @@ import numpy as np
import xgboost as xgb
import testing as tm
import pytest
from test_dmatrix import set_base_margin_info
try:
import modin.pandas as md
@ -144,3 +145,6 @@ class TestModin:
assert data.num_col() == kCols
np.testing.assert_array_equal(data.get_weight(), w)
def test_base_margin(self):
set_base_margin_info(md.DataFrame, xgb.DMatrix, "hist")

View File

@ -3,6 +3,7 @@ import numpy as np
import xgboost as xgb
import testing as tm
import pytest
from test_dmatrix import set_base_margin_info
try:
import pandas as pd
@ -205,6 +206,9 @@ class TestPandas:
np.testing.assert_array_equal(data.get_weight(), w)
def test_base_margin(self):
set_base_margin_info(pd.DataFrame, xgb.DMatrix, "hist")
def test_cv_as_pandas(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,

View File

@ -1,3 +1,4 @@
from typing import Callable, Optional
import collections
import importlib.util
import numpy as np
@ -1147,32 +1148,83 @@ def test_feature_weights():
assert poly_decreasing[0] < -0.08
def run_boost_from_prediction(tree_method):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
def run_boost_from_prediction_binary(tree_method, X, y, as_frame: Optional[Callable]):
"""
Parameters
----------
as_frame: A callable function to convert margin into DataFrame, useful for different
df implementations.
"""
model_0 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4,
tree_method=tree_method)
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
)
model_0.fit(X=X, y=y)
margin = model_0.predict(X, output_margin=True)
if as_frame is not None:
margin = as_frame(margin)
model_1 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4,
tree_method=tree_method)
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
)
model_1.fit(X=X, y=y, base_margin=margin)
predictions_1 = model_1.predict(X, base_margin=margin)
cls_2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8,
tree_method=tree_method)
learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method
)
cls_2.fit(X=X, y=y)
predictions_2 = cls_2.predict(X)
assert np.all(predictions_1 == predictions_2)
np.testing.assert_allclose(predictions_1, predictions_2)
def run_boost_from_prediction_multi_clasas(
tree_method, X, y, as_frame: Optional[Callable]
):
# Multi-class
model_0 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
)
model_0.fit(X=X, y=y)
margin = model_0.get_booster().inplace_predict(X, predict_type="margin")
if as_frame is not None:
margin = as_frame(margin)
model_1 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
)
model_1.fit(X=X, y=y, base_margin=margin)
predictions_1 = model_1.get_booster().predict(
xgb.DMatrix(X, base_margin=margin), output_margin=True
)
model_2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method
)
model_2.fit(X=X, y=y)
predictions_2 = model_2.get_booster().inplace_predict(X, predict_type="margin")
if hasattr(predictions_1, "get"):
predictions_1 = predictions_1.get()
if hasattr(predictions_2, "get"):
predictions_2 = predictions_2.get()
np.testing.assert_allclose(predictions_1, predictions_2, atol=1e-6)
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
def test_boost_from_prediction(tree_method):
run_boost_from_prediction(tree_method)
from sklearn.datasets import load_breast_cancer, load_digits
import pandas as pd
X, y = load_breast_cancer(return_X_y=True)
run_boost_from_prediction_binary(tree_method, X, y, None)
run_boost_from_prediction_binary(tree_method, X, y, pd.DataFrame)
X, y = load_digits(return_X_y=True)
run_boost_from_prediction_multi_clasas(tree_method, X, y, None)
run_boost_from_prediction_multi_clasas(tree_method, X, y, pd.DataFrame)
def test_estimator_type():

View File

@ -3,6 +3,7 @@ import os
import urllib
import zipfile
import sys
from typing import Optional
from contextlib import contextmanager
from io import StringIO
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
@ -177,7 +178,7 @@ class TestDataset:
self.metric = metric
self.X, self.y = get_dataset()
self.w = None
self.margin = None
self.margin: Optional[np.ndarray] = None
def set_params(self, params_in):
params_in['objective'] = self.objective
@ -315,7 +316,7 @@ _unweighted_datasets_strategy = strategies.sampled_from(
@strategies.composite
def _dataset_weight_margin(draw):
data = draw(_unweighted_datasets_strategy)
data: TestDataset = draw(_unweighted_datasets_strategy)
if draw(strategies.booleans()):
data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)))
if draw(strategies.booleans()):
@ -324,6 +325,8 @@ def _dataset_weight_margin(draw):
num_class = int(np.max(data.y) + 1)
data.margin = draw(
arrays(np.float64, (len(data.y) * num_class), elements=strategies.floats(0.5, 1.0)))
if num_class != 1:
data.margin = data.margin.reshape(data.y.shape[0], num_class)
return data