Deterministic GPU histogram. (#5361)

* Use pre-rounding based method to obtain reproducible floating point
  summation.
* GPU Hist for regression and classification are bit-by-bit reproducible.
* Add doc.
* Switch to thrust reduce for `node_sum_gradient`.
This commit is contained in:
Jiaming Yuan 2020-03-04 15:13:28 +08:00 committed by GitHub
parent 9775da02d9
commit 8d06878bf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 410 additions and 97 deletions

View File

@ -230,6 +230,20 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information
Additional parameters for `gpu_hist` tree method
================================================
* ``single_precision_histogram``, [default=``false``]
- Use single precision to build histograms. See document for GPU support for more details.
* ``deterministic_histogram``, [default=``true``]
- Build histogram on GPU deterministically. Histogram building is not deterministic due
to the non-associative aspect of floating point summation. We employ a pre-rounding
routine to mitigate the issue, which may lead to slightly lower accuracy. Set to
``false`` to disable it.
Additional parameters for Dart Booster (``booster=dart``)
=========================================================

View File

@ -135,15 +135,15 @@ class GradientPairInternal {
/*! \brief second order gradient statistics */
T hess_;
XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; }
XGBOOST_DEVICE void SetHess(float h) { hess_ = h; }
XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; }
XGBOOST_DEVICE void SetHess(T h) { hess_ = h; }
public:
using ValueT = T;
XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {}
XGBOOST_DEVICE GradientPairInternal(float grad, float hess) {
XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
SetGrad(grad);
SetHess(hess);
}
@ -160,8 +160,8 @@ class GradientPairInternal {
SetHess(g.GetHess());
}
XGBOOST_DEVICE float GetGrad() const { return grad_; }
XGBOOST_DEVICE float GetHess() const { return hess_; }
XGBOOST_DEVICE T GetGrad() const { return grad_; }
XGBOOST_DEVICE T GetHess() const { return hess_; }
XGBOOST_DEVICE GradientPairInternal<T> &operator+=(
const GradientPairInternal<T> &rhs) {
@ -234,24 +234,6 @@ class GradientPairInternal {
return os;
}
};
template<>
inline XGBOOST_DEVICE float GradientPairInternal<int64_t>::GetGrad() const {
return grad_ * 1e-4f;
}
template<>
inline XGBOOST_DEVICE float GradientPairInternal<int64_t>::GetHess() const {
return hess_ * 1e-4f;
}
template<>
inline XGBOOST_DEVICE void GradientPairInternal<int64_t>::SetGrad(float g) {
grad_ = static_cast<int64_t>(std::round(g * 1e4));
}
template<>
inline XGBOOST_DEVICE void GradientPairInternal<int64_t>::SetHess(float h) {
hess_ = static_cast<int64_t>(std::round(h * 1e4));
}
} // namespace detail
/*! \brief gradient statistics pair usually needed in gradient boosting */
@ -260,11 +242,6 @@ using GradientPair = detail::GradientPairInternal<float>;
/*! \brief High precision gradient statistics pair */
using GradientPairPrecise = detail::GradientPairInternal<double>;
/*! \brief High precision gradient statistics pair with integer backed
* storage. Operators are associative where floating point versions are not
* associative. */
using GradientPairInteger = detail::GradientPairInternal<int64_t>;
using Args = std::vector<std::pair<std::string, std::string> >;
/*! \brief small eps gap for minimum split decision. */

View File

@ -1682,7 +1682,6 @@ class Booster(object):
if importance_type == 'weight':
# do a simpler tree dump to save time
trees = self.get_dump(fmap, with_stats=False)
fmap = {}
for tree in trees:
for line in tree.split('\n'):

View File

