diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index c0bea5980..028d0540c 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -402,7 +402,6 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, void GHistBuilder::BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - const std::vector& feat_set, GHistRow hist) { data_.resize(nbins_ * nthread_, GHistEntry()); std::fill(data_.begin(), data_.end(), GHistEntry()); @@ -461,7 +460,6 @@ void GHistBuilder::BuildHist(const std::vector& gpair, void GHistBuilder::BuildBlockHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexBlockMatrix& gmatb, - const std::vector& feat_set, GHistRow hist) { constexpr int kUnroll = 8; // loop unrolling factor const size_t nblock = gmatb.GetNumBlock(); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 034b8f386..ff5a8b8c3 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -266,13 +266,11 @@ class GHistBuilder { void BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - const std::vector& feat_set, GHistRow hist); // same, with feature grouping void BuildBlockHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexBlockMatrix& gmatb, - const std::vector& feat_set, GHistRow hist); // construct a histogram via subtraction trick void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index f30196b1b..e306119f0 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -102,6 +102,7 @@ void HostDeviceVector::Reshard(GPUSet devices) { } template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; +template class HostDeviceVector; } // namespace xgboost diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index e0d7dbb85..17d4953d5 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -77,7 +77,9 @@ struct HostDeviceVectorImpl { void LazySyncHost() { dh::safe_cuda(cudaSetDevice(device_)); - thrust::copy(data_.begin(), data_.end(), vec_->data_h_.begin() + start_); + dh::safe_cuda( + cudaMemcpy(vec_->data_h_.data(), data_.data().get() + start_, + data_.size() * sizeof(T), cudaMemcpyDeviceToHost)); on_d_ = false; } @@ -90,8 +92,9 @@ struct HostDeviceVectorImpl { size_t size_d = ShardSize(size_h, ndevices, index_); dh::safe_cuda(cudaSetDevice(device_)); data_.resize(size_d); - thrust::copy(vec_->data_h_.begin() + start_, - vec_->data_h_.begin() + start_ + size_d, data_.begin()); + dh::safe_cuda(cudaMemcpy(data_.data().get(), + vec_->data_h_.data() + start_, + size_d * sizeof(T), cudaMemcpyHostToDevice)); on_d_ = true; // this may cause a race condition if LazySyncDevice() is called // from multiple threads in parallel; @@ -186,18 +189,22 @@ struct HostDeviceVectorImpl { void ScatterFrom(thrust::device_ptr begin, thrust::device_ptr end) { CHECK_EQ(end - begin, Size()); if (on_h_) { - thrust::copy(begin, end, data_h_.begin()); + dh::safe_cuda(cudaMemcpy(data_h_.data(), begin.get(), + (end - begin) * sizeof(T), + cudaMemcpyDeviceToHost)); } else { dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { - shard.ScatterFrom(begin.get()); - }); + shard.ScatterFrom(begin.get()); + }); } } void GatherTo(thrust::device_ptr begin, thrust::device_ptr end) { CHECK_EQ(end - begin, Size()); if (on_h_) { - thrust::copy(data_h_.begin(), data_h_.end(), begin); + dh::safe_cuda(cudaMemcpy(begin.get(), data_h_.data(), + data_h_.size() * sizeof(T), + cudaMemcpyHostToDevice)); } else { dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); }); } @@ -400,5 +407,6 @@ void HostDeviceVector::Resize(size_t new_size, T v) { template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; +template class HostDeviceVector; } // namespace xgboost diff --git a/src/common/random.h b/src/common/random.h index bcfe2e904..93041e9d0 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -7,8 +7,14 @@ #ifndef XGBOOST_COMMON_RANDOM_H_ #define XGBOOST_COMMON_RANDOM_H_ -#include +#include +#include +#include #include +#include +#include +#include +#include "host_device_vector.h" namespace xgboost { namespace common { @@ -66,6 +72,78 @@ using GlobalRandomEngine = RandomEngine; */ GlobalRandomEngine& GlobalRandom(); // NOLINT(*) +/** + * \class ColumnSampler + * + * \brief Handles selection of columns due to colsample_bytree and + * colsample_bylevel parameters. Should be initialised before tree + * construction and to reset when tree construction is completed. + */ + +class ColumnSampler { + HostDeviceVector feature_set_tree_; + std::map> feature_set_level_; + float colsample_bylevel_{1.0f}; + float colsample_bytree_{1.0f}; + + std::vector ColSample(std::vector features, float colsample) const { + if (colsample == 1.0f) return features; + CHECK_GT(features.size(), 0); + int n = std::max(1, static_cast(colsample * features.size())); + + std::shuffle(features.begin(), features.end(), common::GlobalRandom()); + features.resize(n); + std::sort(features.begin(), features.end()); + + return features; + } + + public: + /** + * \brief Initialise this object before use. + * + * \param num_col + * \param colsample_bylevel + * \param colsample_bytree + * \param skip_index_0 (Optional) True to skip index 0. + */ + void Init(int64_t num_col, float colsample_bylevel, float colsample_bytree, + bool skip_index_0 = false) { + this->colsample_bylevel_ = colsample_bylevel; + this->colsample_bytree_ = colsample_bytree; + this->Reset(); + + int begin_idx = skip_index_0 ? 1 : 0; + auto& feature_set_h = feature_set_tree_.HostVector(); + feature_set_h.resize(num_col - begin_idx); + + std::iota(feature_set_h.begin(), feature_set_h.end(), begin_idx); + feature_set_h = ColSample(feature_set_h, this->colsample_bytree_); + } + + /** + * \brief Resets this object. + */ + void Reset() { + feature_set_tree_.HostVector().clear(); + feature_set_level_.clear(); + } + + HostDeviceVector& GetFeatureSet(int depth) { + if (this->colsample_bylevel_ == 1.0f) { + return feature_set_tree_; + } + + if (feature_set_level_.count(depth) == 0) { + // Level sampling, level does not yet exist so generate it + auto& level = feature_set_level_[depth].HostVector(); + level = ColSample(feature_set_tree_.HostVector(), this->colsample_bylevel_); + } + // Level sampling + return feature_set_level_[depth]; + } +}; + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_RANDOM_H_ diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index d4eaab7af..b28a87ff0 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -173,19 +173,8 @@ class ColMaker: public TreeUpdater { } } { - // initialize feature index - auto ncol = static_cast(fmat.Info().num_col_); - for (unsigned i = 0; i < ncol; ++i) { - if (fmat.GetColSize(i) != 0) { - feat_index_.push_back(i); - } - } - unsigned n = std::max(static_cast(1), - static_cast(param_.colsample_bytree * feat_index_.size())); - std::shuffle(feat_index_.begin(), feat_index_.end(), common::GlobalRandom()); - CHECK_GT(param_.colsample_bytree, 0U) - << "colsample_bytree cannot be zero."; - feat_index_.resize(n); + column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bylevel, + param_.colsample_bytree); } { // setup temp space for each thread @@ -601,7 +590,7 @@ class ColMaker: public TreeUpdater { // update the solution candidate virtual void UpdateSolution(const SparsePage &batch, - const std::vector &feat_set, + const std::vector &feat_set, const std::vector &gpair, const DMatrix &fmat) { const MetaInfo& info = fmat.Info(); @@ -643,15 +632,7 @@ class ColMaker: public TreeUpdater { const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree) { - std::vector feat_set = feat_index_; - if (param_.colsample_bylevel != 1.0f) { - std::shuffle(feat_set.begin(), feat_set.end(), common::GlobalRandom()); - unsigned n = std::max(static_cast(1), - static_cast(param_.colsample_bylevel * feat_index_.size())); - CHECK_GT(param_.colsample_bylevel, 0U) - << "colsample_bylevel cannot be zero."; - feat_set.resize(n); - } + const std::vector &feat_set = column_sampler_.GetFeatureSet(depth).HostVector(); auto iter = p_fmat->ColIterator(); while (iter->Next()) { this->UpdateSolution(iter->Value(), feat_set, gpair, *p_fmat); @@ -770,8 +751,7 @@ class ColMaker: public TreeUpdater { const TrainParam& param_; // number of omp thread used during training const int nthread_; - // Per feature: shuffle index of each feature index - std::vector feat_index_; + common::ColumnSampler column_sampler_; // Instance Data: current node position in the tree of each instance std::vector position_; // PerThread x PerTreeNode: statistics for per thread construction diff --git a/src/tree/updater_fast_hist.cc b/src/tree/updater_fast_hist.cc index 9c38c6ecc..cda6c30b4 100644 --- a/src/tree/updater_fast_hist.cc +++ b/src/tree/updater_fast_hist.cc @@ -170,7 +170,6 @@ class FastHistMaker: public TreeUpdater { tstart = dmlc::GetTime(); this->InitData(gmat, gpair_h, *p_fmat, *p_tree); - std::vector feat_set = feat_index_; time_init_data = dmlc::GetTime() - tstart; // FIXME(hcho3): this code is broken when param.num_roots > 1. Please fix it @@ -179,7 +178,7 @@ class FastHistMaker: public TreeUpdater { for (int nid = 0; nid < p_tree->param.num_roots; ++nid) { tstart = dmlc::GetTime(); hist_.AddHistRow(nid); - BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, feat_set, hist_[nid]); + BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid]); time_build_hist += dmlc::GetTime() - tstart; tstart = dmlc::GetTime(); @@ -187,7 +186,7 @@ class FastHistMaker: public TreeUpdater { time_init_new_node += dmlc::GetTime() - tstart; tstart = dmlc::GetTime(); - this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree, feat_set); + this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree); time_evaluate_split += dmlc::GetTime() - tstart; qexpand_->push(ExpandEntry(nid, p_tree->GetDepth(nid), snode_[nid].best.loss_chg, @@ -214,10 +213,10 @@ class FastHistMaker: public TreeUpdater { hist_.AddHistRow(cleft); hist_.AddHistRow(cright); if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { - BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, feat_set, hist_[cleft]); + BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]); SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); } else { - BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, feat_set, hist_[cright]); + BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright]); SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]); } time_build_hist += dmlc::GetTime() - tstart; @@ -231,8 +230,8 @@ class FastHistMaker: public TreeUpdater { time_init_new_node += dmlc::GetTime() - tstart; tstart = dmlc::GetTime(); - this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree, feat_set); - this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree, feat_set); + this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree); + this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree); time_evaluate_split += dmlc::GetTime() - tstart; qexpand_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft), @@ -296,12 +295,11 @@ class FastHistMaker: public TreeUpdater { const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, - const std::vector& feat_set, GHistRow hist) { if (fhparam_.enable_feature_grouping > 0) { - hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, feat_set, hist); + hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist); } else { - hist_builder_.BuildHist(gpair, row_indices, gmat, feat_set, hist); + hist_builder_.BuildHist(gpair, row_indices, gmat, hist); } } @@ -427,23 +425,13 @@ class FastHistMaker: public TreeUpdater { // store a pointer to training data p_last_fmat_ = &fmat; // initialize feature index - auto ncol = static_cast(info.num_col_); - feat_index_.clear(); if (data_layout_ == kDenseDataOneBased) { - for (bst_uint i = 1; i < ncol; ++i) { - feat_index_.push_back(i); - } + column_sampler_.Init(info.num_col_, param_.colsample_bylevel, + param_.colsample_bytree, true); } else { - for (bst_uint i = 0; i < ncol; ++i) { - feat_index_.push_back(i); - } + column_sampler_.Init(info.num_col_, param_.colsample_bylevel, + param_.colsample_bytree, false); } - bst_uint n = std::max(static_cast(1), - static_cast(param_.colsample_bytree * feat_index_.size())); - std::shuffle(feat_index_.begin(), feat_index_.end(), common::GlobalRandom()); - CHECK_GT(param_.colsample_bytree, 0U) - << "colsample_bytree cannot be zero."; - feat_index_.resize(n); } if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { /* specialized code for dense data: @@ -481,11 +469,11 @@ class FastHistMaker: public TreeUpdater { const GHistIndexMatrix& gmat, const HistCollection& hist, const DMatrix& fmat, - const RegTree& tree, - const std::vector& feat_set) { + const RegTree& tree) { // start enumeration const MetaInfo& info = fmat.Info(); - const auto nfeature = static_cast(feat_set.size()); + const auto& feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)).HostVector(); + const auto nfeature = static_cast(feature_set.size()); const auto nthread = static_cast(this->nthread_); best_split_tloc_.resize(nthread); #pragma omp parallel for schedule(static) num_threads(nthread) @@ -494,7 +482,7 @@ class FastHistMaker: public TreeUpdater { } #pragma omp parallel for schedule(dynamic) num_threads(nthread) for (bst_omp_uint i = 0; i < nfeature; ++i) { - const bst_uint fid = feat_set[i]; + const bst_uint fid = feature_set[i]; const unsigned tid = omp_get_thread_num(); this->EnumerateSplit(-1, gmat, hist[nid], snode_[nid], info, &best_split_tloc_[tid], fid, nid); @@ -837,8 +825,7 @@ class FastHistMaker: public TreeUpdater { const FastHistParam& fhparam_; // number of omp thread used during training int nthread_; - // Per feature: shuffle index of each feature index - std::vector feat_index_; + common::ColumnSampler column_sampler_; // the internal row sets RowSetCollection row_set_collection_; // the temp space for split diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index ab83c0bb6..e3e9fade8 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -383,80 +383,5 @@ inline void SubsampleGradientPair(dh::DVec* p_gpair, float subsamp }); } -inline std::vector ColSample(std::vector features, float colsample) { - CHECK_GT(features.size(), 0); - int n = std::max(1, static_cast(colsample * features.size())); - - std::shuffle(features.begin(), features.end(), common::GlobalRandom()); - features.resize(n); - std::sort(features.begin(), features.end()); - - return features; -} - -/** - * \class ColumnSampler - * - * \brief Handles selection of columns due to colsample_bytree and - * colsample_bylevel parameters. Should be initialised the before tree - * construction and to reset When tree construction is completed. - */ - -class ColumnSampler { - std::vector feature_set_tree_; - std::map> feature_set_level_; - TrainParam param_; - - public: - /** - * \fn void Init(int64_t num_col, const TrainParam& param) - * - * \brief Initialise this object before use. - * - * \param num_col Number of cols. - * \param param The parameter. - */ - - void Init(int64_t num_col, const TrainParam& param) { - this->Reset(); - this->param_ = param; - feature_set_tree_.resize(num_col); - std::iota(feature_set_tree_.begin(), feature_set_tree_.end(), 0); - feature_set_tree_ = ColSample(feature_set_tree_, param.colsample_bytree); - } - - /** - * \fn void Reset() - * - * \brief Resets this object. - */ - - void Reset() { - feature_set_tree_.clear(); - feature_set_level_.clear(); - } - - /** - * \fn bool ColumnUsed(int column, int depth) - * - * \brief Whether the current column should be considered as a split. - * - * \param column The column index. - * \param depth The current tree depth. - * - * \return True if it should be used, false if it should not be used. - */ - - bool ColumnUsed(int column, int depth) { - if (feature_set_level_.count(depth) == 0) { - feature_set_level_[depth] = - ColSample(feature_set_tree_, param_.colsample_bylevel); - } - - return std::binary_search(feature_set_level_[depth].begin(), - feature_set_level_[depth].end(), column); - } -}; - } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index ba2761d4c..69c0b2ab4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -124,7 +124,7 @@ __device__ void EvaluateFeature(int fidx, const GradientPairSumT* hist, template __global__ void evaluate_split_kernel( const GradientPairSumT* d_hist, int nidx, uint64_t n_features, - DeviceNodeStats nodes, const int* d_feature_segments, + int* feature_set, DeviceNodeStats nodes, const int* d_feature_segments, const float* d_fidx_min_map, const float* d_gidx_fvalue_map, GPUTrainingParam gpu_param, DeviceSplitCandidate* d_split, ValueConstraint value_constraint, int* d_monotonic_constraints) { @@ -151,7 +151,7 @@ __global__ void evaluate_split_kernel( __syncthreads(); - auto fidx = blockIdx.x; + auto fidx = feature_set[blockIdx.x]; auto constraint = d_monotonic_constraints[fidx]; EvaluateFeature( fidx, d_hist, d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map, @@ -204,7 +204,8 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data, struct DeviceHistogram { std::map nidx_map; // Map nidx to starting index of its histogram - thrust::device_vector data; + thrust::device_vector data; + const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size int n_bins; int device_idx; void Init(int device_idx, int n_bins) { @@ -214,29 +215,42 @@ struct DeviceHistogram { void Reset() { dh::safe_cuda(cudaSetDevice(device_idx)); - thrust::fill(data.begin(), data.end(), GradientPairSumT()); + data.resize(0); + nidx_map.clear(); + } + + bool HistogramExists(int nidx) { + return nidx_map.find(nidx) != nidx_map.end(); + } + + void AllocateHistogram(int nidx) { + if (HistogramExists(nidx)) return; + + if (data.size() > kStopGrowingSize) { + // Recycle histogram memory + auto old_entry = *nidx_map.begin(); + nidx_map.erase(old_entry.first); + dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0, + n_bins * sizeof(GradientPairSumT))); + nidx_map[nidx] = old_entry.second; + } else { + // Append new node histogram + nidx_map[nidx] = data.size(); + dh::safe_cuda(cudaSetDevice(device_idx)); + data.resize(data.size() + (n_bins * 2)); + } } /** - * \summary Return pointer to histogram memory for a given node. Be aware that this function - * may reallocate the underlying memory, invalidating previous pointers. - * - * \author Rory - * \date 28/07/2018 - * + * \summary Return pointer to histogram memory for a given node. * \param nidx Tree node index. - * * \return hist pointer. */ GradientPairSumT* GetHistPtr(int nidx) { - if (nidx_map.find(nidx) == nidx_map.end()) { - // Append new node histogram - nidx_map[nidx] = data.size(); - dh::safe_cuda(cudaSetDevice(device_idx)); - data.resize(data.size() + n_bins, GradientPairSumT()); - } - return data.data().get() + nidx_map[nidx]; + CHECK(this->HistogramExists(nidx)); + auto ptr = data.data().get() + nidx_map[nidx]; + return reinterpret_cast(ptr); } }; @@ -576,6 +590,7 @@ struct DeviceShard { } void BuildHist(int nidx) { + hist.AllocateHistogram(nidx); if (can_use_smem_atomics) { BuildHistUsingSharedMem(nidx); } else { @@ -585,10 +600,6 @@ struct DeviceShard { void SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { - // Make sure histograms are already allocated - hist.GetHistPtr(nidx_parent); - hist.GetHistPtr(nidx_histogram); - hist.GetHistPtr(nidx_subtraction); auto d_node_hist_parent = hist.GetHistPtr(nidx_parent); auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram); auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction); @@ -599,6 +610,14 @@ struct DeviceShard { }); } + bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, + int nidx_subtraction) { + // Make sure histograms are already allocated + hist.AllocateHistogram(nidx_subtraction); + return hist.HistogramExists(nidx_histogram) && + hist.HistogramExists(nidx_parent); + } + __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { unsigned ballot = __ballot(val == left_nidx); if (threadIdx.x % 32 == 0) { @@ -817,7 +836,7 @@ class GPUHistMaker : public TreeUpdater { } monitor_.Stop("InitDataOnce", devices_); - column_sampler_.Init(info_->num_col_, param_); + column_sampler_.Init(info_->num_col_, param_.colsample_bylevel, param_.colsample_bytree); // Copy gpair & reset memory monitor_.Start("InitDataReset", devices_); @@ -860,16 +879,34 @@ class GPUHistMaker : public TreeUpdater { subtraction_trick_nidx = nidx_left; } + // Build histogram for node with the smallest number of training examples dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { shard->BuildHist(build_hist_nidx); }); this->AllReduceHist(build_hist_nidx); - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + // Check whether we can use the subtraction trick to calculate the other + bool do_subtraction_trick = true; + for (auto& shard : shards_) { + do_subtraction_trick &= shard->CanDoSubtractionTrick( + nidx_parent, build_hist_nidx, subtraction_trick_nidx); + } + + if (do_subtraction_trick) { + // Calculate other histogram using subtraction trick + dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { shard->SubtractionTrick(nidx_parent, build_hist_nidx, - subtraction_trick_nidx); + subtraction_trick_nidx); }); + } else { + // Calculate other histogram manually + dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + shard->BuildHist(subtraction_trick_nidx); + }); + + this->AllReduceHist(subtraction_trick_nidx); + } } // Returns best loss @@ -877,8 +914,9 @@ class GPUHistMaker : public TreeUpdater { const std::vector& nidx_set, RegTree* p_tree) { auto columns = info_->num_col_; std::vector best_splits(nidx_set.size()); - std::vector candidate_splits(nidx_set.size() * - columns); + DeviceSplitCandidate* candidate_splits; + dh::safe_cuda(cudaMallocHost(&candidate_splits, nidx_set.size() * + columns * sizeof(DeviceSplitCandidate))); // Use first device auto& shard = shards_.front(); dh::safe_cuda(cudaSetDevice(shard->device_idx)); @@ -892,34 +930,37 @@ class GPUHistMaker : public TreeUpdater { for (auto i = 0; i < nidx_set.size(); i++) { auto nidx = nidx_set[i]; DeviceNodeStats node(shard->node_sum_gradients[nidx], nidx, param_); + auto depth = p_tree->GetDepth(nidx); + + auto& feature_set = column_sampler_.GetFeatureSet(depth); + feature_set.Reshard(GPUSet(shard->device_idx, 1)); const int BLOCK_THREADS = 256; evaluate_split_kernel - <<>>( - shard->hist.GetHistPtr(nidx), nidx, info_->num_col_, node, + <<>>( + shard->hist.GetHistPtr(nidx), nidx, info_->num_col_, + feature_set.DevicePointer(shard->device_idx), node, shard->feature_segments.Data(), shard->min_fvalue.Data(), shard->gidx_fvalue_map.Data(), GPUTrainingParam(param_), d_split + i * columns, node_value_constraints_[nidx], shard->monotone_constraints.Data()); } + dh::safe_cuda(cudaDeviceSynchronize()); dh::safe_cuda( - cudaMemcpy(candidate_splits.data(), shard->temp_memory.d_temp_storage, + cudaMemcpy(candidate_splits, shard->temp_memory.d_temp_storage, sizeof(DeviceSplitCandidate) * columns * nidx_set.size(), cudaMemcpyDeviceToHost)); - for (auto i = 0; i < nidx_set.size(); i++) { - auto nidx = nidx_set[i]; + auto depth = p_tree->GetDepth(nidx_set[i]); DeviceSplitCandidate nidx_best; - for (auto fidx = 0; fidx < columns; fidx++) { + for (auto fidx : column_sampler_.GetFeatureSet(depth).HostVector()) { auto& candidate = candidate_splits[i * columns + fidx]; - if (column_sampler_.ColumnUsed(candidate.findex, - p_tree->GetDepth(nidx))) { - nidx_best.Update(candidate_splits[i * columns + fidx], param_); - } + nidx_best.Update(candidate, param_); } best_splits[i] = nidx_best; } + dh::safe_cuda(cudaFreeHost(candidate_splits)); return std::move(best_splits); } @@ -1113,8 +1154,8 @@ class GPUHistMaker : public TreeUpdater { static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { - if (param.max_depth > 0 && depth == param.max_depth) return false; - if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + if (param.max_depth > 0 && depth >= param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; return true; } @@ -1152,7 +1193,7 @@ class GPUHistMaker : public TreeUpdater { int n_bins_; std::vector> shards_; - ColumnSampler column_sampler_; + common::ColumnSampler column_sampler_; typedef std::priority_queue, std::function> ExpandQueue;