Support multi-class with base margin. (#7381)

This is already partially supported but never properly tested. So the only possible way to use it is calling `numpy.ndarray.flatten` with `base_margin` before passing it into XGBoost. This PR adds proper support
for most of the data types along with tests.
This commit is contained in:
Jiaming Yuan
2021-11-02 13:38:00 +08:00
committed by GitHub
parent 6295dc3b67
commit a13321148a
18 changed files with 274 additions and 92 deletions

View File

@@ -210,27 +210,28 @@ class ArrayInterfaceHandler {
}
static void ExtractStride(std::map<std::string, Json> const &column,
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
size_t *stride_r, size_t *stride_c, size_t rows,
size_t cols, size_t itemsize) {
auto strides_it = column.find("strides");
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
// default strides
strides[0] = cols;
strides[1] = 1;
*stride_r = cols;
*stride_c = 1;
} else {
// strides specified by the array interface
auto const &j_strides = get<Array const>(strides_it->second);
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
*stride_r = get<Integer const>(j_strides[0]) / itemsize;
size_t n = 1;
if (j_strides.size() == 2) {
n = get<Integer const>(j_strides[1]) / itemsize;
}
strides[1] = n;
*stride_c = n;
}
auto valid = rows * strides[0] + cols * strides[1] >= (rows * cols);
auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols);
CHECK(valid) << "Invalid strides in array."
<< " strides: (" << strides[0] << "," << strides[1]
<< " strides: (" << (*stride_r) << "," << (*stride_c)
<< "), shape: (" << rows << ", " << cols << ")";
}
@@ -281,8 +282,8 @@ class ArrayInterface {
<< "Masked array is not yet supported.";
}
ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
typestr[2] - '0');
ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col,
num_rows, num_cols, typestr[2] - '0');
auto stream_it = array.find("stream");
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
@@ -323,8 +324,8 @@ class ArrayInterface {
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
num_cols = 1;
strides[0] = std::max(strides[0], strides[1]);
strides[1] = 1;
stride_row = std::max(stride_row, stride_col);
stride_col = 1;
}
void AssignType(StringView typestr) {
@@ -406,13 +407,14 @@ class ArrayInterface {
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
return this->DispatchCall(
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
[=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; });
}
RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
size_t strides[2]{0, 0};
size_t stride_row{0};
size_t stride_col{0};
void* data;
Type type;
};

View File

@@ -30,12 +30,16 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
return;
}
out->SetDevice(ptr_device);
out->Resize(column.num_rows);
size_t size = column.num_rows * column.num_cols;
CHECK_NE(size, 0);
out->Resize(size);
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
dh::LaunchN(column.num_rows, [=] __device__(size_t idx) {
p_dst[idx] = column.GetElement(idx, 0);
dh::LaunchN(size, [=] __device__(size_t idx) {
size_t ridx = idx / column.num_cols;
size_t cidx = idx - (ridx * column.num_cols);
p_dst[idx] = column.GetElement(ridx, cidx);
});
}
@@ -126,16 +130,8 @@ void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_);
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface);
CHECK_EQ(j_arr.size(), 1)
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
ArrayInterface array_interface(interface_str);
std::string key{c_key};
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
// Not an empty column, transform it.
array_interface.AsColumnVector();
}
CHECK(!array_interface.valid.Data())
<< "Meta info " << key << " should be dense, found validity mask";
@@ -143,6 +139,18 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
return;
}
if (key == "base_margin") {
CopyInfoImpl(array_interface, &base_margin_);
return;
}
CHECK(array_interface.num_cols == 1 || array_interface.num_rows == 1)
<< "MetaInfo: " << c_key << " has invalid shape";
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
// Not an empty column, transform it.
array_interface.AsColumnVector();
}
if (key == "label") {
CopyInfoImpl(array_interface, &labels_);
auto ptr = labels_.ConstDevicePointer();
@@ -155,8 +163,6 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(),
WeightsCheck{});
CHECK(valid) << "Weights must be positive values.";
} else if (key == "base_margin") {
CopyInfoImpl(array_interface, &base_margin_);
} else if (key == "group") {
CopyGroupInfoImpl(array_interface, &group_ptr_);
ValidateQueryGroup(group_ptr_);