GPU memory usage fixes + column sampling refactor (#3635)

* Remove thrust copy calls

* Fix  histogram memory usage

* Cap extreme histogram memory usage

* More efficient column sampling

* Use column sampler across updaters

* More efficient split evaluation on GPU with column sampling
This commit is contained in:
Rory Mitchell 2018-08-27 16:26:46 +12:00 committed by GitHub
parent 60787ecebc
commit 686e990ffc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 198 additions and 182 deletions

View File

@ -402,7 +402,6 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
const std::vector<bst_uint>& feat_set,
GHistRow hist) {
data_.resize(nbins_ * nthread_, GHistEntry());
std::fill(data_.begin(), data_.end(), GHistEntry());
@ -461,7 +460,6 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexBlockMatrix& gmatb,
const std::vector<bst_uint>& feat_set,
GHistRow hist) {
constexpr int kUnroll = 8; // loop unrolling factor
const size_t nblock = gmatb.GetNumBlock();

View File

@ -266,13 +266,11 @@ class GHistBuilder {
void BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
const std::vector<bst_uint>& feat_set,
GHistRow hist);
// same, with feature grouping
void BuildBlockHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexBlockMatrix& gmatb,
const std::vector<bst_uint>& feat_set,
GHistRow hist);
// construct a histogram via subtraction trick
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);

View File

@ -102,6 +102,7 @@ void HostDeviceVector<T>::Reshard(GPUSet devices) { }
template class HostDeviceVector<bst_float>;
template class HostDeviceVector<GradientPair>;
template class HostDeviceVector<unsigned int>;
template class HostDeviceVector<int>;
} // namespace xgboost

View File

