Use integer gradients in gpu_hist split evaluation (#8274)
This commit is contained in:
parent
c68684ff4c
commit
210915c985
@ -264,8 +264,8 @@ using GradientPairPrecise = detail::GradientPairInternal<double>;
|
||||
* we don't accidentally use it in gain calculations.*/
|
||||
class GradientPairInt64 {
|
||||
using T = int64_t;
|
||||
T grad_;
|
||||
T hess_;
|
||||
T grad_ = 0;
|
||||
T hess_ = 0;
|
||||
|
||||
public:
|
||||
using ValueT = T;
|
||||
|
||||
@ -15,17 +15,20 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
// With constraints
|
||||
XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
|
||||
const GradientPairPrecise &missing,
|
||||
const GradientPairPrecise &parent_sum,
|
||||
XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan,
|
||||
const GradientPairInt64 &missing,
|
||||
const GradientPairInt64 &parent_sum,
|
||||
const GPUTrainingParam ¶m, bst_node_t nidx,
|
||||
bst_feature_t fidx,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
bool &missing_left_out) { // NOLINT
|
||||
bool &missing_left_out, const GradientQuantiser& quantiser) { // NOLINT
|
||||
const auto left_sum = scan + missing;
|
||||
float missing_left_gain =
|
||||
evaluator.CalcSplitGain(param, nidx, fidx, left_sum, parent_sum - left_sum);
|
||||
float missing_right_gain = evaluator.CalcSplitGain(param, nidx, fidx, scan, parent_sum - scan);
|
||||
float missing_left_gain = evaluator.CalcSplitGain(
|
||||
param, nidx, fidx, quantiser.ToFloatingPoint(left_sum),
|
||||
quantiser.ToFloatingPoint(parent_sum - left_sum));
|
||||
float missing_right_gain = evaluator.CalcSplitGain(
|
||||
param, nidx, fidx, quantiser.ToFloatingPoint(scan),
|
||||
quantiser.ToFloatingPoint(parent_sum - scan));
|
||||
|
||||
missing_left_out = missing_left_gain > missing_right_gain;
|
||||
return missing_left_out?missing_left_gain:missing_right_gain;
|
||||
@ -42,9 +45,9 @@ template <int kBlockSize>
|
||||
class EvaluateSplitAgent {
|
||||
public:
|
||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
|
||||
using BlockScanT = cub::BlockScan<GradientPairInt64, kBlockSize>;
|
||||
using MaxReduceT = cub::WarpReduce<ArgMaxT>;
|
||||
using SumReduceT = cub::WarpReduce<GradientPairPrecise>;
|
||||
using SumReduceT = cub::WarpReduce<GradientPairInt64>;
|
||||
|
||||
struct TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
@ -59,67 +62,67 @@ class EvaluateSplitAgent {
|
||||
const uint32_t gidx_end; // end bin for i^th feature
|
||||
const dh::LDGIterator<float> feature_values;
|
||||
const GradientPairInt64 *node_histogram;
|
||||
const GradientQuantizer &rounding;
|
||||
const GradientPairPrecise parent_sum;
|
||||
const GradientPairPrecise missing;
|
||||
const GradientQuantiser &rounding;
|
||||
const GradientPairInt64 parent_sum;
|
||||
const GradientPairInt64 missing;
|
||||
const GPUTrainingParam ¶m;
|
||||
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator;
|
||||
TempStorage *temp_storage;
|
||||
SumCallbackOp<GradientPairPrecise> prefix_op;
|
||||
SumCallbackOp<GradientPairInt64> prefix_op;
|
||||
static float constexpr kNullGain = -std::numeric_limits<bst_float>::infinity();
|
||||
|
||||
__device__ EvaluateSplitAgent(TempStorage *temp_storage, int fidx,
|
||||
const EvaluateSplitInputs &inputs,
|
||||
const EvaluateSplitSharedInputs &shared_inputs,
|
||||
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
|
||||
: temp_storage(temp_storage),
|
||||
nidx(inputs.nidx),
|
||||
fidx(fidx),
|
||||
__device__ EvaluateSplitAgent(
|
||||
TempStorage *temp_storage, int fidx, const EvaluateSplitInputs &inputs,
|
||||
const EvaluateSplitSharedInputs &shared_inputs,
|
||||
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
|
||||
: temp_storage(temp_storage), nidx(inputs.nidx), fidx(fidx),
|
||||
min_fvalue(__ldg(shared_inputs.min_fvalue.data() + fidx)),
|
||||
gidx_begin(__ldg(shared_inputs.feature_segments.data() + fidx)),
|
||||
gidx_end(__ldg(shared_inputs.feature_segments.data() + fidx + 1)),
|
||||
feature_values(shared_inputs.feature_values.data()),
|
||||
node_histogram(inputs.gradient_histogram.data()),
|
||||
rounding(shared_inputs.rounding),
|
||||
parent_sum(dh::LDGIterator<GradientPairPrecise>(&inputs.parent_sum)[0]),
|
||||
param(shared_inputs.param),
|
||||
evaluator(evaluator),
|
||||
parent_sum(dh::LDGIterator<GradientPairInt64>(&inputs.parent_sum)[0]),
|
||||
param(shared_inputs.param), evaluator(evaluator),
|
||||
missing(parent_sum - ReduceFeature()) {
|
||||
static_assert(kBlockSize == 32,
|
||||
"This kernel relies on the assumption block_size == warp_size");
|
||||
static_assert(
|
||||
kBlockSize == 32,
|
||||
"This kernel relies on the assumption block_size == warp_size");
|
||||
// There should be no missing value gradients for a dense matrix
|
||||
KERNEL_CHECK(!shared_inputs.is_dense || missing.GetQuantisedHess() == 0);
|
||||
}
|
||||
__device__ GradientPairPrecise ReduceFeature() {
|
||||
GradientPairPrecise local_sum;
|
||||
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end; idx += kBlockSize) {
|
||||
__device__ GradientPairInt64 ReduceFeature() {
|
||||
GradientPairInt64 local_sum;
|
||||
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end;
|
||||
idx += kBlockSize) {
|
||||
local_sum += LoadGpair(node_histogram + idx);
|
||||
}
|
||||
local_sum = SumReduceT(temp_storage->sum_reduce).Sum(local_sum);
|
||||
// Broadcast result from thread 0
|
||||
return {__shfl_sync(0xffffffff, local_sum.GetGrad(), 0),
|
||||
__shfl_sync(0xffffffff, local_sum.GetHess(), 0)};
|
||||
return {__shfl_sync(0xffffffff, local_sum.GetQuantisedGrad(), 0),
|
||||
__shfl_sync(0xffffffff, local_sum.GetQuantisedHess(), 0)};
|
||||
}
|
||||
|
||||
// Load using efficient 128 vector load instruction
|
||||
__device__ __forceinline__ GradientPairPrecise LoadGpair(const GradientPairInt64 *ptr) {
|
||||
__device__ __forceinline__ GradientPairInt64 LoadGpair(const GradientPairInt64 *ptr) {
|
||||
float4 tmp = *reinterpret_cast<const float4 *>(ptr);
|
||||
auto gpair_int = *reinterpret_cast<const GradientPairInt64 *>(&tmp);
|
||||
static_assert(sizeof(decltype(gpair_int)) == sizeof(float4),
|
||||
auto gpair = *reinterpret_cast<const GradientPairInt64 *>(&tmp);
|
||||
static_assert(sizeof(decltype(gpair)) == sizeof(float4),
|
||||
"Vector type size does not match gradient pair size.");
|
||||
return rounding.ToFloatingPoint(gpair_int);
|
||||
return gpair;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void Numerical(DeviceSplitCandidate *__restrict__ best_split) {
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
GradientPairPrecise bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
|
||||
: GradientPairPrecise();
|
||||
GradientPairInt64 bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
|
||||
: GradientPairInt64();
|
||||
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||
evaluator, missing_left)
|
||||
evaluator, missing_left, rounding)
|
||||
: kNullGain;
|
||||
|
||||
// Find thread with best gain
|
||||
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||
// This reduce result is only valid in thread 0
|
||||
@ -132,10 +135,10 @@ class EvaluateSplitAgent {
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue =
|
||||
split_gidx < static_cast<int>(gidx_begin) ? min_fvalue : feature_values[split_gidx];
|
||||
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise right = parent_sum - left;
|
||||
GradientPairInt64 left = missing_left ? bin + missing : bin;
|
||||
GradientPairInt64 right = parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
false, param);
|
||||
false, param, rounding);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -145,12 +148,12 @@ class EvaluateSplitAgent {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
auto rest = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
|
||||
: GradientPairPrecise();
|
||||
GradientPairPrecise bin = parent_sum - rest - missing;
|
||||
: GradientPairInt64();
|
||||
GradientPairInt64 bin = parent_sum - rest - missing;
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||
evaluator, missing_left)
|
||||
evaluator, missing_left, rounding)
|
||||
: kNullGain;
|
||||
|
||||
// Find thread with best gain
|
||||
@ -162,10 +165,10 @@ class EvaluateSplitAgent {
|
||||
if (threadIdx.x == best_thread) {
|
||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = feature_values[split_gidx];
|
||||
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise right = parent_sum - left;
|
||||
GradientPairInt64 left = missing_left ? bin + missing : bin;
|
||||
GradientPairInt64 right = parent_sum - left;
|
||||
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir,
|
||||
static_cast<bst_cat_t>(fvalue), fidx, left, right, param);
|
||||
static_cast<bst_cat_t>(fvalue), fidx, left, right, param, rounding);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -174,11 +177,13 @@ class EvaluateSplitAgent {
|
||||
*/
|
||||
__device__ __forceinline__ void PartitionUpdate(bst_bin_t scan_begin, bool thread_active,
|
||||
bool missing_left, bst_bin_t it,
|
||||
GradientPairPrecise const &left_sum,
|
||||
GradientPairPrecise const &right_sum,
|
||||
GradientPairInt64 const &left_sum,
|
||||
GradientPairInt64 const &right_sum,
|
||||
DeviceSplitCandidate *__restrict__ best_split) {
|
||||
auto gain =
|
||||
thread_active ? evaluator.CalcSplitGain(param, nidx, fidx, left_sum, right_sum) : kNullGain;
|
||||
auto gain = thread_active
|
||||
? evaluator.CalcSplitGain(param, nidx, fidx, rounding.ToFloatingPoint(left_sum),
|
||||
rounding.ToFloatingPoint(right_sum))
|
||||
: kNullGain;
|
||||
|
||||
// Find thread with best gain
|
||||
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||
@ -191,7 +196,7 @@ class EvaluateSplitAgent {
|
||||
// index of best threshold inside a feature.
|
||||
auto best_thresh = it - gidx_begin;
|
||||
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left_sum,
|
||||
right_sum, param);
|
||||
right_sum, param, rounding);
|
||||
}
|
||||
}
|
||||
/**
|
||||
@ -213,10 +218,10 @@ class EvaluateSplitAgent {
|
||||
bool thread_active = it < it_end;
|
||||
|
||||
auto right_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
|
||||
: GradientPairPrecise();
|
||||
: GradientPairInt64();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
BlockScanT(temp_storage->scan).InclusiveSum(right_sum, right_sum, prefix_op);
|
||||
GradientPairPrecise left_sum = parent_sum - right_sum;
|
||||
GradientPairInt64 left_sum = parent_sum - right_sum;
|
||||
|
||||
PartitionUpdate(scan_begin, thread_active, true, it, left_sum, right_sum, best_split);
|
||||
}
|
||||
@ -224,17 +229,17 @@ class EvaluateSplitAgent {
|
||||
// backward
|
||||
it_begin = gidx_end - 1;
|
||||
it_end = it_begin - n_bins + 1;
|
||||
prefix_op = SumCallbackOp<GradientPairPrecise>{}; // reset
|
||||
prefix_op = SumCallbackOp<GradientPairInt64>{}; // reset
|
||||
|
||||
for (bst_bin_t scan_begin = it_begin; scan_begin > it_end; scan_begin -= kBlockSize) {
|
||||
auto it = scan_begin - static_cast<bst_bin_t>(threadIdx.x);
|
||||
bool thread_active = it > it_end;
|
||||
|
||||
auto left_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
|
||||
: GradientPairPrecise();
|
||||
: GradientPairInt64();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
BlockScanT(temp_storage->scan).InclusiveSum(left_sum, left_sum, prefix_op);
|
||||
GradientPairPrecise right_sum = parent_sum - left_sum;
|
||||
GradientPairInt64 right_sum = parent_sum - left_sum;
|
||||
|
||||
PartitionUpdate(scan_begin, thread_active, false, it, left_sum, right_sum, best_split);
|
||||
}
|
||||
@ -399,22 +404,30 @@ void GPUHistEvaluator::EvaluateSplits(
|
||||
auto const input = d_inputs[i];
|
||||
auto &split = out_splits[i];
|
||||
// Subtract parent gain here
|
||||
// As it is constant, this is more efficient than doing it during every split evaluation
|
||||
float parent_gain = CalcGain(shared_inputs.param, input.parent_sum);
|
||||
// As it is constant, this is more efficient than doing it during every
|
||||
// split evaluation
|
||||
float parent_gain =
|
||||
CalcGain(shared_inputs.param,
|
||||
shared_inputs.rounding.ToFloatingPoint(input.parent_sum));
|
||||
split.loss_chg -= parent_gain;
|
||||
auto fidx = out_splits[i].findex;
|
||||
|
||||
if (split.is_cat) {
|
||||
SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i,
|
||||
device_cats_accessor.GetNodeCatStorage(input.nidx), &out_splits[i]);
|
||||
device_cats_accessor.GetNodeCatStorage(input.nidx),
|
||||
&out_splits[i]);
|
||||
}
|
||||
|
||||
float base_weight = evaluator.CalcWeight(input.nidx, shared_inputs.param,
|
||||
GradStats{split.left_sum + split.right_sum});
|
||||
float left_weight =
|
||||
evaluator.CalcWeight(input.nidx, shared_inputs.param, GradStats{split.left_sum});
|
||||
float right_weight =
|
||||
evaluator.CalcWeight(input.nidx, shared_inputs.param, GradStats{split.right_sum});
|
||||
float base_weight =
|
||||
evaluator.CalcWeight(input.nidx, shared_inputs.param,
|
||||
shared_inputs.rounding.ToFloatingPoint(
|
||||
split.left_sum + split.right_sum));
|
||||
float left_weight = evaluator.CalcWeight(
|
||||
input.nidx, shared_inputs.param,
|
||||
shared_inputs.rounding.ToFloatingPoint(split.left_sum));
|
||||
float right_weight = evaluator.CalcWeight(
|
||||
input.nidx, shared_inputs.param,
|
||||
shared_inputs.rounding.ToFloatingPoint(split.right_sum));
|
||||
|
||||
d_entries[i] = GPUExpandEntry{input.nidx, input.depth, out_splits[i],
|
||||
base_weight, left_weight, right_weight};
|
||||
|
||||
@ -23,7 +23,7 @@ namespace tree {
|
||||
struct EvaluateSplitInputs {
|
||||
int nidx;
|
||||
int depth;
|
||||
GradientPairPrecise parent_sum;
|
||||
GradientPairInt64 parent_sum;
|
||||
common::Span<const bst_feature_t> feature_set;
|
||||
common::Span<const GradientPairInt64> gradient_histogram;
|
||||
};
|
||||
@ -31,11 +31,12 @@ struct EvaluateSplitInputs {
|
||||
// Inputs necessary for all nodes
|
||||
struct EvaluateSplitSharedInputs {
|
||||
GPUTrainingParam param;
|
||||
GradientQuantizer rounding;
|
||||
GradientQuantiser rounding;
|
||||
common::Span<FeatureType const> feature_types;
|
||||
common::Span<const uint32_t> feature_segments;
|
||||
common::Span<const float> feature_values;
|
||||
common::Span<const float> min_fvalue;
|
||||
bool is_dense;
|
||||
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
|
||||
__device__ auto FeatureBins(bst_feature_t fidx) const {
|
||||
return feature_segments[fidx + 1] - feature_segments[fidx];
|
||||
|
||||
@ -27,7 +27,7 @@ struct GPUExpandEntry {
|
||||
left_weight{left}, right_weight{right} {}
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
if (split.left_sum.GetQuantisedHess() == 0 || split.right_sum.GetQuantisedHess() == 0) {
|
||||
return false;
|
||||
}
|
||||
if (split.loss_chg < param.min_split_loss) {
|
||||
|
||||
@ -72,7 +72,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||
}
|
||||
};
|
||||
|
||||
GradientQuantizer::GradientQuantizer(common::Span<GradientPair const> gpair) {
|
||||
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
@ -153,14 +153,14 @@ class HistogramAgent {
|
||||
const EllpackDeviceAccessor& matrix_;
|
||||
const int feature_stride_;
|
||||
const std::size_t n_elements_;
|
||||
const GradientQuantizer& rounding_;
|
||||
const GradientQuantiser& rounding_;
|
||||
|
||||
public:
|
||||
__device__ HistogramAgent(GradientPairInt64* smem_arr,
|
||||
GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group,
|
||||
const EllpackDeviceAccessor& matrix,
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
const GradientQuantizer& rounding, const GradientPair* d_gpair)
|
||||
const GradientQuantiser& rounding, const GradientPair* d_gpair)
|
||||
: smem_arr_(smem_arr),
|
||||
d_node_hist_(d_node_hist),
|
||||
d_ridx_(d_ridx.data()),
|
||||
@ -254,7 +254,7 @@ __global__ void __launch_bounds__(kBlockThreads)
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
GradientPairInt64* __restrict__ d_node_hist,
|
||||
const GradientPair* __restrict__ d_gpair,
|
||||
GradientQuantizer const rounding) {
|
||||
GradientQuantiser const rounding) {
|
||||
extern __shared__ char smem[];
|
||||
const FeatureGroup group = feature_groups[blockIdx.y];
|
||||
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
|
||||
@ -272,7 +272,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> d_ridx,
|
||||
common::Span<GradientPairInt64> histogram,
|
||||
GradientQuantizer rounding, bool force_global_memory) {
|
||||
GradientQuantiser rounding, bool force_global_memory) {
|
||||
// decide whether to use shared memory
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
|
||||
@ -31,19 +31,24 @@ XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) {
|
||||
atomicAdd(y_high, sig);
|
||||
}
|
||||
|
||||
class GradientQuantizer {
|
||||
class GradientQuantiser {
|
||||
private:
|
||||
/* Convert gradient to fixed point representation. */
|
||||
GradientPairPrecise to_fixed_point_;
|
||||
/* Convert fixed point representation back to floating point. */
|
||||
GradientPairPrecise to_floating_point_;
|
||||
public:
|
||||
explicit GradientQuantizer(common::Span<GradientPair const> gpair);
|
||||
explicit GradientQuantiser(common::Span<GradientPair const> gpair);
|
||||
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
|
||||
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
||||
gpair.GetHess() * to_fixed_point_.GetHess());
|
||||
return adjusted;
|
||||
}
|
||||
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPairPrecise const& gpair) const {
|
||||
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
||||
gpair.GetHess() * to_fixed_point_.GetHess());
|
||||
return adjusted;
|
||||
}
|
||||
XGBOOST_DEVICE GradientPairPrecise ToFloatingPoint(const GradientPairInt64&gpair) const {
|
||||
auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad();
|
||||
auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess();
|
||||
@ -56,7 +61,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPairInt64> histogram,
|
||||
GradientQuantizer rounding,
|
||||
GradientQuantiser rounding,
|
||||
bool force_global_memory = false);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/random.h"
|
||||
#include "gpu_hist/histogram.cuh"
|
||||
#include "param.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -66,54 +67,43 @@ struct DeviceSplitCandidate {
|
||||
common::CatBitField split_cats;
|
||||
bool is_cat { false };
|
||||
|
||||
GradientPairPrecise left_sum;
|
||||
GradientPairPrecise right_sum;
|
||||
GradientPairInt64 left_sum;
|
||||
GradientPairInt64 right_sum;
|
||||
|
||||
XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT
|
||||
|
||||
template <typename ParamT>
|
||||
XGBOOST_DEVICE void Update(const DeviceSplitCandidate& other,
|
||||
const ParamT& param) {
|
||||
if (other.loss_chg > loss_chg &&
|
||||
other.left_sum.GetHess() >= param.min_child_weight &&
|
||||
other.right_sum.GetHess() >= param.min_child_weight) {
|
||||
*this = other;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE void SetCat(T c) {
|
||||
this->split_cats.Set(common::AsCat(c));
|
||||
fvalue = std::max(this->fvalue, static_cast<float>(c));
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in,
|
||||
float fvalue_in, int findex_in,
|
||||
GradientPairPrecise left_sum_in,
|
||||
GradientPairPrecise right_sum_in,
|
||||
bool cat,
|
||||
const GPUTrainingParam& param) {
|
||||
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, float fvalue_in,
|
||||
int findex_in, GradientPairInt64 left_sum_in,
|
||||
GradientPairInt64 right_sum_in, bool cat,
|
||||
const GPUTrainingParam& param, const GradientQuantiser& quantiser) {
|
||||
if (loss_chg_in > loss_chg &&
|
||||
left_sum_in.GetHess() >= param.min_child_weight &&
|
||||
right_sum_in.GetHess() >= param.min_child_weight) {
|
||||
loss_chg = loss_chg_in;
|
||||
dir = dir_in;
|
||||
fvalue = fvalue_in;
|
||||
is_cat = cat;
|
||||
left_sum = left_sum_in;
|
||||
right_sum = right_sum_in;
|
||||
findex = findex_in;
|
||||
}
|
||||
quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight &&
|
||||
quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) {
|
||||
loss_chg = loss_chg_in;
|
||||
dir = dir_in;
|
||||
fvalue = fvalue_in;
|
||||
is_cat = cat;
|
||||
left_sum = left_sum_in;
|
||||
right_sum = right_sum_in;
|
||||
findex = findex_in;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Update for partition-based splits.
|
||||
*/
|
||||
XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in,
|
||||
bst_feature_t findex_in, GradientPairPrecise left_sum_in,
|
||||
GradientPairPrecise right_sum_in, GPUTrainingParam const& param) {
|
||||
if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight &&
|
||||
right_sum_in.GetHess() >= param.min_child_weight) {
|
||||
bst_feature_t findex_in, GradientPairInt64 left_sum_in,
|
||||
GradientPairInt64 right_sum_in, GPUTrainingParam const& param, const GradientQuantiser& quantiser) {
|
||||
if (loss_chg_in > loss_chg &&
|
||||
quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight &&
|
||||
quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) {
|
||||
loss_chg = loss_chg_in;
|
||||
dir = dir_in;
|
||||
fvalue = std::numeric_limits<float>::quiet_NaN();
|
||||
|
||||
@ -190,12 +190,9 @@ struct GPUHistMakerDevice {
|
||||
dh::device_vector<int> monotone_constraints;
|
||||
dh::device_vector<float> update_predictions;
|
||||
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPairPrecise> node_sum_gradients;
|
||||
|
||||
TrainParam param;
|
||||
|
||||
std::unique_ptr<GradientQuantizer> histogram_rounding;
|
||||
std::unique_ptr<GradientQuantiser> quantiser;
|
||||
|
||||
dh::PinnedMemory pinned;
|
||||
dh::PinnedMemory pinned2;
|
||||
@ -227,7 +224,6 @@ struct GPUHistMakerDevice {
|
||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
}
|
||||
node_sum_gradients.resize(256);
|
||||
|
||||
// Init histogram
|
||||
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
|
||||
@ -255,7 +251,6 @@ struct GPUHistMakerDevice {
|
||||
ctx_->gpu_id);
|
||||
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
|
||||
|
||||
if (d_gpair.size() != dh_gpair->Size()) {
|
||||
d_gpair.resize(dh_gpair->Size());
|
||||
@ -267,14 +262,14 @@ struct GPUHistMakerDevice {
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
|
||||
histogram_rounding.reset(new GradientQuantizer(this->gpair));
|
||||
quantiser.reset(new GradientQuantiser(this->gpair));
|
||||
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum) {
|
||||
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
|
||||
int nidx = RegTree::kRoot;
|
||||
GPUTrainingParam gpu_param(param);
|
||||
auto sampled_features = column_sampler.GetFeatureSet(0);
|
||||
@ -285,11 +280,12 @@ struct GPUHistMakerDevice {
|
||||
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
gpu_param,
|
||||
*histogram_rounding,
|
||||
*quantiser,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
matrix.is_dense
|
||||
};
|
||||
auto split = this->evaluator_.EvaluateSingleSplit(inputs, shared_inputs);
|
||||
return split;
|
||||
@ -304,8 +300,9 @@ struct GPUHistMakerDevice {
|
||||
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
|
||||
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
GPUTrainingParam{param}, *histogram_rounding, feature_types, matrix.feature_segments,
|
||||
GPUTrainingParam{param}, *quantiser, feature_types, matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map, matrix.min_fvalue,
|
||||
matrix.is_dense
|
||||
};
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
||||
for (size_t i = 0; i < candidates.size(); i++) {
|
||||
@ -350,7 +347,7 @@ struct GPUHistMakerDevice {
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id),
|
||||
feature_groups->DeviceAccessor(ctx_->gpu_id), gpair,
|
||||
d_ridx, d_node_hist, *histogram_rounding);
|
||||
d_ridx, d_node_hist, *quantiser);
|
||||
}
|
||||
|
||||
// Attempt to do subtraction trick
|
||||
@ -552,7 +549,7 @@ struct GPUHistMakerDevice {
|
||||
for (auto& e : candidates) {
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Use sum of Hessian as a heuristic to select node with fewest training instances
|
||||
bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess();
|
||||
bool fewer_right = e.split.right_sum.GetQuantisedHess() < e.split.left_sum.GetQuantisedHess();
|
||||
if (fewer_right) {
|
||||
hist_nidx.emplace_back(tree[e.nid].RightChild());
|
||||
subtraction_nidx.emplace_back(tree[e.nid].LeftChild());
|
||||
@ -598,10 +595,17 @@ struct GPUHistMakerDevice {
|
||||
<< "No training instances in this leaf!";
|
||||
}
|
||||
|
||||
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
|
||||
auto base_weight = candidate.base_weight;
|
||||
auto left_weight = candidate.left_weight * param.learning_rate;
|
||||
auto right_weight = candidate.right_weight * param.learning_rate;
|
||||
auto parent_hess = quantiser
|
||||
->ToFloatingPoint(candidate.split.left_sum +
|
||||
candidate.split.right_sum)
|
||||
.GetHess();
|
||||
auto left_hess =
|
||||
quantiser->ToFloatingPoint(candidate.split.left_sum).GetHess();
|
||||
auto right_hess =
|
||||
quantiser->ToFloatingPoint(candidate.split.right_sum).GetHess();
|
||||
|
||||
auto is_cat = candidate.split.is_cat;
|
||||
if (is_cat) {
|
||||
@ -618,26 +622,19 @@ struct GPUHistMakerDevice {
|
||||
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_hess,
|
||||
left_hess, right_hess);
|
||||
} else {
|
||||
CHECK(!common::CheckNAN(candidate.split.fvalue));
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
|
||||
candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
candidate.split.loss_chg, parent_hess,
|
||||
left_hess, right_hess);
|
||||
}
|
||||
evaluator_.ApplyTreeSplit(candidate, p_tree);
|
||||
|
||||
const auto& parent = tree[candidate.nid];
|
||||
std::size_t max_nidx = std::max(parent.LeftChild(), parent.RightChild());
|
||||
// Grow as needed
|
||||
if (node_sum_gradients.size() <= max_nidx) {
|
||||
node_sum_gradients.resize(max_nidx * 2 + 1);
|
||||
}
|
||||
node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum;
|
||||
node_sum_gradients[parent.RightChild()] = candidate.split.right_sum;
|
||||
|
||||
interaction_constraints.Split(candidate.nid, parent.SplitIndex(), parent.LeftChild(),
|
||||
parent.RightChild());
|
||||
}
|
||||
@ -645,26 +642,31 @@ struct GPUHistMakerDevice {
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
|
||||
dh::tbegin(gpair), [] __device__(auto const& gpair) { return GradientPairPrecise{gpair}; });
|
||||
GradientPairPrecise root_sum =
|
||||
auto quantiser = *this->quantiser;
|
||||
auto gpair_it = dh::MakeTransformIterator<GradientPairInt64>(
|
||||
dh::tbegin(gpair), [=] __device__(auto const &gpair) {
|
||||
return quantiser.ToFixedPoint(gpair);
|
||||
});
|
||||
GradientPairInt64 root_sum_quantised =
|
||||
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(),
|
||||
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double*>(&root_sum), 2);
|
||||
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
|
||||
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<ReduceT *>(&root_sum_quantised), 2);
|
||||
|
||||
hist.AllocateHistograms({kRootNIdx});
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, communicator, 1);
|
||||
|
||||
// Remember root stats
|
||||
node_sum_gradients[kRootNIdx] = root_sum;
|
||||
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
|
||||
p_tree->Stat(kRootNIdx).sum_hess = root_sum.GetHess();
|
||||
auto weight = CalcWeight(param, root_sum);
|
||||
p_tree->Stat(kRootNIdx).base_weight = weight;
|
||||
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
|
||||
|
||||
// Generate first split
|
||||
auto root_entry = this->EvaluateRootSplit(root_sum);
|
||||
auto root_entry = this->EvaluateRootSplit(root_sum_quantised);
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@ TEST(GpuHist, DriverDepthWise) {
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
DeviceSplitCandidate split;
|
||||
split.loss_chg = 1.0f;
|
||||
split.left_sum = {0.0f, 1.0f};
|
||||
split.right_sum = {0.0f, 1.0f};
|
||||
split.left_sum = {0, 1};
|
||||
split.right_sum = {0, 1};
|
||||
GPUExpandEntry root(0, 0, split, 2.0f, 1.0f, 1.0f);
|
||||
driver.Push({root});
|
||||
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||
@ -42,8 +42,8 @@ TEST(GpuHist, DriverDepthWise) {
|
||||
|
||||
TEST(GpuHist, DriverLossGuided) {
|
||||
DeviceSplitCandidate high_gain;
|
||||
high_gain.left_sum = {0.0f, 1.0f};
|
||||
high_gain.right_sum = {0.0f, 1.0f};
|
||||
high_gain.left_sum = {0, 1};
|
||||
high_gain.right_sum = {0, 1};
|
||||
high_gain.loss_chg = 5.0f;
|
||||
DeviceSplitCandidate low_gain = high_gain;
|
||||
low_gain.loss_chg = 1.0f;
|
||||
|
||||
@ -22,10 +22,10 @@ auto ZeroParam() {
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
inline GradientQuantizer DummyRoundingFactor() {
|
||||
inline GradientQuantiser DummyRoundingFactor() {
|
||||
thrust::device_vector<GradientPair> gpair(1);
|
||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
||||
return GradientQuantizer(dh::ToSpan(gpair));
|
||||
return GradientQuantiser(dh::ToSpan(gpair));
|
||||
}
|
||||
|
||||
thrust::device_vector<GradientPairInt64> ConvertToInteger(std::vector<GradientPairPrecise> x) {
|
||||
@ -48,16 +48,16 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
auto d_feature_types = dh::ToSpan(feature_types);
|
||||
|
||||
EvaluateSplitInputs input{1, 0, parent_sum_, dh::ToSpan(feature_set),
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
EvaluateSplitInputs input{1, 0, quantiser.ToFixedPoint(parent_sum_), dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
d_feature_types,
|
||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts_.cut_values_.ConstDeviceSpan(),
|
||||
cuts_.min_vals_.ConstDeviceSpan(),
|
||||
cuts_.min_vals_.ConstDeviceSpan(), false
|
||||
};
|
||||
|
||||
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
@ -67,7 +67,7 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
|
||||
ASSERT_EQ(result.thresh, 1);
|
||||
this->CheckResult(result.loss_chg, result.findex, result.fvalue, result.is_cat,
|
||||
result.dir == kLeftDir, result.left_sum, result.right_sum);
|
||||
result.dir == kLeftDir, quantiser.ToFloatingPoint(result.left_sum), quantiser.ToFloatingPoint(result.right_sum));
|
||||
}
|
||||
|
||||
TEST(GpuHist, PartitionBasic) {
|
||||
@ -91,10 +91,10 @@ TEST(GpuHist, PartitionBasic) {
|
||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
d_feature_types = dh::ToSpan(feature_types);
|
||||
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -107,7 +107,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
@ -115,14 +115,13 @@ TEST(GpuHist, PartitionBasic) {
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-7.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-7.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-3.0, 1.0}, {-3.0, 1.0}});
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
@ -130,25 +129,23 @@ TEST(GpuHist, PartitionBasic) {
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
{
|
||||
// All -1.0, gain from splitting should be 0.0
|
||||
GradientPairPrecise parent_sum(-3.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-3.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
||||
EvaluateSplitInputs input{2, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_FLOAT_EQ(result.loss_chg, 0.0f);
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
// With 3.0/3.0 missing values
|
||||
// Forward, first 2 categories are selected, while the last one go to left along with missing value
|
||||
{
|
||||
GradientPairPrecise parent_sum(0.0, 6.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 6.0});
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
||||
EvaluateSplitInputs input{3, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
@ -156,13 +153,12 @@ TEST(GpuHist, PartitionBasic) {
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-3.0, 1.0}, {-1.0, 1.0}});
|
||||
EvaluateSplitInputs input{4, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
@ -170,21 +166,19 @@ TEST(GpuHist, PartitionBasic) {
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_EQ(cats, std::bitset<32>("10100000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({{-3.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
||||
EvaluateSplitInputs input{5, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(cats, std::bitset<32>("01000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,9 +203,10 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -222,7 +217,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({ {-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
@ -230,12 +225,11 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({ {-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}});
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
@ -243,8 +237,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
}
|
||||
|
||||
@ -269,9 +262,10 @@ TEST(GpuHist, PartitionTwoNodes) {
|
||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -282,7 +276,7 @@ TEST(GpuHist, PartitionTwoNodes) {
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
auto feature_histogram_a = ConvertToInteger({{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0},
|
||||
{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
thrust::device_vector<EvaluateSplitInputs> inputs(2);
|
||||
@ -303,7 +297,8 @@ TEST(GpuHist, PartitionTwoNodes) {
|
||||
}
|
||||
|
||||
void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
@ -327,7 +322,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -345,10 +340,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
} else {
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
}
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(),
|
||||
parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(),
|
||||
parent_sum.GetHess());
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplit) {
|
||||
@ -360,7 +352,8 @@ TEST(GpuHist, EvaluateSingleCategoricalSplit) {
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
GradientPairPrecise parent_sum(1.0, 1.5);
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{1.0, 1.5});
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
@ -377,7 +370,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -390,8 +383,8 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
EXPECT_EQ(result.findex, 0);
|
||||
EXPECT_EQ(result.fvalue, 1.0);
|
||||
EXPECT_EQ(result.dir, kRightDir);
|
||||
EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5));
|
||||
EXPECT_EQ(result.right_sum, GradientPairPrecise(1.5, 1.0));
|
||||
EXPECT_EQ(result.left_sum,quantiser.ToFixedPoint(GradientPairPrecise(-0.5, 0.5)));
|
||||
EXPECT_EQ(result.right_sum, quantiser.ToFixedPoint(GradientPairPrecise(1.5, 1.0)));
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||
@ -409,7 +402,8 @@ TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||
|
||||
// Feature 0 has a better split, but the algorithm must select feature 1
|
||||
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
GPUTrainingParam param{tparam};
|
||||
@ -429,7 +423,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -441,13 +435,14 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5));
|
||||
EXPECT_EQ(result.right_sum, GradientPairPrecise(0.5, 0.5));
|
||||
EXPECT_EQ(result.left_sum,quantiser.ToFixedPoint(GradientPairPrecise(-0.5, 0.5)));
|
||||
EXPECT_EQ(result.right_sum, quantiser.ToFixedPoint(GradientPairPrecise(0.5, 0.5)));
|
||||
}
|
||||
|
||||
// Features 0 and 1 have identical gain, the algorithm must select 0
|
||||
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
GPUTrainingParam param{tparam};
|
||||
@ -467,7 +462,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -483,7 +478,8 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
|
||||
TEST(GpuHist, EvaluateSplits) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(2);
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
GPUTrainingParam param{tparam};
|
||||
@ -510,7 +506,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
dh::ToSpan(feature_histogram_right)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
quantiser,
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -543,18 +539,18 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, 0);
|
||||
|
||||
// Convert the sample histogram to fixed point
|
||||
auto rounding = DummyRoundingFactor();
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
thrust::host_vector<GradientPairInt64> h_hist;
|
||||
for(auto e: hist_[0]){
|
||||
h_hist.push_back(rounding.ToFixedPoint({float(e.GetGrad()),float(e.GetHess())}));
|
||||
h_hist.push_back(quantiser.ToFixedPoint(e));
|
||||
}
|
||||
dh::device_vector<GradientPairInt64> d_hist = h_hist;
|
||||
dh::device_vector<bst_feature_t> feature_set{std::vector<bst_feature_t>{0}};
|
||||
|
||||
EvaluateSplitInputs input{0, 0, total_gpair_, dh::ToSpan(feature_set), dh::ToSpan(d_hist)};
|
||||
EvaluateSplitInputs input{0, 0, quantiser.ToFixedPoint(total_gpair_), dh::ToSpan(feature_set), dh::ToSpan(d_hist)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
GPUTrainingParam{param_},
|
||||
rounding,
|
||||
quantiser,
|
||||
dh::ToSpan(ft),
|
||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts_.cut_values_.ConstDeviceSpan(),
|
||||
|
||||
@ -33,10 +33,10 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
||||
sizeof(GradientPairInt64));
|
||||
|
||||
auto rounding = GradientQuantizer(gpair.DeviceSpan());
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(),
|
||||
ridx, d_histogram, rounding);
|
||||
ridx, d_histogram, quantiser);
|
||||
|
||||
std::vector<GradientPairInt64> histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
||||
@ -47,11 +47,11 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
|
||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||
|
||||
auto rounding = GradientQuantizer(gpair.DeviceSpan());
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, d_new_histogram,
|
||||
rounding);
|
||||
quantiser);
|
||||
|
||||
std::vector<GradientPairInt64> new_histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
|
||||
@ -74,7 +74,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
single_group.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, dh::ToSpan(baseline),
|
||||
rounding);
|
||||
quantiser);
|
||||
|
||||
std::vector<GradientPairInt64> baseline_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
|
||||
@ -126,7 +126,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||
gpair.SetDevice(0);
|
||||
auto rounding = GradientQuantizer(gpair.DeviceSpan());
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||
/**
|
||||
* Generate hist with cat data.
|
||||
*/
|
||||
@ -136,7 +136,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
single_group.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist),
|
||||
rounding);
|
||||
quantiser);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -151,7 +151,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
single_group.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist),
|
||||
rounding);
|
||||
quantiser);
|
||||
}
|
||||
|
||||
std::vector<GradientPairInt64> h_cat_hist(cat_hist.size());
|
||||
|
||||
@ -107,12 +107,12 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||
maker.hist.AllocateHistograms({0});
|
||||
maker.gpair = gpair.DeviceSpan();
|
||||
maker.histogram_rounding.reset(new GradientQuantizer(maker.gpair));
|
||||
maker.quantiser.reset(new GradientQuantiser(maker.gpair));
|
||||
|
||||
BuildGradientHistogram(
|
||||
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), maker.row_partitioner->GetRows(0),
|
||||
maker.hist.GetNodeHistogram(0), *maker.histogram_rounding,
|
||||
maker.hist.GetNodeHistogram(0), *maker.quantiser,
|
||||
!use_shared_memory_histograms);
|
||||
|
||||
DeviceHistogramStorage<>& d_hist = maker.hist;
|
||||
@ -125,7 +125,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
|
||||
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
|
||||
for (size_t i = 0; i < h_result.size(); ++i) {
|
||||
auto result = maker.histogram_rounding->ToFloatingPoint(h_result[i]);
|
||||
auto result = maker.quantiser->ToFloatingPoint(h_result[i]);
|
||||
EXPECT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
|
||||
EXPECT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
|
||||
}
|
||||
@ -156,85 +156,10 @@ HistogramCutsWrapper GetHostCutMatrix () {
|
||||
return cmat;
|
||||
}
|
||||
|
||||
inline GradientQuantizer DummyRoundingFactor() {
|
||||
inline GradientQuantiser DummyRoundingFactor() {
|
||||
thrust::device_vector<GradientPair> gpair(1);
|
||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
||||
return GradientQuantizer(dh::ToSpan(gpair));
|
||||
}
|
||||
|
||||
// TODO(trivialfis): This test is over simplified.
|
||||
TEST(GpuHist, EvaluateRootSplit) {
|
||||
constexpr int kNRows = 16;
|
||||
constexpr int kNCols = 8;
|
||||
|
||||
TrainParam param;
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> args{
|
||||
{"max_depth", "1"},
|
||||
{"max_leaves", "0"},
|
||||
|
||||
// Disable all other parameters.
|
||||
{"colsample_bynode", "1"},
|
||||
{"colsample_bylevel", "1"},
|
||||
{"colsample_bytree", "1"},
|
||||
{"min_child_weight", "0.01"},
|
||||
{"reg_alpha", "0"},
|
||||
{"reg_lambda", "0"},
|
||||
{"max_delta_step", "0"}};
|
||||
param.Init(args);
|
||||
for (size_t i = 0; i < kNCols; ++i) {
|
||||
param.monotone_constraints.emplace_back(0);
|
||||
}
|
||||
|
||||
int max_bins = 4;
|
||||
|
||||
// Initialize GPUHistMakerDevice
|
||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||
BatchParam batch_param{};
|
||||
Context ctx{CreateEmptyGenericParam(0)};
|
||||
GPUHistMakerDevice<GradientPairPrecise> maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols,
|
||||
batch_param);
|
||||
// Initialize GPUHistMakerDevice::node_sum_gradients
|
||||
maker.node_sum_gradients = {};
|
||||
|
||||
// Initialize GPUHistMakerDevice::cut
|
||||
auto cmat = GetHostCutMatrix();
|
||||
|
||||
// Copy cut matrix to device.
|
||||
page->Cuts() = cmat;
|
||||
maker.monotone_constraints = param.monotone_constraints;
|
||||
|
||||
// Initialize GPUHistMakerDevice::hist
|
||||
maker.hist.Init(0, (max_bins - 1) * kNCols);
|
||||
maker.hist.AllocateHistograms({0});
|
||||
// Each row of hist_gpair represents gpairs for one feature.
|
||||
// Each entry represents a bin.
|
||||
std::vector<GradientPairPrecise> hist_gpair = GetHostHistGpair();
|
||||
maker.histogram_rounding.reset(new GradientQuantizer(DummyRoundingFactor()));
|
||||
std::vector<int64_t> hist;
|
||||
for (auto pair : hist_gpair) {
|
||||
auto grad = maker.histogram_rounding->ToFixedPoint({float(pair.GetGrad()),float(pair.GetHess())});
|
||||
hist.push_back(grad.GetQuantisedGrad());
|
||||
hist.push_back(grad.GetQuantisedHess());
|
||||
}
|
||||
|
||||
ASSERT_EQ(maker.hist.Data().size(), hist.size());
|
||||
thrust::copy(hist.begin(), hist.end(),
|
||||
maker.hist.Data().begin());
|
||||
std::vector<float> feature_weights;
|
||||
|
||||
maker.column_sampler.Init(kNCols, feature_weights, param.colsample_bynode,
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
|
||||
RegTree tree;
|
||||
MetaInfo info;
|
||||
info.num_row_ = kNRows;
|
||||
info.num_col_ = kNCols;
|
||||
|
||||
DeviceSplitCandidate res = maker.EvaluateRootSplit({6.4f, 12.8f}).split;
|
||||
|
||||
ASSERT_EQ(res.findex, 7);
|
||||
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
|
||||
return GradientQuantiser(dh::ToSpan(gpair));
|
||||
}
|
||||
|
||||
void TestHistogramIndexImpl() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user