Refactor split valuation kernel (#8073)
This commit is contained in:
parent
cb40bbdadd
commit
1be09848a7
@ -1949,7 +1949,7 @@ class LDGIterator {
|
|||||||
const T *ptr_;
|
const T *ptr_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit LDGIterator(const T *ptr) : ptr_(ptr) {}
|
XGBOOST_DEVICE explicit LDGIterator(const T *ptr) : ptr_(ptr) {}
|
||||||
__device__ T operator[](std::size_t idx) const {
|
__device__ T operator[](std::size_t idx) const {
|
||||||
DeviceWordT tmp[kNumWords];
|
DeviceWordT tmp[kNumWords];
|
||||||
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
|
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
|
||||||
|
|||||||
@ -22,202 +22,199 @@ XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
|
|||||||
bst_feature_t fidx,
|
bst_feature_t fidx,
|
||||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||||
bool &missing_left_out) { // NOLINT
|
bool &missing_left_out) { // NOLINT
|
||||||
float parent_gain = CalcGain(param, parent_sum);
|
const auto left_sum = scan + missing;
|
||||||
float missing_left_gain = evaluator.CalcSplitGain(param, nidx, fidx, GradStats(scan + missing),
|
float missing_left_gain =
|
||||||
GradStats(parent_sum - (scan + missing)));
|
evaluator.CalcSplitGain(param, nidx, fidx, left_sum, parent_sum - left_sum);
|
||||||
float missing_right_gain =
|
float missing_right_gain = evaluator.CalcSplitGain(param, nidx, fidx, scan, parent_sum - scan);
|
||||||
evaluator.CalcSplitGain(param, nidx, fidx, GradStats(scan), GradStats(parent_sum - scan));
|
|
||||||
|
|
||||||
if (missing_left_gain > missing_right_gain) {
|
missing_left_out = missing_left_gain > missing_right_gain;
|
||||||
missing_left_out = true;
|
return missing_left_out?missing_left_gain:missing_right_gain;
|
||||||
return missing_left_gain - parent_gain;
|
|
||||||
} else {
|
|
||||||
missing_left_out = false;
|
|
||||||
return missing_right_gain - parent_gain;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
// This kernel uses block_size == warp_size. This is an unusually small block size for a cuda kernel
|
||||||
* \brief
|
// - normally a larger block size is preferred to increase the number of resident warps on each SM
|
||||||
*
|
// (occupancy). In the below case each thread has a very large amount of work per thread relative to
|
||||||
* \tparam ReduceT BlockReduce Type.
|
// typical cuda kernels. Thus the SM can be highly utilised by a small number of threads. It was
|
||||||
* \tparam TempStorage Cub Shared memory
|
// discovered by experiments that a small block size here is significantly faster. Furthermore,
|
||||||
*
|
// using only a single warp, synchronisation barriers are eliminated and broadcasts can be performed
|
||||||
* \param begin
|
// using warp intrinsics instead of slower shared memory.
|
||||||
* \param end
|
template <int kBlockSize>
|
||||||
* \param temp_storage Shared memory for intermediate result.
|
class EvaluateSplitAgent {
|
||||||
*/
|
public:
|
||||||
template <int BLOCK_THREADS, typename ReduceT, typename TempStorageT, typename GradientSumT>
|
|
||||||
__device__ GradientSumT ReduceFeature(common::Span<const GradientSumT> feature_histogram,
|
|
||||||
TempStorageT *temp_storage) {
|
|
||||||
__shared__ cub::Uninitialized<GradientSumT> uninitialized_sum;
|
|
||||||
GradientSumT &shared_sum = uninitialized_sum.Alias();
|
|
||||||
|
|
||||||
GradientSumT local_sum = GradientSumT();
|
|
||||||
// For loop sums features into one block size
|
|
||||||
auto begin = feature_histogram.data();
|
|
||||||
auto end = begin + feature_histogram.size();
|
|
||||||
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
|
|
||||||
bool thread_active = itr + threadIdx.x < end;
|
|
||||||
// Scan histogram
|
|
||||||
GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT();
|
|
||||||
local_sum += bin;
|
|
||||||
}
|
|
||||||
local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum());
|
|
||||||
// Reduction result is stored in thread 0.
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
shared_sum = local_sum;
|
|
||||||
}
|
|
||||||
cub::CTA_SYNC();
|
|
||||||
return shared_sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief Find the thread with best gain. */
|
|
||||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT, typename MaxReduceT,
|
|
||||||
typename TempStorageT, typename GradientSumT, SplitType type>
|
|
||||||
__device__ void EvaluateFeature(
|
|
||||||
int fidx, const EvaluateSplitInputs &inputs, const EvaluateSplitSharedInputs &shared_inputs,
|
|
||||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
|
||||||
common::Span<bst_feature_t> sorted_idx, size_t offset,
|
|
||||||
DeviceSplitCandidate *best_split, // shared memory storing best split
|
|
||||||
TempStorageT *temp_storage // temp memory for cub operations
|
|
||||||
) {
|
|
||||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
|
||||||
uint32_t gidx_begin = shared_inputs.feature_segments[fidx]; // beginning bin
|
|
||||||
uint32_t gidx_end = shared_inputs.feature_segments[fidx + 1]; // end bin for i^th feature
|
|
||||||
auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin);
|
|
||||||
|
|
||||||
// Sum histogram bins for current feature
|
|
||||||
GradientSumT const feature_sum =
|
|
||||||
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(feature_hist, temp_storage);
|
|
||||||
|
|
||||||
GradientPairPrecise const missing = inputs.parent_sum - GradientPairPrecise{feature_sum};
|
|
||||||
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
|
||||||
|
|
||||||
SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
|
|
||||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) {
|
|
||||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
|
||||||
|
|
||||||
auto calc_bin_value = [&]() {
|
|
||||||
GradientSumT bin;
|
|
||||||
switch (type) {
|
|
||||||
case kOneHot: {
|
|
||||||
auto rest =
|
|
||||||
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
|
||||||
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kNum: {
|
|
||||||
bin =
|
|
||||||
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
|
||||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kPart: {
|
|
||||||
auto rest = thread_active
|
|
||||||
? inputs.gradient_histogram[sorted_idx[scan_begin + threadIdx.x] - offset]
|
|
||||||
: GradientSumT();
|
|
||||||
// No min value for cat feature, use inclusive scan.
|
|
||||||
ScanT(temp_storage->scan).InclusiveScan(rest, rest, cub::Sum(), prefix_op);
|
|
||||||
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return bin;
|
|
||||||
};
|
|
||||||
auto bin = calc_bin_value();
|
|
||||||
// Whether the gradient of missing values is put to the left side.
|
|
||||||
bool missing_left = true;
|
|
||||||
float gain = null_gain;
|
|
||||||
if (thread_active) {
|
|
||||||
gain = LossChangeMissing(GradientPairPrecise{bin}, missing, inputs.parent_sum,
|
|
||||||
shared_inputs.param, inputs.nidx, fidx, evaluator, missing_left);
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Find thread with best gain
|
|
||||||
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
|
|
||||||
cub::KeyValuePair<int, float> best =
|
|
||||||
MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
|
|
||||||
|
|
||||||
__shared__ cub::KeyValuePair<int, float> block_max;
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
block_max = best;
|
|
||||||
}
|
|
||||||
|
|
||||||
cub::CTA_SYNC();
|
|
||||||
|
|
||||||
// Best thread updates the split
|
|
||||||
if (threadIdx.x == block_max.key) {
|
|
||||||
switch (type) {
|
|
||||||
case kNum: {
|
|
||||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
|
||||||
uint32_t gidx_begin = shared_inputs.feature_segments[fidx]; // beginning bin
|
|
||||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
|
||||||
float fvalue;
|
|
||||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
|
||||||
fvalue = shared_inputs.min_fvalue[fidx];
|
|
||||||
} else {
|
|
||||||
fvalue = shared_inputs.feature_values[split_gidx];
|
|
||||||
}
|
|
||||||
GradientPairPrecise left =
|
|
||||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
|
||||||
GradientPairPrecise right = inputs.parent_sum - left;
|
|
||||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
|
||||||
false, shared_inputs.param);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kOneHot: {
|
|
||||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
|
||||||
float fvalue = shared_inputs.feature_values[split_gidx];
|
|
||||||
GradientPairPrecise left =
|
|
||||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
|
||||||
GradientPairPrecise right = inputs.parent_sum - left;
|
|
||||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
|
||||||
true, shared_inputs.param);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kPart: {
|
|
||||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
|
||||||
float fvalue = shared_inputs.feature_values[split_gidx];
|
|
||||||
GradientPairPrecise left =
|
|
||||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
|
||||||
GradientPairPrecise right = inputs.parent_sum - left;
|
|
||||||
auto best_thresh = block_max.key; // index of best threshold inside a feature.
|
|
||||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
|
|
||||||
right, true, shared_inputs.param);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cub::CTA_SYNC();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int BLOCK_THREADS, typename GradientSumT>
|
|
||||||
__global__ __launch_bounds__(BLOCK_THREADS) void EvaluateSplitsKernel(
|
|
||||||
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,
|
|
||||||
const EvaluateSplitSharedInputs shared_inputs, common::Span<bst_feature_t> sorted_idx,
|
|
||||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
|
||||||
common::Span<DeviceSplitCandidate> out_candidates) {
|
|
||||||
// KeyValuePair here used as threadIdx.x -> gain_value
|
|
||||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||||
using BlockScanT = cub::BlockScan<GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
|
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
|
||||||
using MaxReduceT = cub::BlockReduce<ArgMaxT, BLOCK_THREADS>;
|
using MaxReduceT =
|
||||||
|
cub::WarpReduce<ArgMaxT>;
|
||||||
using SumReduceT = cub::BlockReduce<GradientSumT, BLOCK_THREADS>;
|
using SumReduceT = cub::WarpReduce<GradientPairPrecise>;
|
||||||
|
struct TempStorage {
|
||||||
union TempStorage {
|
|
||||||
typename BlockScanT::TempStorage scan;
|
typename BlockScanT::TempStorage scan;
|
||||||
typename MaxReduceT::TempStorage max_reduce;
|
typename MaxReduceT::TempStorage max_reduce;
|
||||||
typename SumReduceT::TempStorage sum_reduce;
|
typename SumReduceT::TempStorage sum_reduce;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const int fidx;
|
||||||
|
const int nidx;
|
||||||
|
const float min_fvalue;
|
||||||
|
const uint32_t gidx_begin; // beginning bin
|
||||||
|
const uint32_t gidx_end; // end bin for i^th feature
|
||||||
|
const dh::LDGIterator<float> feature_values;
|
||||||
|
const GradientPairPrecise *node_histogram;
|
||||||
|
const GradientPairPrecise parent_sum;
|
||||||
|
const GradientPairPrecise missing;
|
||||||
|
const GPUTrainingParam ¶m;
|
||||||
|
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator;
|
||||||
|
TempStorage *temp_storage;
|
||||||
|
SumCallbackOp<GradientPairPrecise> prefix_op;
|
||||||
|
static float constexpr kNullGain = -std::numeric_limits<bst_float>::infinity();
|
||||||
|
|
||||||
|
__device__ EvaluateSplitAgent(TempStorage *temp_storage, int fidx,
|
||||||
|
const EvaluateSplitInputs &inputs,
|
||||||
|
const EvaluateSplitSharedInputs &shared_inputs,
|
||||||
|
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
|
||||||
|
: temp_storage(temp_storage),
|
||||||
|
nidx(inputs.nidx),
|
||||||
|
fidx(fidx),
|
||||||
|
min_fvalue(__ldg(shared_inputs.min_fvalue.data() + fidx)),
|
||||||
|
gidx_begin(__ldg(shared_inputs.feature_segments.data() + fidx)),
|
||||||
|
gidx_end(__ldg(shared_inputs.feature_segments.data() + fidx + 1)),
|
||||||
|
feature_values(shared_inputs.feature_values.data()),
|
||||||
|
node_histogram(inputs.gradient_histogram.data()),
|
||||||
|
parent_sum(dh::LDGIterator<GradientPairPrecise>(&inputs.parent_sum)[0]),
|
||||||
|
param(shared_inputs.param),
|
||||||
|
evaluator(evaluator),
|
||||||
|
missing(parent_sum - ReduceFeature()) {
|
||||||
|
static_assert(kBlockSize == 32,
|
||||||
|
"This kernel relies on the assumption block_size == warp_size");
|
||||||
|
}
|
||||||
|
__device__ GradientPairPrecise ReduceFeature() {
|
||||||
|
GradientPairPrecise local_sum;
|
||||||
|
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end; idx += kBlockSize) {
|
||||||
|
local_sum += LoadGpair(node_histogram + idx);
|
||||||
|
}
|
||||||
|
local_sum = SumReduceT(temp_storage->sum_reduce).Sum(local_sum);
|
||||||
|
// Broadcast result from thread 0
|
||||||
|
return {__shfl_sync(0xffffffff, local_sum.GetGrad(), 0),
|
||||||
|
__shfl_sync(0xffffffff, local_sum.GetHess(), 0)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load using efficient 128 vector load instruction
|
||||||
|
__device__ __forceinline__ GradientPairPrecise LoadGpair(const GradientPairPrecise *ptr) {
|
||||||
|
static_assert(sizeof(GradientPairPrecise) == sizeof(float4),
|
||||||
|
"Vector type size does not match gradient pair size.");
|
||||||
|
float4 tmp = *reinterpret_cast<const float4 *>(ptr);
|
||||||
|
return *reinterpret_cast<const GradientPairPrecise *>(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void Numerical(DeviceSplitCandidate *__restrict__ best_split) {
|
||||||
|
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||||
|
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||||
|
GradientPairPrecise bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
|
||||||
|
: GradientPairPrecise();
|
||||||
|
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||||
|
// Whether the gradient of missing values is put to the left side.
|
||||||
|
bool missing_left = true;
|
||||||
|
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||||
|
evaluator, missing_left)
|
||||||
|
: kNullGain;
|
||||||
|
|
||||||
|
// Find thread with best gain
|
||||||
|
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||||
|
// This reduce result is only valid in thread 0
|
||||||
|
// broadcast to the rest of the warp
|
||||||
|
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
|
||||||
|
|
||||||
|
// Best thread updates the split
|
||||||
|
if (threadIdx.x == best_thread) {
|
||||||
|
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||||
|
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||||
|
float fvalue =
|
||||||
|
split_gidx < static_cast<int>(gidx_begin) ? min_fvalue : feature_values[split_gidx];
|
||||||
|
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||||
|
GradientPairPrecise right = parent_sum - left;
|
||||||
|
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||||
|
false, param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void OneHot(DeviceSplitCandidate *__restrict__ best_split) {
|
||||||
|
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||||
|
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||||
|
|
||||||
|
auto rest = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
|
||||||
|
: GradientPairPrecise();
|
||||||
|
GradientPairPrecise bin = parent_sum - rest - missing;
|
||||||
|
// Whether the gradient of missing values is put to the left side.
|
||||||
|
bool missing_left = true;
|
||||||
|
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||||
|
evaluator, missing_left)
|
||||||
|
: kNullGain;
|
||||||
|
|
||||||
|
// Find thread with best gain
|
||||||
|
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||||
|
// This reduce result is only valid in thread 0
|
||||||
|
// broadcast to the rest of the warp
|
||||||
|
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
|
||||||
|
// Best thread updates the split
|
||||||
|
if (threadIdx.x == best_thread) {
|
||||||
|
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||||
|
float fvalue = feature_values[split_gidx];
|
||||||
|
GradientPairPrecise left =
|
||||||
|
missing_left ? bin + missing : bin;
|
||||||
|
GradientPairPrecise right = parent_sum - left;
|
||||||
|
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||||
|
true, param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void Partition(DeviceSplitCandidate *__restrict__ best_split,
|
||||||
|
bst_feature_t * __restrict__ sorted_idx,
|
||||||
|
std::size_t offset) {
|
||||||
|
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||||
|
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||||
|
|
||||||
|
auto rest = thread_active
|
||||||
|
? LoadGpair(node_histogram + sorted_idx[scan_begin + threadIdx.x] - offset)
|
||||||
|
: GradientPairPrecise();
|
||||||
|
// No min value for cat feature, use inclusive scan.
|
||||||
|
BlockScanT(temp_storage->scan).InclusiveSum(rest, rest, prefix_op);
|
||||||
|
GradientPairPrecise bin = parent_sum - rest - missing;
|
||||||
|
|
||||||
|
// Whether the gradient of missing values is put to the left side.
|
||||||
|
bool missing_left = true;
|
||||||
|
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||||
|
evaluator, missing_left)
|
||||||
|
: kNullGain;
|
||||||
|
|
||||||
|
|
||||||
|
// Find thread with best gain
|
||||||
|
auto best =
|
||||||
|
MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||||
|
// This reduce result is only valid in thread 0
|
||||||
|
// broadcast to the rest of the warp
|
||||||
|
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
|
||||||
|
// Best thread updates the split
|
||||||
|
if (threadIdx.x == best_thread) {
|
||||||
|
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||||
|
GradientPairPrecise right = parent_sum - left;
|
||||||
|
auto best_thresh =
|
||||||
|
threadIdx.x + (scan_begin - gidx_begin); // index of best threshold inside a feature.
|
||||||
|
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
|
||||||
|
right, true, param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int kBlockSize>
|
||||||
|
__global__ __launch_bounds__(kBlockSize) void EvaluateSplitsKernel(
|
||||||
|
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,
|
||||||
|
const EvaluateSplitSharedInputs shared_inputs, common::Span<bst_feature_t> sorted_idx,
|
||||||
|
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||||
|
common::Span<DeviceSplitCandidate> out_candidates) {
|
||||||
// Aligned && shared storage for best_split
|
// Aligned && shared storage for best_split
|
||||||
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
|
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
|
||||||
DeviceSplitCandidate &best_split = uninitialized_split.Alias();
|
DeviceSplitCandidate &best_split = uninitialized_split.Alias();
|
||||||
__shared__ TempStorage temp_storage;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
best_split = DeviceSplitCandidate();
|
best_split = DeviceSplitCandidate();
|
||||||
@ -232,25 +229,23 @@ __global__ __launch_bounds__(BLOCK_THREADS) void EvaluateSplitsKernel(
|
|||||||
|
|
||||||
int fidx = inputs.feature_set[blockIdx.x % number_active_features];
|
int fidx = inputs.feature_set[blockIdx.x % number_active_features];
|
||||||
|
|
||||||
|
using AgentT = EvaluateSplitAgent<kBlockSize>;
|
||||||
|
__shared__ typename AgentT::TempStorage temp_storage;
|
||||||
|
AgentT agent(&temp_storage, fidx, inputs, shared_inputs, evaluator);
|
||||||
|
|
||||||
if (common::IsCat(shared_inputs.feature_types, fidx)) {
|
if (common::IsCat(shared_inputs.feature_types, fidx)) {
|
||||||
auto n_bins_in_feat =
|
auto n_bins_in_feat =
|
||||||
shared_inputs.feature_segments[fidx + 1] - shared_inputs.feature_segments[fidx];
|
shared_inputs.feature_segments[fidx + 1] - shared_inputs.feature_segments[fidx];
|
||||||
if (common::UseOneHot(n_bins_in_feat, shared_inputs.param.max_cat_to_onehot)) {
|
if (common::UseOneHot(n_bins_in_feat, shared_inputs.param.max_cat_to_onehot)) {
|
||||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
agent.OneHot(&best_split);
|
||||||
kOneHot>(fidx, inputs, shared_inputs, evaluator, sorted_idx, 0, &best_split,
|
|
||||||
&temp_storage);
|
|
||||||
} else {
|
} else {
|
||||||
auto total_bins = shared_inputs.feature_values.size();
|
auto total_bins = shared_inputs.feature_values.size();
|
||||||
size_t offset = total_bins * input_idx;
|
size_t offset = total_bins * input_idx;
|
||||||
auto node_sorted_idx = sorted_idx.subspan(offset, total_bins);
|
auto node_sorted_idx = sorted_idx.subspan(offset, total_bins);
|
||||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
agent.Partition(&best_split, node_sorted_idx.data(), offset);
|
||||||
kPart>(fidx, inputs, shared_inputs, evaluator, node_sorted_idx, offset,
|
|
||||||
&best_split, &temp_storage);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
agent.Numerical(&best_split);
|
||||||
kNum>(fidx, inputs, shared_inputs, evaluator, sorted_idx, 0, &best_split,
|
|
||||||
&temp_storage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cub::CTA_SYNC();
|
cub::CTA_SYNC();
|
||||||
@ -310,8 +305,7 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
void GPUHistEvaluator::LaunchEvaluateSplits(
|
||||||
void GPUHistEvaluator<GradientSumT>::LaunchEvaluateSplits(
|
|
||||||
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,
|
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,
|
||||||
EvaluateSplitSharedInputs shared_inputs,
|
EvaluateSplitSharedInputs shared_inputs,
|
||||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||||
@ -326,7 +320,7 @@ void GPUHistEvaluator<GradientSumT>::LaunchEvaluateSplits(
|
|||||||
// One block for each feature
|
// One block for each feature
|
||||||
uint32_t constexpr kBlockThreads = 32;
|
uint32_t constexpr kBlockThreads = 32;
|
||||||
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
|
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
|
||||||
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, number_active_features, d_inputs,
|
EvaluateSplitsKernel<kBlockThreads>, number_active_features, d_inputs,
|
||||||
shared_inputs, this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size()),
|
shared_inputs, this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size()),
|
||||||
evaluator, dh::ToSpan(feature_best_splits));
|
evaluator, dh::ToSpan(feature_best_splits));
|
||||||
|
|
||||||
@ -345,8 +339,7 @@ void GPUHistEvaluator<GradientSumT>::LaunchEvaluateSplits(
|
|||||||
reduce_offset + 1);
|
reduce_offset + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
void GPUHistEvaluator::CopyToHost(const std::vector<bst_node_t> &nidx) {
|
||||||
void GPUHistEvaluator<GradientSumT>::CopyToHost(const std::vector<bst_node_t> &nidx) {
|
|
||||||
if (!has_categoricals_) return;
|
if (!has_categoricals_) return;
|
||||||
auto d_cats = this->DeviceCatStorage(nidx);
|
auto d_cats = this->DeviceCatStorage(nidx);
|
||||||
auto h_cats = this->HostCatStorage(nidx);
|
auto h_cats = this->HostCatStorage(nidx);
|
||||||
@ -360,8 +353,7 @@ void GPUHistEvaluator<GradientSumT>::CopyToHost(const std::vector<bst_node_t> &n
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
void GPUHistEvaluator::EvaluateSplits(
|
||||||
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
|
||||||
const std::vector<bst_node_t> &nidx, bst_feature_t number_active_features,
|
const std::vector<bst_node_t> &nidx, bst_feature_t number_active_features,
|
||||||
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
|
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
|
||||||
common::Span<GPUExpandEntry> out_entries) {
|
common::Span<GPUExpandEntry> out_entries) {
|
||||||
@ -379,6 +371,10 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
|||||||
dh::LaunchN(d_inputs.size(), [=] __device__(size_t i) mutable {
|
dh::LaunchN(d_inputs.size(), [=] __device__(size_t i) mutable {
|
||||||
auto const input = d_inputs[i];
|
auto const input = d_inputs[i];
|
||||||
auto &split = out_splits[i];
|
auto &split = out_splits[i];
|
||||||
|
// Subtract parent gain here
|
||||||
|
// As it is constant, this is more efficient than doing it during every split evaluation
|
||||||
|
float parent_gain = CalcGain(shared_inputs.param, input.parent_sum);
|
||||||
|
split.loss_chg -= parent_gain;
|
||||||
auto fidx = out_splits[i].findex;
|
auto fidx = out_splits[i].findex;
|
||||||
|
|
||||||
if (split.is_cat) {
|
if (split.is_cat) {
|
||||||
@ -400,8 +396,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
|||||||
this->CopyToHost(nidx);
|
this->CopyToHost(nidx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(
|
||||||
GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
|
||||||
EvaluateSplitInputs input, EvaluateSplitSharedInputs shared_inputs) {
|
EvaluateSplitInputs input, EvaluateSplitSharedInputs shared_inputs) {
|
||||||
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input};
|
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input};
|
||||||
dh::TemporaryArray<GPUExpandEntry> out_entries(1);
|
dh::TemporaryArray<GPUExpandEntry> out_entries(1);
|
||||||
@ -413,6 +408,5 @@ GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
|||||||
return root_entry;
|
return root_entry;
|
||||||
}
|
}
|
||||||
|
|
||||||
template class GPUHistEvaluator<GradientPairPrecise>;
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -51,7 +51,6 @@ struct CatAccessor {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
class GPUHistEvaluator {
|
class GPUHistEvaluator {
|
||||||
using CatST = common::CatBitField::value_type; // categorical storage type
|
using CatST = common::CatBitField::value_type; // categorical storage type
|
||||||
// use pinned memory to stage the categories, used for sort based splits.
|
// use pinned memory to stage the categories, used for sort based splits.
|
||||||
|
|||||||
@ -14,8 +14,7 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
template <typename GradientSumT>
|
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
|
||||||
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
|
||||||
common::Span<FeatureType const> ft,
|
common::Span<FeatureType const> ft,
|
||||||
bst_feature_t n_features, TrainParam const ¶m,
|
bst_feature_t n_features, TrainParam const ¶m,
|
||||||
int32_t device) {
|
int32_t device) {
|
||||||
@ -68,8 +67,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
|
||||||
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
|
||||||
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
|
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
|
||||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
|
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
@ -128,7 +126,5 @@ common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
|||||||
return dh::ToSpan(cat_sorted_idx_);
|
return dh::ToSpan(cat_sorted_idx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template class GPUHistEvaluator<GradientPair>;
|
|
||||||
template class GPUHistEvaluator<GradientPairPrecise>;
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -256,7 +256,7 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
|||||||
// calculate the cost of loss function
|
// calculate the cost of loss function
|
||||||
template <typename TrainingParams, typename T>
|
template <typename TrainingParams, typename T>
|
||||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
|
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
|
||||||
if (sum_hess < p.min_child_weight) {
|
if (sum_hess < p.min_child_weight || sum_hess <= 0.0) {
|
||||||
return T(0.0);
|
return T(0.0);
|
||||||
}
|
}
|
||||||
if (p.max_delta_step == 0.0f) {
|
if (p.max_delta_step == 0.0f) {
|
||||||
|
|||||||
@ -71,11 +71,10 @@ class TreeEvaluator {
|
|||||||
const float* upper;
|
const float* upper;
|
||||||
bool has_constraint;
|
bool has_constraint;
|
||||||
|
|
||||||
XGBOOST_DEVICE float CalcSplitGain(const ParamT ¶m, bst_node_t nidx,
|
template <typename GradientSumT>
|
||||||
bst_feature_t fidx,
|
XGBOOST_DEVICE float CalcSplitGain(const ParamT& param, bst_node_t nidx, bst_feature_t fidx,
|
||||||
tree::GradStats const& left,
|
GradientSumT const& left, GradientSumT const& right) const {
|
||||||
tree::GradStats const& right) const {
|
int constraint = has_constraint ? constraints[fidx] : 0;
|
||||||
int constraint = constraints[fidx];
|
|
||||||
const float negative_infinity = -std::numeric_limits<float>::infinity();
|
const float negative_infinity = -std::numeric_limits<float>::infinity();
|
||||||
float wleft = this->CalcWeight(nidx, param, left);
|
float wleft = this->CalcWeight(nidx, param, left);
|
||||||
float wright = this->CalcWeight(nidx, param, right);
|
float wright = this->CalcWeight(nidx, param, right);
|
||||||
@ -92,8 +91,9 @@ class TreeEvaluator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename GradientSumT>
|
||||||
XGBOOST_DEVICE float CalcWeight(bst_node_t nodeid, const ParamT ¶m,
|
XGBOOST_DEVICE float CalcWeight(bst_node_t nodeid, const ParamT ¶m,
|
||||||
tree::GradStats const& stats) const {
|
GradientSumT const& stats) const {
|
||||||
float w = ::xgboost::tree::CalcWeight(param, stats);
|
float w = ::xgboost::tree::CalcWeight(param, stats);
|
||||||
if (!has_constraint) {
|
if (!has_constraint) {
|
||||||
return w;
|
return w;
|
||||||
@ -118,21 +118,32 @@ class TreeEvaluator {
|
|||||||
return ::xgboost::tree::CalcWeight(param, stats);
|
return ::xgboost::tree::CalcWeight(param, stats);
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE float
|
// Fast floating point division instruction on device
|
||||||
CalcGainGivenWeight(ParamT const &p, tree::GradStats const& stats, float w) const {
|
XGBOOST_DEVICE float Divide(float a, float b) const {
|
||||||
|
#ifdef __CUDA_ARCH__
|
||||||
|
return __fdividef(a, b);
|
||||||
|
#else
|
||||||
|
return a / b;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename GradientSumT>
|
||||||
|
XGBOOST_DEVICE float CalcGainGivenWeight(ParamT const& p, GradientSumT const& stats,
|
||||||
|
float w) const {
|
||||||
if (stats.GetHess() <= 0) {
|
if (stats.GetHess() <= 0) {
|
||||||
return .0f;
|
return .0f;
|
||||||
}
|
}
|
||||||
// Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error.
|
// Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error.
|
||||||
if (p.max_delta_step == 0.0f && has_constraint == false) {
|
if (p.max_delta_step == 0.0f && has_constraint == false) {
|
||||||
return common::Sqr(ThresholdL1(stats.sum_grad, p.reg_alpha)) /
|
return Divide(common::Sqr(ThresholdL1(stats.GetGrad(), p.reg_alpha)),
|
||||||
(stats.sum_hess + p.reg_lambda);
|
(stats.GetHess() + p.reg_lambda));
|
||||||
}
|
}
|
||||||
return tree::CalcGainGivenWeight<ParamT, float>(p, stats.sum_grad,
|
return tree::CalcGainGivenWeight<ParamT, float>(p, stats.GetGrad(),
|
||||||
stats.sum_hess, w);
|
stats.GetHess(), w);
|
||||||
}
|
}
|
||||||
|
template <typename GradientSumT>
|
||||||
XGBOOST_DEVICE float CalcGain(bst_node_t nid, ParamT const &p,
|
XGBOOST_DEVICE float CalcGain(bst_node_t nid, ParamT const &p,
|
||||||
tree::GradStats const& stats) const {
|
GradientSumT const& stats) const {
|
||||||
return this->CalcGainGivenWeight(p, stats, this->CalcWeight(nid, p, stats));
|
return this->CalcGainGivenWeight(p, stats, this->CalcWeight(nid, p, stats));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -171,7 +171,7 @@ class DeviceHistogramStorage {
|
|||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
struct GPUHistMakerDevice {
|
struct GPUHistMakerDevice {
|
||||||
private:
|
private:
|
||||||
GPUHistEvaluator<GradientSumT> evaluator_;
|
GPUHistEvaluator evaluator_;
|
||||||
Context const* ctx_;
|
Context const* ctx_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|||||||
@ -62,7 +62,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
|||||||
cuts.min_vals_.ConstDeviceSpan(),
|
cuts.min_vals_.ConstDeviceSpan(),
|
||||||
};
|
};
|
||||||
|
|
||||||
GPUHistEvaluator<GradientPairPrecise> evaluator{
|
GPUHistEvaluator evaluator{
|
||||||
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||||
@ -109,7 +109,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
|||||||
dh::ToSpan(feature_min_values),
|
dh::ToSpan(feature_min_values),
|
||||||
};
|
};
|
||||||
|
|
||||||
GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, feature_set.size(), 0);
|
GPUHistEvaluator evaluator(tparam, feature_set.size(), 0);
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 0);
|
EXPECT_EQ(result.findex, 0);
|
||||||
@ -121,7 +121,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
|||||||
|
|
||||||
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, 1, 0);
|
GPUHistEvaluator evaluator(tparam, 1, 0);
|
||||||
DeviceSplitCandidate result =
|
DeviceSplitCandidate result =
|
||||||
evaluator.EvaluateSingleSplit(EvaluateSplitInputs{}, EvaluateSplitSharedInputs{}).split;
|
evaluator.EvaluateSingleSplit(EvaluateSplitInputs{}, EvaluateSplitSharedInputs{}).split;
|
||||||
EXPECT_EQ(result.findex, -1);
|
EXPECT_EQ(result.findex, -1);
|
||||||
@ -159,7 +159,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
|||||||
dh::ToSpan(feature_min_values),
|
dh::ToSpan(feature_min_values),
|
||||||
};
|
};
|
||||||
|
|
||||||
GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, feature_min_values.size(), 0);
|
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 1);
|
EXPECT_EQ(result.findex, 1);
|
||||||
@ -199,7 +199,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
|||||||
dh::ToSpan(feature_min_values),
|
dh::ToSpan(feature_min_values),
|
||||||
};
|
};
|
||||||
|
|
||||||
GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, feature_min_values.size(), 0);
|
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input,shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input,shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 0);
|
EXPECT_EQ(result.findex, 0);
|
||||||
@ -246,7 +246,7 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
dh::ToSpan(feature_min_values),
|
dh::ToSpan(feature_min_values),
|
||||||
};
|
};
|
||||||
|
|
||||||
GPUHistEvaluator<GradientPairPrecise> evaluator{
|
GPUHistEvaluator evaluator{
|
||||||
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
||||||
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input_left,input_right};
|
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input_left,input_right};
|
||||||
evaluator.LaunchEvaluateSplits(input_left.feature_set.size(),dh::ToSpan(inputs),shared_inputs, evaluator.GetEvaluator(),
|
evaluator.LaunchEvaluateSplits(input_left.feature_set.size(),dh::ToSpan(inputs),shared_inputs, evaluator.GetEvaluator(),
|
||||||
@ -263,7 +263,7 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
|
|
||||||
TEST_F(TestPartitionBasedSplit, GpuHist) {
|
TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||||
dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}};
|
dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}};
|
||||||
GPUHistEvaluator<GradientPairPrecise> evaluator{param_,
|
GPUHistEvaluator evaluator{param_,
|
||||||
static_cast<bst_feature_t>(info_.num_col_), 0};
|
static_cast<bst_feature_t>(info_.num_col_), 0};
|
||||||
|
|
||||||
cuts_.cut_ptrs_.SetDevice(0);
|
cuts_.cut_ptrs_.SetDevice(0);
|
||||||
@ -287,5 +287,6 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
|||||||
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
|
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -43,6 +43,8 @@ class TestPartitionBasedSplit : public ::testing::Test {
|
|||||||
auto &h_vals = cuts_.cut_values_.HostVector();
|
auto &h_vals = cuts_.cut_values_.HostVector();
|
||||||
h_vals.resize(n_bins_);
|
h_vals.resize(n_bins_);
|
||||||
std::iota(h_vals.begin(), h_vals.end(), 0.0);
|
std::iota(h_vals.begin(), h_vals.end(), 0.0);
|
||||||
|
|
||||||
|
cuts_.min_vals_.Resize(1);
|
||||||
|
|
||||||
hist_.Init(cuts_.TotalBins());
|
hist_.Init(cuts_.TotalBins());
|
||||||
hist_.AddHistRow(0);
|
hist_.AddHistRow(0);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user