Move ellpack page construction into DMatrix (#4833)

This commit is contained in:
Rong Ou 2019-09-16 20:50:55 -07:00 committed by Jiaming Yuan
parent 512f037e55
commit 125bcec62e
17 changed files with 761 additions and 513 deletions

View File

@ -29,6 +29,7 @@
// data
#include "../src/data/data.cc"
#include "../src/data/ellpack_page.cc"
#include "../src/data/simple_csr_source.cc"
#include "../src/data/simple_dmatrix.cc"
#include "../src/data/sparse_page_raw_format.cc"

View File

@ -26,6 +26,8 @@
namespace xgboost {
// forward declare learner.
class LearnerImpl;
// forward declare dmatrix.
class DMatrix;
/*! \brief data type accepted by xgboost interface */
enum DataType {
@ -86,7 +88,7 @@ class MetaInfo {
* \return The pre-defined root index of i-th instance.
*/
inline unsigned GetRoot(size_t i) const {
return root_index_.size() != 0 ? root_index_[i] : 0U;
return !root_index_.empty() ? root_index_[i] : 0U;
}
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
inline const std::vector<size_t>& LabelAbsSort() const {
@ -166,7 +168,7 @@ class SparsePage {
/*! \brief the data of the segments */
HostDeviceVector<Entry> data;
size_t base_rowid;
size_t base_rowid{};
/*! \brief an instance of sparse vector in the batch */
using Inst = common::Span<Entry const>;
@ -215,23 +217,23 @@ class SparsePage {
const int nthread = omp_get_max_threads();
builder.InitBudget(num_columns, nthread);
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = (*this)[i];
for (bst_uint j = 0; j < inst.size(); ++j) {
builder.AddBudget(inst[j].index, tid);
for (const auto& entry : inst) {
builder.AddBudget(entry.index, tid);
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static)
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = (*this)[i];
for (bst_uint j = 0; j < inst.size(); ++j) {
for (const auto& entry : inst) {
builder.Push(
inst[j].index,
Entry(static_cast<bst_uint>(this->base_rowid + i), inst[j].fvalue),
entry.index,
Entry(static_cast<bst_uint>(this->base_rowid + i), entry.fvalue),
tid);
}
}
@ -240,7 +242,7 @@ class SparsePage {
void SortRows() {
auto ncol = static_cast<bst_omp_uint>(this->Size());
#pragma omp parallel for schedule(dynamic, 1)
#pragma omp parallel for default(none) shared(ncol) schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
std::sort(
@ -287,10 +289,29 @@ class SortedCSCPage : public SparsePage {
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
};
class EllpackPageImpl;
/*!
* \brief A page stored in ELLPACK format.
*
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
* including CUDA-specific implementation details in the header.
*/
class EllpackPage {
public:
explicit EllpackPage(DMatrix* dmat);
~EllpackPage();
const EllpackPageImpl* Impl() const { return impl_.get(); }
EllpackPageImpl* Impl() { return impl_.get(); }
private:
std::unique_ptr<EllpackPageImpl> impl_;
};
template<typename T>
class BatchIteratorImpl {
public:
virtual ~BatchIteratorImpl() {}
virtual ~BatchIteratorImpl() = default;
virtual T& operator*() = 0;
virtual const T& operator*() const = 0;
virtual void operator++() = 0;
@ -412,7 +433,7 @@ class DMatrix {
bool silent,
bool load_row_split,
const std::string& file_format = "auto",
const size_t page_size = kPageSize);
size_t page_size = kPageSize);
/*!
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
@ -438,7 +459,7 @@ class DMatrix {
*/
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
const std::string& cache_prefix = "",
const size_t page_size = kPageSize);
size_t page_size = kPageSize);
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;
@ -447,6 +468,7 @@ class DMatrix {
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches() = 0;
};
template<>
@ -463,6 +485,11 @@ template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
}
template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches() {
return GetEllpackBatches();
}
} // namespace xgboost
namespace dmlc {

View File

@ -99,15 +99,15 @@ struct SketchContainer {
std::vector<std::mutex> col_locks_; // NOLINT
static constexpr int kOmpNumColsParallelizeLimit = 1000;
SketchContainer(const tree::TrainParam &param, DMatrix *dmat) :
SketchContainer(int max_bin, DMatrix *dmat) :
col_locks_(dmat->Info().num_col_) {
const MetaInfo &info = dmat->Info();
// Initialize Sketches for this dmatrix
sketches_.resize(info.num_col_);
#pragma omp parallel for default(none) shared(info, param) schedule(static) \
#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \
if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin));
sketches_[icol].Init(info.num_row_, 1.0 / (8 * max_bin));
}
}
@ -130,7 +130,7 @@ struct GPUSketcher {
bool has_weights_{false};
size_t row_stride_{0};
tree::TrainParam param_;
const int max_bin_;
SketchContainer *sketch_container_;
dh::device_vector<size_t> row_ptrs_{};
dh::device_vector<Entry> entries_{};
@ -148,11 +148,11 @@ struct GPUSketcher {
public:
DeviceShard(int device,
bst_uint n_rows,
tree::TrainParam param,
int max_bin,
SketchContainer* sketch_container) :
device_(device),
n_rows_(n_rows),
param_(std::move(param)),
max_bin_(max_bin),
sketch_container_(sketch_container) {
}
@ -183,7 +183,7 @@ struct GPUSketcher {
}
constexpr int kFactor = 8;
double eps = 1.0 / (kFactor * param_.max_bin);
double eps = 1.0 / (kFactor * max_bin_);
size_t dummy_nlevel;
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
@ -362,7 +362,7 @@ struct GPUSketcher {
// add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
#pragma omp parallel for default(none) schedule(static) \
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) {
WXQSketch::SummaryContainer summary;
summary.Reserve(n_cuts_);
@ -403,10 +403,8 @@ struct GPUSketcher {
};
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
auto device = generic_param_.gpu_id;
// create device shard
shard_.reset(new DeviceShard(device, batch.Size(), param_, sketch_container_.get()));
shard_.reset(new DeviceShard(device_, batch.Size(), max_bin_, sketch_container_.get()));
// compute sketches for the shard
shard_->Init(batch, info, gpu_batch_nrows_);
@ -417,9 +415,8 @@ struct GPUSketcher {
row_stride_ = shard_->GetRowStride();
}
GPUSketcher(const tree::TrainParam &param, const GenericParameter &generic_param, int gpu_nrows)
: param_(param), generic_param_(generic_param), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {
}
GPUSketcher(int device, int max_bin, int gpu_nrows)
: device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {}
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
* for the entire dataset */
@ -427,29 +424,31 @@ struct GPUSketcher {
const MetaInfo &info = dmat->Info();
row_stride_ = 0;
sketch_container_.reset(new SketchContainer(param_, dmat));
sketch_container_.reset(new SketchContainer(max_bin_, dmat));
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
this->SketchBatch(batch, info);
}
hmat->Init(&sketch_container_->sketches_, param_.max_bin);
hmat->Init(&sketch_container_->sketches_, max_bin_);
return row_stride_;
}
private:
std::unique_ptr<DeviceShard> shard_;
const tree::TrainParam &param_;
const GenericParameter &generic_param_;
const int device_;
const int max_bin_;
int gpu_batch_nrows_;
size_t row_stride_;
std::unique_ptr<SketchContainer> sketch_container_;
};
size_t DeviceSketch
(const tree::TrainParam &param, const GenericParameter &learner_param, int gpu_batch_nrows,
DMatrix *dmat, HistogramCuts *hmat) {
GPUSketcher sketcher(param, learner_param, gpu_batch_nrows);
size_t DeviceSketch(int device,
int max_bin,
int gpu_batch_nrows,
DMatrix* dmat,
HistogramCuts* hmat) {
GPUSketcher sketcher(device, max_bin, gpu_batch_nrows);
// We only need to return the result in HistogramCuts container, so it is safe to
// use a pointer of local HistogramCutsDense
DenseCuts dense_cuts(hmat);

View File

@ -290,10 +290,11 @@ class DenseCuts : public CutsBuilder {
*
* \return The row stride across the entire dataset.
*/
size_t DeviceSketch
(const tree::TrainParam& param, const GenericParameter &learner_param, int gpu_batch_nrows,
DMatrix* dmat, HistogramCuts* hmat);
size_t DeviceSketch(int device,
int max_bin,
int gpu_batch_nrows,
DMatrix* dmat,
HistogramCuts* hmat);
/*!
* \brief preprocessed global index matrix, in CSR format

25
src/data/ellpack_page.cc Normal file
View File

@ -0,0 +1,25 @@
/*!
* Copyright 2019 XGBoost contributors
*
* \file ellpack_page.cc
*/
#ifndef XGBOOST_USE_CUDA
#include <xgboost/data.h>
// dummy implementation of ELlpackPage in case CUDA is not used
namespace xgboost {
class EllpackPageImpl {};
EllpackPage::EllpackPage(DMatrix* dmat) {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
}
EllpackPage::~EllpackPage() {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
}
} // namespace xgboost
#endif // XGBOOST_USE_CUDA

197
src/data/ellpack_page.cu Normal file
View File

@ -0,0 +1,197 @@
/*!
* Copyright 2019 XGBoost contributors
*
* \file ellpack_page.cu
*/
#include <xgboost/data.h>
#include "./ellpack_page.cuh"
#include "../common/hist_util.h"
#include "../common/random.h"
namespace xgboost {
EllpackPage::EllpackPage(DMatrix* dmat) : impl_{new EllpackPageImpl(dmat)} {}
EllpackPage::~EllpackPage() = default;
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat) : dmat_{dmat} {}
// Bin each input data entry, store the bin indices in compressed form.
template<typename std::enable_if<true, int>::type = 0>
__global__ void CompressBinEllpackKernel(
common::CompressedBufferWriter wr,
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
unsigned int null_gidx_value) {
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
if (irow >= n_rows || ifeature >= row_stride) {
return;
}
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_length) {
Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float *feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
if (bin >= ncuts) {
bin = ncuts - 1;
}
// Add the number of bins in previous features.
bin += cut_rows[feature];
}
// Write to gidx buffer.
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
void EllpackPageImpl::Init(int device, int max_bin, int gpu_batch_nrows) {
if (initialised_) return;
monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(device));
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
common::HistogramCuts hmat;
size_t row_stride = common::DeviceSketch(device, max_bin, gpu_batch_nrows, dmat_, &hmat);
monitor_.StopCuda("Quantiles");
const auto& info = dmat_->Info();
auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_;
// Init global data for each shard
monitor_.StartCuda("InitCompressedData");
InitCompressedData(device, hmat, row_stride, is_dense);
monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(info.num_row_);
for (const auto& batch : dmat_->GetBatches<SparsePage>()) {
hist_builder_row_state.BeginBatch(batch);
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice());
hist_builder_row_state.EndBatch();
}
monitor_.StopCuda("BinningCompression");
initialised_ = true;
}
void EllpackPageImpl::InitCompressedData(int device,
const common::HistogramCuts& hmat,
size_t row_stride,
bool is_dense) {
n_bins = hmat.Ptrs().back();
int null_gidx_value = hmat.Ptrs().back();
int num_symbols = n_bins + 1;
// minimum value for each feature.
common::Span<bst_float> min_fvalue;
// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
row_stride * dmat_->Info().num_row_, num_symbols);
ba.Allocate(device,
&feature_segments, hmat.Ptrs().size(),
&gidx_fvalue_map, hmat.Values().size(),
&min_fvalue, hmat.MinValues().size(),
&gidx_buffer, compressed_size_bytes);
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
thrust::fill(
thrust::device_pointer_cast(gidx_buffer.data()),
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
ellpack_matrix.Init(feature_segments,
min_fvalue,
gidx_fvalue_map,
row_stride,
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
is_dense,
null_gidx_value);
}
void EllpackPageImpl::CreateHistIndices(int device,
const SparsePage& row_batch,
const RowStateOnDevice& device_row_state) {
// Has any been allocated for me in this batch?
if (!device_row_state.rows_to_process_from_batch) return;
unsigned int null_gidx_value = n_bins;
size_t row_stride = this->ellpack_matrix.row_stride;
const auto &offset_vec = row_batch.offset.ConstHostVector();
int num_symbols = n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows = std::min(
dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
batch_row_end = device_row_state.rows_to_process_from_batch;
}
size_t batch_nrows = batch_row_end - batch_row_begin;
const auto ent_cnt_begin =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin];
const auto ent_cnt_end =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];
/*! \brief row offset in SparsePage (the input data). */
dh::device_vector<size_t> row_ptrs(batch_nrows+1);
thrust::copy(
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
row_ptrs.begin());
// number of entries in this batch.
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpy(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry),
cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y),
1);
CompressBinEllpackKernel<<<grid3, block3>>>(
common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
row_ptrs.data().get(),
entries_d.data().get(),
gidx_fvalue_map.data(),
feature_segments.data(),
device_row_state.total_rows_processed + batch_row_begin,
batch_nrows,
row_stride,
null_gidx_value);
}
}
} // namespace xgboost

203
src/data/ellpack_page.cuh Normal file
View File

@ -0,0 +1,203 @@
/*!
* Copyright 2019 by XGBoost Contributors
*
* \file ellpack_page.cuh
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
#define XGBOOST_DATA_ELLPACK_PAGE_H_
#include <xgboost/data.h>
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
namespace xgboost {
// Find a gidx value for a given feature otherwise return -1 if not found
__forceinline__ __device__ int BinarySearchRow(
bst_uint begin, bst_uint end,
common::CompressedIterator<uint32_t> data,
int const fidx_begin, int const fidx_end) {
bst_uint previous_middle = UINT32_MAX;
while (end != begin) {
auto middle = begin + (end - begin) / 2;
if (middle == previous_middle) {
break;
}
previous_middle = middle;
auto gidx = data[middle];
if (gidx >= fidx_begin && gidx < fidx_end) {
return gidx;
} else if (gidx < fidx_begin) {
begin = middle;
} else {
end = middle;
}
}
// Value is missing
return -1;
}
/** \brief Struct for accessing and manipulating an ellpack matrix on the
* device. Does not own underlying memory and may be trivially copied into
* kernels.*/
struct ELLPackMatrix {
common::Span<uint32_t> feature_segments;
/*! \brief minimum value for each feature. */
common::Span<bst_float> min_fvalue;
/*! \brief Cut. */
common::Span<bst_float> gidx_fvalue_map;
/*! \brief row length for ELLPack. */
size_t row_stride{0};
common::CompressedIterator<uint32_t> gidx_iter;
int null_gidx_value;
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
auto gidx = -1;
if (is_dense) {
gidx = gidx_iter[row_begin + fidx];
} else {
gidx =
BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx],
feature_segments[fidx + 1]);
}
if (gidx == -1) {
return nan("");
}
return gidx_fvalue_map[gidx];
}
void Init(common::Span<uint32_t> feature_segments,
common::Span<bst_float> min_fvalue,
common::Span<bst_float> gidx_fvalue_map, size_t row_stride,
common::CompressedIterator<uint32_t> gidx_iter, bool is_dense,
int null_gidx_value) {
this->feature_segments = feature_segments;
this->min_fvalue = min_fvalue;
this->gidx_fvalue_map = gidx_fvalue_map;
this->row_stride = row_stride;
this->gidx_iter = gidx_iter;
this->is_dense = is_dense;
this->null_gidx_value = null_gidx_value;
}
private:
bool is_dense;
};
// Instances of this type are created while creating the histogram bins for the
// entire dataset across multiple sparse page batches. This keeps track of the number
// of rows to process from a batch and the position from which to process on each device.
struct RowStateOnDevice {
// Number of rows assigned to this device
size_t total_rows_assigned_to_device;
// Number of rows processed thus far
size_t total_rows_processed;
// Number of rows to process from the current sparse page batch
size_t rows_to_process_from_batch;
// Offset from the current sparse page batch to begin processing
size_t row_offset_in_current_batch;
explicit RowStateOnDevice(size_t total_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(0), row_offset_in_current_batch(0) {
}
explicit RowStateOnDevice(size_t total_rows, size_t batch_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) {
}
// Advance the row state by the number of rows processed
void Advance() {
total_rows_processed += rows_to_process_from_batch;
CHECK_LE(total_rows_processed, total_rows_assigned_to_device);
rows_to_process_from_batch = row_offset_in_current_batch = 0;
}
};
// An instance of this type is created which keeps track of total number of rows to process,
// rows processed thus far, rows to process and the offset from the current sparse page batch
// to begin processing on each device
class DeviceHistogramBuilderState {
public:
explicit DeviceHistogramBuilderState(int n_rows) : device_row_state_(n_rows) {}
const RowStateOnDevice& GetRowStateOnDevice() const {
return device_row_state_;
}
// This method is invoked at the beginning of each sparse page batch. This distributes
// the rows in the sparse page to the device.
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
void BeginBatch(const SparsePage &batch) {
size_t rem_rows = batch.Size();
size_t row_offset_in_current_batch = 0;
// Do we have anymore left to process from this batch on this device?
if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) {
// There are still some rows that needs to be assigned to this device
device_row_state_.rows_to_process_from_batch =
std::min(
device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed,
rem_rows);
} else {
// All rows have been assigned to this device
device_row_state_.rows_to_process_from_batch = 0;
}
device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch;
row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch;
rem_rows -= device_row_state_.rows_to_process_from_batch;
}
// This method is invoked after completion of each sparse page batch
void EndBatch() {
device_row_state_.Advance();
}
private:
RowStateOnDevice device_row_state_{0};
};
class EllpackPageImpl {
public:
ELLPackMatrix ellpack_matrix;
int n_bins{};
/*! \brief global index of histogram, which is stored in ELLPack format. */
common::Span<common::CompressedByteT> gidx_buffer;
explicit EllpackPageImpl(DMatrix* dmat);
void Init(int device, int max_bin, int gpu_batch_nrows);
void InitCompressedData(int device,
const common::HistogramCuts& hmat,
size_t row_stride,
bool is_dense);
void CreateHistIndices(int device,
const SparsePage& row_batch,
const RowStateOnDevice& device_row_state);
private:
bool initialised_{false};
DMatrix* dmat_;
common::Monitor monitor_;
dh::BulkAllocator ba;
/*! \brief Cut. */
common::Span<bst_float> gidx_fvalue_map;
/*! \brief row_ptr form HistogramCuts. */
common::Span<uint32_t> feature_segments;
};
} // namespace xgboost
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_

