Gradient based sampling for GPU Hist (#5093)

* Implement gradient based sampling for GPU Hist tree method.
* Add samplers and handle compacted page in GPU Hist.
This commit is contained in:
Rong Ou
2020-02-03 18:31:27 -08:00
committed by GitHub
parent c74216f22c
commit e4b74c4d22
18 changed files with 1187 additions and 175 deletions

View File

@@ -0,0 +1,380 @@
/*!
* Copyright 2019 by XGBoost Contributors
*/
#include <thrust/functional.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/logging.h>
#include <algorithm>
#include <limits>
#include "../../common/compressed_iterator.h"
#include "../../common/random.h"
#include "gradient_based_sampler.cuh"
namespace xgboost {
namespace tree {
/*! \brief A functor that returns random weights. */
class RandomWeight : public thrust::unary_function<size_t, float> {
public:
explicit RandomWeight(size_t seed) : seed_(seed) {}
XGBOOST_DEVICE float operator()(size_t i) const {
thrust::default_random_engine rng(seed_);
thrust::uniform_real_distribution<float> dist;
rng.discard(i);
return dist(rng);
}
private:
uint32_t seed_;
};
/*! \brief A functor that performs a Bernoulli trial to discard a gradient pair. */
class BernoulliTrial : public thrust::unary_function<size_t, bool> {
public:
BernoulliTrial(size_t seed, float p) : rnd_(seed), p_(p) {}
XGBOOST_DEVICE bool operator()(size_t i) const {
return rnd_(i) > p_;
}
private:
RandomWeight rnd_;
float p_;
};
/*! \brief A functor that returns true if the gradient pair is non-zero. */
struct IsNonZero : public thrust::unary_function<GradientPair, bool> {
XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const {
return gpair.GetGrad() != 0 || gpair.GetHess() != 0;
}
};
/*! \brief A functor that clears the row indexes with empty gradient. */
struct ClearEmptyRows : public thrust::binary_function<GradientPair, size_t, size_t> {
XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const {
if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) {
return row_index;
} else {
return std::numeric_limits<std::size_t>::max();
}
}
};
/*! \brief A functor that combines the gradient pair into a single float.
*
* The approach here is based on Minimal Variance Sampling (MVS), with lambda set to 0.1.
*
* \see Ibragimov, B., & Gusev, G. (2019). Minimal Variance Sampling in Stochastic Gradient
* Boosting. In Advances in Neural Information Processing Systems (pp. 15061-15071).
*/
class CombineGradientPair : public thrust::unary_function<GradientPair, float> {
public:
XGBOOST_DEVICE float operator()(const GradientPair& gpair) const {
return sqrtf(powf(gpair.GetGrad(), 2) + kLambda * powf(gpair.GetHess(), 2));
}
private:
static constexpr float kLambda = 0.1f;
};
/*! \brief A functor that calculates the difference between the sample rate and the desired sample
* rows, given a cumulative gradient sum.
*/
class SampleRateDelta : public thrust::binary_function<float, size_t, float> {
public:
SampleRateDelta(common::Span<float> threshold, size_t n_rows, size_t sample_rows)
: threshold_(threshold), n_rows_(n_rows), sample_rows_(sample_rows) {}
XGBOOST_DEVICE float operator()(float gradient_sum, size_t row_index) const {
float lower = threshold_[row_index];
float upper = threshold_[row_index + 1];
float u = gradient_sum / static_cast<float>(sample_rows_ - n_rows_ + row_index + 1);
if (u > lower && u <= upper) {
threshold_[row_index + 1] = u;
return 0.0f;
} else {
return std::numeric_limits<float>::max();
}
}
private:
common::Span<float> threshold_;
size_t n_rows_;
size_t sample_rows_;
};
/*! \brief A functor that performs Poisson sampling, and scales gradient pairs by 1/p_i. */
class PoissonSampling : public thrust::binary_function<GradientPair, size_t, GradientPair> {
public:
PoissonSampling(common::Span<float> threshold, size_t threshold_index, RandomWeight rnd)
: threshold_(threshold), threshold_index_(threshold_index), rnd_(rnd) {}
XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) {
// If the gradient and hessian are both empty, we should never select this row.
if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) {
return gpair;
}
float combined_gradient = combine_(gpair);
float u = threshold_[threshold_index_];
float p = combined_gradient / u;
if (p >= 1) {
// Always select this row.
return gpair;
} else {
// Select this row randomly with probability proportional to the combined gradient.
// Scale gpair by 1/p.
if (rnd_(i) <= p) {
return gpair / p;
} else {
return GradientPair();
}
}
}
private:
common::Span<float> threshold_;
size_t threshold_index_;
RandomWeight rnd_;
CombineGradientPair combine_;
};
NoSampling::NoSampling(EllpackPageImpl* page) : page_(page) {}
GradientBasedSample NoSampling::Sample(common::Span<GradientPair> gpair, DMatrix* dmat) {
return {dmat->Info().num_row_, page_, gpair};
}
ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param)
: batch_param_(batch_param),
page_(new EllpackPageImpl(batch_param.gpu_id, page->matrix.info, n_rows)) {}
GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
if (!page_concatenated_) {
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
auto page = batch.Impl();
size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset);
offset += num_elements;
}
page_concatenated_ = true;
}
return {dmat->Info().num_row_, page_.get(), gpair};
}
UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample)
: page_(page), subsample_(subsample) {}
GradientBasedSample UniformSampling::Sample(common::Span<GradientPair> gpair, DMatrix* dmat) {
// Set gradient pair to 0 with p = 1 - subsample
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_),
GradientPair());
return {dmat->Info().num_row_, page_, gpair};
}
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample)
: original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows);
}
GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
// Set gradient pair to 0 with p = 1 - subsample
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_),
GradientPair());
// Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
size_t n_rows = dmat->Info().num_row_;
// Compact gradient pairs.
gpair_.resize(sample_rows);
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
ClearEmptyRows());
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(batch_param_.gpu_id,
original_page_->matrix.info,
sample_rows));
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
}
GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : page_(page), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows);
}
GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
threshold_index,
RandomWeight(common::GlobalRandom()())));
return {n_rows, page_, gpair};
}
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows,
&sample_row_index_, n_rows);
}
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
threshold_index,
RandomWeight(common::GlobalRandom()())));
// Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
// Compact gradient pairs.
gpair_.resize(sample_rows);
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
ClearEmptyRows());
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(batch_param_.gpu_id,
original_page_->matrix.info,
sample_rows));
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
}
GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample,
int sampling_method) {
monitor_.Init("gradient_based_sampler");
bool is_sampling = subsample < 1.0;
bool is_external_memory = page->matrix.n_rows != n_rows;
if (is_sampling) {
switch (sampling_method) {
case TrainParam::kUniform:
if (is_external_memory) {
strategy_.reset(new ExternalMemoryUniformSampling(page, n_rows, batch_param, subsample));
} else {
strategy_.reset(new UniformSampling(page, subsample));
}
break;
case TrainParam::kGradientBased:
if (is_external_memory) {
strategy_.reset(
new ExternalMemoryGradientBasedSampling(page, n_rows, batch_param, subsample));
} else {
strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample));
}
break;
default:LOG(FATAL) << "unknown sampling method";
}
} else {
if (is_external_memory) {
strategy_.reset(new ExternalMemoryNoSampling(page, n_rows, batch_param));
} else {
strategy_.reset(new NoSampling(page));
}
}
}
// Sample a DMatrix based on the given gradient pairs.
GradientBasedSample GradientBasedSampler::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
monitor_.StartCuda("Sample");
GradientBasedSample sample = strategy_->Sample(gpair, dmat);
monitor_.StopCuda("Sample");
return sample;
}
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
size_t sample_rows) {
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(threshold),
CombineGradientPair());
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum));
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
thrust::counting_iterator<size_t>(0),
dh::tbegin(grad_sum),
SampleRateDelta(threshold, gpair.size(), sample_rows));
thrust::device_ptr<float> min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
}
}; // namespace tree
}; // namespace xgboost

