Gradient based sampling for GPU Hist (#5093)

* Implement gradient based sampling for GPU Hist tree method.
* Add samplers and handle compacted page in GPU Hist.
This commit is contained in:
Rong Ou
2020-02-03 18:31:27 -08:00
committed by GitHub
parent c74216f22c
commit e4b74c4d22
18 changed files with 1187 additions and 175 deletions

View File

@@ -63,7 +63,7 @@ class CompressedBufferWriter {
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
* num_elements, int num_symbols)
*
* \brief Calculates number of bytes requiredm for a given number of elements
* \brief Calculates number of bytes required for a given number of elements
* and a symbol range.
*
* \author Rory
@@ -74,7 +74,6 @@ class CompressedBufferWriter {
*
* \return The calculated buffer size.
*/
static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) {
const int bits_per_byte = 8;
size_t compressed_size = static_cast<size_t>(std::ceil(
@@ -188,7 +187,7 @@ class CompressedIterator {
public:
CompressedIterator() : buffer_(nullptr), symbol_bits_(0), offset_(0) {}
CompressedIterator(CompressedByteT *buffer, int num_symbols)
CompressedIterator(CompressedByteT *buffer, size_t num_symbols)
: buffer_(buffer), offset_(0) {
symbol_bits_ = detail::SymbolBits(num_symbols);
}

View File

@@ -1266,6 +1266,26 @@ thrust::device_ptr<T const> tcend(xgboost::HostDeviceVector<T> const& vector) {
return tcbegin(vector) + vector.Size();
}
template <typename T>
thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT
return thrust::device_ptr<T>(span.data());
}
template <typename T>
thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // // NOLINT
return tbegin(span) + span.size();
}
template <typename T>
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) {
return thrust::device_ptr<T const>(span.data());
}
template <typename T>
thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) {
return tcbegin(span) + span.size();
}
template <typename FunctionT>
class LauncherItr {
public: