Support categorical data in ellpack. (#6140)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
* Copyright 2019-2020 XGBoost contributors
|
||||
*/
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "../common/random.h"
|
||||
#include "./ellpack_page.cuh"
|
||||
@@ -33,6 +33,7 @@ __global__ void CompressBinEllpackKernel(
|
||||
const Entry* __restrict__ entries, // One batch of input data
|
||||
const float* __restrict__ cuts, // HistogramCuts::cut_values_
|
||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
|
||||
common::Span<FeatureType const> feature_types,
|
||||
size_t base_row, // batch_row_begin
|
||||
size_t n_rows,
|
||||
size_t row_stride,
|
||||
@@ -51,11 +52,19 @@ __global__ void CompressBinEllpackKernel(
|
||||
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
|
||||
const float* feature_cuts = &cuts[cut_rows[feature]];
|
||||
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
||||
bool is_cat = common::IsCat(feature_types, ifeature);
|
||||
// Assigning the bin in current entry.
|
||||
// S.t.: fvalue < feature_cuts[bin]
|
||||
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
|
||||
fvalue) -
|
||||
feature_cuts;
|
||||
if (is_cat) {
|
||||
auto it = dh::MakeTransformIterator<int>(
|
||||
feature_cuts, [](float v) { return common::AsCat(v); });
|
||||
bin = thrust::lower_bound(thrust::seq, it, it + ncuts, common::AsCat(fvalue)) - it;
|
||||
} else {
|
||||
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
|
||||
fvalue) -
|
||||
feature_cuts;
|
||||
}
|
||||
|
||||
if (bin >= ncuts) {
|
||||
bin = ncuts - 1;
|
||||
}
|
||||
@@ -83,14 +92,13 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
}
|
||||
|
||||
EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
const SparsePage& page, bool is_dense,
|
||||
size_t row_stride)
|
||||
: cuts_(std::move(cuts)),
|
||||
is_dense(is_dense),
|
||||
n_rows(page.Size()),
|
||||
const SparsePage &page, bool is_dense,
|
||||
size_t row_stride,
|
||||
common::Span<FeatureType const> feature_types)
|
||||
: cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()),
|
||||
row_stride(row_stride) {
|
||||
this->InitCompressedData(device);
|
||||
this->CreateHistIndices(device, page);
|
||||
this->CreateHistIndices(device, page, feature_types);
|
||||
}
|
||||
|
||||
// Construct an ELLPACK matrix in memory.
|
||||
@@ -108,12 +116,14 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
|
||||
monitor_.Stop("Quantiles");
|
||||
|
||||
monitor_.Start("InitCompressedData");
|
||||
InitCompressedData(param.gpu_id);
|
||||
this->InitCompressedData(param.gpu_id);
|
||||
monitor_.Stop("InitCompressedData");
|
||||
|
||||
dmat->Info().feature_types.SetDevice(param.gpu_id);
|
||||
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
monitor_.Start("BinningCompression");
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
CreateHistIndices(param.gpu_id, batch);
|
||||
CreateHistIndices(param.gpu_id, batch, ft);
|
||||
}
|
||||
monitor_.Stop("BinningCompression");
|
||||
}
|
||||
@@ -365,7 +375,8 @@ void EllpackPageImpl::InitCompressedData(int device) {
|
||||
|
||||
// Compress a CSR page into ELLPACK.
|
||||
void EllpackPageImpl::CreateHistIndices(int device,
|
||||
const SparsePage& row_batch) {
|
||||
const SparsePage& row_batch,
|
||||
common::Span<FeatureType const> feature_types) {
|
||||
if (row_batch.Size() == 0) return;
|
||||
unsigned int null_gidx_value = NumSymbols() - 1;
|
||||
|
||||
@@ -397,9 +408,9 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
size_t n_entries = ent_cnt_end - ent_cnt_begin;
|
||||
dh::device_vector<Entry> entries_d(n_entries);
|
||||
// copy data entries to device.
|
||||
dh::safe_cuda(cudaMemcpy(entries_d.data().get(),
|
||||
data_vec.data() + ent_cnt_begin,
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(),
|
||||
data_vec.data() + ent_cnt_begin,
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
const dim3 block3(32, 8, 1); // 256 threads
|
||||
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||
common::DivRoundUp(row_stride, block3.y), 1);
|
||||
@@ -408,7 +419,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()),
|
||||
gidx_buffer.DevicePointer(), row_ptrs.data().get(),
|
||||
entries_d.data().get(), device_accessor.gidx_fvalue_map.data(),
|
||||
device_accessor.feature_segments.data(),
|
||||
device_accessor.feature_segments.data(), feature_types,
|
||||
row_batch.base_rowid + batch_row_begin, batch_nrows, row_stride,
|
||||
null_gidx_value);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user