Support multi-class with base margin. (#7381)
This is already partially supported but never properly tested. So the only possible way to use it is calling `numpy.ndarray.flatten` with `base_margin` before passing it into XGBoost. This PR adds proper support for most of the data types along with tests.
This commit is contained in:
@@ -210,27 +210,28 @@ class ArrayInterfaceHandler {
|
||||
}
|
||||
|
||||
static void ExtractStride(std::map<std::string, Json> const &column,
|
||||
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
|
||||
size_t *stride_r, size_t *stride_c, size_t rows,
|
||||
size_t cols, size_t itemsize) {
|
||||
auto strides_it = column.find("strides");
|
||||
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
|
||||
// default strides
|
||||
strides[0] = cols;
|
||||
strides[1] = 1;
|
||||
*stride_r = cols;
|
||||
*stride_c = 1;
|
||||
} else {
|
||||
// strides specified by the array interface
|
||||
auto const &j_strides = get<Array const>(strides_it->second);
|
||||
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
|
||||
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
|
||||
*stride_r = get<Integer const>(j_strides[0]) / itemsize;
|
||||
size_t n = 1;
|
||||
if (j_strides.size() == 2) {
|
||||
n = get<Integer const>(j_strides[1]) / itemsize;
|
||||
}
|
||||
strides[1] = n;
|
||||
*stride_c = n;
|
||||
}
|
||||
|
||||
auto valid = rows * strides[0] + cols * strides[1] >= (rows * cols);
|
||||
auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols);
|
||||
CHECK(valid) << "Invalid strides in array."
|
||||
<< " strides: (" << strides[0] << "," << strides[1]
|
||||
<< " strides: (" << (*stride_r) << "," << (*stride_c)
|
||||
<< "), shape: (" << rows << ", " << cols << ")";
|
||||
}
|
||||
|
||||
@@ -281,8 +282,8 @@ class ArrayInterface {
|
||||
<< "Masked array is not yet supported.";
|
||||
}
|
||||
|
||||
ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
|
||||
typestr[2] - '0');
|
||||
ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col,
|
||||
num_rows, num_cols, typestr[2] - '0');
|
||||
|
||||
auto stream_it = array.find("stream");
|
||||
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
|
||||
@@ -323,8 +324,8 @@ class ArrayInterface {
|
||||
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
|
||||
num_cols = 1;
|
||||
|
||||
strides[0] = std::max(strides[0], strides[1]);
|
||||
strides[1] = 1;
|
||||
stride_row = std::max(stride_row, stride_col);
|
||||
stride_col = 1;
|
||||
}
|
||||
|
||||
void AssignType(StringView typestr) {
|
||||
@@ -406,13 +407,14 @@ class ArrayInterface {
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
|
||||
return this->DispatchCall(
|
||||
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
|
||||
[=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; });
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
size_t strides[2]{0, 0};
|
||||
size_t stride_row{0};
|
||||
size_t stride_col{0};
|
||||
void* data;
|
||||
Type type;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user