Support exporting cut values (#9356)

This commit is contained in:
Jiaming Yuan
2023-07-08 15:32:41 +08:00
committed by GitHub
parent c3124813e8
commit 20c52f07d2
28 changed files with 722 additions and 101 deletions

View File

@@ -3,7 +3,7 @@
*/
#include "xgboost/c_api.h"
#include <algorithm> // for copy
#include <algorithm> // for copy, transform
#include <cinttypes> // for strtoimax
#include <cmath> // for nan
#include <cstring> // for strcmp
@@ -20,9 +20,11 @@
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
#include "../data/ellpack_page.h" // for EllpackPage
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
@@ -785,6 +787,104 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
API_END();
}
namespace {
template <typename Page>
void GetCutImpl(Context const *ctx, std::shared_ptr<DMatrix> p_m,
std::vector<std::uint64_t> *p_indptr, std::vector<float> *p_data) {
auto &indptr = *p_indptr;
auto &data = *p_data;
for (auto const &page : p_m->GetBatches<Page>(ctx, {})) {
auto const &cut = page.Cuts();
auto const &ptrs = cut.Ptrs();
indptr.resize(ptrs.size());
auto const &vals = cut.Values();
auto const &mins = cut.MinValues();
bst_feature_t n_features = p_m->Info().num_col_;
auto ft = p_m->Info().feature_types.ConstHostSpan();
std::size_t n_categories = std::count_if(ft.cbegin(), ft.cend(),
[](auto t) { return t == FeatureType::kCategorical; });
data.resize(vals.size() + n_features - n_categories); // |vals| + |mins|
std::size_t i{0}, n_numeric{0};
for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) {
CHECK_LT(i, data.size());
bool is_numeric = !common::IsCat(ft, fidx);
if (is_numeric) {
data[i] = mins[fidx];
i++;
}
auto beg = ptrs[fidx];
auto end = ptrs[fidx + 1];
CHECK_LE(end, data.size());
std::copy(vals.cbegin() + beg, vals.cbegin() + end, data.begin() + i);
i += (end - beg);
// shift by min values.
indptr[fidx] = ptrs[fidx] + n_numeric;
if (is_numeric) {
n_numeric++;
}
}
CHECK_EQ(n_numeric, n_features - n_categories);
indptr.back() = data.size();
CHECK_EQ(indptr.back(), vals.size() + mins.size() - n_categories);
break;
}
}
} // namespace
XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config,
char const **out_indptr, char const **out_data) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
xgboost_CHECK_C_ARG_PTR(out_indptr);
xgboost_CHECK_C_ARG_PTR(out_data);
auto jconfig = Json::Load(StringView{config});
if (!p_m->PageExists<GHistIndexMatrix>() && !p_m->PageExists<EllpackPage>()) {
LOG(FATAL) << "The quantile cut hasn't been generated yet. Unless this is a `QuantileDMatrix`, "
"quantile cut is generated during training.";
}
// Get return buffer
auto &data = p_m->GetThreadLocal().ret_vec_float;
auto &indptr = p_m->GetThreadLocal().ret_vec_u64;
if (p_m->PageExists<GHistIndexMatrix>()) {
auto ctx = p_m->Ctx()->IsCPU() ? *p_m->Ctx() : p_m->Ctx()->MakeCPU();
GetCutImpl<GHistIndexMatrix>(&ctx, p_m, &indptr, &data);
} else {
auto ctx = p_m->Ctx()->IsCUDA() ? *p_m->Ctx() : p_m->Ctx()->MakeCUDA(0);
GetCutImpl<EllpackPage>(&ctx, p_m, &indptr, &data);
}
// Create a CPU context
Context ctx;
// Get return buffer
auto &ret_vec_str = p_m->GetThreadLocal().ret_vec_str;
ret_vec_str.clear();
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
linalg::MakeTensorView(&ctx, common::Span{indptr.data(), indptr.size()}, indptr.size())));
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
linalg::MakeTensorView(&ctx, common::Span{data.data(), data.size()}, data.size())));
auto &charp_vecs = p_m->GetThreadLocal().ret_vec_charp;
charp_vecs.resize(ret_vec_str.size());
std::transform(ret_vec_str.cbegin(), ret_vec_str.cend(), charp_vecs.begin(),
[](auto const &str) { return str.c_str(); });
*out_indptr = charp_vecs[0];
*out_data = charp_vecs[1];
API_END();
}
// xgboost implementation
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
xgboost::bst_ulong len,