Gradient based sampling for GPU Hist (#5093)

* Implement gradient based sampling for GPU Hist tree method.
* Add samplers and handle compacted page in GPU Hist.
This commit is contained in:
Rong Ou 2020-02-03 18:31:27 -08:00 committed by GitHub
parent c74216f22c
commit e4b74c4d22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1187 additions and 175 deletions

View File

@ -193,6 +193,36 @@ class GradientPairInternal {
return g;
}
XGBOOST_DEVICE GradientPairInternal<T> &operator*=(float multiplier) {
grad_ *= multiplier;
hess_ *= multiplier;
return *this;
}
XGBOOST_DEVICE GradientPairInternal<T> operator*(float multiplier) const {
GradientPairInternal<T> g;
g.grad_ = grad_ * multiplier;
g.hess_ = hess_ * multiplier;
return g;
}
XGBOOST_DEVICE GradientPairInternal<T> &operator/=(float divisor) {
grad_ /= divisor;
hess_ /= divisor;
return *this;
}
XGBOOST_DEVICE GradientPairInternal<T> operator/(float divisor) const {
GradientPairInternal<T> g;
g.grad_ = grad_ / divisor;
g.hess_ = hess_ / divisor;
return g;
}
XGBOOST_DEVICE bool operator==(const GradientPairInternal<T> &rhs) const {
return grad_ == rhs.grad_ && hess_ == rhs.hess_;
}
XGBOOST_DEVICE explicit GradientPairInternal(int value) {
*this = GradientPairInternal<T>(static_cast<float>(value),
static_cast<float>(value));

View File

@ -63,7 +63,7 @@ class CompressedBufferWriter {
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
* num_elements, int num_symbols)
*
* \brief Calculates number of bytes requiredm for a given number of elements
* \brief Calculates number of bytes required for a given number of elements
* and a symbol range.
*
* \author Rory
@ -74,7 +74,6 @@ class CompressedBufferWriter {
*
* \return The calculated buffer size.
*/
static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) {
const int bits_per_byte = 8;
size_t compressed_size = static_cast<size_t>(std::ceil(
@ -188,7 +187,7 @@ class CompressedIterator {
public:
CompressedIterator() : buffer_(nullptr), symbol_bits_(0), offset_(0) {}
CompressedIterator(CompressedByteT *buffer, int num_symbols)
CompressedIterator(CompressedByteT *buffer, size_t num_symbols)
: buffer_(buffer), offset_(0) {
symbol_bits_ = detail::SymbolBits(num_symbols);
}

View File

@ -1266,6 +1266,26 @@ thrust::device_ptr<T const> tcend(xgboost::HostDeviceVector<T> const& vector) {
return tcbegin(vector) + vector.Size();
}
template <typename T>
thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT
return thrust::device_ptr<T>(span.data());
}
template <typename T>
thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // // NOLINT
return tbegin(span) + span.size();
}
template <typename T>
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) {
return thrust::device_ptr<T const>(span.data());
}
template <typename T>
thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) {
return tcbegin(span) + span.size();
}
template <typename FunctionT>
class LauncherItr {
public:

View File

@ -64,6 +64,20 @@ __global__ void CompressBinEllpackKernel(
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
// Construct an ELLPACK matrix with the given number of empty rows.
EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) {
monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(device));
matrix.info = info;
matrix.base_rowid = 0;
matrix.n_rows = n_rows;
monitor_.StartCuda("InitCompressedData");
InitCompressedData(device, n_rows);
monitor_.StopCuda("InitCompressedData");
}
// Construct an ELLPACK matrix in memory.
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.Init("ellpack_page");
@ -96,6 +110,85 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.StopCuda("BinningCompression");
}
// A functor that copies the data from one EllpackPage to another.
struct CopyPage {
common::CompressedBufferWriter cbw;
common::CompressedByteT* dst_data_d;
common::CompressedIterator<uint32_t> src_iterator_d;
// The number of elements to skip.
size_t offset;
CopyPage(EllpackPageImpl* dst, EllpackPageImpl* src, size_t offset)
: cbw{dst->matrix.info.NumSymbols()},
dst_data_d{dst->gidx_buffer.data()},
src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()},
offset(offset) {}
__device__ void operator()(size_t element_id) {
cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], element_id + offset);
}
};
// Copy the data from the given EllpackPage to the current page.
size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) {
monitor_.StartCuda("Copy");
size_t num_elements = page->matrix.n_rows * page->matrix.info.row_stride;
CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride);
CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols());
CHECK_GE(matrix.n_rows * matrix.info.row_stride, offset + num_elements);
dh::LaunchN(device, num_elements, CopyPage(this, page, offset));
monitor_.StopCuda("Copy");
return num_elements;
}
// A functor that compacts the rows from one EllpackPage into another.
struct CompactPage {
common::CompressedBufferWriter cbw;
common::CompressedByteT* dst_data_d;
common::CompressedIterator<uint32_t> src_iterator_d;
/*! \brief An array that maps the rows from the full DMatrix to the compacted page.
*
* The total size is the number of rows in the original, uncompacted DMatrix. Elements are the
* row ids in the compacted page. Rows not needed are set to SIZE_MAX.
*
* An example compacting 16 rows to 8 rows:
* [SIZE_MAX, 0, 1, SIZE_MAX, SIZE_MAX, 2, SIZE_MAX, 3, 4, 5, SIZE_MAX, 6, SIZE_MAX, 7, SIZE_MAX,
* SIZE_MAX]
*/
common::Span<size_t> row_indexes;
size_t base_rowid;
size_t row_stride;
CompactPage(EllpackPageImpl* dst, EllpackPageImpl* src, common::Span<size_t> row_indexes)
: cbw{dst->matrix.info.NumSymbols()},
dst_data_d{dst->gidx_buffer.data()},
src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()},
row_indexes(row_indexes),
base_rowid{src->matrix.base_rowid},
row_stride{src->matrix.info.row_stride} {}
__device__ void operator()(size_t row_id) {
size_t src_row = base_rowid + row_id;
size_t dst_row = row_indexes[src_row];
if (dst_row == SIZE_MAX) return;
size_t dst_offset = dst_row * row_stride;
size_t src_offset = row_id * row_stride;
for (size_t j = 0; j < row_stride; j++) {
cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset + j], dst_offset + j);
}
}
};
// Compacts the data from the given EllpackPage into the current page.
void EllpackPageImpl::Compact(int device, EllpackPageImpl* page, common::Span<size_t> row_indexes) {
monitor_.StartCuda("Compact");
CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride);
CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols());
CHECK_LE(page->matrix.base_rowid + page->matrix.n_rows, row_indexes.size());
dh::LaunchN(device, page->matrix.n_rows, CompactPage(this, page, row_indexes));
monitor_.StopCuda("Compact");
}
// Construct an EllpackInfo based on histogram cuts of features.
EllpackInfo::EllpackInfo(int device,
bool is_dense,
@ -123,16 +216,14 @@ void EllpackPageImpl::InitInfo(int device,
// Initialize the buffer to stored compressed features.
void EllpackPageImpl::InitCompressedData(int device, size_t num_rows) {
size_t num_symbols = matrix.info.n_bins + 1;
size_t num_symbols = matrix.info.NumSymbols();
// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
matrix.info.row_stride * num_rows, num_symbols);
ba_.Allocate(device, &gidx_buffer, compressed_size_bytes);
thrust::fill(
thrust::device_pointer_cast(gidx_buffer.data()),
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
thrust::fill(dh::tbegin(gidx_buffer), dh::tend(gidx_buffer), 0);
matrix.gidx_iter = common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols);
}
@ -149,7 +240,6 @@ void EllpackPageImpl::CreateHistIndices(int device,
const auto& offset_vec = row_batch.offset.ConstHostVector();
int num_symbols = matrix.info.n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows = std::min(
dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
@ -193,7 +283,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
1);
dh::LaunchKernel {grid3, block3} (
CompressBinEllpackKernel,
common::CompressedBufferWriter(num_symbols),
common::CompressedBufferWriter(matrix.info.NumSymbols()),
gidx_buffer.data(),
row_ptrs.data().get(),
entries_d.data().get(),
@ -254,11 +344,9 @@ void EllpackPageImpl::CompressSparsePage(int device) {
// Return the memory cost for storing the compressed features.
size_t EllpackPageImpl::MemCostBytes() const {
size_t num_symbols = matrix.info.n_bins + 1;
// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
matrix.info.row_stride * matrix.n_rows, num_symbols);
matrix.info.row_stride * matrix.n_rows, matrix.info.NumSymbols());
return compressed_size_bytes;
}
@ -280,5 +368,4 @@ void EllpackPageImpl::InitDevice(int device, EllpackInfo info) {
device_initialized_ = true;
}
} // namespace xgboost

View File

@ -71,6 +71,11 @@ struct EllpackInfo {
size_t row_stride,
const common::HistogramCuts& hmat,
dh::BulkAllocator* ba);
/*! \brief Return the total number of symbols (total number of bins plus 1 for not found). */
size_t NumSymbols() const {
return n_bins + 1;
}
};
/** \brief Struct for accessing and manipulating an ellpack matrix on the
@ -200,6 +205,14 @@ class EllpackPageImpl {
*/
EllpackPageImpl() = default;
/*!
* \brief Constructor from an existing EllpackInfo.
*
* This is used in the sampling case. The ELLPACK page is constructed from an existing EllpackInfo
* and the given number of rows.
*/
explicit EllpackPageImpl(int device, EllpackInfo info, size_t n_rows);
/*!
* \brief Constructor from an existing DMatrix.
*
@ -208,6 +221,23 @@ class EllpackPageImpl {
*/
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
/*! \brief Copy the elements of the given ELLPACK page into this page.
*
* @param device The GPU device to use.
* @param page The ELLPACK page to copy from.
* @param offset The number of elements to skip before copying.
* @returns The number of elements copied.
*/
size_t Copy(int device, EllpackPageImpl* page, size_t offset);
/*! \brief Compact the given ELLPACK page into the current page.
*
* @param device The GPU device to use.
* @param page The ELLPACK page to compact from.
* @param row_indexes Row indexes for the compacted page.
*/
void Compact(int device, EllpackPageImpl* page, common::Span<size_t> row_indexes);
/*!
* \brief Initialize the EllpackInfo contained in the EllpackMatrix.
*

View File

@ -0,0 +1,380 @@
/*!
* Copyright 2019 by XGBoost Contributors
*/
#include <thrust/functional.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/logging.h>
#include <algorithm>
#include <limits>
#include "../../common/compressed_iterator.h"
#include "../../common/random.h"
#include "gradient_based_sampler.cuh"
namespace xgboost {
namespace tree {
/*! \brief A functor that returns random weights. */
class RandomWeight : public thrust::unary_function<size_t, float> {
public:
explicit RandomWeight(size_t seed) : seed_(seed) {}
XGBOOST_DEVICE float operator()(size_t i) const {
thrust::default_random_engine rng(seed_);
thrust::uniform_real_distribution<float> dist;
rng.discard(i);
return dist(rng);
}
private:
uint32_t seed_;
};
/*! \brief A functor that performs a Bernoulli trial to discard a gradient pair. */
class BernoulliTrial : public thrust::unary_function<size_t, bool> {
public:
BernoulliTrial(size_t seed, float p) : rnd_(seed), p_(p) {}
XGBOOST_DEVICE bool operator()(size_t i) const {
return rnd_(i) > p_;
}
private:
RandomWeight rnd_;
float p_;
};
/*! \brief A functor that returns true if the gradient pair is non-zero. */
struct IsNonZero : public thrust::unary_function<GradientPair, bool> {
XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const {
return gpair.GetGrad() != 0 || gpair.GetHess() != 0;
}
};
/*! \brief A functor that clears the row indexes with empty gradient. */
struct ClearEmptyRows : public thrust::binary_function<GradientPair, size_t, size_t> {
XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const {
if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) {
return row_index;
} else {
return std::numeric_limits<std::size_t>::max();
}
}
};
/*! \brief A functor that combines the gradient pair into a single float.
*
* The approach here is based on Minimal Variance Sampling (MVS), with lambda set to 0.1.
*
* \see Ibragimov, B., & Gusev, G. (2019). Minimal Variance Sampling in Stochastic Gradient
* Boosting. In Advances in Neural Information Processing Systems (pp. 15061-15071).
*/
class CombineGradientPair : public thrust::unary_function<GradientPair, float> {
public:
XGBOOST_DEVICE float operator()(const GradientPair& gpair) const {
return sqrtf(powf(gpair.GetGrad(), 2) + kLambda * powf(gpair.GetHess(), 2));
}
private:
static constexpr float kLambda = 0.1f;
};
/*! \brief A functor that calculates the difference between the sample rate and the desired sample
* rows, given a cumulative gradient sum.
*/
class SampleRateDelta : public thrust::binary_function<float, size_t, float> {
public:
SampleRateDelta(common::Span<float> threshold, size_t n_rows, size_t sample_rows)
: threshold_(threshold), n_rows_(n_rows), sample_rows_(sample_rows) {}
XGBOOST_DEVICE float operator()(float gradient_sum, size_t row_index) const {
float lower = threshold_[row_index];
float upper = threshold_[row_index + 1];
float u = gradient_sum / static_cast<float>(sample_rows_ - n_rows_ + row_index + 1);
if (u > lower && u <= upper) {
threshold_[row_index + 1] = u;
return 0.0f;
} else {
return std::numeric_limits<float>::max();
}
}
private:
common::Span<float> threshold_;
size_t n_rows_;
size_t sample_rows_;
};
/*! \brief A functor that performs Poisson sampling, and scales gradient pairs by 1/p_i. */
class PoissonSampling : public thrust::binary_function<GradientPair, size_t, GradientPair> {
public:
PoissonSampling(common::Span<float> threshold, size_t threshold_index, RandomWeight rnd)
: threshold_(threshold), threshold_index_(threshold_index), rnd_(rnd) {}
XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) {
// If the gradient and hessian are both empty, we should never select this row.
if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) {
return gpair;
}
float combined_gradient = combine_(gpair);
float u = threshold_[threshold_index_];
float p = combined_gradient / u;
if (p >= 1) {
// Always select this row.
return gpair;
} else {
// Select this row randomly with probability proportional to the combined gradient.
// Scale gpair by 1/p.
if (rnd_(i) <= p) {
return gpair / p;
} else {
return GradientPair();
}
}
}
private:
common::Span<float> threshold_;
size_t threshold_index_;
RandomWeight rnd_;
CombineGradientPair combine_;
};
NoSampling::NoSampling(EllpackPageImpl* page) : page_(page) {}
GradientBasedSample NoSampling::Sample(common::Span<GradientPair> gpair, DMatrix* dmat) {
return {dmat->Info().num_row_, page_, gpair};
}
ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param)
: batch_param_(batch_param),
page_(new EllpackPageImpl(batch_param.gpu_id, page->matrix.info, n_rows)) {}
GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
if (!page_concatenated_) {
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
auto page = batch.Impl();
size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset);
offset += num_elements;
}
page_concatenated_ = true;
}
return {dmat->Info().num_row_, page_.get(), gpair};
}
UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample)
: page_(page), subsample_(subsample) {}
GradientBasedSample UniformSampling::Sample(common::Span<GradientPair> gpair, DMatrix* dmat) {
// Set gradient pair to 0 with p = 1 - subsample
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_),
GradientPair());
return {dmat->Info().num_row_, page_, gpair};
}
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample)
: original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows);
}
GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
// Set gradient pair to 0 with p = 1 - subsample
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_),
GradientPair());
// Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
size_t n_rows = dmat->Info().num_row_;
// Compact gradient pairs.
gpair_.resize(sample_rows);
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
ClearEmptyRows());
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(batch_param_.gpu_id,
original_page_->matrix.info,
sample_rows));
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
}
GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : page_(page), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows);
}
GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
threshold_index,
RandomWeight(common::GlobalRandom()())));
return {n_rows, page_, gpair};
}
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows,
&sample_row_index_, n_rows);
}
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
threshold_index,
RandomWeight(common::GlobalRandom()())));
// Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
// Compact gradient pairs.
gpair_.resize(sample_rows);
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
ClearEmptyRows());
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(batch_param_.gpu_id,
original_page_->matrix.info,
sample_rows));
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
}
GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample,
int sampling_method) {
monitor_.Init("gradient_based_sampler");
bool is_sampling = subsample < 1.0;
bool is_external_memory = page->matrix.n_rows != n_rows;
if (is_sampling) {
switch (sampling_method) {
case TrainParam::kUniform:
if (is_external_memory) {
strategy_.reset(new ExternalMemoryUniformSampling(page, n_rows, batch_param, subsample));
} else {
strategy_.reset(new UniformSampling(page, subsample));
}
break;
case TrainParam::kGradientBased:
if (is_external_memory) {
strategy_.reset(
new ExternalMemoryGradientBasedSampling(page, n_rows, batch_param, subsample));
} else {
strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample));
}
break;
default:LOG(FATAL) << "unknown sampling method";
}
} else {
if (is_external_memory) {
strategy_.reset(new ExternalMemoryNoSampling(page, n_rows, batch_param));
} else {
strategy_.reset(new NoSampling(page));
}
}
}
// Sample a DMatrix based on the given gradient pairs.
GradientBasedSample GradientBasedSampler::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
monitor_.StartCuda("Sample");
GradientBasedSample sample = strategy_->Sample(gpair, dmat);
monitor_.StopCuda("Sample");
return sample;
}
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
size_t sample_rows) {
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(threshold),
CombineGradientPair());
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum));
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
thrust::counting_iterator<size_t>(0),
dh::tbegin(grad_sum),
SampleRateDelta(threshold, gpair.size(), sample_rows));
thrust::device_ptr<float> min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
}
}; // namespace tree
}; // namespace xgboost

