Device dmatrix (#5420)
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "xgboost/learner.h"
|
||||
#include "c_api_error.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/device_dmatrix.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@@ -29,3 +30,25 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
bst_float missing, int nthread, int max_bin,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CudfAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
bst_float missing, int nthread, int max_bin,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CupyAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ static const int kPadding = 4; // Assign padding so we can read slightly off
|
||||
// the beginning of the array
|
||||
|
||||
// The number of bits required to represent a given unsigned range
|
||||
static size_t SymbolBits(size_t num_symbols) {
|
||||
auto bits = std::ceil(std::log2(num_symbols));
|
||||
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
|
||||
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
|
||||
return std::max(static_cast<size_t>(bits), size_t(1));
|
||||
}
|
||||
} // namespace detail
|
||||
@@ -50,14 +50,11 @@ static size_t SymbolBits(size_t num_symbols) {
|
||||
*/
|
||||
|
||||
class CompressedBufferWriter {
|
||||
private:
|
||||
size_t symbol_bits_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
|
||||
symbol_bits_ = detail::SymbolBits(num_symbols);
|
||||
}
|
||||
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
|
||||
: symbol_bits_(detail::SymbolBits(num_symbols)) {}
|
||||
|
||||
/**
|
||||
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
|
||||
@@ -164,18 +161,15 @@ class CompressedBufferWriter {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
/**
|
||||
* \class CompressedIterator
|
||||
*
|
||||
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
|
||||
* host.
|
||||
* \brief Read symbols from a bit compressed memory buffer. Usable on device and host.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
class CompressedIterator {
|
||||
public:
|
||||
// Type definitions for thrust
|
||||
|
||||
@@ -1540,4 +1540,12 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
} // namespace dh
|
||||
|
||||
@@ -338,31 +338,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
return cuts;
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const data::COOTuple& e) const {
|
||||
if (common::CheckNAN(e.value) || e.value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
__device__ bool operator()(const Entry& e) const {
|
||||
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
SketchContainer* sketch_container, int num_cuts) {
|
||||
@@ -372,10 +347,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
auto &batch = adapter->Value();
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
auto batch_iter = MakeTransformIterator<data::COOTuple>(
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
auto entry_iter = MakeTransformIterator<Entry>(
|
||||
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
return Entry(batch.GetElement(idx).column_idx,
|
||||
batch.GetElement(idx).value);
|
||||
@@ -385,7 +360,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
0);
|
||||
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
IsValidFunctor is_valid(missing);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto e = batch_iter[begin + idx];
|
||||
if (is_valid(e)) {
|
||||
|
||||
@@ -105,10 +105,10 @@ class HistogramCuts {
|
||||
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
|
||||
const auto &values = cut_values_.ConstHostVector();
|
||||
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
||||
if (it == values.cend()) {
|
||||
it = values.cend() - 1;
|
||||
}
|
||||
BinIdx idx = it - values.cbegin();
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "array_interface.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "device_adapter.cuh"
|
||||
#include "simple_dmatrix.h"
|
||||
#include "device_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
|
||||
@@ -8,12 +8,31 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/math.h"
|
||||
#include "adapter.h"
|
||||
#include "array_interface.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const data::COOTuple& e) const {
|
||||
if (common::CheckNAN(e.value) || e.value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
__device__ bool operator()(const Entry& e) const {
|
||||
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CudfAdapterBatch() = default;
|
||||
|
||||
238
src/data/device_dmatrix.cu
Normal file
238
src/data/device_dmatrix.cu
Normal file
@@ -0,0 +1,238 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.cu
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "../common/hist_util.h"
|
||||
#include "adapter.h"
|
||||
#include "device_adapter.cuh"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "device_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
auto element = batch.GetElement(idx);
|
||||
if (is_valid(element)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&offset[element.row_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
size_t row_stride = thrust::reduce(
|
||||
thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0),
|
||||
thrust::maximum<size_t>());
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
struct WriteCompressedEllpackFunctor {
|
||||
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
|
||||
const common::CompressedBufferWriter& writer,
|
||||
const AdapterBatchT& batch,
|
||||
EllpackDeviceAccessor accessor,
|
||||
const IsValidFunctor& is_valid)
|
||||
: d_buffer(buffer),
|
||||
writer(writer),
|
||||
batch(batch),
|
||||
accessor(std::move(accessor)),
|
||||
is_valid(is_valid) {}
|
||||
|
||||
common::CompressedByteT* d_buffer;
|
||||
common::CompressedBufferWriter writer;
|
||||
AdapterBatchT batch;
|
||||
EllpackDeviceAccessor accessor;
|
||||
IsValidFunctor is_valid;
|
||||
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
__device__ size_t operator()(Tuple out) {
|
||||
auto e = batch.GetElement(out.get<2>());
|
||||
if (is_valid(e)) {
|
||||
// -1 because the scan is inclusive
|
||||
size_t output_position =
|
||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
||||
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
|
||||
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ellpack matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
// compaction needs to be modified to do this, so we manually define a
|
||||
// segmented stream compaction via operators on an inclusive scan. The output
|
||||
// of this inclusive scan is fed to a custom function which works out the
|
||||
// correct output position
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
IsValidFunctor is_valid(missing);
|
||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting,
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx).row_idx; });
|
||||
auto value_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting, [=] __device__(size_t idx) -> size_t {
|
||||
return is_valid(batch.GetElement(idx));
|
||||
});
|
||||
|
||||
auto key_value_index_iter = thrust::make_zip_iterator(
|
||||
thrust::make_tuple(key_iter, value_iter, counting));
|
||||
|
||||
// Tuple[0] = The row index of the input, used as a key to define segments
|
||||
// Tuple[1] = Scanned flags of valid elements for each row
|
||||
// Tuple[2] = The index in the input data
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
|
||||
// We redirect the scan output into this functor to do the actual writing
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||
d_compressed_buffer, writer, batch, device_accessor, is_valid);
|
||||
thrust::discard_iterator<size_t> discard;
|
||||
thrust::transform_output_iterator<
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||
out(discard, functor);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
|
||||
key_value_index_iter + batch.Size(), out,
|
||||
[=] __device__(Tuple a, Tuple b) {
|
||||
// Key equal
|
||||
if (a.get<0>() == b.get<0>()) {
|
||||
b.get<1>() += a.get<1>();
|
||||
return b;
|
||||
}
|
||||
// Not equal
|
||||
return b;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT, typename AdapterBatchT>
|
||||
void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
|
||||
EllpackPageImpl* dst, float missing) {
|
||||
// Step 1: Get the sizes of the input columns
|
||||
dh::caching_device_vector<size_t> column_sizes(adapter->NumColumns(), 0);
|
||||
auto d_column_sizes = column_sizes.data().get();
|
||||
// Populate column sizes
|
||||
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
|
||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
||||
|
||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
||||
// temporary row pointers
|
||||
dh::caching_device_vector<size_t> temp_row_ptr(adapter->NumRows(), 0);
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
auto row_stride = dst->row_stride;
|
||||
size_t begin = 0;
|
||||
auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx());
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!is_valid(e)) return;
|
||||
size_t output_position =
|
||||
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
|
||||
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
|
||||
output_position);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
common::Span<size_t> row_counts) {
|
||||
// Write the null values
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
auto row_stride = dst->row_stride;
|
||||
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
size_t row_idx = idx / row_stride;
|
||||
size_t row_offset = idx % row_stride;
|
||||
if (row_offset >= row_counts[row_idx]) {
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
|
||||
device_accessor.NullValue(), idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
template <typename AdapterT>
|
||||
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin) {
|
||||
common::HistogramCuts cuts =
|
||||
common::AdapterDeviceSketch(adapter, max_bin, missing);
|
||||
auto& batch = adapter->Value();
|
||||
// Work out how many valid entries we have in each row
|
||||
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
size_t row_stride =
|
||||
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
info.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
|
||||
row_counts.begin(), row_counts.end());
|
||||
info.num_col_ = adapter->NumColumns();
|
||||
info.num_row_ = adapter->NumRows();
|
||||
ellpack_page_.reset(new EllpackPage());
|
||||
*ellpack_page_->Impl() =
|
||||
EllpackPageImpl(adapter->DeviceIdx(), cuts, this->IsDense(), row_stride,
|
||||
adapter->NumRows());
|
||||
if (adapter->IsRowMajor()) {
|
||||
CopyDataRowMajor(batch, ellpack_page_->Impl(), adapter->DeviceIdx(),
|
||||
missing);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, batch, ellpack_page_->Impl(), missing);
|
||||
}
|
||||
|
||||
WriteNullValues(ellpack_page_->Impl(), adapter->DeviceIdx(), row_counts_span);
|
||||
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
|
||||
}
|
||||
template DeviceDMatrix::DeviceDMatrix(CudfAdapter* adapter, float missing,
|
||||
int nthread, int max_bin);
|
||||
template DeviceDMatrix::DeviceDMatrix(CupyAdapter* adapter, float missing,
|
||||
int nthread, int max_bin);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
60
src/data/device_dmatrix.h
Normal file
60
src/data/device_dmatrix.h
Normal file
@@ -0,0 +1,60 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.h
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
#define XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "adapter.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "simple_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class DeviceDMatrix : public DMatrix {
|
||||
public:
|
||||
template <typename AdapterT>
|
||||
explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin);
|
||||
|
||||
MetaInfo& Info() override { return info; }
|
||||
|
||||
const MetaInfo& Info() const override { return info; }
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||
}
|
||||
BatchSet<CSCPage> GetColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
|
||||
auto begin_iter = BatchIterator<EllpackPage>(
|
||||
new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
MetaInfo info;
|
||||
// source data pointer.
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
@@ -26,7 +26,6 @@ void EllpackPage::SetBaseRowId(size_t row_id) {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
}
|
||||
|
||||
size_t EllpackPage::Size() const {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
|
||||
@@ -210,8 +210,8 @@ void EllpackPageImpl::InitCompressedData(int device) {
|
||||
|
||||
// Required buffer size for storing data matrix in ELLPack format.
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
gidx_buffer.SetDevice(device);
|
||||
// Don't call fill unnecessarily
|
||||
if (gidx_buffer.Size() == 0) {
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "../common/compressed_iterator.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include <thrust/binary_search.h>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -90,6 +91,19 @@ struct EllpackDeviceAccessor {
|
||||
}
|
||||
return gidx;
|
||||
}
|
||||
|
||||
__device__ uint32_t SearchBin(float value, size_t column_id) const {
|
||||
auto beg = feature_segments[column_id];
|
||||
auto end = feature_segments[column_id + 1];
|
||||
auto it =
|
||||
thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin()+ beg, gidx_fvalue_map.cbegin() + end, value);
|
||||
uint32_t idx = it - gidx_fvalue_map.cbegin();
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
auto gidx = GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
@@ -104,7 +118,7 @@ struct EllpackDeviceAccessor {
|
||||
}
|
||||
/*! \brief Return the total number of symbols (total number of bins plus 1 for
|
||||
* not found). */
|
||||
size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||
XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||
|
||||
size_t NullValue() const { return gidx_fvalue_map.size(); }
|
||||
|
||||
|
||||
@@ -8,26 +8,20 @@
|
||||
#include <xgboost/data.h>
|
||||
#include "../common/random.h"
|
||||
#include "./simple_dmatrix.h"
|
||||
#include "../common/math.h"
|
||||
#include "device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
XGBOOST_DEVICE bool IsValid(float value, float missing) {
|
||||
if (common::CheckNAN(value) || value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
auto element = batch.GetElement(idx);
|
||||
if (IsValid(element.value, missing)) {
|
||||
if (is_valid(element)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&offset[element.row_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
@@ -66,11 +60,12 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
thrust::device_pointer_cast(row_ptr.data() + row_ptr.size()));
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
size_t begin = 0;
|
||||
IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!IsValid(e.value, missing)) return;
|
||||
if (!is_valid(e)) return;
|
||||
data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
@@ -79,15 +74,6 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
}
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const Entry& x) const {
|
||||
return IsValid(x.fvalue, missing);
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterT>
|
||||
|
||||
@@ -400,10 +400,11 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
|
||||
auto on_device =
|
||||
f_dmat &&
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
|
||||
(f_dmat->PageExists<EllpackPage>() ||
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead());
|
||||
|
||||
// Use GPU Predictor if data is already on device.
|
||||
if (on_device) {
|
||||
// Use GPU Predictor if data is already on device and gpu_id is set.
|
||||
if (on_device && generic_param_->gpu_id >= 0) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
|
||||
Reference in New Issue
Block a user