Support building gradient index with cat data. (#7371)

This commit is contained in:
Jiaming Yuan
2021-11-03 22:37:37 +08:00
committed by GitHub
parent 57a4b4ff64
commit ccdabe4512
10 changed files with 105 additions and 27 deletions

View File

@@ -7,6 +7,7 @@
#include <vector>
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "../common/categorical.h"
#include "../common/hist_util.h"
#include "../common/threading_utils.h"
@@ -18,8 +19,9 @@ namespace xgboost {
* index for CPU histogram. On GPU ellpack page is used.
*/
class GHistIndexMatrix {
void PushBatch(SparsePage const &batch, size_t rbegin, size_t prev_sum,
uint32_t nbins, int32_t n_threads);
void PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
size_t rbegin, size_t prev_sum, uint32_t nbins,
int32_t n_threads);
public:
/*! \brief row pointer to rows by element position */
@@ -40,12 +42,14 @@ class GHistIndexMatrix {
}
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins, common::Span<float> hess);
void Init(SparsePage const &page, common::HistogramCuts const &cuts,
int32_t max_bins_per_feat, bool is_dense, int32_t n_threads);
void Init(SparsePage const &page, common::Span<FeatureType const> ft,
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
bool is_dense, int32_t n_threads);
// specific method for sparse data as no possibility to reduce allocated memory
template <typename BinIdxType, typename GetOffset>
void SetIndexData(common::Span<BinIdxType> index_data_span,
common::Span<FeatureType const> ft,
size_t batch_threads, const SparsePage &batch,
size_t rbegin, size_t nbins, GetOffset get_offset) {
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
@@ -61,9 +65,16 @@ class GHistIndexMatrix {
SparsePage::Inst inst = {data_ptr + offset_vec[i], size};
CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) {
uint32_t idx = cut.SearchBin(inst[j]);
index_data[ibegin + j] = get_offset(idx, j);
++hit_count_tloc_[tid * nbins + idx];
auto e = inst[j];
if (common::IsCat(ft, e.index)) {
auto bin_idx = cut.SearchCatBin(e);
index_data[ibegin + j] = get_offset(bin_idx, j);
++hit_count_tloc_[tid * nbins + bin_idx];
} else {
uint32_t idx = cut.SearchBin(inst[j]);
index_data[ibegin + j] = get_offset(idx, j);
++hit_count_tloc_[tid * nbins + idx];
}
}
});
}