Device dmatrix (#5420)

This commit is contained in:
Rory Mitchell
2020-03-28 14:42:21 +13:00
committed by GitHub
parent 780de49ddb
commit 13b10a6370
24 changed files with 915 additions and 310 deletions

View File

@@ -32,8 +32,8 @@ static const int kPadding = 4; // Assign padding so we can read slightly off
// the beginning of the array
// The number of bits required to represent a given unsigned range
static size_t SymbolBits(size_t num_symbols) {
auto bits = std::ceil(std::log2(num_symbols));
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
return std::max(static_cast<size_t>(bits), size_t(1));
}
} // namespace detail
@@ -50,14 +50,11 @@ static size_t SymbolBits(size_t num_symbols) {
*/
class CompressedBufferWriter {
private:
size_t symbol_bits_;
size_t offset_;
public:
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
symbol_bits_ = detail::SymbolBits(num_symbols);
}
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
: symbol_bits_(detail::SymbolBits(num_symbols)) {}
/**
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
@@ -164,18 +161,15 @@ class CompressedBufferWriter {
}
};
template <typename T>
/**
* \class CompressedIterator
*
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
* host.
* \brief Read symbols from a bit compressed memory buffer. Usable on device and host.
*
* \author Rory
* \date 7/9/2017
*
* \tparam T Generic type parameter.
*/
template <typename T>
class CompressedIterator {
public:
// Type definitions for thrust

View File

@@ -1540,4 +1540,12 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
}
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}
} // namespace dh

View File

@@ -338,31 +338,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
return cuts;
}
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
explicit IsValidFunctor(float missing) : missing(missing) {}
float missing;
__device__ bool operator()(const data::COOTuple& e) const {
if (common::CheckNAN(e.value) || e.value == missing) {
return false;
}
return true;
}
__device__ bool operator()(const Entry& e) const {
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
return false;
}
return true;
}
};
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}
template <typename AdapterT>
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts) {
@@ -372,10 +347,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
auto &batch = adapter->Value();
// Enforce single batch
CHECK(!adapter->Next());
auto batch_iter = MakeTransformIterator<data::COOTuple>(
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto entry_iter = MakeTransformIterator<Entry>(
auto entry_iter = dh::MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
batch.GetElement(idx).value);
@@ -385,7 +360,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
0);
auto d_column_sizes_scan = column_sizes_scan.data().get();
IsValidFunctor is_valid(missing);
data::IsValidFunctor is_valid(missing);
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx];
if (is_valid(e)) {

View File

@@ -105,10 +105,10 @@ class HistogramCuts {
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
const auto &values = cut_values_.ConstHostVector();
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
if (it == values.cend()) {
it = values.cend() - 1;
}
BinIdx idx = it - values.cbegin();
if (idx == end) {
idx -= 1;
}
return idx;
}