From 00465d243d02ec4b8030e9848026b42ec6799f91 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 20 Mar 2019 13:30:06 +1300 Subject: [PATCH] Optimisations for gpu_hist. (#4248) * Optimisations for gpu_hist. * Use streams to overlap operations. * ColumnSampler now uses HostDeviceVector to prevent repeatedly copying feature vectors to the device. --- src/common/device_helpers.cuh | 38 ++++- src/common/random.h | 37 +++-- src/tree/updater_colmaker.cc | 5 +- src/tree/updater_gpu_common.cuh | 12 ++ src/tree/updater_gpu_hist.cu | 259 +++++++++++++++++++++--------- src/tree/updater_quantile_hist.cc | 2 +- tests/cpp/common/test_random.cc | 24 +-- tests/cpp/tree/test_gpu_hist.cu | 20 ++- 8 files changed, 278 insertions(+), 119 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index a26342411..325130c28 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -208,16 +208,23 @@ __global__ void LaunchNKernel(int device_idx, size_t begin, size_t end, } template -inline void LaunchN(int device_idx, size_t n, L lambda) { +inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) { if (n == 0) { return; } safe_cuda(cudaSetDevice(device_idx)); + const int GRID_SIZE = static_cast(DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS)); - LaunchNKernel<<>>(static_cast(0), n, - lambda); + LaunchNKernel<<>>(static_cast(0), + n, lambda); +} + +// Default stream version +template +inline void LaunchN(int device_idx, size_t n, L lambda) { + LaunchN(device_idx, n, nullptr, lambda); } /* @@ -500,6 +507,31 @@ class BulkAllocator { } }; +// Keep track of pinned memory allocation +struct PinnedMemory { + void *temp_storage{nullptr}; + size_t temp_storage_bytes{0}; + + ~PinnedMemory() { Free(); } + + template + xgboost::common::Span GetSpan(size_t size) { + size_t num_bytes = size * sizeof(T); + if (num_bytes > temp_storage_bytes) { + Free(); + safe_cuda(cudaMallocHost(&temp_storage, num_bytes)); + temp_storage_bytes = num_bytes; + } + return xgboost::common::Span(static_cast(temp_storage), size); + } + + void Free() { + if (temp_storage != nullptr) { + safe_cuda(cudaFreeHost(temp_storage)); + } + } +}; + // Keep track of cub library device allocation struct CubMemory { void *d_temp_storage; diff --git a/src/common/random.h b/src/common/random.h index 00b7046de..5e28e8878 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -18,6 +18,7 @@ #include #include "io.h" +#include "host_device_vector.h" namespace xgboost { namespace common { @@ -84,26 +85,29 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*) */ class ColumnSampler { - std::shared_ptr> feature_set_tree_; - std::map>> feature_set_level_; + std::shared_ptr> feature_set_tree_; + std::map>> feature_set_level_; float colsample_bylevel_{1.0f}; float colsample_bytree_{1.0f}; float colsample_bynode_{1.0f}; GlobalRandomEngine rng_; - std::shared_ptr> ColSample - (std::shared_ptr> p_features, float colsample) { + std::shared_ptr> ColSample( + std::shared_ptr> p_features, float colsample) { if (colsample == 1.0f) return p_features; - const auto& features = *p_features; + const auto& features = p_features->HostVector(); CHECK_GT(features.size(), 0); int n = std::max(1, static_cast(colsample * features.size())); - auto p_new_features = std::make_shared>(); + auto p_new_features = std::make_shared>(); auto& new_features = *p_new_features; - new_features.resize(features.size()); - std::copy(features.begin(), features.end(), new_features.begin()); - std::shuffle(new_features.begin(), new_features.end(), rng_); - new_features.resize(n); - std::sort(new_features.begin(), new_features.end()); + new_features.Resize(features.size()); + std::copy(features.begin(), features.end(), + new_features.HostVector().begin()); + std::shuffle(new_features.HostVector().begin(), + new_features.HostVector().end(), rng_); + new_features.Resize(n); + std::sort(new_features.HostVector().begin(), + new_features.HostVector().end()); return p_new_features; } @@ -135,13 +139,14 @@ class ColumnSampler { colsample_bynode_ = colsample_bynode; if (feature_set_tree_ == nullptr) { - feature_set_tree_ = std::make_shared>(); + feature_set_tree_ = std::make_shared>(); } Reset(); int begin_idx = skip_index_0 ? 1 : 0; - feature_set_tree_->resize(num_col - begin_idx); - std::iota(feature_set_tree_->begin(), feature_set_tree_->end(), begin_idx); + feature_set_tree_->Resize(num_col - begin_idx); + std::iota(feature_set_tree_->HostVector().begin(), + feature_set_tree_->HostVector().end(), begin_idx); feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_); } @@ -150,7 +155,7 @@ class ColumnSampler { * \brief Resets this object. */ void Reset() { - feature_set_tree_->clear(); + feature_set_tree_->Resize(0); feature_set_level_.clear(); } @@ -165,7 +170,7 @@ class ColumnSampler { * construction of each tree node, and must be called the same number of times in each * process and with the same parameters to return the same feature set across processes. */ - std::shared_ptr> GetFeatureSet(int depth) { + std::shared_ptr> GetFeatureSet(int depth) { if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) { return feature_set_tree_; } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 24b47ba65..32a66df62 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -632,10 +632,9 @@ class ColMaker: public TreeUpdater { const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree) { - auto p_feature_set = column_sampler_.GetFeatureSet(depth); - const auto& feat_set = *p_feature_set; + auto feat_set = column_sampler_.GetFeatureSet(depth); for (const auto &batch : p_fmat->GetSortedColumnBatches()) { - this->UpdateSolution(batch, feat_set, gpair, p_fmat); + this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat); } // after this each thread's stemp will get the best candidates, aggregate results this->SyncBestSolution(qexpand); diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index bdf309d47..4c928c26e 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -125,6 +125,18 @@ struct DeviceSplitCandidate { XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; } }; +struct DeviceSplitCandidateReduceOp { + GPUTrainingParam param; + DeviceSplitCandidateReduceOp(GPUTrainingParam param) : param(param) {} + XGBOOST_DEVICE DeviceSplitCandidate operator()( + const DeviceSplitCandidate& a, const DeviceSplitCandidate& b) const { + DeviceSplitCandidate best; + best.Update(a, param); + best.Update(b, param); + return best; + } +}; + struct DeviceNodeStats { GradientPair sum_gradients; float root_gain; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 11ba665ff..eeb211265 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -306,8 +306,8 @@ class DeviceHistogram { void AllocateHistogram(int nidx) { if (HistogramExists(nidx)) return; - size_t current_size = - nidx_map_.size() * n_bins_ * 2; // Number of items currently used in data + size_t current_size = nidx_map_.size() * n_bins_ * + 2; // Number of items currently used in data dh::safe_cuda(cudaSetDevice(device_id_)); if (data_.size() >= kStopGrowingSize) { // Recycle histogram memory @@ -452,7 +452,8 @@ struct IndicateLeftTransform { void SortPosition(dh::CubMemory* temp_memory, common::Span position, common::Span position_out, common::Span ridx, common::Span ridx_out, int left_nidx, - int right_nidx, int64_t left_count) { + int right_nidx, int64_t* d_left_count, + cudaStream_t stream = nullptr) { auto d_position_out = position_out.data(); auto d_position_in = position.data(); auto d_ridx_out = ridx_out.data(); @@ -462,7 +463,7 @@ void SortPosition(dh::CubMemory* temp_memory, common::Span position, if (d_position_in[idx] == left_nidx) { scatter_address = ex_scan_result; } else { - scatter_address = (idx - ex_scan_result) + left_count; + scatter_address = (idx - ex_scan_result) + *d_left_count; } d_position_out[scatter_address] = d_position_in[idx]; d_ridx_out[scatter_address] = d_ridx_in[idx]; @@ -474,11 +475,20 @@ void SortPosition(dh::CubMemory* temp_memory, common::Span position, dh::DiscardLambdaItr out_itr(write_results); size_t temp_storage_bytes = 0; cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr, - position.size()); + position.size(), stream); temp_memory->LazyAllocate(temp_storage_bytes); cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage, temp_memory->temp_storage_bytes, in_itr, - out_itr, position.size()); + out_itr, position.size(), stream); +} + +/*! \brief Count how many rows are assigned to left node. */ +__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { + unsigned ballot = __ballot(val == left_nidx); + if (threadIdx.x % 32 == 0) { + atomicAdd(reinterpret_cast(d_count), // NOLINT + static_cast(__popc(ballot))); // NOLINT + } } template @@ -539,6 +549,8 @@ struct DeviceShard { thrust::device_vector row_ptrs; /*! \brief On-device feature set, only actually used on one of the devices */ thrust::device_vector feature_set_d; + thrust::device_vector + left_counts; // Useful to keep a bunch of zeroed memory for sort position /*! The row offset for this shard. */ bst_uint row_begin_idx; bst_uint row_end_idx; @@ -548,6 +560,9 @@ struct DeviceShard { bool prediction_cache_initialised; dh::CubMemory temp_memory; + dh::PinnedMemory pinned_memory; + + std::vector streams; std::unique_ptr> hist_builder; @@ -597,7 +612,30 @@ struct DeviceShard { void CreateHistIndices(const SparsePage& row_batch); - ~DeviceShard() = default; + ~DeviceShard() { + dh::safe_cuda(cudaSetDevice(device_id)); + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } + } + + // Get vector of at least n initialised streams + std::vector& GetStreams(int n) { + if (n > streams.size()) { + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } + + streams.clear(); + streams.resize(n); + + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamCreate(&stream)); + } + } + + return streams; + } // Reset values for each update iteration void Reset(HostDeviceVector* dh_gpair) { @@ -605,7 +643,12 @@ struct DeviceShard { position.CurrentDVec().Fill(0); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); - + if (left_counts.size() < 256) { + left_counts.resize(256); + } else { + dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0, + sizeof(int64_t) * left_counts.size())); + } thrust::sequence(ridx.CurrentDVec().tbegin(), ridx.CurrentDVec().tend()); std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); @@ -616,38 +659,76 @@ struct DeviceShard { hist.Reset(); } - DeviceSplitCandidate EvaluateSplit(int nidx, - const std::vector& feature_set, - ValueConstraint value_constraint) { + std::vector EvaluateSplits( + std::vector nidxs, const RegTree& tree, + common::ColumnSampler* column_sampler, + const std::vector& value_constraints, + size_t num_columns) { dh::safe_cuda(cudaSetDevice(device_id)); - auto d_split_candidates = temp_memory.GetSpan(feature_set.size()); - feature_set_d.resize(feature_set.size()); - auto d_features = common::Span(feature_set_d.data().get(), - feature_set_d.size()); - dh::safe_cuda(cudaMemcpyAsync(d_features.data(), feature_set.data(), - d_features.size_bytes(), cudaMemcpyDefault)); - DeviceNodeStats node(node_sum_gradients[nidx], nidx, param); + auto result = pinned_memory.GetSpan(nidxs.size()); - // One block for each feature - int constexpr kBlockThreads = 256; - EvaluateSplitKernel - <<>> - (hist.GetNodeHistogram(nidx), d_features, node, - d_cut.feature_segments.GetSpan(), d_cut.min_fvalue.GetSpan(), - d_cut.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param), - d_split_candidates, value_constraint, monotone_constraints.GetSpan()); + // Work out cub temporary memory requirement + GPUTrainingParam gpu_param(param); + DeviceSplitCandidateReduceOp op(gpu_param); + size_t temp_storage_bytes; + DeviceSplitCandidate*dummy = nullptr; + cub::DeviceReduce::Reduce( + nullptr, temp_storage_bytes, dummy, + dummy, num_columns, op, + DeviceSplitCandidate()); + // size in terms of DeviceSplitCandidate + size_t cub_memory_size = + std::ceil(static_cast(temp_storage_bytes) / + sizeof(DeviceSplitCandidate)); - std::vector split_candidates(feature_set.size()); - dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(), - split_candidates.size() * sizeof(DeviceSplitCandidate), - cudaMemcpyDeviceToHost)); + // Allocate enough temporary memory + // Result for each nidx + // + intermediate result for each column + // + cub reduce memory + auto temp_span = temp_memory.GetSpan( + nidxs.size() + nidxs.size() * num_columns +cub_memory_size*nidxs.size()); + auto d_result_all = temp_span.subspan(0, nidxs.size()); + auto d_split_candidates_all = + temp_span.subspan(d_result_all.size(), nidxs.size() * num_columns); + auto d_cub_memory_all = + temp_span.subspan(d_result_all.size() + d_split_candidates_all.size(), + cub_memory_size * nidxs.size()); - DeviceSplitCandidate best_split; - for (auto candidate : split_candidates) { - best_split.Update(candidate, param); + auto& streams = this->GetStreams(nidxs.size()); + for (auto i = 0ull; i < nidxs.size(); i++) { + auto nidx = nidxs[i]; + auto p_feature_set = column_sampler->GetFeatureSet(tree.GetDepth(nidx)); + p_feature_set->Reshard(GPUSet(device_id, 1)); + auto d_feature_set = p_feature_set->DeviceSpan(device_id); + auto d_split_candidates = + d_split_candidates_all.subspan(i * num_columns, d_feature_set.size()); + DeviceNodeStats node(node_sum_gradients[nidx], nidx, param); + + // One block for each feature + int constexpr kBlockThreads = 256; + EvaluateSplitKernel + <<>>( + hist.GetNodeHistogram(nidx), d_feature_set, node, + d_cut.feature_segments.GetSpan(), d_cut.min_fvalue.GetSpan(), + d_cut.gidx_fvalue_map.GetSpan(), gpu_param, d_split_candidates, + value_constraints[nidx], monotone_constraints.GetSpan()); + + // Reduce over features to find best feature + auto d_result = d_result_all.subspan(i, 1); + auto d_cub_memory = + d_cub_memory_all.subspan(i * cub_memory_size, cub_memory_size); + size_t cub_bytes = d_cub_memory.size() * sizeof(DeviceSplitCandidate); + cub::DeviceReduce::Reduce(reinterpret_cast(d_cub_memory.data()), + cub_bytes, d_split_candidates.data(), + d_result.data(), d_split_candidates.size(), op, + DeviceSplitCandidate(), streams[i]); } - return best_split; + dh::safe_cuda(cudaMemcpy(result.data(), d_result_all.data(), + sizeof(DeviceSplitCandidate) * d_result_all.size(), + cudaMemcpyDeviceToHost)); + + return std::vector(result.begin(), result.end()); } void BuildHist(int nidx) { @@ -685,6 +766,10 @@ struct DeviceShard { int* d_position = position.Current(); common::CompressedIterator d_gidx = gidx; size_t row_stride = this->row_stride; + if (left_counts.size() <= nidx) { + left_counts.resize((nidx * 2) + 1); + } + int64_t* d_left_count = left_counts.data().get() + nidx; // Launch 1 thread for each row dh::LaunchN<1, 128>( device_id, segment.Size(), [=] __device__(bst_uint idx) { @@ -710,18 +795,23 @@ struct DeviceShard { // Feature is missing position = default_dir_left ? left_nidx : right_nidx; } - + CountLeft(d_left_count, position, left_nidx); d_position[idx] = position; }); - IndicateLeftTransform conversion_op(left_nidx); - cub::TransformInputIterator left_itr( - d_position + segment.begin, conversion_op); - int left_count = dh::SumReduction(temp_memory, left_itr, segment.Size()); + + // Overlap device to host memory copy (left_count) with sort + auto& streams = this->GetStreams(2); + auto tmp_pinned = pinned_memory.GetSpan(1); + dh::safe_cuda(cudaMemcpyAsync(tmp_pinned.data(), d_left_count, sizeof(int64_t), + cudaMemcpyDeviceToHost, streams[0])); + + SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, + streams[1]); + + dh::safe_cuda(cudaStreamSynchronize(streams[0])); + int64_t left_count = tmp_pinned[0]; CHECK_LE(left_count, segment.Size()); CHECK_GE(left_count, 0); - - SortPositionAndCopy(segment, left_nidx, right_nidx, left_count); - ridx_segments[left_nidx] = Segment(segment.begin, segment.begin + left_count); ridx_segments[right_nidx] = @@ -729,21 +819,22 @@ struct DeviceShard { } /*! \brief Sort row indices according to position. */ - void SortPositionAndCopy(const Segment& segment, int left_nidx, int right_nidx, - size_t left_count) { + void SortPositionAndCopy(const Segment& segment, int left_nidx, + int right_nidx, int64_t* d_left_count, + cudaStream_t stream) { SortPosition( &temp_memory, common::Span(position.Current() + segment.begin, segment.Size()), common::Span(position.other() + segment.begin, segment.Size()), common::Span(ridx.Current() + segment.begin, segment.Size()), common::Span(ridx.other() + segment.begin, segment.Size()), - left_nidx, right_nidx, left_count); + left_nidx, right_nidx, d_left_count, stream); // Copy back key/value const auto d_position_current = position.Current() + segment.begin; const auto d_position_other = position.other() + segment.begin; const auto d_ridx_current = ridx.Current() + segment.begin; const auto d_ridx_other = ridx.other() + segment.begin; - dh::LaunchN(device_id, segment.Size(), [=] __device__(size_t idx) { + dh::LaunchN(device_id, segment.Size(), stream, [=] __device__(size_t idx) { d_position_current[idx] = d_position_other[idx]; d_ridx_current[idx] = d_ridx_other[idx]; }); @@ -752,18 +843,18 @@ struct DeviceShard { void UpdatePredictionCache(bst_float* out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); if (!prediction_cache_initialised) { - dh::safe_cuda(cudaMemcpyAsync( - prediction_cache.Data(), out_preds_d, - prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpyAsync(prediction_cache.Data(), out_preds_d, + prediction_cache.Size() * sizeof(bst_float), + cudaMemcpyDefault)); } prediction_cache_initialised = true; CalcWeightTrainParam param_d(param); - dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.Data(), - node_sum_gradients.data(), - sizeof(GradientPair) * node_sum_gradients.size(), - cudaMemcpyHostToDevice)); + dh::safe_cuda( + cudaMemcpyAsync(node_sum_gradients_d.Data(), node_sum_gradients.data(), + sizeof(GradientPair) * node_sum_gradients.size(), + cudaMemcpyHostToDevice)); auto d_position = position.Current(); auto d_ridx = ridx.Current(); auto d_node_sum_gradients = node_sum_gradients_d.Data(); @@ -840,6 +931,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase { template inline void DeviceShard::InitCompressedData( const common::HistCutMatrix& hmat, const SparsePage& row_batch) { + dh::safe_cuda(cudaSetDevice(device_id)); n_bins = hmat.NumBins(); null_gidx_value = hmat.NumBins(); @@ -864,7 +956,6 @@ inline void DeviceShard::InitCompressedData( node_sum_gradients.resize(max_nodes); ridx_segments.resize(max_nodes); - dh::safe_cuda(cudaSetDevice(device_id)); // allocate compressed bin data int num_symbols = n_bins + 1; @@ -1011,14 +1102,17 @@ class GPUHistMakerSpecialised{ const SparsePage& batch = *batch_iter; // Create device shards shards_.resize(n_devices); - dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr>& shard) { - size_t start = dist_.ShardStart(info_->num_row_, i); - size_t size = dist_.ShardSize(info_->num_row_, i); - shard = std::unique_ptr> - (new DeviceShard(dist_.Devices().DeviceId(i), - start, start + size, param_)); - shard->InitRowPtrs(batch); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int i, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(dist_.Devices().DeviceId(i))); + size_t start = dist_.ShardStart(info_->num_row_, i); + size_t size = dist_.ShardSize(info_->num_row_, i); + shard = std::unique_ptr>( + new DeviceShard(dist_.Devices().DeviceId(i), start, + start + size, param_)); + shard->InitRowPtrs(batch); + }); // Find the cuts. monitor_.StartCuda("Quantiles"); @@ -1027,10 +1121,12 @@ class GPUHistMakerSpecialised{ monitor_.StopCuda("Quantiles"); monitor_.StartCuda("BinningCompression"); - dh::ExecuteIndexShards(&shards_, [&](int idx, - std::unique_ptr>& shard) { - shard->InitCompressedData(hmat_, batch); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); + shard->InitCompressedData(hmat_, batch); + }); monitor_.StopCuda("BinningCompression"); ++batch_iter; CHECK(batch_iter.AtEnd()) << "External memory not supported"; @@ -1056,6 +1152,7 @@ class GPUHistMakerSpecialised{ dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); shard->Reset(gpair); }); monitor_.StopCuda("InitDataReset"); @@ -1110,6 +1207,7 @@ class GPUHistMakerSpecialised{ dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); shard->BuildHist(build_hist_nidx); }); @@ -1127,6 +1225,7 @@ class GPUHistMakerSpecialised{ dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); shard->SubtractionTrick(nidx_parent, build_hist_nidx, subtraction_trick_nidx); }); @@ -1135,6 +1234,7 @@ class GPUHistMakerSpecialised{ dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); shard->BuildHist(subtraction_trick_nidx); }); @@ -1142,10 +1242,12 @@ class GPUHistMakerSpecialised{ } } - DeviceSplitCandidate EvaluateSplit(int nidx, RegTree* p_tree) { - return shards_.front()->EvaluateSplit( - nidx, *column_sampler_.GetFeatureSet(p_tree->GetDepth(nidx)), - node_value_constraints_[nidx]); + std::vector EvaluateSplits(std::vector nidx, + RegTree* p_tree) { + dh::safe_cuda(cudaSetDevice(shards_.front()->device_id)); + return shards_.front()->EvaluateSplits(nidx, *p_tree, &column_sampler_, + node_value_constraints_, + info_->num_col_); } void InitRoot(RegTree* p_tree) { @@ -1171,6 +1273,7 @@ class GPUHistMakerSpecialised{ dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); shard->BuildHist(kRootNIdx); }); @@ -1191,9 +1294,9 @@ class GPUHistMakerSpecialised{ node_value_constraints_.resize(p_tree->GetNodes().size()); // Generate first split - auto split = this->EvaluateSplit(kRootNIdx, p_tree); + auto split = this->EvaluateSplits({ kRootNIdx }, p_tree); qexpand_->push( - ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0)); + ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split.at(0), 0)); } void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) { @@ -1219,6 +1322,7 @@ class GPUHistMakerSpecialised{ dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx, default_dir_left, is_dense, fidx_begin, fidx_end); @@ -1296,14 +1400,14 @@ class GPUHistMakerSpecialised{ monitor_.StopCuda("BuildHist"); monitor_.StartCuda("EvaluateSplits"); - auto left_child_split = this->EvaluateSplit(left_child_nidx, p_tree); - auto right_child_split = this->EvaluateSplit(right_child_nidx, p_tree); + auto splits = + this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree); qexpand_->push(ExpandEntry(left_child_nidx, - tree.GetDepth(left_child_nidx), - left_child_split, timestamp++)); + tree.GetDepth(left_child_nidx), splits.at(0), + timestamp++)); qexpand_->push(ExpandEntry(right_child_nidx, tree.GetDepth(right_child_nidx), - right_child_split, timestamp++)); + splits.at(1), timestamp++)); monitor_.StopCuda("EvaluateSplits"); } } @@ -1319,6 +1423,7 @@ class GPUHistMakerSpecialised{ dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); shard->UpdatePredictionCache( p_out_preds->DevicePointer(shard->device_id)); }); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 49a28fb15..140604efb 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -529,7 +529,7 @@ void QuantileHistMaker::Builder::EvaluateSplit(const int nid, // start enumeration const MetaInfo& info = fmat.Info(); auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); - const auto& feature_set = *p_feature_set; + const auto& feature_set = p_feature_set->HostVector(); const auto nfeature = static_cast(feature_set.size()); const auto nthread = static_cast(this->nthread_); best_split_tloc_.resize(nthread); diff --git a/tests/cpp/common/test_random.cc b/tests/cpp/common/test_random.cc index 7e92e2658..702f29907 100644 --- a/tests/cpp/common/test_random.cc +++ b/tests/cpp/common/test_random.cc @@ -11,38 +11,40 @@ TEST(ColumnSampler, Test) { // No node sampling cs.Init(n, 1.0f, 0.5f, 0.5f); auto set0 = *cs.GetFeatureSet(0); - ASSERT_EQ(set0.size(), 32); + ASSERT_EQ(set0.Size(), 32); auto set1 = *cs.GetFeatureSet(0); - ASSERT_EQ(set0, set1); + + ASSERT_EQ(set0.HostVector(), set1.HostVector()); auto set2 = *cs.GetFeatureSet(1); - ASSERT_NE(set1, set2); - ASSERT_EQ(set2.size(), 32); + ASSERT_NE(set1.HostVector(), set2.HostVector()); + ASSERT_EQ(set2.Size(), 32); // Node sampling cs.Init(n, 0.5f, 1.0f, 0.5f); auto set3 = *cs.GetFeatureSet(0); - ASSERT_EQ(set3.size(), 32); + ASSERT_EQ(set3.Size(), 32); auto set4 = *cs.GetFeatureSet(0); - ASSERT_NE(set3, set4); - ASSERT_EQ(set4.size(), 32); + + ASSERT_NE(set3.HostVector(), set4.HostVector()); + ASSERT_EQ(set4.Size(), 32); // No level or node sampling, should be the same at different depth cs.Init(n, 1.0f, 1.0f, 0.5f); - ASSERT_EQ(*cs.GetFeatureSet(0), *cs.GetFeatureSet(1)); + ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector()); cs.Init(n, 1.0f, 1.0f, 1.0f); auto set5 = *cs.GetFeatureSet(0); - ASSERT_EQ(set5.size(), n); + ASSERT_EQ(set5.Size(), n); cs.Init(n, 1.0f, 1.0f, 1.0f); auto set6 = *cs.GetFeatureSet(0); - ASSERT_EQ(set5, set6); + ASSERT_EQ(set5.HostVector(), set6.HostVector()); // Should always be a minimum of one feature cs.Init(n, 1e-16f, 1e-16f, 1e-16f); - ASSERT_EQ(cs.GetFeatureSet(0)->size(), 1); + ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1); } } // namespace common diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 21ad0efe8..12ef32917 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -304,11 +304,13 @@ TEST(GpuHist, EvaluateSplits) { hist_maker.node_value_constraints_[0].lower_bound = -1.0; hist_maker.node_value_constraints_[0].upper_bound = 1.0; - DeviceSplitCandidate res = - hist_maker.EvaluateSplit(0, &tree); + std::vector res = + hist_maker.EvaluateSplits({ 0,0 }, &tree); - ASSERT_EQ(res.findex, 7); - ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps); + ASSERT_EQ(res[0].findex, 7); + ASSERT_EQ(res[1].findex, 7); + ASSERT_NEAR(res[0].fvalue, 0.26, xgboost::kRtEps); + ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps); } TEST(GpuHist, ApplySplit) { @@ -400,7 +402,9 @@ TEST(GpuHist, ApplySplit) { void TestSortPosition(const std::vector& position_in, int left_idx, int right_idx) { - int left_count = std::count(position_in.begin(), position_in.end(), left_idx); + std::vector left_count = { + std::count(position_in.begin(), position_in.end(), left_idx)}; + thrust::device_vector d_left_count = left_count; thrust::device_vector position = position_in; thrust::device_vector position_out(position.size()); @@ -413,7 +417,7 @@ void TestSortPosition(const std::vector& position_in, int left_idx, common::Span(position_out.data().get(), position_out.size()), common::Span(ridx.data().get(), ridx.size()), common::Span(ridx_out.data().get(), ridx_out.size()), left_idx, - right_idx, left_count); + right_idx, d_left_count.data().get()); thrust::host_vector position_result = position_out; thrust::host_vector ridx_result = ridx_out; @@ -421,9 +425,9 @@ void TestSortPosition(const std::vector& position_in, int left_idx, EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end())); // Check row indices are sorted inside left and right segment EXPECT_TRUE( - std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count)); + std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0])); EXPECT_TRUE( - std::is_sorted(ridx_result.begin() + left_count, ridx_result.end())); + std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end())); // Check key value pairs are the same for (auto i = 0ull; i < ridx_result.size(); i++) {