@ -77,7 +77,9 @@ struct HostDeviceVectorImpl {
void LazySyncHost() {
dh::safe_cuda(cudaSetDevice(device_));
thrust::copy(data_.begin(), data_.end(), vec_->data_h_.begin() + start_);
dh::safe_cuda(
cudaMemcpy(vec_->data_h_.data(), data_.data().get() + start_,
data_.size() * sizeof(T), cudaMemcpyDeviceToHost));
on_d_ = false;
}
@ -90,8 +92,9 @@ struct HostDeviceVectorImpl {
size_t size_d = ShardSize(size_h, ndevices, index_);
dh::safe_cuda(cudaSetDevice(device_));
data_.resize(size_d);
thrust::copy(vec_->data_h_.begin() + start_,
vec_->data_h_.begin() + start_ + size_d, data_.begin());
dh::safe_cuda(cudaMemcpy(data_.data().get(),
vec_->data_h_.data() + start_,
size_d * sizeof(T), cudaMemcpyHostToDevice));
on_d_ = true;
// this may cause a race condition if LazySyncDevice() is called
// from multiple threads in parallel;
@ -186,7 +189,9 @@ struct HostDeviceVectorImpl {
void ScatterFrom(thrust::device_ptr<T> begin, thrust::device_ptr<T> end) {
CHECK_EQ(end - begin, Size());
if (on_h_) {
thrust::copy(begin, end, data_h_.begin());
dh::safe_cuda(cudaMemcpy(data_h_.data(), begin.get(),
(end - begin) * sizeof(T),
cudaMemcpyDeviceToHost));
} else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
shard.ScatterFrom(begin.get());
@ -197,7 +202,9 @@ struct HostDeviceVectorImpl {
void GatherTo(thrust::device_ptr<T> begin, thrust::device_ptr<T> end) {
CHECK_EQ(end - begin, Size());
if (on_h_) {
thrust::copy(data_h_.begin(), data_h_.end(), begin);
dh::safe_cuda(cudaMemcpy(begin.get(), data_h_.data(),
data_h_.size() * sizeof(T),
cudaMemcpyHostToDevice));
} else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
}
@ -400,5 +407,6 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
template class HostDeviceVector<bst_float>;
template class HostDeviceVector<GradientPair>;
template class HostDeviceVector<unsigned int>;
template class HostDeviceVector<int>;
} // namespace xgboost

View File

@ -7,8 +7,14 @@
#ifndef XGBOOST_COMMON_RANDOM_H_
#define XGBOOST_COMMON_RANDOM_H_
#include <random>
#include <xgboost/logging.h>
#include <algorithm>
#include <vector>
#include <limits>
#include <map>
#include <numeric>
#include <random>
#include "host_device_vector.h"
namespace xgboost {
namespace common {
@ -66,6 +72,78 @@ using GlobalRandomEngine = RandomEngine;
*/
GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
/**
* \class ColumnSampler
*
* \brief Handles selection of columns due to colsample_bytree and
* colsample_bylevel parameters. Should be initialised before tree
* construction and to reset when tree construction is completed.
*/
class ColumnSampler {
HostDeviceVector<int> feature_set_tree_;
std::map<int, HostDeviceVector<int>> feature_set_level_;
float colsample_bylevel_{1.0f};
float colsample_bytree_{1.0f};
std::vector<int> ColSample(std::vector<int> features, float colsample) const {
if (colsample == 1.0f) return features;
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
std::sort(features.begin(), features.end());
return features;
}
public:
/**
* \brief Initialise this object before use.
*
* \param num_col
* \param colsample_bylevel
* \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0.
*/
void Init(int64_t num_col, float colsample_bylevel, float colsample_bytree,
bool skip_index_0 = false) {
this->colsample_bylevel_ = colsample_bylevel;
this->colsample_bytree_ = colsample_bytree;
this->Reset();
int begin_idx = skip_index_0 ? 1 : 0;
auto& feature_set_h = feature_set_tree_.HostVector();
feature_set_h.resize(num_col - begin_idx);
std::iota(feature_set_h.begin(), feature_set_h.end(), begin_idx);
feature_set_h = ColSample(feature_set_h, this->colsample_bytree_);
}
/**
* \brief Resets this object.
*/
void Reset() {
feature_set_tree_.HostVector().clear();
feature_set_level_.clear();
}
HostDeviceVector<int>& GetFeatureSet(int depth) {
if (this->colsample_bylevel_ == 1.0f) {
return feature_set_tree_;
}
if (feature_set_level_.count(depth) == 0) {
// Level sampling, level does not yet exist so generate it
auto& level = feature_set_level_[depth].HostVector();
level = ColSample(feature_set_tree_.HostVector(), this->colsample_bylevel_);
}
// Level sampling
return feature_set_level_[depth];
}
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_RANDOM_H_

View File

@ -173,19 +173,8 @@ class ColMaker: public TreeUpdater {
}
}
{
// initialize feature index
auto ncol = static_cast<unsigned>(fmat.Info().num_col_);
for (unsigned i = 0; i < ncol; ++i) {
if (fmat.GetColSize(i) != 0) {
feat_index_.push_back(i);
}
}
unsigned n = std::max(static_cast<unsigned>(1),
static_cast<unsigned>(param_.colsample_bytree * feat_index_.size()));
std::shuffle(feat_index_.begin(), feat_index_.end(), common::GlobalRandom());
CHECK_GT(param_.colsample_bytree, 0U)
<< "colsample_bytree cannot be zero.";
feat_index_.resize(n);
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bylevel,
param_.colsample_bytree);
}
{
// setup temp space for each thread
@ -601,7 +590,7 @@ class ColMaker: public TreeUpdater {
// update the solution candidate
virtual void UpdateSolution(const SparsePage &batch,
const std::vector<bst_uint> &feat_set,
const std::vector<int> &feat_set,
const std::vector<GradientPair> &gpair,
const DMatrix &fmat) {
const MetaInfo& info = fmat.Info();
@ -643,15 +632,7 @@ class ColMaker: public TreeUpdater {
const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
RegTree *p_tree) {
std::vector<bst_uint> feat_set = feat_index_;
if (param_.colsample_bylevel != 1.0f) {
std::shuffle(feat_set.begin(), feat_set.end(), common::GlobalRandom());
unsigned n = std::max(static_cast<unsigned>(1),
static_cast<unsigned>(param_.colsample_bylevel * feat_index_.size()));
CHECK_GT(param_.colsample_bylevel, 0U)
<< "colsample_bylevel cannot be zero.";
feat_set.resize(n);
}
const std::vector<int> &feat_set = column_sampler_.GetFeatureSet(depth).HostVector();
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
this->UpdateSolution(iter->Value(), feat_set, gpair, *p_fmat);
@ -770,8 +751,7 @@ class ColMaker: public TreeUpdater {
const TrainParam& param_;
// number of omp thread used during training
const int nthread_;
// Per feature: shuffle index of each feature index
std::vector<bst_uint> feat_index_;
common::ColumnSampler column_sampler_;
// Instance Data: current node position in the tree of each instance
std::vector<int> position_;
// PerThread x PerTreeNode: statistics for per thread construction

View File

@ -170,7 +170,6 @@ class FastHistMaker: public TreeUpdater {
tstart = dmlc::GetTime();
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
std::vector<bst_uint> feat_set = feat_index_;
time_init_data = dmlc::GetTime() - tstart;
// FIXME(hcho3): this code is broken when param.num_roots > 1. Please fix it
@ -179,7 +178,7 @@ class FastHistMaker: public TreeUpdater {
for (int nid = 0; nid < p_tree->param.num_roots; ++nid) {
tstart = dmlc::GetTime();
hist_.AddHistRow(nid);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, feat_set, hist_[nid]);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid]);
time_build_hist += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
@ -187,7 +186,7 @@ class FastHistMaker: public TreeUpdater {
time_init_new_node += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree, feat_set);
this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree);
time_evaluate_split += dmlc::GetTime() - tstart;
qexpand_->push(ExpandEntry(nid, p_tree->GetDepth(nid),
snode_[nid].best.loss_chg,
@ -214,10 +213,10 @@ class FastHistMaker: public TreeUpdater {
hist_.AddHistRow(cleft);
hist_.AddHistRow(cright);
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, feat_set, hist_[cleft]);
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]);
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
} else {
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, feat_set, hist_[cright]);
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright]);
SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
}
time_build_hist += dmlc::GetTime() - tstart;
@ -231,8 +230,8 @@ class FastHistMaker: public TreeUpdater {
time_init_new_node += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree, feat_set);
this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree, feat_set);
this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree);
this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree);
time_evaluate_split += dmlc::GetTime() - tstart;
qexpand_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft),
@ -296,12 +295,11 @@ class FastHistMaker: public TreeUpdater {
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
const GHistIndexBlockMatrix& gmatb,
const std::vector<bst_uint>& feat_set,
GHistRow hist) {
if (fhparam_.enable_feature_grouping > 0) {
hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, feat_set, hist);
hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist);
} else {
hist_builder_.BuildHist(gpair, row_indices, gmat, feat_set, hist);
hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
}
}
@ -427,24 +425,14 @@ class FastHistMaker: public TreeUpdater {
// store a pointer to training data
p_last_fmat_ = &fmat;
// initialize feature index
auto ncol = static_cast<bst_uint>(info.num_col_);
feat_index_.clear();
if (data_layout_ == kDenseDataOneBased) {
for (bst_uint i = 1; i < ncol; ++i) {
feat_index_.push_back(i);
}
column_sampler_.Init(info.num_col_, param_.colsample_bylevel,
param_.colsample_bytree, true);
} else {
for (bst_uint i = 0; i < ncol; ++i) {
feat_index_.push_back(i);
column_sampler_.Init(info.num_col_, param_.colsample_bylevel,
param_.colsample_bytree, false);
}
}
bst_uint n = std::max(static_cast<bst_uint>(1),
static_cast<bst_uint>(param_.colsample_bytree * feat_index_.size()));
std::shuffle(feat_index_.begin(), feat_index_.end(), common::GlobalRandom());
CHECK_GT(param_.colsample_bytree, 0U)
<< "colsample_bytree cannot be zero.";
feat_index_.resize(n);
}
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
/* specialized code for dense data:
choose the column that has a least positive number of discrete bins.
@ -481,11 +469,11 @@ class FastHistMaker: public TreeUpdater {
const GHistIndexMatrix& gmat,
const HistCollection& hist,
const DMatrix& fmat,
const RegTree& tree,
const std::vector<bst_uint>& feat_set) {
const RegTree& tree) {
// start enumeration
const MetaInfo& info = fmat.Info();
const auto nfeature = static_cast<bst_uint>(feat_set.size());
const auto& feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)).HostVector();
const auto nfeature = static_cast<bst_uint>(feature_set.size());
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
best_split_tloc_.resize(nthread);
#pragma omp parallel for schedule(static) num_threads(nthread)
@ -494,7 +482,7 @@ class FastHistMaker: public TreeUpdater {
}
#pragma omp parallel for schedule(dynamic) num_threads(nthread)
for (bst_omp_uint i = 0; i < nfeature; ++i) {
const bst_uint fid = feat_set[i];
const bst_uint fid = feature_set[i];
const unsigned tid = omp_get_thread_num();
this->EnumerateSplit(-1, gmat, hist[nid], snode_[nid], info,
&best_split_tloc_[tid], fid, nid);
@ -837,8 +825,7 @@ class FastHistMaker: public TreeUpdater {
const FastHistParam& fhparam_;
// number of omp thread used during training
int nthread_;
// Per feature: shuffle index of each feature index
std::vector<bst_uint> feat_index_;
common::ColumnSampler column_sampler_;
// the internal row sets
RowSetCollection row_set_collection_;
// the temp space for split

View File

@ -383,80 +383,5 @@ inline void SubsampleGradientPair(dh::DVec<GradientPair>* p_gpair, float subsamp
});
}
inline std::vector<int> ColSample(std::vector<int> features, float colsample) {
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
std::sort(features.begin(), features.end());
return features;
}
/**
* \class ColumnSampler
*
* \brief Handles selection of columns due to colsample_bytree and
* colsample_bylevel parameters. Should be initialised the before tree
* construction and to reset When tree construction is completed.
*/
class ColumnSampler {
std::vector<int> feature_set_tree_;
std::map<int, std::vector<int>> feature_set_level_;
TrainParam param_;
public:
/**
* \fn void Init(int64_t num_col, const TrainParam& param)
*
* \brief Initialise this object before use.
*
* \param num_col Number of cols.
* \param param The parameter.
*/
void Init(int64_t num_col, const TrainParam& param) {
this->Reset();
this->param_ = param;
feature_set_tree_.resize(num_col);
std::iota(feature_set_tree_.begin(), feature_set_tree_.end(), 0);
feature_set_tree_ = ColSample(feature_set_tree_, param.colsample_bytree);
}
/**
* \fn void Reset()
*
* \brief Resets this object.
*/
void Reset() {
feature_set_tree_.clear();
feature_set_level_.clear();
}
/**
* \fn bool ColumnUsed(int column, int depth)
*
* \brief Whether the current column should be considered as a split.
*
* \param column The column index.
* \param depth The current tree depth.
*
* \return True if it should be used, false if it should not be used.
*/
bool ColumnUsed(int column, int depth) {
if (feature_set_level_.count(depth) == 0) {
feature_set_level_[depth] =
ColSample(feature_set_tree_, param_.colsample_bylevel);
}
return std::binary_search(feature_set_level_[depth].begin(),
feature_set_level_[depth].end(), column);
}
};
} // namespace tree
} // namespace xgboost