@ -68,7 +68,9 @@ def plot_importance(booster, ax=None, height=0.2,
raise ValueError('tree must be Booster, XGBModel or dict instance')
if not importance:
raise ValueError('Booster.get_score() results in empty')
raise ValueError(
'Booster.get_score() results in empty. ' +
'This maybe caused by having all trees as decision dumps.')
tuples = [(k, importance[k]) for k in importance]
if max_num_features is not None:

View File

@ -16,12 +16,12 @@
#include "xgboost/base.h"
#include "xgboost/tree_model.h"
#if defined(XGBOOST_STRICT_R_MODE)
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
#define OBSERVER_PRINT LOG(INFO)
#define OBSERVER_ENDL ""
#define OBSERVER_NEWLINE ""
#else
#define OBSERVER_PRINT std::cout
#define OBSERVER_PRINT std::cout << std::setprecision(17)
#define OBSERVER_ENDL std::endl
#define OBSERVER_NEWLINE "\n"
#endif // defined(XGBOOST_STRICT_R_MODE)

View File

@ -29,14 +29,14 @@ bool EllpackPageSource::Next() {
EllpackPage& EllpackPageSource::Value() {
LOG(FATAL) << "Internal Error: "
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
EllpackPage* page;
EllpackPage* page { nullptr };
return *page;
}
const EllpackPage& EllpackPageSource::Value() const {
LOG(FATAL) << "Internal Error: "
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
EllpackPage* page;
EllpackPage* page { nullptr };
return *page;
}

View File

@ -734,6 +734,7 @@ class LearnerImpl : public Learner {
monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true);
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
monitor_.Stop("PredictRaw");
monitor_.Start("GetGradient");

View File

@ -25,6 +25,7 @@ class SamplingStrategy {
public:
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
virtual GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) = 0;
virtual ~SamplingStrategy() = default;
};
/*! \brief No sampling in in-memory mode. */

View File

@ -0,0 +1,184 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include <thrust/reduce.h>
#include <thrust/iterator/transform_iterator.h>
#include <algorithm>
#include <ctgmath>
#include <limits>
#include "xgboost/base.h"
#include "row_partitioner.cuh"
#include "histogram.cuh"
#include "../../data/ellpack_page.cuh"
#include "../../common/device_helpers.cuh"
namespace xgboost {
namespace tree {
// Following 2 functions are slightly modifed version of fbcuda.
/* \brief Constructs a rounding factor used to truncate elements in a sum such that the
sum of the truncated elements is the same no matter what the order of the sum is.
* Algorithm 5: Reproducible Sequential Sum in 'Fast Reproducible Floating-Point
* Summation' by Demmel and Nguyen
* In algorithm 5 the bound is calculated as $max(|v_i|) * n$. Here we use the bound
*
* \begin{equation}
* max( fl(\sum^{V}_{v_i>0}{v_i}), fl(\sum^{V}_{v_i<0}|v_i|) )
* \end{equation}
*
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
*/
template <typename T>
DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) {
T delta = max_abs / (static_cast<T>(1.0) - 2 * n * std::numeric_limits<T>::epsilon());
// Calculate ceil(log_2(delta)).
// frexpf() calculates exp and returns `x` such that
// delta = x * 2^exp, where `x` in (-1.0, -0.5] U [0.5, 1).
// Because |x| < 1, exp is exactly ceil(log_2(delta)).
int exp;
std::frexp(delta, &exp);
// return M = 2 ^ ceil(log_2(delta))
return std::ldexp(static_cast<T>(1.0), exp);
}
namespace {
struct Pair {
GradientPair first;
GradientPair second;
};
DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
return {lhs.first + rhs.first, lhs.second + rhs.second};
}
} // anonymous namespace
struct Clip : public thrust::unary_function<GradientPair, Pair> {
static DEV_INLINE float Pclip(float v) {
return v > 0 ? v : 0;
}
static DEV_INLINE float Nclip(float v) {
return v < 0 ? abs(v) : 0;
}
DEV_INLINE Pair operator()(GradientPair x) const {
auto pg = Pclip(x.GetGrad());
auto ph = Pclip(x.GetHess());
auto ng = Nclip(x.GetGrad());
auto nh = Nclip(x.GetHess());
return { GradientPair{ pg, ph }, GradientPair{ ng, nh } };
}
};
template <typename GradientSumT>
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair) {
using T = typename GradientSumT::ValueT;
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::device_ptr<GradientPair const> gpair_beg {gpair.data()};
thrust::device_ptr<GradientPair const> gpair_end {gpair.data() + gpair.size()};
auto beg = thrust::make_transform_iterator(gpair_beg, Clip());
auto end = thrust::make_transform_iterator(gpair_end, Clip());
Pair p = thrust::reduce(thrust::cuda::par(alloc), beg, end, Pair{});
GradientPair positive_sum {p.first}, negative_sum {p.second};
auto histogram_rounding = GradientSumT {
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()),
gpair.size()),
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
gpair.size()) };
return histogram_rounding;
}
template GradientPairPrecise CreateRoundingFactor(common::Span<GradientPair const> gpair);
template GradientPair CreateRoundingFactor(common::Span<GradientPair const> gpair);
template <typename GradientSumT>
__global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
size_t n_elements,
GradientSumT const rounding,
bool use_shared_memory_histograms) {
using T = typename GradientSumT::ValueT;
extern __shared__ char smem[];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT());
__syncthreads();
}
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.info.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
if (gidx != matrix.info.n_bins) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), d_gpair[ridx].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(), d_gpair[ridx].GetHess()),
};
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT* atomic_add_ptr =
use_shared_memory_histograms ? smem_arr : d_node_hist;
dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated);
}
}
if (use_shared_memory_histograms) {
// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.info.n_bins)) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), smem_arr[i].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(), smem_arr[i].GetHess()),
};
dh::AtomicAddGpair(d_node_hist + i, truncated);
}
}
}
template <typename GradientSumT>
void BuildGradientHistogram(EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx,
common::Span<GradientSumT> histogram,
GradientSumT rounding, bool shared) {
const size_t smem_size =
shared
? sizeof(GradientSumT) * matrix.info.n_bins
: 0;
auto n_elements = d_ridx.size() * matrix.info.row_stride;
uint32_t items_per_thread = 8;
uint32_t block_threads = 256;
auto grid_size = static_cast<uint32_t>(
common::DivRoundUp(n_elements, items_per_thread * block_threads));
dh::LaunchKernel {grid_size, block_threads, smem_size} (
SharedMemHistKernel<GradientSumT>,
matrix, d_ridx, histogram.data(), gpair.data(), n_elements,
rounding, shared);
}
template void BuildGradientHistogram<GradientPair>(
EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPair> histogram,
GradientPair rounding, bool shared);
template void BuildGradientHistogram<GradientPairPrecise>(
EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPairPrecise> histogram,
GradientPairPrecise rounding, bool shared);
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,29 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#ifndef HISTOGRAM_CUH_
#define HISTOGRAM_CUH_
#include <thrust/transform.h>
#include "../../data/ellpack_page.cuh"
namespace xgboost {
namespace tree {
template <typename GradientSumT>
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
template <typename T>
DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) {
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
}
template <typename GradientSumT>
void BuildGradientHistogram(EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientSumT> histogram,
GradientSumT rounding, bool shared);
} // namespace tree
} // namespace xgboost
#endif // HISTOGRAM_CUH_

