Gradient based sampling for GPU Hist (#5093)
* Implement gradient based sampling for GPU Hist tree method. * Add samplers and handle compacted page in GPU Hist.
This commit is contained in:
@@ -29,6 +29,7 @@
|
||||
#include "param.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "constraints.cuh"
|
||||
#include "gpu_hist/gradient_based_sampler.cuh"
|
||||
#include "gpu_hist/row_partitioner.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -415,11 +416,8 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
|
||||
}
|
||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||
int ridx = d_ridx[idx / matrix.info.row_stride];
|
||||
if (!matrix.IsInRange(ridx)) {
|
||||
continue;
|
||||
}
|
||||
int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride
|
||||
+ idx % matrix.info.row_stride];
|
||||
int gidx =
|
||||
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
|
||||
if (gidx != matrix.info.n_bins) {
|
||||
// If we are not using shared memory, accumulate the values directly into
|
||||
// global memory
|
||||
@@ -480,6 +478,8 @@ struct GPUHistMakerDevice {
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
std::unique_ptr<ExpandQueue> qexpand;
|
||||
|
||||
std::unique_ptr<GradientBasedSampler> sampler;
|
||||
|
||||
GPUHistMakerDevice(int _device_id,
|
||||
EllpackPageImpl* _page,
|
||||
bst_uint _n_rows,
|
||||
@@ -495,6 +495,11 @@ struct GPUHistMakerDevice {
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features),
|
||||
batch_param(_batch_param) {
|
||||
sampler.reset(new GradientBasedSampler(page,
|
||||
n_rows,
|
||||
batch_param,
|
||||
param.subsample,
|
||||
param.sampling_method));
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
|
||||
}
|
||||
|
||||
@@ -528,7 +533,7 @@ struct GPUHistMakerDevice {
|
||||
// Reset values for each update iteration
|
||||
// Note that the column sampler must be passed by value because it is not
|
||||
// thread safe
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, int64_t num_columns) {
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||
if (param.grow_policy == TrainParam::kLossGuide) {
|
||||
qexpand.reset(new ExpandQueue(LossGuide));
|
||||
} else {
|
||||
@@ -540,13 +545,14 @@ struct GPUHistMakerDevice {
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
|
||||
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
|
||||
n_rows = sample.sample_rows;
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
|
||||
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair.data(), dh_gpair->ConstDevicePointer(),
|
||||
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
|
||||
SubsampleGradientPair(device_id, gpair, param.subsample);
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
@@ -632,14 +638,6 @@ struct GPUHistMakerDevice {
|
||||
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
|
||||
}
|
||||
|
||||
// Build gradient histograms for a given node across all the batches in the DMatrix.
|
||||
void BuildHistBatches(int nidx, DMatrix* p_fmat) {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
page = batch.Impl();
|
||||
BuildHist(nidx);
|
||||
}
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
hist.AllocateHistogram(nidx);
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
@@ -687,10 +685,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
[=] __device__(size_t ridx) {
|
||||
if (!d_matrix.IsInRange(ridx)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
[=] __device__(bst_uint ridx) {
|
||||
// given a row index, returns the node id it belongs to
|
||||
bst_float cut_value =
|
||||
d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
@@ -719,33 +714,44 @@ struct GPUHistMakerDevice {
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
page = batch.Impl();
|
||||
auto d_matrix = page->matrix;
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(size_t row_id, int position) {
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
|
||||
}
|
||||
if (page->matrix.n_rows == p_fmat->Info().num_row_) {
|
||||
FinalisePositionInPage(page, d_nodes);
|
||||
} else {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
FinalisePositionInPage(batch.Impl(), d_nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FinalisePositionInPage(EllpackPageImpl* page, const common::Span<RegTree::Node> d_nodes) {
|
||||
auto d_matrix = page->matrix;
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(size_t row_id, int position) {
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
@@ -797,7 +803,8 @@ struct GPUHistMakerDevice {
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) {
|
||||
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left,
|
||||
int nidx_right, dh::AllReducer* reducer) {
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
@@ -809,34 +816,6 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
this->BuildHist(build_hist_nidx);
|
||||
|
||||
// Check whether we can use the subtraction trick to calculate the other
|
||||
bool do_subtraction_trick = this->CanDoSubtractionTrick(
|
||||
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
|
||||
|
||||
if (!do_subtraction_trick) {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief AllReduce GPU histograms for the left and right child of some parent node.
|
||||
*/
|
||||
void ReduceHistLeftRight(const ExpandEntry& candidate,
|
||||
int nidx_left,
|
||||
int nidx_right,
|
||||
dh::AllReducer* reducer) {
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Use sum of Hessian as a heuristic to select node with fewest training instances
|
||||
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
|
||||
if (fewer_right) {
|
||||
std::swap(build_hist_nidx, subtraction_trick_nidx);
|
||||
}
|
||||
|
||||
this->AllReduceHist(build_hist_nidx, reducer);
|
||||
|
||||
// Check whether we can use the subtraction trick to calculate the other
|
||||
@@ -849,6 +828,7 @@ struct GPUHistMakerDevice {
|
||||
subtraction_trick_nidx);
|
||||
} else {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
this->AllReduceHist(subtraction_trick_nidx, reducer);
|
||||
}
|
||||
}
|
||||
@@ -889,14 +869,10 @@ struct GPUHistMakerDevice {
|
||||
tree[candidate.nid].RightChild());
|
||||
}
|
||||
|
||||
void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
dh::AllReducer* reducer, int64_t num_columns) {
|
||||
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) {
|
||||
constexpr int kRootNIdx = 0;
|
||||
|
||||
const auto &gpair = gpair_all->DeviceSpan();
|
||||
|
||||
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d,
|
||||
gpair.size());
|
||||
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size());
|
||||
reducer->AllReduceSum(
|
||||
reinterpret_cast<float*>(node_sum_gradients_d.data()),
|
||||
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
|
||||
@@ -905,7 +881,7 @@ struct GPUHistMakerDevice {
|
||||
node_sum_gradients_d.data(), sizeof(GradientPair),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
this->BuildHistBatches(kRootNIdx, p_fmat);
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, reducer);
|
||||
|
||||
// Remember root stats
|
||||
@@ -928,11 +904,11 @@ struct GPUHistMakerDevice {
|
||||
auto& tree = *p_tree;
|
||||
|
||||
monitor.StartCuda("Reset");
|
||||
this->Reset(gpair_all, p_fmat->Info().num_col_);
|
||||
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
|
||||
monitor.StopCuda("Reset");
|
||||
|
||||
monitor.StartCuda("InitRoot");
|
||||
this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_);
|
||||
this->InitRoot(p_tree, reducer, p_fmat->Info().num_col_);
|
||||
monitor.StopCuda("InitRoot");
|
||||
|
||||
auto timestamp = qexpand->size();
|
||||
@@ -951,21 +927,15 @@ struct GPUHistMakerDevice {
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
// Only create child entries if needed
|
||||
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
page = batch.Impl();
|
||||
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor.StartCuda("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
|
||||
monitor.StopCuda("UpdatePosition");
|
||||
|
||||
monitor.StartCuda("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
|
||||
monitor.StopCuda("UpdatePosition");
|
||||
|
||||
monitor.StartCuda("BuildHist");
|
||||
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx);
|
||||
monitor.StopCuda("BuildHist");
|
||||
}
|
||||
monitor.StartCuda("ReduceHist");
|
||||
this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
|
||||
monitor.StopCuda("ReduceHist");
|
||||
monitor.StartCuda("BuildHist");
|
||||
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
|
||||
monitor.StopCuda("BuildHist");
|
||||
|
||||
monitor.StartCuda("EvaluateSplits");
|
||||
auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx},
|
||||
@@ -997,7 +967,6 @@ inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
|
||||
|
||||
ba.Allocate(device_id,
|
||||
&gpair, n_rows,
|
||||
&prediction_cache, n_rows,
|
||||
&node_sum_gradients_d, max_nodes,
|
||||
&monotone_constraints, param.monotone_constraints.size());
|
||||
|
||||
Reference in New Issue
Block a user