Use quantised gradients in gpu_hist histograms (#8246)
This commit is contained in:
parent
4056974e37
commit
8f77677193
@ -259,10 +259,61 @@ class GradientPairInternal {
|
||||
using GradientPair = detail::GradientPairInternal<float>;
|
||||
/*! \brief High precision gradient statistics pair */
|
||||
using GradientPairPrecise = detail::GradientPairInternal<double>;
|
||||
/*! \brief Fixed point representation for gradient pair. */
|
||||
using GradientPairInt32 = detail::GradientPairInternal<int>;
|
||||
/*! \brief Fixed point representation for high precision gradient pair. */
|
||||
using GradientPairInt64 = detail::GradientPairInternal<int64_t>;
|
||||
|
||||
/*! \brief Fixed point representation for high precision gradient pair. Has a different interface so
|
||||
* we don't accidentally use it in gain calculations.*/
|
||||
class GradientPairInt64 {
|
||||
using T = int64_t;
|
||||
T grad_;
|
||||
T hess_;
|
||||
|
||||
public:
|
||||
using ValueT = T;
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64(T grad, T hess) : grad_(grad), hess_(hess) {}
|
||||
GradientPairInt64() = default;
|
||||
|
||||
// Copy constructor if of same value type, marked as default to be trivially_copyable
|
||||
GradientPairInt64(const GradientPairInt64 &g) = default;
|
||||
|
||||
XGBOOST_DEVICE T GetQuantisedGrad() const { return grad_; }
|
||||
XGBOOST_DEVICE T GetQuantisedHess() const { return hess_; }
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 &operator+=(const GradientPairInt64 &rhs) {
|
||||
grad_ += rhs.grad_;
|
||||
hess_ += rhs.hess_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 operator+(const GradientPairInt64 &rhs) const {
|
||||
GradientPairInt64 g;
|
||||
g.grad_ = grad_ + rhs.grad_;
|
||||
g.hess_ = hess_ + rhs.hess_;
|
||||
return g;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 &operator-=(const GradientPairInt64 &rhs) {
|
||||
grad_ -= rhs.grad_;
|
||||
hess_ -= rhs.hess_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 operator-(const GradientPairInt64 &rhs) const {
|
||||
GradientPairInt64 g;
|
||||
g.grad_ = grad_ - rhs.grad_;
|
||||
g.hess_ = hess_ - rhs.hess_;
|
||||
return g;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bool operator==(const GradientPairInt64 &rhs) const {
|
||||
return grad_ == rhs.grad_ && hess_ == rhs.hess_;
|
||||
}
|
||||
friend std::ostream &operator<<(std::ostream &os,
|
||||
const GradientPairInt64 &g) {
|
||||
os << g.GetQuantisedGrad() << "/" << g.GetQuantisedHess();
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
using Args = std::vector<std::pair<std::string, std::string> >;
|
||||
|
||||
|
||||
@ -1511,44 +1511,6 @@ XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief An atomicAdd designed for gradient pair with better performance. For general
|
||||
* int64_t atomicAdd, one can simply cast it to unsigned long long.
|
||||
*/
|
||||
XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t *dst, int64_t src) {
|
||||
uint32_t* y_low = reinterpret_cast<uint32_t *>(dst);
|
||||
uint32_t *y_high = y_low + 1;
|
||||
|
||||
auto cast_src = reinterpret_cast<uint64_t *>(&src);
|
||||
|
||||
uint32_t const x_low = static_cast<uint32_t>(src);
|
||||
uint32_t const x_high = (*cast_src) >> 32;
|
||||
|
||||
auto const old = atomicAdd(y_low, x_low);
|
||||
uint32_t const carry = old > (std::numeric_limits<uint32_t>::max() - x_low) ? 1 : 0;
|
||||
uint32_t const sig = x_high + carry;
|
||||
atomicAdd(y_high, sig);
|
||||
}
|
||||
|
||||
XGBOOST_DEV_INLINE void
|
||||
AtomicAddGpair(xgboost::GradientPairInt64 *dest,
|
||||
xgboost::GradientPairInt64 const &gpair) {
|
||||
auto dst_ptr = reinterpret_cast<int64_t *>(dest);
|
||||
auto g = gpair.GetGrad();
|
||||
auto h = gpair.GetHess();
|
||||
|
||||
AtomicAdd64As32(dst_ptr, g);
|
||||
AtomicAdd64As32(dst_ptr + 1, h);
|
||||
}
|
||||
|
||||
XGBOOST_DEV_INLINE void
|
||||
AtomicAddGpair(xgboost::GradientPairInt32 *dest,
|
||||
xgboost::GradientPairInt32 const &gpair) {
|
||||
auto dst_ptr = reinterpret_cast<typename xgboost::GradientPairInt32::ValueT*>(dest);
|
||||
|
||||
::atomicAdd(dst_ptr, static_cast<int>(gpair.GetGrad()));
|
||||
::atomicAdd(dst_ptr + 1, static_cast<int>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
|
||||
@ -58,7 +58,8 @@ class EvaluateSplitAgent {
|
||||
const uint32_t gidx_begin; // beginning bin
|
||||
const uint32_t gidx_end; // end bin for i^th feature
|
||||
const dh::LDGIterator<float> feature_values;
|
||||
const GradientPairPrecise *node_histogram;
|
||||
const GradientPairInt64 *node_histogram;
|
||||
const GradientQuantizer &rounding;
|
||||
const GradientPairPrecise parent_sum;
|
||||
const GradientPairPrecise missing;
|
||||
const GPUTrainingParam ¶m;
|
||||
@ -79,6 +80,7 @@ class EvaluateSplitAgent {
|
||||
gidx_end(__ldg(shared_inputs.feature_segments.data() + fidx + 1)),
|
||||
feature_values(shared_inputs.feature_values.data()),
|
||||
node_histogram(inputs.gradient_histogram.data()),
|
||||
rounding(shared_inputs.rounding),
|
||||
parent_sum(dh::LDGIterator<GradientPairPrecise>(&inputs.parent_sum)[0]),
|
||||
param(shared_inputs.param),
|
||||
evaluator(evaluator),
|
||||
@ -98,11 +100,12 @@ class EvaluateSplitAgent {
|
||||
}
|
||||
|
||||
// Load using efficient 128 vector load instruction
|
||||
__device__ __forceinline__ GradientPairPrecise LoadGpair(const GradientPairPrecise *ptr) {
|
||||
static_assert(sizeof(GradientPairPrecise) == sizeof(float4),
|
||||
"Vector type size does not match gradient pair size.");
|
||||
__device__ __forceinline__ GradientPairPrecise LoadGpair(const GradientPairInt64 *ptr) {
|
||||
float4 tmp = *reinterpret_cast<const float4 *>(ptr);
|
||||
return *reinterpret_cast<const GradientPairPrecise *>(&tmp);
|
||||
auto gpair_int = *reinterpret_cast<const GradientPairInt64 *>(&tmp);
|
||||
static_assert(sizeof(decltype(gpair_int)) == sizeof(float4),
|
||||
"Vector type size does not match gradient pair size.");
|
||||
return rounding.ToFloatingPoint(gpair_int);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void Numerical(DeviceSplitCandidate *__restrict__ best_split) {
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "../split_evaluator.h"
|
||||
#include "../updater_gpu_common.cuh"
|
||||
#include "expand_entry.cuh"
|
||||
#include "histogram.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -24,12 +25,13 @@ struct EvaluateSplitInputs {
|
||||
int depth;
|
||||
GradientPairPrecise parent_sum;
|
||||
common::Span<const bst_feature_t> feature_set;
|
||||
common::Span<const GradientPairPrecise> gradient_histogram;
|
||||
common::Span<const GradientPairInt64> gradient_histogram;
|
||||
};
|
||||
|
||||
// Inputs necessary for all nodes
|
||||
struct EvaluateSplitSharedInputs {
|
||||
GPUTrainingParam param;
|
||||
GradientQuantizer rounding;
|
||||
common::Span<FeatureType const> feature_types;
|
||||
common::Span<const uint32_t> feature_segments;
|
||||
common::Span<const float> feature_values;
|
||||
|
||||
@ -83,8 +83,9 @@ common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
|
||||
auto j = i % total_bins;
|
||||
auto fidx = d_feature_idx[j];
|
||||
if (common::IsCat(shared_inputs.feature_types, fidx)) {
|
||||
auto lw = evaluator.CalcWeightCat(shared_inputs.param,
|
||||
input.gradient_histogram[j]);
|
||||
auto grad =
|
||||
shared_inputs.rounding.ToFloatingPoint(input.gradient_histogram[j]);
|
||||
auto lw = evaluator.CalcWeightCat(shared_inputs.param, grad);
|
||||
return thrust::make_tuple(i, lw);
|
||||
}
|
||||
return thrust::make_tuple(i, 0.0f);
|
||||
|
||||
@ -72,30 +72,35 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const> gpair) {
|
||||
GradientQuantizer::GradientQuantizer(common::Span<GradientPair const> gpair) {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
thrust::device_ptr<GradientPair const> gpair_beg{gpair.data()};
|
||||
thrust::device_ptr<GradientPair const> gpair_end{gpair.data() + gpair.size()};
|
||||
auto beg = thrust::make_transform_iterator(gpair_beg, Clip());
|
||||
auto end = thrust::make_transform_iterator(gpair_end, Clip());
|
||||
Pair p = dh::Reduce(thrust::cuda::par(alloc), beg, end, Pair{}, thrust::plus<Pair>{});
|
||||
Pair p =
|
||||
dh::Reduce(thrust::cuda::par(alloc), beg, beg + gpair.size(), Pair{}, thrust::plus<Pair>{});
|
||||
// Treat pair as array of 4 primitive types to allreduce
|
||||
using ReduceT = typename decltype(p.first)::ValueT;
|
||||
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
|
||||
rabit::Allreduce<rabit::op::Sum, ReduceT>(reinterpret_cast<ReduceT*>(&p), 4);
|
||||
GradientPair positive_sum{p.first}, negative_sum{p.second};
|
||||
|
||||
auto histogram_rounding =
|
||||
GradientSumT{CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()),
|
||||
gpair.size()),
|
||||
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
|
||||
gpair.size())};
|
||||
std::size_t total_rows = gpair.size();
|
||||
rabit::Allreduce<rabit::op::Sum>(&total_rows, 1);
|
||||
|
||||
using IntT = typename HistRounding<GradientSumT>::SharedSumT::ValueT;
|
||||
auto histogram_rounding = GradientSumT{
|
||||
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), total_rows),
|
||||
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
|
||||
total_rows)};
|
||||
|
||||
using IntT = typename GradientPairInt64::ValueT;
|
||||
|
||||
/**
|
||||
* Factor for converting gradients from fixed-point to floating-point.
|
||||
*/
|
||||
GradientSumT to_floating_point =
|
||||
to_floating_point_ =
|
||||
histogram_rounding /
|
||||
T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 - 2)); // keep 1 for sign bit
|
||||
/**
|
||||
@ -107,35 +112,55 @@ HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const>
|
||||
* rounding is calcuated as exp(m), see the rounding factor calcuation for
|
||||
* details.
|
||||
*/
|
||||
GradientSumT to_fixed_point =
|
||||
GradientSumT(T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess());
|
||||
|
||||
return {histogram_rounding, to_fixed_point, to_floating_point};
|
||||
to_fixed_point_ =
|
||||
GradientSumT(T(1) / to_floating_point_.GetGrad(), T(1) / to_floating_point_.GetHess());
|
||||
}
|
||||
|
||||
template HistRounding<GradientPairPrecise> CreateRoundingFactor(
|
||||
common::Span<GradientPair const> gpair);
|
||||
template HistRounding<GradientPair> CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
|
||||
template <typename GradientSumT, int kBlockThreads, int kItemsPerThread,
|
||||
XGBOOST_DEV_INLINE void
|
||||
AtomicAddGpairShared(xgboost::GradientPairInt64 *dest,
|
||||
xgboost::GradientPairInt64 const &gpair) {
|
||||
auto dst_ptr = reinterpret_cast<int64_t *>(dest);
|
||||
auto g = gpair.GetQuantisedGrad();
|
||||
auto h = gpair.GetQuantisedHess();
|
||||
|
||||
AtomicAdd64As32(dst_ptr, g);
|
||||
AtomicAdd64As32(dst_ptr + 1, h);
|
||||
}
|
||||
|
||||
// Global 64 bit integer atomics at the time of writing do not benefit from being separated into two
|
||||
// 32 bit atomics
|
||||
XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest,
|
||||
xgboost::GradientPairInt64 const& gpair) {
|
||||
auto dst_ptr = reinterpret_cast<uint64_t*>(dest);
|
||||
auto g = gpair.GetQuantisedGrad();
|
||||
auto h = gpair.GetQuantisedHess();
|
||||
|
||||
atomicAdd(dst_ptr,
|
||||
*reinterpret_cast<uint64_t*>(&g));
|
||||
atomicAdd(dst_ptr + 1,
|
||||
*reinterpret_cast<uint64_t*>(&h));
|
||||
}
|
||||
|
||||
template <int kBlockThreads, int kItemsPerThread,
|
||||
int kItemsPerTile = kBlockThreads* kItemsPerThread>
|
||||
class HistogramAgent {
|
||||
using SharedSumT = typename HistRounding<GradientSumT>::SharedSumT;
|
||||
SharedSumT* smem_arr_;
|
||||
GradientSumT* d_node_hist_;
|
||||
GradientPairInt64* smem_arr_;
|
||||
GradientPairInt64* d_node_hist_;
|
||||
dh::LDGIterator<const RowPartitioner::RowIndexT> d_ridx_;
|
||||
const GradientPair* d_gpair_;
|
||||
const FeatureGroup group_;
|
||||
const EllpackDeviceAccessor& matrix_;
|
||||
const int feature_stride_;
|
||||
const std::size_t n_elements_;
|
||||
const HistRounding<GradientSumT>& rounding_;
|
||||
const GradientQuantizer& rounding_;
|
||||
|
||||
public:
|
||||
__device__ HistogramAgent(SharedSumT* smem_arr, GradientSumT* __restrict__ d_node_hist,
|
||||
const FeatureGroup& group, const EllpackDeviceAccessor& matrix,
|
||||
__device__ HistogramAgent(GradientPairInt64* smem_arr,
|
||||
GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group,
|
||||
const EllpackDeviceAccessor& matrix,
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
const HistRounding<GradientSumT>& rounding, const GradientPair* d_gpair)
|
||||
const GradientQuantizer& rounding, const GradientPair* d_gpair)
|
||||
: smem_arr_(smem_arr),
|
||||
d_node_hist_(d_node_hist),
|
||||
d_ridx_(d_ridx.data()),
|
||||
@ -155,7 +180,7 @@ class HistogramAgent {
|
||||
group_.start_bin;
|
||||
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
||||
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
||||
dh::AtomicAddGpair(smem_arr_ + gidx, adjusted);
|
||||
AtomicAddGpairShared(smem_arr_ + gidx, adjusted);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -185,12 +210,12 @@ class HistogramAgent {
|
||||
for (int i = 0; i < kItemsPerThread; i++) {
|
||||
if ((matrix_.is_dense || gidx[i] != matrix_.NumBins())) {
|
||||
auto adjusted = rounding_.ToFixedPoint(gpair[i]);
|
||||
dh::AtomicAddGpair(smem_arr_ + gidx[i] - group_.start_bin, adjusted);
|
||||
AtomicAddGpairShared(smem_arr_ + gidx[i] - group_.start_bin, adjusted);
|
||||
}
|
||||
}
|
||||
}
|
||||
__device__ void BuildHistogramWithShared() {
|
||||
dh::BlockFill(smem_arr_, group_.num_bins, SharedSumT());
|
||||
dh::BlockFill(smem_arr_, group_.num_bins, GradientPairInt64());
|
||||
__syncthreads();
|
||||
|
||||
std::size_t offset = blockIdx.x * kItemsPerTile;
|
||||
@ -203,8 +228,7 @@ class HistogramAgent {
|
||||
// Write shared memory back to global memory
|
||||
__syncthreads();
|
||||
for (auto i : dh::BlockStrideRange(0, group_.num_bins)) {
|
||||
auto truncated = rounding_.ToFloatingPoint(smem_arr_[i]);
|
||||
dh::AtomicAddGpair(d_node_hist_ + group_.start_bin + i, truncated);
|
||||
AtomicAddGpairGlobal(d_node_hist_ + group_.start_bin + i, smem_arr_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,36 +239,26 @@ class HistogramAgent {
|
||||
matrix_
|
||||
.gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_];
|
||||
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
||||
// If we are not using shared memory, accumulate the values directly into
|
||||
// global memory
|
||||
GradientSumT truncated{
|
||||
TruncateWithRoundingFactor<GradientSumT::ValueT>(rounding_.rounding.GetGrad(),
|
||||
d_gpair_[ridx].GetGrad()),
|
||||
TruncateWithRoundingFactor<GradientSumT::ValueT>(rounding_.rounding.GetHess(),
|
||||
d_gpair_[ridx].GetHess()),
|
||||
};
|
||||
dh::AtomicAddGpair(d_node_hist_ + gidx, truncated);
|
||||
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
||||
AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT, bool use_shared_memory_histograms, int kBlockThreads,
|
||||
template <bool use_shared_memory_histograms, int kBlockThreads,
|
||||
int kItemsPerThread>
|
||||
__global__ void __launch_bounds__(kBlockThreads)
|
||||
SharedMemHistKernel(const EllpackDeviceAccessor matrix,
|
||||
const FeatureGroupsAccessor feature_groups,
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
GradientSumT* __restrict__ d_node_hist,
|
||||
GradientPairInt64* __restrict__ d_node_hist,
|
||||
const GradientPair* __restrict__ d_gpair,
|
||||
HistRounding<GradientSumT> const rounding) {
|
||||
using SharedSumT = typename HistRounding<GradientSumT>::SharedSumT;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
|
||||
GradientQuantizer const rounding) {
|
||||
extern __shared__ char smem[];
|
||||
const FeatureGroup group = feature_groups[blockIdx.y];
|
||||
SharedSumT* smem_arr = reinterpret_cast<SharedSumT*>(smem);
|
||||
auto agent = HistogramAgent<GradientSumT, kBlockThreads, kItemsPerThread>(
|
||||
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
|
||||
auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>(
|
||||
smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair);
|
||||
if (use_shared_memory_histograms) {
|
||||
agent.BuildHistogramWithShared();
|
||||
@ -253,13 +267,12 @@ __global__ void __launch_bounds__(kBlockThreads)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> d_ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
HistRounding<GradientSumT> rounding, bool force_global_memory) {
|
||||
common::Span<GradientPairInt64> histogram,
|
||||
GradientQuantizer rounding, bool force_global_memory) {
|
||||
// decide whether to use shared memory
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
@ -267,7 +280,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
size_t max_shared_memory = dh::MaxSharedMemoryOptin(device);
|
||||
|
||||
size_t smem_size =
|
||||
sizeof(typename HistRounding<GradientSumT>::SharedSumT) * feature_groups.max_group_bins;
|
||||
sizeof(GradientPairInt64) * feature_groups.max_group_bins;
|
||||
bool shared = !force_global_memory && smem_size <= max_shared_memory;
|
||||
smem_size = shared ? smem_size : 0;
|
||||
|
||||
@ -311,19 +324,13 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
};
|
||||
|
||||
if (shared) {
|
||||
runit(SharedMemHistKernel<GradientSumT, true, kBlockThreads, kItemsPerThread>);
|
||||
runit(SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>);
|
||||
} else {
|
||||
runit(SharedMemHistKernel<GradientSumT, false, kBlockThreads, kItemsPerThread>);
|
||||
runit(SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>);
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
|
||||
template void BuildGradientHistogram<GradientPairPrecise>(
|
||||
EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair, common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPairPrecise> histogram, HistRounding<GradientPairPrecise> rounding,
|
||||
bool force_global_memory);
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -12,56 +12,51 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename T, typename U>
|
||||
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, U const x) {
|
||||
static_assert(sizeof(T) >= sizeof(U), "Rounding must have higher or equal precision.");
|
||||
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
|
||||
/**
|
||||
* \brief An atomicAdd designed for gradient pair with better performance. For general
|
||||
* int64_t atomicAdd, one can simply cast it to unsigned long long. Exposed for testing.
|
||||
*/
|
||||
XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) {
|
||||
uint32_t* y_low = reinterpret_cast<uint32_t*>(dst);
|
||||
uint32_t* y_high = y_low + 1;
|
||||
|
||||
auto cast_src = reinterpret_cast<uint64_t *>(&src);
|
||||
|
||||
uint32_t const x_low = static_cast<uint32_t>(src);
|
||||
uint32_t const x_high = (*cast_src) >> 32;
|
||||
|
||||
auto const old = atomicAdd(y_low, x_low);
|
||||
uint32_t const carry = old > (std::numeric_limits<uint32_t>::max() - x_low) ? 1 : 0;
|
||||
uint32_t const sig = x_high + carry;
|
||||
atomicAdd(y_high, sig);
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncation factor for gradient, see comments in `CreateRoundingFactor()` for details.
|
||||
*/
|
||||
template <typename GradientSumT>
|
||||
struct HistRounding {
|
||||
/* Factor to truncate the gradient before building histogram for deterministic result. */
|
||||
GradientSumT rounding;
|
||||
class GradientQuantizer {
|
||||
private:
|
||||
/* Convert gradient to fixed point representation. */
|
||||
GradientSumT to_fixed_point;
|
||||
GradientPairPrecise to_fixed_point_;
|
||||
/* Convert fixed point representation back to floating point. */
|
||||
GradientSumT to_floating_point;
|
||||
|
||||
/* Type used in shared memory. */
|
||||
using SharedSumT = std::conditional_t<
|
||||
std::is_same<typename GradientSumT::ValueT, float>::value,
|
||||
GradientPairInt32, GradientPairInt64>;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
|
||||
XGBOOST_DEV_INLINE SharedSumT ToFixedPoint(GradientPair const& gpair) const {
|
||||
auto adjusted = SharedSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()),
|
||||
T(gpair.GetHess() * to_fixed_point.GetHess()));
|
||||
GradientPairPrecise to_floating_point_;
|
||||
public:
|
||||
explicit GradientQuantizer(common::Span<GradientPair const> gpair);
|
||||
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
|
||||
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
||||
gpair.GetHess() * to_fixed_point_.GetHess());
|
||||
return adjusted;
|
||||
}
|
||||
XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(SharedSumT const &gpair) const {
|
||||
auto g = gpair.GetGrad() * to_floating_point.GetGrad();
|
||||
auto h = gpair.GetHess() * to_floating_point.GetHess();
|
||||
GradientSumT truncated{
|
||||
TruncateWithRoundingFactor<T>(rounding.GetGrad(), g),
|
||||
TruncateWithRoundingFactor<T>(rounding.GetHess(), h),
|
||||
};
|
||||
return truncated;
|
||||
XGBOOST_DEVICE GradientPairPrecise ToFloatingPoint(const GradientPairInt64&gpair) const {
|
||||
auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad();
|
||||
auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess();
|
||||
return {g,h};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
HistRounding<GradientSumT> rounding,
|
||||
common::Span<GradientPairInt64> histogram,
|
||||
GradientQuantizer rounding,
|
||||
bool force_global_memory = false);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -72,9 +72,10 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||
* \author Rory
|
||||
* \date 28/07/2018
|
||||
*/
|
||||
template <typename GradientSumT, size_t kStopGrowingSize = 1 << 28>
|
||||
template <size_t kStopGrowingSize = 1 << 28>
|
||||
class DeviceHistogramStorage {
|
||||
private:
|
||||
using GradientSumT = GradientPairInt64;
|
||||
/*! \brief Map nidx to starting index of its histogram. */
|
||||
std::map<int, size_t> nidx_map_;
|
||||
// Large buffer of zeroed memory, caches histograms
|
||||
@ -180,7 +181,7 @@ struct GPUHistMakerDevice {
|
||||
BatchParam batch_param;
|
||||
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogramStorage<GradientSumT> hist{};
|
||||
DeviceHistogramStorage<> hist{};
|
||||
|
||||
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
|
||||
common::Span<GradientPair> gpair;
|
||||
@ -193,7 +194,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
TrainParam param;
|
||||
|
||||
HistRounding<GradientSumT> histogram_rounding;
|
||||
std::unique_ptr<GradientQuantizer> histogram_rounding;
|
||||
|
||||
dh::PinnedMemory pinned;
|
||||
dh::PinnedMemory pinned2;
|
||||
@ -265,7 +266,7 @@ struct GPUHistMakerDevice {
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
|
||||
histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair);
|
||||
histogram_rounding.reset(new GradientQuantizer(this->gpair));
|
||||
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
|
||||
@ -282,7 +283,11 @@ struct GPUHistMakerDevice {
|
||||
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
gpu_param, feature_types, matrix.feature_segments, matrix.gidx_fvalue_map,
|
||||
gpu_param,
|
||||
*histogram_rounding,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
};
|
||||
auto split = this->evaluator_.EvaluateSingleSplit(inputs, shared_inputs);
|
||||
@ -298,7 +303,7 @@ struct GPUHistMakerDevice {
|
||||
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
|
||||
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
GPUTrainingParam{param}, feature_types, matrix.feature_segments,
|
||||
GPUTrainingParam{param}, *histogram_rounding, feature_types, matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map, matrix.min_fvalue,
|
||||
};
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
||||
@ -344,7 +349,7 @@ struct GPUHistMakerDevice {
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id),
|
||||
feature_groups->DeviceAccessor(ctx_->gpu_id), gpair,
|
||||
d_ridx, d_node_hist, histogram_rounding);
|
||||
d_ridx, d_node_hist, *histogram_rounding);
|
||||
}
|
||||
|
||||
// Attempt to do subtraction trick
|
||||
@ -526,11 +531,10 @@ struct GPUHistMakerDevice {
|
||||
void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
reducer->AllReduceSum(reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() *
|
||||
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) *
|
||||
num_histograms);
|
||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||
reducer->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
|
||||
monitor.Stop("AllReduce");
|
||||
}
|
||||
|
||||
@ -22,23 +22,4 @@ inline std::vector<float> OneHotEncodeFeature(std::vector<float> x,
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void ValidateCategoricalHistogram(size_t n_categories,
|
||||
common::Span<GradientSumT> onehot,
|
||||
common::Span<GradientSumT> cat) {
|
||||
auto cat_sum = std::accumulate(cat.cbegin(), cat.cend(), GradientPairPrecise{});
|
||||
for (size_t c = 0; c < n_categories; ++c) {
|
||||
auto zero = onehot[c * 2];
|
||||
auto one = onehot[c * 2 + 1];
|
||||
|
||||
auto chosen = cat[c];
|
||||
auto not_chosen = cat_sum - chosen;
|
||||
|
||||
ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps);
|
||||
ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps);
|
||||
|
||||
ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps);
|
||||
ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -196,72 +196,4 @@ TEST(DeviceHelpers, ArgSort) {
|
||||
thrust::greater<size_t>{}));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Atomic add as type cast for test.
|
||||
XGBOOST_DEV_INLINE int64_t atomicAdd(int64_t *dst, int64_t src) { // NOLINT
|
||||
uint64_t* u_dst = reinterpret_cast<uint64_t*>(dst);
|
||||
uint64_t u_src = *reinterpret_cast<uint64_t*>(&src);
|
||||
uint64_t ret = ::atomicAdd(u_dst, u_src);
|
||||
return *reinterpret_cast<int64_t*>(&ret);
|
||||
}
|
||||
}
|
||||
|
||||
void TestAtomicAdd() {
|
||||
size_t n_elements = 1024;
|
||||
dh::device_vector<int64_t> result_a(1, 0);
|
||||
auto d_result_a = result_a.data().get();
|
||||
|
||||
dh::device_vector<int64_t> result_b(1, 0);
|
||||
auto d_result_b = result_b.data().get();
|
||||
|
||||
/**
|
||||
* Test for simple inputs
|
||||
*/
|
||||
std::vector<int64_t> h_inputs(n_elements);
|
||||
for (size_t i = 0; i < h_inputs.size(); ++i) {
|
||||
h_inputs[i] = (i % 2 == 0) ? i : -i;
|
||||
}
|
||||
dh::device_vector<int64_t> inputs(h_inputs);
|
||||
auto d_inputs = inputs.data().get();
|
||||
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
|
||||
/**
|
||||
* Test for positive values that don't fit into 32 bit integer.
|
||||
*/
|
||||
thrust::fill(inputs.begin(), inputs.end(),
|
||||
(std::numeric_limits<uint32_t>::max() / 2));
|
||||
thrust::fill(result_a.begin(), result_a.end(), 0);
|
||||
thrust::fill(result_b.begin(), result_b.end(), 0);
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
ASSERT_GT(result_a[0], std::numeric_limits<uint32_t>::max());
|
||||
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
|
||||
|
||||
/**
|
||||
* Test for negative values that don't fit into 32 bit integer.
|
||||
*/
|
||||
thrust::fill(inputs.begin(), inputs.end(),
|
||||
(std::numeric_limits<int32_t>::min() / 2));
|
||||
thrust::fill(result_a.begin(), result_a.end(), 0);
|
||||
thrust::fill(result_b.begin(), result_b.end(), 0);
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
ASSERT_LT(result_a[0], std::numeric_limits<int32_t>::min());
|
||||
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
|
||||
}
|
||||
|
||||
TEST(AtomicAdd, Int64) {
|
||||
TestAtomicAdd();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include "../../helpers.h"
|
||||
#include "../../histogram_helpers.h"
|
||||
#include "../test_evaluate_splits.h" // TestPartitionBasedSplit
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -21,13 +22,29 @@ auto ZeroParam() {
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
inline GradientQuantizer DummyRoundingFactor() {
|
||||
thrust::device_vector<GradientPair> gpair(1);
|
||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
||||
return GradientQuantizer(dh::ToSpan(gpair));
|
||||
}
|
||||
|
||||
thrust::device_vector<GradientPairInt64> ConvertToInteger(std::vector<GradientPairPrecise> x) {
|
||||
auto r = DummyRoundingFactor();
|
||||
std::vector<GradientPairInt64> y(x.size());
|
||||
for (int i = 0; i < x.size(); i++) {
|
||||
y[i] = r.ToFixedPoint(GradientPair(x[i]));
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
|
||||
GPUTrainingParam param{param_};
|
||||
cuts_.cut_ptrs_.SetDevice(0);
|
||||
cuts_.cut_values_.SetDevice(0);
|
||||
cuts_.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram{feature_histogram_};
|
||||
thrust::device_vector<GradientPairInt64> feature_histogram{ConvertToInteger(feature_histogram_)};
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
auto d_feature_types = dh::ToSpan(feature_types);
|
||||
@ -36,6 +53,7 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
d_feature_types,
|
||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts_.cut_values_.ConstDeviceSpan(),
|
||||
@ -76,6 +94,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -89,8 +108,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -105,8 +123,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-7.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-3.0, 1.0}, {-3.0, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-3.0, 1.0}, {-3.0, 1.0}});
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -119,8 +136,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
{
|
||||
// All -1.0, gain from splitting should be 0.0
|
||||
GradientPairPrecise parent_sum(-3.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
||||
EvaluateSplitInputs input{2, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -133,8 +149,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
// Forward, first 2 categories are selected, while the last one go to left along with missing value
|
||||
{
|
||||
GradientPairPrecise parent_sum(0.0, 6.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
||||
EvaluateSplitInputs input{3, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -148,8 +163,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-3.0, 1.0}, {-1.0, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-3.0, 1.0}, {-1.0, 1.0}});
|
||||
EvaluateSplitInputs input{4, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -163,8 +177,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-3.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({{-3.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
||||
EvaluateSplitInputs input{5, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -198,6 +211,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -209,8 +223,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram = std::vector<GradientPairPrecise>{
|
||||
{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({ {-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -223,8 +236,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram = std::vector<GradientPairPrecise>{
|
||||
{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}};
|
||||
auto feature_histogram = ConvertToInteger({ {-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}});
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -259,6 +271,7 @@ TEST(GpuHist, PartitionTwoNodes) {
|
||||
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -270,14 +283,12 @@ TEST(GpuHist, PartitionTwoNodes) {
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram_a =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0},
|
||||
{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
|
||||
auto feature_histogram_a = ConvertToInteger({{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0},
|
||||
{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
thrust::device_vector<EvaluateSplitInputs> inputs(2);
|
||||
inputs[0] = EvaluateSplitInputs{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_a)};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram_b =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
|
||||
auto feature_histogram_b = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
inputs[1] = EvaluateSplitInputs{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_b)};
|
||||
thrust::device_vector<GPUExpandEntry> results(2);
|
||||
@ -300,9 +311,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
|
||||
auto feature_histogram = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(),
|
||||
FeatureType::kCategorical);
|
||||
@ -318,6 +327,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
@ -360,14 +370,14 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
std::vector<bst_row_t>{0, 2};
|
||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0};
|
||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-0.5, 0.5}, {0.5, 0.5}};
|
||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -388,7 +398,11 @@ TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUHistEvaluator evaluator(tparam, 1, 0);
|
||||
DeviceSplitCandidate result =
|
||||
evaluator.EvaluateSingleSplit(EvaluateSplitInputs{}, EvaluateSplitSharedInputs{}).split;
|
||||
evaluator
|
||||
.EvaluateSingleSplit(
|
||||
EvaluateSplitInputs{},
|
||||
EvaluateSplitSharedInputs{GPUTrainingParam(tparam), DummyRoundingFactor()})
|
||||
.split;
|
||||
EXPECT_EQ(result.findex, -1);
|
||||
EXPECT_LT(result.loss_chg, 0.0f);
|
||||
}
|
||||
@ -408,15 +422,14 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 10.0};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
auto feature_histogram = ConvertToInteger({ {-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -447,15 +460,14 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 10.0};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
auto feature_histogram = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -484,12 +496,8 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 0.0};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram_left =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram_right =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
auto feature_histogram_left = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
auto feature_histogram_right = ConvertToInteger({ {-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input_left{
|
||||
1,0,
|
||||
parent_sum,
|
||||
@ -502,6 +510,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
dh::ToSpan(feature_histogram_right)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
DummyRoundingFactor(),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
@ -533,20 +542,26 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
|
||||
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, 0);
|
||||
|
||||
dh::device_vector<GradientPairPrecise> d_hist(hist_[0].size());
|
||||
auto node_hist = hist_[0];
|
||||
dh::safe_cuda(cudaMemcpy(d_hist.data().get(), node_hist.data(), node_hist.size_bytes(),
|
||||
cudaMemcpyHostToDevice));
|
||||
// Convert the sample histogram to fixed point
|
||||
auto rounding = DummyRoundingFactor();
|
||||
thrust::host_vector<GradientPairInt64> h_hist;
|
||||
for(auto e: hist_[0]){
|
||||
h_hist.push_back(rounding.ToFixedPoint({float(e.GetGrad()),float(e.GetHess())}));
|
||||
}
|
||||
dh::device_vector<GradientPairInt64> d_hist = h_hist;
|
||||
dh::device_vector<bst_feature_t> feature_set{std::vector<bst_feature_t>{0}};
|
||||
|
||||
EvaluateSplitInputs input{0, 0, total_gpair_, dh::ToSpan(feature_set), dh::ToSpan(d_hist)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
GPUTrainingParam{param_}, dh::ToSpan(ft),
|
||||
cuts_.cut_ptrs_.ConstDeviceSpan(), cuts_.cut_values_.ConstDeviceSpan(),
|
||||
GPUTrainingParam{param_},
|
||||
rounding,
|
||||
dh::ToSpan(ft),
|
||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts_.cut_values_.ConstDeviceSpan(),
|
||||
cuts_.min_vals_.ConstDeviceSpan(),
|
||||
};
|
||||
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
|
||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-2);
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -10,7 +10,6 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename Gradient>
|
||||
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
||||
float constexpr kLower = -1e-2, kUpper = 1e2;
|
||||
@ -26,41 +25,41 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
auto ridx = row_partitioner.GetRows(0);
|
||||
|
||||
int num_bins = kBins * kCols;
|
||||
dh::device_vector<Gradient> histogram(num_bins);
|
||||
dh::device_vector<GradientPairInt64> histogram(num_bins);
|
||||
auto d_histogram = dh::ToSpan(histogram);
|
||||
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
||||
gpair.SetDevice(0);
|
||||
|
||||
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
||||
sizeof(Gradient));
|
||||
sizeof(GradientPairInt64));
|
||||
|
||||
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
|
||||
auto rounding = GradientQuantizer(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(),
|
||||
ridx, d_histogram, rounding);
|
||||
|
||||
std::vector<Gradient> histogram_h(num_bins);
|
||||
std::vector<GradientPairInt64> histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
||||
num_bins * sizeof(Gradient),
|
||||
num_bins * sizeof(GradientPairInt64),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < kRounds; ++i) {
|
||||
dh::device_vector<Gradient> new_histogram(num_bins);
|
||||
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
|
||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||
|
||||
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
|
||||
auto rounding = GradientQuantizer(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, d_new_histogram,
|
||||
rounding);
|
||||
|
||||
std::vector<Gradient> new_histogram_h(num_bins);
|
||||
std::vector<GradientPairInt64> new_histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
|
||||
num_bins * sizeof(Gradient),
|
||||
num_bins * sizeof(GradientPairInt64),
|
||||
cudaMemcpyDeviceToHost));
|
||||
for (size_t j = 0; j < new_histogram_h.size(); ++j) {
|
||||
ASSERT_EQ(new_histogram_h[j].GetGrad(), histogram_h[j].GetGrad());
|
||||
ASSERT_EQ(new_histogram_h[j].GetHess(), histogram_h[j].GetHess());
|
||||
ASSERT_EQ(new_histogram_h[j].GetQuantisedGrad(), histogram_h[j].GetQuantisedGrad());
|
||||
ASSERT_EQ(new_histogram_h[j].GetQuantisedHess(), histogram_h[j].GetQuantisedHess());
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,20 +70,20 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
// Use a single feature group to compute the baseline.
|
||||
FeatureGroups single_group(page->Cuts());
|
||||
|
||||
dh::device_vector<Gradient> baseline(num_bins);
|
||||
dh::device_vector<GradientPairInt64> baseline(num_bins);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
single_group.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, dh::ToSpan(baseline),
|
||||
rounding);
|
||||
|
||||
std::vector<Gradient> baseline_h(num_bins);
|
||||
std::vector<GradientPairInt64> baseline_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
|
||||
num_bins * sizeof(Gradient),
|
||||
num_bins * sizeof(GradientPairInt64),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < baseline.size(); ++i) {
|
||||
EXPECT_NEAR(baseline_h[i].GetGrad(), histogram_h[i].GetGrad(),
|
||||
baseline_h[i].GetGrad() * 1e-3);
|
||||
EXPECT_NEAR(baseline_h[i].GetQuantisedGrad(), histogram_h[i].GetQuantisedGrad(),
|
||||
baseline_h[i].GetQuantisedGrad() * 1e-3);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -95,11 +94,25 @@ TEST(Histogram, GPUDeterministic) {
|
||||
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
|
||||
for (bool is_dense : is_dense_array) {
|
||||
for (int shm_size : shm_sizes) {
|
||||
TestDeterministicHistogram<GradientPairPrecise>(is_dense, shm_size);
|
||||
TestDeterministicHistogram(is_dense, shm_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateCategoricalHistogram(size_t n_categories, common::Span<GradientPairInt64> onehot,
|
||||
common::Span<GradientPairInt64> cat) {
|
||||
auto cat_sum = std::accumulate(cat.cbegin(), cat.cend(), GradientPairInt64{});
|
||||
for (size_t c = 0; c < n_categories; ++c) {
|
||||
auto zero = onehot[c * 2];
|
||||
auto one = onehot[c * 2 + 1];
|
||||
|
||||
auto chosen = cat[c];
|
||||
auto not_chosen = cat_sum - chosen;
|
||||
ASSERT_EQ(zero, not_chosen);
|
||||
ASSERT_EQ(one, chosen);
|
||||
}
|
||||
}
|
||||
|
||||
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
|
||||
void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
size_t constexpr kRows = 340;
|
||||
@ -110,10 +123,10 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
||||
tree::RowPartitioner row_partitioner(0, kRows);
|
||||
auto ridx = row_partitioner.GetRows(0);
|
||||
dh::device_vector<GradientPairPrecise> cat_hist(num_categories);
|
||||
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||
gpair.SetDevice(0);
|
||||
auto rounding = CreateRoundingFactor<GradientPairPrecise>(gpair.DeviceSpan());
|
||||
auto rounding = GradientQuantizer(gpair.DeviceSpan());
|
||||
/**
|
||||
* Generate hist with cat data.
|
||||
*/
|
||||
@ -131,7 +144,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
*/
|
||||
auto x_encoded = OneHotEncodeFeature(x, num_categories);
|
||||
auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories);
|
||||
dh::device_vector<GradientPairPrecise> encode_hist(2 * num_categories);
|
||||
dh::device_vector<GradientPairInt64> encode_hist(2 * num_categories);
|
||||
for (auto const &batch : encode_m->GetBatches<EllpackPage>(batch_param)) {
|
||||
auto* page = batch.Impl();
|
||||
FeatureGroups single_group(page->Cuts());
|
||||
@ -141,14 +154,14 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
rounding);
|
||||
}
|
||||
|
||||
std::vector<GradientPairPrecise> h_cat_hist(cat_hist.size());
|
||||
std::vector<GradientPairInt64> h_cat_hist(cat_hist.size());
|
||||
thrust::copy(cat_hist.begin(), cat_hist.end(), h_cat_hist.begin());
|
||||
|
||||
std::vector<GradientPairPrecise> h_encode_hist(encode_hist.size());
|
||||
std::vector<GradientPairInt64> h_encode_hist(encode_hist.size());
|
||||
thrust::copy(encode_hist.begin(), encode_hist.end(), h_encode_hist.begin());
|
||||
ValidateCategoricalHistogram(num_categories,
|
||||
common::Span<GradientPairPrecise>{h_encode_hist},
|
||||
common::Span<GradientPairPrecise>{h_cat_hist});
|
||||
common::Span<GradientPairInt64>{h_encode_hist},
|
||||
common::Span<GradientPairInt64>{h_cat_hist});
|
||||
}
|
||||
|
||||
TEST(Histogram, GPUHistCategorical) {
|
||||
@ -156,5 +169,74 @@ TEST(Histogram, GPUHistCategorical) {
|
||||
TestGPUHistogramCategorical(num_categories);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Atomic add as type cast for test.
|
||||
XGBOOST_DEV_INLINE int64_t atomicAdd(int64_t *dst, int64_t src) { // NOLINT
|
||||
uint64_t* u_dst = reinterpret_cast<uint64_t*>(dst);
|
||||
uint64_t u_src = *reinterpret_cast<uint64_t*>(&src);
|
||||
uint64_t ret = ::atomicAdd(u_dst, u_src);
|
||||
return *reinterpret_cast<int64_t*>(&ret);
|
||||
}
|
||||
}
|
||||
|
||||
void TestAtomicAdd() {
|
||||
size_t n_elements = 1024;
|
||||
dh::device_vector<int64_t> result_a(1, 0);
|
||||
auto d_result_a = result_a.data().get();
|
||||
|
||||
dh::device_vector<int64_t> result_b(1, 0);
|
||||
auto d_result_b = result_b.data().get();
|
||||
|
||||
/**
|
||||
* Test for simple inputs
|
||||
*/
|
||||
std::vector<int64_t> h_inputs(n_elements);
|
||||
for (size_t i = 0; i < h_inputs.size(); ++i) {
|
||||
h_inputs[i] = (i % 2 == 0) ? i : -i;
|
||||
}
|
||||
dh::device_vector<int64_t> inputs(h_inputs);
|
||||
auto d_inputs = inputs.data().get();
|
||||
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
|
||||
/**
|
||||
* Test for positive values that don't fit into 32 bit integer.
|
||||
*/
|
||||
thrust::fill(inputs.begin(), inputs.end(),
|
||||
(std::numeric_limits<uint32_t>::max() / 2));
|
||||
thrust::fill(result_a.begin(), result_a.end(), 0);
|
||||
thrust::fill(result_b.begin(), result_b.end(), 0);
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
ASSERT_GT(result_a[0], std::numeric_limits<uint32_t>::max());
|
||||
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
|
||||
|
||||
/**
|
||||
* Test for negative values that don't fit into 32 bit integer.
|
||||
*/
|
||||
thrust::fill(inputs.begin(), inputs.end(),
|
||||
(std::numeric_limits<int32_t>::min() / 2));
|
||||
thrust::fill(result_a.begin(), result_a.end(), 0);
|
||||
thrust::fill(result_b.begin(), result_b.end(), 0);
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
ASSERT_LT(result_a[0], std::numeric_limits<int32_t>::min());
|
||||
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
|
||||
}
|
||||
|
||||
TEST(Histogram, AtomicAddInt64) {
|
||||
TestAtomicAdd();
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -291,6 +291,26 @@ TEST(CPUHistogram, BuildHist) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename GradientSumT>
|
||||
void ValidateCategoricalHistogram(size_t n_categories,
|
||||
common::Span<GradientSumT> onehot,
|
||||
common::Span<GradientSumT> cat) {
|
||||
auto cat_sum = std::accumulate(cat.cbegin(), cat.cend(), GradientPairPrecise{});
|
||||
for (size_t c = 0; c < n_categories; ++c) {
|
||||
auto zero = onehot[c * 2];
|
||||
auto one = onehot[c * 2 + 1];
|
||||
|
||||
auto chosen = cat[c];
|
||||
auto not_chosen = cat_sum - chosen;
|
||||
|
||||
ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps);
|
||||
ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps);
|
||||
|
||||
ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps);
|
||||
ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps);
|
||||
}
|
||||
}
|
||||
|
||||
void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
size_t constexpr kRows = 340;
|
||||
int32_t constexpr kBins = 256;
|
||||
|
||||
@ -29,7 +29,7 @@ TEST(GpuHist, DeviceHistogram) {
|
||||
constexpr size_t kNBins = 128;
|
||||
constexpr int kNNodes = 4;
|
||||
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
|
||||
DeviceHistogramStorage<GradientPairPrecise, kStopGrowing> histogram;
|
||||
DeviceHistogramStorage<kStopGrowing> histogram;
|
||||
histogram.Init(0, kNBins);
|
||||
for (int i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistograms({i});
|
||||
@ -107,32 +107,27 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||
maker.hist.AllocateHistograms({0});
|
||||
maker.gpair = gpair.DeviceSpan();
|
||||
maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);
|
||||
maker.histogram_rounding.reset(new GradientQuantizer(maker.gpair));
|
||||
|
||||
BuildGradientHistogram(
|
||||
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), maker.row_partitioner->GetRows(0),
|
||||
maker.hist.GetNodeHistogram(0), maker.histogram_rounding,
|
||||
maker.hist.GetNodeHistogram(0), *maker.histogram_rounding,
|
||||
!use_shared_memory_histograms);
|
||||
|
||||
DeviceHistogramStorage<GradientSumT>& d_hist = maker.hist;
|
||||
DeviceHistogramStorage<>& d_hist = maker.hist;
|
||||
|
||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
||||
// d_hist.data stored in float, not gradient pair
|
||||
thrust::host_vector<GradientSumT> h_result (d_hist.Data().size() / 2);
|
||||
size_t data_size =
|
||||
sizeof(GradientSumT) /
|
||||
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT));
|
||||
data_size *= d_hist.Data().size();
|
||||
dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), data_size,
|
||||
thrust::host_vector<GradientPairInt64> h_result (node_histogram.size());
|
||||
dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), node_histogram.size_bytes(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
|
||||
std::cout << std::fixed;
|
||||
for (size_t i = 0; i < h_result.size(); ++i) {
|
||||
ASSERT_FALSE(std::isnan(h_result[i].GetGrad()));
|
||||
EXPECT_NEAR(h_result[i].GetGrad(), solution[i].GetGrad(), 0.01f);
|
||||
EXPECT_NEAR(h_result[i].GetHess(), solution[i].GetHess(), 0.01f);
|
||||
auto result = maker.histogram_rounding->ToFloatingPoint(h_result[i]);
|
||||
EXPECT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
|
||||
EXPECT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,6 +156,12 @@ HistogramCutsWrapper GetHostCutMatrix () {
|
||||
return cmat;
|
||||
}
|
||||
|
||||
inline GradientQuantizer DummyRoundingFactor() {
|
||||
thrust::device_vector<GradientPair> gpair(1);
|
||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
||||
return GradientQuantizer(dh::ToSpan(gpair));
|
||||
}
|
||||
|
||||
// TODO(trivialfis): This test is over simplified.
|
||||
TEST(GpuHist, EvaluateRootSplit) {
|
||||
constexpr int kNRows = 16;
|
||||
@ -209,10 +210,12 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
// Each row of hist_gpair represents gpairs for one feature.
|
||||
// Each entry represents a bin.
|
||||
std::vector<GradientPairPrecise> hist_gpair = GetHostHistGpair();
|
||||
std::vector<bst_float> hist;
|
||||
maker.histogram_rounding.reset(new GradientQuantizer(DummyRoundingFactor()));
|
||||
std::vector<int64_t> hist;
|
||||
for (auto pair : hist_gpair) {
|
||||
hist.push_back(pair.GetGrad());
|
||||
hist.push_back(pair.GetHess());
|
||||
auto grad = maker.histogram_rounding->ToFixedPoint({float(pair.GetGrad()),float(pair.GetHess())});
|
||||
hist.push_back(grad.GetQuantisedGrad());
|
||||
hist.push_back(grad.GetQuantisedHess());
|
||||
}
|
||||
|
||||
ASSERT_EQ(maker.hist.Data().size(), hist.size());
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user