View File

@ -91,6 +91,16 @@ struct DeviceSplitCandidate {
}
}
XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
os << "loss_chg:" << c.loss_chg << ", "
<< "dir: " << c.dir << ", "
<< "findex: " << c.findex << ", "
<< "fvalue: " << c.fvalue << ", "
<< "left sum: " << c.left_sum << ", "
<< "right sum: " << c.right_sum << std::endl;
return os;
}
};
struct DeviceSplitCandidateReduceOp {
@ -186,6 +196,5 @@ struct SumCallbackOp {
XGBOOST_DEVICE inline int MaxNodesDepth(int depth) {
return (1 << (depth + 1)) - 1;
}
} // namespace tree
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017-2019 XGBoost contributors
* Copyright 2017-2020 XGBoost contributors
*/
#include <thrust/copy.h>
#include <thrust/functional.h>
@ -31,10 +31,10 @@
#include "constraints.cuh"
#include "gpu_hist/gradient_based_sampler.cuh"
#include "gpu_hist/row_partitioner.cuh"
#include "gpu_hist/histogram.cuh"
namespace xgboost {
namespace tree {
#if !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST)
@ -43,6 +43,7 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
struct GPUHistMakerTrainParam
: public XGBoostParameter<GPUHistMakerTrainParam> {
bool single_precision_histogram;
bool deterministic_histogram;
// number of rows in a single GPU batch
int gpu_batch_nrows;
bool debug_synchronize;
@ -50,6 +51,8 @@ struct GPUHistMakerTrainParam
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe(
"Pre-round the gradient for obtaining deterministic gradient histogram.");
DMLC_DECLARE_FIELD(gpu_batch_nrows)
.set_lower_bound(-1)
.set_default(0)
@ -336,6 +339,9 @@ class DeviceHistogram {
bool HistogramExists(int nidx) const {
return nidx_map_.find(nidx) != nidx_map_.cend();
}
int Bins() const {
return n_bins_;
}
size_t HistogramSize() const {
return n_bins_ * kNumItemsInGradientSum;
}
@ -402,40 +408,6 @@ struct CalcWeightTrainParam {
learning_rate(p.learning_rate) {}
};
template <typename GradientSumT>
__global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* d_node_hist,
const GradientPair* d_gpair, size_t n_elements,
bool use_shared_memory_histograms) {
extern __shared__ char smem[];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT());
__syncthreads();
}
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.info.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
if (gidx != matrix.info.n_bins) {
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT* atomic_add_ptr =
use_shared_memory_histograms ? smem_arr : d_node_hist;
dh::AtomicAddGpair(atomic_add_ptr + gidx, d_gpair[ridx]);
}
}
if (use_shared_memory_histograms) {
// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.info.n_bins)) {
dh::AtomicAddGpair(d_node_hist + i, smem_arr[i]);
}
}
}
// Manage memory for a single GPU
template <typename GradientSumT>
struct GPUHistMakerDevice {
@ -460,9 +432,12 @@ struct GPUHistMakerDevice {
bst_uint n_rows;
TrainParam param;
bool deterministic_histogram;
bool prediction_cache_initialised;
bool use_shared_memory_histograms {false};
GradientSumT histogram_rounding;
dh::CubMemory temp_memory;
dh::PinnedMemory pinned_memory;
@ -486,6 +461,7 @@ struct GPUHistMakerDevice {
TrainParam _param,
uint32_t column_sampler_seed,
uint32_t n_features,
bool deterministic_histogram,
BatchParam _batch_param)
: device_id(_device_id),
page(_page),
@ -494,6 +470,7 @@ struct GPUHistMakerDevice {
prediction_cache_initialised(false),
column_sampler(column_sampler_seed),
interaction_constraints(param, n_features),
deterministic_histogram{deterministic_histogram},
batch_param(_batch_param) {
sampler.reset(new GradientBasedSampler(page,
n_rows,
@ -551,6 +528,12 @@ struct GPUHistMakerDevice {
page = sample.page;
gpair = sample.gpair;
if (deterministic_histogram) {
histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair);
} else {
histogram_rounding = GradientSumT{0.0, 0.0};
}
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
hist.Reset();
@ -644,20 +627,8 @@ struct GPUHistMakerDevice {
auto d_ridx = row_partitioner->GetRows(nidx);
auto d_gpair = gpair.data();
auto n_elements = d_ridx.size() * page->matrix.info.row_stride;
const size_t smem_size =
use_shared_memory_histograms
? sizeof(GradientSumT) * page->matrix.info.n_bins
: 0;
uint32_t items_per_thread = 8;
uint32_t block_threads = 256;
auto grid_size = static_cast<uint32_t>(
common::DivRoundUp(n_elements, items_per_thread * block_threads));
dh::LaunchKernel {grid_size, block_threads, smem_size} (
SharedMemHistKernel<GradientSumT>,
page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
use_shared_memory_histograms);
BuildGradientHistogram(page->matrix, gpair, d_ridx, d_node_hist,
histogram_rounding, use_shared_memory_histograms);
}
void SubtractionTrick(int nidx_parent, int nidx_histogram,
@ -707,7 +678,7 @@ struct GPUHistMakerDevice {
// After tree update is finished, update the position of all training
// instances to their final leaf. This information is used later to update the
// prediction cache
void FinalisePosition(RegTree* p_tree, DMatrix* p_fmat) {
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) {
const auto d_nodes =
temp_memory.GetSpan<RegTree::Node>(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
@ -870,14 +841,19 @@ struct GPUHistMakerDevice {
}
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) {
constexpr int kRootNIdx = 0;
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size());
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
GradientPair root_sum = thrust::reduce(
thrust::cuda::par(alloc),
thrust::device_ptr<GradientPair const>(gpair.data()),
thrust::device_ptr<GradientPair const>(gpair.data() + gpair.size()));
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.data(), &root_sum, sizeof(root_sum),
cudaMemcpyHostToDevice));
reducer->AllReduceSum(
reinterpret_cast<float*>(node_sum_gradients_d.data()),
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
reducer->Synchronize();
dh::safe_cuda(cudaMemcpy(node_sum_gradients.data(),
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients.data(),
node_sum_gradients_d.data(), sizeof(GradientPair),
cudaMemcpyDeviceToHost));
@ -1055,6 +1031,7 @@ class GPUHistMakerSpecialised {
param_,
column_sampling_seed,
info_->num_col_,
hist_maker_param_.deterministic_histogram,
batch_param));
monitor_.StartCuda("InitHistogram");

View File

@ -76,6 +76,20 @@ void TestDeviceSketch(bool use_external_memory) {
ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows);
}
// Determinstic
size_t constexpr kRounds { 100 };
for (size_t r = 0; r < kRounds; ++r) {
HistogramCuts new_sketch;
DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &new_sketch);
ASSERT_EQ(hmat_gpu.Values().size(), new_sketch.Values().size());
for (size_t i = 0; i < hmat_gpu.Values().size(); ++i) {
ASSERT_EQ(hmat_gpu.Values()[i], new_sketch.Values()[i]);
}
for (size_t i = 0; i < hmat_gpu.MinValues().size(); ++i) {
ASSERT_EQ(hmat_gpu.MinValues()[i], new_sketch.MinValues()[i]);
}
}
delete dmat;
}

View File

@ -224,9 +224,10 @@ inline GenericParameter CreateEmptyGenericParam(int gpu_id) {
return tparam;
}
inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows) {
inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows,
float lower= 0.0f, float upper = 1.0f) {
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower, upper);
std::vector<GradientPair> h_gpair(n_rows);
for (auto &gpair : h_gpair) {
bst_float grad = dist(&gen);
@ -288,6 +289,5 @@ inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
return page;
}
#endif
} // namespace xgboost
#endif

View File

@ -605,6 +605,10 @@ TEST_F(MultiClassesSerializationTest, GPU_Hist) {
{"seed", "0"},
{"nthread", "1"},
{"max_depth", std::to_string(kClasses)},
// Somehow rebuilding the cache can generate slightly
// different result (1e-7) with CPU predictor for some
// entries.
{"predictor", "gpu_predictor"},
{"enable_experimental_json_serialization", "1"},
{"tree_method", "gpu_hist"}},
fmap_, *pp_dmat_);

View File

@ -0,0 +1,69 @@
#include <gtest/gtest.h>
#include "../../helpers.h"
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../../../src/tree/gpu_hist/histogram.cuh"
namespace xgboost {
namespace tree {
template <typename Gradient>
void TestDeterminsticHistogram() {
size_t constexpr kBins = 24, kCols = 8, kRows = 32768, kRounds = 16;
float constexpr kLower = -1e-2, kUpper = 1e2;
auto pp_m = CreateDMatrix(kRows, kCols, 0.5);
auto& matrix = **pp_m;
BatchParam batch_param{0, static_cast<int32_t>(kBins), 0, 0};
for (auto const& batch : matrix.GetBatches<EllpackPage>(batch_param)) {
auto* page = batch.Impl();
tree::RowPartitioner row_partitioner(0, kRows);
auto ridx = row_partitioner.GetRows(0);
dh::device_vector<Gradient> histogram(kBins * kCols);
auto d_histogram = dh::ToSpan(histogram);
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
gpair.SetDevice(0);
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
BuildGradientHistogram(page->matrix, gpair.DeviceSpan(), ridx,
d_histogram, rounding, true);
for (size_t i = 0; i < kRounds; ++i) {
dh::device_vector<Gradient> new_histogram(kBins * kCols);
auto d_histogram = dh::ToSpan(new_histogram);
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
BuildGradientHistogram(page->matrix, gpair.DeviceSpan(), ridx,
d_histogram, rounding, true);
for (size_t j = 0; j < new_histogram.size(); ++j) {
ASSERT_EQ(((Gradient)new_histogram[j]).GetGrad(),
((Gradient)histogram[j]).GetGrad());
ASSERT_EQ(((Gradient)new_histogram[j]).GetHess(),
((Gradient)histogram[j]).GetHess());
}
}
{
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
gpair.SetDevice(0);
dh::device_vector<Gradient> baseline(kBins * kCols);
BuildGradientHistogram(page->matrix, gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), rounding, true);
for (size_t i = 0; i < baseline.size(); ++i) {
EXPECT_NEAR(((Gradient)baseline[i]).GetGrad(), ((Gradient)histogram[i]).GetGrad(),
((Gradient)baseline[i]).GetGrad() * 1e-3);
}
}
}
delete pp_m;
}
TEST(Histogram, GPUDeterminstic) {
TestDeterminsticHistogram<GradientPair>();
TestDeterminsticHistogram<GradientPairPrecise>();
}
} // namespace tree
} // namespace xgboost

