Add basic unittests for gpu-hist method. (#3785)
* Split building histogram into separated class. * Extract `InitCompressedRow` definition. * Basic tests for gpu-hist. * Document the code more verbosely. * Removed `HistCutUnit`. * Removed some duplicated copies in `GPUHistMaker`. * Implement LCG and use it in tests.
This commit is contained in:
committed by
Rory Mitchell
parent
184efff9f9
commit
516457fadc
@@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2016-2018 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TESTS_CPP_HELPERS_H_
|
||||
#define XGBOOST_TESTS_CPP_HELPERS_H_
|
||||
|
||||
@@ -56,7 +59,86 @@ namespace xgboost {
|
||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _beg2);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Linear congruential generator.
|
||||
*
|
||||
* The distribution defined in std is not portable. Given the same seed, it
|
||||
* migth produce different outputs on different platforms or with different
|
||||
* compilers. The SimpleLCG implemented here is to make sure all tests are
|
||||
* reproducible.
|
||||
*/
|
||||
class SimpleLCG {
|
||||
private:
|
||||
using StateType = int64_t;
|
||||
static StateType constexpr default_init_ = 3;
|
||||
static StateType constexpr default_alpha_ = 61;
|
||||
static StateType constexpr max_value_ = ((StateType)1 << 32) - 1;
|
||||
|
||||
StateType state_;
|
||||
StateType const alpha_;
|
||||
StateType const mod_;
|
||||
|
||||
StateType const seed_;
|
||||
|
||||
public:
|
||||
SimpleLCG() : state_{default_init_},
|
||||
alpha_{default_alpha_}, mod_{max_value_}, seed_{state_}{}
|
||||
/*!
|
||||
* \brief Initialize SimpleLCG.
|
||||
*
|
||||
* \param state Initial state, can also be considered as seed. If set to
|
||||
* zero, SimpleLCG will use internal default value.
|
||||
* \param alpha multiplier
|
||||
* \param mod modulo
|
||||
*/
|
||||
SimpleLCG(StateType state,
|
||||
StateType alpha=default_alpha_, StateType mod=max_value_)
|
||||
: state_{state == 0 ? default_init_ : state},
|
||||
alpha_{alpha}, mod_{mod} , seed_{state} {}
|
||||
|
||||
StateType operator()();
|
||||
StateType Min() const;
|
||||
StateType Max() const;
|
||||
};
|
||||
|
||||
template <typename ResultT>
|
||||
class SimpleRealUniformDistribution {
|
||||
private:
|
||||
ResultT const lower;
|
||||
ResultT const upper;
|
||||
|
||||
/*! \brief Over-simplified version of std::generate_canonical. */
|
||||
template <size_t Bits, typename GeneratorT>
|
||||
ResultT GenerateCanonical(GeneratorT* rng) const {
|
||||
static_assert(std::is_floating_point<ResultT>::value,
|
||||
"Result type must be floating point.");
|
||||
long double const r = (static_cast<long double>(rng->Max())
|
||||
- static_cast<long double>(rng->Min())) + 1.0L;
|
||||
size_t const log2r = std::log(r) / std::log(2.0L);
|
||||
size_t m = std::max<size_t>(1UL, (Bits + log2r - 1UL) / log2r);
|
||||
ResultT sum_value = 0, r_k = 1;
|
||||
|
||||
for (size_t k = m; k != 0; --k) {
|
||||
sum_value += ResultT((*rng)() - rng->Min()) * r_k;
|
||||
r_k *= r;
|
||||
}
|
||||
|
||||
ResultT res = sum_value / r_k;
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
SimpleRealUniformDistribution(ResultT l, ResultT u) :
|
||||
lower{l}, upper{u} {}
|
||||
|
||||
template <typename GeneratorT>
|
||||
ResultT operator()(GeneratorT* rng) const {
|
||||
ResultT tmp = GenerateCanonical<std::numeric_limits<ResultT>::digits,
|
||||
GeneratorT>(rng);
|
||||
return (tmp * (upper - lower)) + lower;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \fn std::shared_ptr<xgboost::DMatrix> CreateDMatrix(int rows, int columns, float sparsity, int seed);
|
||||
@@ -70,7 +152,8 @@ bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
||||
*
|
||||
* \return The new d matrix.
|
||||
*/
|
||||
|
||||
std::shared_ptr<xgboost::DMatrix> *CreateDMatrix(int rows, int columns,
|
||||
float sparsity, int seed = 0);
|
||||
|
||||
} // namespace xgboost
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user