Support building gradient index with cat data. (#7371)
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
@@ -18,8 +19,9 @@ namespace xgboost {
|
||||
* index for CPU histogram. On GPU ellpack page is used.
|
||||
*/
|
||||
class GHistIndexMatrix {
|
||||
void PushBatch(SparsePage const &batch, size_t rbegin, size_t prev_sum,
|
||||
uint32_t nbins, int32_t n_threads);
|
||||
void PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||
size_t rbegin, size_t prev_sum, uint32_t nbins,
|
||||
int32_t n_threads);
|
||||
|
||||
public:
|
||||
/*! \brief row pointer to rows by element position */
|
||||
@@ -40,12 +42,14 @@ class GHistIndexMatrix {
|
||||
}
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat, int max_num_bins, common::Span<float> hess);
|
||||
void Init(SparsePage const &page, common::HistogramCuts const &cuts,
|
||||
int32_t max_bins_per_feat, bool is_dense, int32_t n_threads);
|
||||
void Init(SparsePage const &page, common::Span<FeatureType const> ft,
|
||||
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
|
||||
bool is_dense, int32_t n_threads);
|
||||
|
||||
// specific method for sparse data as no possibility to reduce allocated memory
|
||||
template <typename BinIdxType, typename GetOffset>
|
||||
void SetIndexData(common::Span<BinIdxType> index_data_span,
|
||||
common::Span<FeatureType const> ft,
|
||||
size_t batch_threads, const SparsePage &batch,
|
||||
size_t rbegin, size_t nbins, GetOffset get_offset) {
|
||||
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
|
||||
@@ -61,9 +65,16 @@ class GHistIndexMatrix {
|
||||
SparsePage::Inst inst = {data_ptr + offset_vec[i], size};
|
||||
CHECK_EQ(ibegin + inst.size(), iend);
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
uint32_t idx = cut.SearchBin(inst[j]);
|
||||
index_data[ibegin + j] = get_offset(idx, j);
|
||||
++hit_count_tloc_[tid * nbins + idx];
|
||||
auto e = inst[j];
|
||||
if (common::IsCat(ft, e.index)) {
|
||||
auto bin_idx = cut.SearchCatBin(e);
|
||||
index_data[ibegin + j] = get_offset(bin_idx, j);
|
||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||
} else {
|
||||
uint32_t idx = cut.SearchBin(inst[j]);
|
||||
index_data[ibegin + j] = get_offset(idx, j);
|
||||
++hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user