Support multi-class with base margin. (#7381)
This is already partially supported but never properly tested. So the only possible way to use it is calling `numpy.ndarray.flatten` with `base_margin` before passing it into XGBoost. This PR adds proper support for most of the data types along with tests.
This commit is contained in:
@@ -210,27 +210,28 @@ class ArrayInterfaceHandler {
|
||||
}
|
||||
|
||||
static void ExtractStride(std::map<std::string, Json> const &column,
|
||||
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
|
||||
size_t *stride_r, size_t *stride_c, size_t rows,
|
||||
size_t cols, size_t itemsize) {
|
||||
auto strides_it = column.find("strides");
|
||||
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
|
||||
// default strides
|
||||
strides[0] = cols;
|
||||
strides[1] = 1;
|
||||
*stride_r = cols;
|
||||
*stride_c = 1;
|
||||
} else {
|
||||
// strides specified by the array interface
|
||||
auto const &j_strides = get<Array const>(strides_it->second);
|
||||
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
|
||||
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
|
||||
*stride_r = get<Integer const>(j_strides[0]) / itemsize;
|
||||
size_t n = 1;
|
||||
if (j_strides.size() == 2) {
|
||||
n = get<Integer const>(j_strides[1]) / itemsize;
|
||||
}
|
||||
strides[1] = n;
|
||||
*stride_c = n;
|
||||
}
|
||||
|
||||
auto valid = rows * strides[0] + cols * strides[1] >= (rows * cols);
|
||||
auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols);
|
||||
CHECK(valid) << "Invalid strides in array."
|
||||
<< " strides: (" << strides[0] << "," << strides[1]
|
||||
<< " strides: (" << (*stride_r) << "," << (*stride_c)
|
||||
<< "), shape: (" << rows << ", " << cols << ")";
|
||||
}
|
||||
|
||||
@@ -281,8 +282,8 @@ class ArrayInterface {
|
||||
<< "Masked array is not yet supported.";
|
||||
}
|
||||
|
||||
ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
|
||||
typestr[2] - '0');
|
||||
ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col,
|
||||
num_rows, num_cols, typestr[2] - '0');
|
||||
|
||||
auto stream_it = array.find("stream");
|
||||
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
|
||||
@@ -323,8 +324,8 @@ class ArrayInterface {
|
||||
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
|
||||
num_cols = 1;
|
||||
|
||||
strides[0] = std::max(strides[0], strides[1]);
|
||||
strides[1] = 1;
|
||||
stride_row = std::max(stride_row, stride_col);
|
||||
stride_col = 1;
|
||||
}
|
||||
|
||||
void AssignType(StringView typestr) {
|
||||
@@ -406,13 +407,14 @@ class ArrayInterface {
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
|
||||
return this->DispatchCall(
|
||||
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
|
||||
[=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; });
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
size_t strides[2]{0, 0};
|
||||
size_t stride_row{0};
|
||||
size_t stride_col{0};
|
||||
void* data;
|
||||
Type type;
|
||||
};
|
||||
|
||||
@@ -30,12 +30,16 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
|
||||
return;
|
||||
}
|
||||
out->SetDevice(ptr_device);
|
||||
out->Resize(column.num_rows);
|
||||
|
||||
size_t size = column.num_rows * column.num_cols;
|
||||
CHECK_NE(size, 0);
|
||||
out->Resize(size);
|
||||
|
||||
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
|
||||
|
||||
dh::LaunchN(column.num_rows, [=] __device__(size_t idx) {
|
||||
p_dst[idx] = column.GetElement(idx, 0);
|
||||
dh::LaunchN(size, [=] __device__(size_t idx) {
|
||||
size_t ridx = idx / column.num_cols;
|
||||
size_t cidx = idx - (ridx * column.num_cols);
|
||||
p_dst[idx] = column.GetElement(ridx, cidx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -126,16 +130,8 @@ void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_);
|
||||
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
auto const& j_arr = get<Array>(j_interface);
|
||||
CHECK_EQ(j_arr.size(), 1)
|
||||
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
|
||||
ArrayInterface array_interface(interface_str);
|
||||
std::string key{c_key};
|
||||
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
|
||||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
|
||||
// Not an empty column, transform it.
|
||||
array_interface.AsColumnVector();
|
||||
}
|
||||
|
||||
CHECK(!array_interface.valid.Data())
|
||||
<< "Meta info " << key << " should be dense, found validity mask";
|
||||
@@ -143,6 +139,18 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (key == "base_margin") {
|
||||
CopyInfoImpl(array_interface, &base_margin_);
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK(array_interface.num_cols == 1 || array_interface.num_rows == 1)
|
||||
<< "MetaInfo: " << c_key << " has invalid shape";
|
||||
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
|
||||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
|
||||
// Not an empty column, transform it.
|
||||
array_interface.AsColumnVector();
|
||||
}
|
||||
if (key == "label") {
|
||||
CopyInfoImpl(array_interface, &labels_);
|
||||
auto ptr = labels_.ConstDevicePointer();
|
||||
@@ -155,8 +163,6 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(),
|
||||
WeightsCheck{});
|
||||
CHECK(valid) << "Weights must be positive values.";
|
||||
} else if (key == "base_margin") {
|
||||
CopyInfoImpl(array_interface, &base_margin_);
|
||||
} else if (key == "group") {
|
||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||
ValidateQueryGroup(group_ptr_);
|
||||
|
||||
Reference in New Issue
Block a user