From db396ee34046f29443cdc801903990341ea89fa8 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Thu, 4 Jan 2024 09:51:22 +0100 Subject: [PATCH] [R] make sure output fits into int32 (#9949) --- R-package/src/xgboost_R.cc | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 60a3fe68b..d7d4c49e1 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -167,21 +167,38 @@ SEXP SafeAllocInteger(size_t size, SEXP continuation_token) { [[nodiscard]] SEXP CopyArrayToR(const char *array_str, SEXP ctoken) { xgboost::ArrayInterface<1> array{xgboost::StringView{array_str}}; // R supports only int and double. - bool is_int = + bool is_int_type = xgboost::DispatchDType(array.type, [](auto t) { return std::is_integral_v; }); bool is_float = xgboost::DispatchDType( array.type, [](auto v) { return std::is_floating_point_v; }); - CHECK(is_int || is_float) << "Internal error: Invalid DType."; + CHECK(is_int_type || is_float) << "Internal error: Invalid DType."; CHECK(array.is_contiguous) << "Internal error: Return by XGBoost should be contiguous"; + // Note: the only case in which this will receive an integer type is + // for the 'indptr' part of the quantile cut outputs, which comes + // in sorted order, so the last element contains the maximum value. + bool fits_into_C_int = xgboost::DispatchDType(array.type, [&](auto t) { + using T = decltype(t); + if (!std::is_integral_v) { + return false; + } + auto ptr = static_cast(array.data); + T last_elt = ptr[array.n - 1]; + if (last_elt < 0) { + last_elt = -last_elt; // no std::abs overload for all possible types + } + return last_elt <= std::numeric_limits::max(); + }); + bool use_int = is_int_type && fits_into_C_int; + // Allocate memory in R SEXP out = - Rf_protect(is_int ? SafeAllocInteger(array.n, ctoken) : SafeAllocReal(array.n, ctoken)); + Rf_protect(use_int ? SafeAllocInteger(array.n, ctoken) : SafeAllocReal(array.n, ctoken)); xgboost::DispatchDType(array.type, [&](auto t) { using T = decltype(t); auto in_ptr = static_cast(array.data); - if (is_int) { + if (use_int) { auto out_ptr = INTEGER(out); std::copy_n(in_ptr, array.n, out_ptr); } else {