Fix clang-tidy warnings. (#4149)
* Upgrade gtest for clang-tidy. * Use CMake to install GTest instead of mv. * Don't enforce clang-tidy to return 0 due to errors in thrust. * Add a small test for tidy itself. * Reformat.
This commit is contained in:
@@ -108,7 +108,7 @@ __device__ GradientSumT ReduceFeature(common::Span<const GradientSumT> feature_h
|
||||
}
|
||||
|
||||
/*! \brief Find the thread with best gain. */
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename scan_t,
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx,
|
||||
@@ -142,7 +142,7 @@ __device__ void EvaluateFeature(
|
||||
// Gradient value for current bin.
|
||||
GradientSumT bin =
|
||||
thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
||||
scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
@@ -198,12 +198,12 @@ __global__ void EvaluateSplitKernel(
|
||||
ValueConstraint value_constraint,
|
||||
common::Span<int> d_monotonic_constraints) {
|
||||
// KeyValuePair here used as threadIdx.x -> gain_value
|
||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||
typedef cub::BlockScan<
|
||||
GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT;
|
||||
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||
using BlockScanT =
|
||||
cub::BlockScan<GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
using MaxReduceT = cub::BlockReduce<ArgMaxT, BLOCK_THREADS>;
|
||||
|
||||
typedef cub::BlockReduce<GradientSumT, BLOCK_THREADS> SumReduceT;
|
||||
using SumReduceT = cub::BlockReduce<GradientSumT, BLOCK_THREADS>;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
@@ -274,51 +274,56 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data,
|
||||
* \date 28/07/2018
|
||||
*/
|
||||
template <typename GradientSumT>
|
||||
struct DeviceHistogram {
|
||||
class DeviceHistogram {
|
||||
private:
|
||||
/*! \brief Map nidx to starting index of its histogram. */
|
||||
std::map<int, size_t> nidx_map;
|
||||
thrust::device_vector<typename GradientSumT::ValueT> data;
|
||||
const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size
|
||||
int n_bins;
|
||||
std::map<int, size_t> nidx_map_;
|
||||
thrust::device_vector<typename GradientSumT::ValueT> data_;
|
||||
static constexpr size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size
|
||||
int n_bins_;
|
||||
int device_id_;
|
||||
|
||||
public:
|
||||
void Init(int device_id, int n_bins) {
|
||||
this->n_bins = n_bins;
|
||||
this->n_bins_ = n_bins;
|
||||
this->device_id_ = device_id;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaMemsetAsync(
|
||||
data.data().get(), 0,
|
||||
data.size() * sizeof(typename decltype(data)::value_type)));
|
||||
nidx_map.clear();
|
||||
data_.data().get(), 0,
|
||||
data_.size() * sizeof(typename decltype(data_)::value_type)));
|
||||
nidx_map_.clear();
|
||||
}
|
||||
bool HistogramExists(int nidx) {
|
||||
return nidx_map_.find(nidx) != nidx_map_.end();
|
||||
}
|
||||
|
||||
bool HistogramExists(int nidx) {
|
||||
return nidx_map.find(nidx) != nidx_map.end();
|
||||
thrust::device_vector<typename GradientSumT::ValueT> &Data() {
|
||||
return data_;
|
||||
}
|
||||
|
||||
void AllocateHistogram(int nidx) {
|
||||
if (HistogramExists(nidx)) return;
|
||||
size_t current_size =
|
||||
nidx_map.size() * n_bins * 2; // Number of items currently used in data
|
||||
nidx_map_.size() * n_bins_ * 2; // Number of items currently used in data
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
if (data.size() >= kStopGrowingSize) {
|
||||
if (data_.size() >= kStopGrowingSize) {
|
||||
// Recycle histogram memory
|
||||
std::pair<int, size_t> old_entry = *nidx_map.begin();
|
||||
nidx_map.erase(old_entry.first);
|
||||
dh::safe_cuda(cudaMemsetAsync(data.data().get() + old_entry.second, 0,
|
||||
n_bins * sizeof(GradientSumT)));
|
||||
nidx_map[nidx] = old_entry.second;
|
||||
std::pair<int, size_t> old_entry = *nidx_map_.begin();
|
||||
nidx_map_.erase(old_entry.first);
|
||||
dh::safe_cuda(cudaMemsetAsync(data_.data().get() + old_entry.second, 0,
|
||||
n_bins_ * sizeof(GradientSumT)));
|
||||
nidx_map_[nidx] = old_entry.second;
|
||||
} else {
|
||||
// Append new node histogram
|
||||
nidx_map[nidx] = current_size;
|
||||
if (data.size() < current_size + n_bins * 2) {
|
||||
nidx_map_[nidx] = current_size;
|
||||
if (data_.size() < current_size + n_bins_ * 2) {
|
||||
size_t new_size = current_size * 2; // Double in size
|
||||
new_size = std::max(static_cast<size_t>(n_bins * 2),
|
||||
new_size = std::max(static_cast<size_t>(n_bins_ * 2),
|
||||
new_size); // Have at least one histogram
|
||||
data.resize(new_size);
|
||||
data_.resize(new_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -330,9 +335,9 @@ struct DeviceHistogram {
|
||||
*/
|
||||
common::Span<GradientSumT> GetNodeHistogram(int nidx) {
|
||||
CHECK(this->HistogramExists(nidx));
|
||||
auto ptr = data.data().get() + nidx_map[nidx];
|
||||
auto ptr = data_.data().get() + nidx_map_[nidx];
|
||||
return common::Span<GradientSumT>(
|
||||
reinterpret_cast<GradientSumT*>(ptr), n_bins);
|
||||
reinterpret_cast<GradientSumT*>(ptr), n_bins_);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -351,7 +356,7 @@ struct CalcWeightTrainParam {
|
||||
};
|
||||
|
||||
// Bin each input data entry, store the bin indices in compressed form.
|
||||
__global__ void compress_bin_ellpack_k(
|
||||
__global__ void CompressBinEllpackKernel(
|
||||
common::CompressedBufferWriter wr,
|
||||
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
||||
const size_t* __restrict__ row_ptrs, // row offset of input data
|
||||
@@ -366,8 +371,9 @@ __global__ void compress_bin_ellpack_k(
|
||||
unsigned int null_gidx_value) {
|
||||
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
if (irow >= n_rows || ifeature >= row_stride)
|
||||
if (irow >= n_rows || ifeature >= row_stride) {
|
||||
return;
|
||||
}
|
||||
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
|
||||
unsigned int bin = null_gidx_value;
|
||||
if (ifeature < row_length) {
|
||||
@@ -380,8 +386,9 @@ __global__ void compress_bin_ellpack_k(
|
||||
// Assigning the bin in current entry.
|
||||
// S.t.: fvalue < feature_cuts[bin]
|
||||
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
|
||||
if (bin >= ncuts)
|
||||
if (bin >= ncuts) {
|
||||
bin = ncuts - 1;
|
||||
}
|
||||
// Add the number of bins in previous features.
|
||||
bin += cut_rows[feature];
|
||||
}
|
||||
@@ -419,7 +426,7 @@ struct Segment {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
|
||||
Segment() : begin(0), end(0) {}
|
||||
Segment() : begin{0}, end{0} {}
|
||||
|
||||
Segment(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
CHECK_GE(end, begin);
|
||||
@@ -487,7 +494,9 @@ struct GPUHistBuilderBase {
|
||||
// Manage memory for a single GPU
|
||||
template <typename GradientSumT>
|
||||
struct DeviceShard {
|
||||
int device_id_;
|
||||
int n_bins;
|
||||
int device_id;
|
||||
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
|
||||
|
||||
/*! \brief HistCutMatrix stored in device. */
|
||||
@@ -498,14 +507,12 @@ struct DeviceShard {
|
||||
dh::DVec<bst_float> min_fvalue;
|
||||
/*! \brief Cut. */
|
||||
dh::DVec<bst_float> gidx_fvalue_map;
|
||||
} cut_;
|
||||
} d_cut;
|
||||
|
||||
/*! \brief Range of rows for each node. */
|
||||
std::vector<Segment> ridx_segments;
|
||||
DeviceHistogram<GradientSumT> hist;
|
||||
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
dh::DVec<common::CompressedByteT> gidx_buffer;
|
||||
/*! \brief row length for ELLPack. */
|
||||
size_t row_stride;
|
||||
common::CompressedIterator<uint32_t> gidx;
|
||||
@@ -526,6 +533,8 @@ struct DeviceShard {
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
dh::DVec<GradientPair> node_sum_gradients_d;
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
dh::DVec<common::CompressedByteT> gidx_buffer;
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
thrust::device_vector<size_t> row_ptrs;
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
@@ -534,7 +543,6 @@ struct DeviceShard {
|
||||
bst_uint row_begin_idx;
|
||||
bst_uint row_end_idx;
|
||||
bst_uint n_rows;
|
||||
int n_bins;
|
||||
|
||||
TrainParam param;
|
||||
bool prediction_cache_initialised;
|
||||
@@ -544,21 +552,21 @@ struct DeviceShard {
|
||||
std::unique_ptr<GPUHistBuilderBase<GradientSumT>> hist_builder;
|
||||
|
||||
// TODO(canonizer): do add support multi-batch DMatrix here
|
||||
DeviceShard(int device_id, bst_uint row_begin, bst_uint row_end,
|
||||
DeviceShard(int _device_id, bst_uint row_begin, bst_uint row_end,
|
||||
TrainParam _param)
|
||||
: device_id_(device_id),
|
||||
: device_id(_device_id),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(0),
|
||||
n_bins{0},
|
||||
null_gidx_value(0),
|
||||
param(_param),
|
||||
param(std::move(_param)),
|
||||
prediction_cache_initialised(false) {}
|
||||
|
||||
/* Init row_ptrs and row_stride */
|
||||
void InitRowPtrs(const SparsePage& row_batch) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
row_ptrs.resize(n_rows + 1);
|
||||
thrust::copy(offset_vec.data() + row_begin_idx,
|
||||
@@ -589,12 +597,11 @@ struct DeviceShard {
|
||||
|
||||
void CreateHistIndices(const SparsePage& row_batch);
|
||||
|
||||
~DeviceShard() {
|
||||
}
|
||||
~DeviceShard() = default;
|
||||
|
||||
// Reset values for each update iteration
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
position.CurrentDVec().Fill(0);
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
@@ -603,8 +610,8 @@ struct DeviceShard {
|
||||
|
||||
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
|
||||
ridx_segments.front() = Segment(0, ridx.Size());
|
||||
this->gpair.copy(dh_gpair->tcbegin(device_id_),
|
||||
dh_gpair->tcend(device_id_));
|
||||
this->gpair.copy(dh_gpair->tcbegin(device_id),
|
||||
dh_gpair->tcend(device_id));
|
||||
SubsampleGradientPair(&gpair, param.subsample, row_begin_idx);
|
||||
hist.Reset();
|
||||
}
|
||||
@@ -612,7 +619,7 @@ struct DeviceShard {
|
||||
DeviceSplitCandidate EvaluateSplit(int nidx,
|
||||
const std::vector<int>& feature_set,
|
||||
ValueConstraint value_constraint) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
auto d_split_candidates = temp_memory.GetSpan<DeviceSplitCandidate>(feature_set.size());
|
||||
feature_set_d.resize(feature_set.size());
|
||||
auto d_features = common::Span<int>(feature_set_d.data().get(),
|
||||
@@ -622,14 +629,13 @@ struct DeviceShard {
|
||||
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
||||
|
||||
// One block for each feature
|
||||
int constexpr BLOCK_THREADS = 256;
|
||||
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
|
||||
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>(
|
||||
hist.GetNodeHistogram(nidx), d_features, node,
|
||||
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
|
||||
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
||||
d_split_candidates, value_constraint,
|
||||
monotone_constraints.GetSpan());
|
||||
int constexpr kBlockThreads = 256;
|
||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>
|
||||
<<<uint32_t(feature_set.size()), kBlockThreads, 0>>>
|
||||
(hist.GetNodeHistogram(nidx), d_features, node,
|
||||
d_cut.feature_segments.GetSpan(), d_cut.min_fvalue.GetSpan(),
|
||||
d_cut.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
||||
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
|
||||
|
||||
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
|
||||
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
||||
@@ -655,7 +661,7 @@ struct DeviceShard {
|
||||
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
|
||||
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
|
||||
|
||||
dh::LaunchN(device_id_, hist.n_bins, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device_id, n_bins, [=] __device__(size_t idx) {
|
||||
d_node_hist_subtraction[idx] =
|
||||
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
||||
});
|
||||
@@ -673,7 +679,7 @@ struct DeviceShard {
|
||||
int64_t split_gidx, bool default_dir_left, bool is_dense,
|
||||
int fidx_begin, // cut.row_ptr[fidx]
|
||||
int fidx_end) { // cut.row_ptr[fidx + 1]
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
Segment segment = ridx_segments[nidx];
|
||||
bst_uint* d_ridx = ridx.Current();
|
||||
int* d_position = position.Current();
|
||||
@@ -681,7 +687,7 @@ struct DeviceShard {
|
||||
size_t row_stride = this->row_stride;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 128>(
|
||||
device_id_, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
device_id, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
idx += segment.begin;
|
||||
bst_uint ridx = d_ridx[idx];
|
||||
auto row_begin = row_stride * ridx;
|
||||
@@ -724,7 +730,7 @@ struct DeviceShard {
|
||||
|
||||
/*! \brief Sort row indices according to position. */
|
||||
void SortPositionAndCopy(const Segment& segment, int left_nidx, int right_nidx,
|
||||
size_t left_count) {
|
||||
size_t left_count) {
|
||||
SortPosition(
|
||||
&temp_memory,
|
||||
common::Span<int>(position.Current() + segment.begin, segment.Size()),
|
||||
@@ -737,14 +743,14 @@ struct DeviceShard {
|
||||
const auto d_position_other = position.other() + segment.begin;
|
||||
const auto d_ridx_current = ridx.Current() + segment.begin;
|
||||
const auto d_ridx_other = ridx.other() + segment.begin;
|
||||
dh::LaunchN(device_id_, segment.Size(), [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device_id, segment.Size(), [=] __device__(size_t idx) {
|
||||
d_position_current[idx] = d_position_other[idx];
|
||||
d_ridx_current[idx] = d_ridx_other[idx];
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
prediction_cache.Data(), out_preds_d,
|
||||
@@ -764,7 +770,7 @@ struct DeviceShard {
|
||||
auto d_prediction_cache = prediction_cache.Data();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id_, prediction_cache.Size(), [=] __device__(int local_idx) {
|
||||
device_id, prediction_cache.Size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
@@ -799,7 +805,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
if (grid_size <= 0) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>
|
||||
(shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist.data(), d_gpair,
|
||||
segment_begin, n_elements);
|
||||
@@ -819,7 +825,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
size_t const row_stride = shard->row_stride;
|
||||
int const null_gidx_value = shard->null_gidx_value;
|
||||
|
||||
dh::LaunchN(shard->device_id_, n_elements, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(shard->device_id, n_elements, [=] __device__(size_t idx) {
|
||||
int ridx = d_ridx[(idx / row_stride) + segment.begin];
|
||||
// lookup the index (bin) of histogram.
|
||||
int gidx = d_gidx[ridx * row_stride + idx % row_stride];
|
||||
@@ -834,31 +840,31 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
|
||||
n_bins = hmat.row_ptr.back();
|
||||
null_gidx_value = hmat.row_ptr.back();
|
||||
n_bins = hmat.NumBins();
|
||||
null_gidx_value = hmat.NumBins();
|
||||
|
||||
int max_nodes =
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
|
||||
|
||||
ba.Allocate(device_id_,
|
||||
ba.Allocate(device_id,
|
||||
&gpair, n_rows,
|
||||
&ridx, n_rows,
|
||||
&position, n_rows,
|
||||
&prediction_cache, n_rows,
|
||||
&node_sum_gradients_d, max_nodes,
|
||||
&cut_.feature_segments, hmat.row_ptr.size(),
|
||||
&cut_.gidx_fvalue_map, hmat.cut.size(),
|
||||
&cut_.min_fvalue, hmat.min_val.size(),
|
||||
&d_cut.feature_segments, hmat.row_ptr.size(),
|
||||
&d_cut.gidx_fvalue_map, hmat.cut.size(),
|
||||
&d_cut.min_fvalue, hmat.min_val.size(),
|
||||
&monotone_constraints, param.monotone_constraints.size());
|
||||
cut_.gidx_fvalue_map = hmat.cut;
|
||||
cut_.min_fvalue = hmat.min_val;
|
||||
cut_.feature_segments = hmat.row_ptr;
|
||||
d_cut.gidx_fvalue_map = hmat.cut;
|
||||
d_cut.min_fvalue = hmat.min_val;
|
||||
d_cut.feature_segments = hmat.row_ptr;
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
@@ -870,7 +876,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_id_, &gidx_buffer, compressed_size_bytes);
|
||||
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
|
||||
gidx_buffer.Fill(0);
|
||||
|
||||
int nbits = common::detail::SymbolBits(num_symbols);
|
||||
@@ -882,7 +888,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
// check if we can use shared memory for building histograms
|
||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding)
|
||||
auto histogram_size = sizeof(GradientSumT) * null_gidx_value;
|
||||
auto max_smem = dh::MaxSharedMemory(device_id_);
|
||||
auto max_smem = dh::MaxSharedMemory(device_id);
|
||||
if (histogram_size <= max_smem) {
|
||||
hist_builder.reset(new SharedMemHistBuilder<GradientSumT>);
|
||||
} else {
|
||||
@@ -890,7 +896,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
}
|
||||
|
||||
// Init histogram
|
||||
hist.Init(device_id_, hmat.row_ptr.back());
|
||||
hist.Init(device_id, hmat.NumBins());
|
||||
}
|
||||
|
||||
|
||||
@@ -900,7 +906,7 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
|
||||
// bin and compress entries in batches of rows
|
||||
size_t gpu_batch_nrows =
|
||||
std::min
|
||||
(dh::TotalMemory(device_id_) / (16 * row_stride * sizeof(Entry)),
|
||||
(dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(n_rows));
|
||||
const std::vector<Entry>& data_vec = row_batch.data.HostVector();
|
||||
|
||||
@@ -924,12 +930,12 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
|
||||
const dim3 block3(32, 8, 1); // 256 threads
|
||||
const dim3 grid3(dh::DivRoundUp(n_rows, block3.x),
|
||||
dh::DivRoundUp(row_stride, block3.y), 1);
|
||||
compress_bin_ellpack_k<<<grid3, block3>>>
|
||||
CompressBinEllpackKernel<<<grid3, block3>>>
|
||||
(common::CompressedBufferWriter(num_symbols),
|
||||
gidx_buffer.Data(),
|
||||
row_ptrs.data().get() + batch_row_begin,
|
||||
entries_d.data().get(),
|
||||
cut_.gidx_fvalue_map.Data(), cut_.feature_segments.Data(),
|
||||
d_cut.gidx_fvalue_map.Data(), d_cut.feature_segments.Data(),
|
||||
batch_row_begin, batch_nrows,
|
||||
row_ptrs[batch_row_begin],
|
||||
row_stride, null_gidx_value);
|
||||
@@ -948,7 +954,7 @@ class GPUHistMakerSpecialised{
|
||||
public:
|
||||
struct ExpandEntry;
|
||||
|
||||
GPUHistMakerSpecialised() : initialised_(false), p_last_fmat_(nullptr) {}
|
||||
GPUHistMakerSpecialised() : initialised_{false}, p_last_fmat_{nullptr} {}
|
||||
void Init(
|
||||
const std::vector<std::pair<std::string, std::string>>& args) {
|
||||
param_.InitAllowUnknown(args);
|
||||
@@ -977,8 +983,8 @@ class GPUHistMakerSpecialised{
|
||||
ValueConstraint::Init(¶m_, dmat->Info().num_col_);
|
||||
// build tree
|
||||
try {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
this->UpdateTree(gpair, dmat, trees[i]);
|
||||
for (xgboost::RegTree* tree : trees) {
|
||||
this->UpdateTree(gpair, dmat, tree);
|
||||
}
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
} catch (const std::exception& e) {
|
||||
@@ -1056,14 +1062,16 @@ class GPUHistMakerSpecialised{
|
||||
}
|
||||
|
||||
void AllReduceHist(int nidx) {
|
||||
if (shards_.size() == 1 && !rabit::IsDistributed()) return;
|
||||
if (shards_.size() == 1 && !rabit::IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
monitor_.StartCuda("AllReduce");
|
||||
|
||||
reducer_.GroupStart();
|
||||
for (auto& shard : shards_) {
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
|
||||
reducer_.AllReduceSum(
|
||||
dist_.Devices().Index(shard->device_id_),
|
||||
dist_.Devices().Index(shard->device_id),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
n_bins_ * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||
@@ -1141,14 +1149,14 @@ class GPUHistMakerSpecialised{
|
||||
}
|
||||
|
||||
void InitRoot(RegTree* p_tree) {
|
||||
constexpr int root_nidx = 0;
|
||||
constexpr int kRootNIdx = 0;
|
||||
// Sum gradients
|
||||
std::vector<GradientPair> tmp_sums(shards_.size());
|
||||
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
tmp_sums[i] = dh::SumReduction(
|
||||
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
|
||||
});
|
||||
@@ -1156,35 +1164,36 @@ class GPUHistMakerSpecialised{
|
||||
GradientPair sum_gradient =
|
||||
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>((GradientPair::ValueT*)&sum_gradient, 2);
|
||||
rabit::Allreduce<rabit::op::Sum>(
|
||||
reinterpret_cast<GradientPair::ValueT*>(&sum_gradient), 2);
|
||||
|
||||
// Generate root histogram
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
shard->BuildHist(root_nidx);
|
||||
shard->BuildHist(kRootNIdx);
|
||||
});
|
||||
|
||||
this->AllReduceHist(root_nidx);
|
||||
this->AllReduceHist(kRootNIdx);
|
||||
|
||||
// Remember root stats
|
||||
p_tree->Stat(root_nidx).sum_hess = sum_gradient.GetHess();
|
||||
p_tree->Stat(kRootNIdx).sum_hess = sum_gradient.GetHess();
|
||||
auto weight = CalcWeight(param_, sum_gradient);
|
||||
p_tree->Stat(root_nidx).base_weight = weight;
|
||||
(*p_tree)[root_nidx].SetLeaf(param_.learning_rate * weight);
|
||||
p_tree->Stat(kRootNIdx).base_weight = weight;
|
||||
(*p_tree)[kRootNIdx].SetLeaf(param_.learning_rate * weight);
|
||||
|
||||
// Store sum gradients
|
||||
for (auto& shard : shards_) {
|
||||
shard->node_sum_gradients[root_nidx] = sum_gradient;
|
||||
shard->node_sum_gradients[kRootNIdx] = sum_gradient;
|
||||
}
|
||||
|
||||
// Initialise root constraint
|
||||
node_value_constraints_.resize(p_tree->GetNodes().size());
|
||||
|
||||
// Generate first split
|
||||
auto split = this->EvaluateSplit(root_nidx, p_tree);
|
||||
auto split = this->EvaluateSplit(kRootNIdx, p_tree);
|
||||
qexpand_->push(
|
||||
ExpandEntry(root_nidx, p_tree->GetDepth(root_nidx), split, 0));
|
||||
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0));
|
||||
}
|
||||
|
||||
void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
@@ -1302,15 +1311,16 @@ class GPUHistMakerSpecialised{
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
||||
return false;
|
||||
monitor_.StartCuda("UpdatePredictionCache");
|
||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
p_out_preds->Reshard(dist_.Devices());
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
shard->UpdatePredictionCache(
|
||||
p_out_preds->DevicePointer(shard->device_id_));
|
||||
p_out_preds->DevicePointer(shard->device_id));
|
||||
});
|
||||
monitor_.StopCuda("UpdatePredictionCache");
|
||||
return true;
|
||||
@@ -1321,15 +1331,23 @@ class GPUHistMakerSpecialised{
|
||||
int depth;
|
||||
DeviceSplitCandidate split;
|
||||
uint64_t timestamp;
|
||||
ExpandEntry(int nid, int depth, const DeviceSplitCandidate& split,
|
||||
uint64_t timestamp)
|
||||
: nid(nid), depth(depth), split(split), timestamp(timestamp) {}
|
||||
ExpandEntry(int _nid, int _depth, const DeviceSplitCandidate _split,
|
||||
uint64_t _timestamp) :
|
||||
nid{_nid}, depth{_depth}, split(std::move(_split)),
|
||||
timestamp{_timestamp} {}
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0)
|
||||
if (split.loss_chg <= kRtEps) {
|
||||
return false;
|
||||
if (param.max_depth > 0 && depth == param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
|
||||
}
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1365,28 +1383,36 @@ class GPUHistMakerSpecialised{
|
||||
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
|
||||
}
|
||||
}
|
||||
TrainParam param_;
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
common::HistCutMatrix hmat_;
|
||||
common::GHistIndexMatrix gmat_;
|
||||
MetaInfo* info_;
|
||||
|
||||
TrainParam param_; // NOLINT
|
||||
common::HistCutMatrix hmat_; // NOLINT
|
||||
MetaInfo* info_; // NOLINT
|
||||
|
||||
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT
|
||||
common::ColumnSampler column_sampler_; // NOLINT
|
||||
|
||||
std::vector<ValueConstraint> node_value_constraints_; // NOLINT
|
||||
|
||||
private:
|
||||
bool initialised_;
|
||||
|
||||
int n_devices_;
|
||||
int n_bins_;
|
||||
|
||||
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_;
|
||||
common::ColumnSampler column_sampler_;
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
common::GHistIndexMatrix gmat_;
|
||||
|
||||
using ExpandQueue = std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
std::unique_ptr<ExpandQueue> qexpand_;
|
||||
common::Monitor monitor_;
|
||||
dh::AllReducer reducer_;
|
||||
std::vector<ValueConstraint> node_value_constraints_;
|
||||
/*! List storing device id. */
|
||||
std::vector<int> device_list_;
|
||||
|
||||
DMatrix* p_last_fmat_;
|
||||
GPUDistribution dist_;
|
||||
|
||||
common::Monitor monitor_;
|
||||
/*! List storing device id. */
|
||||
std::vector<int> device_list_;
|
||||
};
|
||||
|
||||
class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
Reference in New Issue
Block a user