[R] make sure output fits into int32 (#9949)

This commit is contained in:
david-cortes 2024-01-04 09:51:22 +01:00 committed by GitHub
parent 621348abb3
commit db396ee340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -167,21 +167,38 @@ SEXP SafeAllocInteger(size_t size, SEXP continuation_token) {
[[nodiscard]] SEXP CopyArrayToR(const char *array_str, SEXP ctoken) {
xgboost::ArrayInterface<1> array{xgboost::StringView{array_str}};
// R supports only int and double.
bool is_int =
bool is_int_type =
xgboost::DispatchDType(array.type, [](auto t) { return std::is_integral_v<decltype(t)>; });
bool is_float = xgboost::DispatchDType(
array.type, [](auto v) { return std::is_floating_point_v<decltype(v)>; });
CHECK(is_int || is_float) << "Internal error: Invalid DType.";
CHECK(is_int_type || is_float) << "Internal error: Invalid DType.";
CHECK(array.is_contiguous) << "Internal error: Return by XGBoost should be contiguous";
// Note: the only case in which this will receive an integer type is
// for the 'indptr' part of the quantile cut outputs, which comes
// in sorted order, so the last element contains the maximum value.
bool fits_into_C_int = xgboost::DispatchDType(array.type, [&](auto t) {
using T = decltype(t);
if (!std::is_integral_v<decltype(t)>) {
return false;
}
auto ptr = static_cast<T const *>(array.data);
T last_elt = ptr[array.n - 1];
if (last_elt < 0) {
last_elt = -last_elt; // no std::abs overload for all possible types
}
return last_elt <= std::numeric_limits<int>::max();
});
bool use_int = is_int_type && fits_into_C_int;
// Allocate memory in R
SEXP out =
Rf_protect(is_int ? SafeAllocInteger(array.n, ctoken) : SafeAllocReal(array.n, ctoken));
Rf_protect(use_int ? SafeAllocInteger(array.n, ctoken) : SafeAllocReal(array.n, ctoken));
xgboost::DispatchDType(array.type, [&](auto t) {
using T = decltype(t);
auto in_ptr = static_cast<T const *>(array.data);
if (is_int) {
if (use_int) {
auto out_ptr = INTEGER(out);
std::copy_n(in_ptr, array.n, out_ptr);
} else {