Combine thread launches into single launch per tree for gpu_hist (#4343)

* Combine thread launches into single launch per tree for gpu_hist
algorithm.

* Address deprecation warning

* Add manual column sampler constructor

* Turn off omp dynamic to get a guaranteed number of threads

* Enable openmp in cuda code
This commit is contained in:
Rory Mitchell 2019-04-29 09:58:34 +12:00 committed by GitHub
parent 146e83f3b3
commit 5e582b0fa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 402 additions and 325 deletions

View File

@ -62,6 +62,13 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
.describe("Size of leaf vector, reserved for vector tree");
}
bool operator==(const TreeParam& b) const {
return num_roots == b.num_roots && num_nodes == b.num_nodes &&
num_deleted == b.num_deleted && max_depth == b.max_depth &&
num_feature == b.num_feature &&
size_leaf_vector == b.size_leaf_vector;
}
};
/*! \brief node statistics used in regression tree */
@ -74,6 +81,10 @@ struct RTreeNodeStat {
bst_float base_weight;
/*! \brief number of child that is leaf node known up to now */
int leaf_child_cnt;
bool operator==(const RTreeNodeStat& b) const {
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
}
};
/*!
@ -188,6 +199,11 @@ class RegTree {
if (is_left_child) pidx |= (1U << 31);
this->parent_ = pidx;
}
bool operator==(const Node& b) const {
return parent_ == b.parent_ && cleft_ == b.cleft_ &&
cright_ == b.cright_ && sindex_ == b.sindex_ &&
info_.leaf_value == b.info_.leaf_value;
}
private:
/*!
@ -304,6 +320,11 @@ class RegTree {
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
}
bool operator==(const RegTree& b) const {
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
}
/**
* \brief Expands a leaf node into two additional leaf nodes.
*

View File

@ -57,6 +57,12 @@ if (USE_CUDA)
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_NVTX=1)
endif (USE_NVTX)
# OpenMP is mandatory for cuda version
find_package(OpenMP REQUIRED)
target_compile_options(objxgboost PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=${OpenMP_CXX_FLAGS}>
)
set_target_properties(objxgboost PROPERTIES
CUDA_SEPARABLE_COMPILATION OFF)
else (USE_CUDA)

View File

@ -12,6 +12,7 @@
#include "span.h"
#include <algorithm>
#include <omp.h>
#include <chrono>
#include <ctime>
#include <cub/cub.cuh>
@ -752,6 +753,29 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
});
}
class SaveCudaContext {
private:
int saved_device_;
public:
template <typename Functor>
explicit SaveCudaContext (Functor func) : saved_device_{-1} {
// When compiled with CUDA but running on CPU only device,
// cudaGetDevice will fail.
try {
safe_cuda(cudaGetDevice(&saved_device_));
} catch (const dmlc::Error &except) {
saved_device_ = -1;
}
func();
}
~SaveCudaContext() {
if (saved_device_ != -1) {
safe_cuda(cudaSetDevice(saved_device_));
}
}
};
/**
* \class AllReducer
*
@ -777,8 +801,18 @@ class AllReducer {
allreduce_calls_(0) {}
/**
* \fn void Init(const std::vector<int> &device_ordinals)
*
* \brief If we are using a single GPU only
*/
bool IsSingleGPU() {
#ifdef XGBOOST_USE_NCCL
CHECK(device_counts.size() > 0) << "AllReducer not initialised.";
return device_counts.size() <= 1 && device_counts.at(0) == 1;
#else
return true;
#endif
}
/**
* \brief Initialise with the desired device ordinals for this communication
* group.
*
@ -956,6 +990,22 @@ class AllReducer {
#endif
};
/**
* \brief Synchronizes the device
*
* \param device_id Identifier for the device.
*/
void Synchronize(int device_id) {
#ifdef XGBOOST_USE_NCCL
SaveCudaContext([&]() {
dh::safe_cuda(cudaSetDevice(device_id));
int idx = std::find(device_ordinals.begin(), device_ordinals.end(), device_id) - device_ordinals.begin();
CHECK(idx < device_ordinals.size());
dh::safe_cuda(cudaStreamSynchronize(streams[idx]));
});
#endif
};
#ifdef XGBOOST_USE_NCCL
/**
* \fn ncclUniqueId GetUniqueId()
@ -980,29 +1030,6 @@ class AllReducer {
#endif
};
class SaveCudaContext {
private:
int saved_device_;
public:
template <typename Functor>
explicit SaveCudaContext (Functor func) : saved_device_{-1} {
// When compiled with CUDA but running on CPU only device,
// cudaGetDevice will fail.
try {
safe_cuda(cudaGetDevice(&saved_device_));
} catch (const dmlc::Error &except) {
saved_device_ = -1;
}
func();
}
~SaveCudaContext() {
if (saved_device_ != -1) {
safe_cuda(cudaSetDevice(saved_device_));
}
}
};
/**
* \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element. In addition, passes the shard index
@ -1017,11 +1044,15 @@ class SaveCudaContext {
template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext{[&]() {
// Temporarily turn off dynamic so we have a guaranteed number of threads
bool dynamic = omp_get_dynamic();
omp_set_dynamic(false);
const long shards_size = static_cast<long>(shards->size());
#pragma omp parallel for schedule(static, 1) if (shards_size > 1)
for (long shard = 0; shard < shards_size; ++shard) {
f(shard, shards->at(shard));
}
omp_set_dynamic(dynamic);
}};
}

View File

@ -113,6 +113,14 @@ class ColumnSampler {
}
public:
/**
* \brief Column sampler constructor.
* \note This constructor manually sets the rng seed
*/
explicit ColumnSampler(uint32_t seed) {
rng_.seed(seed);
}
/**
* \brief Column sampler constructor.
* \note This constructor synchronizes the RNG seed across processes.

View File

@ -342,7 +342,8 @@ class GPUPredictor : public xgboost::Predictor {
}
public:
GPUPredictor() : cpu_predictor_(Predictor::Create("cpu_predictor")) {}
GPUPredictor() // NOLINT
: cpu_predictor_(Predictor::Create("cpu_predictor")) {} // NOLINT
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin,

View File

@ -38,6 +38,7 @@ struct GPUHistMakerTrainParam
bool single_precision_histogram;
// number of rows in a single GPU batch
int gpu_batch_nrows;
bool debug_synchronize;
// declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
@ -47,6 +48,8 @@ struct GPUHistMakerTrainParam
.set_default(0)
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
"Check if all distributed tree are identical after tree construction.");
}
};
#if !defined(GTEST_TEST)
@ -598,12 +601,23 @@ inline void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
}
/*! \brief Count how many rows are assigned to left node. */
__forceinline__ __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
__forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
int left_nidx) {
#if __CUDACC_VER_MAJOR__ > 8
int mask = __activemask();
unsigned ballot = __ballot_sync(mask, val == left_nidx);
int leader = __ffs(mask) - 1;
if (threadIdx.x % 32 == leader) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
#else
unsigned ballot = __ballot(val == left_nidx);
if (threadIdx.x % 32 == 0) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
#endif
}
template <typename GradientSumT>
@ -621,6 +635,7 @@ template <typename GradientSumT>
struct DeviceShard {
int n_bins;
int device_id;
int shard_idx; // Position in the local array of shards
dh::BulkAllocator ba;
@ -670,18 +685,31 @@ struct DeviceShard {
std::vector<cudaStream_t> streams;
common::Monitor monitor;
std::vector<ValueConstraint> node_value_constraints;
common::ColumnSampler column_sampler;
std::unique_ptr<GPUHistBuilderBase<GradientSumT>> hist_builder;
using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
// TODO(canonizer): do add support multi-batch DMatrix here
DeviceShard(int _device_id, bst_uint row_begin, bst_uint row_end,
TrainParam _param)
DeviceShard(int _device_id, int shard_idx, bst_uint row_begin,
bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed)
: device_id(_device_id),
shard_idx(shard_idx),
row_begin_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_bins(0),
param(std::move(_param)),
prediction_cache_initialised(false) {}
prediction_cache_initialised(false),
column_sampler(column_sampler_seed) {
monitor.Init(std::string("DeviceShard") + std::to_string(device_id));
}
/* Init row_ptrs and row_stride */
size_t InitRowPtrs(const SparsePage& row_batch) {
@ -736,7 +764,16 @@ struct DeviceShard {
}
// Reset values for each update iteration
void Reset(HostDeviceVector<GradientPair>* dh_gpair) {
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, int64_t num_columns) {
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand.reset(new ExpandQueue(LossGuide));
} else {
qexpand.reset(new ExpandQueue(DepthWise));
}
this->column_sampler.Init(num_columns, param.colsample_bynode,
param.colsample_bylevel, param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
thrust::fill(
thrust::device_pointer_cast(position.Current()),
@ -764,8 +801,6 @@ struct DeviceShard {
std::vector<DeviceSplitCandidate> EvaluateSplits(
std::vector<int> nidxs, const RegTree& tree,
common::ColumnSampler* column_sampler,
const std::vector<ValueConstraint>& value_constraints,
size_t num_columns) {
dh::safe_cuda(cudaSetDevice(device_id));
auto result = pinned_memory.GetSpan<DeviceSplitCandidate>(nidxs.size());
@ -800,7 +835,7 @@ struct DeviceShard {
auto& streams = this->GetStreams(nidxs.size());
for (auto i = 0ull; i < nidxs.size(); i++) {
auto nidx = nidxs[i];
auto p_feature_set = column_sampler->GetFeatureSet(tree.GetDepth(nidx));
auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx));
p_feature_set->Reshard(GPUSet(device_id, 1));
auto d_feature_set = p_feature_set->DeviceSpan(device_id);
auto d_split_candidates =
@ -812,7 +847,7 @@ struct DeviceShard {
EvaluateSplitKernel<kBlockThreads, GradientSumT>
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix,
gpu_param, d_split_candidates, value_constraints[nidx],
gpu_param, d_split_candidates, node_value_constraints[nidx],
monotone_constraints);
// Reduce over features to find best feature
@ -997,6 +1032,179 @@ struct DeviceShard {
out_preds_d, prediction_cache.data(),
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
}
void AllReduceHist(int nidx, dh::AllReducer* reducer) {
monitor.StartCuda("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
reducer->AllReduceSum(
shard_idx,
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
ellpack_matrix.BinCount() *
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
reducer->Synchronize(device_id);
monitor.StopCuda("AllReduce");
}
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right, dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
// If we are using a single GPU, build the histogram for the node with the
// fewest training instances
// If we are distributed, don't bother
if (reducer->IsSingleGPU()) {
bool fewer_right =
ridx_segments[nidx_right].Size() < ridx_segments[nidx_left].Size();
if (fewer_right) {
std::swap(build_hist_nidx, subtraction_trick_nidx);
}
}
this->BuildHist(build_hist_nidx);
this->AllReduceHist(build_hist_nidx, reducer);
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = this->CanDoSubtractionTrick(
nidx_parent, build_hist_nidx, subtraction_trick_nidx);
if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick
this->SubtractionTrick(nidx_parent, build_hist_nidx,
subtraction_trick_nidx);
} else {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer);
}
}
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;
GradStats left_stats;
left_stats.Add(candidate.split.left_sum);
GradStats right_stats;
right_stats.Add(candidate.split.right_sum);
GradStats parent_sum;
parent_sum.Add(left_stats);
parent_sum.Add(right_stats);
node_value_constraints.resize(tree.GetNodes().size());
auto base_weight = node_value_constraints[candidate.nid].CalcWeight(param, parent_sum);
auto left_weight =
node_value_constraints[candidate.nid].CalcWeight(param, left_stats)*param.learning_rate;
auto right_weight =
node_value_constraints[candidate.nid].CalcWeight(param, right_stats)*param.learning_rate;
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.sum_hess);
// Set up child constraints
node_value_constraints.resize(tree.GetNodes().size());
node_value_constraints[candidate.nid].SetChild(
param, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
&node_value_constraints[tree[candidate.nid].LeftChild()],
&node_value_constraints[tree[candidate.nid].RightChild()]);
node_sum_gradients[tree[candidate.nid].LeftChild()] =
candidate.split.left_sum;
node_sum_gradients[tree[candidate.nid].RightChild()] =
candidate.split.right_sum;
}
void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all,
dh::AllReducer* reducer, int64_t num_columns) {
constexpr int kRootNIdx = 0;
const auto &gpair = gpair_all->DeviceSpan(device_id);
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d,
gpair.size());
reducer->AllReduceSum(
shard_idx, reinterpret_cast<float*>(node_sum_gradients_d.data()),
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
reducer->Synchronize(device_id);
dh::safe_cuda(cudaMemcpy(node_sum_gradients.data(),
node_sum_gradients_d.data(), sizeof(GradientPair),
cudaMemcpyDeviceToHost));
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, reducer);
// Remember root stats
p_tree->Stat(kRootNIdx).sum_hess = node_sum_gradients[kRootNIdx].GetHess();
auto weight = CalcWeight(param, node_sum_gradients[kRootNIdx]);
p_tree->Stat(kRootNIdx).base_weight = weight;
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
// Initialise root constraint
node_value_constraints.resize(p_tree->GetNodes().size());
// Generate first split
auto split = this->EvaluateSplits({kRootNIdx}, *p_tree, num_columns);
qexpand->push(
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split.at(0), 0));
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
RegTree* p_tree, dh::AllReducer* reducer) {
auto& tree = *p_tree;
monitor.StartCuda("Reset");
this->Reset(gpair_all, p_fmat->Info().num_col_);
monitor.StopCuda("Reset");
monitor.StartCuda("InitRoot");
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_);
monitor.StopCuda("InitRoot");
auto timestamp = qexpand->size();
auto num_leaves = 1;
while (!qexpand->empty()) {
ExpandEntry candidate = qexpand->top();
qexpand->pop();
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.StartCuda("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("UpdatePosition");
monitor.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate.nid, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("BuildHist");
monitor.StartCuda("EvaluateSplits");
auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx},
*p_tree, p_fmat->Info().num_col_);
monitor.StopCuda("EvaluateSplits");
qexpand->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), splits.at(0),
timestamp++));
qexpand->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx),
splits.at(1), timestamp++));
}
}
monitor.StartCuda("FinalisePosition");
this->FinalisePosition(p_tree);
monitor.StopCuda("FinalisePosition");
}
};
template <typename GradientSumT>
@ -1179,12 +1387,6 @@ class GPUHistMakerSpecialised{
dh::CheckComputeCapability();
if (param_.grow_policy == TrainParam::kLossGuide) {
qexpand_.reset(new ExpandQueue(LossGuide));
} else {
qexpand_.reset(new ExpandQueue(DepthWise));
}
monitor_.Init("updater_gpu_hist");
}
@ -1223,17 +1425,23 @@ class GPUHistMakerSpecialised{
auto batch_iter = dmat->GetRowBatches().begin();
const SparsePage& batch = *batch_iter;
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
// Create device shards
shards_.resize(n_devices);
dh::ExecuteIndexShards(
&shards_,
[&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(dist_.Devices().DeviceId(i)));
size_t start = dist_.ShardStart(info_->num_row_, i);
size_t size = dist_.ShardSize(info_->num_row_, i);
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(dist_.Devices().DeviceId(idx)));
size_t start = dist_.ShardStart(info_->num_row_, idx);
size_t size = dist_.ShardSize(info_->num_row_, idx);
shard = std::unique_ptr<DeviceShard<GradientSumT>>(
new DeviceShard<GradientSumT>(dist_.Devices().DeviceId(i), start,
start + size, param_));
new DeviceShard<GradientSumT>(dist_.Devices().DeviceId(idx), idx,
start, start + size, param_,
column_sampling_seed));
});
// Find the cuts.
@ -1264,277 +1472,61 @@ class GPUHistMakerSpecialised{
this->InitDataOnce(dmat);
monitor_.StopCuda("InitDataOnce");
}
column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
// Copy gpair & reset memory
monitor_.StartCuda("InitDataReset");
gpair->Reshard(dist_);
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->Reset(gpair);
});
monitor_.StopCuda("InitDataReset");
}
void AllReduceHist(int nidx) {
if (shards_.size() == 1 && !rabit::IsDistributed()) {
return;
// Only call this method for testing
void CheckTreesSynchronized(const std::vector<RegTree>& local_trees) const {
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = rabit::GetRank();
if (rank == 0) {
local_trees.front().Save(&fs);
}
monitor_.StartCuda("AllReduce");
reducer_.GroupStart();
for (auto& shard : shards_) {
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
reducer_.AllReduceSum(
dist_.Devices().Index(shard->device_id),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
n_bins_ * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
}
reducer_.GroupEnd();
reducer_.Synchronize();
monitor_.StopCuda("AllReduce");
}
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) {
size_t left_node_max_elements = 0;
size_t right_node_max_elements = 0;
for (auto& shard : shards_) {
left_node_max_elements = (std::max)(
left_node_max_elements, shard->ridx_segments[nidx_left].Size());
right_node_max_elements = (std::max)(
right_node_max_elements, shard->ridx_segments[nidx_right].Size());
}
rabit::Allreduce<rabit::op::Max, size_t>(&left_node_max_elements, 1);
rabit::Allreduce<rabit::op::Max, size_t>(&right_node_max_elements, 1);
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
if (right_node_max_elements < left_node_max_elements) {
build_hist_nidx = nidx_right;
subtraction_trick_nidx = nidx_left;
}
// Build histogram for node with the smallest number of training examples
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->BuildHist(build_hist_nidx);
});
this->AllReduceHist(build_hist_nidx);
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = true;
for (auto& shard : shards_) {
do_subtraction_trick &= shard->CanDoSubtractionTrick(
nidx_parent, build_hist_nidx, subtraction_trick_nidx);
}
if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->SubtractionTrick(nidx_parent, build_hist_nidx,
subtraction_trick_nidx);
});
} else {
// Calculate other histogram manually
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->BuildHist(subtraction_trick_nidx);
});
this->AllReduceHist(subtraction_trick_nidx);
}
}
std::vector<DeviceSplitCandidate> EvaluateSplits(std::vector<int> nidx,
RegTree* p_tree) {
dh::safe_cuda(cudaSetDevice(shards_.front()->device_id));
return shards_.front()->EvaluateSplits(nidx, *p_tree, &column_sampler_,
node_value_constraints_,
info_->num_col_);
}
void InitRoot(RegTree* p_tree) {
constexpr int kRootNIdx = 0;
// Sum gradients
std::vector<GradientPair> tmp_sums(shards_.size());
dh::ExecuteIndexShards(
&shards_,
[&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
tmp_sums[i] = dh::SumReduction(
shard->temp_memory, shard->gpair.data(), shard->gpair.size());
});
GradientPair sum_gradient =
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
rabit::Allreduce<rabit::op::Sum>(
reinterpret_cast<GradientPair::ValueT*>(&sum_gradient), 2);
// Generate root histogram
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->BuildHist(kRootNIdx);
});
this->AllReduceHist(kRootNIdx);
// Remember root stats
p_tree->Stat(kRootNIdx).sum_hess = sum_gradient.GetHess();
auto weight = CalcWeight(param_, sum_gradient);
p_tree->Stat(kRootNIdx).base_weight = weight;
(*p_tree)[kRootNIdx].SetLeaf(param_.learning_rate * weight);
// Store sum gradients
for (auto& shard : shards_) {
shard->node_sum_gradients[kRootNIdx] = sum_gradient;
}
// Initialise root constraint
node_value_constraints_.resize(p_tree->GetNodes().size());
// Generate first split
auto split = this->EvaluateSplits({ kRootNIdx }, p_tree);
qexpand_->push(
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split.at(0), 0));
}
void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->UpdatePosition(candidate.nid,
p_tree->GetNodes()[candidate.nid]);
});
}
void FinalisePosition(RegTree* p_tree) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->FinalisePosition(p_tree);
});
}
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;
GradStats left_stats;
left_stats.Add(candidate.split.left_sum);
GradStats right_stats;
right_stats.Add(candidate.split.right_sum);
GradStats parent_sum;
parent_sum.Add(left_stats);
parent_sum.Add(right_stats);
node_value_constraints_.resize(tree.GetNodes().size());
auto base_weight = node_value_constraints_[candidate.nid].CalcWeight(param_, parent_sum);
auto left_weight =
node_value_constraints_[candidate.nid].CalcWeight(param_, left_stats)*param_.learning_rate;
auto right_weight =
node_value_constraints_[candidate.nid].CalcWeight(param_, right_stats)*param_.learning_rate;
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.sum_hess);
// Set up child constraints
node_value_constraints_.resize(tree.GetNodes().size());
node_value_constraints_[candidate.nid].SetChild(
param_, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
&node_value_constraints_[tree[candidate.nid].LeftChild()],
&node_value_constraints_[tree[candidate.nid].RightChild()]);
// Store sum gradients
for (auto& shard : shards_) {
shard->node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
shard->node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
RegTree reference_tree;
reference_tree.Load(&fs);
for (const auto& tree : local_trees) {
CHECK(tree == reference_tree);
}
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
RegTree* p_tree) {
auto& tree = *p_tree;
monitor_.StartCuda("InitData");
this->InitData(gpair, p_fmat);
monitor_.StopCuda("InitData");
monitor_.StartCuda("InitRoot");
this->InitRoot(p_tree);
monitor_.StopCuda("InitRoot");
auto timestamp = qexpand_->size();
auto num_leaves = 1;
std::vector<RegTree> trees(shards_.size());
for (auto& tree : trees) {
tree = *p_tree;
}
gpair->Reshard(dist_);
while (!qexpand_->empty()) {
ExpandEntry candidate = qexpand_->top();
qexpand_->pop();
if (!candidate.IsValid(param_, num_leaves)) {
continue;
// Launch one thread for each device "shard" containing a subset of rows.
// Threads will cooperatively build the tree, synchronising over histograms.
// Each thread will redundantly build its own copy of the tree
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->UpdateTree(gpair, p_fmat, &trees.at(idx), &reducer_);
});
// All trees are expected to be identical
if (hist_maker_param_.debug_synchronize) {
this->CheckTreesSynchronized(trees);
}
this->ApplySplit(candidate, p_tree);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor_.StartCuda("UpdatePosition");
this->UpdatePosition(candidate, p_tree);
monitor_.StopCuda("UpdatePosition");
monitor_.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
right_child_nidx);
monitor_.StopCuda("BuildHist");
monitor_.StartCuda("EvaluateSplits");
auto splits =
this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree);
qexpand_->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), splits.at(0),
timestamp++));
qexpand_->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx),
splits.at(1), timestamp++));
monitor_.StopCuda("EvaluateSplits");
}
}
monitor_.StartCuda("FinalisePosition");
this->FinalisePosition(p_tree);
monitor_.StopCuda("FinalisePosition");
// Write the output tree
*p_tree = trees.front();
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
monitor_.StartCuda("UpdatePredictionCache");
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
monitor_.StartCuda("UpdatePredictionCache");
p_out_preds->Reshard(dist_.Devices());
dh::ExecuteIndexShards(
&shards_,
@ -1552,9 +1544,6 @@ class GPUHistMakerSpecialised{
MetaInfo* info_; // NOLINT
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT
common::ColumnSampler column_sampler_; // NOLINT
std::vector<ValueConstraint> node_value_constraints_; // NOLINT
private:
bool initialised_;
@ -1565,10 +1554,6 @@ class GPUHistMakerSpecialised{
GPUHistMakerTrainParam hist_maker_param_;
common::GHistIndexMatrix gmat_;
using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand_;
dh::AllReducer reducer_;
DMatrix* p_last_fmat_;

View File

@ -1,3 +1,4 @@
#include <valarray>
#include "../../../src/common/random.h"
#include "../helpers.h"
#include "gtest/gtest.h"
@ -33,7 +34,8 @@ TEST(ColumnSampler, Test) {
// No level or node sampling, should be the same at different depth
cs.Init(n, 1.0f, 1.0f, 0.5f);
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector());
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(),
cs.GetFeatureSet(1)->HostVector());
cs.Init(n, 1.0f, 1.0f, 1.0f);
auto set5 = *cs.GetFeatureSet(0);
@ -45,7 +47,34 @@ TEST(ColumnSampler, Test) {
// Should always be a minimum of one feature
cs.Init(n, 1e-16f, 1e-16f, 1e-16f);
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
}
// Test if different threads using the same seed produce the same result
TEST(ColumnSampler, ThreadSynchronisation) {
const int64_t num_threads = 100;
int n = 128;
int iterations = 10;
int levels = 5;
std::vector<int> reference_result;
bool success =
true; // Cannot use google test asserts in multithreaded region
#pragma omp parallel num_threads(num_threads)
{
for (auto j = 0ull; j < iterations; j++) {
ColumnSampler cs(j);
cs.Init(n, 0.5f, 0.5f, 0.5f);
for (auto level = 0ull; level < levels; level++) {
auto result = cs.GetFeatureSet(level)->ConstHostVector();
#pragma omp single
{ reference_result = result; }
if (result != reference_result) {
success = false;
}
#pragma omp barrier
}
}
}
ASSERT_TRUE(success);
}
} // namespace common
} // namespace xgboost

