Support exporting cut values (#9356)
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
*/
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include <algorithm> // for copy
|
||||
#include <algorithm> // for copy, transform
|
||||
#include <cinttypes> // for strtoimax
|
||||
#include <cmath> // for nan
|
||||
#include <cstring> // for strcmp
|
||||
@@ -20,9 +20,11 @@
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
|
||||
#include "../data/ellpack_page.h" // for EllpackPage
|
||||
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
|
||||
@@ -785,6 +787,104 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
|
||||
API_END();
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename Page>
|
||||
void GetCutImpl(Context const *ctx, std::shared_ptr<DMatrix> p_m,
|
||||
std::vector<std::uint64_t> *p_indptr, std::vector<float> *p_data) {
|
||||
auto &indptr = *p_indptr;
|
||||
auto &data = *p_data;
|
||||
for (auto const &page : p_m->GetBatches<Page>(ctx, {})) {
|
||||
auto const &cut = page.Cuts();
|
||||
|
||||
auto const &ptrs = cut.Ptrs();
|
||||
indptr.resize(ptrs.size());
|
||||
|
||||
auto const &vals = cut.Values();
|
||||
auto const &mins = cut.MinValues();
|
||||
|
||||
bst_feature_t n_features = p_m->Info().num_col_;
|
||||
auto ft = p_m->Info().feature_types.ConstHostSpan();
|
||||
std::size_t n_categories = std::count_if(ft.cbegin(), ft.cend(),
|
||||
[](auto t) { return t == FeatureType::kCategorical; });
|
||||
data.resize(vals.size() + n_features - n_categories); // |vals| + |mins|
|
||||
std::size_t i{0}, n_numeric{0};
|
||||
for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) {
|
||||
CHECK_LT(i, data.size());
|
||||
bool is_numeric = !common::IsCat(ft, fidx);
|
||||
if (is_numeric) {
|
||||
data[i] = mins[fidx];
|
||||
i++;
|
||||
}
|
||||
auto beg = ptrs[fidx];
|
||||
auto end = ptrs[fidx + 1];
|
||||
CHECK_LE(end, data.size());
|
||||
std::copy(vals.cbegin() + beg, vals.cbegin() + end, data.begin() + i);
|
||||
i += (end - beg);
|
||||
// shift by min values.
|
||||
indptr[fidx] = ptrs[fidx] + n_numeric;
|
||||
if (is_numeric) {
|
||||
n_numeric++;
|
||||
}
|
||||
}
|
||||
CHECK_EQ(n_numeric, n_features - n_categories);
|
||||
|
||||
indptr.back() = data.size();
|
||||
CHECK_EQ(indptr.back(), vals.size() + mins.size() - n_categories);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config,
|
||||
char const **out_indptr, char const **out_data) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
xgboost_CHECK_C_ARG_PTR(out_indptr);
|
||||
xgboost_CHECK_C_ARG_PTR(out_data);
|
||||
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
|
||||
if (!p_m->PageExists<GHistIndexMatrix>() && !p_m->PageExists<EllpackPage>()) {
|
||||
LOG(FATAL) << "The quantile cut hasn't been generated yet. Unless this is a `QuantileDMatrix`, "
|
||||
"quantile cut is generated during training.";
|
||||
}
|
||||
// Get return buffer
|
||||
auto &data = p_m->GetThreadLocal().ret_vec_float;
|
||||
auto &indptr = p_m->GetThreadLocal().ret_vec_u64;
|
||||
|
||||
if (p_m->PageExists<GHistIndexMatrix>()) {
|
||||
auto ctx = p_m->Ctx()->IsCPU() ? *p_m->Ctx() : p_m->Ctx()->MakeCPU();
|
||||
GetCutImpl<GHistIndexMatrix>(&ctx, p_m, &indptr, &data);
|
||||
} else {
|
||||
auto ctx = p_m->Ctx()->IsCUDA() ? *p_m->Ctx() : p_m->Ctx()->MakeCUDA(0);
|
||||
GetCutImpl<EllpackPage>(&ctx, p_m, &indptr, &data);
|
||||
}
|
||||
|
||||
// Create a CPU context
|
||||
Context ctx;
|
||||
// Get return buffer
|
||||
auto &ret_vec_str = p_m->GetThreadLocal().ret_vec_str;
|
||||
ret_vec_str.clear();
|
||||
|
||||
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
|
||||
linalg::MakeTensorView(&ctx, common::Span{indptr.data(), indptr.size()}, indptr.size())));
|
||||
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
|
||||
linalg::MakeTensorView(&ctx, common::Span{data.data(), data.size()}, data.size())));
|
||||
|
||||
auto &charp_vecs = p_m->GetThreadLocal().ret_vec_charp;
|
||||
charp_vecs.resize(ret_vec_str.size());
|
||||
std::transform(ret_vec_str.cbegin(), ret_vec_str.cend(), charp_vecs.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
|
||||
*out_indptr = charp_vecs[0];
|
||||
*out_data = charp_vecs[1];
|
||||
API_END();
|
||||
}
|
||||
|
||||
// xgboost implementation
|
||||
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
xgboost::bst_ulong len,
|
||||
|
||||
Reference in New Issue
Block a user