View File

@@ -0,0 +1,153 @@
/*!
* Copyright 2019 by XGBoost Contributors
*/
#pragma once
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/span.h>
#include "../../common/device_helpers.cuh"
#include "../../data/ellpack_page.cuh"
namespace xgboost {
namespace tree {
struct GradientBasedSample {
/*!\brief Number of sampled rows. */
size_t sample_rows;
/*!\brief Sampled rows in ELLPACK format. */
EllpackPageImpl* page;
/*!\brief Gradient pairs for the sampled rows. */
common::Span<GradientPair> gpair;
};
class SamplingStrategy {
public:
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
virtual GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) = 0;
};
/*! \brief No sampling in in-memory mode. */
class NoSampling : public SamplingStrategy {
public:
explicit NoSampling(EllpackPageImpl* page);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
EllpackPageImpl* page_;
};
/*! \brief No sampling in external memory mode. */
class ExternalMemoryNoSampling : public SamplingStrategy {
public:
ExternalMemoryNoSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
BatchParam batch_param_;
std::unique_ptr<EllpackPageImpl> page_;
bool page_concatenated_{false};
};
/*! \brief Uniform sampling in in-memory mode. */
class UniformSampling : public SamplingStrategy {
public:
UniformSampling(EllpackPageImpl* page, float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
EllpackPageImpl* page_;
float subsample_;
};
/*! \brief No sampling in external memory mode. */
class ExternalMemoryUniformSampling : public SamplingStrategy {
public:
ExternalMemoryUniformSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_{};
common::Span<size_t> sample_row_index_;
};
/*! \brief Gradient-based sampling in in-memory mode.. */
class GradientBasedSampling : public SamplingStrategy {
public:
GradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
EllpackPageImpl* page_;
float subsample_;
dh::BulkAllocator ba_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
};
/*! \brief Gradient-based sampling in external memory mode.. */
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
public:
ExternalMemoryGradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_;
common::Span<size_t> sample_row_index_;
};
/*! \brief Draw a sample of rows from a DMatrix.
*
* \see Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., ... & Liu, T. Y. (2017).
* Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information
* Processing Systems (pp. 3146-3154).
* \see Zhu, R. (2016). Gradient-based sampling: An adaptive importance sampling for least-squares.
* In Advances in Neural Information Processing Systems (pp. 406-414).
* \see Ohlsson, E. (1998). Sequential poisson sampling. Journal of official Statistics, 14(2), 149.
*/
class GradientBasedSampler {
public:
GradientBasedSampler(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample,
int sampling_method);
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat);
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
size_t sample_rows);
private:
common::Monitor monitor_;
std::unique_ptr<SamplingStrategy> strategy_;
};
}; // namespace tree
}; // namespace xgboost