View File

@ -83,7 +83,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
param.Init(args);
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param);
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols,
true, batch_param);
maker.InitHistogram();
xgboost::SimpleLCG gen;
@ -187,7 +188,7 @@ TEST(GpuHist, EvaluateSplits) {
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientPairPrecise>
maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param);
maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param);
// Initialize GPUHistMakerDevice::node_sum_gradients
maker.node_sum_gradients = {{6.4f, 12.8f}};

View File

@ -1,6 +1,8 @@
import sys
import os
import unittest
import numpy as np
import xgboost as xgb
sys.path.append("tests/python")
# Don't import the test class, otherwise they will run twice.
import test_basic_models as test_bm # noqa
@ -12,3 +14,33 @@ class TestGPUBasicModels(unittest.TestCase):
def test_eta_decay_gpu_hist(self):
self.cputest.run_eta_decay('gpu_hist')
def test_deterministic_gpu_hist(self):
kRows = 1000
kCols = 64
kClasses = 4
# Create large values to force rounding.
X = np.random.randn(kRows, kCols) * 1e4
y = np.random.randint(0, kClasses, size=kRows)
cls = xgb.XGBClassifier(tree_method='gpu_hist',
deterministic_histogram=True,
single_precision_histogram=True)
cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
cls = xgb.XGBClassifier(tree_method='gpu_hist',
deterministic_histogram=True,
single_precision_histogram=True)
cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
with open('test_deterministic_gpu_hist-0.json', 'r') as fd:
model_0 = fd.read()
with open('test_deterministic_gpu_hist-1.json', 'r') as fd:
model_1 = fd.read()
assert hash(model_0) == hash(model_1)
os.remove('test_deterministic_gpu_hist-0.json')
os.remove('test_deterministic_gpu_hist-1.json')