[R] make sure output fits into int32 (#9949)
This commit is contained in:
parent
621348abb3
commit
db396ee340
@ -167,21 +167,38 @@ SEXP SafeAllocInteger(size_t size, SEXP continuation_token) {
|
|||||||
[[nodiscard]] SEXP CopyArrayToR(const char *array_str, SEXP ctoken) {
|
[[nodiscard]] SEXP CopyArrayToR(const char *array_str, SEXP ctoken) {
|
||||||
xgboost::ArrayInterface<1> array{xgboost::StringView{array_str}};
|
xgboost::ArrayInterface<1> array{xgboost::StringView{array_str}};
|
||||||
// R supports only int and double.
|
// R supports only int and double.
|
||||||
bool is_int =
|
bool is_int_type =
|
||||||
xgboost::DispatchDType(array.type, [](auto t) { return std::is_integral_v<decltype(t)>; });
|
xgboost::DispatchDType(array.type, [](auto t) { return std::is_integral_v<decltype(t)>; });
|
||||||
bool is_float = xgboost::DispatchDType(
|
bool is_float = xgboost::DispatchDType(
|
||||||
array.type, [](auto v) { return std::is_floating_point_v<decltype(v)>; });
|
array.type, [](auto v) { return std::is_floating_point_v<decltype(v)>; });
|
||||||
CHECK(is_int || is_float) << "Internal error: Invalid DType.";
|
CHECK(is_int_type || is_float) << "Internal error: Invalid DType.";
|
||||||
CHECK(array.is_contiguous) << "Internal error: Return by XGBoost should be contiguous";
|
CHECK(array.is_contiguous) << "Internal error: Return by XGBoost should be contiguous";
|
||||||
|
|
||||||
|
// Note: the only case in which this will receive an integer type is
|
||||||
|
// for the 'indptr' part of the quantile cut outputs, which comes
|
||||||
|
// in sorted order, so the last element contains the maximum value.
|
||||||
|
bool fits_into_C_int = xgboost::DispatchDType(array.type, [&](auto t) {
|
||||||
|
using T = decltype(t);
|
||||||
|
if (!std::is_integral_v<decltype(t)>) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto ptr = static_cast<T const *>(array.data);
|
||||||
|
T last_elt = ptr[array.n - 1];
|
||||||
|
if (last_elt < 0) {
|
||||||
|
last_elt = -last_elt; // no std::abs overload for all possible types
|
||||||
|
}
|
||||||
|
return last_elt <= std::numeric_limits<int>::max();
|
||||||
|
});
|
||||||
|
bool use_int = is_int_type && fits_into_C_int;
|
||||||
|
|
||||||
// Allocate memory in R
|
// Allocate memory in R
|
||||||
SEXP out =
|
SEXP out =
|
||||||
Rf_protect(is_int ? SafeAllocInteger(array.n, ctoken) : SafeAllocReal(array.n, ctoken));
|
Rf_protect(use_int ? SafeAllocInteger(array.n, ctoken) : SafeAllocReal(array.n, ctoken));
|
||||||
|
|
||||||
xgboost::DispatchDType(array.type, [&](auto t) {
|
xgboost::DispatchDType(array.type, [&](auto t) {
|
||||||
using T = decltype(t);
|
using T = decltype(t);
|
||||||
auto in_ptr = static_cast<T const *>(array.data);
|
auto in_ptr = static_cast<T const *>(array.data);
|
||||||
if (is_int) {
|
if (use_int) {
|
||||||
auto out_ptr = INTEGER(out);
|
auto out_ptr = INTEGER(out);
|
||||||
std::copy_n(in_ptr, array.n, out_ptr);
|
std::copy_n(in_ptr, array.n, out_ptr);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user