Support multi-class with base margin. (#7381)
This is already partially supported but never properly tested. So the only possible way to use it is calling `numpy.ndarray.flatten` with `base_margin` before passing it into XGBoost. This PR adds proper support for most of the data types along with tests.
This commit is contained in:
parent
6295dc3b67
commit
a13321148a
@ -577,7 +577,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
# force into void_p, mac need to pass things in as void_p
|
||||
if data is None:
|
||||
self.handle = None
|
||||
self.handle: Optional[ctypes.c_void_p] = None
|
||||
return
|
||||
|
||||
from .data import dispatch_data_backend, _is_iter
|
||||
|
||||
@ -1432,9 +1432,7 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
Value in the input data which needs to be present as a missing
|
||||
value. If None, defaults to np.nan.
|
||||
base_margin:
|
||||
See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well
|
||||
supported with base_margin as it requires the size of base margin to be `n_classes
|
||||
* n_samples`.
|
||||
See :py:obj:`xgboost.DMatrix` for details.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
|
||||
@ -18,6 +18,11 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||
|
||||
CAT_T = "c"
|
||||
|
||||
# meta info that can be a matrix instead of vector.
|
||||
# For now it's base_margin for multi-class, but it can be extended to label once we have
|
||||
# multi-output.
|
||||
_matrix_meta = {"base_margin"}
|
||||
|
||||
|
||||
def _warn_unused_missing(data, missing):
|
||||
if (missing is not None) and (not np.isnan(missing)):
|
||||
@ -217,7 +222,7 @@ _pandas_dtype_mapper = {
|
||||
}
|
||||
|
||||
|
||||
def _invalid_dataframe_dtype(data) -> None:
|
||||
def _invalid_dataframe_dtype(data: Any) -> None:
|
||||
# pandas series has `dtypes` but it's just a single object
|
||||
# cudf series doesn't have `dtypes`.
|
||||
if hasattr(data, "dtypes") and hasattr(data.dtypes, "__iter__"):
|
||||
@ -291,7 +296,7 @@ def _transform_pandas_df(
|
||||
else:
|
||||
transformed = data
|
||||
|
||||
if meta and len(data.columns) > 1:
|
||||
if meta and len(data.columns) > 1 and meta not in _matrix_meta:
|
||||
raise ValueError(f"DataFrame for {meta} cannot have multiple columns")
|
||||
|
||||
dtype = meta_type if meta_type else np.float32
|
||||
@ -323,6 +328,18 @@ def _is_pandas_series(data):
|
||||
return isinstance(data, pd.Series)
|
||||
|
||||
|
||||
def _meta_from_pandas_series(
|
||||
data, name: str, dtype: Optional[str], handle: ctypes.c_void_p
|
||||
) -> None:
|
||||
"""Help transform pandas series for meta data like labels"""
|
||||
data = data.values.astype('float')
|
||||
from pandas.api.types import is_sparse
|
||||
if is_sparse(data):
|
||||
data = data.to_dense()
|
||||
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
|
||||
|
||||
def _is_modin_series(data):
|
||||
try:
|
||||
import modin.pandas as pd
|
||||
@ -374,9 +391,9 @@ def _transform_dt_df(
|
||||
):
|
||||
"""Validate feature names and types if data table"""
|
||||
if meta and data.shape[1] > 1:
|
||||
raise ValueError(
|
||||
'DataTable for label or weight cannot have multiple columns')
|
||||
raise ValueError('DataTable for meta info cannot have multiple columns')
|
||||
if meta:
|
||||
meta_type = "float" if meta_type is None else meta_type
|
||||
# below requires new dt version
|
||||
# extract first column
|
||||
data = data.to_numpy()[:, 0].astype(meta_type)
|
||||
@ -820,19 +837,27 @@ def _to_data_type(dtype: str, name: str):
|
||||
return dtype_map[dtype]
|
||||
|
||||
|
||||
def _validate_meta_shape(data, name: str) -> None:
|
||||
def _validate_meta_shape(data: Any, name: str) -> None:
|
||||
if hasattr(data, "shape"):
|
||||
msg = f"Invalid shape: {data.shape} for {name}"
|
||||
if name in _matrix_meta:
|
||||
if len(data.shape) > 2:
|
||||
raise ValueError(msg)
|
||||
return
|
||||
|
||||
if len(data.shape) > 2 or (
|
||||
len(data.shape) == 2 and (data.shape[1] != 0 and data.shape[1] != 1)
|
||||
):
|
||||
raise ValueError(f"Invalid shape: {data.shape} for {name}")
|
||||
|
||||
|
||||
def _meta_from_numpy(data, field, dtype, handle):
|
||||
def _meta_from_numpy(
|
||||
data: np.ndarray, field: str, dtype, handle: ctypes.c_void_p
|
||||
) -> None:
|
||||
data = _maybe_np_slice(data, dtype)
|
||||
interface = data.__array_interface__
|
||||
assert interface.get('mask', None) is None, 'Masked array is not supported'
|
||||
size = data.shape[0]
|
||||
size = data.size
|
||||
|
||||
c_type = _to_data_type(str(data.dtype), field)
|
||||
ptr = interface['data'][0]
|
||||
@ -855,17 +880,13 @@ def _meta_from_tuple(data, field, dtype, handle):
|
||||
return _meta_from_list(data, field, dtype, handle)
|
||||
|
||||
|
||||
def _meta_from_cudf_df(data, field, handle):
|
||||
if len(data.columns) != 1:
|
||||
raise ValueError(
|
||||
'Expecting meta-info to contain a single column')
|
||||
data = data[data.columns[0]]
|
||||
|
||||
interface = bytes(json.dumps([data.__cuda_array_interface__],
|
||||
indent=2), 'utf-8')
|
||||
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
|
||||
c_str(field),
|
||||
interface))
|
||||
def _meta_from_cudf_df(data, field: str, handle: ctypes.c_void_p) -> None:
|
||||
if field not in _matrix_meta:
|
||||
_meta_from_cudf_series(data.iloc[:, 0], field, handle)
|
||||
else:
|
||||
data = data.values
|
||||
interface = _cuda_array_interface(data)
|
||||
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
|
||||
|
||||
|
||||
def _meta_from_cudf_series(data, field, handle):
|
||||
@ -885,14 +906,15 @@ def _meta_from_cupy_array(data, field, handle):
|
||||
interface))
|
||||
|
||||
|
||||
def _meta_from_dt(data, field, dtype, handle):
|
||||
data, _, _ = _transform_dt_df(data, None, None)
|
||||
def _meta_from_dt(data, field: str, dtype, handle: ctypes.c_void_p):
|
||||
data, _, _ = _transform_dt_df(data, None, None, field, dtype)
|
||||
_meta_from_numpy(data, field, dtype, handle)
|
||||
|
||||
|
||||
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
|
||||
'''Dispatch for meta info.'''
|
||||
handle = matrix.handle
|
||||
assert handle is not None
|
||||
_validate_meta_shape(data, name)
|
||||
if data is None:
|
||||
return
|
||||
@ -911,9 +933,7 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
return
|
||||
if _is_pandas_series(data):
|
||||
data = data.values.astype('float')
|
||||
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
_meta_from_pandas_series(data, name, dtype, handle)
|
||||
return
|
||||
if _is_dlpack(data):
|
||||
data = _transform_dlpack(data)
|
||||
|
||||
@ -210,27 +210,28 @@ class ArrayInterfaceHandler {
|
||||
}
|
||||
|
||||
static void ExtractStride(std::map<std::string, Json> const &column,
|
||||
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
|
||||
size_t *stride_r, size_t *stride_c, size_t rows,
|
||||
size_t cols, size_t itemsize) {
|
||||
auto strides_it = column.find("strides");
|
||||
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
|
||||
// default strides
|
||||
strides[0] = cols;
|
||||
strides[1] = 1;
|
||||
*stride_r = cols;
|
||||
*stride_c = 1;
|
||||
} else {
|
||||
// strides specified by the array interface
|
||||
auto const &j_strides = get<Array const>(strides_it->second);
|
||||
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
|
||||
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
|
||||
*stride_r = get<Integer const>(j_strides[0]) / itemsize;
|
||||
size_t n = 1;
|
||||
if (j_strides.size() == 2) {
|
||||
n = get<Integer const>(j_strides[1]) / itemsize;
|
||||
}
|
||||
strides[1] = n;
|
||||
*stride_c = n;
|
||||
}
|
||||
|
||||
auto valid = rows * strides[0] + cols * strides[1] >= (rows * cols);
|
||||
auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols);
|
||||
CHECK(valid) << "Invalid strides in array."
|
||||
<< " strides: (" << strides[0] << "," << strides[1]
|
||||
<< " strides: (" << (*stride_r) << "," << (*stride_c)
|
||||
<< "), shape: (" << rows << ", " << cols << ")";
|
||||
}
|
||||
|
||||
@ -281,8 +282,8 @@ class ArrayInterface {
|
||||
<< "Masked array is not yet supported.";
|
||||
}
|
||||
|
||||
ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
|
||||
typestr[2] - '0');
|
||||
ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col,
|
||||
num_rows, num_cols, typestr[2] - '0');
|
||||
|
||||
auto stream_it = array.find("stream");
|
||||
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
|
||||
@ -323,8 +324,8 @@ class ArrayInterface {
|
||||
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
|
||||
num_cols = 1;
|
||||
|
||||
strides[0] = std::max(strides[0], strides[1]);
|
||||
strides[1] = 1;
|
||||
stride_row = std::max(stride_row, stride_col);
|
||||
stride_col = 1;
|
||||
}
|
||||
|
||||
void AssignType(StringView typestr) {
|
||||
@ -406,13 +407,14 @@ class ArrayInterface {
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
|
||||
return this->DispatchCall(
|
||||
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
|
||||
[=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; });
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
size_t strides[2]{0, 0};
|
||||
size_t stride_row{0};
|
||||
size_t stride_col{0};
|
||||
void* data;
|
||||
Type type;
|
||||
};
|
||||
|
||||
@ -30,12 +30,16 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
|
||||
return;
|
||||
}
|
||||
out->SetDevice(ptr_device);
|
||||
out->Resize(column.num_rows);
|
||||
|
||||
size_t size = column.num_rows * column.num_cols;
|
||||
CHECK_NE(size, 0);
|
||||
out->Resize(size);
|
||||
|
||||
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
|
||||
|
||||
dh::LaunchN(column.num_rows, [=] __device__(size_t idx) {
|
||||
p_dst[idx] = column.GetElement(idx, 0);
|
||||
dh::LaunchN(size, [=] __device__(size_t idx) {
|
||||
size_t ridx = idx / column.num_cols;
|
||||
size_t cidx = idx - (ridx * column.num_cols);
|
||||
p_dst[idx] = column.GetElement(ridx, cidx);
|
||||
});
|
||||
}
|
||||
|
||||
@ -126,16 +130,8 @@ void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_);
|
||||
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
auto const& j_arr = get<Array>(j_interface);
|
||||
CHECK_EQ(j_arr.size(), 1)
|
||||
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
|
||||
ArrayInterface array_interface(interface_str);
|
||||
std::string key{c_key};
|
||||
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
|
||||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
|
||||
// Not an empty column, transform it.
|
||||
array_interface.AsColumnVector();
|
||||
}
|
||||
|
||||
CHECK(!array_interface.valid.Data())
|
||||
<< "Meta info " << key << " should be dense, found validity mask";
|
||||
@ -143,6 +139,18 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (key == "base_margin") {
|
||||
CopyInfoImpl(array_interface, &base_margin_);
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK(array_interface.num_cols == 1 || array_interface.num_rows == 1)
|
||||
<< "MetaInfo: " << c_key << " has invalid shape";
|
||||
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
|
||||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
|
||||
// Not an empty column, transform it.
|
||||
array_interface.AsColumnVector();
|
||||
}
|
||||
if (key == "label") {
|
||||
CopyInfoImpl(array_interface, &labels_);
|
||||
auto ptr = labels_.ConstDevicePointer();
|
||||
@ -155,8 +163,6 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(),
|
||||
WeightsCheck{});
|
||||
CHECK(valid) << "Weights must be positive values.";
|
||||
} else if (key == "base_margin") {
|
||||
CopyInfoImpl(array_interface, &base_margin_);
|
||||
} else if (key == "group") {
|
||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||
ValidateQueryGroup(group_ptr_);
|
||||
|
||||
@ -290,27 +290,16 @@ class CPUPredictor : public Predictor {
|
||||
const auto& base_margin = info.base_margin_.HostVector();
|
||||
out_preds->Resize(n);
|
||||
std::vector<bst_float>& out_preds_h = out_preds->HostVector();
|
||||
if (base_margin.size() == n) {
|
||||
CHECK_EQ(out_preds->Size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
|
||||
} else {
|
||||
if (!base_margin.empty()) {
|
||||
std::ostringstream oss;
|
||||
oss << "Ignoring the base margin, since it has incorrect length. "
|
||||
<< "The base margin must be an array of length ";
|
||||
if (model.learner_model_param->num_output_group > 1) {
|
||||
oss << "[num_class] * [number of data points], i.e. "
|
||||
<< model.learner_model_param->num_output_group << " * " << info.num_row_
|
||||
<< " = " << n << ". ";
|
||||
} else {
|
||||
oss << "[number of data points], i.e. " << info.num_row_ << ". ";
|
||||
}
|
||||
oss << "Instead, all data points will use "
|
||||
<< "base_score = " << model.learner_model_param->base_score;
|
||||
LOG(WARNING) << oss.str();
|
||||
}
|
||||
if (base_margin.empty()) {
|
||||
std::fill(out_preds_h.begin(), out_preds_h.end(),
|
||||
model.learner_model_param->base_score);
|
||||
} else {
|
||||
std::string expected{
|
||||
"(" + std::to_string(info.num_row_) + ", " +
|
||||
std::to_string(model.learner_model_param->num_output_group) + ")"};
|
||||
CHECK_EQ(base_margin.size(), n)
|
||||
<< "Invalid shape of base_margin. Expected:" << expected;
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -938,7 +938,11 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
out_preds->SetDevice(generic_param_->gpu_id);
|
||||
out_preds->Resize(n);
|
||||
if (base_margin.Size() != 0) {
|
||||
CHECK_EQ(base_margin.Size(), n);
|
||||
std::string expected{
|
||||
"(" + std::to_string(info.num_row_) + ", " +
|
||||
std::to_string(model.learner_model_param->num_output_group) + ")"};
|
||||
CHECK_EQ(base_margin.Size(), n)
|
||||
<< "Invalid shape of base_margin. Expected:" << expected;
|
||||
out_preds->Copy(base_margin);
|
||||
} else {
|
||||
out_preds->Fill(model.learner_model_param->base_score);
|
||||
|
||||
@ -252,6 +252,8 @@ TEST(MetaInfo, Validate) {
|
||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||
|
||||
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
|
||||
d_groups.SetDevice(0);
|
||||
d_groups.DevicePointer(); // pull to device
|
||||
auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1);
|
||||
std::string arr_interface_str;
|
||||
xgboost::Json::Dump(arr_interface, &arr_interface_str);
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from test_dmatrix import set_base_margin_info
|
||||
|
||||
|
||||
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
|
||||
@ -142,6 +143,8 @@ def _test_cudf_metainfo(DMatrixT):
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
|
||||
set_base_margin_info(df, DMatrixT, "gpu_hist")
|
||||
|
||||
|
||||
class TestFromColumnar:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from test_dmatrix import set_base_margin_info
|
||||
|
||||
|
||||
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
||||
@ -107,6 +108,8 @@ def _test_cupy_metainfo(DMatrixT):
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'),
|
||||
dmat_cupy.get_uint_info('group_ptr'))
|
||||
|
||||
set_base_margin_info(cp.asarray, DMatrixT, "gpu_hist")
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
|
||||
@ -22,6 +22,7 @@ from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||
from test_with_dask import run_empty_dmatrix_auc # noqa
|
||||
from test_with_dask import run_auc # noqa
|
||||
from test_with_dask import run_boost_from_prediction # noqa
|
||||
from test_with_dask import run_boost_from_prediction_multi_clasas # noqa
|
||||
from test_with_dask import run_dask_classifier # noqa
|
||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||
from test_with_dask import _get_client_workers # noqa
|
||||
@ -297,13 +298,18 @@ def run_gpu_hist(
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
import cudf
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
with Client(local_cuda_cluster) as client:
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction(X, y, "gpu_hist", client)
|
||||
|
||||
X_, y_ = load_digits(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction_multi_clasas(X, y, "gpu_hist", client)
|
||||
|
||||
|
||||
class TestDistributedGPU:
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
@ -35,8 +35,25 @@ def test_gpu_binary_classification():
|
||||
assert err < 0.1
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_boost_from_prediction_gpu_hist():
|
||||
twskl.run_boost_from_prediction('gpu_hist')
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
import cupy as cp
|
||||
import cudf
|
||||
|
||||
tree_method = "gpu_hist"
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = cp.array(X), cp.array(y)
|
||||
|
||||
twskl.run_boost_from_prediction_binary(tree_method, X, y, None)
|
||||
twskl.run_boost_from_prediction_binary(tree_method, X, y, cudf.DataFrame)
|
||||
|
||||
X, y = load_digits(return_X_y=True)
|
||||
X, y = cp.array(X), cp.array(y)
|
||||
|
||||
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, None)
|
||||
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, cudf.DataFrame)
|
||||
|
||||
|
||||
def test_num_parallel_tree():
|
||||
|
||||
@ -15,6 +15,24 @@ dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
def set_base_margin_info(DType, DMatrixT, tm: str):
|
||||
rng = np.random.default_rng()
|
||||
X = DType(rng.normal(0, 1.0, size=100).reshape(50, 2))
|
||||
if hasattr(X, "iloc"):
|
||||
y = X.iloc[:, 0]
|
||||
else:
|
||||
y = X[:, 0]
|
||||
base_margin = X
|
||||
# no error at set
|
||||
Xy = DMatrixT(X, y, base_margin=base_margin)
|
||||
# Error at train, caused by check in predictor.
|
||||
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
||||
xgb.train({"tree_method": tm}, Xy)
|
||||
|
||||
# FIXME(jiamingy): Currently the metainfo has no concept of shape. If you pass a
|
||||
# base_margin with shape (n_classes, n_samples) to XGBoost the result is undefined.
|
||||
|
||||
|
||||
class TestDMatrix:
|
||||
def test_warn_missing(self):
|
||||
from xgboost import data
|
||||
@ -122,7 +140,7 @@ class TestDMatrix:
|
||||
|
||||
# base margin is per-class in multi-class classifier
|
||||
base_margin = rng.randn(100, 3).astype(np.float32)
|
||||
d.set_base_margin(base_margin.flatten())
|
||||
d.set_base_margin(base_margin)
|
||||
|
||||
ridxs = [1, 2, 3, 4, 5, 6]
|
||||
sliced = d.slice(ridxs)
|
||||
@ -380,3 +398,6 @@ class TestDMatrix:
|
||||
feature_types = ["q"] * 5 + ["c"] + ["q"] * 120
|
||||
Xy = xgb.DMatrix(path + "?indexing_mode=1", feature_types=feature_types)
|
||||
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||
|
||||
def test_base_margin(self):
|
||||
set_base_margin_info(np.asarray, xgb.DMatrix, "hist")
|
||||
|
||||
@ -7,7 +7,7 @@ import sys
|
||||
import numpy as np
|
||||
import scipy
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Optional, Type, Any
|
||||
from typing import List, Tuple, Dict, Optional, Type, Any, Callable
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -182,6 +182,50 @@ def test_dask_predict_shape_infer(client: "Client") -> None:
|
||||
assert prediction.shape[1] == 3
|
||||
|
||||
|
||||
def run_boost_from_prediction_multi_clasas(
|
||||
X: xgb.dask._DaskCollection,
|
||||
y: xgb.dask._DaskCollection,
|
||||
tree_method: str,
|
||||
client: "Client"
|
||||
) -> None:
|
||||
model_0 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = xgb.dask.inplace_predict(
|
||||
client, model_0.get_booster(), X, predict_type="margin"
|
||||
)
|
||||
|
||||
model_1 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = xgb.dask.predict(
|
||||
client,
|
||||
model_1.get_booster(),
|
||||
xgb.dask.DaskDMatrix(client, X, base_margin=margin),
|
||||
output_margin=True
|
||||
)
|
||||
|
||||
model_2 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method
|
||||
)
|
||||
model_2.fit(X=X, y=y)
|
||||
predictions_2 = xgb.dask.inplace_predict(
|
||||
client, model_2.get_booster(), X, predict_type="margin"
|
||||
)
|
||||
a = predictions_1.compute()
|
||||
b = predictions_2.compute()
|
||||
# cupy/cudf
|
||||
if hasattr(a, "get"):
|
||||
a = a.get()
|
||||
if hasattr(b, "values"):
|
||||
b = b.values
|
||||
if hasattr(b, "get"):
|
||||
b = b.get()
|
||||
np.testing.assert_allclose(a, b, atol=1e-5)
|
||||
|
||||
|
||||
def run_boost_from_prediction(
|
||||
X: xgb.dask._DaskCollection, y: xgb.dask._DaskCollection, tree_method: str, client: "Client"
|
||||
) -> None:
|
||||
@ -227,11 +271,15 @@ def run_boost_from_prediction(
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
|
||||
run_boost_from_prediction(X, y, tree_method, client)
|
||||
|
||||
X_, y_ = load_digits(return_X_y=True)
|
||||
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
|
||||
run_boost_from_prediction_multi_clasas(X, y, tree_method, client)
|
||||
|
||||
|
||||
def test_inplace_predict(client: "Client") -> None:
|
||||
from sklearn.datasets import load_boston
|
||||
|
||||
@ -3,6 +3,7 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import pytest
|
||||
from test_dmatrix import set_base_margin_info
|
||||
|
||||
try:
|
||||
import modin.pandas as md
|
||||
@ -144,3 +145,6 @@ class TestModin:
|
||||
assert data.num_col() == kCols
|
||||
|
||||
np.testing.assert_array_equal(data.get_weight(), w)
|
||||
|
||||
def test_base_margin(self):
|
||||
set_base_margin_info(md.DataFrame, xgb.DMatrix, "hist")
|
||||
|
||||
@ -3,6 +3,7 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import pytest
|
||||
from test_dmatrix import set_base_margin_info
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
@ -205,6 +206,9 @@ class TestPandas:
|
||||
|
||||
np.testing.assert_array_equal(data.get_weight(), w)
|
||||
|
||||
def test_base_margin(self):
|
||||
set_base_margin_info(pd.DataFrame, xgb.DMatrix, "hist")
|
||||
|
||||
def test_cv_as_pandas(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from typing import Callable, Optional
|
||||
import collections
|
||||
import importlib.util
|
||||
import numpy as np
|
||||
@ -1147,32 +1148,83 @@ def test_feature_weights():
|
||||
assert poly_decreasing[0] < -0.08
|
||||
|
||||
|
||||
def run_boost_from_prediction(tree_method):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
def run_boost_from_prediction_binary(tree_method, X, y, as_frame: Optional[Callable]):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
||||
as_frame: A callable function to convert margin into DataFrame, useful for different
|
||||
df implementations.
|
||||
"""
|
||||
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4,
|
||||
tree_method=tree_method)
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
if as_frame is not None:
|
||||
margin = as_frame(margin)
|
||||
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4,
|
||||
tree_method=tree_method)
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
|
||||
cls_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8,
|
||||
tree_method=tree_method)
|
||||
learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method
|
||||
)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X)
|
||||
assert np.all(predictions_1 == predictions_2)
|
||||
np.testing.assert_allclose(predictions_1, predictions_2)
|
||||
|
||||
|
||||
def run_boost_from_prediction_multi_clasas(
|
||||
tree_method, X, y, as_frame: Optional[Callable]
|
||||
):
|
||||
# Multi-class
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.get_booster().inplace_predict(X, predict_type="margin")
|
||||
if as_frame is not None:
|
||||
margin = as_frame(margin)
|
||||
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.get_booster().predict(
|
||||
xgb.DMatrix(X, base_margin=margin), output_margin=True
|
||||
)
|
||||
|
||||
model_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method
|
||||
)
|
||||
model_2.fit(X=X, y=y)
|
||||
predictions_2 = model_2.get_booster().inplace_predict(X, predict_type="margin")
|
||||
|
||||
if hasattr(predictions_1, "get"):
|
||||
predictions_1 = predictions_1.get()
|
||||
if hasattr(predictions_2, "get"):
|
||||
predictions_2 = predictions_2.get()
|
||||
np.testing.assert_allclose(predictions_1, predictions_2, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||
def test_boost_from_prediction(tree_method):
|
||||
run_boost_from_prediction(tree_method)
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
import pandas as pd
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
|
||||
run_boost_from_prediction_binary(tree_method, X, y, None)
|
||||
run_boost_from_prediction_binary(tree_method, X, y, pd.DataFrame)
|
||||
|
||||
X, y = load_digits(return_X_y=True)
|
||||
|
||||
run_boost_from_prediction_multi_clasas(tree_method, X, y, None)
|
||||
run_boost_from_prediction_multi_clasas(tree_method, X, y, pd.DataFrame)
|
||||
|
||||
|
||||
def test_estimator_type():
|
||||
|
||||
@ -3,6 +3,7 @@ import os
|
||||
import urllib
|
||||
import zipfile
|
||||
import sys
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
@ -177,7 +178,7 @@ class TestDataset:
|
||||
self.metric = metric
|
||||
self.X, self.y = get_dataset()
|
||||
self.w = None
|
||||
self.margin = None
|
||||
self.margin: Optional[np.ndarray] = None
|
||||
|
||||
def set_params(self, params_in):
|
||||
params_in['objective'] = self.objective
|
||||
@ -315,7 +316,7 @@ _unweighted_datasets_strategy = strategies.sampled_from(
|
||||
|
||||
@strategies.composite
|
||||
def _dataset_weight_margin(draw):
|
||||
data = draw(_unweighted_datasets_strategy)
|
||||
data: TestDataset = draw(_unweighted_datasets_strategy)
|
||||
if draw(strategies.booleans()):
|
||||
data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)))
|
||||
if draw(strategies.booleans()):
|
||||
@ -324,6 +325,8 @@ def _dataset_weight_margin(draw):
|
||||
num_class = int(np.max(data.y) + 1)
|
||||
data.margin = draw(
|
||||
arrays(np.float64, (len(data.y) * num_class), elements=strategies.floats(0.5, 1.0)))
|
||||
if num_class != 1:
|
||||
data.margin = data.margin.reshape(data.y.shape[0], num_class)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user