Device dmatrix (#5420)
This commit is contained in:
@@ -32,8 +32,8 @@ static const int kPadding = 4; // Assign padding so we can read slightly off
|
||||
// the beginning of the array
|
||||
|
||||
// The number of bits required to represent a given unsigned range
|
||||
static size_t SymbolBits(size_t num_symbols) {
|
||||
auto bits = std::ceil(std::log2(num_symbols));
|
||||
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
|
||||
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
|
||||
return std::max(static_cast<size_t>(bits), size_t(1));
|
||||
}
|
||||
} // namespace detail
|
||||
@@ -50,14 +50,11 @@ static size_t SymbolBits(size_t num_symbols) {
|
||||
*/
|
||||
|
||||
class CompressedBufferWriter {
|
||||
private:
|
||||
size_t symbol_bits_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
|
||||
symbol_bits_ = detail::SymbolBits(num_symbols);
|
||||
}
|
||||
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
|
||||
: symbol_bits_(detail::SymbolBits(num_symbols)) {}
|
||||
|
||||
/**
|
||||
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
|
||||
@@ -164,18 +161,15 @@ class CompressedBufferWriter {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
/**
|
||||
* \class CompressedIterator
|
||||
*
|
||||
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
|
||||
* host.
|
||||
* \brief Read symbols from a bit compressed memory buffer. Usable on device and host.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
class CompressedIterator {
|
||||
public:
|
||||
// Type definitions for thrust
|
||||
|
||||
@@ -1540,4 +1540,12 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
} // namespace dh
|
||||
|
||||
@@ -338,31 +338,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
return cuts;
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const data::COOTuple& e) const {
|
||||
if (common::CheckNAN(e.value) || e.value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
__device__ bool operator()(const Entry& e) const {
|
||||
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
SketchContainer* sketch_container, int num_cuts) {
|
||||
@@ -372,10 +347,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
auto &batch = adapter->Value();
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
auto batch_iter = MakeTransformIterator<data::COOTuple>(
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
auto entry_iter = MakeTransformIterator<Entry>(
|
||||
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
return Entry(batch.GetElement(idx).column_idx,
|
||||
batch.GetElement(idx).value);
|
||||
@@ -385,7 +360,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
0);
|
||||
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
IsValidFunctor is_valid(missing);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto e = batch_iter[begin + idx];
|
||||
if (is_valid(e)) {
|
||||
|
||||
@@ -105,10 +105,10 @@ class HistogramCuts {
|
||||
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
|
||||
const auto &values = cut_values_.ConstHostVector();
|
||||
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
||||
if (it == values.cend()) {
|
||||
it = values.cend() - 1;
|
||||
}
|
||||
BinIdx idx = it - values.cbegin();
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user