View File

@@ -125,7 +125,6 @@ class RowPartitioner {
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx); // new node id
if (new_position == kIgnoredTreePosition) return;
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position;

View File

@@ -50,6 +50,9 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
float max_delta_step;
// whether we want to do subsample
float subsample;
// sampling method
enum SamplingMethod { kUniform = 0, kGradientBased = 1 };
int sampling_method;
// whether to subsample columns in each split (node)
float colsample_bynode;
// whether to subsample columns in each level
@@ -144,6 +147,14 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe("Row subsample ratio of training instance.");
DMLC_DECLARE_FIELD(sampling_method)
.set_default(kUniform)
.add_enum("uniform", kUniform)
.add_enum("gradient_based", kGradientBased)
.describe(
"Sampling method. 0: select random training instances uniformly. "
"1: select random training instances with higher probability when the "
"gradient and hessian are larger. (cf. CatBoost)");
DMLC_DECLARE_FIELD(colsample_bynode)
.set_range(0.0f, 1.0f)
.set_default(1.0f)

View File

@@ -148,6 +148,9 @@ class BaseMaker: public TreeUpdater {
}
// mark subsample
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
std::bernoulli_distribution coin_flip(param_.subsample);
auto& rnd = common::GlobalRandom();
for (size_t i = 0; i < position_.size(); ++i) {

View File

@@ -202,6 +202,9 @@ class ColMaker: public TreeUpdater {
}
// mark subsample
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
std::bernoulli_distribution coin_flip(param_.subsample);
auto& rnd = common::GlobalRandom();
for (size_t ridx = 0; ridx < position_.size(); ++ridx) {

View File

@@ -187,41 +187,5 @@ XGBOOST_DEVICE inline int MaxNodesDepth(int depth) {
return (1 << (depth + 1)) - 1;
}
/*
* Random
*/
struct BernoulliRng {
float p;
uint32_t seed;
XGBOOST_DEVICE BernoulliRng(float p, size_t seed_) : p(p) {
seed = static_cast<uint32_t>(seed_);
}
XGBOOST_DEVICE bool operator()(const int i) const {
thrust::default_random_engine rng(seed);
thrust::uniform_real_distribution<float> dist;
rng.discard(i);
return dist(rng) <= p;
}
};
// Set gradient pair to 0 with p = 1 - subsample
inline void SubsampleGradientPair(int device_idx,
common::Span<GradientPair> d_gpair,
float subsample, int offset = 0) {
if (subsample == 1.0) {
return;
}
BernoulliRng rng(subsample, common::GlobalRandom()());
dh::LaunchN(device_idx, d_gpair.size(), [=] XGBOOST_DEVICE(int i) {
if (!rng(i + offset)) {
d_gpair[i] = GradientPair();
}
});
}
} // namespace tree
} // namespace xgboost