View File

@ -0,0 +1,153 @@
/*!
* Copyright 2019 by XGBoost Contributors
*/
#pragma once
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/span.h>
#include "../../common/device_helpers.cuh"
#include "../../data/ellpack_page.cuh"
namespace xgboost {
namespace tree {
struct GradientBasedSample {
/*!\brief Number of sampled rows. */
size_t sample_rows;
/*!\brief Sampled rows in ELLPACK format. */
EllpackPageImpl* page;
/*!\brief Gradient pairs for the sampled rows. */
common::Span<GradientPair> gpair;
};
class SamplingStrategy {
public:
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
virtual GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) = 0;
};
/*! \brief No sampling in in-memory mode. */
class NoSampling : public SamplingStrategy {
public:
explicit NoSampling(EllpackPageImpl* page);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
EllpackPageImpl* page_;
};
/*! \brief No sampling in external memory mode. */
class ExternalMemoryNoSampling : public SamplingStrategy {
public:
ExternalMemoryNoSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
BatchParam batch_param_;
std::unique_ptr<EllpackPageImpl> page_;
bool page_concatenated_{false};
};
/*! \brief Uniform sampling in in-memory mode. */
class UniformSampling : public SamplingStrategy {
public:
UniformSampling(EllpackPageImpl* page, float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
EllpackPageImpl* page_;
float subsample_;
};
/*! \brief No sampling in external memory mode. */
class ExternalMemoryUniformSampling : public SamplingStrategy {
public:
ExternalMemoryUniformSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_{};
common::Span<size_t> sample_row_index_;
};
/*! \brief Gradient-based sampling in in-memory mode.. */
class GradientBasedSampling : public SamplingStrategy {
public:
GradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
EllpackPageImpl* page_;
float subsample_;
dh::BulkAllocator ba_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
};
/*! \brief Gradient-based sampling in external memory mode.. */
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
public:
ExternalMemoryGradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_;
common::Span<size_t> sample_row_index_;
};
/*! \brief Draw a sample of rows from a DMatrix.
*
* \see Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., ... & Liu, T. Y. (2017).
* Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information
* Processing Systems (pp. 3146-3154).
* \see Zhu, R. (2016). Gradient-based sampling: An adaptive importance sampling for least-squares.
* In Advances in Neural Information Processing Systems (pp. 406-414).
* \see Ohlsson, E. (1998). Sequential poisson sampling. Journal of official Statistics, 14(2), 149.
*/
class GradientBasedSampler {
public:
GradientBasedSampler(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample,
int sampling_method);
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat);
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
size_t sample_rows);
private:
common::Monitor monitor_;
std::unique_ptr<SamplingStrategy> strategy_;
};
}; // namespace tree
}; // namespace xgboost

