Optimisations for gpu_hist. (#4248)
* Optimisations for gpu_hist. * Use streams to overlap operations. * ColumnSampler now uses HostDeviceVector to prevent repeatedly copying feature vectors to the device.
This commit is contained in:
parent
7814183199
commit
00465d243d
@ -208,16 +208,23 @@ __global__ void LaunchNKernel(int device_idx, size_t begin, size_t end,
|
||||
}
|
||||
|
||||
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||
inline void LaunchN(int device_idx, size_t n, L lambda) {
|
||||
inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
|
||||
if (n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
|
||||
const int GRID_SIZE =
|
||||
static_cast<int>(DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
||||
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
|
||||
lambda);
|
||||
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>(static_cast<size_t>(0),
|
||||
n, lambda);
|
||||
}
|
||||
|
||||
// Default stream version
|
||||
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||
inline void LaunchN(int device_idx, size_t n, L lambda) {
|
||||
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -500,6 +507,31 @@ class BulkAllocator {
|
||||
}
|
||||
};
|
||||
|
||||
// Keep track of pinned memory allocation
|
||||
struct PinnedMemory {
|
||||
void *temp_storage{nullptr};
|
||||
size_t temp_storage_bytes{0};
|
||||
|
||||
~PinnedMemory() { Free(); }
|
||||
|
||||
template <typename T>
|
||||
xgboost::common::Span<T> GetSpan(size_t size) {
|
||||
size_t num_bytes = size * sizeof(T);
|
||||
if (num_bytes > temp_storage_bytes) {
|
||||
Free();
|
||||
safe_cuda(cudaMallocHost(&temp_storage, num_bytes));
|
||||
temp_storage_bytes = num_bytes;
|
||||
}
|
||||
return xgboost::common::Span<T>(static_cast<T *>(temp_storage), size);
|
||||
}
|
||||
|
||||
void Free() {
|
||||
if (temp_storage != nullptr) {
|
||||
safe_cuda(cudaFreeHost(temp_storage));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Keep track of cub library device allocation
|
||||
struct CubMemory {
|
||||
void *d_temp_storage;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <random>
|
||||
|
||||
#include "io.h"
|
||||
#include "host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -84,26 +85,29 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
*/
|
||||
|
||||
class ColumnSampler {
|
||||
std::shared_ptr<std::vector<int>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<std::vector<int>>> feature_set_level_;
|
||||
std::shared_ptr<HostDeviceVector<int>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<HostDeviceVector<int>>> feature_set_level_;
|
||||
float colsample_bylevel_{1.0f};
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
GlobalRandomEngine rng_;
|
||||
|
||||
std::shared_ptr<std::vector<int>> ColSample
|
||||
(std::shared_ptr<std::vector<int>> p_features, float colsample) {
|
||||
std::shared_ptr<HostDeviceVector<int>> ColSample(
|
||||
std::shared_ptr<HostDeviceVector<int>> p_features, float colsample) {
|
||||
if (colsample == 1.0f) return p_features;
|
||||
const auto& features = *p_features;
|
||||
const auto& features = p_features->HostVector();
|
||||
CHECK_GT(features.size(), 0);
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<std::vector<int>>();
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<int>>();
|
||||
auto& new_features = *p_new_features;
|
||||
new_features.resize(features.size());
|
||||
std::copy(features.begin(), features.end(), new_features.begin());
|
||||
std::shuffle(new_features.begin(), new_features.end(), rng_);
|
||||
new_features.resize(n);
|
||||
std::sort(new_features.begin(), new_features.end());
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(),
|
||||
new_features.HostVector().begin());
|
||||
std::shuffle(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end(), rng_);
|
||||
new_features.Resize(n);
|
||||
std::sort(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end());
|
||||
|
||||
return p_new_features;
|
||||
}
|
||||
@ -135,13 +139,14 @@ class ColumnSampler {
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
|
||||
if (feature_set_tree_ == nullptr) {
|
||||
feature_set_tree_ = std::make_shared<std::vector<int>>();
|
||||
feature_set_tree_ = std::make_shared<HostDeviceVector<int>>();
|
||||
}
|
||||
Reset();
|
||||
|
||||
int begin_idx = skip_index_0 ? 1 : 0;
|
||||
feature_set_tree_->resize(num_col - begin_idx);
|
||||
std::iota(feature_set_tree_->begin(), feature_set_tree_->end(), begin_idx);
|
||||
feature_set_tree_->Resize(num_col - begin_idx);
|
||||
std::iota(feature_set_tree_->HostVector().begin(),
|
||||
feature_set_tree_->HostVector().end(), begin_idx);
|
||||
|
||||
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
|
||||
}
|
||||
@ -150,7 +155,7 @@ class ColumnSampler {
|
||||
* \brief Resets this object.
|
||||
*/
|
||||
void Reset() {
|
||||
feature_set_tree_->clear();
|
||||
feature_set_tree_->Resize(0);
|
||||
feature_set_level_.clear();
|
||||
}
|
||||
|
||||
@ -165,7 +170,7 @@ class ColumnSampler {
|
||||
* construction of each tree node, and must be called the same number of times in each
|
||||
* process and with the same parameters to return the same feature set across processes.
|
||||
*/
|
||||
std::shared_ptr<std::vector<int>> GetFeatureSet(int depth) {
|
||||
std::shared_ptr<HostDeviceVector<int>> GetFeatureSet(int depth) {
|
||||
if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) {
|
||||
return feature_set_tree_;
|
||||
}
|
||||
|
||||
@ -632,10 +632,9 @@ class ColMaker: public TreeUpdater {
|
||||
const std::vector<GradientPair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree) {
|
||||
auto p_feature_set = column_sampler_.GetFeatureSet(depth);
|
||||
const auto& feat_set = *p_feature_set;
|
||||
auto feat_set = column_sampler_.GetFeatureSet(depth);
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
this->UpdateSolution(batch, feat_set, gpair, p_fmat);
|
||||
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
|
||||
}
|
||||
// after this each thread's stemp will get the best candidates, aggregate results
|
||||
this->SyncBestSolution(qexpand);
|
||||
|
||||
@ -125,6 +125,18 @@ struct DeviceSplitCandidate {
|
||||
XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
|
||||
};
|
||||
|
||||
struct DeviceSplitCandidateReduceOp {
|
||||
GPUTrainingParam param;
|
||||
DeviceSplitCandidateReduceOp(GPUTrainingParam param) : param(param) {}
|
||||
XGBOOST_DEVICE DeviceSplitCandidate operator()(
|
||||
const DeviceSplitCandidate& a, const DeviceSplitCandidate& b) const {
|
||||
DeviceSplitCandidate best;
|
||||
best.Update(a, param);
|
||||
best.Update(b, param);
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
struct DeviceNodeStats {
|
||||
GradientPair sum_gradients;
|
||||
float root_gain;
|
||||
|
||||
@ -306,8 +306,8 @@ class DeviceHistogram {
|
||||
|
||||
void AllocateHistogram(int nidx) {
|
||||
if (HistogramExists(nidx)) return;
|
||||
size_t current_size =
|
||||
nidx_map_.size() * n_bins_ * 2; // Number of items currently used in data
|
||||
size_t current_size = nidx_map_.size() * n_bins_ *
|
||||
2; // Number of items currently used in data
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
if (data_.size() >= kStopGrowingSize) {
|
||||
// Recycle histogram memory
|
||||
@ -452,7 +452,8 @@ struct IndicateLeftTransform {
|
||||
void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
|
||||
common::Span<int> position_out, common::Span<bst_uint> ridx,
|
||||
common::Span<bst_uint> ridx_out, int left_nidx,
|
||||
int right_nidx, int64_t left_count) {
|
||||
int right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream = nullptr) {
|
||||
auto d_position_out = position_out.data();
|
||||
auto d_position_in = position.data();
|
||||
auto d_ridx_out = ridx_out.data();
|
||||
@ -462,7 +463,7 @@ void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
|
||||
if (d_position_in[idx] == left_nidx) {
|
||||
scatter_address = ex_scan_result;
|
||||
} else {
|
||||
scatter_address = (idx - ex_scan_result) + left_count;
|
||||
scatter_address = (idx - ex_scan_result) + *d_left_count;
|
||||
}
|
||||
d_position_out[scatter_address] = d_position_in[idx];
|
||||
d_ridx_out[scatter_address] = d_ridx_in[idx];
|
||||
@ -474,11 +475,20 @@ void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
|
||||
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
|
||||
position.size());
|
||||
position.size(), stream);
|
||||
temp_memory->LazyAllocate(temp_storage_bytes);
|
||||
cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage,
|
||||
temp_memory->temp_storage_bytes, in_itr,
|
||||
out_itr, position.size());
|
||||
out_itr, position.size(), stream);
|
||||
}
|
||||
|
||||
/*! \brief Count how many rows are assigned to left node. */
|
||||
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
|
||||
unsigned ballot = __ballot(val == left_nidx);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
@ -539,6 +549,8 @@ struct DeviceShard {
|
||||
thrust::device_vector<size_t> row_ptrs;
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
thrust::device_vector<int> feature_set_d;
|
||||
thrust::device_vector<int64_t>
|
||||
left_counts; // Useful to keep a bunch of zeroed memory for sort position
|
||||
/*! The row offset for this shard. */
|
||||
bst_uint row_begin_idx;
|
||||
bst_uint row_end_idx;
|
||||
@ -548,6 +560,9 @@ struct DeviceShard {
|
||||
bool prediction_cache_initialised;
|
||||
|
||||
dh::CubMemory temp_memory;
|
||||
dh::PinnedMemory pinned_memory;
|
||||
|
||||
std::vector<cudaStream_t> streams;
|
||||
|
||||
std::unique_ptr<GPUHistBuilderBase<GradientSumT>> hist_builder;
|
||||
|
||||
@ -597,7 +612,30 @@ struct DeviceShard {
|
||||
|
||||
void CreateHistIndices(const SparsePage& row_batch);
|
||||
|
||||
~DeviceShard() = default;
|
||||
~DeviceShard() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
for (auto& stream : streams) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
}
|
||||
}
|
||||
|
||||
// Get vector of at least n initialised streams
|
||||
std::vector<cudaStream_t>& GetStreams(int n) {
|
||||
if (n > streams.size()) {
|
||||
for (auto& stream : streams) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
}
|
||||
|
||||
streams.clear();
|
||||
streams.resize(n);
|
||||
|
||||
for (auto& stream : streams) {
|
||||
dh::safe_cuda(cudaStreamCreate(&stream));
|
||||
}
|
||||
}
|
||||
|
||||
return streams;
|
||||
}
|
||||
|
||||
// Reset values for each update iteration
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair) {
|
||||
@ -605,7 +643,12 @@ struct DeviceShard {
|
||||
position.CurrentDVec().Fill(0);
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
|
||||
if (left_counts.size() < 256) {
|
||||
left_counts.resize(256);
|
||||
} else {
|
||||
dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0,
|
||||
sizeof(int64_t) * left_counts.size()));
|
||||
}
|
||||
thrust::sequence(ridx.CurrentDVec().tbegin(), ridx.CurrentDVec().tend());
|
||||
|
||||
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
|
||||
@ -616,38 +659,76 @@ struct DeviceShard {
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
DeviceSplitCandidate EvaluateSplit(int nidx,
|
||||
const std::vector<int>& feature_set,
|
||||
ValueConstraint value_constraint) {
|
||||
std::vector<DeviceSplitCandidate> EvaluateSplits(
|
||||
std::vector<int> nidxs, const RegTree& tree,
|
||||
common::ColumnSampler* column_sampler,
|
||||
const std::vector<ValueConstraint>& value_constraints,
|
||||
size_t num_columns) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
auto d_split_candidates = temp_memory.GetSpan<DeviceSplitCandidate>(feature_set.size());
|
||||
feature_set_d.resize(feature_set.size());
|
||||
auto d_features = common::Span<int>(feature_set_d.data().get(),
|
||||
feature_set_d.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_features.data(), feature_set.data(),
|
||||
d_features.size_bytes(), cudaMemcpyDefault));
|
||||
auto result = pinned_memory.GetSpan<DeviceSplitCandidate>(nidxs.size());
|
||||
|
||||
// Work out cub temporary memory requirement
|
||||
GPUTrainingParam gpu_param(param);
|
||||
DeviceSplitCandidateReduceOp op(gpu_param);
|
||||
size_t temp_storage_bytes;
|
||||
DeviceSplitCandidate*dummy = nullptr;
|
||||
cub::DeviceReduce::Reduce(
|
||||
nullptr, temp_storage_bytes, dummy,
|
||||
dummy, num_columns, op,
|
||||
DeviceSplitCandidate());
|
||||
// size in terms of DeviceSplitCandidate
|
||||
size_t cub_memory_size =
|
||||
std::ceil(static_cast<double>(temp_storage_bytes) /
|
||||
sizeof(DeviceSplitCandidate));
|
||||
|
||||
// Allocate enough temporary memory
|
||||
// Result for each nidx
|
||||
// + intermediate result for each column
|
||||
// + cub reduce memory
|
||||
auto temp_span = temp_memory.GetSpan<DeviceSplitCandidate>(
|
||||
nidxs.size() + nidxs.size() * num_columns +cub_memory_size*nidxs.size());
|
||||
auto d_result_all = temp_span.subspan(0, nidxs.size());
|
||||
auto d_split_candidates_all =
|
||||
temp_span.subspan(d_result_all.size(), nidxs.size() * num_columns);
|
||||
auto d_cub_memory_all =
|
||||
temp_span.subspan(d_result_all.size() + d_split_candidates_all.size(),
|
||||
cub_memory_size * nidxs.size());
|
||||
|
||||
auto& streams = this->GetStreams(nidxs.size());
|
||||
for (auto i = 0ull; i < nidxs.size(); i++) {
|
||||
auto nidx = nidxs[i];
|
||||
auto p_feature_set = column_sampler->GetFeatureSet(tree.GetDepth(nidx));
|
||||
p_feature_set->Reshard(GPUSet(device_id, 1));
|
||||
auto d_feature_set = p_feature_set->DeviceSpan(device_id);
|
||||
auto d_split_candidates =
|
||||
d_split_candidates_all.subspan(i * num_columns, d_feature_set.size());
|
||||
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
||||
|
||||
// One block for each feature
|
||||
int constexpr kBlockThreads = 256;
|
||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>
|
||||
<<<uint32_t(feature_set.size()), kBlockThreads, 0>>>
|
||||
(hist.GetNodeHistogram(nidx), d_features, node,
|
||||
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node,
|
||||
d_cut.feature_segments.GetSpan(), d_cut.min_fvalue.GetSpan(),
|
||||
d_cut.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
||||
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
|
||||
d_cut.gidx_fvalue_map.GetSpan(), gpu_param, d_split_candidates,
|
||||
value_constraints[nidx], monotone_constraints.GetSpan());
|
||||
|
||||
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
|
||||
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
||||
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
DeviceSplitCandidate best_split;
|
||||
for (auto candidate : split_candidates) {
|
||||
best_split.Update(candidate, param);
|
||||
// Reduce over features to find best feature
|
||||
auto d_result = d_result_all.subspan(i, 1);
|
||||
auto d_cub_memory =
|
||||
d_cub_memory_all.subspan(i * cub_memory_size, cub_memory_size);
|
||||
size_t cub_bytes = d_cub_memory.size() * sizeof(DeviceSplitCandidate);
|
||||
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(d_cub_memory.data()),
|
||||
cub_bytes, d_split_candidates.data(),
|
||||
d_result.data(), d_split_candidates.size(), op,
|
||||
DeviceSplitCandidate(), streams[i]);
|
||||
}
|
||||
|
||||
return best_split;
|
||||
dh::safe_cuda(cudaMemcpy(result.data(), d_result_all.data(),
|
||||
sizeof(DeviceSplitCandidate) * d_result_all.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
return std::vector<DeviceSplitCandidate>(result.begin(), result.end());
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
@ -685,6 +766,10 @@ struct DeviceShard {
|
||||
int* d_position = position.Current();
|
||||
common::CompressedIterator<uint32_t> d_gidx = gidx;
|
||||
size_t row_stride = this->row_stride;
|
||||
if (left_counts.size() <= nidx) {
|
||||
left_counts.resize((nidx * 2) + 1);
|
||||
}
|
||||
int64_t* d_left_count = left_counts.data().get() + nidx;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 128>(
|
||||
device_id, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
@ -710,18 +795,23 @@ struct DeviceShard {
|
||||
// Feature is missing
|
||||
position = default_dir_left ? left_nidx : right_nidx;
|
||||
}
|
||||
|
||||
CountLeft(d_left_count, position, left_nidx);
|
||||
d_position[idx] = position;
|
||||
});
|
||||
IndicateLeftTransform conversion_op(left_nidx);
|
||||
cub::TransformInputIterator<int, IndicateLeftTransform, int*> left_itr(
|
||||
d_position + segment.begin, conversion_op);
|
||||
int left_count = dh::SumReduction(temp_memory, left_itr, segment.Size());
|
||||
|
||||
// Overlap device to host memory copy (left_count) with sort
|
||||
auto& streams = this->GetStreams(2);
|
||||
auto tmp_pinned = pinned_memory.GetSpan<int64_t>(1);
|
||||
dh::safe_cuda(cudaMemcpyAsync(tmp_pinned.data(), d_left_count, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, streams[0]));
|
||||
|
||||
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
|
||||
streams[1]);
|
||||
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[0]));
|
||||
int64_t left_count = tmp_pinned[0];
|
||||
CHECK_LE(left_count, segment.Size());
|
||||
CHECK_GE(left_count, 0);
|
||||
|
||||
SortPositionAndCopy(segment, left_nidx, right_nidx, left_count);
|
||||
|
||||
ridx_segments[left_nidx] =
|
||||
Segment(segment.begin, segment.begin + left_count);
|
||||
ridx_segments[right_nidx] =
|
||||
@ -729,21 +819,22 @@ struct DeviceShard {
|
||||
}
|
||||
|
||||
/*! \brief Sort row indices according to position. */
|
||||
void SortPositionAndCopy(const Segment& segment, int left_nidx, int right_nidx,
|
||||
size_t left_count) {
|
||||
void SortPositionAndCopy(const Segment& segment, int left_nidx,
|
||||
int right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream) {
|
||||
SortPosition(
|
||||
&temp_memory,
|
||||
common::Span<int>(position.Current() + segment.begin, segment.Size()),
|
||||
common::Span<int>(position.other() + segment.begin, segment.Size()),
|
||||
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
|
||||
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
|
||||
left_nidx, right_nidx, left_count);
|
||||
left_nidx, right_nidx, d_left_count, stream);
|
||||
// Copy back key/value
|
||||
const auto d_position_current = position.Current() + segment.begin;
|
||||
const auto d_position_other = position.other() + segment.begin;
|
||||
const auto d_ridx_current = ridx.Current() + segment.begin;
|
||||
const auto d_ridx_other = ridx.other() + segment.begin;
|
||||
dh::LaunchN(device_id, segment.Size(), [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device_id, segment.Size(), stream, [=] __device__(size_t idx) {
|
||||
d_position_current[idx] = d_position_other[idx];
|
||||
d_ridx_current[idx] = d_ridx_other[idx];
|
||||
});
|
||||
@ -752,16 +843,16 @@ struct DeviceShard {
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
prediction_cache.Data(), out_preds_d,
|
||||
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.Data(), out_preds_d,
|
||||
prediction_cache.Size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
prediction_cache_initialised = true;
|
||||
|
||||
CalcWeightTrainParam param_d(param);
|
||||
|
||||
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.Data(),
|
||||
node_sum_gradients.data(),
|
||||
dh::safe_cuda(
|
||||
cudaMemcpyAsync(node_sum_gradients_d.Data(), node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = position.Current();
|
||||
@ -840,6 +931,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
n_bins = hmat.NumBins();
|
||||
null_gidx_value = hmat.NumBins();
|
||||
|
||||
@ -864,7 +956,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
@ -1011,12 +1102,15 @@ class GPUHistMakerSpecialised{
|
||||
const SparsePage& batch = *batch_iter;
|
||||
// Create device shards
|
||||
shards_.resize(n_devices);
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(dist_.Devices().DeviceId(i)));
|
||||
size_t start = dist_.ShardStart(info_->num_row_, i);
|
||||
size_t size = dist_.ShardSize(info_->num_row_, i);
|
||||
shard = std::unique_ptr<DeviceShard<GradientSumT>>
|
||||
(new DeviceShard<GradientSumT>(dist_.Devices().DeviceId(i),
|
||||
start, start + size, param_));
|
||||
shard = std::unique_ptr<DeviceShard<GradientSumT>>(
|
||||
new DeviceShard<GradientSumT>(dist_.Devices().DeviceId(i), start,
|
||||
start + size, param_));
|
||||
shard->InitRowPtrs(batch);
|
||||
});
|
||||
|
||||
@ -1027,8 +1121,10 @@ class GPUHistMakerSpecialised{
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx,
|
||||
std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->InitCompressedData(hmat_, batch);
|
||||
});
|
||||
monitor_.StopCuda("BinningCompression");
|
||||
@ -1056,6 +1152,7 @@ class GPUHistMakerSpecialised{
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->Reset(gpair);
|
||||
});
|
||||
monitor_.StopCuda("InitDataReset");
|
||||
@ -1110,6 +1207,7 @@ class GPUHistMakerSpecialised{
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->BuildHist(build_hist_nidx);
|
||||
});
|
||||
|
||||
@ -1127,6 +1225,7 @@ class GPUHistMakerSpecialised{
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->SubtractionTrick(nidx_parent, build_hist_nidx,
|
||||
subtraction_trick_nidx);
|
||||
});
|
||||
@ -1135,6 +1234,7 @@ class GPUHistMakerSpecialised{
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->BuildHist(subtraction_trick_nidx);
|
||||
});
|
||||
|
||||
@ -1142,10 +1242,12 @@ class GPUHistMakerSpecialised{
|
||||
}
|
||||
}
|
||||
|
||||
DeviceSplitCandidate EvaluateSplit(int nidx, RegTree* p_tree) {
|
||||
return shards_.front()->EvaluateSplit(
|
||||
nidx, *column_sampler_.GetFeatureSet(p_tree->GetDepth(nidx)),
|
||||
node_value_constraints_[nidx]);
|
||||
std::vector<DeviceSplitCandidate> EvaluateSplits(std::vector<int> nidx,
|
||||
RegTree* p_tree) {
|
||||
dh::safe_cuda(cudaSetDevice(shards_.front()->device_id));
|
||||
return shards_.front()->EvaluateSplits(nidx, *p_tree, &column_sampler_,
|
||||
node_value_constraints_,
|
||||
info_->num_col_);
|
||||
}
|
||||
|
||||
void InitRoot(RegTree* p_tree) {
|
||||
@ -1171,6 +1273,7 @@ class GPUHistMakerSpecialised{
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->BuildHist(kRootNIdx);
|
||||
});
|
||||
|
||||
@ -1191,9 +1294,9 @@ class GPUHistMakerSpecialised{
|
||||
node_value_constraints_.resize(p_tree->GetNodes().size());
|
||||
|
||||
// Generate first split
|
||||
auto split = this->EvaluateSplit(kRootNIdx, p_tree);
|
||||
auto split = this->EvaluateSplits({ kRootNIdx }, p_tree);
|
||||
qexpand_->push(
|
||||
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0));
|
||||
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split.at(0), 0));
|
||||
}
|
||||
|
||||
void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
@ -1219,6 +1322,7 @@ class GPUHistMakerSpecialised{
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx,
|
||||
default_dir_left, is_dense, fidx_begin,
|
||||
fidx_end);
|
||||
@ -1296,14 +1400,14 @@ class GPUHistMakerSpecialised{
|
||||
monitor_.StopCuda("BuildHist");
|
||||
|
||||
monitor_.StartCuda("EvaluateSplits");
|
||||
auto left_child_split = this->EvaluateSplit(left_child_nidx, p_tree);
|
||||
auto right_child_split = this->EvaluateSplit(right_child_nidx, p_tree);
|
||||
auto splits =
|
||||
this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree);
|
||||
qexpand_->push(ExpandEntry(left_child_nidx,
|
||||
tree.GetDepth(left_child_nidx),
|
||||
left_child_split, timestamp++));
|
||||
tree.GetDepth(left_child_nidx), splits.at(0),
|
||||
timestamp++));
|
||||
qexpand_->push(ExpandEntry(right_child_nidx,
|
||||
tree.GetDepth(right_child_nidx),
|
||||
right_child_split, timestamp++));
|
||||
splits.at(1), timestamp++));
|
||||
monitor_.StopCuda("EvaluateSplits");
|
||||
}
|
||||
}
|
||||
@ -1319,6 +1423,7 @@ class GPUHistMakerSpecialised{
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->UpdatePredictionCache(
|
||||
p_out_preds->DevicePointer(shard->device_id));
|
||||
});
|
||||
|
||||
@ -529,7 +529,7 @@ void QuantileHistMaker::Builder::EvaluateSplit(const int nid,
|
||||
// start enumeration
|
||||
const MetaInfo& info = fmat.Info();
|
||||
auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
|
||||
const auto& feature_set = *p_feature_set;
|
||||
const auto& feature_set = p_feature_set->HostVector();
|
||||
const auto nfeature = static_cast<bst_uint>(feature_set.size());
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
best_split_tloc_.resize(nthread);
|
||||
|
||||
@ -11,38 +11,40 @@ TEST(ColumnSampler, Test) {
|
||||
// No node sampling
|
||||
cs.Init(n, 1.0f, 0.5f, 0.5f);
|
||||
auto set0 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set0.size(), 32);
|
||||
ASSERT_EQ(set0.Size(), 32);
|
||||
|
||||
auto set1 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set0, set1);
|
||||
|
||||
ASSERT_EQ(set0.HostVector(), set1.HostVector());
|
||||
|
||||
auto set2 = *cs.GetFeatureSet(1);
|
||||
ASSERT_NE(set1, set2);
|
||||
ASSERT_EQ(set2.size(), 32);
|
||||
ASSERT_NE(set1.HostVector(), set2.HostVector());
|
||||
ASSERT_EQ(set2.Size(), 32);
|
||||
|
||||
// Node sampling
|
||||
cs.Init(n, 0.5f, 1.0f, 0.5f);
|
||||
auto set3 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set3.size(), 32);
|
||||
ASSERT_EQ(set3.Size(), 32);
|
||||
|
||||
auto set4 = *cs.GetFeatureSet(0);
|
||||
ASSERT_NE(set3, set4);
|
||||
ASSERT_EQ(set4.size(), 32);
|
||||
|
||||
ASSERT_NE(set3.HostVector(), set4.HostVector());
|
||||
ASSERT_EQ(set4.Size(), 32);
|
||||
|
||||
// No level or node sampling, should be the same at different depth
|
||||
cs.Init(n, 1.0f, 1.0f, 0.5f);
|
||||
ASSERT_EQ(*cs.GetFeatureSet(0), *cs.GetFeatureSet(1));
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector());
|
||||
|
||||
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||
auto set5 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5.size(), n);
|
||||
ASSERT_EQ(set5.Size(), n);
|
||||
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||
auto set6 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5, set6);
|
||||
ASSERT_EQ(set5.HostVector(), set6.HostVector());
|
||||
|
||||
// Should always be a minimum of one feature
|
||||
cs.Init(n, 1e-16f, 1e-16f, 1e-16f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->size(), 1);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
|
||||
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@ -304,11 +304,13 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
hist_maker.node_value_constraints_[0].lower_bound = -1.0;
|
||||
hist_maker.node_value_constraints_[0].upper_bound = 1.0;
|
||||
|
||||
DeviceSplitCandidate res =
|
||||
hist_maker.EvaluateSplit(0, &tree);
|
||||
std::vector<DeviceSplitCandidate> res =
|
||||
hist_maker.EvaluateSplits({ 0,0 }, &tree);
|
||||
|
||||
ASSERT_EQ(res.findex, 7);
|
||||
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
|
||||
ASSERT_EQ(res[0].findex, 7);
|
||||
ASSERT_EQ(res[1].findex, 7);
|
||||
ASSERT_NEAR(res[0].fvalue, 0.26, xgboost::kRtEps);
|
||||
ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps);
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
@ -400,7 +402,9 @@ TEST(GpuHist, ApplySplit) {
|
||||
|
||||
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
int right_idx) {
|
||||
int left_count = std::count(position_in.begin(), position_in.end(), left_idx);
|
||||
std::vector<int64_t> left_count = {
|
||||
std::count(position_in.begin(), position_in.end(), left_idx)};
|
||||
thrust::device_vector<int64_t> d_left_count = left_count;
|
||||
thrust::device_vector<int> position = position_in;
|
||||
thrust::device_vector<int> position_out(position.size());
|
||||
|
||||
@ -413,7 +417,7 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
common::Span<int>(position_out.data().get(), position_out.size()),
|
||||
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
|
||||
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
|
||||
right_idx, left_count);
|
||||
right_idx, d_left_count.data().get());
|
||||
thrust::host_vector<int> position_result = position_out;
|
||||
thrust::host_vector<int> ridx_result = ridx_out;
|
||||
|
||||
@ -421,9 +425,9 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
|
||||
// Check row indices are sorted inside left and right segment
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count));
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin() + left_count, ridx_result.end()));
|
||||
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
|
||||
|
||||
// Check key value pairs are the same
|
||||
for (auto i = 0ull; i < ridx_result.size(); i++) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user