View File

@@ -29,6 +29,7 @@
#include "param.h"
#include "updater_gpu_common.cuh"
#include "constraints.cuh"
#include "gpu_hist/gradient_based_sampler.cuh"
#include "gpu_hist/row_partitioner.cuh"
namespace xgboost {
@@ -415,11 +416,8 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
}
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.info.row_stride];
if (!matrix.IsInRange(ridx)) {
continue;
}
int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride
+ idx % matrix.info.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
if (gidx != matrix.info.n_bins) {
// If we are not using shared memory, accumulate the values directly into
// global memory
@@ -480,6 +478,8 @@ struct GPUHistMakerDevice {
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
std::unique_ptr<GradientBasedSampler> sampler;
GPUHistMakerDevice(int _device_id,
EllpackPageImpl* _page,
bst_uint _n_rows,
@@ -495,6 +495,11 @@ struct GPUHistMakerDevice {
column_sampler(column_sampler_seed),
interaction_constraints(param, n_features),
batch_param(_batch_param) {
sampler.reset(new GradientBasedSampler(page,
n_rows,
batch_param,
param.subsample,
param.sampling_method));
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
}
@@ -528,7 +533,7 @@ struct GPUHistMakerDevice {
// Reset values for each update iteration
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, int64_t num_columns) {
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand.reset(new ExpandQueue(LossGuide));
} else {
@@ -540,13 +545,14 @@ struct GPUHistMakerDevice {
this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
GradientPair());
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
n_rows = sample.sample_rows;
page = sample.page;
gpair = sample.gpair;
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
dh::safe_cuda(cudaMemcpyAsync(
gpair.data(), dh_gpair->ConstDevicePointer(),
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
SubsampleGradientPair(device_id, gpair, param.subsample);
hist.Reset();
}
@@ -632,14 +638,6 @@ struct GPUHistMakerDevice {
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
}
// Build gradient histograms for a given node across all the batches in the DMatrix.
void BuildHistBatches(int nidx, DMatrix* p_fmat) {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
page = batch.Impl();
BuildHist(nidx);
}
}
void BuildHist(int nidx) {
hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx);
@@ -687,10 +685,7 @@ struct GPUHistMakerDevice {
row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(),
[=] __device__(size_t ridx) {
if (!d_matrix.IsInRange(ridx)) {
return RowPartitioner::kIgnoredTreePosition;
}
[=] __device__(bst_uint ridx) {
// given a row index, returns the node id it belongs to
bst_float cut_value =
d_matrix.GetElement(ridx, split_node.SplitIndex());
@@ -719,33 +714,44 @@ struct GPUHistMakerDevice {
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
page = batch.Impl();
auto d_matrix = page->matrix;
row_partitioner->FinalisePosition(
[=] __device__(size_t row_id, int position) {
if (!d_matrix.IsInRange(row_id)) {
return RowPartitioner::kIgnoredTreePosition;
}
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
if (element <= node.SplitCond()) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
return position;
});
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
}
if (page->matrix.n_rows == p_fmat->Info().num_row_) {
FinalisePositionInPage(page, d_nodes);
} else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
FinalisePositionInPage(batch.Impl(), d_nodes);
}
}
}
void FinalisePositionInPage(EllpackPageImpl* page, const common::Span<RegTree::Node> d_nodes) {
auto d_matrix = page->matrix;
row_partitioner->FinalisePosition(
[=] __device__(size_t row_id, int position) {
if (!d_matrix.IsInRange(row_id)) {
return RowPartitioner::kIgnoredTreePosition;
}
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
if (element <= node.SplitCond()) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
return position;
});
}
void UpdatePredictionCache(bst_float* out_preds_d) {
@@ -797,7 +803,8 @@ struct GPUHistMakerDevice {
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) {
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left,
int nidx_right, dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
@@ -809,34 +816,6 @@ struct GPUHistMakerDevice {
}
this->BuildHist(build_hist_nidx);
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = this->CanDoSubtractionTrick(
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
if (!do_subtraction_trick) {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
}
}
/**
* \brief AllReduce GPU histograms for the left and right child of some parent node.
*/
void ReduceHistLeftRight(const ExpandEntry& candidate,
int nidx_left,
int nidx_right,
dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
// Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
if (fewer_right) {
std::swap(build_hist_nidx, subtraction_trick_nidx);
}
this->AllReduceHist(build_hist_nidx, reducer);
// Check whether we can use the subtraction trick to calculate the other
@@ -849,6 +828,7 @@ struct GPUHistMakerDevice {
subtraction_trick_nidx);
} else {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer);
}
}
@@ -889,14 +869,10 @@ struct GPUHistMakerDevice {
tree[candidate.nid].RightChild());
}
void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
dh::AllReducer* reducer, int64_t num_columns) {
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) {
constexpr int kRootNIdx = 0;
const auto &gpair = gpair_all->DeviceSpan();
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d,
gpair.size());
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size());
reducer->AllReduceSum(
reinterpret_cast<float*>(node_sum_gradients_d.data()),
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
@@ -905,7 +881,7 @@ struct GPUHistMakerDevice {
node_sum_gradients_d.data(), sizeof(GradientPair),
cudaMemcpyDeviceToHost));
this->BuildHistBatches(kRootNIdx, p_fmat);
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, reducer);
// Remember root stats
@@ -928,11 +904,11 @@ struct GPUHistMakerDevice {
auto& tree = *p_tree;
monitor.StartCuda("Reset");
this->Reset(gpair_all, p_fmat->Info().num_col_);
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
monitor.StopCuda("Reset");
monitor.StartCuda("InitRoot");
this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_);
this->InitRoot(p_tree, reducer, p_fmat->Info().num_col_);
monitor.StopCuda("InitRoot");
auto timestamp = qexpand->size();
@@ -951,21 +927,15 @@ struct GPUHistMakerDevice {
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
page = batch.Impl();
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.StartCuda("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("UpdatePosition");
monitor.StartCuda("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("UpdatePosition");
monitor.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx);
monitor.StopCuda("BuildHist");
}
monitor.StartCuda("ReduceHist");
this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("ReduceHist");
monitor.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("BuildHist");
monitor.StartCuda("EvaluateSplits");
auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx},
@@ -997,7 +967,6 @@ inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
ba.Allocate(device_id,
&gpair, n_rows,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
&monotone_constraints, param.monotone_constraints.size());

View File

@@ -534,6 +534,9 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
// mark subsample and build list of member rows
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
std::bernoulli_distribution coin_flip(param_.subsample);
auto& rnd = common::GlobalRandom();
size_t j = 0;