Write ELLPACK pages to disk (#4879)
* add ellpack source * add batch param * extract function to parse cache info * construct ellpack info separately * push batch to ellpack page * write ellpack page. * make sparse page source reusable
This commit is contained in:
parent
310fe60b35
commit
5b1715d97c
@ -40,7 +40,6 @@
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
#include "../src/data/sparse_page_dmatrix.cc"
|
||||
#include "../src/data/sparse_page_writer.cc"
|
||||
#endif
|
||||
|
||||
// tress
|
||||
|
||||
@ -156,6 +156,18 @@ struct Entry {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Parameters for constructing batches.
|
||||
*/
|
||||
struct BatchParam {
|
||||
/*! \brief The GPU device to use. */
|
||||
int gpu_id;
|
||||
/*! \brief Maximum number of bins per feature for histograms. */
|
||||
int max_bin;
|
||||
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
|
||||
int gpu_batch_nrows;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief In-memory storage unit of sparse batch, stored in CSR format.
|
||||
*/
|
||||
@ -191,14 +203,17 @@ class SparsePage {
|
||||
SparsePage() {
|
||||
this->Clear();
|
||||
}
|
||||
/*! \return number of instance in the page */
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
inline size_t Size() const {
|
||||
return offset.Size() - 1;
|
||||
}
|
||||
|
||||
/*! \return estimation of memory cost of this page */
|
||||
inline size_t MemCostBytes() const {
|
||||
return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry);
|
||||
}
|
||||
|
||||
/*! \brief clear the page */
|
||||
inline void Clear() {
|
||||
base_rowid = 0;
|
||||
@ -208,6 +223,11 @@ class SparsePage {
|
||||
data.HostVector().clear();
|
||||
}
|
||||
|
||||
/*! \brief Set the base row id for this page. */
|
||||
inline void SetBaseRowId(size_t row_id) {
|
||||
base_rowid = row_id;
|
||||
}
|
||||
|
||||
SparsePage GetTranspose(int num_columns) const;
|
||||
|
||||
void SortRows() {
|
||||
@ -238,13 +258,6 @@ class SparsePage {
|
||||
* \param batch The row batch to be pushed
|
||||
*/
|
||||
void PushCSC(const SparsePage& batch);
|
||||
/*!
|
||||
* \brief Push one instance into page
|
||||
* \param inst an instance row
|
||||
*/
|
||||
void Push(const Inst &inst);
|
||||
|
||||
size_t Size() { return offset.Size() - 1; }
|
||||
};
|
||||
|
||||
class CSCPage: public SparsePage {
|
||||
@ -268,9 +281,31 @@ class EllpackPageImpl;
|
||||
*/
|
||||
class EllpackPage {
|
||||
public:
|
||||
explicit EllpackPage(DMatrix* dmat);
|
||||
/*!
|
||||
* \brief Default constructor.
|
||||
*
|
||||
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
|
||||
* set later by the reader.
|
||||
*/
|
||||
EllpackPage();
|
||||
|
||||
/*!
|
||||
* \brief Constructor from an existing DMatrix.
|
||||
*
|
||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||
* in CSR format.
|
||||
*/
|
||||
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
|
||||
|
||||
/*! \brief Destructor. */
|
||||
~EllpackPage();
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
size_t Size() const;
|
||||
|
||||
/*! \brief Set the base row id for this page. */
|
||||
void SetBaseRowId(size_t row_id);
|
||||
|
||||
const EllpackPageImpl* Impl() const { return impl_.get(); }
|
||||
EllpackPageImpl* Impl() { return impl_.get(); }
|
||||
|
||||
@ -356,7 +391,8 @@ class DataSource : public dmlc::DataIter<T> {
|
||||
* There are two ways to create a customized DMatrix that reads in user defined-format.
|
||||
*
|
||||
* - Provide a dmlc::Parser and pass into the DMatrix::Create
|
||||
* - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by DMLC_REGISTER_DATA_PARSER;
|
||||
* - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by
|
||||
* DMLC_REGISTER_DATA_PARSER;
|
||||
* - This works best for user defined data input source, such as data-base, filesystem.
|
||||
* - Provide a DataSource, that can be passed to DMatrix::Create
|
||||
* This can be used to re-use inmemory data structure into DMatrix.
|
||||
@ -373,7 +409,7 @@ class DMatrix {
|
||||
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
*/
|
||||
template<typename T>
|
||||
BatchSet<T> GetBatches();
|
||||
BatchSet<T> GetBatches(const BatchParam& param = {});
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
@ -389,6 +425,12 @@ class DMatrix {
|
||||
* \return The created DMatrix.
|
||||
*/
|
||||
virtual void SaveToLocalFile(const std::string& fname);
|
||||
|
||||
/*! \brief Whether the matrix is dense. */
|
||||
bool IsDense() const {
|
||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Load DMatrix from URI.
|
||||
* \param uri The URI of input.
|
||||
@ -438,27 +480,27 @@ class DMatrix {
|
||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
inline BatchSet<SparsePage> DMatrix::GetBatches() {
|
||||
inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
|
||||
return GetRowBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches() {
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
|
||||
return GetColumnBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(const BatchParam&) {
|
||||
return GetSortedColumnBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches() {
|
||||
return GetEllpackBatches();
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetEllpackBatches(param);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -540,16 +540,21 @@ class BulkAllocator {
|
||||
}
|
||||
|
||||
public:
|
||||
BulkAllocator() = default;
|
||||
BulkAllocator() = default;
|
||||
// prevent accidental copying, moving or assignment of this object
|
||||
BulkAllocator(const BulkAllocator&) = delete;
|
||||
BulkAllocator(BulkAllocator&&) = delete;
|
||||
void operator=(const BulkAllocator&) = delete;
|
||||
void operator=(BulkAllocator&&) = delete;
|
||||
|
||||
~BulkAllocator() {
|
||||
for (size_t i = 0; i < d_ptr_.size(); i++) {
|
||||
if (!(d_ptr_[i] == nullptr)) {
|
||||
/*!
|
||||
* \brief Clear the bulk allocator.
|
||||
*
|
||||
* This frees the GPU memory managed by this allocator.
|
||||
*/
|
||||
void Clear() {
|
||||
for (size_t i = 0; i < d_ptr_.size(); i++) { // NOLINT(modernize-loop-convert)
|
||||
if (d_ptr_[i] != nullptr) {
|
||||
safe_cuda(cudaSetDevice(device_idx_[i]));
|
||||
XGBDeviceAllocator<char> allocator;
|
||||
allocator.deallocate(thrust::device_ptr<char>(d_ptr_[i]), size_[i]);
|
||||
@ -558,6 +563,10 @@ class BulkAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
~BulkAllocator() {
|
||||
Clear();
|
||||
}
|
||||
|
||||
// returns sum of bytes for all allocations
|
||||
size_t Size() {
|
||||
return std::accumulate(size_.begin(), size_.end(), static_cast<size_t>(0));
|
||||
|
||||
@ -21,7 +21,10 @@
|
||||
#endif // DMLC_ENABLE_STD_THREAD
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::CSCPage>);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPage>);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
@ -329,31 +332,6 @@ DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
data::SparsePageFormat* data::SparsePageFormat::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::data::SparsePageFormatReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown format type " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string>
|
||||
data::SparsePageFormat::DecideFormat(const std::string& cache_prefix) {
|
||||
size_t pos = cache_prefix.rfind(".fmt-");
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
std::string fmt = cache_prefix.substr(pos + 5, cache_prefix.length());
|
||||
size_t cpos = fmt.rfind('-');
|
||||
if (cpos != std::string::npos) {
|
||||
return std::make_pair(fmt.substr(0, cpos), fmt.substr(cpos + 1, fmt.length()));
|
||||
} else {
|
||||
return std::make_pair(fmt, fmt);
|
||||
}
|
||||
} else {
|
||||
std::string raw = "raw";
|
||||
return std::make_pair(raw, raw);
|
||||
}
|
||||
}
|
||||
SparsePage SparsePage::GetTranspose(int num_columns) const {
|
||||
SparsePage transpose;
|
||||
common::ParallelGroupBuilder<Entry> builder(&transpose.offset.HostVector(),
|
||||
@ -476,18 +454,6 @@ void SparsePage::PushCSC(const SparsePage &batch) {
|
||||
self_offset = std::move(offset);
|
||||
}
|
||||
|
||||
void SparsePage::Push(const Inst &inst) {
|
||||
auto& data_vec = data.HostVector();
|
||||
auto& offset_vec = offset.HostVector();
|
||||
offset_vec.push_back(offset_vec.back() + inst.size());
|
||||
size_t begin = data_vec.size();
|
||||
data_vec.resize(begin + inst.size());
|
||||
if (inst.size() != 0) {
|
||||
std::memcpy(dmlc::BeginPtr(data_vec) + begin, inst.data(),
|
||||
sizeof(Entry) * inst.size());
|
||||
}
|
||||
}
|
||||
|
||||
namespace data {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);
|
||||
|
||||
@ -1,18 +1,16 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*
|
||||
* \file ellpack_page.cc
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
#include <xgboost/data.h>
|
||||
|
||||
// dummy implementation of ELlpackPage in case CUDA is not used
|
||||
// dummy implementation of EllpackPage in case CUDA is not used
|
||||
namespace xgboost {
|
||||
|
||||
class EllpackPageImpl {};
|
||||
|
||||
EllpackPage::EllpackPage(DMatrix* dmat) {
|
||||
EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param) {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*
|
||||
* \file ellpack_page.cu
|
||||
*/
|
||||
|
||||
#include <xgboost/data.h>
|
||||
@ -12,14 +10,22 @@
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
EllpackPage::EllpackPage(DMatrix* dmat) : impl_{new EllpackPageImpl(dmat)} {}
|
||||
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
|
||||
|
||||
EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param)
|
||||
: impl_{new EllpackPageImpl(dmat, param)} {}
|
||||
|
||||
EllpackPage::~EllpackPage() = default;
|
||||
|
||||
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat) : dmat_{dmat} {}
|
||||
size_t EllpackPage::Size() const {
|
||||
return impl_->Size();
|
||||
}
|
||||
|
||||
void EllpackPage::SetBaseRowId(size_t row_id) {
|
||||
impl_->SetBaseRowId(row_id);
|
||||
}
|
||||
|
||||
// Bin each input data entry, store the bin indices in compressed form.
|
||||
template<typename std::enable_if<true, int>::type = 0>
|
||||
__global__ void CompressBinEllpackKernel(
|
||||
common::CompressedBufferWriter wr,
|
||||
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
||||
@ -43,7 +49,7 @@ __global__ void CompressBinEllpackKernel(
|
||||
int feature = entry.index;
|
||||
float fvalue = entry.fvalue;
|
||||
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
|
||||
const float *feature_cuts = &cuts[cut_rows[feature]];
|
||||
const float* feature_cuts = &cuts[cut_rows[feature]];
|
||||
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
||||
// Assigning the bin in current entry.
|
||||
// S.t.: fvalue < feature_cuts[bin]
|
||||
@ -58,87 +64,90 @@ __global__ void CompressBinEllpackKernel(
|
||||
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
||||
}
|
||||
|
||||
void EllpackPageImpl::Init(int device, int max_bin, int gpu_batch_nrows) {
|
||||
if (initialised_) return;
|
||||
|
||||
// Construct an ELLPACK matrix in memory.
|
||||
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
|
||||
monitor_.Init("ellpack_page");
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
|
||||
monitor_.StartCuda("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
common::HistogramCuts hmat;
|
||||
size_t row_stride = common::DeviceSketch(device, max_bin, gpu_batch_nrows, dmat_, &hmat);
|
||||
size_t row_stride =
|
||||
common::DeviceSketch(param.gpu_id, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
const auto& info = dmat_->Info();
|
||||
auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_;
|
||||
monitor_.StartCuda("InitEllpackInfo");
|
||||
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, hmat);
|
||||
monitor_.StopCuda("InitEllpackInfo");
|
||||
|
||||
// Init global data
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
InitCompressedData(device, hmat, row_stride, is_dense);
|
||||
InitCompressedData(param.gpu_id, dmat->Info().num_row_);
|
||||
monitor_.StopCuda("InitCompressedData");
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
DeviceHistogramBuilderState hist_builder_row_state(info.num_row_);
|
||||
for (const auto& batch : dmat_->GetBatches<SparsePage>()) {
|
||||
DeviceHistogramBuilderState hist_builder_row_state(dmat->Info().num_row_);
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
hist_builder_row_state.BeginBatch(batch);
|
||||
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice());
|
||||
CreateHistIndices(param.gpu_id, batch, hist_builder_row_state.GetRowStateOnDevice());
|
||||
hist_builder_row_state.EndBatch();
|
||||
}
|
||||
monitor_.StopCuda("BinningCompression");
|
||||
|
||||
initialised_ = true;
|
||||
}
|
||||
|
||||
void EllpackPageImpl::InitCompressedData(int device,
|
||||
const common::HistogramCuts& hmat,
|
||||
size_t row_stride,
|
||||
bool is_dense) {
|
||||
n_bins = hmat.Ptrs().back();
|
||||
int null_gidx_value = hmat.Ptrs().back();
|
||||
int num_symbols = n_bins + 1;
|
||||
|
||||
// minimum value for each feature.
|
||||
common::Span<bst_float> min_fvalue;
|
||||
|
||||
// Required buffer size for storing data matrix in ELLPack format.
|
||||
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
|
||||
row_stride * dmat_->Info().num_row_, num_symbols);
|
||||
// Construct an EllpackInfo based on histogram cuts of features.
|
||||
EllpackInfo::EllpackInfo(int device,
|
||||
bool is_dense,
|
||||
size_t row_stride,
|
||||
const common::HistogramCuts& hmat,
|
||||
dh::BulkAllocator& ba)
|
||||
: is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) {
|
||||
|
||||
ba.Allocate(device,
|
||||
&feature_segments, hmat.Ptrs().size(),
|
||||
&gidx_fvalue_map, hmat.Values().size(),
|
||||
&min_fvalue, hmat.MinValues().size(),
|
||||
&gidx_buffer, compressed_size_bytes);
|
||||
|
||||
&min_fvalue, hmat.MinValues().size());
|
||||
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
|
||||
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
|
||||
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
|
||||
}
|
||||
|
||||
// Initialize the EllpackInfo for this page.
|
||||
void EllpackPageImpl::InitInfo(int device,
|
||||
bool is_dense,
|
||||
size_t row_stride,
|
||||
const common::HistogramCuts& hmat) {
|
||||
matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, ba_);
|
||||
}
|
||||
|
||||
// Initialize the buffer to stored compressed features.
|
||||
void EllpackPageImpl::InitCompressedData(int device, size_t num_rows) {
|
||||
int num_symbols = matrix.info.n_bins + 1;
|
||||
|
||||
// Required buffer size for storing data matrix in ELLPack format.
|
||||
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
|
||||
matrix.info.row_stride * num_rows, num_symbols);
|
||||
ba_.Allocate(device, &gidx_buffer, compressed_size_bytes);
|
||||
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(gidx_buffer.data()),
|
||||
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
|
||||
|
||||
ellpack_matrix.Init(feature_segments,
|
||||
min_fvalue,
|
||||
gidx_fvalue_map,
|
||||
row_stride,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
|
||||
is_dense,
|
||||
null_gidx_value);
|
||||
matrix.gidx_iter = common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols);
|
||||
}
|
||||
|
||||
// Compress a CSR page into ELLPACK.
|
||||
void EllpackPageImpl::CreateHistIndices(int device,
|
||||
const SparsePage& row_batch,
|
||||
const RowStateOnDevice& device_row_state) {
|
||||
// Has any been allocated for me in this batch?
|
||||
if (!device_row_state.rows_to_process_from_batch) return;
|
||||
|
||||
unsigned int null_gidx_value = n_bins;
|
||||
size_t row_stride = this->ellpack_matrix.row_stride;
|
||||
unsigned int null_gidx_value = matrix.info.n_bins;
|
||||
size_t row_stride = matrix.info.row_stride;
|
||||
|
||||
const auto &offset_vec = row_batch.offset.ConstHostVector();
|
||||
const auto& offset_vec = row_batch.offset.ConstHostVector();
|
||||
|
||||
int num_symbols = n_bins + 1;
|
||||
int num_symbols = matrix.info.n_bins + 1;
|
||||
// bin and compress entries in batches of rows
|
||||
size_t gpu_batch_nrows = std::min(
|
||||
dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
|
||||
@ -162,7 +171,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];
|
||||
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
dh::device_vector<size_t> row_ptrs(batch_nrows+1);
|
||||
dh::device_vector<size_t> row_ptrs(batch_nrows + 1);
|
||||
thrust::copy(
|
||||
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
|
||||
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
|
||||
@ -185,8 +194,8 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
gidx_buffer.data(),
|
||||
row_ptrs.data().get(),
|
||||
entries_d.data().get(),
|
||||
gidx_fvalue_map.data(),
|
||||
feature_segments.data(),
|
||||
matrix.info.gidx_fvalue_map.data(),
|
||||
matrix.info.feature_segments.data(),
|
||||
device_row_state.total_rows_processed + batch_row_begin,
|
||||
batch_nrows,
|
||||
row_stride,
|
||||
@ -194,4 +203,73 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
}
|
||||
}
|
||||
|
||||
// Return the number of rows contained in this page.
|
||||
size_t EllpackPageImpl::Size() const {
|
||||
return n_rows;
|
||||
}
|
||||
|
||||
// Clear the current page.
|
||||
void EllpackPageImpl::Clear() {
|
||||
ba_.Clear();
|
||||
gidx_buffer = {};
|
||||
idx_buffer.clear();
|
||||
n_rows = 0;
|
||||
}
|
||||
|
||||
// Push a CSR page to the current page.
|
||||
//
|
||||
// First compress the CSR page into ELLPACK, then the compressed buffer is copied to host and
|
||||
// appended to the existing host vector.
|
||||
void EllpackPageImpl::Push(int device, const SparsePage& batch) {
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
InitCompressedData(device, batch.Size());
|
||||
monitor_.StopCuda("InitCompressedData");
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
DeviceHistogramBuilderState hist_builder_row_state(batch.Size());
|
||||
hist_builder_row_state.BeginBatch(batch);
|
||||
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice());
|
||||
hist_builder_row_state.EndBatch();
|
||||
monitor_.StopCuda("BinningCompression");
|
||||
|
||||
monitor_.StartCuda("CopyDeviceToHost");
|
||||
std::vector<common::CompressedByteT> buffer(gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&buffer, gidx_buffer);
|
||||
int offset = 0;
|
||||
if (!idx_buffer.empty()) {
|
||||
offset = ::xgboost::common::detail::kPadding;
|
||||
}
|
||||
idx_buffer.reserve(idx_buffer.size() + buffer.size() - offset);
|
||||
idx_buffer.insert(idx_buffer.end(), buffer.begin() + offset, buffer.end());
|
||||
ba_.Clear();
|
||||
gidx_buffer = {};
|
||||
monitor_.StopCuda("CopyDeviceToHost");
|
||||
|
||||
n_rows += batch.Size();
|
||||
}
|
||||
|
||||
// Return the memory cost for storing the compressed features.
|
||||
size_t EllpackPageImpl::MemCostBytes() const {
|
||||
return idx_buffer.size() * sizeof(common::CompressedByteT);
|
||||
}
|
||||
|
||||
// Copy the compressed features to GPU.
|
||||
void EllpackPageImpl::InitDevice(int device, EllpackInfo info) {
|
||||
if (device_initialized_) return;
|
||||
|
||||
monitor_.StartCuda("CopyPageToDevice");
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
gidx_buffer = {};
|
||||
ba_.Allocate(device, &gidx_buffer, idx_buffer.size());
|
||||
dh::CopyVectorToDeviceSpan(gidx_buffer, idx_buffer);
|
||||
|
||||
matrix.info = info;
|
||||
matrix.gidx_iter = common::CompressedIterator<uint32_t>(gidx_buffer.data(), info.n_bins + 1);
|
||||
|
||||
monitor_.StopCuda("CopyPageToDevice");
|
||||
|
||||
device_initialized_ = true;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
*
|
||||
* \file ellpack_page.cuh
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
@ -42,56 +40,68 @@ __forceinline__ __device__ int BinarySearchRow(
|
||||
return -1;
|
||||
}
|
||||
|
||||
/** \brief Meta information about the ELLPACK matrix. */
|
||||
struct EllpackInfo {
|
||||
/*! \brief Whether or not if the matrix is dense. */
|
||||
bool is_dense;
|
||||
/*! \brief Row length for ELLPack, equal to number of features. */
|
||||
size_t row_stride;
|
||||
/*! \brief Total number of bins, also used as the null index value, . */
|
||||
size_t n_bins;
|
||||
/*! \brief Minimum value for each feature. Size equals to number of features. */
|
||||
common::Span<bst_float> min_fvalue;
|
||||
/*! \brief Histogram cut pointers. Size equals to (number of features + 1). */
|
||||
common::Span<uint32_t> feature_segments;
|
||||
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
|
||||
common::Span<bst_float> gidx_fvalue_map;
|
||||
|
||||
EllpackInfo() = default;
|
||||
|
||||
/*!
|
||||
* \brief Constructor.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
* @param is_dense Whether the matrix is dense.
|
||||
* @param row_stride The number of features between starts of consecutive rows.
|
||||
* @param hmat The histogram cuts of all the features.
|
||||
* @param ba The BulkAllocator that owns the GPU memory.
|
||||
*/
|
||||
explicit EllpackInfo(int device,
|
||||
bool is_dense,
|
||||
size_t row_stride,
|
||||
const common::HistogramCuts& hmat,
|
||||
dh::BulkAllocator& ba);
|
||||
};
|
||||
|
||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||
* device. Does not own underlying memory and may be trivially copied into
|
||||
* kernels.*/
|
||||
struct ELLPackMatrix {
|
||||
common::Span<uint32_t> feature_segments;
|
||||
/*! \brief minimum value for each feature. */
|
||||
common::Span<bst_float> min_fvalue;
|
||||
/*! \brief Cut. */
|
||||
common::Span<bst_float> gidx_fvalue_map;
|
||||
/*! \brief row length for ELLPack. */
|
||||
size_t row_stride{0};
|
||||
struct EllpackMatrix {
|
||||
EllpackInfo info;
|
||||
common::CompressedIterator<uint32_t> gidx_iter;
|
||||
int null_gidx_value;
|
||||
|
||||
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
|
||||
XGBOOST_DEVICE size_t BinCount() const { return info.gidx_fvalue_map.size(); }
|
||||
|
||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||
// Given a row index and a feature index, returns the corresponding cut value
|
||||
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto row_begin = row_stride * ridx;
|
||||
auto row_end = row_begin + row_stride;
|
||||
auto row_begin = info.row_stride * ridx;
|
||||
auto row_end = row_begin + info.row_stride;
|
||||
auto gidx = -1;
|
||||
if (is_dense) {
|
||||
if (info.is_dense) {
|
||||
gidx = gidx_iter[row_begin + fidx];
|
||||
} else {
|
||||
gidx =
|
||||
BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx],
|
||||
feature_segments[fidx + 1]);
|
||||
gidx = BinarySearchRow(row_begin,
|
||||
row_end,
|
||||
gidx_iter,
|
||||
info.feature_segments[fidx],
|
||||
info.feature_segments[fidx + 1]);
|
||||
}
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
}
|
||||
return gidx_fvalue_map[gidx];
|
||||
return info.gidx_fvalue_map[gidx];
|
||||
}
|
||||
void Init(common::Span<uint32_t> feature_segments,
|
||||
common::Span<bst_float> min_fvalue,
|
||||
common::Span<bst_float> gidx_fvalue_map, size_t row_stride,
|
||||
common::CompressedIterator<uint32_t> gidx_iter, bool is_dense,
|
||||
int null_gidx_value) {
|
||||
this->feature_segments = feature_segments;
|
||||
this->min_fvalue = min_fvalue;
|
||||
this->gidx_fvalue_map = gidx_fvalue_map;
|
||||
this->row_stride = row_stride;
|
||||
this->gidx_iter = gidx_iter;
|
||||
this->is_dense = is_dense;
|
||||
this->null_gidx_value = null_gidx_value;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_dense;
|
||||
};
|
||||
|
||||
// Instances of this type are created while creating the histogram bins for the
|
||||
@ -171,31 +181,93 @@ class DeviceHistogramBuilderState {
|
||||
|
||||
class EllpackPageImpl {
|
||||
public:
|
||||
ELLPackMatrix ellpack_matrix;
|
||||
int n_bins{};
|
||||
EllpackMatrix matrix;
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
common::Span<common::CompressedByteT> gidx_buffer;
|
||||
std::vector<common::CompressedByteT> idx_buffer;
|
||||
size_t n_rows{};
|
||||
|
||||
explicit EllpackPageImpl(DMatrix* dmat);
|
||||
void Init(int device, int max_bin, int gpu_batch_nrows);
|
||||
void InitCompressedData(int device,
|
||||
const common::HistogramCuts& hmat,
|
||||
size_t row_stride,
|
||||
bool is_dense);
|
||||
/*!
|
||||
* \brief Default constructor.
|
||||
*
|
||||
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
|
||||
* set later by the reader.
|
||||
*/
|
||||
EllpackPageImpl() = default;
|
||||
|
||||
/*!
|
||||
* \brief Constructor from an existing DMatrix.
|
||||
*
|
||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||
* in CSR format.
|
||||
*/
|
||||
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the EllpackInfo contained in the EllpackMatrix.
|
||||
*
|
||||
* This is used in the in-memory case. The current page owns the BulkAllocator, which in turn owns
|
||||
* the GPU memory used by the EllpackInfo.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
* @param is_dense Whether the matrix is dense.
|
||||
* @param row_stride The number of features between starts of consecutive rows.
|
||||
* @param hmat The histogram cuts of all the features.
|
||||
*/
|
||||
void InitInfo(int device, bool is_dense, size_t row_stride, const common::HistogramCuts& hmat);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the buffer to store compressed features.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
* @param num_rows The number of rows we are storing in the buffer.
|
||||
*/
|
||||
void InitCompressedData(int device, size_t num_rows);
|
||||
|
||||
/*!
|
||||
* \brief Compress a single page of CSR data into ELLPACK.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
* @param row_batch The CSR page.
|
||||
* @param device_row_state On-device data for maintaining state.
|
||||
*/
|
||||
void CreateHistIndices(int device,
|
||||
const SparsePage& row_batch,
|
||||
const RowStateOnDevice& device_row_state);
|
||||
|
||||
private:
|
||||
bool initialised_{false};
|
||||
DMatrix* dmat_;
|
||||
common::Monitor monitor_;
|
||||
dh::BulkAllocator ba;
|
||||
/*! \return Number of instances in the page. */
|
||||
size_t Size() const;
|
||||
|
||||
/*! \brief Cut. */
|
||||
common::Span<bst_float> gidx_fvalue_map;
|
||||
/*! \brief row_ptr form HistogramCuts. */
|
||||
common::Span<uint32_t> feature_segments;
|
||||
/*! \brief Set the base row id for this page. */
|
||||
inline void SetBaseRowId(size_t row_id) {
|
||||
base_rowid_ = row_id;
|
||||
}
|
||||
|
||||
/*! \brief clear the page. */
|
||||
void Clear();
|
||||
|
||||
/*!
|
||||
* \brief Push a sparse page.
|
||||
* \param batch The row page.
|
||||
*/
|
||||
void Push(int device, const SparsePage& batch);
|
||||
|
||||
/*! \return Estimation of memory cost of this page. */
|
||||
size_t MemCostBytes() const;
|
||||
|
||||
/*!
|
||||
* \brief Copy the ELLPACK matrix to GPU.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
* @param info The EllpackInfo for the matrix.
|
||||
*/
|
||||
void InitDevice(int device, EllpackInfo info);
|
||||
|
||||
private:
|
||||
common::Monitor monitor_;
|
||||
dh::BulkAllocator ba_;
|
||||
size_t base_rowid_{};
|
||||
bool device_initialized_{false};
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
48
src/data/ellpack_page_raw_format.cu
Normal file
48
src/data/ellpack_page_raw_format.cu
Normal file
@ -0,0 +1,48 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*/
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "./sparse_page_writer.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
|
||||
|
||||
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
||||
public:
|
||||
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
|
||||
auto* impl = page->Impl();
|
||||
if (!fi->Read(&impl->n_rows)) return false;
|
||||
return fi->Read(&impl->idx_buffer);
|
||||
}
|
||||
|
||||
bool Read(EllpackPage* page,
|
||||
dmlc::SeekStream* fi,
|
||||
const std::vector<bst_uint>& sorted_index_set) override {
|
||||
auto* impl = page->Impl();
|
||||
if (!fi->Read(&impl->n_rows)) return false;
|
||||
return fi->Read(&page->Impl()->idx_buffer);
|
||||
}
|
||||
|
||||
void Write(const EllpackPage& page, dmlc::Stream* fo) override {
|
||||
auto* impl = page.Impl();
|
||||
fo->Write(impl->n_rows);
|
||||
auto buffer = impl->idx_buffer;
|
||||
CHECK(!buffer.empty());
|
||||
fo->Write(buffer);
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(raw)
|
||||
.describe("Raw ELLPACK binary data format.")
|
||||
.set_body([]() {
|
||||
return new EllpackPageRawFormat();
|
||||
});
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
46
src/data/ellpack_page_source.cc
Normal file
46
src/data/ellpack_page_source.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
#include "ellpack_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
EllpackPageSource::EllpackPageSource(DMatrix* dmat,
|
||||
const std::string& cache_info,
|
||||
const BatchParam& param) noexcept(false) {
|
||||
LOG(FATAL) << "Internal Error: "
|
||||
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
|
||||
}
|
||||
|
||||
void EllpackPageSource::BeforeFirst() {
|
||||
LOG(FATAL) << "Internal Error: "
|
||||
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
|
||||
}
|
||||
|
||||
bool EllpackPageSource::Next() {
|
||||
LOG(FATAL) << "Internal Error: "
|
||||
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
|
||||
return false;
|
||||
}
|
||||
|
||||
EllpackPage& EllpackPageSource::Value() {
|
||||
LOG(FATAL) << "Internal Error: "
|
||||
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
|
||||
EllpackPage* page;
|
||||
return *page;
|
||||
}
|
||||
|
||||
const EllpackPage& EllpackPageSource::Value() const {
|
||||
LOG(FATAL) << "Internal Error: "
|
||||
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
|
||||
EllpackPage* page;
|
||||
return *page;
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
155
src/data/ellpack_page_source.cu
Normal file
155
src/data/ellpack_page_source.cu
Normal file
@ -0,0 +1,155 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*/
|
||||
|
||||
#include "ellpack_page_source.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/hist_util.h"
|
||||
#include "ellpack_page.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class EllpackPageSourceImpl : public DataSource<EllpackPage> {
|
||||
public:
|
||||
/*!
|
||||
* \brief Create source from cache files the cache_prefix.
|
||||
* \param cache_prefix The prefix of cache we want to solve.
|
||||
*/
|
||||
explicit EllpackPageSourceImpl(DMatrix* dmat,
|
||||
const std::string& cache_info,
|
||||
const BatchParam& param) noexcept(false);
|
||||
|
||||
/*! \brief destructor */
|
||||
~EllpackPageSourceImpl() override = default;
|
||||
|
||||
void BeforeFirst() override;
|
||||
bool Next() override;
|
||||
EllpackPage& Value();
|
||||
const EllpackPage& Value() const override;
|
||||
|
||||
private:
|
||||
/*! \brief Write Ellpack pages after accumulating them in memory. */
|
||||
void WriteEllpackPages(DMatrix* dmat, const std::string& cache_info) const;
|
||||
|
||||
/*! \brief The page type string for ELLPACK. */
|
||||
const std::string kPageType_{".ellpack.page"};
|
||||
|
||||
int device_{-1};
|
||||
common::Monitor monitor_;
|
||||
dh::BulkAllocator ba_;
|
||||
/*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */
|
||||
EllpackInfo ellpack_info_;
|
||||
std::unique_ptr<SparsePageSource<EllpackPage>> source_;
|
||||
};
|
||||
|
||||
EllpackPageSource::EllpackPageSource(DMatrix* dmat,
|
||||
const std::string& cache_info,
|
||||
const BatchParam& param) noexcept(false)
|
||||
: impl_{new EllpackPageSourceImpl(dmat, cache_info, param)} {}
|
||||
|
||||
void EllpackPageSource::BeforeFirst() {
|
||||
impl_->BeforeFirst();
|
||||
}
|
||||
|
||||
bool EllpackPageSource::Next() {
|
||||
return impl_->Next();
|
||||
}
|
||||
|
||||
EllpackPage& EllpackPageSource::Value() {
|
||||
return impl_->Value();
|
||||
}
|
||||
|
||||
const EllpackPage& EllpackPageSource::Value() const {
|
||||
return impl_->Value();
|
||||
}
|
||||
|
||||
// Build the quantile sketch across the whole input data, then use the histogram cuts to compress
|
||||
// each CSR page, and write the accumulated ELLPACK pages to disk.
|
||||
EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
|
||||
const std::string& cache_info,
|
||||
const BatchParam& param) noexcept(false) {
|
||||
device_ = param.gpu_id;
|
||||
|
||||
monitor_.Init("ellpack_page_source");
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
monitor_.StartCuda("Quantiles");
|
||||
common::HistogramCuts hmat;
|
||||
size_t row_stride =
|
||||
common::DeviceSketch(device_, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("CreateEllpackInfo");
|
||||
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, ba_);
|
||||
monitor_.StopCuda("CreateEllpackInfo");
|
||||
|
||||
monitor_.StartCuda("WriteEllpackPages");
|
||||
WriteEllpackPages(dmat, cache_info);
|
||||
monitor_.StopCuda("WriteEllpackPages");
|
||||
|
||||
source_.reset(new SparsePageSource<EllpackPage>(cache_info, kPageType_));
|
||||
}
|
||||
|
||||
void EllpackPageSourceImpl::BeforeFirst() {
|
||||
source_->BeforeFirst();
|
||||
}
|
||||
|
||||
bool EllpackPageSourceImpl::Next() {
|
||||
return source_->Next();
|
||||
}
|
||||
|
||||
EllpackPage& EllpackPageSourceImpl::Value() {
|
||||
EllpackPage& page = source_->Value();
|
||||
page.Impl()->InitDevice(device_, ellpack_info_);
|
||||
return page;
|
||||
}
|
||||
|
||||
const EllpackPage& EllpackPageSourceImpl::Value() const {
|
||||
EllpackPage& page = source_->Value();
|
||||
page.Impl()->InitDevice(device_, ellpack_info_);
|
||||
return page;
|
||||
}
|
||||
|
||||
// Compress each CSR page to ELLPACK, and write the accumulated pages to disk.
|
||||
void EllpackPageSourceImpl::WriteEllpackPages(DMatrix* dmat, const std::string& cache_info) const {
|
||||
auto cinfo = ParseCacheInfo(cache_info, kPageType_);
|
||||
const size_t extra_buffer_capacity = 6;
|
||||
SparsePageWriter<EllpackPage> writer(
|
||||
cinfo.name_shards, cinfo.format_shards, extra_buffer_capacity);
|
||||
std::shared_ptr<EllpackPage> page;
|
||||
writer.Alloc(&page);
|
||||
auto* impl = page->Impl();
|
||||
impl->matrix.info = ellpack_info_;
|
||||
impl->Clear();
|
||||
|
||||
const MetaInfo& info = dmat->Info();
|
||||
size_t bytes_write = 0;
|
||||
double tstart = dmlc::GetTime();
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
impl->Push(device_, batch);
|
||||
|
||||
if (impl->MemCostBytes() >= DMatrix::kPageSize) {
|
||||
bytes_write += impl->MemCostBytes();
|
||||
writer.PushWrite(std::move(page));
|
||||
writer.Alloc(&page);
|
||||
impl = page->Impl();
|
||||
impl->matrix.info = ellpack_info_;
|
||||
impl->Clear();
|
||||
double tdiff = dmlc::GetTime() - tstart;
|
||||
LOG(INFO) << "Writing to " << cache_info << " in "
|
||||
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
||||
<< (bytes_write >> 20UL) << " written";
|
||||
}
|
||||
}
|
||||
if (impl->Size() != 0) {
|
||||
writer.PushWrite(std::move(page));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
54
src/data/ellpack_page_source.h
Normal file
54
src/data/ellpack_page_source.h
Normal file
@ -0,0 +1,54 @@
|
||||
/*!
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
||||
#define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sparse_page_source.h"
|
||||
#include "../common/timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class EllpackPageSourceImpl;
|
||||
|
||||
/*!
|
||||
* \brief External memory data source for ELLPACK format.
|
||||
*
|
||||
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
|
||||
* including CUDA-specific implementation details in the header.
|
||||
*/
|
||||
class EllpackPageSource : public DataSource<EllpackPage> {
|
||||
public:
|
||||
/*!
|
||||
* \brief Create source from cache files the cache_prefix.
|
||||
* \param cache_prefix The prefix of cache we want to solve.
|
||||
*/
|
||||
explicit EllpackPageSource(DMatrix* dmat,
|
||||
const std::string& cache_info,
|
||||
const BatchParam& param) noexcept(false);
|
||||
|
||||
/*! \brief destructor */
|
||||
~EllpackPageSource() override = default;
|
||||
|
||||
void BeforeFirst() override;
|
||||
bool Next() override;
|
||||
EllpackPage& Value();
|
||||
const EllpackPage& Value() const override;
|
||||
|
||||
const EllpackPageSourceImpl* Impl() const { return impl_.get(); }
|
||||
EllpackPageSourceImpl* Impl() { return impl_.get(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<EllpackPageSourceImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
||||
@ -62,10 +62,12 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
return BatchSet<SortedCSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches() {
|
||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
if (!ellpack_page_) {
|
||||
ellpack_page_.reset(new EllpackPage(this));
|
||||
ellpack_page_.reset(new EllpackPage(this, param));
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
|
||||
@ -38,7 +38,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
BatchSet<SparsePage> GetRowBatches() override;
|
||||
BatchSet<CSCPage> GetColumnBatches() override;
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
|
||||
// source data pointer.
|
||||
std::unique_ptr<DataSource<SparsePage>> source_;
|
||||
|
||||
@ -23,10 +23,10 @@ const MetaInfo& SparsePageDMatrix::Info() const {
|
||||
return row_source_->info;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template<typename S, typename T>
|
||||
class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||
public:
|
||||
explicit SparseBatchIteratorImpl(SparsePageSource<T>* source) : source_(source) {
|
||||
explicit SparseBatchIteratorImpl(S* source) : source_(source) {
|
||||
CHECK(source_ != nullptr);
|
||||
}
|
||||
T& operator*() override { return source_->Value(); }
|
||||
@ -35,7 +35,7 @@ class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||
bool AtEnd() const override { return at_end_; }
|
||||
|
||||
private:
|
||||
SparsePageSource<T>* source_{nullptr};
|
||||
S* source_{nullptr};
|
||||
bool at_end_{ false };
|
||||
};
|
||||
|
||||
@ -43,7 +43,8 @@ BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
||||
auto cast = dynamic_cast<SparsePageSource<SparsePage>*>(row_source_.get());
|
||||
cast->BeforeFirst();
|
||||
cast->Next();
|
||||
auto begin_iter = BatchIterator<SparsePage>(new SparseBatchIteratorImpl<SparsePage>(cast));
|
||||
auto begin_iter = BatchIterator<SparsePage>(
|
||||
new SparseBatchIteratorImpl<SparsePageSource<SparsePage>, SparsePage>(cast));
|
||||
return BatchSet<SparsePage>(begin_iter);
|
||||
}
|
||||
|
||||
@ -55,8 +56,8 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
|
||||
}
|
||||
column_source_->BeforeFirst();
|
||||
column_source_->Next();
|
||||
auto begin_iter =
|
||||
BatchIterator<CSCPage>(new SparseBatchIteratorImpl<CSCPage>(column_source_.get()));
|
||||
auto begin_iter = BatchIterator<CSCPage>(
|
||||
new SparseBatchIteratorImpl<SparsePageSource<CSCPage>, CSCPage>(column_source_.get()));
|
||||
return BatchSet<CSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
@ -70,17 +71,26 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||
sorted_column_source_->BeforeFirst();
|
||||
sorted_column_source_->Next();
|
||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
||||
new SparseBatchIteratorImpl<SortedCSCPage>(sorted_column_source_.get()));
|
||||
new SparseBatchIteratorImpl<SparsePageSource<SortedCSCPage>, SortedCSCPage>(
|
||||
sorted_column_source_.get()));
|
||||
return BatchSet<SortedCSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches() {
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
if (!ellpack_page_) {
|
||||
ellpack_page_.reset(new EllpackPage(this));
|
||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
// Lazily instantiate
|
||||
if (!ellpack_source_ ||
|
||||
batch_param_.gpu_id != param.gpu_id ||
|
||||
batch_param_.max_bin != param.max_bin ||
|
||||
batch_param_.gpu_batch_nrows != param.gpu_batch_nrows) {
|
||||
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
|
||||
batch_param_ = param;
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
ellpack_source_->BeforeFirst();
|
||||
ellpack_source_->Next();
|
||||
auto begin_iter = BatchIterator<EllpackPage>(
|
||||
new SparseBatchIteratorImpl<EllpackPageSource, EllpackPage>(ellpack_source_.get()));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ellpack_page_source.h"
|
||||
#include "sparse_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -38,13 +39,15 @@ class SparsePageDMatrix : public DMatrix {
|
||||
BatchSet<SparsePage> GetRowBatches() override;
|
||||
BatchSet<CSCPage> GetColumnBatches() override;
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
|
||||
// source data pointers.
|
||||
std::unique_ptr<DataSource<SparsePage>> row_source_;
|
||||
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
|
||||
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
std::unique_ptr<EllpackPageSource> ellpack_source_;
|
||||
// saved batch param
|
||||
BatchParam batch_param_;
|
||||
// the cache prefix
|
||||
std::string cache_info_;
|
||||
// Store column densities to avoid recalculating
|
||||
|
||||
@ -12,9 +12,10 @@ namespace data {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format);
|
||||
|
||||
class SparsePageRawFormat : public SparsePageFormat {
|
||||
template<typename T>
|
||||
class SparsePageRawFormat : public SparsePageFormat<T> {
|
||||
public:
|
||||
bool Read(SparsePage* page, dmlc::SeekStream* fi) override {
|
||||
bool Read(T* page, dmlc::SeekStream* fi) override {
|
||||
auto& offset_vec = page->offset.HostVector();
|
||||
if (!fi->Read(&offset_vec)) return false;
|
||||
auto& data_vec = page->data.HostVector();
|
||||
@ -29,7 +30,7 @@ class SparsePageRawFormat : public SparsePageFormat {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Read(SparsePage* page,
|
||||
bool Read(T* page,
|
||||
dmlc::SeekStream* fi,
|
||||
const std::vector<bst_uint>& sorted_index_set) override {
|
||||
if (!fi->Read(&disk_offset_)) return false;
|
||||
@ -79,7 +80,7 @@ class SparsePageRawFormat : public SparsePageFormat {
|
||||
return true;
|
||||
}
|
||||
|
||||
void Write(const SparsePage& page, dmlc::Stream* fo) override {
|
||||
void Write(const T& page, dmlc::Stream* fo) override {
|
||||
const auto& offset_vec = page.offset.HostVector();
|
||||
const auto& data_vec = page.data.HostVector();
|
||||
CHECK(page.offset.Size() != 0 && offset_vec[0] == 0);
|
||||
@ -98,7 +99,20 @@ class SparsePageRawFormat : public SparsePageFormat {
|
||||
XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
|
||||
.describe("Raw binary data format.")
|
||||
.set_body([]() {
|
||||
return new SparsePageRawFormat();
|
||||
return new SparsePageRawFormat<SparsePage>();
|
||||
});
|
||||
|
||||
XGBOOST_REGISTER_CSC_PAGE_FORMAT(raw)
|
||||
.describe("Raw binary data format.")
|
||||
.set_body([]() {
|
||||
return new SparsePageRawFormat<CSCPage>();
|
||||
});
|
||||
|
||||
XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(raw)
|
||||
.describe("Raw binary data format.")
|
||||
.set_body([]() {
|
||||
return new SparsePageRawFormat<SortedCSCPage>();
|
||||
});
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -46,6 +46,47 @@ GetCacheShards(const std::string& cache_info) {
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
/*!
|
||||
* \brief decide the format from cache prefix.
|
||||
* \return pair of row format, column format type of the cache prefix.
|
||||
*/
|
||||
inline std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix) {
|
||||
size_t pos = cache_prefix.rfind(".fmt-");
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
std::string fmt = cache_prefix.substr(pos + 5, cache_prefix.length());
|
||||
size_t cpos = fmt.rfind('-');
|
||||
if (cpos != std::string::npos) {
|
||||
return std::make_pair(fmt.substr(0, cpos), fmt.substr(cpos + 1, fmt.length()));
|
||||
} else {
|
||||
return std::make_pair(fmt, fmt);
|
||||
}
|
||||
} else {
|
||||
std::string raw = "raw";
|
||||
return std::make_pair(raw, raw);
|
||||
}
|
||||
}
|
||||
|
||||
struct CacheInfo {
|
||||
std::string name_info;
|
||||
std::vector<std::string> format_shards;
|
||||
std::vector<std::string> name_shards;
|
||||
};
|
||||
|
||||
inline CacheInfo ParseCacheInfo(const std::string& cache_info, const std::string& page_type) {
|
||||
CacheInfo info;
|
||||
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
|
||||
CHECK_NE(cache_shards.size(), 0U);
|
||||
// read in the info files.
|
||||
info.name_info = cache_shards[0];
|
||||
for (const std::string& prefix : cache_shards) {
|
||||
info.name_shards.push_back(prefix + page_type);
|
||||
info.format_shards.push_back(DecideFormat(prefix).first);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief External memory data source.
|
||||
* \code
|
||||
@ -72,6 +113,7 @@ class SparsePageSource : public DataSource<T> {
|
||||
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r"));
|
||||
int tmagic;
|
||||
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
|
||||
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
|
||||
this->info.LoadBinary(finfo.get());
|
||||
}
|
||||
files_.resize(cache_shards.size());
|
||||
@ -85,8 +127,8 @@ class SparsePageSource : public DataSource<T> {
|
||||
std::unique_ptr<dmlc::SeekStream>& fi = files_[i];
|
||||
std::string format;
|
||||
CHECK(fi->Read(&format)) << "Invalid page format";
|
||||
formats_[i].reset(SparsePageFormat::Create(format));
|
||||
std::unique_ptr<SparsePageFormat>& fmt = formats_[i];
|
||||
formats_[i].reset(CreatePageFormat<T>(format));
|
||||
std::unique_ptr<SparsePageFormat<T>>& fmt = formats_[i];
|
||||
size_t fbegin = fi->Tell();
|
||||
prefetchers_[i].reset(new dmlc::ThreadedIter<T>(4));
|
||||
prefetchers_[i]->Init([&fi, &fmt] (T** dptr) {
|
||||
@ -111,7 +153,7 @@ class SparsePageSource : public DataSource<T> {
|
||||
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
|
||||
}
|
||||
if (prefetchers_[clock_ptr_]->Next(&page_)) {
|
||||
page_->base_rowid = base_rowid_;
|
||||
page_->SetBaseRowId(base_rowid_);
|
||||
base_rowid_ += page_->Size();
|
||||
// advance clock
|
||||
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
|
||||
@ -149,17 +191,9 @@ class SparsePageSource : public DataSource<T> {
|
||||
const std::string& cache_info,
|
||||
const size_t page_size = DMatrix::kPageSize) {
|
||||
const std::string page_type = ".row.page";
|
||||
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
|
||||
CHECK_NE(cache_shards.size(), 0U);
|
||||
// read in the info files.
|
||||
std::string name_info = cache_shards[0];
|
||||
std::vector<std::string> name_shards, format_shards;
|
||||
for (const std::string& prefix : cache_shards) {
|
||||
name_shards.push_back(prefix + page_type);
|
||||
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first);
|
||||
}
|
||||
auto cinfo = ParseCacheInfo(cache_info, page_type);
|
||||
{
|
||||
SparsePageWriter writer(name_shards, format_shards, 6);
|
||||
SparsePageWriter<SparsePage> writer(cinfo.name_shards, cinfo.format_shards, 6);
|
||||
std::shared_ptr<SparsePage> page;
|
||||
writer.Alloc(&page); page->Clear();
|
||||
|
||||
@ -230,30 +264,19 @@ class SparsePageSource : public DataSource<T> {
|
||||
writer.PushWrite(std::move(page));
|
||||
}
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(name_info.c_str(), "w"));
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(cinfo.name_info.c_str(), "w"));
|
||||
int tmagic = kMagic;
|
||||
fo->Write(&tmagic, sizeof(tmagic));
|
||||
// Either every row has query ID or none at all
|
||||
CHECK(qids.empty() || qids.size() == info.num_row_);
|
||||
info.SaveBinary(fo.get());
|
||||
}
|
||||
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
|
||||
<< name_info;
|
||||
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to " << cinfo.name_info;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Create source cache by copy content from DMatrix.
|
||||
* \param cache_info The cache_info of cache file location.
|
||||
*/
|
||||
static void CreateRowPage(DMatrix* src,
|
||||
const std::string& cache_info) {
|
||||
const std::string page_type = ".row.page";
|
||||
CreatePageFromDMatrix(src, cache_info, page_type);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Create source cache by copy content from DMatrix. Creates transposed column page, may be sorted or not.
|
||||
* Creates transposed column page, may be sorted or not.
|
||||
* \param cache_info The cache_info of cache file location.
|
||||
* \param sorted Whether columns should be pre-sorted
|
||||
*/
|
||||
@ -293,17 +316,9 @@ class SparsePageSource : public DataSource<T> {
|
||||
static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info,
|
||||
const std::string& page_type,
|
||||
const size_t page_size = DMatrix::kPageSize) {
|
||||
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
|
||||
CHECK_NE(cache_shards.size(), 0U);
|
||||
// read in the info files.
|
||||
std::string name_info = cache_shards[0];
|
||||
std::vector<std::string> name_shards, format_shards;
|
||||
for (const std::string& prefix : cache_shards) {
|
||||
name_shards.push_back(prefix + page_type);
|
||||
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first);
|
||||
}
|
||||
auto cinfo = ParseCacheInfo(cache_info, page_type);
|
||||
{
|
||||
SparsePageWriter writer(name_shards, format_shards, 6);
|
||||
SparsePageWriter<SparsePage> writer(cinfo.name_shards, cinfo.format_shards, 6);
|
||||
std::shared_ptr<SparsePage> page;
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
@ -312,9 +327,7 @@ class SparsePageSource : public DataSource<T> {
|
||||
size_t bytes_write = 0;
|
||||
double tstart = dmlc::GetTime();
|
||||
for (auto& batch : src->GetBatches<SparsePage>()) {
|
||||
if (page_type == ".row.page") {
|
||||
page->Push(batch);
|
||||
} else if (page_type == ".col.page") {
|
||||
if (page_type == ".col.page") {
|
||||
page->PushCSC(batch.GetTranspose(src->Info().num_col_));
|
||||
} else if (page_type == ".sorted.col.page") {
|
||||
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
|
||||
@ -338,28 +351,22 @@ class SparsePageSource : public DataSource<T> {
|
||||
if (page->data.Size() != 0) {
|
||||
writer.PushWrite(std::move(page));
|
||||
}
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(name_info.c_str(), "w"));
|
||||
int tmagic = kMagic;
|
||||
fo->Write(&tmagic, sizeof(tmagic));
|
||||
info.SaveBinary(fo.get());
|
||||
}
|
||||
LOG(INFO) << "SparsePageSource: Finished writing to " << name_info;
|
||||
LOG(INFO) << "SparsePageSource: Finished writing to " << cinfo.name_info;
|
||||
}
|
||||
|
||||
/*! \brief number of rows */
|
||||
size_t base_rowid_;
|
||||
/*! \brief page currently on hold. */
|
||||
T *page_;
|
||||
T* page_;
|
||||
/*! \brief internal clock ptr */
|
||||
size_t clock_ptr_;
|
||||
/*! \brief file pointer to the row blob file. */
|
||||
std::vector<std::unique_ptr<dmlc::SeekStream> > files_;
|
||||
std::vector<std::unique_ptr<dmlc::SeekStream>> files_;
|
||||
/*! \brief Sparse page format file. */
|
||||
std::vector<std::unique_ptr<SparsePageFormat> > formats_;
|
||||
std::vector<std::unique_ptr<SparsePageFormat<T>>> formats_;
|
||||
/*! \brief internal prefetcher. */
|
||||
std::vector<std::unique_ptr<dmlc::ThreadedIter<T> > > prefetchers_;
|
||||
std::vector<std::unique_ptr<dmlc::ThreadedIter<T>>> prefetchers_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* \file sparse_batch_writer.cc
|
||||
* \param Writer class sparse page.
|
||||
*/
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include "./sparse_page_writer.h"
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
SparsePageWriter::SparsePageWriter(
|
||||
const std::vector<std::string>& name_shards,
|
||||
const std::vector<std::string>& format_shards,
|
||||
size_t extra_buffer_capacity)
|
||||
: num_free_buffer_(extra_buffer_capacity + name_shards.size()),
|
||||
clock_ptr_(0),
|
||||
workers_(name_shards.size()),
|
||||
qworkers_(name_shards.size()) {
|
||||
CHECK_EQ(name_shards.size(), format_shards.size());
|
||||
// start writer threads
|
||||
for (size_t i = 0; i < name_shards.size(); ++i) {
|
||||
std::string name_shard = name_shards[i];
|
||||
std::string format_shard = format_shards[i];
|
||||
auto* wqueue = &qworkers_[i];
|
||||
workers_[i].reset(new std::thread(
|
||||
[this, name_shard, format_shard, wqueue] () {
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(name_shard.c_str(), "w"));
|
||||
std::unique_ptr<SparsePageFormat> fmt(
|
||||
SparsePageFormat::Create(format_shard));
|
||||
fo->Write(format_shard);
|
||||
std::shared_ptr<SparsePage> page;
|
||||
while (wqueue->Pop(&page)) {
|
||||
if (page == nullptr) break;
|
||||
fmt->Write(*page, fo.get());
|
||||
qrecycle_.Push(std::move(page));
|
||||
}
|
||||
fo.reset(nullptr);
|
||||
LOG(INFO) << "SparsePage::Writer Finished writing to " << name_shard;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
SparsePageWriter::~SparsePageWriter() {
|
||||
for (auto& queue : qworkers_) {
|
||||
// use nullptr to signal termination.
|
||||
std::shared_ptr<SparsePage> sig(nullptr);
|
||||
queue.Push(std::move(sig));
|
||||
}
|
||||
for (auto& thread : workers_) {
|
||||
thread->join();
|
||||
}
|
||||
}
|
||||
|
||||
void SparsePageWriter::PushWrite(std::shared_ptr<SparsePage>&& page) {
|
||||
qworkers_[clock_ptr_].Push(std::move(page));
|
||||
clock_ptr_ = (clock_ptr_ + 1) % workers_.size();
|
||||
}
|
||||
|
||||
void SparsePageWriter::Alloc(std::shared_ptr<SparsePage>* out_page) {
|
||||
CHECK(*out_page == nullptr);
|
||||
if (num_free_buffer_ != 0) {
|
||||
out_page->reset(new SparsePage());
|
||||
--num_free_buffer_;
|
||||
} else {
|
||||
CHECK(qrecycle_.Pop(out_page));
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // DMLC_ENABLE_STD_THREAD
|
||||
@ -23,9 +23,14 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
template<typename T>
|
||||
struct SparsePageFormatReg;
|
||||
|
||||
/*!
|
||||
* \brief Format specification of SparsePage.
|
||||
*/
|
||||
template<typename T>
|
||||
class SparsePageFormat {
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
@ -36,7 +41,8 @@ class SparsePageFormat {
|
||||
* \param fi the input stream of the file
|
||||
* \return true of the loading as successful, false if end of file was reached
|
||||
*/
|
||||
virtual bool Read(SparsePage* page, dmlc::SeekStream* fi) = 0;
|
||||
virtual bool Read(T* page, dmlc::SeekStream* fi) = 0;
|
||||
|
||||
/*!
|
||||
* \brief read only the segments we are interested in, advance fi to end of the block.
|
||||
* \param page The page to load the data into.
|
||||
@ -44,30 +50,35 @@ class SparsePageFormat {
|
||||
* \param sorted_index_set sorted index of segments we are interested in
|
||||
* \return true of the loading as successful, false if end of file was reached
|
||||
*/
|
||||
virtual bool Read(SparsePage* page,
|
||||
virtual bool Read(T* page,
|
||||
dmlc::SeekStream* fi,
|
||||
const std::vector<bst_uint>& sorted_index_set) = 0;
|
||||
/*!
|
||||
* \brief save the data to fo, when a page was written.
|
||||
* \param fo output stream
|
||||
*/
|
||||
virtual void Write(const SparsePage& page, dmlc::Stream* fo) = 0;
|
||||
/*!
|
||||
* \brief Create sparse page of format.
|
||||
* \return The created format functors.
|
||||
*/
|
||||
static SparsePageFormat* Create(const std::string& name);
|
||||
/*!
|
||||
* \brief decide the format from cache prefix.
|
||||
* \return pair of row format, column format type of the cache prefix.
|
||||
*/
|
||||
static std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix);
|
||||
virtual void Write(const T& page, dmlc::Stream* fo) = 0;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Create sparse page of format.
|
||||
* \return The created format functors.
|
||||
*/
|
||||
template<typename T>
|
||||
inline SparsePageFormat<T>* CreatePageFormat(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry<SparsePageFormatReg<T>>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown format type " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
/*!
|
||||
* \brief A threaded writer to write sparse batch page to sharded files.
|
||||
* @tparam T Type of the page.
|
||||
*/
|
||||
template<typename T>
|
||||
class SparsePageWriter {
|
||||
public:
|
||||
/*!
|
||||
@ -76,26 +87,74 @@ class SparsePageWriter {
|
||||
* \param format_shards format of each shard.
|
||||
* \param extra_buffer_capacity Extra buffer capacity before block.
|
||||
*/
|
||||
explicit SparsePageWriter(
|
||||
const std::vector<std::string>& name_shards,
|
||||
const std::vector<std::string>& format_shards,
|
||||
size_t extra_buffer_capacity);
|
||||
explicit SparsePageWriter(const std::vector<std::string>& name_shards,
|
||||
const std::vector<std::string>& format_shards,
|
||||
size_t extra_buffer_capacity)
|
||||
: num_free_buffer_(extra_buffer_capacity + name_shards.size()),
|
||||
clock_ptr_(0),
|
||||
workers_(name_shards.size()),
|
||||
qworkers_(name_shards.size()) {
|
||||
CHECK_EQ(name_shards.size(), format_shards.size());
|
||||
// start writer threads
|
||||
for (size_t i = 0; i < name_shards.size(); ++i) {
|
||||
std::string name_shard = name_shards[i];
|
||||
std::string format_shard = format_shards[i];
|
||||
auto* wqueue = &qworkers_[i];
|
||||
workers_[i].reset(new std::thread(
|
||||
[this, name_shard, format_shard, wqueue]() {
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(name_shard.c_str(), "w"));
|
||||
std::unique_ptr<SparsePageFormat<T>> fmt(CreatePageFormat<T>(format_shard));
|
||||
fo->Write(format_shard);
|
||||
std::shared_ptr<T> page;
|
||||
while (wqueue->Pop(&page)) {
|
||||
if (page == nullptr) break;
|
||||
fmt->Write(*page, fo.get());
|
||||
qrecycle_.Push(std::move(page));
|
||||
}
|
||||
fo.reset(nullptr);
|
||||
LOG(INFO) << "SparsePageWriter Finished writing to " << name_shard;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief destructor, will close the files automatically */
|
||||
~SparsePageWriter();
|
||||
~SparsePageWriter() {
|
||||
for (auto& queue : qworkers_) {
|
||||
// use nullptr to signal termination.
|
||||
std::shared_ptr<T> sig(nullptr);
|
||||
queue.Push(std::move(sig));
|
||||
}
|
||||
for (auto& thread : workers_) {
|
||||
thread->join();
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Push a write job to the writer.
|
||||
* This function won't block,
|
||||
* writing is done by another thread inside writer.
|
||||
* \param page The page to be written
|
||||
*/
|
||||
void PushWrite(std::shared_ptr<SparsePage>&& page);
|
||||
void PushWrite(std::shared_ptr<T>&& page) {
|
||||
qworkers_[clock_ptr_].Push(std::move(page));
|
||||
clock_ptr_ = (clock_ptr_ + 1) % workers_.size();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Allocate a page to store results.
|
||||
* This function can block when the writer is too slow and buffer pages
|
||||
* have not yet been recycled.
|
||||
* \param out_page Used to store the allocated pages.
|
||||
*/
|
||||
void Alloc(std::shared_ptr<SparsePage>* out_page);
|
||||
void Alloc(std::shared_ptr<T>* out_page) {
|
||||
CHECK(*out_page == nullptr);
|
||||
if (num_free_buffer_ != 0) {
|
||||
out_page->reset(new T());
|
||||
--num_free_buffer_;
|
||||
} else {
|
||||
CHECK(qrecycle_.Pop(out_page));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief number of allocated pages */
|
||||
@ -103,20 +162,21 @@ class SparsePageWriter {
|
||||
/*! \brief clock_pointer */
|
||||
size_t clock_ptr_;
|
||||
/*! \brief writer threads */
|
||||
std::vector<std::unique_ptr<std::thread> > workers_;
|
||||
std::vector<std::unique_ptr<std::thread>> workers_;
|
||||
/*! \brief recycler queue */
|
||||
dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > qrecycle_;
|
||||
dmlc::ConcurrentBlockingQueue<std::shared_ptr<T>> qrecycle_;
|
||||
/*! \brief worker threads */
|
||||
std::vector<dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > > qworkers_;
|
||||
std::vector<dmlc::ConcurrentBlockingQueue<std::shared_ptr<T>>> qworkers_;
|
||||
};
|
||||
#endif // DMLC_ENABLE_STD_THREAD
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for sparse page format.
|
||||
*/
|
||||
template<typename T>
|
||||
struct SparsePageFormatReg
|
||||
: public dmlc::FunctionRegEntryBase<SparsePageFormatReg,
|
||||
std::function<SparsePageFormat* ()> > {
|
||||
: public dmlc::FunctionRegEntryBase<SparsePageFormatReg<T>,
|
||||
std::function<SparsePageFormat<T>* ()>> {
|
||||
};
|
||||
|
||||
/*!
|
||||
@ -131,8 +191,21 @@ struct SparsePageFormatReg
|
||||
* });
|
||||
* \endcode
|
||||
*/
|
||||
#define SparsePageFmt SparsePageFormat<SparsePage>
|
||||
#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(::xgboost::data::SparsePageFormatReg, SparsePageFormat, Name)
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<SparsePage>, SparsePageFmt, Name)
|
||||
|
||||
#define CSCPageFmt SparsePageFormat<CSCPage>
|
||||
#define XGBOOST_REGISTER_CSC_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<CSCPage>, CSCPageFmt, Name)
|
||||
|
||||
#define SortedCSCPageFmt SparsePageFormat<SortedCSCPage>
|
||||
#define XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<SortedCSCPage>, SortedCSCPageFmt, Name)
|
||||
|
||||
#define EllpackPageFmt SparsePageFormat<EllpackPage>
|
||||
#define XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<EllpackPage>, EllpackPageFm, Name)
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -174,16 +174,15 @@ template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx, common::Span<const GradientSumT> node_histogram,
|
||||
const xgboost::ELLPackMatrix& matrix,
|
||||
const xgboost::EllpackMatrix& matrix,
|
||||
DeviceSplitCandidate* best_split, // shared memory storing best split
|
||||
const DeviceNodeStats& node, const GPUTrainingParam& param,
|
||||
TempStorageT* temp_storage, // temp memory for cub operations
|
||||
int constraint, // monotonic_constraints
|
||||
const ValueConstraint& value_constraint) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = matrix.feature_segments[fidx]; // begining bin
|
||||
uint32_t gidx_end =
|
||||
matrix.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
uint32_t gidx_begin = matrix.info.feature_segments[fidx]; // begining bin
|
||||
uint32_t gidx_end = matrix.info.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
|
||||
// Sum histogram bins for current feature
|
||||
GradientSumT const feature_sum = ReduceFeature<BLOCK_THREADS, ReduceT>(
|
||||
@ -231,9 +230,9 @@ __device__ void EvaluateFeature(
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = matrix.min_fvalue[fidx];
|
||||
fvalue = matrix.info.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = matrix.gidx_fvalue_map[split_gidx];
|
||||
fvalue = matrix.info.gidx_fvalue_map[split_gidx];
|
||||
}
|
||||
GradientSumT left = missing_left ? bin + missing : bin;
|
||||
GradientSumT right = parent_sum - left;
|
||||
@ -249,7 +248,7 @@ __global__ void EvaluateSplitKernel(
|
||||
common::Span<const GradientSumT> node_histogram, // histogram for gradients
|
||||
common::Span<const int> feature_set, // Selected features
|
||||
DeviceNodeStats node,
|
||||
xgboost::ELLPackMatrix matrix,
|
||||
xgboost::EllpackMatrix matrix,
|
||||
GPUTrainingParam gpu_param,
|
||||
common::Span<DeviceSplitCandidate> split_candidates, // resulting split
|
||||
ValueConstraint value_constraint,
|
||||
@ -401,7 +400,7 @@ struct CalcWeightTrainParam {
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
__global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix,
|
||||
__global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
GradientSumT* d_node_hist,
|
||||
const GradientPair* d_gpair, size_t n_elements,
|
||||
@ -413,10 +412,10 @@ __global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix,
|
||||
__syncthreads();
|
||||
}
|
||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||
int ridx = d_ridx[idx / matrix.row_stride ];
|
||||
int ridx = d_ridx[idx / matrix.info.row_stride ];
|
||||
int gidx =
|
||||
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
||||
if (gidx != matrix.null_gidx_value) {
|
||||
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
|
||||
if (gidx != matrix.info.n_bins) {
|
||||
// If we are not using shared memory, accumulate the values directly into
|
||||
// global memory
|
||||
GradientSumT* atomic_add_ptr =
|
||||
@ -606,7 +605,7 @@ struct GPUHistMakerDevice {
|
||||
int constexpr kBlockThreads = 256;
|
||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>
|
||||
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node, page->ellpack_matrix,
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix,
|
||||
gpu_param, d_split_candidates, node_value_constraints[nidx],
|
||||
monotone_constraints);
|
||||
|
||||
@ -632,11 +631,11 @@ struct GPUHistMakerDevice {
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
auto d_gpair = gpair.data();
|
||||
|
||||
auto n_elements = d_ridx.size() * page->ellpack_matrix.row_stride;
|
||||
auto n_elements = d_ridx.size() * page->matrix.info.row_stride;
|
||||
|
||||
const size_t smem_size =
|
||||
use_shared_memory_histograms
|
||||
? sizeof(GradientSumT) * page->ellpack_matrix.BinCount()
|
||||
? sizeof(GradientSumT) * page->matrix.BinCount()
|
||||
: 0;
|
||||
const int items_per_thread = 8;
|
||||
const int block_threads = 256;
|
||||
@ -646,7 +645,7 @@ struct GPUHistMakerDevice {
|
||||
return;
|
||||
}
|
||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
||||
page->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
|
||||
page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
|
||||
use_shared_memory_histograms);
|
||||
}
|
||||
|
||||
@ -656,7 +655,7 @@ struct GPUHistMakerDevice {
|
||||
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
|
||||
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
|
||||
|
||||
dh::LaunchN(device_id, page->n_bins, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device_id, page->matrix.info.n_bins, [=] __device__(size_t idx) {
|
||||
d_node_hist_subtraction[idx] =
|
||||
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
||||
});
|
||||
@ -671,7 +670,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, RegTree::Node split_node) {
|
||||
auto d_matrix = page->ellpack_matrix;
|
||||
auto d_matrix = page->matrix;
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
@ -703,7 +702,7 @@ struct GPUHistMakerDevice {
|
||||
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_matrix = page->ellpack_matrix;
|
||||
auto d_matrix = page->matrix;
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(bst_uint ridx, int position) {
|
||||
auto node = d_nodes[position];
|
||||
@ -766,8 +765,7 @@ struct GPUHistMakerDevice {
|
||||
reducer->AllReduceSum(
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
page->ellpack_matrix.BinCount() *
|
||||
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||
page->matrix.BinCount() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||
reducer->Synchronize();
|
||||
|
||||
monitor.StopCuda("AllReduce");
|
||||
@ -956,14 +954,14 @@ inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
|
||||
// check if we can use shared memory for building histograms
|
||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
||||
// hiding)
|
||||
auto histogram_size = sizeof(GradientSumT) * page->n_bins;
|
||||
auto histogram_size = sizeof(GradientSumT) * page->matrix.info.n_bins;
|
||||
auto max_smem = dh::MaxSharedMemory(device_id);
|
||||
if (histogram_size <= max_smem) {
|
||||
use_shared_memory_histograms = true;
|
||||
}
|
||||
|
||||
// Init histogram
|
||||
hist.Init(device_id, page->n_bins);
|
||||
hist.Init(device_id, page->matrix.info.n_bins);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
@ -1017,22 +1015,23 @@ class GPUHistMakerSpecialised {
|
||||
|
||||
// TODO(rongou): support multiple Ellpack pages.
|
||||
EllpackPageImpl* page{};
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>()) {
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>({device_,
|
||||
param_.max_bin,
|
||||
hist_maker_param_.gpu_batch_nrows})) {
|
||||
page = batch.Impl();
|
||||
page->Init(device_, param_.max_bin, hist_maker_param_.gpu_batch_nrows);
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
maker_.reset(new GPUHistMakerDevice<GradientSumT>(device_,
|
||||
page,
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_));
|
||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(device_,
|
||||
page,
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_));
|
||||
|
||||
monitor_.StartCuda("InitHistogram");
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
maker_->InitHistogram();
|
||||
maker->InitHistogram();
|
||||
monitor_.StopCuda("InitHistogram");
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
@ -1071,17 +1070,17 @@ class GPUHistMakerSpecialised {
|
||||
monitor_.StopCuda("InitData");
|
||||
|
||||
gpair->SetDevice(device_);
|
||||
maker_->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
monitor_.StartCuda("UpdatePredictionCache");
|
||||
p_out_preds->SetDevice(device_);
|
||||
maker_->UpdatePredictionCache(p_out_preds->DevicePointer());
|
||||
maker->UpdatePredictionCache(p_out_preds->DevicePointer());
|
||||
monitor_.StopCuda("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
@ -1089,7 +1088,7 @@ class GPUHistMakerSpecialised {
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker_; // NOLINT
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
private:
|
||||
bool initialised_;
|
||||
|
||||
@ -17,15 +17,13 @@ TEST(EllpackPage, EmptyDMatrix) {
|
||||
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256, kGpuBatchNRows = 64;
|
||||
constexpr float kSparsity = 0;
|
||||
auto dmat = *CreateDMatrix(kNRows, kNCols, kSparsity);
|
||||
auto& page = *dmat->GetBatches<EllpackPage>().begin();
|
||||
auto& page = *dmat->GetBatches<EllpackPage>({0, kMaxBin, kGpuBatchNRows}).begin();
|
||||
auto impl = page.Impl();
|
||||
impl->Init(0, kMaxBin, kGpuBatchNRows);
|
||||
ASSERT_EQ(impl->ellpack_matrix.feature_segments.size(), 1);
|
||||
ASSERT_EQ(impl->ellpack_matrix.min_fvalue.size(), 0);
|
||||
ASSERT_EQ(impl->ellpack_matrix.gidx_fvalue_map.size(), 0);
|
||||
ASSERT_EQ(impl->ellpack_matrix.row_stride, 0);
|
||||
ASSERT_EQ(impl->ellpack_matrix.null_gidx_value, 0);
|
||||
ASSERT_EQ(impl->n_bins, 0);
|
||||
ASSERT_EQ(impl->matrix.info.feature_segments.size(), 1);
|
||||
ASSERT_EQ(impl->matrix.info.min_fvalue.size(), 0);
|
||||
ASSERT_EQ(impl->matrix.info.gidx_fvalue_map.size(), 0);
|
||||
ASSERT_EQ(impl->matrix.info.row_stride, 0);
|
||||
ASSERT_EQ(impl->matrix.info.n_bins, 0);
|
||||
ASSERT_EQ(impl->gidx_buffer.size(), 4);
|
||||
}
|
||||
|
||||
@ -37,7 +35,7 @@ TEST(EllpackPage, BuildGidxDense) {
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer);
|
||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||
|
||||
ASSERT_EQ(page->ellpack_matrix.row_stride, kNCols);
|
||||
ASSERT_EQ(page->matrix.info.row_stride, kNCols);
|
||||
|
||||
std::vector<uint32_t> solution = {
|
||||
0, 3, 8, 9, 14, 17, 20, 21,
|
||||
@ -70,7 +68,7 @@ TEST(EllpackPage, BuildGidxSparse) {
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer);
|
||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||
|
||||
ASSERT_LE(page->ellpack_matrix.row_stride, 3);
|
||||
ASSERT_LE(page->matrix.info.row_stride, 3);
|
||||
|
||||
// row_stride = 3, 16 rows, 48 entries for ELLPack
|
||||
std::vector<uint32_t> solution = {
|
||||
@ -78,7 +76,7 @@ TEST(EllpackPage, BuildGidxSparse) {
|
||||
24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24,
|
||||
24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24
|
||||
};
|
||||
for (size_t i = 0; i < kNRows * page->ellpack_matrix.row_stride; ++i) {
|
||||
for (size_t i = 0; i < kNRows * page->matrix.info.row_stride; ++i) {
|
||||
ASSERT_EQ(solution[i], gidx[i]);
|
||||
}
|
||||
}
|
||||
|
||||
26
tests/cpp/data/test_sparse_page_dmatrix.cu
Normal file
26
tests/cpp/data/test_sparse_page_dmatrix.cu
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright by Contributors
|
||||
|
||||
#include <dmlc/filesystem.h>
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(GPUSparsePageDMatrix, EllpackPage) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false);
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256, 64})) {
|
||||
EXPECT_EQ(batch.Size(), dmat->Info().num_row_);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page"));
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.ellpack.page"));
|
||||
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@ -192,14 +192,14 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
|
||||
return dmat;
|
||||
}
|
||||
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
|
||||
size_t page_size, bool deterministic) {
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
||||
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
|
||||
const dmlc::TemporaryDirectory& tempdir) {
|
||||
if (!n_rows || !n_cols) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create the svm file in a temp dir
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
||||
|
||||
std::ofstream fo(tmp_file.c_str());
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/metric.h>
|
||||
@ -199,8 +200,9 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
|
||||
*
|
||||
* \return The new dmatrix.
|
||||
*/
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
|
||||
size_t page_size, bool deterministic);
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
||||
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
|
||||
const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory());
|
||||
|
||||
gbm::GBTreeModel CreateTestModel();
|
||||
|
||||
@ -247,16 +249,15 @@ inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
|
||||
0.26f, 0.71f, 1.83f});
|
||||
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
|
||||
|
||||
auto is_dense = (*dmat)->Info().num_nonzero_ ==
|
||||
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
|
||||
size_t row_stride = 0;
|
||||
const auto &offset_vec = batch.offset.ConstHostVector();
|
||||
for (size_t i = 1; i < offset_vec.size(); ++i) {
|
||||
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
|
||||
}
|
||||
|
||||
auto page = std::unique_ptr<EllpackPageImpl>(new EllpackPageImpl(dmat->get()));
|
||||
page->InitCompressedData(0, cmat, row_stride, is_dense);
|
||||
auto page = std::unique_ptr<EllpackPageImpl>(new EllpackPageImpl(dmat->get(), {0, 256, 0}));
|
||||
page->InitInfo(0, (*dmat)->IsDense(), row_stride, cmat);
|
||||
page->InitCompressedData(0, n_rows);
|
||||
page->CreateHistIndices(0, batch, RowStateOnDevice(batch.Size(), batch.Size()));
|
||||
|
||||
delete dmat;
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
*/
|
||||
#include <thrust/device_vector.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <random>
|
||||
#include <string>
|
||||
@ -207,14 +208,14 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
|
||||
// Copy cut matrix to device.
|
||||
maker.ba.Allocate(0,
|
||||
&(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(),
|
||||
&(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(),
|
||||
&(page->ellpack_matrix.gidx_fvalue_map), 24,
|
||||
&(page->matrix.info.feature_segments), cmat.Ptrs().size(),
|
||||
&(page->matrix.info.min_fvalue), cmat.MinValues().size(),
|
||||
&(page->matrix.info.gidx_fvalue_map), 24,
|
||||
&(maker.monotone_constraints), kNCols);
|
||||
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.feature_segments, cmat.Ptrs());
|
||||
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.gidx_fvalue_map, cmat.Values());
|
||||
dh::CopyVectorToDeviceSpan(page->matrix.info.feature_segments, cmat.Ptrs());
|
||||
dh::CopyVectorToDeviceSpan(page->matrix.info.gidx_fvalue_map, cmat.Values());
|
||||
dh::CopyVectorToDeviceSpan(maker.monotone_constraints, param.monotone_constraints);
|
||||
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.min_fvalue, cmat.MinValues());
|
||||
dh::CopyVectorToDeviceSpan(page->matrix.info.min_fvalue, cmat.MinValues());
|
||||
|
||||
// Initialize GPUHistMakerDevice::hist
|
||||
maker.hist.Init(0, (max_bins - 1) * kNCols);
|
||||
@ -265,8 +266,10 @@ void TestHistogramIndexImpl() {
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker, hist_maker_ext;
|
||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
std::unique_ptr<DMatrix> hist_maker_ext_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true));
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true, tempdir));
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> training_params = {
|
||||
{"max_depth", "10"},
|
||||
@ -275,22 +278,21 @@ void TestHistogramIndexImpl() {
|
||||
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
hist_maker.Configure(training_params, &generic_param);
|
||||
|
||||
hist_maker.InitDataOnce(hist_maker_dmat.get());
|
||||
hist_maker_ext.Configure(training_params, &generic_param);
|
||||
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
||||
|
||||
// Extract the device maker from the histogram makers and from that its compressed
|
||||
// histogram index
|
||||
const auto &maker = hist_maker.maker_;
|
||||
const auto &maker = hist_maker.maker;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(maker->page->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, maker->page->gidx_buffer);
|
||||
|
||||
const auto &maker_ext = hist_maker_ext.maker_;
|
||||
const auto &maker_ext = hist_maker_ext.maker;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(maker_ext->page->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, maker_ext->page->gidx_buffer);
|
||||
|
||||
ASSERT_EQ(maker->page->n_bins, maker_ext->page->n_bins);
|
||||
ASSERT_EQ(maker->page->matrix.info.n_bins, maker_ext->page->matrix.info.n_bins);
|
||||
ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size());
|
||||
|
||||
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user