View File

@ -124,7 +124,7 @@ __device__ void EvaluateFeature(int fidx, const GradientPairSumT* hist,
template <int BLOCK_THREADS>
__global__ void evaluate_split_kernel(
const GradientPairSumT* d_hist, int nidx, uint64_t n_features,
DeviceNodeStats nodes, const int* d_feature_segments,
int* feature_set, DeviceNodeStats nodes, const int* d_feature_segments,
const float* d_fidx_min_map, const float* d_gidx_fvalue_map,
GPUTrainingParam gpu_param, DeviceSplitCandidate* d_split,
ValueConstraint value_constraint, int* d_monotonic_constraints) {
@ -151,7 +151,7 @@ __global__ void evaluate_split_kernel(
__syncthreads();
auto fidx = blockIdx.x;
auto fidx = feature_set[blockIdx.x];
auto constraint = d_monotonic_constraints[fidx];
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
fidx, d_hist, d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map,
@ -204,7 +204,8 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data,
struct DeviceHistogram {
std::map<int, size_t>
nidx_map; // Map nidx to starting index of its histogram
thrust::device_vector<GradientPairSumT> data;
thrust::device_vector<GradientPairSumT::ValueT> data;
const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size
int n_bins;
int device_idx;
void Init(int device_idx, int n_bins) {
@ -214,29 +215,42 @@ struct DeviceHistogram {
void Reset() {
dh::safe_cuda(cudaSetDevice(device_idx));
thrust::fill(data.begin(), data.end(), GradientPairSumT());
data.resize(0);
nidx_map.clear();
}
bool HistogramExists(int nidx) {
return nidx_map.find(nidx) != nidx_map.end();
}
void AllocateHistogram(int nidx) {
if (HistogramExists(nidx)) return;
if (data.size() > kStopGrowingSize) {
// Recycle histogram memory
auto old_entry = *nidx_map.begin();
nidx_map.erase(old_entry.first);
dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0,
n_bins * sizeof(GradientPairSumT)));
nidx_map[nidx] = old_entry.second;
} else {
// Append new node histogram
nidx_map[nidx] = data.size();
dh::safe_cuda(cudaSetDevice(device_idx));
data.resize(data.size() + (n_bins * 2));
}
}
/**
* \summary Return pointer to histogram memory for a given node. Be aware that this function
* may reallocate the underlying memory, invalidating previous pointers.
*
* \author Rory
* \date 28/07/2018
*
* \summary Return pointer to histogram memory for a given node.
* \param nidx Tree node index.
*
* \return hist pointer.
*/
GradientPairSumT* GetHistPtr(int nidx) {
if (nidx_map.find(nidx) == nidx_map.end()) {
// Append new node histogram
nidx_map[nidx] = data.size();
dh::safe_cuda(cudaSetDevice(device_idx));
data.resize(data.size() + n_bins, GradientPairSumT());
}
return data.data().get() + nidx_map[nidx];
CHECK(this->HistogramExists(nidx));
auto ptr = data.data().get() + nidx_map[nidx];
return reinterpret_cast<GradientPairSumT*>(ptr);
}
};
@ -576,6 +590,7 @@ struct DeviceShard {
}
void BuildHist(int nidx) {
hist.AllocateHistogram(nidx);
if (can_use_smem_atomics) {
BuildHistUsingSharedMem(nidx);
} else {
@ -585,10 +600,6 @@ struct DeviceShard {
void SubtractionTrick(int nidx_parent, int nidx_histogram,
int nidx_subtraction) {
// Make sure histograms are already allocated
hist.GetHistPtr(nidx_parent);
hist.GetHistPtr(nidx_histogram);
hist.GetHistPtr(nidx_subtraction);
auto d_node_hist_parent = hist.GetHistPtr(nidx_parent);
auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram);
auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction);
@ -599,6 +610,14 @@ struct DeviceShard {
});
}
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram,
int nidx_subtraction) {
// Make sure histograms are already allocated
hist.AllocateHistogram(nidx_subtraction);
return hist.HistogramExists(nidx_histogram) &&
hist.HistogramExists(nidx_parent);
}
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
unsigned ballot = __ballot(val == left_nidx);
if (threadIdx.x % 32 == 0) {
@ -817,7 +836,7 @@ class GPUHistMaker : public TreeUpdater {
}
monitor_.Stop("InitDataOnce", devices_);
column_sampler_.Init(info_->num_col_, param_);
column_sampler_.Init(info_->num_col_, param_.colsample_bylevel, param_.colsample_bytree);
// Copy gpair & reset memory
monitor_.Start("InitDataReset", devices_);
@ -860,16 +879,34 @@ class GPUHistMaker : public TreeUpdater {
subtraction_trick_nidx = nidx_left;
}
// Build histogram for node with the smallest number of training examples
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
shard->BuildHist(build_hist_nidx);
});
this->AllReduceHist(build_hist_nidx);
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = true;
for (auto& shard : shards_) {
do_subtraction_trick &= shard->CanDoSubtractionTrick(
nidx_parent, build_hist_nidx, subtraction_trick_nidx);
}
if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
shard->SubtractionTrick(nidx_parent, build_hist_nidx,
subtraction_trick_nidx);
});
} else {
// Calculate other histogram manually
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
shard->BuildHist(subtraction_trick_nidx);
});
this->AllReduceHist(subtraction_trick_nidx);
}
}
// Returns best loss
@ -877,8 +914,9 @@ class GPUHistMaker : public TreeUpdater {
const std::vector<int>& nidx_set, RegTree* p_tree) {
auto columns = info_->num_col_;
std::vector<DeviceSplitCandidate> best_splits(nidx_set.size());
std::vector<DeviceSplitCandidate> candidate_splits(nidx_set.size() *
columns);
DeviceSplitCandidate* candidate_splits;
dh::safe_cuda(cudaMallocHost(&candidate_splits, nidx_set.size() *
columns * sizeof(DeviceSplitCandidate)));
// Use first device
auto& shard = shards_.front();
dh::safe_cuda(cudaSetDevice(shard->device_idx));
@ -892,34 +930,37 @@ class GPUHistMaker : public TreeUpdater {
for (auto i = 0; i < nidx_set.size(); i++) {
auto nidx = nidx_set[i];
DeviceNodeStats node(shard->node_sum_gradients[nidx], nidx, param_);
auto depth = p_tree->GetDepth(nidx);
auto& feature_set = column_sampler_.GetFeatureSet(depth);
feature_set.Reshard(GPUSet(shard->device_idx, 1));
const int BLOCK_THREADS = 256;
evaluate_split_kernel<BLOCK_THREADS>
<<<uint32_t(columns), BLOCK_THREADS, 0, streams[i]>>>(
shard->hist.GetHistPtr(nidx), nidx, info_->num_col_, node,
<<<uint32_t(feature_set.Size()), BLOCK_THREADS, 0, streams[i]>>>(
shard->hist.GetHistPtr(nidx), nidx, info_->num_col_,
feature_set.DevicePointer(shard->device_idx), node,
shard->feature_segments.Data(), shard->min_fvalue.Data(),
shard->gidx_fvalue_map.Data(), GPUTrainingParam(param_),
d_split + i * columns, node_value_constraints_[nidx],
shard->monotone_constraints.Data());
}
dh::safe_cuda(cudaDeviceSynchronize());
dh::safe_cuda(
cudaMemcpy(candidate_splits.data(), shard->temp_memory.d_temp_storage,
cudaMemcpy(candidate_splits, shard->temp_memory.d_temp_storage,
sizeof(DeviceSplitCandidate) * columns * nidx_set.size(),
cudaMemcpyDeviceToHost));
for (auto i = 0; i < nidx_set.size(); i++) {
auto nidx = nidx_set[i];
auto depth = p_tree->GetDepth(nidx_set[i]);
DeviceSplitCandidate nidx_best;
for (auto fidx = 0; fidx < columns; fidx++) {
for (auto fidx : column_sampler_.GetFeatureSet(depth).HostVector()) {
auto& candidate = candidate_splits[i * columns + fidx];
if (column_sampler_.ColumnUsed(candidate.findex,
p_tree->GetDepth(nidx))) {
nidx_best.Update(candidate_splits[i * columns + fidx], param_);
}
nidx_best.Update(candidate, param_);
}
best_splits[i] = nidx_best;
}
dh::safe_cuda(cudaFreeHost(candidate_splits));
return std::move(best_splits);
}
@ -1113,8 +1154,8 @@ class GPUHistMaker : public TreeUpdater {
static bool ChildIsValid(const TrainParam& param, int depth,
int num_leaves) {
if (param.max_depth > 0 && depth == param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
if (param.max_depth > 0 && depth >= param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
return true;
}
@ -1152,7 +1193,7 @@ class GPUHistMaker : public TreeUpdater {
int n_bins_;
std::vector<std::unique_ptr<DeviceShard>> shards_;
ColumnSampler column_sampler_;
common::ColumnSampler column_sampler_;
typedef std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>
ExpandQueue;