Refactor split valuation kernel (#8073)
This commit is contained in:
@@ -22,202 +22,199 @@ XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
|
||||
bst_feature_t fidx,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
bool &missing_left_out) { // NOLINT
|
||||
float parent_gain = CalcGain(param, parent_sum);
|
||||
float missing_left_gain = evaluator.CalcSplitGain(param, nidx, fidx, GradStats(scan + missing),
|
||||
GradStats(parent_sum - (scan + missing)));
|
||||
float missing_right_gain =
|
||||
evaluator.CalcSplitGain(param, nidx, fidx, GradStats(scan), GradStats(parent_sum - scan));
|
||||
const auto left_sum = scan + missing;
|
||||
float missing_left_gain =
|
||||
evaluator.CalcSplitGain(param, nidx, fidx, left_sum, parent_sum - left_sum);
|
||||
float missing_right_gain = evaluator.CalcSplitGain(param, nidx, fidx, scan, parent_sum - scan);
|
||||
|
||||
if (missing_left_gain > missing_right_gain) {
|
||||
missing_left_out = true;
|
||||
return missing_left_gain - parent_gain;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_gain - parent_gain;
|
||||
}
|
||||
missing_left_out = missing_left_gain > missing_right_gain;
|
||||
return missing_left_out?missing_left_gain:missing_right_gain;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief
|
||||
*
|
||||
* \tparam ReduceT BlockReduce Type.
|
||||
* \tparam TempStorage Cub Shared memory
|
||||
*
|
||||
* \param begin
|
||||
* \param end
|
||||
* \param temp_storage Shared memory for intermediate result.
|
||||
*/
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename TempStorageT, typename GradientSumT>
|
||||
__device__ GradientSumT ReduceFeature(common::Span<const GradientSumT> feature_histogram,
|
||||
TempStorageT *temp_storage) {
|
||||
__shared__ cub::Uninitialized<GradientSumT> uninitialized_sum;
|
||||
GradientSumT &shared_sum = uninitialized_sum.Alias();
|
||||
|
||||
GradientSumT local_sum = GradientSumT();
|
||||
// For loop sums features into one block size
|
||||
auto begin = feature_histogram.data();
|
||||
auto end = begin + feature_histogram.size();
|
||||
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
|
||||
bool thread_active = itr + threadIdx.x < end;
|
||||
// Scan histogram
|
||||
GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT();
|
||||
local_sum += bin;
|
||||
}
|
||||
local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum());
|
||||
// Reduction result is stored in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
shared_sum = local_sum;
|
||||
}
|
||||
cub::CTA_SYNC();
|
||||
return shared_sum;
|
||||
}
|
||||
|
||||
/*! \brief Find the thread with best gain. */
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT, typename MaxReduceT,
|
||||
typename TempStorageT, typename GradientSumT, SplitType type>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx, const EvaluateSplitInputs &inputs, const EvaluateSplitSharedInputs &shared_inputs,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<bst_feature_t> sorted_idx, size_t offset,
|
||||
DeviceSplitCandidate *best_split, // shared memory storing best split
|
||||
TempStorageT *temp_storage // temp memory for cub operations
|
||||
) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = shared_inputs.feature_segments[fidx]; // beginning bin
|
||||
uint32_t gidx_end = shared_inputs.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin);
|
||||
|
||||
// Sum histogram bins for current feature
|
||||
GradientSumT const feature_sum =
|
||||
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(feature_hist, temp_storage);
|
||||
|
||||
GradientPairPrecise const missing = inputs.parent_sum - GradientPairPrecise{feature_sum};
|
||||
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
||||
|
||||
SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
auto calc_bin_value = [&]() {
|
||||
GradientSumT bin;
|
||||
switch (type) {
|
||||
case kOneHot: {
|
||||
auto rest =
|
||||
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
||||
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
|
||||
break;
|
||||
}
|
||||
case kNum: {
|
||||
bin =
|
||||
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
break;
|
||||
}
|
||||
case kPart: {
|
||||
auto rest = thread_active
|
||||
? inputs.gradient_histogram[sorted_idx[scan_begin + threadIdx.x] - offset]
|
||||
: GradientSumT();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
ScanT(temp_storage->scan).InclusiveScan(rest, rest, cub::Sum(), prefix_op);
|
||||
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
|
||||
break;
|
||||
}
|
||||
}
|
||||
return bin;
|
||||
};
|
||||
auto bin = calc_bin_value();
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = null_gain;
|
||||
if (thread_active) {
|
||||
gain = LossChangeMissing(GradientPairPrecise{bin}, missing, inputs.parent_sum,
|
||||
shared_inputs.param, inputs.nidx, fidx, evaluator, missing_left);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find thread with best gain
|
||||
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
|
||||
cub::KeyValuePair<int, float> best =
|
||||
MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
|
||||
|
||||
__shared__ cub::KeyValuePair<int, float> block_max;
|
||||
if (threadIdx.x == 0) {
|
||||
block_max = best;
|
||||
}
|
||||
|
||||
cub::CTA_SYNC();
|
||||
|
||||
// Best thread updates the split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
switch (type) {
|
||||
case kNum: {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = shared_inputs.feature_segments[fidx]; // beginning bin
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = shared_inputs.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = shared_inputs.feature_values[split_gidx];
|
||||
}
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
false, shared_inputs.param);
|
||||
break;
|
||||
}
|
||||
case kOneHot: {
|
||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = shared_inputs.feature_values[split_gidx];
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
true, shared_inputs.param);
|
||||
break;
|
||||
}
|
||||
case kPart: {
|
||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = shared_inputs.feature_values[split_gidx];
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
auto best_thresh = block_max.key; // index of best threshold inside a feature.
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
|
||||
right, true, shared_inputs.param);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cub::CTA_SYNC();
|
||||
}
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS, typename GradientSumT>
|
||||
__global__ __launch_bounds__(BLOCK_THREADS) void EvaluateSplitsKernel(
|
||||
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,
|
||||
const EvaluateSplitSharedInputs shared_inputs, common::Span<bst_feature_t> sorted_idx,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_candidates) {
|
||||
// KeyValuePair here used as threadIdx.x -> gain_value
|
||||
// This kernel uses block_size == warp_size. This is an unusually small block size for a cuda kernel
|
||||
// - normally a larger block size is preferred to increase the number of resident warps on each SM
|
||||
// (occupancy). In the below case each thread has a very large amount of work per thread relative to
|
||||
// typical cuda kernels. Thus the SM can be highly utilised by a small number of threads. It was
|
||||
// discovered by experiments that a small block size here is significantly faster. Furthermore,
|
||||
// using only a single warp, synchronisation barriers are eliminated and broadcasts can be performed
|
||||
// using warp intrinsics instead of slower shared memory.
|
||||
template <int kBlockSize>
|
||||
class EvaluateSplitAgent {
|
||||
public:
|
||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||
using BlockScanT = cub::BlockScan<GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
using MaxReduceT = cub::BlockReduce<ArgMaxT, BLOCK_THREADS>;
|
||||
|
||||
using SumReduceT = cub::BlockReduce<GradientSumT, BLOCK_THREADS>;
|
||||
|
||||
union TempStorage {
|
||||
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
|
||||
using MaxReduceT =
|
||||
cub::WarpReduce<ArgMaxT>;
|
||||
using SumReduceT = cub::WarpReduce<GradientPairPrecise>;
|
||||
struct TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename MaxReduceT::TempStorage max_reduce;
|
||||
typename SumReduceT::TempStorage sum_reduce;
|
||||
};
|
||||
|
||||
const int fidx;
|
||||
const int nidx;
|
||||
const float min_fvalue;
|
||||
const uint32_t gidx_begin; // beginning bin
|
||||
const uint32_t gidx_end; // end bin for i^th feature
|
||||
const dh::LDGIterator<float> feature_values;
|
||||
const GradientPairPrecise *node_histogram;
|
||||
const GradientPairPrecise parent_sum;
|
||||
const GradientPairPrecise missing;
|
||||
const GPUTrainingParam ¶m;
|
||||
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator;
|
||||
TempStorage *temp_storage;
|
||||
SumCallbackOp<GradientPairPrecise> prefix_op;
|
||||
static float constexpr kNullGain = -std::numeric_limits<bst_float>::infinity();
|
||||
|
||||
__device__ EvaluateSplitAgent(TempStorage *temp_storage, int fidx,
|
||||
const EvaluateSplitInputs &inputs,
|
||||
const EvaluateSplitSharedInputs &shared_inputs,
|
||||
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
|
||||
: temp_storage(temp_storage),
|
||||
nidx(inputs.nidx),
|
||||
fidx(fidx),
|
||||
min_fvalue(__ldg(shared_inputs.min_fvalue.data() + fidx)),
|
||||
gidx_begin(__ldg(shared_inputs.feature_segments.data() + fidx)),
|
||||
gidx_end(__ldg(shared_inputs.feature_segments.data() + fidx + 1)),
|
||||
feature_values(shared_inputs.feature_values.data()),
|
||||
node_histogram(inputs.gradient_histogram.data()),
|
||||
parent_sum(dh::LDGIterator<GradientPairPrecise>(&inputs.parent_sum)[0]),
|
||||
param(shared_inputs.param),
|
||||
evaluator(evaluator),
|
||||
missing(parent_sum - ReduceFeature()) {
|
||||
static_assert(kBlockSize == 32,
|
||||
"This kernel relies on the assumption block_size == warp_size");
|
||||
}
|
||||
__device__ GradientPairPrecise ReduceFeature() {
|
||||
GradientPairPrecise local_sum;
|
||||
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end; idx += kBlockSize) {
|
||||
local_sum += LoadGpair(node_histogram + idx);
|
||||
}
|
||||
local_sum = SumReduceT(temp_storage->sum_reduce).Sum(local_sum);
|
||||
// Broadcast result from thread 0
|
||||
return {__shfl_sync(0xffffffff, local_sum.GetGrad(), 0),
|
||||
__shfl_sync(0xffffffff, local_sum.GetHess(), 0)};
|
||||
}
|
||||
|
||||
// Load using efficient 128 vector load instruction
|
||||
__device__ __forceinline__ GradientPairPrecise LoadGpair(const GradientPairPrecise *ptr) {
|
||||
static_assert(sizeof(GradientPairPrecise) == sizeof(float4),
|
||||
"Vector type size does not match gradient pair size.");
|
||||
float4 tmp = *reinterpret_cast<const float4 *>(ptr);
|
||||
return *reinterpret_cast<const GradientPairPrecise *>(&tmp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void Numerical(DeviceSplitCandidate *__restrict__ best_split) {
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
GradientPairPrecise bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
|
||||
: GradientPairPrecise();
|
||||
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||
evaluator, missing_left)
|
||||
: kNullGain;
|
||||
|
||||
// Find thread with best gain
|
||||
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||
// This reduce result is only valid in thread 0
|
||||
// broadcast to the rest of the warp
|
||||
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
|
||||
|
||||
// Best thread updates the split
|
||||
if (threadIdx.x == best_thread) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue =
|
||||
split_gidx < static_cast<int>(gidx_begin) ? min_fvalue : feature_values[split_gidx];
|
||||
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise right = parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
false, param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void OneHot(DeviceSplitCandidate *__restrict__ best_split) {
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
auto rest = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
|
||||
: GradientPairPrecise();
|
||||
GradientPairPrecise bin = parent_sum - rest - missing;
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||
evaluator, missing_left)
|
||||
: kNullGain;
|
||||
|
||||
// Find thread with best gain
|
||||
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||
// This reduce result is only valid in thread 0
|
||||
// broadcast to the rest of the warp
|
||||
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
|
||||
// Best thread updates the split
|
||||
if (threadIdx.x == best_thread) {
|
||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = feature_values[split_gidx];
|
||||
GradientPairPrecise left =
|
||||
missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise right = parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
true, param);
|
||||
}
|
||||
}
|
||||
}
|
||||
__device__ __forceinline__ void Partition(DeviceSplitCandidate *__restrict__ best_split,
|
||||
bst_feature_t * __restrict__ sorted_idx,
|
||||
std::size_t offset) {
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
auto rest = thread_active
|
||||
? LoadGpair(node_histogram + sorted_idx[scan_begin + threadIdx.x] - offset)
|
||||
: GradientPairPrecise();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
BlockScanT(temp_storage->scan).InclusiveSum(rest, rest, prefix_op);
|
||||
GradientPairPrecise bin = parent_sum - rest - missing;
|
||||
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||
evaluator, missing_left)
|
||||
: kNullGain;
|
||||
|
||||
|
||||
// Find thread with best gain
|
||||
auto best =
|
||||
MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||
// This reduce result is only valid in thread 0
|
||||
// broadcast to the rest of the warp
|
||||
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
|
||||
// Best thread updates the split
|
||||
if (threadIdx.x == best_thread) {
|
||||
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise right = parent_sum - left;
|
||||
auto best_thresh =
|
||||
threadIdx.x + (scan_begin - gidx_begin); // index of best threshold inside a feature.
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
|
||||
right, true, param);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int kBlockSize>
|
||||
__global__ __launch_bounds__(kBlockSize) void EvaluateSplitsKernel(
|
||||
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,
|
||||
const EvaluateSplitSharedInputs shared_inputs, common::Span<bst_feature_t> sorted_idx,
|
||||
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_candidates) {
|
||||
// Aligned && shared storage for best_split
|
||||
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
|
||||
DeviceSplitCandidate &best_split = uninitialized_split.Alias();
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
best_split = DeviceSplitCandidate();
|
||||
@@ -232,25 +229,23 @@ __global__ __launch_bounds__(BLOCK_THREADS) void EvaluateSplitsKernel(
|
||||
|
||||
int fidx = inputs.feature_set[blockIdx.x % number_active_features];
|
||||
|
||||
using AgentT = EvaluateSplitAgent<kBlockSize>;
|
||||
__shared__ typename AgentT::TempStorage temp_storage;
|
||||
AgentT agent(&temp_storage, fidx, inputs, shared_inputs, evaluator);
|
||||
|
||||
if (common::IsCat(shared_inputs.feature_types, fidx)) {
|
||||
auto n_bins_in_feat =
|
||||
shared_inputs.feature_segments[fidx + 1] - shared_inputs.feature_segments[fidx];
|
||||
if (common::UseOneHot(n_bins_in_feat, shared_inputs.param.max_cat_to_onehot)) {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
||||
kOneHot>(fidx, inputs, shared_inputs, evaluator, sorted_idx, 0, &best_split,
|
||||
&temp_storage);
|
||||
agent.OneHot(&best_split);
|
||||
} else {
|
||||
auto total_bins = shared_inputs.feature_values.size();
|
||||
size_t offset = total_bins * input_idx;
|
||||
auto node_sorted_idx = sorted_idx.subspan(offset, total_bins);
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
||||
kPart>(fidx, inputs, shared_inputs, evaluator, node_sorted_idx, offset,
|
||||
&best_split, &temp_storage);
|
||||
agent.Partition(&best_split, node_sorted_idx.data(), offset);
|
||||
}
|
||||
} else {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
||||
kNum>(fidx, inputs, shared_inputs, evaluator, sorted_idx, 0, &best_split,
|
||||
&temp_storage);
|
||||
agent.Numerical(&best_split);
|
||||
}
|
||||
|
||||
cub::CTA_SYNC();
|
||||
@@ -310,8 +305,7 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::LaunchEvaluateSplits(
|
||||
void GPUHistEvaluator::LaunchEvaluateSplits(
|
||||
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,
|
||||
EvaluateSplitSharedInputs shared_inputs,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
@@ -326,7 +320,7 @@ void GPUHistEvaluator<GradientSumT>::LaunchEvaluateSplits(
|
||||
// One block for each feature
|
||||
uint32_t constexpr kBlockThreads = 32;
|
||||
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
|
||||
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, number_active_features, d_inputs,
|
||||
EvaluateSplitsKernel<kBlockThreads>, number_active_features, d_inputs,
|
||||
shared_inputs, this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size()),
|
||||
evaluator, dh::ToSpan(feature_best_splits));
|
||||
|
||||
@@ -345,8 +339,7 @@ void GPUHistEvaluator<GradientSumT>::LaunchEvaluateSplits(
|
||||
reduce_offset + 1);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::CopyToHost(const std::vector<bst_node_t> &nidx) {
|
||||
void GPUHistEvaluator::CopyToHost(const std::vector<bst_node_t> &nidx) {
|
||||
if (!has_categoricals_) return;
|
||||
auto d_cats = this->DeviceCatStorage(nidx);
|
||||
auto h_cats = this->HostCatStorage(nidx);
|
||||
@@ -360,8 +353,7 @@ void GPUHistEvaluator<GradientSumT>::CopyToHost(const std::vector<bst_node_t> &n
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
void GPUHistEvaluator::EvaluateSplits(
|
||||
const std::vector<bst_node_t> &nidx, bst_feature_t number_active_features,
|
||||
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
|
||||
common::Span<GPUExpandEntry> out_entries) {
|
||||
@@ -379,6 +371,10 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
dh::LaunchN(d_inputs.size(), [=] __device__(size_t i) mutable {
|
||||
auto const input = d_inputs[i];
|
||||
auto &split = out_splits[i];
|
||||
// Subtract parent gain here
|
||||
// As it is constant, this is more efficient than doing it during every split evaluation
|
||||
float parent_gain = CalcGain(shared_inputs.param, input.parent_sum);
|
||||
split.loss_chg -= parent_gain;
|
||||
auto fidx = out_splits[i].findex;
|
||||
|
||||
if (split.is_cat) {
|
||||
@@ -400,8 +396,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
this->CopyToHost(nidx);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
||||
GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(
|
||||
EvaluateSplitInputs input, EvaluateSplitSharedInputs shared_inputs) {
|
||||
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input};
|
||||
dh::TemporaryArray<GPUExpandEntry> out_entries(1);
|
||||
@@ -413,6 +408,5 @@ GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
template class GPUHistEvaluator<GradientPairPrecise>;
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -51,7 +51,6 @@ struct CatAccessor {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
class GPUHistEvaluator {
|
||||
using CatST = common::CatBitField::value_type; // categorical storage type
|
||||
// use pinned memory to stage the categories, used for sort based splits.
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
|
||||
common::Span<FeatureType const> ft,
|
||||
bst_feature_t n_features, TrainParam const ¶m,
|
||||
int32_t device) {
|
||||
@@ -68,8 +67,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
||||
common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
|
||||
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
@@ -128,7 +126,5 @@ common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
||||
return dh::ToSpan(cat_sorted_idx_);
|
||||
}
|
||||
|
||||
template class GPUHistEvaluator<GradientPair>;
|
||||
template class GPUHistEvaluator<GradientPairPrecise>;
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user