View File

@ -0,0 +1,33 @@
/*!
* Copyright 2019 XGBoost contributors
*/
#ifndef XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
#define XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
#include <xgboost/data.h>
namespace xgboost {
namespace data {
template<typename T>
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
T& operator*() override {
CHECK(page_ != nullptr);
return *page_;
}
const T& operator*() const override {
CHECK(page_ != nullptr);
return *page_;
}
void operator++() override { page_ = nullptr; }
bool AtEnd() const override { return page_ == nullptr; }
private:
T* page_{nullptr};
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_

View File

@ -6,6 +6,7 @@
*/
#include "./simple_dmatrix.h"
#include <xgboost/data.h>
#include "./simple_batch_iterator.h"
#include "../common/random.h"
namespace xgboost {
@ -29,25 +30,6 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
}
template<typename T>
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
T& operator*() override {
CHECK(page_ != nullptr);
return *page_;
}
const T& operator*() const override {
CHECK(page_ != nullptr);
return *page_;
}
void operator++() override { page_ = nullptr; }
bool AtEnd() const override { return page_ == nullptr; }
private:
T* page_{nullptr};
};
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
@ -80,6 +62,16 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
return BatchSet<SortedCSCPage>(begin_iter);
}
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches() {
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
ellpack_page_.reset(new EllpackPage(this));
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
return BatchSet<EllpackPage>(begin_iter);
}
bool SimpleDMatrix::SingleColBlock() const { return true; }
} // namespace data
} // namespace xgboost