View File

@ -125,7 +125,6 @@ class RowPartitioner {
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx); // new node id
if (new_position == kIgnoredTreePosition) return;
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position;

View File

@ -50,6 +50,9 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
float max_delta_step;
// whether we want to do subsample
float subsample;
// sampling method
enum SamplingMethod { kUniform = 0, kGradientBased = 1 };
int sampling_method;
// whether to subsample columns in each split (node)
float colsample_bynode;
// whether to subsample columns in each level
@ -144,6 +147,14 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe("Row subsample ratio of training instance.");
DMLC_DECLARE_FIELD(sampling_method)
.set_default(kUniform)
.add_enum("uniform", kUniform)
.add_enum("gradient_based", kGradientBased)
.describe(
"Sampling method. 0: select random training instances uniformly. "
"1: select random training instances with higher probability when the "
"gradient and hessian are larger. (cf. CatBoost)");
DMLC_DECLARE_FIELD(colsample_bynode)
.set_range(0.0f, 1.0f)
.set_default(1.0f)

View File

@ -148,6 +148,9 @@ class BaseMaker: public TreeUpdater {
}
// mark subsample
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
std::bernoulli_distribution coin_flip(param_.subsample);
auto& rnd = common::GlobalRandom();
for (size_t i = 0; i < position_.size(); ++i) {

View File

@ -202,6 +202,9 @@ class ColMaker: public TreeUpdater {
}
// mark subsample
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
std::bernoulli_distribution coin_flip(param_.subsample);
auto& rnd = common::GlobalRandom();
for (size_t ridx = 0; ridx < position_.size(); ++ridx) {

View File

@ -187,41 +187,5 @@ XGBOOST_DEVICE inline int MaxNodesDepth(int depth) {
return (1 << (depth + 1)) - 1;
}
/*
* Random
*/
struct BernoulliRng {
float p;
uint32_t seed;
XGBOOST_DEVICE BernoulliRng(float p, size_t seed_) : p(p) {
seed = static_cast<uint32_t>(seed_);
}
XGBOOST_DEVICE bool operator()(const int i) const {
thrust::default_random_engine rng(seed);
thrust::uniform_real_distribution<float> dist;
rng.discard(i);
return dist(rng) <= p;
}
};
// Set gradient pair to 0 with p = 1 - subsample
inline void SubsampleGradientPair(int device_idx,
common::Span<GradientPair> d_gpair,
float subsample, int offset = 0) {
if (subsample == 1.0) {
return;
}
BernoulliRng rng(subsample, common::GlobalRandom()());
dh::LaunchN(device_idx, d_gpair.size(), [=] XGBOOST_DEVICE(int i) {
if (!rng(i + offset)) {
d_gpair[i] = GradientPair();
}
});
}
} // namespace tree
} // namespace xgboost

