Fix CPU hist init for sparse dataset. (#4625)

* Fix CPU hist init for sparse dataset.

* Implement sparse histogram cut.
* Allow empty features.

* Fix windows build, don't use sparse in distributed environment.

* Comments.

* Smaller threshold.

* Fix windows omp.

* Fix msvc lambda capture.

* Fix MSVC macro.

* Fix MSVC initialization list.

* Fix MSVC initialization list x2.

* Preserve categorical feature behavior.

* Rename matrix to sparse cuts.
* Reuse UseGroup.
* Check for categorical data when adding cut.

Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>

* Sanity check.

* Fix comments.

* Fix comment.
This commit is contained in:
Jiaming Yuan
2019-07-04 19:27:03 -04:00
committed by Philip Hyunsu Cho
parent b7a1f22d24
commit d9a47794a5
33 changed files with 681 additions and 299 deletions

View File

@@ -480,8 +480,8 @@ __global__ void CompressBinEllpackKernel(
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistCutMatrix::cut
const uint32_t* __restrict__ cut_rows, // HistCutMatrix::row_ptrs
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
@@ -593,7 +593,7 @@ struct DeviceShard {
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist;
/*! \brief row_ptr form HistCutMatrix. */
/*! \brief row_ptr form HistogramCuts. */
common::Span<uint32_t> feature_segments;
/*! \brief minimum value for each feature. */
common::Span<bst_float> min_fvalue;
@@ -654,10 +654,10 @@ struct DeviceShard {
}
void InitCompressedData(
const common::HistCutMatrix& hmat, size_t row_stride, bool is_dense);
const common::HistogramCuts& hmat, size_t row_stride, bool is_dense);
void CreateHistIndices(
const SparsePage &row_batch, const common::HistCutMatrix &hmat,
const SparsePage &row_batch, const common::HistogramCuts &hmat,
const RowStateOnDevice &device_row_state, int rows_per_batch);
~DeviceShard() {
@@ -718,7 +718,7 @@ struct DeviceShard {
// Work out cub temporary memory requirement
GPUTrainingParam gpu_param(param);
DeviceSplitCandidateReduceOp op(gpu_param);
size_t temp_storage_bytes;
size_t temp_storage_bytes = 0;
DeviceSplitCandidate*dummy = nullptr;
cub::DeviceReduce::Reduce(
nullptr, temp_storage_bytes, dummy,
@@ -806,7 +806,7 @@ struct DeviceShard {
const int items_per_thread = 8;
const int block_threads = 256;
const int grid_size = static_cast<int>(
dh::DivRoundUp(n_elements, items_per_thread * block_threads));
common::DivRoundUp(n_elements, items_per_thread * block_threads));
if (grid_size <= 0) {
return;
}
@@ -1106,9 +1106,9 @@ struct DeviceShard {
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::InitCompressedData(
const common::HistCutMatrix &hmat, size_t row_stride, bool is_dense) {
n_bins = hmat.row_ptr.back();
int null_gidx_value = hmat.row_ptr.back();
const common::HistogramCuts &hmat, size_t row_stride, bool is_dense) {
n_bins = hmat.Ptrs().back();
int null_gidx_value = hmat.Ptrs().back();
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
@@ -1121,14 +1121,14 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
&gpair, n_rows,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
&feature_segments, hmat.row_ptr.size(),
&gidx_fvalue_map, hmat.cut.size(),
&min_fvalue, hmat.min_val.size(),
&feature_segments, hmat.Ptrs().size(),
&gidx_fvalue_map, hmat.Values().size(),
&min_fvalue, hmat.MinValues().size(),
&monotone_constraints, param.monotone_constraints.size());
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.cut);
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.min_val);
dh::CopyVectorToDeviceSpan(feature_segments, hmat.row_ptr);
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
node_sum_gradients.resize(max_nodes);
@@ -1153,26 +1153,26 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
// check if we can use shared memory for building histograms
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
// hiding)
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
auto histogram_size = sizeof(GradientSumT) * hmat.Ptrs().back();
auto max_smem = dh::MaxSharedMemory(device_id);
if (histogram_size <= max_smem) {
use_shared_memory_histograms = true;
}
// Init histogram
hist.Init(device_id, hmat.NumBins());
hist.Init(device_id, hmat.Ptrs().back());
}
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::CreateHistIndices(
const SparsePage &row_batch,
const common::HistCutMatrix &hmat,
const common::HistogramCuts &hmat,
const RowStateOnDevice &device_row_state,
int rows_per_batch) {
// Has any been allocated for me in this batch?
if (!device_row_state.rows_to_process_from_batch) return;
unsigned int null_gidx_value = hmat.row_ptr.back();
unsigned int null_gidx_value = hmat.Ptrs().back();
size_t row_stride = this->ellpack_matrix.row_stride;
const auto &offset_vec = row_batch.offset.ConstHostVector();
@@ -1184,8 +1184,8 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
@@ -1216,8 +1216,8 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
(entries_d.data().get(), data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
dh::DivRoundUp(row_stride, block3.y), 1);
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1);
CompressBinEllpackKernel<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
@@ -1361,13 +1361,13 @@ class GPUHistMakerSpecialised {
});
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistCutMatrix
// Create the quantile sketches for the dmatrix and initialize HistogramCuts
size_t row_stride = common::DeviceSketch(param_, *learner_param_,
hist_maker_param_.gpu_batch_nrows,
dmat, &hmat_);
monitor_.StopCuda("Quantiles");
n_bins_ = hmat_.row_ptr.back();
n_bins_ = hmat_.Ptrs().back();
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
@@ -1475,9 +1475,9 @@ class GPUHistMakerSpecialised {
return true;
}
TrainParam param_; // NOLINT
common::HistCutMatrix hmat_; // NOLINT
MetaInfo* info_; // NOLINT
TrainParam param_; // NOLINT
common::HistogramCuts hmat_; // NOLINT
MetaInfo* info_; // NOLINT
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT