Move ellpack page construction into DMatrix (#4833)

This commit is contained in:
Rong Ou
2019-09-16 20:50:55 -07:00
committed by Jiaming Yuan
parent 512f037e55
commit 125bcec62e
17 changed files with 761 additions and 513 deletions

25
src/data/ellpack_page.cc Normal file
View File

@@ -0,0 +1,25 @@
/*!
* Copyright 2019 XGBoost contributors
*
* \file ellpack_page.cc
*/
#ifndef XGBOOST_USE_CUDA
#include <xgboost/data.h>
// dummy implementation of ELlpackPage in case CUDA is not used
namespace xgboost {
class EllpackPageImpl {};
EllpackPage::EllpackPage(DMatrix* dmat) {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
}
EllpackPage::~EllpackPage() {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
}
} // namespace xgboost
#endif // XGBOOST_USE_CUDA

197
src/data/ellpack_page.cu Normal file
View File

@@ -0,0 +1,197 @@
/*!
* Copyright 2019 XGBoost contributors
*
* \file ellpack_page.cu
*/
#include <xgboost/data.h>
#include "./ellpack_page.cuh"
#include "../common/hist_util.h"
#include "../common/random.h"
namespace xgboost {
EllpackPage::EllpackPage(DMatrix* dmat) : impl_{new EllpackPageImpl(dmat)} {}
EllpackPage::~EllpackPage() = default;
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat) : dmat_{dmat} {}
// Bin each input data entry, store the bin indices in compressed form.
template<typename std::enable_if<true, int>::type = 0>
__global__ void CompressBinEllpackKernel(
common::CompressedBufferWriter wr,
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
unsigned int null_gidx_value) {
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
if (irow >= n_rows || ifeature >= row_stride) {
return;
}
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_length) {
Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float *feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
if (bin >= ncuts) {
bin = ncuts - 1;
}
// Add the number of bins in previous features.
bin += cut_rows[feature];
}
// Write to gidx buffer.
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
void EllpackPageImpl::Init(int device, int max_bin, int gpu_batch_nrows) {
if (initialised_) return;
monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(device));
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
common::HistogramCuts hmat;
size_t row_stride = common::DeviceSketch(device, max_bin, gpu_batch_nrows, dmat_, &hmat);
monitor_.StopCuda("Quantiles");
const auto& info = dmat_->Info();
auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_;
// Init global data for each shard
monitor_.StartCuda("InitCompressedData");
InitCompressedData(device, hmat, row_stride, is_dense);
monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(info.num_row_);
for (const auto& batch : dmat_->GetBatches<SparsePage>()) {
hist_builder_row_state.BeginBatch(batch);
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice());
hist_builder_row_state.EndBatch();
}
monitor_.StopCuda("BinningCompression");
initialised_ = true;
}
void EllpackPageImpl::InitCompressedData(int device,
const common::HistogramCuts& hmat,
size_t row_stride,
bool is_dense) {
n_bins = hmat.Ptrs().back();
int null_gidx_value = hmat.Ptrs().back();
int num_symbols = n_bins + 1;
// minimum value for each feature.
common::Span<bst_float> min_fvalue;
// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
row_stride * dmat_->Info().num_row_, num_symbols);
ba.Allocate(device,
&feature_segments, hmat.Ptrs().size(),
&gidx_fvalue_map, hmat.Values().size(),
&min_fvalue, hmat.MinValues().size(),
&gidx_buffer, compressed_size_bytes);
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
thrust::fill(
thrust::device_pointer_cast(gidx_buffer.data()),
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
ellpack_matrix.Init(feature_segments,
min_fvalue,
gidx_fvalue_map,
row_stride,
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
is_dense,
null_gidx_value);
}
void EllpackPageImpl::CreateHistIndices(int device,
const SparsePage& row_batch,
const RowStateOnDevice& device_row_state) {
// Has any been allocated for me in this batch?
if (!device_row_state.rows_to_process_from_batch) return;
unsigned int null_gidx_value = n_bins;
size_t row_stride = this->ellpack_matrix.row_stride;
const auto &offset_vec = row_batch.offset.ConstHostVector();
int num_symbols = n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows = std::min(
dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
batch_row_end = device_row_state.rows_to_process_from_batch;
}
size_t batch_nrows = batch_row_end - batch_row_begin;
const auto ent_cnt_begin =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin];
const auto ent_cnt_end =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];
/*! \brief row offset in SparsePage (the input data). */
dh::device_vector<size_t> row_ptrs(batch_nrows+1);
thrust::copy(
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
row_ptrs.begin());
// number of entries in this batch.
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpy(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry),
cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y),
1);
CompressBinEllpackKernel<<<grid3, block3>>>(
common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
row_ptrs.data().get(),
entries_d.data().get(),
gidx_fvalue_map.data(),
feature_segments.data(),
device_row_state.total_rows_processed + batch_row_begin,
batch_nrows,
row_stride,
null_gidx_value);
}
}
} // namespace xgboost

203
src/data/ellpack_page.cuh Normal file
View File

@@ -0,0 +1,203 @@
/*!
* Copyright 2019 by XGBoost Contributors
*
* \file ellpack_page.cuh
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
#define XGBOOST_DATA_ELLPACK_PAGE_H_
#include <xgboost/data.h>
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
namespace xgboost {
// Find a gidx value for a given feature otherwise return -1 if not found
__forceinline__ __device__ int BinarySearchRow(
bst_uint begin, bst_uint end,
common::CompressedIterator<uint32_t> data,
int const fidx_begin, int const fidx_end) {
bst_uint previous_middle = UINT32_MAX;
while (end != begin) {
auto middle = begin + (end - begin) / 2;
if (middle == previous_middle) {
break;
}
previous_middle = middle;
auto gidx = data[middle];
if (gidx >= fidx_begin && gidx < fidx_end) {
return gidx;
} else if (gidx < fidx_begin) {
begin = middle;
} else {
end = middle;
}
}
// Value is missing
return -1;
}
/** \brief Struct for accessing and manipulating an ellpack matrix on the
* device. Does not own underlying memory and may be trivially copied into
* kernels.*/
struct ELLPackMatrix {
common::Span<uint32_t> feature_segments;
/*! \brief minimum value for each feature. */
common::Span<bst_float> min_fvalue;
/*! \brief Cut. */
common::Span<bst_float> gidx_fvalue_map;
/*! \brief row length for ELLPack. */
size_t row_stride{0};
common::CompressedIterator<uint32_t> gidx_iter;
int null_gidx_value;
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
auto gidx = -1;
if (is_dense) {
gidx = gidx_iter[row_begin + fidx];
} else {
gidx =
BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx],
feature_segments[fidx + 1]);
}
if (gidx == -1) {
return nan("");
}
return gidx_fvalue_map[gidx];
}
void Init(common::Span<uint32_t> feature_segments,
common::Span<bst_float> min_fvalue,
common::Span<bst_float> gidx_fvalue_map, size_t row_stride,
common::CompressedIterator<uint32_t> gidx_iter, bool is_dense,
int null_gidx_value) {
this->feature_segments = feature_segments;
this->min_fvalue = min_fvalue;
this->gidx_fvalue_map = gidx_fvalue_map;
this->row_stride = row_stride;
this->gidx_iter = gidx_iter;
this->is_dense = is_dense;
this->null_gidx_value = null_gidx_value;
}
private:
bool is_dense;
};
// Instances of this type are created while creating the histogram bins for the
// entire dataset across multiple sparse page batches. This keeps track of the number
// of rows to process from a batch and the position from which to process on each device.
struct RowStateOnDevice {
// Number of rows assigned to this device
size_t total_rows_assigned_to_device;
// Number of rows processed thus far
size_t total_rows_processed;
// Number of rows to process from the current sparse page batch
size_t rows_to_process_from_batch;
// Offset from the current sparse page batch to begin processing
size_t row_offset_in_current_batch;
explicit RowStateOnDevice(size_t total_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(0), row_offset_in_current_batch(0) {
}
explicit RowStateOnDevice(size_t total_rows, size_t batch_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) {
}
// Advance the row state by the number of rows processed
void Advance() {
total_rows_processed += rows_to_process_from_batch;
CHECK_LE(total_rows_processed, total_rows_assigned_to_device);
rows_to_process_from_batch = row_offset_in_current_batch = 0;
}
};
// An instance of this type is created which keeps track of total number of rows to process,
// rows processed thus far, rows to process and the offset from the current sparse page batch
// to begin processing on each device
class DeviceHistogramBuilderState {
public:
explicit DeviceHistogramBuilderState(int n_rows) : device_row_state_(n_rows) {}
const RowStateOnDevice& GetRowStateOnDevice() const {
return device_row_state_;
}
// This method is invoked at the beginning of each sparse page batch. This distributes
// the rows in the sparse page to the device.
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
void BeginBatch(const SparsePage &batch) {
size_t rem_rows = batch.Size();
size_t row_offset_in_current_batch = 0;
// Do we have anymore left to process from this batch on this device?
if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) {
// There are still some rows that needs to be assigned to this device
device_row_state_.rows_to_process_from_batch =
std::min(
device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed,
rem_rows);
} else {
// All rows have been assigned to this device
device_row_state_.rows_to_process_from_batch = 0;
}
device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch;
row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch;
rem_rows -= device_row_state_.rows_to_process_from_batch;
}
// This method is invoked after completion of each sparse page batch
void EndBatch() {
device_row_state_.Advance();
}
private:
RowStateOnDevice device_row_state_{0};
};
class EllpackPageImpl {
public:
ELLPackMatrix ellpack_matrix;
int n_bins{};
/*! \brief global index of histogram, which is stored in ELLPack format. */
common::Span<common::CompressedByteT> gidx_buffer;
explicit EllpackPageImpl(DMatrix* dmat);
void Init(int device, int max_bin, int gpu_batch_nrows);
void InitCompressedData(int device,
const common::HistogramCuts& hmat,
size_t row_stride,
bool is_dense);
void CreateHistIndices(int device,
const SparsePage& row_batch,
const RowStateOnDevice& device_row_state);
private:
bool initialised_{false};
DMatrix* dmat_;
common::Monitor monitor_;
dh::BulkAllocator ba;
/*! \brief Cut. */
common::Span<bst_float> gidx_fvalue_map;
/*! \brief row_ptr form HistogramCuts. */
common::Span<uint32_t> feature_segments;
};
} // namespace xgboost
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_

View File

@@ -0,0 +1,33 @@
/*!
* Copyright 2019 XGBoost contributors
*/
#ifndef XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
#define XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
#include <xgboost/data.h>
namespace xgboost {
namespace data {
template<typename T>
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
T& operator*() override {
CHECK(page_ != nullptr);
return *page_;
}
const T& operator*() const override {
CHECK(page_ != nullptr);
return *page_;
}
void operator++() override { page_ = nullptr; }
bool AtEnd() const override { return page_ == nullptr; }
private:
T* page_{nullptr};
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_

View File

@@ -6,6 +6,7 @@
*/
#include "./simple_dmatrix.h"
#include <xgboost/data.h>
#include "./simple_batch_iterator.h"
#include "../common/random.h"
namespace xgboost {
@@ -29,25 +30,6 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
}
template<typename T>
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
T& operator*() override {
CHECK(page_ != nullptr);
return *page_;
}
const T& operator*() const override {
CHECK(page_ != nullptr);
return *page_;
}
void operator++() override { page_ = nullptr; }
bool AtEnd() const override { return page_ == nullptr; }
private:
T* page_{nullptr};
};
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
@@ -80,6 +62,16 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
return BatchSet<SortedCSCPage>(begin_iter);
}
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches() {
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
ellpack_page_.reset(new EllpackPage(this));
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
return BatchSet<EllpackPage>(begin_iter);
}
bool SimpleDMatrix::SingleColBlock() const { return true; }
} // namespace data
} // namespace xgboost

View File

@@ -38,12 +38,14 @@ class SimpleDMatrix : public DMatrix {
BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches() override;
// source data pointer.
std::unique_ptr<DataSource<SparsePage>> source_;
std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_;
};
} // namespace data
} // namespace xgboost

View File

@@ -10,6 +10,8 @@
#if DMLC_ENABLE_STD_THREAD
#include "./sparse_page_dmatrix.h"
#include "./simple_batch_iterator.h"
namespace xgboost {
namespace data {
@@ -72,6 +74,16 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
return BatchSet<SortedCSCPage>(begin_iter);
}
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches() {
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
ellpack_page_.reset(new EllpackPage(this));
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
return BatchSet<EllpackPage>(begin_iter);
}
float SparsePageDMatrix::GetColDensity(size_t cidx) {
// Finds densities if we don't already have them
if (col_density_.empty()) {

View File

@@ -24,7 +24,7 @@ class SparsePageDMatrix : public DMatrix {
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
std::string cache_info)
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
virtual ~SparsePageDMatrix() = default;
~SparsePageDMatrix() override = default;
MetaInfo& Info() override;
@@ -38,11 +38,13 @@ class SparsePageDMatrix : public DMatrix {
BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches() override;
// source data pointers.
std::unique_ptr<DataSource<SparsePage>> row_source_;
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
std::unique_ptr<EllpackPage> ellpack_page_;
// the cache prefix
std::string cache_info_;
// Store column densities to avoid recalculating