Support building gradient index with cat data. (#7371)

This commit is contained in:
Jiaming Yuan
2021-11-03 22:37:37 +08:00
committed by GitHub
parent 57a4b4ff64
commit ccdabe4512
10 changed files with 105 additions and 27 deletions

View File

@@ -9,8 +9,9 @@
namespace xgboost {
void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin,
size_t prev_sum, uint32_t nbins,
void GHistIndexMatrix::PushBatch(SparsePage const &batch,
common::Span<FeatureType const> ft,
size_t rbegin, size_t prev_sum, uint32_t nbins,
int32_t n_threads) {
// The number of threads is pegged to the batch size. If the OMP
// block is parallelized on anything other than the batch/block size,
@@ -86,7 +87,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin,
common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
if (curent_bin_size == common::kUint8BinsTypeSize) {
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(), n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint8_t>(idx - offsets[j]);
});
@@ -94,7 +95,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin,
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint16_t>(idx - offsets[j]);
});
@@ -102,7 +103,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin,
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint32_t>(idx - offsets[j]);
});
@@ -113,7 +114,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin,
not reduced */
} else {
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[](auto idx, auto) { return idx; });
}
@@ -147,15 +148,17 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span<float> h
size_t prev_sum = 0;
const bool isDense = p_fmat->IsDense();
this->isDense_ = isDense;
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
this->PushBatch(batch, rbegin, prev_sum, nbins, nthread);
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, nthread);
prev_sum = row_ptr[rbegin + batch.Size()];
rbegin += batch.Size();
}
}
void GHistIndexMatrix::Init(SparsePage const &batch,
common::Span<FeatureType const> ft,
common::HistogramCuts const &cuts,
int32_t max_bins_per_feat, bool isDense,
int32_t n_threads) {
@@ -176,7 +179,7 @@ void GHistIndexMatrix::Init(SparsePage const &batch,
size_t rbegin = 0;
size_t prev_sum = 0;
this->PushBatch(batch, rbegin, prev_sum, nbins, n_threads);
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
}
void GHistIndexMatrix::ResizeIndex(const size_t n_index,