View File

@ -89,7 +89,7 @@ TEST(GpuHist, BuildGidxDense) {
param.n_gpus = 1;
param.max_leaves = 0;
DeviceShard<GradientPairPrecise> shard(0, 0, kNRows, param);
DeviceShard<GradientPairPrecise> shard(0, 0, 0, kNRows, param, kNCols);
BuildGidx(&shard, kNRows, kNCols);
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
@ -128,7 +128,7 @@ TEST(GpuHist, BuildGidxSparse) {
param.n_gpus = 1;
param.max_leaves = 0;
DeviceShard<GradientPairPrecise> shard(0, 0, kNRows, param);
DeviceShard<GradientPairPrecise> shard(0, 0, 0, kNRows, param, kNCols);
BuildGidx(&shard, kNRows, kNCols, 0.9f);
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
@ -172,7 +172,7 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
param.n_gpus = 1;
param.max_leaves = 0;
DeviceShard<GradientSumT> shard(0, 0, kNRows, param);
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols);
BuildGidx(&shard, kNRows, kNCols);
@ -282,8 +282,8 @@ TEST(GpuHist, EvaluateSplits) {
int max_bins = 4;
// Initialize DeviceShard
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard {
new DeviceShard<GradientPairPrecise>(0, 0, kNRows, param)};
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols)};
// Initialize DeviceShard::node_sum_gradients
shard->node_sum_gradients = {{6.4f, 12.8f}};
@ -321,12 +321,7 @@ TEST(GpuHist, EvaluateSplits) {
thrust::copy(hist.begin(), hist.end(),
shard->hist.Data().begin());
// Initialize GPUHistMaker
GPUHistMakerSpecialised<GradientPairPrecise> hist_maker =
GPUHistMakerSpecialised<GradientPairPrecise>();
hist_maker.param_ = param;
hist_maker.shards_.push_back(std::move(shard));
hist_maker.column_sampler_.Init(kNCols,
shard->column_sampler.Init(kNCols,
param.colsample_bynode,
param.colsample_bylevel,
param.colsample_bytree,
@ -337,13 +332,12 @@ TEST(GpuHist, EvaluateSplits) {
info.num_row_ = kNRows;
info.num_col_ = kNCols;
hist_maker.info_ = &info;
hist_maker.node_value_constraints_.resize(1);
hist_maker.node_value_constraints_[0].lower_bound = -1.0;
hist_maker.node_value_constraints_[0].upper_bound = 1.0;
shard->node_value_constraints.resize(1);
shard->node_value_constraints[0].lower_bound = -1.0;
shard->node_value_constraints[0].upper_bound = 1.0;
std::vector<DeviceSplitCandidate> res =
hist_maker.EvaluateSplits({ 0,0 }, &tree);
shard->EvaluateSplits({ 0,0 }, tree, kNCols);
ASSERT_EQ(res[0].findex, 7);
ASSERT_EQ(res[1].findex, 7);
@ -368,7 +362,8 @@ TEST(GpuHist, ApplySplit) {
}
hist_maker.shards_.resize(1);
hist_maker.shards_[0].reset(new DeviceShard<GradientPairPrecise>(0, 0, kNRows, param));
hist_maker.shards_[0].reset(
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols));
auto& shard = hist_maker.shards_.at(0);
shard->ridx_segments.resize(3); // 3 nodes.
@ -435,8 +430,8 @@ TEST(GpuHist, ApplySplit) {
shard->gidx_buffer.data(), num_symbols);
hist_maker.info_ = &info;
hist_maker.ApplySplit(candidate_entry, &tree);
hist_maker.UpdatePosition(candidate_entry, &tree);
shard->ApplySplit(candidate_entry, &tree);
shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]);
ASSERT_FALSE(tree[kNId].IsLeaf());

View File

@ -54,7 +54,8 @@ base_params = {
'max_depth': 2,
'eta': 1,
'verbosity': 0,
'objective': 'binary:logistic'
'objective': 'binary:logistic',
'debug_synchronize': True
}

View File

@ -51,7 +51,7 @@ class TestGPU(unittest.TestCase):
variable_param = {'n_gpus': [-1], 'max_depth': [2, 10],
'max_leaves': [255, 4],
'max_bin': [2, 256],
'grow_policy': ['lossguide']}
'grow_policy': ['lossguide'], 'debug_synchronize': [True]}
for param in parameter_combinations(variable_param):
param['tree_method'] = 'gpu_hist'
gpu_results = run_suite(param, select_datasets=datasets)