View File

@ -29,6 +29,7 @@
#include "param.h"
#include "updater_gpu_common.cuh"
#include "constraints.cuh"
#include "gpu_hist/gradient_based_sampler.cuh"
#include "gpu_hist/row_partitioner.cuh"
namespace xgboost {
@ -415,11 +416,8 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
}
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.info.row_stride];
if (!matrix.IsInRange(ridx)) {
continue;
}
int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride
+ idx % matrix.info.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
if (gidx != matrix.info.n_bins) {
// If we are not using shared memory, accumulate the values directly into
// global memory
@ -480,6 +478,8 @@ struct GPUHistMakerDevice {
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
std::unique_ptr<GradientBasedSampler> sampler;
GPUHistMakerDevice(int _device_id,
EllpackPageImpl* _page,
bst_uint _n_rows,
@ -495,6 +495,11 @@ struct GPUHistMakerDevice {
column_sampler(column_sampler_seed),
interaction_constraints(param, n_features),
batch_param(_batch_param) {
sampler.reset(new GradientBasedSampler(page,
n_rows,
batch_param,
param.subsample,
param.sampling_method));
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
}
@ -528,7 +533,7 @@ struct GPUHistMakerDevice {
// Reset values for each update iteration
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, int64_t num_columns) {
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand.reset(new ExpandQueue(LossGuide));
} else {
@ -540,13 +545,14 @@ struct GPUHistMakerDevice {
this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
GradientPair());
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
n_rows = sample.sample_rows;
page = sample.page;
gpair = sample.gpair;
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
dh::safe_cuda(cudaMemcpyAsync(
gpair.data(), dh_gpair->ConstDevicePointer(),
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
SubsampleGradientPair(device_id, gpair, param.subsample);
hist.Reset();
}
@ -632,14 +638,6 @@ struct GPUHistMakerDevice {
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
}
// Build gradient histograms for a given node across all the batches in the DMatrix.
void BuildHistBatches(int nidx, DMatrix* p_fmat) {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
page = batch.Impl();
BuildHist(nidx);
}
}
void BuildHist(int nidx) {
hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx);
@ -687,10 +685,7 @@ struct GPUHistMakerDevice {
row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(),
[=] __device__(size_t ridx) {
if (!d_matrix.IsInRange(ridx)) {
return RowPartitioner::kIgnoredTreePosition;
}
[=] __device__(bst_uint ridx) {
// given a row index, returns the node id it belongs to
bst_float cut_value =
d_matrix.GetElement(ridx, split_node.SplitIndex());
@ -719,8 +714,20 @@ struct GPUHistMakerDevice {
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
}
if (page->matrix.n_rows == p_fmat->Info().num_row_) {
FinalisePositionInPage(page, d_nodes);
} else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
page = batch.Impl();
FinalisePositionInPage(batch.Impl(), d_nodes);
}
}
}
void FinalisePositionInPage(EllpackPageImpl* page, const common::Span<RegTree::Node> d_nodes) {
auto d_matrix = page->matrix;
row_partitioner->FinalisePosition(
[=] __device__(size_t row_id, int position) {
@ -746,7 +753,6 @@ struct GPUHistMakerDevice {
return position;
});
}
}
void UpdatePredictionCache(bst_float* out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id));
@ -797,7 +803,8 @@ struct GPUHistMakerDevice {
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) {
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left,
int nidx_right, dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
@ -809,34 +816,6 @@ struct GPUHistMakerDevice {
}
this->BuildHist(build_hist_nidx);
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = this->CanDoSubtractionTrick(
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
if (!do_subtraction_trick) {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
}
}
/**
* \brief AllReduce GPU histograms for the left and right child of some parent node.
*/
void ReduceHistLeftRight(const ExpandEntry& candidate,
int nidx_left,
int nidx_right,
dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
// Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
if (fewer_right) {
std::swap(build_hist_nidx, subtraction_trick_nidx);
}
this->AllReduceHist(build_hist_nidx, reducer);
// Check whether we can use the subtraction trick to calculate the other
@ -849,6 +828,7 @@ struct GPUHistMakerDevice {
subtraction_trick_nidx);
} else {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer);
}
}
@ -889,14 +869,10 @@ struct GPUHistMakerDevice {
tree[candidate.nid].RightChild());
}
void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
dh::AllReducer* reducer, int64_t num_columns) {
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) {
constexpr int kRootNIdx = 0;
const auto &gpair = gpair_all->DeviceSpan();
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d,
gpair.size());
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size());
reducer->AllReduceSum(
reinterpret_cast<float*>(node_sum_gradients_d.data()),
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
@ -905,7 +881,7 @@ struct GPUHistMakerDevice {
node_sum_gradients_d.data(), sizeof(GradientPair),
cudaMemcpyDeviceToHost));
this->BuildHistBatches(kRootNIdx, p_fmat);
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, reducer);
// Remember root stats
@ -928,11 +904,11 @@ struct GPUHistMakerDevice {
auto& tree = *p_tree;
monitor.StartCuda("Reset");
this->Reset(gpair_all, p_fmat->Info().num_col_);
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
monitor.StopCuda("Reset");
monitor.StartCuda("InitRoot");
this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_);
this->InitRoot(p_tree, reducer, p_fmat->Info().num_col_);
monitor.StopCuda("InitRoot");
auto timestamp = qexpand->size();
@ -951,21 +927,15 @@ struct GPUHistMakerDevice {
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
page = batch.Impl();
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.StartCuda("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("UpdatePosition");
monitor.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx);
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("BuildHist");
}
monitor.StartCuda("ReduceHist");
this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("ReduceHist");
monitor.StartCuda("EvaluateSplits");
auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx},
@ -997,7 +967,6 @@ inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
ba.Allocate(device_id,
&gpair, n_rows,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
&monotone_constraints, param.monotone_constraints.size());

