Move ellpack page construction into DMatrix (#4833)
This commit is contained in:
parent
512f037e55
commit
125bcec62e
@ -29,6 +29,7 @@
|
|||||||
|
|
||||||
// data
|
// data
|
||||||
#include "../src/data/data.cc"
|
#include "../src/data/data.cc"
|
||||||
|
#include "../src/data/ellpack_page.cc"
|
||||||
#include "../src/data/simple_csr_source.cc"
|
#include "../src/data/simple_csr_source.cc"
|
||||||
#include "../src/data/simple_dmatrix.cc"
|
#include "../src/data/simple_dmatrix.cc"
|
||||||
#include "../src/data/sparse_page_raw_format.cc"
|
#include "../src/data/sparse_page_raw_format.cc"
|
||||||
|
|||||||
@ -26,6 +26,8 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// forward declare learner.
|
// forward declare learner.
|
||||||
class LearnerImpl;
|
class LearnerImpl;
|
||||||
|
// forward declare dmatrix.
|
||||||
|
class DMatrix;
|
||||||
|
|
||||||
/*! \brief data type accepted by xgboost interface */
|
/*! \brief data type accepted by xgboost interface */
|
||||||
enum DataType {
|
enum DataType {
|
||||||
@ -86,7 +88,7 @@ class MetaInfo {
|
|||||||
* \return The pre-defined root index of i-th instance.
|
* \return The pre-defined root index of i-th instance.
|
||||||
*/
|
*/
|
||||||
inline unsigned GetRoot(size_t i) const {
|
inline unsigned GetRoot(size_t i) const {
|
||||||
return root_index_.size() != 0 ? root_index_[i] : 0U;
|
return !root_index_.empty() ? root_index_[i] : 0U;
|
||||||
}
|
}
|
||||||
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
|
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
|
||||||
inline const std::vector<size_t>& LabelAbsSort() const {
|
inline const std::vector<size_t>& LabelAbsSort() const {
|
||||||
@ -166,7 +168,7 @@ class SparsePage {
|
|||||||
/*! \brief the data of the segments */
|
/*! \brief the data of the segments */
|
||||||
HostDeviceVector<Entry> data;
|
HostDeviceVector<Entry> data;
|
||||||
|
|
||||||
size_t base_rowid;
|
size_t base_rowid{};
|
||||||
|
|
||||||
/*! \brief an instance of sparse vector in the batch */
|
/*! \brief an instance of sparse vector in the batch */
|
||||||
using Inst = common::Span<Entry const>;
|
using Inst = common::Span<Entry const>;
|
||||||
@ -215,23 +217,23 @@ class SparsePage {
|
|||||||
const int nthread = omp_get_max_threads();
|
const int nthread = omp_get_max_threads();
|
||||||
builder.InitBudget(num_columns, nthread);
|
builder.InitBudget(num_columns, nthread);
|
||||||
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
|
||||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
auto inst = (*this)[i];
|
auto inst = (*this)[i];
|
||||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
for (const auto& entry : inst) {
|
||||||
builder.AddBudget(inst[j].index, tid);
|
builder.AddBudget(entry.index, tid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
builder.InitStorage();
|
builder.InitStorage();
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
|
||||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
auto inst = (*this)[i];
|
auto inst = (*this)[i];
|
||||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
for (const auto& entry : inst) {
|
||||||
builder.Push(
|
builder.Push(
|
||||||
inst[j].index,
|
entry.index,
|
||||||
Entry(static_cast<bst_uint>(this->base_rowid + i), inst[j].fvalue),
|
Entry(static_cast<bst_uint>(this->base_rowid + i), entry.fvalue),
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -240,7 +242,7 @@ class SparsePage {
|
|||||||
|
|
||||||
void SortRows() {
|
void SortRows() {
|
||||||
auto ncol = static_cast<bst_omp_uint>(this->Size());
|
auto ncol = static_cast<bst_omp_uint>(this->Size());
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for default(none) shared(ncol) schedule(dynamic, 1)
|
||||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||||
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
|
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
|
||||||
std::sort(
|
std::sort(
|
||||||
@ -287,10 +289,29 @@ class SortedCSCPage : public SparsePage {
|
|||||||
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class EllpackPageImpl;
|
||||||
|
/*!
|
||||||
|
* \brief A page stored in ELLPACK format.
|
||||||
|
*
|
||||||
|
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
|
||||||
|
* including CUDA-specific implementation details in the header.
|
||||||
|
*/
|
||||||
|
class EllpackPage {
|
||||||
|
public:
|
||||||
|
explicit EllpackPage(DMatrix* dmat);
|
||||||
|
~EllpackPage();
|
||||||
|
|
||||||
|
const EllpackPageImpl* Impl() const { return impl_.get(); }
|
||||||
|
EllpackPageImpl* Impl() { return impl_.get(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<EllpackPageImpl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
class BatchIteratorImpl {
|
class BatchIteratorImpl {
|
||||||
public:
|
public:
|
||||||
virtual ~BatchIteratorImpl() {}
|
virtual ~BatchIteratorImpl() = default;
|
||||||
virtual T& operator*() = 0;
|
virtual T& operator*() = 0;
|
||||||
virtual const T& operator*() const = 0;
|
virtual const T& operator*() const = 0;
|
||||||
virtual void operator++() = 0;
|
virtual void operator++() = 0;
|
||||||
@ -412,7 +433,7 @@ class DMatrix {
|
|||||||
bool silent,
|
bool silent,
|
||||||
bool load_row_split,
|
bool load_row_split,
|
||||||
const std::string& file_format = "auto",
|
const std::string& file_format = "auto",
|
||||||
const size_t page_size = kPageSize);
|
size_t page_size = kPageSize);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
|
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
|
||||||
@ -438,7 +459,7 @@ class DMatrix {
|
|||||||
*/
|
*/
|
||||||
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
|
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
|
||||||
const std::string& cache_prefix = "",
|
const std::string& cache_prefix = "",
|
||||||
const size_t page_size = kPageSize);
|
size_t page_size = kPageSize);
|
||||||
|
|
||||||
/*! \brief page size 32 MB */
|
/*! \brief page size 32 MB */
|
||||||
static const size_t kPageSize = 32UL << 20UL;
|
static const size_t kPageSize = 32UL << 20UL;
|
||||||
@ -447,6 +468,7 @@ class DMatrix {
|
|||||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||||
|
virtual BatchSet<EllpackPage> GetEllpackBatches() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
@ -463,6 +485,11 @@ template<>
|
|||||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
||||||
return GetSortedColumnBatches();
|
return GetSortedColumnBatches();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline BatchSet<EllpackPage> DMatrix::GetBatches() {
|
||||||
|
return GetEllpackBatches();
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
namespace dmlc {
|
namespace dmlc {
|
||||||
|
|||||||
@ -99,15 +99,15 @@ struct SketchContainer {
|
|||||||
std::vector<std::mutex> col_locks_; // NOLINT
|
std::vector<std::mutex> col_locks_; // NOLINT
|
||||||
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
||||||
|
|
||||||
SketchContainer(const tree::TrainParam ¶m, DMatrix *dmat) :
|
SketchContainer(int max_bin, DMatrix *dmat) :
|
||||||
col_locks_(dmat->Info().num_col_) {
|
col_locks_(dmat->Info().num_col_) {
|
||||||
const MetaInfo &info = dmat->Info();
|
const MetaInfo &info = dmat->Info();
|
||||||
// Initialize Sketches for this dmatrix
|
// Initialize Sketches for this dmatrix
|
||||||
sketches_.resize(info.num_col_);
|
sketches_.resize(info.num_col_);
|
||||||
#pragma omp parallel for default(none) shared(info, param) schedule(static) \
|
#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \
|
||||||
if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
|
if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
|
||||||
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
|
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
|
||||||
sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin));
|
sketches_[icol].Init(info.num_row_, 1.0 / (8 * max_bin));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ struct GPUSketcher {
|
|||||||
bool has_weights_{false};
|
bool has_weights_{false};
|
||||||
size_t row_stride_{0};
|
size_t row_stride_{0};
|
||||||
|
|
||||||
tree::TrainParam param_;
|
const int max_bin_;
|
||||||
SketchContainer *sketch_container_;
|
SketchContainer *sketch_container_;
|
||||||
dh::device_vector<size_t> row_ptrs_{};
|
dh::device_vector<size_t> row_ptrs_{};
|
||||||
dh::device_vector<Entry> entries_{};
|
dh::device_vector<Entry> entries_{};
|
||||||
@ -148,11 +148,11 @@ struct GPUSketcher {
|
|||||||
public:
|
public:
|
||||||
DeviceShard(int device,
|
DeviceShard(int device,
|
||||||
bst_uint n_rows,
|
bst_uint n_rows,
|
||||||
tree::TrainParam param,
|
int max_bin,
|
||||||
SketchContainer* sketch_container) :
|
SketchContainer* sketch_container) :
|
||||||
device_(device),
|
device_(device),
|
||||||
n_rows_(n_rows),
|
n_rows_(n_rows),
|
||||||
param_(std::move(param)),
|
max_bin_(max_bin),
|
||||||
sketch_container_(sketch_container) {
|
sketch_container_(sketch_container) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ struct GPUSketcher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
constexpr int kFactor = 8;
|
constexpr int kFactor = 8;
|
||||||
double eps = 1.0 / (kFactor * param_.max_bin);
|
double eps = 1.0 / (kFactor * max_bin_);
|
||||||
size_t dummy_nlevel;
|
size_t dummy_nlevel;
|
||||||
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
||||||
|
|
||||||
@ -362,7 +362,7 @@ struct GPUSketcher {
|
|||||||
// add cuts into sketches
|
// add cuts into sketches
|
||||||
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
||||||
#pragma omp parallel for default(none) schedule(static) \
|
#pragma omp parallel for default(none) schedule(static) \
|
||||||
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||||
WXQSketch::SummaryContainer summary;
|
WXQSketch::SummaryContainer summary;
|
||||||
summary.Reserve(n_cuts_);
|
summary.Reserve(n_cuts_);
|
||||||
@ -403,10 +403,8 @@ struct GPUSketcher {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
|
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
|
||||||
auto device = generic_param_.gpu_id;
|
|
||||||
|
|
||||||
// create device shard
|
// create device shard
|
||||||
shard_.reset(new DeviceShard(device, batch.Size(), param_, sketch_container_.get()));
|
shard_.reset(new DeviceShard(device_, batch.Size(), max_bin_, sketch_container_.get()));
|
||||||
|
|
||||||
// compute sketches for the shard
|
// compute sketches for the shard
|
||||||
shard_->Init(batch, info, gpu_batch_nrows_);
|
shard_->Init(batch, info, gpu_batch_nrows_);
|
||||||
@ -417,9 +415,8 @@ struct GPUSketcher {
|
|||||||
row_stride_ = shard_->GetRowStride();
|
row_stride_ = shard_->GetRowStride();
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUSketcher(const tree::TrainParam ¶m, const GenericParameter &generic_param, int gpu_nrows)
|
GPUSketcher(int device, int max_bin, int gpu_nrows)
|
||||||
: param_(param), generic_param_(generic_param), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {
|
: device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {}
|
||||||
}
|
|
||||||
|
|
||||||
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
|
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
|
||||||
* for the entire dataset */
|
* for the entire dataset */
|
||||||
@ -427,29 +424,31 @@ struct GPUSketcher {
|
|||||||
const MetaInfo &info = dmat->Info();
|
const MetaInfo &info = dmat->Info();
|
||||||
|
|
||||||
row_stride_ = 0;
|
row_stride_ = 0;
|
||||||
sketch_container_.reset(new SketchContainer(param_, dmat));
|
sketch_container_.reset(new SketchContainer(max_bin_, dmat));
|
||||||
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||||
this->SketchBatch(batch, info);
|
this->SketchBatch(batch, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
hmat->Init(&sketch_container_->sketches_, param_.max_bin);
|
hmat->Init(&sketch_container_->sketches_, max_bin_);
|
||||||
|
|
||||||
return row_stride_;
|
return row_stride_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<DeviceShard> shard_;
|
std::unique_ptr<DeviceShard> shard_;
|
||||||
const tree::TrainParam ¶m_;
|
const int device_;
|
||||||
const GenericParameter &generic_param_;
|
const int max_bin_;
|
||||||
int gpu_batch_nrows_;
|
int gpu_batch_nrows_;
|
||||||
size_t row_stride_;
|
size_t row_stride_;
|
||||||
std::unique_ptr<SketchContainer> sketch_container_;
|
std::unique_ptr<SketchContainer> sketch_container_;
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t DeviceSketch
|
size_t DeviceSketch(int device,
|
||||||
(const tree::TrainParam ¶m, const GenericParameter &learner_param, int gpu_batch_nrows,
|
int max_bin,
|
||||||
DMatrix *dmat, HistogramCuts *hmat) {
|
int gpu_batch_nrows,
|
||||||
GPUSketcher sketcher(param, learner_param, gpu_batch_nrows);
|
DMatrix* dmat,
|
||||||
|
HistogramCuts* hmat) {
|
||||||
|
GPUSketcher sketcher(device, max_bin, gpu_batch_nrows);
|
||||||
// We only need to return the result in HistogramCuts container, so it is safe to
|
// We only need to return the result in HistogramCuts container, so it is safe to
|
||||||
// use a pointer of local HistogramCutsDense
|
// use a pointer of local HistogramCutsDense
|
||||||
DenseCuts dense_cuts(hmat);
|
DenseCuts dense_cuts(hmat);
|
||||||
|
|||||||
@ -290,10 +290,11 @@ class DenseCuts : public CutsBuilder {
|
|||||||
*
|
*
|
||||||
* \return The row stride across the entire dataset.
|
* \return The row stride across the entire dataset.
|
||||||
*/
|
*/
|
||||||
size_t DeviceSketch
|
size_t DeviceSketch(int device,
|
||||||
(const tree::TrainParam& param, const GenericParameter &learner_param, int gpu_batch_nrows,
|
int max_bin,
|
||||||
DMatrix* dmat, HistogramCuts* hmat);
|
int gpu_batch_nrows,
|
||||||
|
DMatrix* dmat,
|
||||||
|
HistogramCuts* hmat);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief preprocessed global index matrix, in CSR format
|
* \brief preprocessed global index matrix, in CSR format
|
||||||
|
|||||||
25
src/data/ellpack_page.cc
Normal file
25
src/data/ellpack_page.cc
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2019 XGBoost contributors
|
||||||
|
*
|
||||||
|
* \file ellpack_page.cc
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_USE_CUDA
|
||||||
|
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
// dummy implementation of ELlpackPage in case CUDA is not used
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
class EllpackPageImpl {};
|
||||||
|
|
||||||
|
EllpackPage::EllpackPage(DMatrix* dmat) {
|
||||||
|
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
|
||||||
|
}
|
||||||
|
|
||||||
|
EllpackPage::~EllpackPage() {
|
||||||
|
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // XGBOOST_USE_CUDA
|
||||||
197
src/data/ellpack_page.cu
Normal file
197
src/data/ellpack_page.cu
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2019 XGBoost contributors
|
||||||
|
*
|
||||||
|
* \file ellpack_page.cu
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
#include "./ellpack_page.cuh"
|
||||||
|
#include "../common/hist_util.h"
|
||||||
|
#include "../common/random.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
EllpackPage::EllpackPage(DMatrix* dmat) : impl_{new EllpackPageImpl(dmat)} {}
|
||||||
|
|
||||||
|
EllpackPage::~EllpackPage() = default;
|
||||||
|
|
||||||
|
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat) : dmat_{dmat} {}
|
||||||
|
|
||||||
|
// Bin each input data entry, store the bin indices in compressed form.
|
||||||
|
template<typename std::enable_if<true, int>::type = 0>
|
||||||
|
__global__ void CompressBinEllpackKernel(
|
||||||
|
common::CompressedBufferWriter wr,
|
||||||
|
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
||||||
|
const size_t* __restrict__ row_ptrs, // row offset of input data
|
||||||
|
const Entry* __restrict__ entries, // One batch of input data
|
||||||
|
const float* __restrict__ cuts, // HistogramCuts::cut
|
||||||
|
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
|
||||||
|
size_t base_row, // batch_row_begin
|
||||||
|
size_t n_rows,
|
||||||
|
size_t row_stride,
|
||||||
|
unsigned int null_gidx_value) {
|
||||||
|
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
|
||||||
|
if (irow >= n_rows || ifeature >= row_stride) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
|
||||||
|
unsigned int bin = null_gidx_value;
|
||||||
|
if (ifeature < row_length) {
|
||||||
|
Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature];
|
||||||
|
int feature = entry.index;
|
||||||
|
float fvalue = entry.fvalue;
|
||||||
|
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
|
||||||
|
const float *feature_cuts = &cuts[cut_rows[feature]];
|
||||||
|
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
||||||
|
// Assigning the bin in current entry.
|
||||||
|
// S.t.: fvalue < feature_cuts[bin]
|
||||||
|
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
|
||||||
|
if (bin >= ncuts) {
|
||||||
|
bin = ncuts - 1;
|
||||||
|
}
|
||||||
|
// Add the number of bins in previous features.
|
||||||
|
bin += cut_rows[feature];
|
||||||
|
}
|
||||||
|
// Write to gidx buffer.
|
||||||
|
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
||||||
|
}
|
||||||
|
|
||||||
|
void EllpackPageImpl::Init(int device, int max_bin, int gpu_batch_nrows) {
|
||||||
|
if (initialised_) return;
|
||||||
|
|
||||||
|
monitor_.Init("ellpack_page");
|
||||||
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
|
|
||||||
|
monitor_.StartCuda("Quantiles");
|
||||||
|
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||||
|
common::HistogramCuts hmat;
|
||||||
|
size_t row_stride = common::DeviceSketch(device, max_bin, gpu_batch_nrows, dmat_, &hmat);
|
||||||
|
monitor_.StopCuda("Quantiles");
|
||||||
|
|
||||||
|
const auto& info = dmat_->Info();
|
||||||
|
auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_;
|
||||||
|
|
||||||
|
// Init global data for each shard
|
||||||
|
monitor_.StartCuda("InitCompressedData");
|
||||||
|
InitCompressedData(device, hmat, row_stride, is_dense);
|
||||||
|
monitor_.StopCuda("InitCompressedData");
|
||||||
|
|
||||||
|
monitor_.StartCuda("BinningCompression");
|
||||||
|
DeviceHistogramBuilderState hist_builder_row_state(info.num_row_);
|
||||||
|
for (const auto& batch : dmat_->GetBatches<SparsePage>()) {
|
||||||
|
hist_builder_row_state.BeginBatch(batch);
|
||||||
|
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice());
|
||||||
|
hist_builder_row_state.EndBatch();
|
||||||
|
}
|
||||||
|
monitor_.StopCuda("BinningCompression");
|
||||||
|
|
||||||
|
initialised_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EllpackPageImpl::InitCompressedData(int device,
|
||||||
|
const common::HistogramCuts& hmat,
|
||||||
|
size_t row_stride,
|
||||||
|
bool is_dense) {
|
||||||
|
n_bins = hmat.Ptrs().back();
|
||||||
|
int null_gidx_value = hmat.Ptrs().back();
|
||||||
|
int num_symbols = n_bins + 1;
|
||||||
|
|
||||||
|
// minimum value for each feature.
|
||||||
|
common::Span<bst_float> min_fvalue;
|
||||||
|
|
||||||
|
// Required buffer size for storing data matrix in ELLPack format.
|
||||||
|
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
|
||||||
|
row_stride * dmat_->Info().num_row_, num_symbols);
|
||||||
|
|
||||||
|
ba.Allocate(device,
|
||||||
|
&feature_segments, hmat.Ptrs().size(),
|
||||||
|
&gidx_fvalue_map, hmat.Values().size(),
|
||||||
|
&min_fvalue, hmat.MinValues().size(),
|
||||||
|
&gidx_buffer, compressed_size_bytes);
|
||||||
|
|
||||||
|
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
|
||||||
|
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
|
||||||
|
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
|
||||||
|
thrust::fill(
|
||||||
|
thrust::device_pointer_cast(gidx_buffer.data()),
|
||||||
|
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
|
||||||
|
|
||||||
|
ellpack_matrix.Init(feature_segments,
|
||||||
|
min_fvalue,
|
||||||
|
gidx_fvalue_map,
|
||||||
|
row_stride,
|
||||||
|
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
|
||||||
|
is_dense,
|
||||||
|
null_gidx_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void EllpackPageImpl::CreateHistIndices(int device,
|
||||||
|
const SparsePage& row_batch,
|
||||||
|
const RowStateOnDevice& device_row_state) {
|
||||||
|
// Has any been allocated for me in this batch?
|
||||||
|
if (!device_row_state.rows_to_process_from_batch) return;
|
||||||
|
|
||||||
|
unsigned int null_gidx_value = n_bins;
|
||||||
|
size_t row_stride = this->ellpack_matrix.row_stride;
|
||||||
|
|
||||||
|
const auto &offset_vec = row_batch.offset.ConstHostVector();
|
||||||
|
|
||||||
|
int num_symbols = n_bins + 1;
|
||||||
|
// bin and compress entries in batches of rows
|
||||||
|
size_t gpu_batch_nrows = std::min(
|
||||||
|
dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
|
||||||
|
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
|
||||||
|
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
|
||||||
|
|
||||||
|
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
|
||||||
|
gpu_batch_nrows);
|
||||||
|
|
||||||
|
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||||
|
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
|
||||||
|
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
|
||||||
|
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
|
||||||
|
batch_row_end = device_row_state.rows_to_process_from_batch;
|
||||||
|
}
|
||||||
|
size_t batch_nrows = batch_row_end - batch_row_begin;
|
||||||
|
|
||||||
|
const auto ent_cnt_begin =
|
||||||
|
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin];
|
||||||
|
const auto ent_cnt_end =
|
||||||
|
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];
|
||||||
|
|
||||||
|
/*! \brief row offset in SparsePage (the input data). */
|
||||||
|
dh::device_vector<size_t> row_ptrs(batch_nrows+1);
|
||||||
|
thrust::copy(
|
||||||
|
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
|
||||||
|
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
|
||||||
|
row_ptrs.begin());
|
||||||
|
|
||||||
|
// number of entries in this batch.
|
||||||
|
size_t n_entries = ent_cnt_end - ent_cnt_begin;
|
||||||
|
dh::device_vector<Entry> entries_d(n_entries);
|
||||||
|
// copy data entries to device.
|
||||||
|
dh::safe_cuda(cudaMemcpy(entries_d.data().get(),
|
||||||
|
data_vec.data() + ent_cnt_begin,
|
||||||
|
n_entries * sizeof(Entry),
|
||||||
|
cudaMemcpyDefault));
|
||||||
|
const dim3 block3(32, 8, 1); // 256 threads
|
||||||
|
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||||
|
common::DivRoundUp(row_stride, block3.y),
|
||||||
|
1);
|
||||||
|
CompressBinEllpackKernel<<<grid3, block3>>>(
|
||||||
|
common::CompressedBufferWriter(num_symbols),
|
||||||
|
gidx_buffer.data(),
|
||||||
|
row_ptrs.data().get(),
|
||||||
|
entries_d.data().get(),
|
||||||
|
gidx_fvalue_map.data(),
|
||||||
|
feature_segments.data(),
|
||||||
|
device_row_state.total_rows_processed + batch_row_begin,
|
||||||
|
batch_nrows,
|
||||||
|
row_stride,
|
||||||
|
null_gidx_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
203
src/data/ellpack_page.cuh
Normal file
203
src/data/ellpack_page.cuh
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2019 by XGBoost Contributors
|
||||||
|
*
|
||||||
|
* \file ellpack_page.cuh
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||||
|
#define XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||||
|
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
#include "../common/compressed_iterator.h"
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
#include "../common/hist_util.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
// Find a gidx value for a given feature otherwise return -1 if not found
|
||||||
|
__forceinline__ __device__ int BinarySearchRow(
|
||||||
|
bst_uint begin, bst_uint end,
|
||||||
|
common::CompressedIterator<uint32_t> data,
|
||||||
|
int const fidx_begin, int const fidx_end) {
|
||||||
|
bst_uint previous_middle = UINT32_MAX;
|
||||||
|
while (end != begin) {
|
||||||
|
auto middle = begin + (end - begin) / 2;
|
||||||
|
if (middle == previous_middle) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
previous_middle = middle;
|
||||||
|
|
||||||
|
auto gidx = data[middle];
|
||||||
|
|
||||||
|
if (gidx >= fidx_begin && gidx < fidx_end) {
|
||||||
|
return gidx;
|
||||||
|
} else if (gidx < fidx_begin) {
|
||||||
|
begin = middle;
|
||||||
|
} else {
|
||||||
|
end = middle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Value is missing
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||||
|
* device. Does not own underlying memory and may be trivially copied into
|
||||||
|
* kernels.*/
|
||||||
|
struct ELLPackMatrix {
|
||||||
|
common::Span<uint32_t> feature_segments;
|
||||||
|
/*! \brief minimum value for each feature. */
|
||||||
|
common::Span<bst_float> min_fvalue;
|
||||||
|
/*! \brief Cut. */
|
||||||
|
common::Span<bst_float> gidx_fvalue_map;
|
||||||
|
/*! \brief row length for ELLPack. */
|
||||||
|
size_t row_stride{0};
|
||||||
|
common::CompressedIterator<uint32_t> gidx_iter;
|
||||||
|
int null_gidx_value;
|
||||||
|
|
||||||
|
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
|
||||||
|
|
||||||
|
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||||
|
// Given a row index and a feature index, returns the corresponding cut value
|
||||||
|
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
||||||
|
auto row_begin = row_stride * ridx;
|
||||||
|
auto row_end = row_begin + row_stride;
|
||||||
|
auto gidx = -1;
|
||||||
|
if (is_dense) {
|
||||||
|
gidx = gidx_iter[row_begin + fidx];
|
||||||
|
} else {
|
||||||
|
gidx =
|
||||||
|
BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx],
|
||||||
|
feature_segments[fidx + 1]);
|
||||||
|
}
|
||||||
|
if (gidx == -1) {
|
||||||
|
return nan("");
|
||||||
|
}
|
||||||
|
return gidx_fvalue_map[gidx];
|
||||||
|
}
|
||||||
|
void Init(common::Span<uint32_t> feature_segments,
|
||||||
|
common::Span<bst_float> min_fvalue,
|
||||||
|
common::Span<bst_float> gidx_fvalue_map, size_t row_stride,
|
||||||
|
common::CompressedIterator<uint32_t> gidx_iter, bool is_dense,
|
||||||
|
int null_gidx_value) {
|
||||||
|
this->feature_segments = feature_segments;
|
||||||
|
this->min_fvalue = min_fvalue;
|
||||||
|
this->gidx_fvalue_map = gidx_fvalue_map;
|
||||||
|
this->row_stride = row_stride;
|
||||||
|
this->gidx_iter = gidx_iter;
|
||||||
|
this->is_dense = is_dense;
|
||||||
|
this->null_gidx_value = null_gidx_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_dense;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Instances of this type are created while creating the histogram bins for the
|
||||||
|
// entire dataset across multiple sparse page batches. This keeps track of the number
|
||||||
|
// of rows to process from a batch and the position from which to process on each device.
|
||||||
|
struct RowStateOnDevice {
|
||||||
|
// Number of rows assigned to this device
|
||||||
|
size_t total_rows_assigned_to_device;
|
||||||
|
// Number of rows processed thus far
|
||||||
|
size_t total_rows_processed;
|
||||||
|
// Number of rows to process from the current sparse page batch
|
||||||
|
size_t rows_to_process_from_batch;
|
||||||
|
// Offset from the current sparse page batch to begin processing
|
||||||
|
size_t row_offset_in_current_batch;
|
||||||
|
|
||||||
|
explicit RowStateOnDevice(size_t total_rows)
|
||||||
|
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
|
||||||
|
rows_to_process_from_batch(0), row_offset_in_current_batch(0) {
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit RowStateOnDevice(size_t total_rows, size_t batch_rows)
|
||||||
|
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
|
||||||
|
rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance the row state by the number of rows processed
|
||||||
|
void Advance() {
|
||||||
|
total_rows_processed += rows_to_process_from_batch;
|
||||||
|
CHECK_LE(total_rows_processed, total_rows_assigned_to_device);
|
||||||
|
rows_to_process_from_batch = row_offset_in_current_batch = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// An instance of this type is created which keeps track of total number of rows to process,
|
||||||
|
// rows processed thus far, rows to process and the offset from the current sparse page batch
|
||||||
|
// to begin processing on each device
|
||||||
|
class DeviceHistogramBuilderState {
|
||||||
|
public:
|
||||||
|
explicit DeviceHistogramBuilderState(int n_rows) : device_row_state_(n_rows) {}
|
||||||
|
|
||||||
|
const RowStateOnDevice& GetRowStateOnDevice() const {
|
||||||
|
return device_row_state_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This method is invoked at the beginning of each sparse page batch. This distributes
|
||||||
|
// the rows in the sparse page to the device.
|
||||||
|
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
|
||||||
|
void BeginBatch(const SparsePage &batch) {
|
||||||
|
size_t rem_rows = batch.Size();
|
||||||
|
size_t row_offset_in_current_batch = 0;
|
||||||
|
|
||||||
|
// Do we have anymore left to process from this batch on this device?
|
||||||
|
if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) {
|
||||||
|
// There are still some rows that needs to be assigned to this device
|
||||||
|
device_row_state_.rows_to_process_from_batch =
|
||||||
|
std::min(
|
||||||
|
device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed,
|
||||||
|
rem_rows);
|
||||||
|
} else {
|
||||||
|
// All rows have been assigned to this device
|
||||||
|
device_row_state_.rows_to_process_from_batch = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch;
|
||||||
|
row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch;
|
||||||
|
rem_rows -= device_row_state_.rows_to_process_from_batch;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This method is invoked after completion of each sparse page batch
|
||||||
|
void EndBatch() {
|
||||||
|
device_row_state_.Advance();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
RowStateOnDevice device_row_state_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
class EllpackPageImpl {
|
||||||
|
public:
|
||||||
|
ELLPackMatrix ellpack_matrix;
|
||||||
|
int n_bins{};
|
||||||
|
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||||
|
common::Span<common::CompressedByteT> gidx_buffer;
|
||||||
|
|
||||||
|
explicit EllpackPageImpl(DMatrix* dmat);
|
||||||
|
void Init(int device, int max_bin, int gpu_batch_nrows);
|
||||||
|
void InitCompressedData(int device,
|
||||||
|
const common::HistogramCuts& hmat,
|
||||||
|
size_t row_stride,
|
||||||
|
bool is_dense);
|
||||||
|
void CreateHistIndices(int device,
|
||||||
|
const SparsePage& row_batch,
|
||||||
|
const RowStateOnDevice& device_row_state);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool initialised_{false};
|
||||||
|
DMatrix* dmat_;
|
||||||
|
common::Monitor monitor_;
|
||||||
|
dh::BulkAllocator ba;
|
||||||
|
|
||||||
|
/*! \brief Cut. */
|
||||||
|
common::Span<bst_float> gidx_fvalue_map;
|
||||||
|
/*! \brief row_ptr form HistogramCuts. */
|
||||||
|
common::Span<uint32_t> feature_segments;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||||
33
src/data/simple_batch_iterator.h
Normal file
33
src/data/simple_batch_iterator.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2019 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
|
||||||
|
#define XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
|
||||||
|
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace data {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||||
|
public:
|
||||||
|
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
|
||||||
|
T& operator*() override {
|
||||||
|
CHECK(page_ != nullptr);
|
||||||
|
return *page_;
|
||||||
|
}
|
||||||
|
const T& operator*() const override {
|
||||||
|
CHECK(page_ != nullptr);
|
||||||
|
return *page_;
|
||||||
|
}
|
||||||
|
void operator++() override { page_ = nullptr; }
|
||||||
|
bool AtEnd() const override { return page_ == nullptr; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
T* page_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace data
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
|
||||||
@ -6,6 +6,7 @@
|
|||||||
*/
|
*/
|
||||||
#include "./simple_dmatrix.h"
|
#include "./simple_dmatrix.h"
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
#include "./simple_batch_iterator.h"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -29,25 +30,6 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
|
|||||||
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
|
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
|
|
||||||
public:
|
|
||||||
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
|
|
||||||
T& operator*() override {
|
|
||||||
CHECK(page_ != nullptr);
|
|
||||||
return *page_;
|
|
||||||
}
|
|
||||||
const T& operator*() const override {
|
|
||||||
CHECK(page_ != nullptr);
|
|
||||||
return *page_;
|
|
||||||
}
|
|
||||||
void operator++() override { page_ = nullptr; }
|
|
||||||
bool AtEnd() const override { return page_ == nullptr; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
T* page_{nullptr};
|
|
||||||
};
|
|
||||||
|
|
||||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||||
// since csr is the default data structure so `source_` is always available.
|
// since csr is the default data structure so `source_` is always available.
|
||||||
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
|
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
|
||||||
@ -80,6 +62,16 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
|||||||
return BatchSet<SortedCSCPage>(begin_iter);
|
return BatchSet<SortedCSCPage>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches() {
|
||||||
|
// ELLPACK page doesn't exist, generate it
|
||||||
|
if (!ellpack_page_) {
|
||||||
|
ellpack_page_.reset(new EllpackPage(this));
|
||||||
|
}
|
||||||
|
auto begin_iter =
|
||||||
|
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||||
|
return BatchSet<EllpackPage>(begin_iter);
|
||||||
|
}
|
||||||
|
|
||||||
bool SimpleDMatrix::SingleColBlock() const { return true; }
|
bool SimpleDMatrix::SingleColBlock() const { return true; }
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -38,12 +38,14 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
BatchSet<SparsePage> GetRowBatches() override;
|
BatchSet<SparsePage> GetRowBatches() override;
|
||||||
BatchSet<CSCPage> GetColumnBatches() override;
|
BatchSet<CSCPage> GetColumnBatches() override;
|
||||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||||
|
BatchSet<EllpackPage> GetEllpackBatches() override;
|
||||||
|
|
||||||
// source data pointer.
|
// source data pointer.
|
||||||
std::unique_ptr<DataSource<SparsePage>> source_;
|
std::unique_ptr<DataSource<SparsePage>> source_;
|
||||||
|
|
||||||
std::unique_ptr<CSCPage> column_page_;
|
std::unique_ptr<CSCPage> column_page_;
|
||||||
std::unique_ptr<SortedCSCPage> sorted_column_page_;
|
std::unique_ptr<SortedCSCPage> sorted_column_page_;
|
||||||
|
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -10,6 +10,8 @@
|
|||||||
#if DMLC_ENABLE_STD_THREAD
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
#include "./sparse_page_dmatrix.h"
|
#include "./sparse_page_dmatrix.h"
|
||||||
|
|
||||||
|
#include "./simple_batch_iterator.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
@ -72,6 +74,16 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
|||||||
return BatchSet<SortedCSCPage>(begin_iter);
|
return BatchSet<SortedCSCPage>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches() {
|
||||||
|
// ELLPACK page doesn't exist, generate it
|
||||||
|
if (!ellpack_page_) {
|
||||||
|
ellpack_page_.reset(new EllpackPage(this));
|
||||||
|
}
|
||||||
|
auto begin_iter =
|
||||||
|
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||||
|
return BatchSet<EllpackPage>(begin_iter);
|
||||||
|
}
|
||||||
|
|
||||||
float SparsePageDMatrix::GetColDensity(size_t cidx) {
|
float SparsePageDMatrix::GetColDensity(size_t cidx) {
|
||||||
// Finds densities if we don't already have them
|
// Finds densities if we don't already have them
|
||||||
if (col_density_.empty()) {
|
if (col_density_.empty()) {
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
|
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
|
||||||
std::string cache_info)
|
std::string cache_info)
|
||||||
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
|
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
|
||||||
virtual ~SparsePageDMatrix() = default;
|
~SparsePageDMatrix() override = default;
|
||||||
|
|
||||||
MetaInfo& Info() override;
|
MetaInfo& Info() override;
|
||||||
|
|
||||||
@ -38,11 +38,13 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
BatchSet<SparsePage> GetRowBatches() override;
|
BatchSet<SparsePage> GetRowBatches() override;
|
||||||
BatchSet<CSCPage> GetColumnBatches() override;
|
BatchSet<CSCPage> GetColumnBatches() override;
|
||||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||||
|
BatchSet<EllpackPage> GetEllpackBatches() override;
|
||||||
|
|
||||||
// source data pointers.
|
// source data pointers.
|
||||||
std::unique_ptr<DataSource<SparsePage>> row_source_;
|
std::unique_ptr<DataSource<SparsePage>> row_source_;
|
||||||
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
|
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
|
||||||
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
|
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
|
||||||
|
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||||
// the cache prefix
|
// the cache prefix
|
||||||
std::string cache_info_;
|
std::string cache_info_;
|
||||||
// Store column densities to avoid recalculating
|
// Store column densities to avoid recalculating
|
||||||
|
|||||||
@ -21,6 +21,7 @@
|
|||||||
#include "../common/host_device_vector.h"
|
#include "../common/host_device_vector.h"
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
#include "../common/span.h"
|
#include "../common/span.h"
|
||||||
|
#include "../data/ellpack_page.cuh"
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
#include "updater_gpu_common.cuh"
|
#include "updater_gpu_common.cuh"
|
||||||
#include "constraints.cuh"
|
#include "constraints.cuh"
|
||||||
@ -108,83 +109,6 @@ inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find a gidx value for a given feature otherwise return -1 if not found
|
|
||||||
__forceinline__ __device__ int BinarySearchRow(
|
|
||||||
bst_uint begin, bst_uint end,
|
|
||||||
common::CompressedIterator<uint32_t> data,
|
|
||||||
int const fidx_begin, int const fidx_end) {
|
|
||||||
bst_uint previous_middle = UINT32_MAX;
|
|
||||||
while (end != begin) {
|
|
||||||
auto middle = begin + (end - begin) / 2;
|
|
||||||
if (middle == previous_middle) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
previous_middle = middle;
|
|
||||||
|
|
||||||
auto gidx = data[middle];
|
|
||||||
|
|
||||||
if (gidx >= fidx_begin && gidx < fidx_end) {
|
|
||||||
return gidx;
|
|
||||||
} else if (gidx < fidx_begin) {
|
|
||||||
begin = middle;
|
|
||||||
} else {
|
|
||||||
end = middle;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Value is missing
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
|
||||||
* device. Does not own underlying memory and may be trivially copied into
|
|
||||||
* kernels.*/
|
|
||||||
struct ELLPackMatrix {
|
|
||||||
common::Span<uint32_t> feature_segments;
|
|
||||||
/*! \brief minimum value for each feature. */
|
|
||||||
common::Span<bst_float> min_fvalue;
|
|
||||||
/*! \brief Cut. */
|
|
||||||
common::Span<bst_float> gidx_fvalue_map;
|
|
||||||
/*! \brief row length for ELLPack. */
|
|
||||||
size_t row_stride{0};
|
|
||||||
common::CompressedIterator<uint32_t> gidx_iter;
|
|
||||||
bool is_dense;
|
|
||||||
int null_gidx_value;
|
|
||||||
|
|
||||||
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
|
|
||||||
|
|
||||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
|
||||||
// Given a row index and a feature index, returns the corresponding cut value
|
|
||||||
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
|
||||||
auto row_begin = row_stride * ridx;
|
|
||||||
auto row_end = row_begin + row_stride;
|
|
||||||
auto gidx = -1;
|
|
||||||
if (is_dense) {
|
|
||||||
gidx = gidx_iter[row_begin + fidx];
|
|
||||||
} else {
|
|
||||||
gidx =
|
|
||||||
BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx],
|
|
||||||
feature_segments[fidx + 1]);
|
|
||||||
}
|
|
||||||
if (gidx == -1) {
|
|
||||||
return nan("");
|
|
||||||
}
|
|
||||||
return gidx_fvalue_map[gidx];
|
|
||||||
}
|
|
||||||
void Init(common::Span<uint32_t> feature_segments,
|
|
||||||
common::Span<bst_float> min_fvalue,
|
|
||||||
common::Span<bst_float> gidx_fvalue_map, size_t row_stride,
|
|
||||||
common::CompressedIterator<uint32_t> gidx_iter, bool is_dense,
|
|
||||||
int null_gidx_value) {
|
|
||||||
this->feature_segments = feature_segments;
|
|
||||||
this->min_fvalue = min_fvalue;
|
|
||||||
this->gidx_fvalue_map = gidx_fvalue_map;
|
|
||||||
this->row_stride = row_stride;
|
|
||||||
this->gidx_iter = gidx_iter;
|
|
||||||
this->is_dense = is_dense;
|
|
||||||
this->null_gidx_value = null_gidx_value;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// With constraints
|
// With constraints
|
||||||
template <typename GradientPairT>
|
template <typename GradientPairT>
|
||||||
XGBOOST_DEVICE float inline LossChangeMissing(
|
XGBOOST_DEVICE float inline LossChangeMissing(
|
||||||
@ -247,7 +171,7 @@ template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
|||||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
||||||
__device__ void EvaluateFeature(
|
__device__ void EvaluateFeature(
|
||||||
int fidx, common::Span<const GradientSumT> node_histogram,
|
int fidx, common::Span<const GradientSumT> node_histogram,
|
||||||
const ELLPackMatrix& matrix,
|
const xgboost::ELLPackMatrix& matrix,
|
||||||
DeviceSplitCandidate* best_split, // shared memory storing best split
|
DeviceSplitCandidate* best_split, // shared memory storing best split
|
||||||
const DeviceNodeStats& node, const GPUTrainingParam& param,
|
const DeviceNodeStats& node, const GPUTrainingParam& param,
|
||||||
TempStorageT* temp_storage, // temp memory for cub operations
|
TempStorageT* temp_storage, // temp memory for cub operations
|
||||||
@ -322,7 +246,7 @@ __global__ void EvaluateSplitKernel(
|
|||||||
common::Span<const GradientSumT> node_histogram, // histogram for gradients
|
common::Span<const GradientSumT> node_histogram, // histogram for gradients
|
||||||
common::Span<const int> feature_set, // Selected features
|
common::Span<const int> feature_set, // Selected features
|
||||||
DeviceNodeStats node,
|
DeviceNodeStats node,
|
||||||
ELLPackMatrix matrix,
|
xgboost::ELLPackMatrix matrix,
|
||||||
GPUTrainingParam gpu_param,
|
GPUTrainingParam gpu_param,
|
||||||
common::Span<DeviceSplitCandidate> split_candidates, // resulting split
|
common::Span<DeviceSplitCandidate> split_candidates, // resulting split
|
||||||
ValueConstraint value_constraint,
|
ValueConstraint value_constraint,
|
||||||
@ -473,48 +397,8 @@ struct CalcWeightTrainParam {
|
|||||||
learning_rate(p.learning_rate) {}
|
learning_rate(p.learning_rate) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Bin each input data entry, store the bin indices in compressed form.
|
|
||||||
template<typename std::enable_if<true, int>::type = 0>
|
|
||||||
__global__ void CompressBinEllpackKernel(
|
|
||||||
common::CompressedBufferWriter wr,
|
|
||||||
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
|
||||||
const size_t* __restrict__ row_ptrs, // row offset of input data
|
|
||||||
const Entry* __restrict__ entries, // One batch of input data
|
|
||||||
const float* __restrict__ cuts, // HistogramCuts::cut
|
|
||||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
|
|
||||||
size_t base_row, // batch_row_begin
|
|
||||||
size_t n_rows,
|
|
||||||
size_t row_stride,
|
|
||||||
unsigned int null_gidx_value) {
|
|
||||||
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
|
|
||||||
if (irow >= n_rows || ifeature >= row_stride) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
|
|
||||||
unsigned int bin = null_gidx_value;
|
|
||||||
if (ifeature < row_length) {
|
|
||||||
Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature];
|
|
||||||
int feature = entry.index;
|
|
||||||
float fvalue = entry.fvalue;
|
|
||||||
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
|
|
||||||
const float *feature_cuts = &cuts[cut_rows[feature]];
|
|
||||||
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
|
||||||
// Assigning the bin in current entry.
|
|
||||||
// S.t.: fvalue < feature_cuts[bin]
|
|
||||||
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
|
|
||||||
if (bin >= ncuts) {
|
|
||||||
bin = ncuts - 1;
|
|
||||||
}
|
|
||||||
// Add the number of bins in previous features.
|
|
||||||
bin += cut_rows[feature];
|
|
||||||
}
|
|
||||||
// Write to gidx buffer.
|
|
||||||
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
__global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
__global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix,
|
||||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||||
GradientSumT* d_node_hist,
|
GradientSumT* d_node_hist,
|
||||||
const GradientPair* d_gpair, size_t n_elements,
|
const GradientPair* d_gpair, size_t n_elements,
|
||||||
@ -548,59 +432,17 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instances of this type are created while creating the histogram bins for the
|
|
||||||
// entire dataset across multiple sparse page batches. This keeps track of the number
|
|
||||||
// of rows to process from a batch and the position from which to process on each device.
|
|
||||||
struct RowStateOnDevice {
|
|
||||||
// Number of rows assigned to this device
|
|
||||||
size_t total_rows_assigned_to_device;
|
|
||||||
// Number of rows processed thus far
|
|
||||||
size_t total_rows_processed;
|
|
||||||
// Number of rows to process from the current sparse page batch
|
|
||||||
size_t rows_to_process_from_batch;
|
|
||||||
// Offset from the current sparse page batch to begin processing
|
|
||||||
size_t row_offset_in_current_batch;
|
|
||||||
|
|
||||||
explicit RowStateOnDevice(size_t total_rows)
|
|
||||||
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
|
|
||||||
rows_to_process_from_batch(0), row_offset_in_current_batch(0) {
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit RowStateOnDevice(size_t total_rows, size_t batch_rows)
|
|
||||||
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
|
|
||||||
rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Advance the row state by the number of rows processed
|
|
||||||
void Advance() {
|
|
||||||
total_rows_processed += rows_to_process_from_batch;
|
|
||||||
CHECK_LE(total_rows_processed, total_rows_assigned_to_device);
|
|
||||||
rows_to_process_from_batch = row_offset_in_current_batch = 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Manage memory for a single GPU
|
// Manage memory for a single GPU
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
struct DeviceShard {
|
struct DeviceShard {
|
||||||
int n_bins;
|
|
||||||
int device_id;
|
int device_id;
|
||||||
|
EllpackPageImpl* page;
|
||||||
|
|
||||||
dh::BulkAllocator ba;
|
dh::BulkAllocator ba;
|
||||||
|
|
||||||
ELLPackMatrix ellpack_matrix;
|
|
||||||
|
|
||||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||||
DeviceHistogram<GradientSumT> hist{};
|
DeviceHistogram<GradientSumT> hist{};
|
||||||
|
|
||||||
/*! \brief row_ptr form HistogramCuts. */
|
|
||||||
common::Span<uint32_t> feature_segments;
|
|
||||||
/*! \brief minimum value for each feature. */
|
|
||||||
common::Span<bst_float> min_fvalue;
|
|
||||||
/*! \brief Cut. */
|
|
||||||
common::Span<bst_float> gidx_fvalue_map;
|
|
||||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
|
||||||
common::Span<common::CompressedByteT> gidx_buffer;
|
|
||||||
|
|
||||||
/*! \brief Gradient pair for each row. */
|
/*! \brief Gradient pair for each row. */
|
||||||
common::Span<GradientPair> gpair;
|
common::Span<GradientPair> gpair;
|
||||||
|
|
||||||
@ -631,11 +473,15 @@ struct DeviceShard {
|
|||||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||||
std::unique_ptr<ExpandQueue> qexpand;
|
std::unique_ptr<ExpandQueue> qexpand;
|
||||||
|
|
||||||
DeviceShard(int _device_id, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed,
|
DeviceShard(int _device_id,
|
||||||
|
EllpackPageImpl* _page,
|
||||||
|
bst_uint _n_rows,
|
||||||
|
TrainParam _param,
|
||||||
|
uint32_t column_sampler_seed,
|
||||||
uint32_t n_features)
|
uint32_t n_features)
|
||||||
: device_id(_device_id),
|
: device_id(_device_id),
|
||||||
|
page(_page),
|
||||||
n_rows(_n_rows),
|
n_rows(_n_rows),
|
||||||
n_bins(0),
|
|
||||||
param(std::move(_param)),
|
param(std::move(_param)),
|
||||||
prediction_cache_initialised(false),
|
prediction_cache_initialised(false),
|
||||||
column_sampler(column_sampler_seed),
|
column_sampler(column_sampler_seed),
|
||||||
@ -643,12 +489,7 @@ struct DeviceShard {
|
|||||||
monitor.Init(std::string("DeviceShard") + std::to_string(device_id));
|
monitor.Init(std::string("DeviceShard") + std::to_string(device_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitCompressedData(
|
void InitHistogram();
|
||||||
const common::HistogramCuts& hmat, size_t row_stride, bool is_dense);
|
|
||||||
|
|
||||||
void CreateHistIndices(
|
|
||||||
const SparsePage &row_batch, const common::HistogramCuts &hmat,
|
|
||||||
const RowStateOnDevice &device_row_state, int rows_per_batch);
|
|
||||||
|
|
||||||
~DeviceShard() { // NOLINT
|
~DeviceShard() { // NOLINT
|
||||||
dh::safe_cuda(cudaSetDevice(device_id));
|
dh::safe_cuda(cudaSetDevice(device_id));
|
||||||
@ -762,7 +603,7 @@ struct DeviceShard {
|
|||||||
int constexpr kBlockThreads = 256;
|
int constexpr kBlockThreads = 256;
|
||||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>
|
EvaluateSplitKernel<kBlockThreads, GradientSumT>
|
||||||
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
||||||
hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix,
|
hist.GetNodeHistogram(nidx), d_feature_set, node, page->ellpack_matrix,
|
||||||
gpu_param, d_split_candidates, node_value_constraints[nidx],
|
gpu_param, d_split_candidates, node_value_constraints[nidx],
|
||||||
monotone_constraints);
|
monotone_constraints);
|
||||||
|
|
||||||
@ -788,11 +629,11 @@ struct DeviceShard {
|
|||||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||||
auto d_gpair = gpair.data();
|
auto d_gpair = gpair.data();
|
||||||
|
|
||||||
auto n_elements = d_ridx.size() * ellpack_matrix.row_stride;
|
auto n_elements = d_ridx.size() * page->ellpack_matrix.row_stride;
|
||||||
|
|
||||||
const size_t smem_size =
|
const size_t smem_size =
|
||||||
use_shared_memory_histograms
|
use_shared_memory_histograms
|
||||||
? sizeof(GradientSumT) * ellpack_matrix.BinCount()
|
? sizeof(GradientSumT) * page->ellpack_matrix.BinCount()
|
||||||
: 0;
|
: 0;
|
||||||
const int items_per_thread = 8;
|
const int items_per_thread = 8;
|
||||||
const int block_threads = 256;
|
const int block_threads = 256;
|
||||||
@ -802,7 +643,7 @@ struct DeviceShard {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
||||||
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
|
page->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
|
||||||
use_shared_memory_histograms);
|
use_shared_memory_histograms);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -812,7 +653,7 @@ struct DeviceShard {
|
|||||||
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
|
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
|
||||||
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
|
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
|
||||||
|
|
||||||
dh::LaunchN(device_id, n_bins, [=] __device__(size_t idx) {
|
dh::LaunchN(device_id, page->n_bins, [=] __device__(size_t idx) {
|
||||||
d_node_hist_subtraction[idx] =
|
d_node_hist_subtraction[idx] =
|
||||||
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
||||||
});
|
});
|
||||||
@ -827,7 +668,7 @@ struct DeviceShard {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdatePosition(int nidx, RegTree::Node split_node) {
|
void UpdatePosition(int nidx, RegTree::Node split_node) {
|
||||||
auto d_matrix = ellpack_matrix;
|
auto d_matrix = page->ellpack_matrix;
|
||||||
|
|
||||||
row_partitioner->UpdatePosition(
|
row_partitioner->UpdatePosition(
|
||||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||||
@ -859,7 +700,7 @@ struct DeviceShard {
|
|||||||
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
|
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
|
||||||
d_nodes.size() * sizeof(RegTree::Node),
|
d_nodes.size() * sizeof(RegTree::Node),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
auto d_matrix = ellpack_matrix;
|
auto d_matrix = page->ellpack_matrix;
|
||||||
row_partitioner->FinalisePosition(
|
row_partitioner->FinalisePosition(
|
||||||
[=] __device__(bst_uint ridx, int position) {
|
[=] __device__(bst_uint ridx, int position) {
|
||||||
auto node = d_nodes[position];
|
auto node = d_nodes[position];
|
||||||
@ -922,7 +763,7 @@ struct DeviceShard {
|
|||||||
reducer->AllReduceSum(
|
reducer->AllReduceSum(
|
||||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||||
ellpack_matrix.BinCount() *
|
page->ellpack_matrix.BinCount() *
|
||||||
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||||
reducer->Synchronize();
|
reducer->Synchronize();
|
||||||
|
|
||||||
@ -1097,11 +938,7 @@ struct DeviceShard {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
inline void DeviceShard<GradientSumT>::InitHistogram() {
|
||||||
const common::HistogramCuts &hmat, size_t row_stride, bool is_dense) {
|
|
||||||
n_bins = hmat.Ptrs().back();
|
|
||||||
int null_gidx_value = hmat.Ptrs().back();
|
|
||||||
|
|
||||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||||
"gpu_hist.";
|
"gpu_hist.";
|
||||||
@ -1113,163 +950,25 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
|||||||
&gpair, n_rows,
|
&gpair, n_rows,
|
||||||
&prediction_cache, n_rows,
|
&prediction_cache, n_rows,
|
||||||
&node_sum_gradients_d, max_nodes,
|
&node_sum_gradients_d, max_nodes,
|
||||||
&feature_segments, hmat.Ptrs().size(),
|
|
||||||
&gidx_fvalue_map, hmat.Values().size(),
|
|
||||||
&min_fvalue, hmat.MinValues().size(),
|
|
||||||
&monotone_constraints, param.monotone_constraints.size());
|
&monotone_constraints, param.monotone_constraints.size());
|
||||||
|
|
||||||
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
|
|
||||||
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
|
|
||||||
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
|
|
||||||
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
||||||
|
|
||||||
node_sum_gradients.resize(max_nodes);
|
node_sum_gradients.resize(max_nodes);
|
||||||
|
|
||||||
// allocate compressed bin data
|
|
||||||
int num_symbols = n_bins + 1;
|
|
||||||
// Required buffer size for storing data matrix in ELLPack format.
|
|
||||||
size_t compressed_size_bytes =
|
|
||||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
|
||||||
num_symbols);
|
|
||||||
|
|
||||||
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
|
|
||||||
thrust::fill(
|
|
||||||
thrust::device_pointer_cast(gidx_buffer.data()),
|
|
||||||
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
|
|
||||||
|
|
||||||
ellpack_matrix.Init(
|
|
||||||
feature_segments, min_fvalue,
|
|
||||||
gidx_fvalue_map, row_stride,
|
|
||||||
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
|
|
||||||
is_dense, null_gidx_value);
|
|
||||||
// check if we can use shared memory for building histograms
|
// check if we can use shared memory for building histograms
|
||||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
||||||
// hiding)
|
// hiding)
|
||||||
auto histogram_size = sizeof(GradientSumT) * hmat.Ptrs().back();
|
auto histogram_size = sizeof(GradientSumT) * page->n_bins;
|
||||||
auto max_smem = dh::MaxSharedMemory(device_id);
|
auto max_smem = dh::MaxSharedMemory(device_id);
|
||||||
if (histogram_size <= max_smem) {
|
if (histogram_size <= max_smem) {
|
||||||
use_shared_memory_histograms = true;
|
use_shared_memory_histograms = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(device_id, hmat.Ptrs().back());
|
hist.Init(device_id, page->n_bins);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
|
||||||
const SparsePage &row_batch,
|
|
||||||
const common::HistogramCuts &hmat,
|
|
||||||
const RowStateOnDevice &device_row_state,
|
|
||||||
int rows_per_batch) {
|
|
||||||
// Has any been allocated for me in this batch?
|
|
||||||
if (!device_row_state.rows_to_process_from_batch) return;
|
|
||||||
|
|
||||||
unsigned int null_gidx_value = hmat.Ptrs().back();
|
|
||||||
size_t row_stride = this->ellpack_matrix.row_stride;
|
|
||||||
|
|
||||||
const auto &offset_vec = row_batch.offset.ConstHostVector();
|
|
||||||
|
|
||||||
int num_symbols = n_bins + 1;
|
|
||||||
// bin and compress entries in batches of rows
|
|
||||||
size_t gpu_batch_nrows = std::min(
|
|
||||||
dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)),
|
|
||||||
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
|
|
||||||
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
|
|
||||||
|
|
||||||
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
|
|
||||||
gpu_batch_nrows);
|
|
||||||
|
|
||||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
|
||||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
|
|
||||||
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
|
|
||||||
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
|
|
||||||
batch_row_end = device_row_state.rows_to_process_from_batch;
|
|
||||||
}
|
|
||||||
size_t batch_nrows = batch_row_end - batch_row_begin;
|
|
||||||
|
|
||||||
const auto ent_cnt_begin =
|
|
||||||
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin];
|
|
||||||
const auto ent_cnt_end =
|
|
||||||
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];
|
|
||||||
|
|
||||||
/*! \brief row offset in SparsePage (the input data). */
|
|
||||||
dh::device_vector<size_t> row_ptrs(batch_nrows+1);
|
|
||||||
thrust::copy(
|
|
||||||
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
|
|
||||||
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
|
|
||||||
row_ptrs.begin());
|
|
||||||
|
|
||||||
// number of entries in this batch.
|
|
||||||
size_t n_entries = ent_cnt_end - ent_cnt_begin;
|
|
||||||
dh::device_vector<Entry> entries_d(n_entries);
|
|
||||||
// copy data entries to device.
|
|
||||||
dh::safe_cuda
|
|
||||||
(cudaMemcpy
|
|
||||||
(entries_d.data().get(), data_vec.data() + ent_cnt_begin,
|
|
||||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
|
||||||
const dim3 block3(32, 8, 1); // 256 threads
|
|
||||||
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
|
||||||
common::DivRoundUp(row_stride, block3.y), 1);
|
|
||||||
CompressBinEllpackKernel<<<grid3, block3>>>
|
|
||||||
(common::CompressedBufferWriter(num_symbols),
|
|
||||||
gidx_buffer.data(),
|
|
||||||
row_ptrs.data().get(),
|
|
||||||
entries_d.data().get(),
|
|
||||||
gidx_fvalue_map.data(),
|
|
||||||
feature_segments.data(),
|
|
||||||
device_row_state.total_rows_processed + batch_row_begin,
|
|
||||||
batch_nrows,
|
|
||||||
row_stride,
|
|
||||||
null_gidx_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// An instance of this type is created which keeps track of total number of rows to process,
|
|
||||||
// rows processed thus far, rows to process and the offset from the current sparse page batch
|
|
||||||
// to begin processing on each device
|
|
||||||
class DeviceHistogramBuilderState {
|
|
||||||
public:
|
|
||||||
template <typename GradientSumT>
|
|
||||||
explicit DeviceHistogramBuilderState(const std::unique_ptr<DeviceShard<GradientSumT>>& shard)
|
|
||||||
: device_row_state_(shard->n_rows) {}
|
|
||||||
|
|
||||||
const RowStateOnDevice& GetRowStateOnDevice() const {
|
|
||||||
return device_row_state_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// This method is invoked at the beginning of each sparse page batch. This distributes
|
|
||||||
// the rows in the sparse page to the device.
|
|
||||||
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
|
|
||||||
void BeginBatch(const SparsePage &batch) {
|
|
||||||
size_t rem_rows = batch.Size();
|
|
||||||
size_t row_offset_in_current_batch = 0;
|
|
||||||
|
|
||||||
// Do we have anymore left to process from this batch on this device?
|
|
||||||
if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) {
|
|
||||||
// There are still some rows that needs to be assigned to this device
|
|
||||||
device_row_state_.rows_to_process_from_batch =
|
|
||||||
std::min(
|
|
||||||
device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed,
|
|
||||||
rem_rows);
|
|
||||||
} else {
|
|
||||||
// All rows have been assigned to this device
|
|
||||||
device_row_state_.rows_to_process_from_batch = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch;
|
|
||||||
row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch;
|
|
||||||
rem_rows -= device_row_state_.rows_to_process_from_batch;
|
|
||||||
}
|
|
||||||
|
|
||||||
// This method is invoked after completion of each sparse page batch
|
|
||||||
void EndBatch() {
|
|
||||||
device_row_state_.Advance();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
RowStateOnDevice device_row_state_{0};
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
class GPUHistMakerSpecialised {
|
class GPUHistMakerSpecialised {
|
||||||
public:
|
public:
|
||||||
@ -1319,47 +1018,33 @@ class GPUHistMakerSpecialised {
|
|||||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||||
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||||
|
|
||||||
|
// TODO(rongou): support multiple Ellpack pages.
|
||||||
|
EllpackPageImpl* page{};
|
||||||
|
for (auto& batch : dmat->GetBatches<EllpackPage>()) {
|
||||||
|
page = batch.Impl();
|
||||||
|
page->Init(device_, param_.max_bin, hist_maker_param_.gpu_batch_nrows);
|
||||||
|
}
|
||||||
|
|
||||||
// Create device shard
|
// Create device shard
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
shard_.reset(new DeviceShard<GradientSumT>(device_,
|
shard_.reset(new DeviceShard<GradientSumT>(device_,
|
||||||
|
page,
|
||||||
info_->num_row_,
|
info_->num_row_,
|
||||||
param_,
|
param_,
|
||||||
column_sampling_seed,
|
column_sampling_seed,
|
||||||
info_->num_col_));
|
info_->num_col_));
|
||||||
|
|
||||||
monitor_.StartCuda("Quantiles");
|
|
||||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts
|
|
||||||
size_t row_stride = common::DeviceSketch(param_, *generic_param_,
|
|
||||||
hist_maker_param_.gpu_batch_nrows,
|
|
||||||
dmat, &hmat_);
|
|
||||||
monitor_.StopCuda("Quantiles");
|
|
||||||
|
|
||||||
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
|
||||||
|
|
||||||
// Init global data for each shard
|
// Init global data for each shard
|
||||||
monitor_.StartCuda("InitCompressedData");
|
monitor_.StartCuda("InitHistogram");
|
||||||
dh::safe_cuda(cudaSetDevice(shard_->device_id));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
shard_->InitCompressedData(hmat_, row_stride, is_dense);
|
shard_->InitHistogram();
|
||||||
monitor_.StopCuda("InitCompressedData");
|
monitor_.StopCuda("InitHistogram");
|
||||||
|
|
||||||
monitor_.StartCuda("BinningCompression");
|
|
||||||
DeviceHistogramBuilderState hist_builder_row_state(shard_);
|
|
||||||
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
|
||||||
hist_builder_row_state.BeginBatch(batch);
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(shard_->device_id));
|
|
||||||
shard_->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(),
|
|
||||||
hist_maker_param_.gpu_batch_nrows);
|
|
||||||
|
|
||||||
hist_builder_row_state.EndBatch();
|
|
||||||
}
|
|
||||||
monitor_.StopCuda("BinningCompression");
|
|
||||||
|
|
||||||
p_last_fmat_ = dmat;
|
p_last_fmat_ = dmat;
|
||||||
initialised_ = true;
|
initialised_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) {
|
void InitData(DMatrix* dmat) {
|
||||||
if (!initialised_) {
|
if (!initialised_) {
|
||||||
monitor_.StartCuda("InitDataOnce");
|
monitor_.StartCuda("InitDataOnce");
|
||||||
this->InitDataOnce(dmat);
|
this->InitDataOnce(dmat);
|
||||||
@ -1387,7 +1072,7 @@ class GPUHistMakerSpecialised {
|
|||||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||||
RegTree* p_tree) {
|
RegTree* p_tree) {
|
||||||
monitor_.StartCuda("InitData");
|
monitor_.StartCuda("InitData");
|
||||||
this->InitData(gpair, p_fmat);
|
this->InitData(p_fmat);
|
||||||
monitor_.StopCuda("InitData");
|
monitor_.StopCuda("InitData");
|
||||||
|
|
||||||
gpair->SetDevice(device_);
|
gpair->SetDevice(device_);
|
||||||
@ -1408,7 +1093,6 @@ class GPUHistMakerSpecialised {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TrainParam param_; // NOLINT
|
TrainParam param_; // NOLINT
|
||||||
common::HistogramCuts hmat_; // NOLINT
|
|
||||||
MetaInfo* info_{}; // NOLINT
|
MetaInfo* info_{}; // NOLINT
|
||||||
|
|
||||||
std::unique_ptr<DeviceShard<GradientSumT>> shard_; // NOLINT
|
std::unique_ptr<DeviceShard<GradientSumT>> shard_; // NOLINT
|
||||||
|
|||||||
@ -43,18 +43,17 @@ void TestDeviceSketch(bool use_external_memory) {
|
|||||||
dmat = static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
|
dmat = static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
tree::TrainParam p;
|
int device{0};
|
||||||
p.max_bin = 20;
|
int max_bin{20};
|
||||||
int gpu_batch_nrows = 0;
|
int gpu_batch_nrows{0};
|
||||||
|
|
||||||
// find quantiles on the CPU
|
// find quantiles on the CPU
|
||||||
HistogramCuts hmat_cpu;
|
HistogramCuts hmat_cpu;
|
||||||
hmat_cpu.Build((*dmat).get(), p.max_bin);
|
hmat_cpu.Build((*dmat).get(), max_bin);
|
||||||
|
|
||||||
// find the cuts on the GPU
|
// find the cuts on the GPU
|
||||||
HistogramCuts hmat_gpu;
|
HistogramCuts hmat_gpu;
|
||||||
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0), gpu_batch_nrows,
|
size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu);
|
||||||
dmat->get(), &hmat_gpu);
|
|
||||||
|
|
||||||
// compare the row stride with the one obtained from the dmatrix
|
// compare the row stride with the one obtained from the dmatrix
|
||||||
size_t expected_row_stride = 0;
|
size_t expected_row_stride = 0;
|
||||||
|
|||||||
86
tests/cpp/data/test_ellpack_page.cu
Normal file
86
tests/cpp/data/test_ellpack_page.cu
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2019 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
#include "../../../src/common/hist_util.h"
|
||||||
|
#include "../../../src/data/ellpack_page.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
TEST(EllpackPage, EmptyDMatrix) {
|
||||||
|
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256, kGpuBatchNRows = 64;
|
||||||
|
constexpr float kSparsity = 0;
|
||||||
|
auto dmat = *CreateDMatrix(kNRows, kNCols, kSparsity);
|
||||||
|
auto& page = *dmat->GetBatches<EllpackPage>().begin();
|
||||||
|
auto impl = page.Impl();
|
||||||
|
impl->Init(0, kMaxBin, kGpuBatchNRows);
|
||||||
|
ASSERT_EQ(impl->ellpack_matrix.feature_segments.size(), 1);
|
||||||
|
ASSERT_EQ(impl->ellpack_matrix.min_fvalue.size(), 0);
|
||||||
|
ASSERT_EQ(impl->ellpack_matrix.gidx_fvalue_map.size(), 0);
|
||||||
|
ASSERT_EQ(impl->ellpack_matrix.row_stride, 0);
|
||||||
|
ASSERT_EQ(impl->ellpack_matrix.null_gidx_value, 0);
|
||||||
|
ASSERT_EQ(impl->n_bins, 0);
|
||||||
|
ASSERT_EQ(impl->gidx_buffer.size(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EllpackPage, BuildGidxDense) {
|
||||||
|
int constexpr kNRows = 16, kNCols = 8;
|
||||||
|
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||||
|
|
||||||
|
std::vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.size());
|
||||||
|
dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer);
|
||||||
|
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||||
|
|
||||||
|
ASSERT_EQ(page->ellpack_matrix.row_stride, kNCols);
|
||||||
|
|
||||||
|
std::vector<uint32_t> solution = {
|
||||||
|
0, 3, 8, 9, 14, 17, 20, 21,
|
||||||
|
0, 4, 7, 10, 14, 16, 19, 22,
|
||||||
|
1, 3, 7, 11, 14, 15, 19, 21,
|
||||||
|
2, 3, 7, 9, 13, 16, 20, 22,
|
||||||
|
2, 3, 6, 9, 12, 16, 20, 21,
|
||||||
|
1, 5, 6, 10, 13, 16, 20, 21,
|
||||||
|
2, 5, 8, 9, 13, 17, 19, 22,
|
||||||
|
2, 4, 6, 10, 14, 17, 19, 21,
|
||||||
|
2, 5, 7, 9, 13, 16, 19, 22,
|
||||||
|
0, 3, 8, 10, 12, 16, 19, 22,
|
||||||
|
1, 3, 7, 10, 13, 16, 19, 21,
|
||||||
|
1, 3, 8, 10, 13, 17, 20, 22,
|
||||||
|
2, 4, 6, 9, 14, 15, 19, 22,
|
||||||
|
1, 4, 6, 9, 13, 16, 19, 21,
|
||||||
|
2, 4, 8, 10, 14, 15, 19, 22,
|
||||||
|
1, 4, 7, 10, 14, 16, 19, 21,
|
||||||
|
};
|
||||||
|
for (size_t i = 0; i < kNRows * kNCols; ++i) {
|
||||||
|
ASSERT_EQ(solution[i], gidx[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EllpackPage, BuildGidxSparse) {
|
||||||
|
int constexpr kNRows = 16, kNCols = 8;
|
||||||
|
auto page = BuildEllpackPage(kNRows, kNCols, 0.9f);
|
||||||
|
|
||||||
|
std::vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.size());
|
||||||
|
dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer);
|
||||||
|
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||||
|
|
||||||
|
ASSERT_LE(page->ellpack_matrix.row_stride, 3);
|
||||||
|
|
||||||
|
// row_stride = 3, 16 rows, 48 entries for ELLPack
|
||||||
|
std::vector<uint32_t> solution = {
|
||||||
|
15, 24, 24, 0, 24, 24, 24, 24, 24, 24, 24, 24, 20, 24, 24, 24,
|
||||||
|
24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24,
|
||||||
|
24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24
|
||||||
|
};
|
||||||
|
for (size_t i = 0; i < kNRows * page->ellpack_matrix.row_stride; ++i) {
|
||||||
|
ASSERT_EQ(solution[i], gidx[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
@ -21,6 +21,10 @@
|
|||||||
#include <xgboost/generic_parameters.h>
|
#include <xgboost/generic_parameters.h>
|
||||||
|
|
||||||
#include "../../src/common/common.h"
|
#include "../../src/common/common.h"
|
||||||
|
#include "../../src/common/hist_util.h"
|
||||||
|
#if defined(__CUDACC__)
|
||||||
|
#include "../../src/data/ellpack_page.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#define DeclareUnifiedTest(name) GPU ## name
|
#define DeclareUnifiedTest(name) GPU ## name
|
||||||
@ -197,5 +201,58 @@ inline GenericParameter CreateEmptyGenericParam(int gpu_id) {
|
|||||||
return tparam;
|
return tparam;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(__CUDACC__)
|
||||||
|
namespace {
|
||||||
|
class HistogramCutsWrapper : public common::HistogramCuts {
|
||||||
|
public:
|
||||||
|
using SuperT = common::HistogramCuts;
|
||||||
|
void SetValues(std::vector<float> cuts) {
|
||||||
|
SuperT::cut_values_ = std::move(cuts);
|
||||||
|
}
|
||||||
|
void SetPtrs(std::vector<uint32_t> ptrs) {
|
||||||
|
SuperT::cut_ptrs_ = std::move(ptrs);
|
||||||
|
}
|
||||||
|
void SetMins(std::vector<float> mins) {
|
||||||
|
SuperT::min_vals_ = std::move(mins);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
|
||||||
|
int n_rows, int n_cols, bst_float sparsity= 0) {
|
||||||
|
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
|
||||||
|
const SparsePage& batch = *(*dmat)->GetBatches<xgboost::SparsePage>().begin();
|
||||||
|
|
||||||
|
HistogramCutsWrapper cmat;
|
||||||
|
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||||
|
// 24 cut fields, 3 cut fields for each feature (column).
|
||||||
|
cmat.SetValues({0.30f, 0.67f, 1.64f,
|
||||||
|
0.32f, 0.77f, 1.95f,
|
||||||
|
0.29f, 0.70f, 1.80f,
|
||||||
|
0.32f, 0.75f, 1.85f,
|
||||||
|
0.18f, 0.59f, 1.69f,
|
||||||
|
0.25f, 0.74f, 2.00f,
|
||||||
|
0.26f, 0.74f, 1.98f,
|
||||||
|
0.26f, 0.71f, 1.83f});
|
||||||
|
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
|
||||||
|
|
||||||
|
auto is_dense = (*dmat)->Info().num_nonzero_ ==
|
||||||
|
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
|
||||||
|
size_t row_stride = 0;
|
||||||
|
const auto &offset_vec = batch.offset.ConstHostVector();
|
||||||
|
for (size_t i = 1; i < offset_vec.size(); ++i) {
|
||||||
|
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto page = std::unique_ptr<EllpackPageImpl>(new EllpackPageImpl(dmat->get()));
|
||||||
|
page->InitCompressedData(0, cmat, row_stride, is_dense);
|
||||||
|
page->CreateHistIndices(0, batch, RowStateOnDevice(batch.Size(), batch.Size()));
|
||||||
|
|
||||||
|
delete dmat;
|
||||||
|
|
||||||
|
return page;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -98,82 +98,13 @@ void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
|
|||||||
for (size_t i = 1; i < offset_vec.size(); ++i) {
|
for (size_t i = 1; i < offset_vec.size(); ++i) {
|
||||||
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
|
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
|
||||||
}
|
}
|
||||||
shard->InitCompressedData(cmat, row_stride, is_dense);
|
shard->InitHistogram(cmat, row_stride, is_dense);
|
||||||
shard->CreateHistIndices(
|
shard->CreateHistIndices(
|
||||||
batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1);
|
batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1);
|
||||||
|
|
||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, BuildGidxDense) {
|
|
||||||
int constexpr kNRows = 16, kNCols = 8;
|
|
||||||
tree::TrainParam param;
|
|
||||||
std::vector<std::pair<std::string, std::string>> args {
|
|
||||||
{"max_depth", "1"},
|
|
||||||
{"max_leaves", "0"},
|
|
||||||
};
|
|
||||||
param.Init(args);
|
|
||||||
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
|
|
||||||
BuildGidx(&shard, kNRows, kNCols);
|
|
||||||
|
|
||||||
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
|
|
||||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer);
|
|
||||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
|
||||||
|
|
||||||
ASSERT_EQ(shard.ellpack_matrix.row_stride, kNCols);
|
|
||||||
|
|
||||||
std::vector<uint32_t> solution = {
|
|
||||||
0, 3, 8, 9, 14, 17, 20, 21,
|
|
||||||
0, 4, 7, 10, 14, 16, 19, 22,
|
|
||||||
1, 3, 7, 11, 14, 15, 19, 21,
|
|
||||||
2, 3, 7, 9, 13, 16, 20, 22,
|
|
||||||
2, 3, 6, 9, 12, 16, 20, 21,
|
|
||||||
1, 5, 6, 10, 13, 16, 20, 21,
|
|
||||||
2, 5, 8, 9, 13, 17, 19, 22,
|
|
||||||
2, 4, 6, 10, 14, 17, 19, 21,
|
|
||||||
2, 5, 7, 9, 13, 16, 19, 22,
|
|
||||||
0, 3, 8, 10, 12, 16, 19, 22,
|
|
||||||
1, 3, 7, 10, 13, 16, 19, 21,
|
|
||||||
1, 3, 8, 10, 13, 17, 20, 22,
|
|
||||||
2, 4, 6, 9, 14, 15, 19, 22,
|
|
||||||
1, 4, 6, 9, 13, 16, 19, 21,
|
|
||||||
2, 4, 8, 10, 14, 15, 19, 22,
|
|
||||||
1, 4, 7, 10, 14, 16, 19, 21,
|
|
||||||
};
|
|
||||||
for (size_t i = 0; i < kNRows * kNCols; ++i) {
|
|
||||||
ASSERT_EQ(solution[i], gidx[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(GpuHist, BuildGidxSparse) {
|
|
||||||
int constexpr kNRows = 16, kNCols = 8;
|
|
||||||
TrainParam param;
|
|
||||||
std::vector<std::pair<std::string, std::string>> args {
|
|
||||||
{"max_depth", "1"},
|
|
||||||
{"max_leaves", "0"},
|
|
||||||
};
|
|
||||||
param.Init(args);
|
|
||||||
|
|
||||||
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
|
|
||||||
BuildGidx(&shard, kNRows, kNCols, 0.9f);
|
|
||||||
|
|
||||||
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
|
|
||||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer);
|
|
||||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
|
||||||
|
|
||||||
ASSERT_LE(shard.ellpack_matrix.row_stride, 3);
|
|
||||||
|
|
||||||
// row_stride = 3, 16 rows, 48 entries for ELLPack
|
|
||||||
std::vector<uint32_t> solution = {
|
|
||||||
15, 24, 24, 0, 24, 24, 24, 24, 24, 24, 24, 24, 20, 24, 24, 24,
|
|
||||||
24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24,
|
|
||||||
24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24
|
|
||||||
};
|
|
||||||
for (size_t i = 0; i < kNRows * shard.ellpack_matrix.row_stride; ++i) {
|
|
||||||
ASSERT_EQ(solution[i], gidx[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<GradientPairPrecise> GetHostHistGpair() {
|
std::vector<GradientPairPrecise> GetHostHistGpair() {
|
||||||
// 24 bins, 3 bins for each feature (column).
|
// 24 bins, 3 bins for each feature (column).
|
||||||
std::vector<GradientPairPrecise> hist_gpair = {
|
std::vector<GradientPairPrecise> hist_gpair = {
|
||||||
@ -199,9 +130,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
{"max_leaves", "0"},
|
{"max_leaves", "0"},
|
||||||
};
|
};
|
||||||
param.Init(args);
|
param.Init(args);
|
||||||
DeviceShard<GradientSumT> shard(0, kNRows, param, kNCols, kNCols);
|
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||||
BuildGidx(&shard, kNRows, kNCols);
|
DeviceShard<GradientSumT> shard(0, page.get(), kNRows, param, kNCols, kNCols);
|
||||||
|
shard.InitHistogram();
|
||||||
|
|
||||||
xgboost::SimpleLCG gen;
|
xgboost::SimpleLCG gen;
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
std::vector<GradientPair> h_gpair(kNRows);
|
std::vector<GradientPair> h_gpair(kNRows);
|
||||||
@ -211,12 +143,11 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
gpair = GradientPair(grad, hess);
|
gpair = GradientPair(grad, hess);
|
||||||
}
|
}
|
||||||
|
|
||||||
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (
|
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.size());
|
||||||
shard.gidx_buffer.size());
|
|
||||||
|
|
||||||
common::CompressedByteT* d_gidx_buffer_ptr = shard.gidx_buffer.data();
|
common::CompressedByteT* d_gidx_buffer_ptr = page->gidx_buffer.data();
|
||||||
dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr,
|
dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr,
|
||||||
sizeof(common::CompressedByteT) * shard.gidx_buffer.size(),
|
sizeof(common::CompressedByteT) * page->gidx_buffer.size(),
|
||||||
cudaMemcpyDeviceToHost));
|
cudaMemcpyDeviceToHost));
|
||||||
|
|
||||||
shard.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
shard.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||||
@ -300,8 +231,9 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
int max_bins = 4;
|
int max_bins = 4;
|
||||||
|
|
||||||
// Initialize DeviceShard
|
// Initialize DeviceShard
|
||||||
|
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||||
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
|
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
|
||||||
new DeviceShard<GradientPairPrecise>(0, kNRows, param, kNCols, kNCols)};
|
new DeviceShard<GradientPairPrecise>(0, page.get(), kNRows, param, kNCols, kNCols)};
|
||||||
// Initialize DeviceShard::node_sum_gradients
|
// Initialize DeviceShard::node_sum_gradients
|
||||||
shard->node_sum_gradients = {{6.4f, 12.8f}};
|
shard->node_sum_gradients = {{6.4f, 12.8f}};
|
||||||
|
|
||||||
@ -310,18 +242,14 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
|
|
||||||
// Copy cut matrix to device.
|
// Copy cut matrix to device.
|
||||||
shard->ba.Allocate(0,
|
shard->ba.Allocate(0,
|
||||||
&(shard->feature_segments), cmat.Ptrs().size(),
|
&(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(),
|
||||||
&(shard->min_fvalue), cmat.MinValues().size(),
|
&(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(),
|
||||||
&(shard->gidx_fvalue_map), 24,
|
&(page->ellpack_matrix.gidx_fvalue_map), 24,
|
||||||
&(shard->monotone_constraints), kNCols);
|
&(shard->monotone_constraints), kNCols);
|
||||||
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.Ptrs());
|
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.feature_segments, cmat.Ptrs());
|
||||||
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.Values());
|
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.gidx_fvalue_map, cmat.Values());
|
||||||
dh::CopyVectorToDeviceSpan(shard->monotone_constraints,
|
dh::CopyVectorToDeviceSpan(shard->monotone_constraints, param.monotone_constraints);
|
||||||
param.monotone_constraints);
|
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.min_fvalue, cmat.MinValues());
|
||||||
shard->ellpack_matrix.feature_segments = shard->feature_segments;
|
|
||||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
|
|
||||||
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.MinValues());
|
|
||||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
|
|
||||||
|
|
||||||
// Initialize DeviceShard::hist
|
// Initialize DeviceShard::hist
|
||||||
shard->hist.Init(0, (max_bins - 1) * kNCols);
|
shard->hist.Init(0, (max_bins - 1) * kNCols);
|
||||||
@ -391,15 +319,15 @@ void TestHistogramIndexImpl() {
|
|||||||
// Extract the device shard from the histogram makers and from that its compressed
|
// Extract the device shard from the histogram makers and from that its compressed
|
||||||
// histogram index
|
// histogram index
|
||||||
const auto &dev_shard = hist_maker.shard_;
|
const auto &dev_shard = hist_maker.shard_;
|
||||||
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
|
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->page->gidx_buffer.size());
|
||||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
|
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->page->gidx_buffer);
|
||||||
|
|
||||||
const auto &dev_shard_ext = hist_maker_ext.shard_;
|
const auto &dev_shard_ext = hist_maker_ext.shard_;
|
||||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
|
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->page->gidx_buffer.size());
|
||||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
|
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->page->gidx_buffer);
|
||||||
|
|
||||||
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
|
ASSERT_EQ(dev_shard->page->n_bins, dev_shard_ext->page->n_bins);
|
||||||
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
|
ASSERT_EQ(dev_shard->page->gidx_buffer.size(), dev_shard_ext->page->gidx_buffer.size());
|
||||||
|
|
||||||
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user