View File

@ -38,12 +38,14 @@ class SimpleDMatrix : public DMatrix {
BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches() override;
// source data pointer.
std::unique_ptr<DataSource<SparsePage>> source_;
std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_;
};
} // namespace data
} // namespace xgboost

View File

@ -10,6 +10,8 @@
#if DMLC_ENABLE_STD_THREAD
#include "./sparse_page_dmatrix.h"
#include "./simple_batch_iterator.h"
namespace xgboost {
namespace data {
@ -72,6 +74,16 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
return BatchSet<SortedCSCPage>(begin_iter);
}
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches() {
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
ellpack_page_.reset(new EllpackPage(this));
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
return BatchSet<EllpackPage>(begin_iter);
}
float SparsePageDMatrix::GetColDensity(size_t cidx) {
// Finds densities if we don't already have them
if (col_density_.empty()) {

View File

@ -24,7 +24,7 @@ class SparsePageDMatrix : public DMatrix {
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
std::string cache_info)
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
virtual ~SparsePageDMatrix() = default;
~SparsePageDMatrix() override = default;
MetaInfo& Info() override;
@ -38,11 +38,13 @@ class SparsePageDMatrix : public DMatrix {
BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches() override;
// source data pointers.
std::unique_ptr<DataSource<SparsePage>> row_source_;
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
std::unique_ptr<EllpackPage> ellpack_page_;
// the cache prefix
std::string cache_info_;
// Store column densities to avoid recalculating

View File

@ -21,6 +21,7 @@
#include "../common/host_device_vector.h"
#include "../common/timer.h"
#include "../common/span.h"
#include "../data/ellpack_page.cuh"
#include "param.h"
#include "updater_gpu_common.cuh"
#include "constraints.cuh"
@ -108,83 +109,6 @@ inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
}
}
// Find a gidx value for a given feature otherwise return -1 if not found
__forceinline__ __device__ int BinarySearchRow(
bst_uint begin, bst_uint end,
common::CompressedIterator<uint32_t> data,
int const fidx_begin, int const fidx_end) {
bst_uint previous_middle = UINT32_MAX;
while (end != begin) {
auto middle = begin + (end - begin) / 2;
if (middle == previous_middle) {
break;
}
previous_middle = middle;
auto gidx = data[middle];
if (gidx >= fidx_begin && gidx < fidx_end) {
return gidx;
} else if (gidx < fidx_begin) {
begin = middle;
} else {
end = middle;
}
}
// Value is missing
return -1;
}
/** \brief Struct for accessing and manipulating an ellpack matrix on the
* device. Does not own underlying memory and may be trivially copied into
* kernels.*/
struct ELLPackMatrix {
common::Span<uint32_t> feature_segments;
/*! \brief minimum value for each feature. */
common::Span<bst_float> min_fvalue;
/*! \brief Cut. */
common::Span<bst_float> gidx_fvalue_map;
/*! \brief row length for ELLPack. */
size_t row_stride{0};
common::CompressedIterator<uint32_t> gidx_iter;
bool is_dense;
int null_gidx_value;
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
auto gidx = -1;
if (is_dense) {
gidx = gidx_iter[row_begin + fidx];
} else {
gidx =
BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx],
feature_segments[fidx + 1]);
}
if (gidx == -1) {
return nan("");
}
return gidx_fvalue_map[gidx];
}
void Init(common::Span<uint32_t> feature_segments,
common::Span<bst_float> min_fvalue,
common::Span<bst_float> gidx_fvalue_map, size_t row_stride,
common::CompressedIterator<uint32_t> gidx_iter, bool is_dense,
int null_gidx_value) {
this->feature_segments = feature_segments;
this->min_fvalue = min_fvalue;
this->gidx_fvalue_map = gidx_fvalue_map;
this->row_stride = row_stride;
this->gidx_iter = gidx_iter;
this->is_dense = is_dense;
this->null_gidx_value = null_gidx_value;
}
};
// With constraints
template <typename GradientPairT>
XGBOOST_DEVICE float inline LossChangeMissing(
@ -247,7 +171,7 @@ template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
__device__ void EvaluateFeature(
int fidx, common::Span<const GradientSumT> node_histogram,
const ELLPackMatrix& matrix,
const xgboost::ELLPackMatrix& matrix,
DeviceSplitCandidate* best_split, // shared memory storing best split
const DeviceNodeStats& node, const GPUTrainingParam& param,
TempStorageT* temp_storage, // temp memory for cub operations
@ -322,7 +246,7 @@ __global__ void EvaluateSplitKernel(
common::Span<const GradientSumT> node_histogram, // histogram for gradients
common::Span<const int> feature_set, // Selected features
DeviceNodeStats node,
ELLPackMatrix matrix,
xgboost::ELLPackMatrix matrix,
GPUTrainingParam gpu_param,
common::Span<DeviceSplitCandidate> split_candidates, // resulting split
ValueConstraint value_constraint,
@ -473,48 +397,8 @@ struct CalcWeightTrainParam {
learning_rate(p.learning_rate) {}
};
// Bin each input data entry, store the bin indices in compressed form.
template<typename std::enable_if<true, int>::type = 0>
__global__ void CompressBinEllpackKernel(
common::CompressedBufferWriter wr,
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
unsigned int null_gidx_value) {
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
if (irow >= n_rows || ifeature >= row_stride) {
return;
}
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_length) {
Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float *feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
if (bin >= ncuts) {
bin = ncuts - 1;
}
// Add the number of bins in previous features.
bin += cut_rows[feature];
}
// Write to gidx buffer.
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
template <typename GradientSumT>
__global__ void SharedMemHistKernel(ELLPackMatrix matrix,
__global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* d_node_hist,
const GradientPair* d_gpair, size_t n_elements,
@ -548,59 +432,17 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
}
}
// Instances of this type are created while creating the histogram bins for the
// entire dataset across multiple sparse page batches. This keeps track of the number
// of rows to process from a batch and the position from which to process on each device.
struct RowStateOnDevice {
// Number of rows assigned to this device
size_t total_rows_assigned_to_device;
// Number of rows processed thus far
size_t total_rows_processed;
// Number of rows to process from the current sparse page batch
size_t rows_to_process_from_batch;
// Offset from the current sparse page batch to begin processing
size_t row_offset_in_current_batch;
explicit RowStateOnDevice(size_t total_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(0), row_offset_in_current_batch(0) {
}
explicit RowStateOnDevice(size_t total_rows, size_t batch_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) {
}
// Advance the row state by the number of rows processed
void Advance() {
total_rows_processed += rows_to_process_from_batch;
CHECK_LE(total_rows_processed, total_rows_assigned_to_device);
rows_to_process_from_batch = row_offset_in_current_batch = 0;
}
};
// Manage memory for a single GPU
template <typename GradientSumT>
struct DeviceShard {
int n_bins;
int device_id;
EllpackPageImpl* page;
dh::BulkAllocator ba;
ELLPackMatrix ellpack_matrix;
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist{};
/*! \brief row_ptr form HistogramCuts. */
common::Span<uint32_t> feature_segments;
/*! \brief minimum value for each feature. */
common::Span<bst_float> min_fvalue;
/*! \brief Cut. */
common::Span<bst_float> gidx_fvalue_map;
/*! \brief global index of histogram, which is stored in ELLPack format. */
common::Span<common::CompressedByteT> gidx_buffer;
/*! \brief Gradient pair for each row. */
common::Span<GradientPair> gpair;
@ -631,11 +473,15 @@ struct DeviceShard {
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
DeviceShard(int _device_id, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed,
DeviceShard(int _device_id,
EllpackPageImpl* _page,
bst_uint _n_rows,
TrainParam _param,
uint32_t column_sampler_seed,
uint32_t n_features)
: device_id(_device_id),
page(_page),
n_rows(_n_rows),
n_bins(0),
param(std::move(_param)),
prediction_cache_initialised(false),
column_sampler(column_sampler_seed),
@ -643,12 +489,7 @@ struct DeviceShard {
monitor.Init(std::string("DeviceShard") + std::to_string(device_id));
}
void InitCompressedData(
const common::HistogramCuts& hmat, size_t row_stride, bool is_dense);
void CreateHistIndices(
const SparsePage &row_batch, const common::HistogramCuts &hmat,
const RowStateOnDevice &device_row_state, int rows_per_batch);
void InitHistogram();
~DeviceShard() { // NOLINT
dh::safe_cuda(cudaSetDevice(device_id));
@ -762,7 +603,7 @@ struct DeviceShard {
int constexpr kBlockThreads = 256;
EvaluateSplitKernel<kBlockThreads, GradientSumT>
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix,
hist.GetNodeHistogram(nidx), d_feature_set, node, page->ellpack_matrix,
gpu_param, d_split_candidates, node_value_constraints[nidx],
monotone_constraints);
@ -788,11 +629,11 @@ struct DeviceShard {
auto d_ridx = row_partitioner->GetRows(nidx);
auto d_gpair = gpair.data();
auto n_elements = d_ridx.size() * ellpack_matrix.row_stride;
auto n_elements = d_ridx.size() * page->ellpack_matrix.row_stride;
const size_t smem_size =
use_shared_memory_histograms
? sizeof(GradientSumT) * ellpack_matrix.BinCount()
? sizeof(GradientSumT) * page->ellpack_matrix.BinCount()
: 0;
const int items_per_thread = 8;
const int block_threads = 256;
@ -802,7 +643,7 @@ struct DeviceShard {
return;
}
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
page->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
use_shared_memory_histograms);
}
@ -812,7 +653,7 @@ struct DeviceShard {
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
dh::LaunchN(device_id, n_bins, [=] __device__(size_t idx) {
dh::LaunchN(device_id, page->n_bins, [=] __device__(size_t idx) {
d_node_hist_subtraction[idx] =
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
});
@ -827,7 +668,7 @@ struct DeviceShard {
}
void UpdatePosition(int nidx, RegTree::Node split_node) {
auto d_matrix = ellpack_matrix;
auto d_matrix = page->ellpack_matrix;
row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(),
@ -859,7 +700,7 @@ struct DeviceShard {
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));
auto d_matrix = ellpack_matrix;
auto d_matrix = page->ellpack_matrix;
row_partitioner->FinalisePosition(
[=] __device__(bst_uint ridx, int position) {
auto node = d_nodes[position];
@ -922,7 +763,7 @@ struct DeviceShard {
reducer->AllReduceSum(
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
ellpack_matrix.BinCount() *
page->ellpack_matrix.BinCount() *
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
reducer->Synchronize();
@ -1097,11 +938,7 @@ struct DeviceShard {
};
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::InitCompressedData(
const common::HistogramCuts &hmat, size_t row_stride, bool is_dense) {
n_bins = hmat.Ptrs().back();
int null_gidx_value = hmat.Ptrs().back();
inline void DeviceShard<GradientSumT>::InitHistogram() {
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
@ -1113,163 +950,25 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
&gpair, n_rows,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
&feature_segments, hmat.Ptrs().size(),
&gidx_fvalue_map, hmat.Values().size(),
&min_fvalue, hmat.MinValues().size(),
&monotone_constraints, param.monotone_constraints.size());
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
node_sum_gradients.resize(max_nodes);
// allocate compressed bin data
int num_symbols = n_bins + 1;
// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
num_symbols);
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
thrust::fill(
thrust::device_pointer_cast(gidx_buffer.data()),
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
ellpack_matrix.Init(
feature_segments, min_fvalue,
gidx_fvalue_map, row_stride,
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
is_dense, null_gidx_value);
// check if we can use shared memory for building histograms
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
// hiding)
auto histogram_size = sizeof(GradientSumT) * hmat.Ptrs().back();
auto histogram_size = sizeof(GradientSumT) * page->n_bins;
auto max_smem = dh::MaxSharedMemory(device_id);
if (histogram_size <= max_smem) {
use_shared_memory_histograms = true;
}
// Init histogram
hist.Init(device_id, hmat.Ptrs().back());
hist.Init(device_id, page->n_bins);
}
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::CreateHistIndices(
const SparsePage &row_batch,
const common::HistogramCuts &hmat,
const RowStateOnDevice &device_row_state,
int rows_per_batch) {
// Has any been allocated for me in this batch?
if (!device_row_state.rows_to_process_from_batch) return;
unsigned int null_gidx_value = hmat.Ptrs().back();
size_t row_stride = this->ellpack_matrix.row_stride;
const auto &offset_vec = row_batch.offset.ConstHostVector();
int num_symbols = n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows = std::min(
dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
batch_row_end = device_row_state.rows_to_process_from_batch;
}
size_t batch_nrows = batch_row_end - batch_row_begin;
const auto ent_cnt_begin =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin];
const auto ent_cnt_end =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];
/*! \brief row offset in SparsePage (the input data). */
dh::device_vector<size_t> row_ptrs(batch_nrows+1);
thrust::copy(
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
row_ptrs.begin());
// number of entries in this batch.
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda
(cudaMemcpy
(entries_d.data().get(), data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1);
CompressBinEllpackKernel<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
row_ptrs.data().get(),
entries_d.data().get(),
gidx_fvalue_map.data(),
feature_segments.data(),
device_row_state.total_rows_processed + batch_row_begin,
batch_nrows,
row_stride,
null_gidx_value);
}
}
// An instance of this type is created which keeps track of total number of rows to process,
// rows processed thus far, rows to process and the offset from the current sparse page batch
// to begin processing on each device
class DeviceHistogramBuilderState {
public:
template <typename GradientSumT>
explicit DeviceHistogramBuilderState(const std::unique_ptr<DeviceShard<GradientSumT>>& shard)
: device_row_state_(shard->n_rows) {}
const RowStateOnDevice& GetRowStateOnDevice() const {
return device_row_state_;
}
// This method is invoked at the beginning of each sparse page batch. This distributes
// the rows in the sparse page to the device.
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
void BeginBatch(const SparsePage &batch) {
size_t rem_rows = batch.Size();
size_t row_offset_in_current_batch = 0;
// Do we have anymore left to process from this batch on this device?
if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) {
// There are still some rows that needs to be assigned to this device
device_row_state_.rows_to_process_from_batch =
std::min(
device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed,
rem_rows);
} else {
// All rows have been assigned to this device
device_row_state_.rows_to_process_from_batch = 0;
}
device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch;
row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch;
rem_rows -= device_row_state_.rows_to_process_from_batch;
}
// This method is invoked after completion of each sparse page batch
void EndBatch() {
device_row_state_.Advance();
}
private:
RowStateOnDevice device_row_state_{0};
};
template <typename GradientSumT>
class GPUHistMakerSpecialised {
public:
@ -1319,47 +1018,33 @@ class GPUHistMakerSpecialised {
uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
// TODO(rongou): support multiple Ellpack pages.
EllpackPageImpl* page{};
for (auto& batch : dmat->GetBatches<EllpackPage>()) {
page = batch.Impl();
page->Init(device_, param_.max_bin, hist_maker_param_.gpu_batch_nrows);
}
// Create device shard
dh::safe_cuda(cudaSetDevice(device_));
shard_.reset(new DeviceShard<GradientSumT>(device_,
page,
info_->num_row_,
param_,
column_sampling_seed,
info_->num_col_));
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts
size_t row_stride = common::DeviceSketch(param_, *generic_param_,
hist_maker_param_.gpu_batch_nrows,
dmat, &hmat_);
monitor_.StopCuda("Quantiles");
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
// Init global data for each shard
monitor_.StartCuda("InitCompressedData");
dh::safe_cuda(cudaSetDevice(shard_->device_id));
shard_->InitCompressedData(hmat_, row_stride, is_dense);
monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(shard_);
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
hist_builder_row_state.BeginBatch(batch);
dh::safe_cuda(cudaSetDevice(shard_->device_id));
shard_->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(),
hist_maker_param_.gpu_batch_nrows);
hist_builder_row_state.EndBatch();
}
monitor_.StopCuda("BinningCompression");
monitor_.StartCuda("InitHistogram");
dh::safe_cuda(cudaSetDevice(device_));
shard_->InitHistogram();
monitor_.StopCuda("InitHistogram");
p_last_fmat_ = dmat;
initialised_ = true;
}
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) {
void InitData(DMatrix* dmat) {
if (!initialised_) {
monitor_.StartCuda("InitDataOnce");
this->InitDataOnce(dmat);
@ -1387,7 +1072,7 @@ class GPUHistMakerSpecialised {
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
RegTree* p_tree) {
monitor_.StartCuda("InitData");
this->InitData(gpair, p_fmat);
this->InitData(p_fmat);
monitor_.StopCuda("InitData");
gpair->SetDevice(device_);
@ -1408,7 +1093,6 @@ class GPUHistMakerSpecialised {
}
TrainParam param_; // NOLINT
common::HistogramCuts hmat_; // NOLINT
MetaInfo* info_{}; // NOLINT
std::unique_ptr<DeviceShard<GradientSumT>> shard_; // NOLINT

View File

@ -43,18 +43,17 @@ void TestDeviceSketch(bool use_external_memory) {
dmat = static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
}
tree::TrainParam p;
p.max_bin = 20;
int gpu_batch_nrows = 0;
int device{0};
int max_bin{20};
int gpu_batch_nrows{0};
// find quantiles on the CPU
HistogramCuts hmat_cpu;
hmat_cpu.Build((*dmat).get(), p.max_bin);
hmat_cpu.Build((*dmat).get(), max_bin);
// find the cuts on the GPU
HistogramCuts hmat_gpu;
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0), gpu_batch_nrows,
dmat->get(), &hmat_gpu);
size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu);
// compare the row stride with the one obtained from the dmatrix
size_t expected_row_stride = 0;

View File

@ -0,0 +1,86 @@
/*!
* Copyright 2019 XGBoost contributors
*/
#include <xgboost/base.h>
#include <utility>
#include "../helpers.h"
#include "gtest/gtest.h"
#include "../../../src/common/hist_util.h"
#include "../../../src/data/ellpack_page.cuh"
namespace xgboost {
TEST(EllpackPage, EmptyDMatrix) {
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256, kGpuBatchNRows = 64;
constexpr float kSparsity = 0;
auto dmat = *CreateDMatrix(kNRows, kNCols, kSparsity);
auto& page = *dmat->GetBatches<EllpackPage>().begin();
auto impl = page.Impl();
impl->Init(0, kMaxBin, kGpuBatchNRows);
ASSERT_EQ(impl->ellpack_matrix.feature_segments.size(), 1);
ASSERT_EQ(impl->ellpack_matrix.min_fvalue.size(), 0);
ASSERT_EQ(impl->ellpack_matrix.gidx_fvalue_map.size(), 0);
ASSERT_EQ(impl->ellpack_matrix.row_stride, 0);
ASSERT_EQ(impl->ellpack_matrix.null_gidx_value, 0);
ASSERT_EQ(impl->n_bins, 0);
ASSERT_EQ(impl->gidx_buffer.size(), 4);
}
TEST(EllpackPage, BuildGidxDense) {
int constexpr kNRows = 16, kNCols = 8;
auto page = BuildEllpackPage(kNRows, kNCols);
std::vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer);
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
ASSERT_EQ(page->ellpack_matrix.row_stride, kNCols);
std::vector<uint32_t> solution = {
0, 3, 8, 9, 14, 17, 20, 21,
0, 4, 7, 10, 14, 16, 19, 22,
1, 3, 7, 11, 14, 15, 19, 21,
2, 3, 7, 9, 13, 16, 20, 22,
2, 3, 6, 9, 12, 16, 20, 21,
1, 5, 6, 10, 13, 16, 20, 21,
2, 5, 8, 9, 13, 17, 19, 22,
2, 4, 6, 10, 14, 17, 19, 21,
2, 5, 7, 9, 13, 16, 19, 22,
0, 3, 8, 10, 12, 16, 19, 22,
1, 3, 7, 10, 13, 16, 19, 21,
1, 3, 8, 10, 13, 17, 20, 22,
2, 4, 6, 9, 14, 15, 19, 22,
1, 4, 6, 9, 13, 16, 19, 21,
2, 4, 8, 10, 14, 15, 19, 22,
1, 4, 7, 10, 14, 16, 19, 21,
};
for (size_t i = 0; i < kNRows * kNCols; ++i) {
ASSERT_EQ(solution[i], gidx[i]);
}
}
TEST(EllpackPage, BuildGidxSparse) {
int constexpr kNRows = 16, kNCols = 8;
auto page = BuildEllpackPage(kNRows, kNCols, 0.9f);
std::vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer);
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
ASSERT_LE(page->ellpack_matrix.row_stride, 3);
// row_stride = 3, 16 rows, 48 entries for ELLPack
std::vector<uint32_t> solution = {
15, 24, 24, 0, 24, 24, 24, 24, 24, 24, 24, 24, 20, 24, 24, 24,
24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24,
24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24
};
for (size_t i = 0; i < kNRows * page->ellpack_matrix.row_stride; ++i) {
ASSERT_EQ(solution[i], gidx[i]);
}
}
} // namespace xgboost

View File

@ -21,6 +21,10 @@
#include <xgboost/generic_parameters.h>
#include "../../src/common/common.h"
#include "../../src/common/hist_util.h"
#if defined(__CUDACC__)
#include "../../src/data/ellpack_page.cuh"
#endif
#if defined(__CUDACC__)
#define DeclareUnifiedTest(name) GPU ## name
@ -197,5 +201,58 @@ inline GenericParameter CreateEmptyGenericParam(int gpu_id) {
return tparam;
}
#if defined(__CUDACC__)
namespace {
class HistogramCutsWrapper : public common::HistogramCuts {
public:
using SuperT = common::HistogramCuts;
void SetValues(std::vector<float> cuts) {
SuperT::cut_values_ = std::move(cuts);
}
void SetPtrs(std::vector<uint32_t> ptrs) {
SuperT::cut_ptrs_ = std::move(ptrs);
}
void SetMins(std::vector<float> mins) {
SuperT::min_vals_ = std::move(mins);
}
};
} // anonymous namespace
inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
int n_rows, int n_cols, bst_float sparsity= 0) {
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
const SparsePage& batch = *(*dmat)->GetBatches<xgboost::SparsePage>().begin();
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
// 24 cut fields, 3 cut fields for each feature (column).
cmat.SetValues({0.30f, 0.67f, 1.64f,
0.32f, 0.77f, 1.95f,
0.29f, 0.70f, 1.80f,
0.32f, 0.75f, 1.85f,
0.18f, 0.59f, 1.69f,
0.25f, 0.74f, 2.00f,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f});
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
auto is_dense = (*dmat)->Info().num_nonzero_ ==
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
size_t row_stride = 0;
const auto &offset_vec = batch.offset.ConstHostVector();
for (size_t i = 1; i < offset_vec.size(); ++i) {
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
}
auto page = std::unique_ptr<EllpackPageImpl>(new EllpackPageImpl(dmat->get()));
page->InitCompressedData(0, cmat, row_stride, is_dense);
page->CreateHistIndices(0, batch, RowStateOnDevice(batch.Size(), batch.Size()));
delete dmat;
return page;
}
#endif
} // namespace xgboost
#endif

View File

@ -98,82 +98,13 @@ void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
for (size_t i = 1; i < offset_vec.size(); ++i) {
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
}
shard->InitCompressedData(cmat, row_stride, is_dense);
shard->InitHistogram(cmat, row_stride, is_dense);
shard->CreateHistIndices(
batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1);
delete dmat;
}
TEST(GpuHist, BuildGidxDense) {
int constexpr kNRows = 16, kNCols = 8;
tree::TrainParam param;
std::vector<std::pair<std::string, std::string>> args {
{"max_depth", "1"},
{"max_leaves", "0"},
};
param.Init(args);
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
BuildGidx(&shard, kNRows, kNCols);
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer);
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
ASSERT_EQ(shard.ellpack_matrix.row_stride, kNCols);
std::vector<uint32_t> solution = {
0, 3, 8, 9, 14, 17, 20, 21,
0, 4, 7, 10, 14, 16, 19, 22,
1, 3, 7, 11, 14, 15, 19, 21,
2, 3, 7, 9, 13, 16, 20, 22,
2, 3, 6, 9, 12, 16, 20, 21,
1, 5, 6, 10, 13, 16, 20, 21,
2, 5, 8, 9, 13, 17, 19, 22,
2, 4, 6, 10, 14, 17, 19, 21,
2, 5, 7, 9, 13, 16, 19, 22,
0, 3, 8, 10, 12, 16, 19, 22,
1, 3, 7, 10, 13, 16, 19, 21,
1, 3, 8, 10, 13, 17, 20, 22,
2, 4, 6, 9, 14, 15, 19, 22,
1, 4, 6, 9, 13, 16, 19, 21,
2, 4, 8, 10, 14, 15, 19, 22,
1, 4, 7, 10, 14, 16, 19, 21,
};
for (size_t i = 0; i < kNRows * kNCols; ++i) {
ASSERT_EQ(solution[i], gidx[i]);
}
}
TEST(GpuHist, BuildGidxSparse) {
int constexpr kNRows = 16, kNCols = 8;
TrainParam param;
std::vector<std::pair<std::string, std::string>> args {
{"max_depth", "1"},
{"max_leaves", "0"},
};
param.Init(args);
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
BuildGidx(&shard, kNRows, kNCols, 0.9f);
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer);
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
ASSERT_LE(shard.ellpack_matrix.row_stride, 3);
// row_stride = 3, 16 rows, 48 entries for ELLPack
std::vector<uint32_t> solution = {
15, 24, 24, 0, 24, 24, 24, 24, 24, 24, 24, 24, 20, 24, 24, 24,
24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24,
24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24
};
for (size_t i = 0; i < kNRows * shard.ellpack_matrix.row_stride; ++i) {
ASSERT_EQ(solution[i], gidx[i]);
}
}
std::vector<GradientPairPrecise> GetHostHistGpair() {
// 24 bins, 3 bins for each feature (column).
std::vector<GradientPairPrecise> hist_gpair = {
@ -199,9 +130,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
{"max_leaves", "0"},
};
param.Init(args);
DeviceShard<GradientSumT> shard(0, kNRows, param, kNCols, kNCols);
BuildGidx(&shard, kNRows, kNCols);
auto page = BuildEllpackPage(kNRows, kNCols);
DeviceShard<GradientSumT> shard(0, page.get(), kNRows, param, kNCols, kNCols);
shard.InitHistogram();
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(kNRows);
@ -211,12 +143,11 @@ void TestBuildHist(bool use_shared_memory_histograms) {
gpair = GradientPair(grad, hess);
}
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (
shard.gidx_buffer.size());
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.size());
common::CompressedByteT* d_gidx_buffer_ptr = shard.gidx_buffer.data();
common::CompressedByteT* d_gidx_buffer_ptr = page->gidx_buffer.data();
dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr,
sizeof(common::CompressedByteT) * shard.gidx_buffer.size(),
sizeof(common::CompressedByteT) * page->gidx_buffer.size(),
cudaMemcpyDeviceToHost));
shard.row_partitioner.reset(new RowPartitioner(0, kNRows));
@ -300,8 +231,9 @@ TEST(GpuHist, EvaluateSplits) {
int max_bins = 4;
// Initialize DeviceShard
auto page = BuildEllpackPage(kNRows, kNCols);
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
new DeviceShard<GradientPairPrecise>(0, kNRows, param, kNCols, kNCols)};
new DeviceShard<GradientPairPrecise>(0, page.get(), kNRows, param, kNCols, kNCols)};
// Initialize DeviceShard::node_sum_gradients
shard->node_sum_gradients = {{6.4f, 12.8f}};
@ -310,18 +242,14 @@ TEST(GpuHist, EvaluateSplits) {
// Copy cut matrix to device.
shard->ba.Allocate(0,
&(shard->feature_segments), cmat.Ptrs().size(),
&(shard->min_fvalue), cmat.MinValues().size(),
&(shard->gidx_fvalue_map), 24,
&(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(),
&(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(),
&(page->ellpack_matrix.gidx_fvalue_map), 24,
&(shard->monotone_constraints), kNCols);
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.Ptrs());
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.Values());
dh::CopyVectorToDeviceSpan(shard->monotone_constraints,
param.monotone_constraints);
shard->ellpack_matrix.feature_segments = shard->feature_segments;
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.MinValues());
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.feature_segments, cmat.Ptrs());
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.gidx_fvalue_map, cmat.Values());
dh::CopyVectorToDeviceSpan(shard->monotone_constraints, param.monotone_constraints);
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.min_fvalue, cmat.MinValues());
// Initialize DeviceShard::hist
shard->hist.Init(0, (max_bins - 1) * kNCols);
@ -391,15 +319,15 @@ void TestHistogramIndexImpl() {
// Extract the device shard from the histogram makers and from that its compressed
// histogram index
const auto &dev_shard = hist_maker.shard_;
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->page->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->page->gidx_buffer);
const auto &dev_shard_ext = hist_maker_ext.shard_;
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->page->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->page->gidx_buffer);
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
ASSERT_EQ(dev_shard->page->n_bins, dev_shard_ext->page->n_bins);
ASSERT_EQ(dev_shard->page->gidx_buffer.size(), dev_shard_ext->page->gidx_buffer.size());
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
}