View File

@ -534,6 +534,9 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
// mark subsample and build list of member rows
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
std::bernoulli_distribution coin_flip(param_.subsample);
auto& rnd = common::GlobalRandom();
size_t j = 0;

View File

@ -81,4 +81,119 @@ TEST(EllpackPage, BuildGidxSparse) {
}
}
struct ReadRowFunction {
EllpackMatrix matrix;
int row;
bst_float* row_data_d;
ReadRowFunction(EllpackMatrix matrix, int row, bst_float* row_data_d)
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
__device__ void operator()(size_t col) {
auto value = matrix.GetElement(row, col);
if (isnan(value)) {
value = -1;
}
row_data_d[col] = value;
}
};
TEST(EllpackPage, Copy) {
constexpr size_t kRows = 1024;
constexpr size_t kCols = 16;
constexpr size_t kPageSize = 1024;
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
BatchParam param{0, 256, 0, kPageSize};
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
// Create an empty result page.
EllpackPageImpl result(0, page->matrix.info, kRows);
// Copy batch pages into the result page.
size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(param)) {
size_t num_elements = result.Copy(0, batch.Impl(), offset);
offset += num_elements;
}
size_t current_row = 0;
thrust::device_vector<bst_float> row_d(kCols);
thrust::device_vector<bst_float> row_result_d(kCols);
std::vector<bst_float> row(kCols);
std::vector<bst_float> row_result(kCols);
for (auto& page : dmat->GetBatches<EllpackPage>(param)) {
auto impl = page.Impl();
EXPECT_EQ(impl->matrix.base_rowid, current_row);
for (size_t i = 0; i < impl->Size(); i++) {
dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(0, kCols, ReadRowFunction(result.matrix, current_row, row_result_d.data().get()));
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());
EXPECT_EQ(row, row_result);
current_row++;
}
}
}
TEST(EllpackPage, Compact) {
constexpr size_t kRows = 16;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1;
constexpr size_t kCompactedRows = 8;
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
BatchParam param{0, 256, 0, kPageSize};
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
// Create an empty result page.
EllpackPageImpl result(0, page->matrix.info, kCompactedRows);
// Compact batch pages into the result page.
std::vector<size_t> row_indexes_h {
SIZE_MAX, 0, 1, 2, SIZE_MAX, 3, SIZE_MAX, 4, 5, SIZE_MAX, 6, SIZE_MAX, 7, SIZE_MAX, SIZE_MAX,
SIZE_MAX};
thrust::device_vector<size_t> row_indexes_d = row_indexes_h;
common::Span<size_t> row_indexes_span(row_indexes_d.data().get(), kRows);
for (auto& batch : dmat->GetBatches<EllpackPage>(param)) {
result.Compact(0, batch.Impl(), row_indexes_span);
}
size_t current_row = 0;
thrust::device_vector<bst_float> row_d(kCols);
thrust::device_vector<bst_float> row_result_d(kCols);
std::vector<bst_float> row(kCols);
std::vector<bst_float> row_result(kCols);
for (auto& page : dmat->GetBatches<EllpackPage>(param)) {
auto impl = page.Impl();
EXPECT_EQ(impl->matrix.base_rowid, current_row);
for (size_t i = 0; i < impl->Size(); i++) {
size_t compacted_row = row_indexes_h[current_row];
if (compacted_row == SIZE_MAX) {
current_row++;
continue;
}
dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(0, kCols,
ReadRowFunction(result.matrix, compacted_row, row_result_d.data().get()));
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());
EXPECT_EQ(row, row_result);
current_row++;
}
}
}
} // namespace xgboost

View File

@ -221,6 +221,19 @@ inline GenericParameter CreateEmptyGenericParam(int gpu_id) {
return tparam;
}
inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows) {
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(n_rows);
for (auto &gpair : h_gpair) {
bst_float grad = dist(&gen);
bst_float hess = dist(&gen);
gpair = GradientPair(grad, hess);
}
HostDeviceVector<GradientPair> gpair(h_gpair);
return gpair;
}
#if defined(__CUDACC__)
namespace {
class HistogramCutsWrapper : public common::HistogramCuts {

View File

@ -0,0 +1,150 @@
#include <gtest/gtest.h>
#include "../../../../src/data/ellpack_page.cuh"
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
#include "../../helpers.h"
namespace xgboost {
namespace tree {
void VerifySampling(size_t page_size,
float subsample,
int sampling_method,
bool fixed_size_sampling = true,
bool check_sum = true) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 1;
size_t sample_rows = kRows * subsample;
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix> dmat(
CreateSparsePageDMatrixWithRC(kRows, kCols, page_size, true, tmpdir));
auto gpair = GenerateRandomGradients(kRows);
GradientPair sum_gpair{};
for (const auto& gp : gpair.ConstHostVector()) {
sum_gpair += gp;
}
gpair.SetDevice(0);
BatchParam param{0, 256, 0, page_size};
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
if (page_size != 0) {
EXPECT_NE(page->matrix.n_rows, kRows);
}
GradientBasedSampler sampler(page, kRows, param, subsample, sampling_method);
auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get());
if (fixed_size_sampling) {
EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.page->matrix.n_rows, kRows);
EXPECT_EQ(sample.gpair.size(), kRows);
} else {
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012f);
EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012f);
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012f);
}
GradientPair sum_sampled_gpair{};
std::vector<GradientPair> sampled_gpair_h(sample.gpair.size());
dh::CopyDeviceSpanToVector(&sampled_gpair_h, sample.gpair);
for (const auto& gp : sampled_gpair_h) {
sum_sampled_gpair += gp;
}
if (check_sum) {
EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows);
EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows);
} else {
EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.02f);
EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.02f);
}
}
TEST(GradientBasedSampler, NoSampling) {
constexpr size_t kPageSize = 0;
constexpr float kSubsample = 1.0f;
constexpr int kSamplingMethod = TrainParam::kUniform;
VerifySampling(kPageSize, kSubsample, kSamplingMethod);
}
// In external mode, when not sampling, we concatenate the pages together.
TEST(GradientBasedSampler, NoSampling_ExternalMemory) {
constexpr size_t kRows = 2048;
constexpr size_t kCols = 1;
constexpr float kSubsample = 1.0f;
constexpr size_t kPageSize = 1024;
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
auto gpair = GenerateRandomGradients(kRows);
gpair.SetDevice(0);
BatchParam param{0, 256, 0, kPageSize};
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
EXPECT_NE(page->matrix.n_rows, kRows);
GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform);
auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get());
auto sampled_page = sample.page;
EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.gpair.size(), gpair.Size());
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
EXPECT_EQ(sampled_page->matrix.n_rows, kRows);
std::vector<common::CompressedByteT> buffer(sampled_page->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&buffer, sampled_page->gidx_buffer);
common::CompressedIterator<common::CompressedByteT>
ci(buffer.data(), sampled_page->matrix.info.NumSymbols());
size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(param)) {
auto page = batch.Impl();
std::vector<common::CompressedByteT> page_buffer(page->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&page_buffer, page->gidx_buffer);
common::CompressedIterator<common::CompressedByteT>
page_ci(page_buffer.data(), page->matrix.info.NumSymbols());
size_t num_elements = page->matrix.n_rows * page->matrix.info.row_stride;
for (size_t i = 0; i < num_elements; i++) {
EXPECT_EQ(ci[i + offset], page_ci[i]);
}
offset += num_elements;
}
}
TEST(GradientBasedSampler, UniformSampling) {
constexpr size_t kPageSize = 0;
constexpr float kSubsample = 0.5;
constexpr int kSamplingMethod = TrainParam::kUniform;
constexpr bool kFixedSizeSampling = true;
constexpr bool kCheckSum = false;
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling, kCheckSum);
}
TEST(GradientBasedSampler, UniformSampling_ExternalMemory) {
constexpr size_t kPageSize = 1024;
constexpr float kSubsample = 0.5;
constexpr int kSamplingMethod = TrainParam::kUniform;
constexpr bool kFixedSizeSampling = false;
constexpr bool kCheckSum = false;
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling, kCheckSum);
}
TEST(GradientBasedSampler, GradientBasedSampling) {
constexpr size_t kPageSize = 0;
constexpr float kSubsample = 0.8;
constexpr int kSamplingMethod = TrainParam::kGradientBased;
VerifySampling(kPageSize, kSubsample, kSamplingMethod);
}
TEST(GradientBasedSampler, GradientBasedSampling_ExternalMemory) {
constexpr size_t kPageSize = 1024;
constexpr float kSubsample = 0.8;
constexpr int kSamplingMethod = TrainParam::kGradientBased;
constexpr bool kFixedSizeSampling = false;
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
}
}; // namespace tree
}; // namespace xgboost

