Support multiple batches in gpu_hist (#5014)
* Initial external memory training support for GPU Hist tree method.
This commit is contained in:
@@ -33,6 +33,7 @@ class RowPartitioner {
|
||||
public:
|
||||
using RowIndexT = bst_uint;
|
||||
struct Segment;
|
||||
static constexpr bst_node_t kIgnoredTreePosition = -1;
|
||||
|
||||
private:
|
||||
int device_idx;
|
||||
@@ -124,6 +125,7 @@ class RowPartitioner {
|
||||
idx += segment.begin;
|
||||
RowIndexT ridx = d_ridx[idx];
|
||||
bst_node_t new_position = op(ridx); // new node id
|
||||
if (new_position == kIgnoredTreePosition) return;
|
||||
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
|
||||
AtomicIncrement(d_left_count, new_position == left_nidx);
|
||||
d_position[idx] = new_position;
|
||||
@@ -163,7 +165,9 @@ class RowPartitioner {
|
||||
dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) {
|
||||
auto position = d_position[idx];
|
||||
RowIndexT ridx = d_ridx[idx];
|
||||
d_position[idx] = op(ridx, position);
|
||||
bst_node_t new_position = op(ridx, position);
|
||||
if (new_position == kIgnoredTreePosition) return;
|
||||
d_position[idx] = new_position;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -409,13 +409,16 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
|
||||
extern __shared__ char smem[];
|
||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||
if (use_shared_memory_histograms) {
|
||||
dh::BlockFill(smem_arr, matrix.BinCount(), GradientSumT());
|
||||
dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT());
|
||||
__syncthreads();
|
||||
}
|
||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||
int ridx = d_ridx[idx / matrix.info.row_stride ];
|
||||
int gidx =
|
||||
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
|
||||
int ridx = d_ridx[idx / matrix.info.row_stride];
|
||||
if (!matrix.IsInRange(ridx)) {
|
||||
continue;
|
||||
}
|
||||
int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride
|
||||
+ idx % matrix.info.row_stride];
|
||||
if (gidx != matrix.info.n_bins) {
|
||||
// If we are not using shared memory, accumulate the values directly into
|
||||
// global memory
|
||||
@@ -428,8 +431,7 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
|
||||
if (use_shared_memory_histograms) {
|
||||
// Write shared memory back to global memory
|
||||
__syncthreads();
|
||||
for (auto i :
|
||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
||||
for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.info.n_bins)) {
|
||||
dh::AtomicAddGpair(d_node_hist + i, smem_arr[i]);
|
||||
}
|
||||
}
|
||||
@@ -440,6 +442,7 @@ template <typename GradientSumT>
|
||||
struct GPUHistMakerDevice {
|
||||
int device_id;
|
||||
EllpackPageImpl* page;
|
||||
BatchParam batch_param;
|
||||
|
||||
dh::BulkAllocator ba;
|
||||
|
||||
@@ -481,14 +484,16 @@ struct GPUHistMakerDevice {
|
||||
bst_uint _n_rows,
|
||||
TrainParam _param,
|
||||
uint32_t column_sampler_seed,
|
||||
uint32_t n_features)
|
||||
uint32_t n_features,
|
||||
BatchParam _batch_param)
|
||||
: device_id(_device_id),
|
||||
page(_page),
|
||||
n_rows(_n_rows),
|
||||
param(std::move(_param)),
|
||||
prediction_cache_initialised(false),
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features) {
|
||||
interaction_constraints(param, n_features),
|
||||
batch_param(_batch_param) {
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
|
||||
}
|
||||
|
||||
@@ -626,6 +631,14 @@ struct GPUHistMakerDevice {
|
||||
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
|
||||
}
|
||||
|
||||
// Build gradient histograms for a given node across all the batches in the DMatrix.
|
||||
void BuildHistBatches(int nidx, DMatrix* p_fmat) {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
page = batch.Impl();
|
||||
BuildHist(nidx);
|
||||
}
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
hist.AllocateHistogram(nidx);
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
@@ -636,7 +649,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
const size_t smem_size =
|
||||
use_shared_memory_histograms
|
||||
? sizeof(GradientSumT) * page->matrix.BinCount()
|
||||
? sizeof(GradientSumT) * page->matrix.info.n_bins
|
||||
: 0;
|
||||
uint32_t items_per_thread = 8;
|
||||
uint32_t block_threads = 256;
|
||||
@@ -673,7 +686,10 @@ struct GPUHistMakerDevice {
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
[=] __device__(bst_uint ridx) {
|
||||
[=] __device__(size_t ridx) {
|
||||
if (!d_matrix.IsInRange(ridx)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
// given a row index, returns the node id it belongs to
|
||||
bst_float cut_value =
|
||||
d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
@@ -693,35 +709,42 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
// After tree update is finished, update the position of all training
|
||||
// instances to their final leaf This information is used later to update the
|
||||
// instances to their final leaf. This information is used later to update the
|
||||
// prediction cache
|
||||
void FinalisePosition(RegTree* p_tree) {
|
||||
void FinalisePosition(RegTree* p_tree, DMatrix* p_fmat) {
|
||||
const auto d_nodes =
|
||||
temp_memory.GetSpan<RegTree::Node>(p_tree->GetNodes().size());
|
||||
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_matrix = page->matrix;
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(bst_uint ridx, int position) {
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(ridx, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
page = batch.Impl();
|
||||
auto d_matrix = page->matrix;
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(size_t row_id, int position) {
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
@@ -764,7 +787,7 @@ struct GPUHistMakerDevice {
|
||||
reducer->AllReduceSum(
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
page->matrix.BinCount() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||
page->matrix.info.n_bins * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||
reducer->Synchronize();
|
||||
|
||||
monitor.StopCuda("AllReduce");
|
||||
@@ -773,12 +796,10 @@ struct GPUHistMakerDevice {
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left,
|
||||
int nidx_right, dh::AllReducer* reducer) {
|
||||
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) {
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Use sum of Hessian as a heuristic to select node with fewest training instances
|
||||
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
|
||||
@@ -787,22 +808,50 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
this->BuildHist(build_hist_nidx);
|
||||
this->AllReduceHist(build_hist_nidx, reducer);
|
||||
|
||||
// Check whether we can use the subtraction trick to calculate the other
|
||||
bool do_subtraction_trick = this->CanDoSubtractionTrick(
|
||||
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
|
||||
|
||||
if (!do_subtraction_trick) {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief AllReduce GPU histograms for the left and right child of some parent node.
|
||||
*/
|
||||
void ReduceHistLeftRight(const ExpandEntry& candidate,
|
||||
int nidx_left,
|
||||
int nidx_right,
|
||||
dh::AllReducer* reducer) {
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Use sum of Hessian as a heuristic to select node with fewest training instances
|
||||
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
|
||||
if (fewer_right) {
|
||||
std::swap(build_hist_nidx, subtraction_trick_nidx);
|
||||
}
|
||||
|
||||
this->AllReduceHist(build_hist_nidx, reducer);
|
||||
|
||||
// Check whether we can use the subtraction trick to calculate the other
|
||||
bool do_subtraction_trick = this->CanDoSubtractionTrick(
|
||||
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
|
||||
|
||||
if (do_subtraction_trick) {
|
||||
// Calculate other histogram using subtraction trick
|
||||
this->SubtractionTrick(candidate.nid, build_hist_nidx,
|
||||
subtraction_trick_nidx);
|
||||
} else {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
this->AllReduceHist(subtraction_trick_nidx, reducer);
|
||||
}
|
||||
}
|
||||
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
@@ -839,7 +888,7 @@ struct GPUHistMakerDevice {
|
||||
tree[candidate.nid].RightChild());
|
||||
}
|
||||
|
||||
void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all,
|
||||
void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
dh::AllReducer* reducer, int64_t num_columns) {
|
||||
constexpr int kRootNIdx = 0;
|
||||
|
||||
@@ -855,7 +904,7 @@ struct GPUHistMakerDevice {
|
||||
node_sum_gradients_d.data(), sizeof(GradientPair),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->BuildHistBatches(kRootNIdx, p_fmat);
|
||||
this->AllReduceHist(kRootNIdx, reducer);
|
||||
|
||||
// Remember root stats
|
||||
@@ -882,7 +931,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.StopCuda("Reset");
|
||||
|
||||
monitor.StartCuda("InitRoot");
|
||||
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_);
|
||||
this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_);
|
||||
monitor.StopCuda("InitRoot");
|
||||
|
||||
auto timestamp = qexpand->size();
|
||||
@@ -901,15 +950,21 @@ struct GPUHistMakerDevice {
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
// Only create child entries if needed
|
||||
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor.StartCuda("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
|
||||
monitor.StopCuda("UpdatePosition");
|
||||
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
page = batch.Impl();
|
||||
|
||||
monitor.StartCuda("BuildHist");
|
||||
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
|
||||
monitor.StopCuda("BuildHist");
|
||||
monitor.StartCuda("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
|
||||
monitor.StopCuda("UpdatePosition");
|
||||
|
||||
monitor.StartCuda("BuildHist");
|
||||
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx);
|
||||
monitor.StopCuda("BuildHist");
|
||||
}
|
||||
monitor.StartCuda("ReduceHist");
|
||||
this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
|
||||
monitor.StopCuda("ReduceHist");
|
||||
|
||||
monitor.StartCuda("EvaluateSplits");
|
||||
auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx},
|
||||
@@ -926,7 +981,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
monitor.StartCuda("FinalisePosition");
|
||||
this->FinalisePosition(p_tree);
|
||||
this->FinalisePosition(p_tree, p_fmat);
|
||||
monitor.StopCuda("FinalisePosition");
|
||||
}
|
||||
};
|
||||
@@ -1016,21 +1071,21 @@ class GPUHistMakerSpecialised {
|
||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
|
||||
// TODO(rongou): support multiple Ellpack pages.
|
||||
EllpackPageImpl* page{};
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>({device_,
|
||||
param_.max_bin,
|
||||
hist_maker_param_.gpu_batch_nrows})) {
|
||||
page = batch.Impl();
|
||||
}
|
||||
|
||||
BatchParam batch_param{
|
||||
device_,
|
||||
param_.max_bin,
|
||||
hist_maker_param_.gpu_batch_nrows,
|
||||
generic_param_->gpu_page_size
|
||||
};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(device_,
|
||||
page,
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_));
|
||||
info_->num_col_,
|
||||
batch_param));
|
||||
|
||||
monitor_.StartCuda("InitHistogram");
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
Reference in New Issue
Block a user