View File

@ -88,12 +88,13 @@ void TestBuildHist(bool use_shared_memory_histograms) {
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(kNRows);
for (auto &gpair : h_gpair) {
HostDeviceVector<GradientPair> gpair(kNRows);
for (auto &gp : gpair.HostVector()) {
bst_float grad = dist(&gen);
bst_float hess = dist(&gen);
gpair = GradientPair(grad, hess);
gp = GradientPair(grad, hess);
}
gpair.SetDevice(0);
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.size());
@ -104,7 +105,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
maker.hist.AllocateHistogram(0);
dh::CopyVectorToDeviceSpan(maker.gpair, h_gpair);
maker.gpair = gpair.DeviceSpan();
maker.use_shared_memory_histograms = use_shared_memory_histograms;
maker.BuildHist(0);
@ -319,19 +320,6 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPa
return n_nodes;
}
HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows) {
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(n_rows);
for (auto &gpair : h_gpair) {
bst_float grad = dist(&gen);
bst_float hess = dist(&gen);
gpair = GradientPair(grad, hess);
}
HostDeviceVector<GradientPair> gpair(h_gpair);
return gpair;
}
TEST(GpuHist, MinSplitLoss) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
@ -358,7 +346,9 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair,
DMatrix* dmat,
size_t gpu_page_size,
RegTree* tree,
HostDeviceVector<bst_float>* preds) {
HostDeviceVector<bst_float>* preds,
float subsample = 1.0f,
const std::string& sampling_method = "uniform") {
constexpr size_t kMaxBin = 2;
if (gpu_page_size > 0) {
@ -379,7 +369,9 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair,
{"max_bin", std::to_string(kMaxBin)},
{"min_child_weight", "0.0"},
{"reg_alpha", "0"},
{"reg_lambda", "0"}
{"reg_lambda", "0"},
{"subsample", std::to_string(subsample)},
{"sampling_method", sampling_method},
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
@ -391,10 +383,66 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair,
hist_maker.UpdatePredictionCache(dmat, preds);
}
TEST(GpuHist, ExternalMemory) {
constexpr size_t kRows = 6;
TEST(GpuHist, UniformSampling) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1;
constexpr float kSubsample = 0.99;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto gpair = GenerateRandomGradients(kRows);
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds);
// Build another tree using sampling.
RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
auto preds_sampling_h = preds_sampling.ConstHostVector();
for (int i = 0; i < kRows; i++) {
EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 2e-3);
}
}
TEST(GpuHist, GradientBasedSampling) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 2;
constexpr float kSubsample = 0.99;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto gpair = GenerateRandomGradients(kRows);
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds);
// Build another tree using sampling.
RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "gradient_based");
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
auto preds_sampling_h = preds_sampling.ConstHostVector();
for (int i = 0; i < kRows; i++) {
EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-3);
}
}
TEST(GpuHist, ExternalMemory) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1024;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
@ -420,7 +468,42 @@ TEST(GpuHist, ExternalMemory) {
auto preds_h = preds.ConstHostVector();
auto preds_ext_h = preds_ext.ConstHostVector();
for (int i = 0; i < kRows; i++) {
ASSERT_FLOAT_EQ(preds_h[i], preds_ext_h[i]);
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-6);
}
}
TEST(GpuHist, ExternalMemoryWithSampling) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1024;
constexpr float kSubsample = 0.5;
const std::string kSamplingMethod = "gradient_based";
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
auto gpair = GenerateRandomGradients(kRows);
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod);
// Build another tree using multiple ELLPACK pages.
RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample, kSamplingMethod);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
auto preds_ext_h = preds_ext.ConstHostVector();
for (int i = 0; i < kRows; i++) {
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 3e-3);
}
}