diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 2c528d13a..6dea48d5e 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -40,7 +40,6 @@ #if DMLC_ENABLE_STD_THREAD #include "../src/data/sparse_page_dmatrix.cc" -#include "../src/data/sparse_page_writer.cc" #endif // tress diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 44d9ea841..7fa83472b 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -156,6 +156,18 @@ struct Entry { } }; +/*! + * \brief Parameters for constructing batches. + */ +struct BatchParam { + /*! \brief The GPU device to use. */ + int gpu_id; + /*! \brief Maximum number of bins per feature for histograms. */ + int max_bin; + /*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */ + int gpu_batch_nrows; +}; + /*! * \brief In-memory storage unit of sparse batch, stored in CSR format. */ @@ -191,14 +203,17 @@ class SparsePage { SparsePage() { this->Clear(); } - /*! \return number of instance in the page */ + + /*! \return Number of instances in the page. */ inline size_t Size() const { return offset.Size() - 1; } + /*! \return estimation of memory cost of this page */ inline size_t MemCostBytes() const { return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry); } + /*! \brief clear the page */ inline void Clear() { base_rowid = 0; @@ -208,6 +223,11 @@ class SparsePage { data.HostVector().clear(); } + /*! \brief Set the base row id for this page. */ + inline void SetBaseRowId(size_t row_id) { + base_rowid = row_id; + } + SparsePage GetTranspose(int num_columns) const; void SortRows() { @@ -238,13 +258,6 @@ class SparsePage { * \param batch The row batch to be pushed */ void PushCSC(const SparsePage& batch); - /*! - * \brief Push one instance into page - * \param inst an instance row - */ - void Push(const Inst &inst); - - size_t Size() { return offset.Size() - 1; } }; class CSCPage: public SparsePage { @@ -268,9 +281,31 @@ class EllpackPageImpl; */ class EllpackPage { public: - explicit EllpackPage(DMatrix* dmat); + /*! + * \brief Default constructor. + * + * This is used in the external memory case. An empty ELLPACK page is constructed with its content + * set later by the reader. + */ + EllpackPage(); + + /*! + * \brief Constructor from an existing DMatrix. + * + * This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix + * in CSR format. + */ + explicit EllpackPage(DMatrix* dmat, const BatchParam& param); + + /*! \brief Destructor. */ ~EllpackPage(); + /*! \return Number of instances in the page. */ + size_t Size() const; + + /*! \brief Set the base row id for this page. */ + void SetBaseRowId(size_t row_id); + const EllpackPageImpl* Impl() const { return impl_.get(); } EllpackPageImpl* Impl() { return impl_.get(); } @@ -356,7 +391,8 @@ class DataSource : public dmlc::DataIter { * There are two ways to create a customized DMatrix that reads in user defined-format. * * - Provide a dmlc::Parser and pass into the DMatrix::Create - * - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by DMLC_REGISTER_DATA_PARSER; + * - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by + * DMLC_REGISTER_DATA_PARSER; * - This works best for user defined data input source, such as data-base, filesystem. * - Provide a DataSource, that can be passed to DMatrix::Create * This can be used to re-use inmemory data structure into DMatrix. @@ -373,7 +409,7 @@ class DMatrix { * \brief Gets batches. Use range based for loop over BatchSet to access individual batches. */ template - BatchSet GetBatches(); + BatchSet GetBatches(const BatchParam& param = {}); // the following are column meta data, should be able to answer them fast. /*! \return Whether the data columns single column block. */ virtual bool SingleColBlock() const = 0; @@ -389,6 +425,12 @@ class DMatrix { * \return The created DMatrix. */ virtual void SaveToLocalFile(const std::string& fname); + + /*! \brief Whether the matrix is dense. */ + bool IsDense() const { + return Info().num_nonzero_ == Info().num_row_ * Info().num_col_; + } + /*! * \brief Load DMatrix from URI. * \param uri The URI of input. @@ -438,27 +480,27 @@ class DMatrix { virtual BatchSet GetRowBatches() = 0; virtual BatchSet GetColumnBatches() = 0; virtual BatchSet GetSortedColumnBatches() = 0; - virtual BatchSet GetEllpackBatches() = 0; + virtual BatchSet GetEllpackBatches(const BatchParam& param) = 0; }; template<> -inline BatchSet DMatrix::GetBatches() { +inline BatchSet DMatrix::GetBatches(const BatchParam&) { return GetRowBatches(); } template<> -inline BatchSet DMatrix::GetBatches() { +inline BatchSet DMatrix::GetBatches(const BatchParam&) { return GetColumnBatches(); } template<> -inline BatchSet DMatrix::GetBatches() { +inline BatchSet DMatrix::GetBatches(const BatchParam&) { return GetSortedColumnBatches(); } template<> -inline BatchSet DMatrix::GetBatches() { - return GetEllpackBatches(); +inline BatchSet DMatrix::GetBatches(const BatchParam& param) { + return GetEllpackBatches(param); } } // namespace xgboost diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index e5bc74359..1e87c9273 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -540,16 +540,21 @@ class BulkAllocator { } public: - BulkAllocator() = default; + BulkAllocator() = default; // prevent accidental copying, moving or assignment of this object BulkAllocator(const BulkAllocator&) = delete; BulkAllocator(BulkAllocator&&) = delete; void operator=(const BulkAllocator&) = delete; void operator=(BulkAllocator&&) = delete; - ~BulkAllocator() { - for (size_t i = 0; i < d_ptr_.size(); i++) { - if (!(d_ptr_[i] == nullptr)) { + /*! + * \brief Clear the bulk allocator. + * + * This frees the GPU memory managed by this allocator. + */ + void Clear() { + for (size_t i = 0; i < d_ptr_.size(); i++) { // NOLINT(modernize-loop-convert) + if (d_ptr_[i] != nullptr) { safe_cuda(cudaSetDevice(device_idx_[i])); XGBDeviceAllocator allocator; allocator.deallocate(thrust::device_ptr(d_ptr_[i]), size_[i]); @@ -558,6 +563,10 @@ class BulkAllocator { } } + ~BulkAllocator() { + Clear(); + } + // returns sum of bytes for all allocations size_t Size() { return std::accumulate(size_.begin(), size_.end(), static_cast(0)); diff --git a/src/data/data.cc b/src/data/data.cc index 72cc67383..0ca76c315 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -21,7 +21,10 @@ #endif // DMLC_ENABLE_STD_THREAD namespace dmlc { -DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg); +DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>); +DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::CSCPage>); +DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPage>); +DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>); } // namespace dmlc namespace xgboost { @@ -329,31 +332,6 @@ DMatrix* DMatrix::Create(std::unique_ptr>&& source, } // namespace xgboost namespace xgboost { - data::SparsePageFormat* data::SparsePageFormat::Create(const std::string& name) { - auto *e = ::dmlc::Registry< ::xgboost::data::SparsePageFormatReg>::Get()->Find(name); - if (e == nullptr) { - LOG(FATAL) << "Unknown format type " << name; - } - return (e->body)(); -} - -std::pair -data::SparsePageFormat::DecideFormat(const std::string& cache_prefix) { - size_t pos = cache_prefix.rfind(".fmt-"); - - if (pos != std::string::npos) { - std::string fmt = cache_prefix.substr(pos + 5, cache_prefix.length()); - size_t cpos = fmt.rfind('-'); - if (cpos != std::string::npos) { - return std::make_pair(fmt.substr(0, cpos), fmt.substr(cpos + 1, fmt.length())); - } else { - return std::make_pair(fmt, fmt); - } - } else { - std::string raw = "raw"; - return std::make_pair(raw, raw); - } -} SparsePage SparsePage::GetTranspose(int num_columns) const { SparsePage transpose; common::ParallelGroupBuilder builder(&transpose.offset.HostVector(), @@ -476,18 +454,6 @@ void SparsePage::PushCSC(const SparsePage &batch) { self_offset = std::move(offset); } -void SparsePage::Push(const Inst &inst) { - auto& data_vec = data.HostVector(); - auto& offset_vec = offset.HostVector(); - offset_vec.push_back(offset_vec.back() + inst.size()); - size_t begin = data_vec.size(); - data_vec.resize(begin + inst.size()); - if (inst.size() != 0) { - std::memcpy(dmlc::BeginPtr(data_vec) + begin, inst.data(), - sizeof(Entry) * inst.size()); - } -} - namespace data { // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format); diff --git a/src/data/ellpack_page.cc b/src/data/ellpack_page.cc index 333b966cc..282b77a2e 100644 --- a/src/data/ellpack_page.cc +++ b/src/data/ellpack_page.cc @@ -1,18 +1,16 @@ /*! * Copyright 2019 XGBoost contributors - * - * \file ellpack_page.cc */ #ifndef XGBOOST_USE_CUDA #include -// dummy implementation of ELlpackPage in case CUDA is not used +// dummy implementation of EllpackPage in case CUDA is not used namespace xgboost { class EllpackPageImpl {}; -EllpackPage::EllpackPage(DMatrix* dmat) { +EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param) { LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required"; } diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index cfacec0d6..3fd12deb3 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -1,7 +1,5 @@ /*! * Copyright 2019 XGBoost contributors - * - * \file ellpack_page.cu */ #include @@ -12,14 +10,22 @@ namespace xgboost { -EllpackPage::EllpackPage(DMatrix* dmat) : impl_{new EllpackPageImpl(dmat)} {} +EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {} + +EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param) + : impl_{new EllpackPageImpl(dmat, param)} {} EllpackPage::~EllpackPage() = default; -EllpackPageImpl::EllpackPageImpl(DMatrix* dmat) : dmat_{dmat} {} +size_t EllpackPage::Size() const { + return impl_->Size(); +} + +void EllpackPage::SetBaseRowId(size_t row_id) { + impl_->SetBaseRowId(row_id); +} // Bin each input data entry, store the bin indices in compressed form. -template::type = 0> __global__ void CompressBinEllpackKernel( common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer, // gidx_buffer @@ -43,7 +49,7 @@ __global__ void CompressBinEllpackKernel( int feature = entry.index; float fvalue = entry.fvalue; // {feature_cuts, ncuts} forms the array of cuts of `feature'. - const float *feature_cuts = &cuts[cut_rows[feature]]; + const float* feature_cuts = &cuts[cut_rows[feature]]; int ncuts = cut_rows[feature + 1] - cut_rows[feature]; // Assigning the bin in current entry. // S.t.: fvalue < feature_cuts[bin] @@ -58,87 +64,90 @@ __global__ void CompressBinEllpackKernel( wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } -void EllpackPageImpl::Init(int device, int max_bin, int gpu_batch_nrows) { - if (initialised_) return; - +// Construct an ELLPACK matrix in memory. +EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.Init("ellpack_page"); - dh::safe_cuda(cudaSetDevice(device)); + dh::safe_cuda(cudaSetDevice(param.gpu_id)); monitor_.StartCuda("Quantiles"); // Create the quantile sketches for the dmatrix and initialize HistogramCuts. common::HistogramCuts hmat; - size_t row_stride = common::DeviceSketch(device, max_bin, gpu_batch_nrows, dmat_, &hmat); + size_t row_stride = + common::DeviceSketch(param.gpu_id, param.max_bin, param.gpu_batch_nrows, dmat, &hmat); monitor_.StopCuda("Quantiles"); - const auto& info = dmat_->Info(); - auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_; + monitor_.StartCuda("InitEllpackInfo"); + InitInfo(param.gpu_id, dmat->IsDense(), row_stride, hmat); + monitor_.StopCuda("InitEllpackInfo"); - // Init global data monitor_.StartCuda("InitCompressedData"); - InitCompressedData(device, hmat, row_stride, is_dense); + InitCompressedData(param.gpu_id, dmat->Info().num_row_); monitor_.StopCuda("InitCompressedData"); monitor_.StartCuda("BinningCompression"); - DeviceHistogramBuilderState hist_builder_row_state(info.num_row_); - for (const auto& batch : dmat_->GetBatches()) { + DeviceHistogramBuilderState hist_builder_row_state(dmat->Info().num_row_); + for (const auto& batch : dmat->GetBatches()) { hist_builder_row_state.BeginBatch(batch); - CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice()); + CreateHistIndices(param.gpu_id, batch, hist_builder_row_state.GetRowStateOnDevice()); hist_builder_row_state.EndBatch(); } monitor_.StopCuda("BinningCompression"); - - initialised_ = true; } -void EllpackPageImpl::InitCompressedData(int device, - const common::HistogramCuts& hmat, - size_t row_stride, - bool is_dense) { - n_bins = hmat.Ptrs().back(); - int null_gidx_value = hmat.Ptrs().back(); - int num_symbols = n_bins + 1; - - // minimum value for each feature. - common::Span min_fvalue; - - // Required buffer size for storing data matrix in ELLPack format. - size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( - row_stride * dmat_->Info().num_row_, num_symbols); +// Construct an EllpackInfo based on histogram cuts of features. +EllpackInfo::EllpackInfo(int device, + bool is_dense, + size_t row_stride, + const common::HistogramCuts& hmat, + dh::BulkAllocator& ba) + : is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) { ba.Allocate(device, &feature_segments, hmat.Ptrs().size(), &gidx_fvalue_map, hmat.Values().size(), - &min_fvalue, hmat.MinValues().size(), - &gidx_buffer, compressed_size_bytes); - + &min_fvalue, hmat.MinValues().size()); dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values()); dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues()); dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs()); +} + +// Initialize the EllpackInfo for this page. +void EllpackPageImpl::InitInfo(int device, + bool is_dense, + size_t row_stride, + const common::HistogramCuts& hmat) { + matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, ba_); +} + +// Initialize the buffer to stored compressed features. +void EllpackPageImpl::InitCompressedData(int device, size_t num_rows) { + int num_symbols = matrix.info.n_bins + 1; + + // Required buffer size for storing data matrix in ELLPack format. + size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( + matrix.info.row_stride * num_rows, num_symbols); + ba_.Allocate(device, &gidx_buffer, compressed_size_bytes); + thrust::fill( thrust::device_pointer_cast(gidx_buffer.data()), thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0); - ellpack_matrix.Init(feature_segments, - min_fvalue, - gidx_fvalue_map, - row_stride, - common::CompressedIterator(gidx_buffer.data(), num_symbols), - is_dense, - null_gidx_value); + matrix.gidx_iter = common::CompressedIterator(gidx_buffer.data(), num_symbols); } +// Compress a CSR page into ELLPACK. void EllpackPageImpl::CreateHistIndices(int device, const SparsePage& row_batch, const RowStateOnDevice& device_row_state) { // Has any been allocated for me in this batch? if (!device_row_state.rows_to_process_from_batch) return; - unsigned int null_gidx_value = n_bins; - size_t row_stride = this->ellpack_matrix.row_stride; + unsigned int null_gidx_value = matrix.info.n_bins; + size_t row_stride = matrix.info.row_stride; - const auto &offset_vec = row_batch.offset.ConstHostVector(); + const auto& offset_vec = row_batch.offset.ConstHostVector(); - int num_symbols = n_bins + 1; + int num_symbols = matrix.info.n_bins + 1; // bin and compress entries in batches of rows size_t gpu_batch_nrows = std::min( dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)), @@ -162,7 +171,7 @@ void EllpackPageImpl::CreateHistIndices(int device, offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end]; /*! \brief row offset in SparsePage (the input data). */ - dh::device_vector row_ptrs(batch_nrows+1); + dh::device_vector row_ptrs(batch_nrows + 1); thrust::copy( offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin, offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1, @@ -185,8 +194,8 @@ void EllpackPageImpl::CreateHistIndices(int device, gidx_buffer.data(), row_ptrs.data().get(), entries_d.data().get(), - gidx_fvalue_map.data(), - feature_segments.data(), + matrix.info.gidx_fvalue_map.data(), + matrix.info.feature_segments.data(), device_row_state.total_rows_processed + batch_row_begin, batch_nrows, row_stride, @@ -194,4 +203,73 @@ void EllpackPageImpl::CreateHistIndices(int device, } } +// Return the number of rows contained in this page. +size_t EllpackPageImpl::Size() const { + return n_rows; +} + +// Clear the current page. +void EllpackPageImpl::Clear() { + ba_.Clear(); + gidx_buffer = {}; + idx_buffer.clear(); + n_rows = 0; +} + +// Push a CSR page to the current page. +// +// First compress the CSR page into ELLPACK, then the compressed buffer is copied to host and +// appended to the existing host vector. +void EllpackPageImpl::Push(int device, const SparsePage& batch) { + monitor_.StartCuda("InitCompressedData"); + InitCompressedData(device, batch.Size()); + monitor_.StopCuda("InitCompressedData"); + + monitor_.StartCuda("BinningCompression"); + DeviceHistogramBuilderState hist_builder_row_state(batch.Size()); + hist_builder_row_state.BeginBatch(batch); + CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice()); + hist_builder_row_state.EndBatch(); + monitor_.StopCuda("BinningCompression"); + + monitor_.StartCuda("CopyDeviceToHost"); + std::vector buffer(gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&buffer, gidx_buffer); + int offset = 0; + if (!idx_buffer.empty()) { + offset = ::xgboost::common::detail::kPadding; + } + idx_buffer.reserve(idx_buffer.size() + buffer.size() - offset); + idx_buffer.insert(idx_buffer.end(), buffer.begin() + offset, buffer.end()); + ba_.Clear(); + gidx_buffer = {}; + monitor_.StopCuda("CopyDeviceToHost"); + + n_rows += batch.Size(); +} + +// Return the memory cost for storing the compressed features. +size_t EllpackPageImpl::MemCostBytes() const { + return idx_buffer.size() * sizeof(common::CompressedByteT); +} + +// Copy the compressed features to GPU. +void EllpackPageImpl::InitDevice(int device, EllpackInfo info) { + if (device_initialized_) return; + + monitor_.StartCuda("CopyPageToDevice"); + dh::safe_cuda(cudaSetDevice(device)); + + gidx_buffer = {}; + ba_.Allocate(device, &gidx_buffer, idx_buffer.size()); + dh::CopyVectorToDeviceSpan(gidx_buffer, idx_buffer); + + matrix.info = info; + matrix.gidx_iter = common::CompressedIterator(gidx_buffer.data(), info.n_bins + 1); + + monitor_.StopCuda("CopyPageToDevice"); + + device_initialized_ = true; +} + } // namespace xgboost diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index daa90ff73..6b471049e 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -1,7 +1,5 @@ /*! * Copyright 2019 by XGBoost Contributors - * - * \file ellpack_page.cuh */ #ifndef XGBOOST_DATA_ELLPACK_PAGE_H_ @@ -42,56 +40,68 @@ __forceinline__ __device__ int BinarySearchRow( return -1; } +/** \brief Meta information about the ELLPACK matrix. */ +struct EllpackInfo { + /*! \brief Whether or not if the matrix is dense. */ + bool is_dense; + /*! \brief Row length for ELLPack, equal to number of features. */ + size_t row_stride; + /*! \brief Total number of bins, also used as the null index value, . */ + size_t n_bins; + /*! \brief Minimum value for each feature. Size equals to number of features. */ + common::Span min_fvalue; + /*! \brief Histogram cut pointers. Size equals to (number of features + 1). */ + common::Span feature_segments; + /*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */ + common::Span gidx_fvalue_map; + + EllpackInfo() = default; + + /*! + * \brief Constructor. + * + * @param device The GPU device to use. + * @param is_dense Whether the matrix is dense. + * @param row_stride The number of features between starts of consecutive rows. + * @param hmat The histogram cuts of all the features. + * @param ba The BulkAllocator that owns the GPU memory. + */ + explicit EllpackInfo(int device, + bool is_dense, + size_t row_stride, + const common::HistogramCuts& hmat, + dh::BulkAllocator& ba); +}; + /** \brief Struct for accessing and manipulating an ellpack matrix on the * device. Does not own underlying memory and may be trivially copied into * kernels.*/ -struct ELLPackMatrix { - common::Span feature_segments; - /*! \brief minimum value for each feature. */ - common::Span min_fvalue; - /*! \brief Cut. */ - common::Span gidx_fvalue_map; - /*! \brief row length for ELLPack. */ - size_t row_stride{0}; +struct EllpackMatrix { + EllpackInfo info; common::CompressedIterator gidx_iter; - int null_gidx_value; - XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); } + XGBOOST_DEVICE size_t BinCount() const { return info.gidx_fvalue_map.size(); } // Get a matrix element, uses binary search for look up Return NaN if missing // Given a row index and a feature index, returns the corresponding cut value __device__ bst_float GetElement(size_t ridx, size_t fidx) const { - auto row_begin = row_stride * ridx; - auto row_end = row_begin + row_stride; + auto row_begin = info.row_stride * ridx; + auto row_end = row_begin + info.row_stride; auto gidx = -1; - if (is_dense) { + if (info.is_dense) { gidx = gidx_iter[row_begin + fidx]; } else { - gidx = - BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx], - feature_segments[fidx + 1]); + gidx = BinarySearchRow(row_begin, + row_end, + gidx_iter, + info.feature_segments[fidx], + info.feature_segments[fidx + 1]); } if (gidx == -1) { return nan(""); } - return gidx_fvalue_map[gidx]; + return info.gidx_fvalue_map[gidx]; } - void Init(common::Span feature_segments, - common::Span min_fvalue, - common::Span gidx_fvalue_map, size_t row_stride, - common::CompressedIterator gidx_iter, bool is_dense, - int null_gidx_value) { - this->feature_segments = feature_segments; - this->min_fvalue = min_fvalue; - this->gidx_fvalue_map = gidx_fvalue_map; - this->row_stride = row_stride; - this->gidx_iter = gidx_iter; - this->is_dense = is_dense; - this->null_gidx_value = null_gidx_value; - } - - private: - bool is_dense; }; // Instances of this type are created while creating the histogram bins for the @@ -171,31 +181,93 @@ class DeviceHistogramBuilderState { class EllpackPageImpl { public: - ELLPackMatrix ellpack_matrix; - int n_bins{}; + EllpackMatrix matrix; /*! \brief global index of histogram, which is stored in ELLPack format. */ common::Span gidx_buffer; + std::vector idx_buffer; + size_t n_rows{}; - explicit EllpackPageImpl(DMatrix* dmat); - void Init(int device, int max_bin, int gpu_batch_nrows); - void InitCompressedData(int device, - const common::HistogramCuts& hmat, - size_t row_stride, - bool is_dense); + /*! + * \brief Default constructor. + * + * This is used in the external memory case. An empty ELLPACK page is constructed with its content + * set later by the reader. + */ + EllpackPageImpl() = default; + + /*! + * \brief Constructor from an existing DMatrix. + * + * This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix + * in CSR format. + */ + explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm); + + /*! + * \brief Initialize the EllpackInfo contained in the EllpackMatrix. + * + * This is used in the in-memory case. The current page owns the BulkAllocator, which in turn owns + * the GPU memory used by the EllpackInfo. + * + * @param device The GPU device to use. + * @param is_dense Whether the matrix is dense. + * @param row_stride The number of features between starts of consecutive rows. + * @param hmat The histogram cuts of all the features. + */ + void InitInfo(int device, bool is_dense, size_t row_stride, const common::HistogramCuts& hmat); + + /*! + * \brief Initialize the buffer to store compressed features. + * + * @param device The GPU device to use. + * @param num_rows The number of rows we are storing in the buffer. + */ + void InitCompressedData(int device, size_t num_rows); + + /*! + * \brief Compress a single page of CSR data into ELLPACK. + * + * @param device The GPU device to use. + * @param row_batch The CSR page. + * @param device_row_state On-device data for maintaining state. + */ void CreateHistIndices(int device, const SparsePage& row_batch, const RowStateOnDevice& device_row_state); - private: - bool initialised_{false}; - DMatrix* dmat_; - common::Monitor monitor_; - dh::BulkAllocator ba; + /*! \return Number of instances in the page. */ + size_t Size() const; - /*! \brief Cut. */ - common::Span gidx_fvalue_map; - /*! \brief row_ptr form HistogramCuts. */ - common::Span feature_segments; + /*! \brief Set the base row id for this page. */ + inline void SetBaseRowId(size_t row_id) { + base_rowid_ = row_id; + } + + /*! \brief clear the page. */ + void Clear(); + + /*! + * \brief Push a sparse page. + * \param batch The row page. + */ + void Push(int device, const SparsePage& batch); + + /*! \return Estimation of memory cost of this page. */ + size_t MemCostBytes() const; + + /*! + * \brief Copy the ELLPACK matrix to GPU. + * + * @param device The GPU device to use. + * @param info The EllpackInfo for the matrix. + */ + void InitDevice(int device, EllpackInfo info); + + private: + common::Monitor monitor_; + dh::BulkAllocator ba_; + size_t base_rowid_{}; + bool device_initialized_{false}; }; } // namespace xgboost diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu new file mode 100644 index 000000000..fc8dcde62 --- /dev/null +++ b/src/data/ellpack_page_raw_format.cu @@ -0,0 +1,48 @@ +/*! + * Copyright 2019 XGBoost contributors + */ + +#include +#include + +#include "./ellpack_page.cuh" +#include "./sparse_page_writer.h" + +namespace xgboost { +namespace data { + +DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format); + +class EllpackPageRawFormat : public SparsePageFormat { + public: + bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { + auto* impl = page->Impl(); + if (!fi->Read(&impl->n_rows)) return false; + return fi->Read(&impl->idx_buffer); + } + + bool Read(EllpackPage* page, + dmlc::SeekStream* fi, + const std::vector& sorted_index_set) override { + auto* impl = page->Impl(); + if (!fi->Read(&impl->n_rows)) return false; + return fi->Read(&page->Impl()->idx_buffer); + } + + void Write(const EllpackPage& page, dmlc::Stream* fo) override { + auto* impl = page.Impl(); + fo->Write(impl->n_rows); + auto buffer = impl->idx_buffer; + CHECK(!buffer.empty()); + fo->Write(buffer); + } +}; + +XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(raw) + .describe("Raw ELLPACK binary data format.") + .set_body([]() { + return new EllpackPageRawFormat(); + }); + +} // namespace data +} // namespace xgboost diff --git a/src/data/ellpack_page_source.cc b/src/data/ellpack_page_source.cc new file mode 100644 index 000000000..130e5d6a6 --- /dev/null +++ b/src/data/ellpack_page_source.cc @@ -0,0 +1,46 @@ +/*! + * Copyright 2019 XGBoost contributors + */ +#ifndef XGBOOST_USE_CUDA + +#include "ellpack_page_source.h" + +namespace xgboost { +namespace data { + +EllpackPageSource::EllpackPageSource(DMatrix* dmat, + const std::string& cache_info, + const BatchParam& param) noexcept(false) { + LOG(FATAL) << "Internal Error: " + "XGBoost is not compiled with CUDA but EllpackPageSource is required"; +} + +void EllpackPageSource::BeforeFirst() { + LOG(FATAL) << "Internal Error: " + "XGBoost is not compiled with CUDA but EllpackPageSource is required"; +} + +bool EllpackPageSource::Next() { + LOG(FATAL) << "Internal Error: " + "XGBoost is not compiled with CUDA but EllpackPageSource is required"; + return false; +} + +EllpackPage& EllpackPageSource::Value() { + LOG(FATAL) << "Internal Error: " + "XGBoost is not compiled with CUDA but EllpackPageSource is required"; + EllpackPage* page; + return *page; +} + +const EllpackPage& EllpackPageSource::Value() const { + LOG(FATAL) << "Internal Error: " + "XGBoost is not compiled with CUDA but EllpackPageSource is required"; + EllpackPage* page; + return *page; +} + +} // namespace data +} // namespace xgboost + +#endif // XGBOOST_USE_CUDA diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu new file mode 100644 index 000000000..fcb0936b7 --- /dev/null +++ b/src/data/ellpack_page_source.cu @@ -0,0 +1,155 @@ +/*! + * Copyright 2019 XGBoost contributors + */ + +#include "ellpack_page_source.h" + +#include +#include +#include + +#include "../common/hist_util.h" +#include "ellpack_page.cuh" + +namespace xgboost { +namespace data { + +class EllpackPageSourceImpl : public DataSource { + public: + /*! + * \brief Create source from cache files the cache_prefix. + * \param cache_prefix The prefix of cache we want to solve. + */ + explicit EllpackPageSourceImpl(DMatrix* dmat, + const std::string& cache_info, + const BatchParam& param) noexcept(false); + + /*! \brief destructor */ + ~EllpackPageSourceImpl() override = default; + + void BeforeFirst() override; + bool Next() override; + EllpackPage& Value(); + const EllpackPage& Value() const override; + + private: + /*! \brief Write Ellpack pages after accumulating them in memory. */ + void WriteEllpackPages(DMatrix* dmat, const std::string& cache_info) const; + + /*! \brief The page type string for ELLPACK. */ + const std::string kPageType_{".ellpack.page"}; + + int device_{-1}; + common::Monitor monitor_; + dh::BulkAllocator ba_; + /*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */ + EllpackInfo ellpack_info_; + std::unique_ptr> source_; +}; + +EllpackPageSource::EllpackPageSource(DMatrix* dmat, + const std::string& cache_info, + const BatchParam& param) noexcept(false) + : impl_{new EllpackPageSourceImpl(dmat, cache_info, param)} {} + +void EllpackPageSource::BeforeFirst() { + impl_->BeforeFirst(); +} + +bool EllpackPageSource::Next() { + return impl_->Next(); +} + +EllpackPage& EllpackPageSource::Value() { + return impl_->Value(); +} + +const EllpackPage& EllpackPageSource::Value() const { + return impl_->Value(); +} + +// Build the quantile sketch across the whole input data, then use the histogram cuts to compress +// each CSR page, and write the accumulated ELLPACK pages to disk. +EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, + const std::string& cache_info, + const BatchParam& param) noexcept(false) { + device_ = param.gpu_id; + + monitor_.Init("ellpack_page_source"); + dh::safe_cuda(cudaSetDevice(device_)); + + monitor_.StartCuda("Quantiles"); + common::HistogramCuts hmat; + size_t row_stride = + common::DeviceSketch(device_, param.max_bin, param.gpu_batch_nrows, dmat, &hmat); + monitor_.StopCuda("Quantiles"); + + monitor_.StartCuda("CreateEllpackInfo"); + ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, ba_); + monitor_.StopCuda("CreateEllpackInfo"); + + monitor_.StartCuda("WriteEllpackPages"); + WriteEllpackPages(dmat, cache_info); + monitor_.StopCuda("WriteEllpackPages"); + + source_.reset(new SparsePageSource(cache_info, kPageType_)); +} + +void EllpackPageSourceImpl::BeforeFirst() { + source_->BeforeFirst(); +} + +bool EllpackPageSourceImpl::Next() { + return source_->Next(); +} + +EllpackPage& EllpackPageSourceImpl::Value() { + EllpackPage& page = source_->Value(); + page.Impl()->InitDevice(device_, ellpack_info_); + return page; +} + +const EllpackPage& EllpackPageSourceImpl::Value() const { + EllpackPage& page = source_->Value(); + page.Impl()->InitDevice(device_, ellpack_info_); + return page; +} + +// Compress each CSR page to ELLPACK, and write the accumulated pages to disk. +void EllpackPageSourceImpl::WriteEllpackPages(DMatrix* dmat, const std::string& cache_info) const { + auto cinfo = ParseCacheInfo(cache_info, kPageType_); + const size_t extra_buffer_capacity = 6; + SparsePageWriter writer( + cinfo.name_shards, cinfo.format_shards, extra_buffer_capacity); + std::shared_ptr page; + writer.Alloc(&page); + auto* impl = page->Impl(); + impl->matrix.info = ellpack_info_; + impl->Clear(); + + const MetaInfo& info = dmat->Info(); + size_t bytes_write = 0; + double tstart = dmlc::GetTime(); + for (const auto& batch : dmat->GetBatches()) { + impl->Push(device_, batch); + + if (impl->MemCostBytes() >= DMatrix::kPageSize) { + bytes_write += impl->MemCostBytes(); + writer.PushWrite(std::move(page)); + writer.Alloc(&page); + impl = page->Impl(); + impl->matrix.info = ellpack_info_; + impl->Clear(); + double tdiff = dmlc::GetTime() - tstart; + LOG(INFO) << "Writing to " << cache_info << " in " + << ((bytes_write >> 20UL) / tdiff) << " MB/s, " + << (bytes_write >> 20UL) << " written"; + } + } + if (impl->Size() != 0) { + writer.PushWrite(std::move(page)); + } +} + +} // namespace data +} // namespace xgboost diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h new file mode 100644 index 000000000..08be882f5 --- /dev/null +++ b/src/data/ellpack_page_source.h @@ -0,0 +1,54 @@ +/*! + * Copyright 2019 by XGBoost Contributors + */ + +#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ +#define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ + +#include +#include +#include + +#include "sparse_page_source.h" +#include "../common/timer.h" + +namespace xgboost { +namespace data { + +class EllpackPageSourceImpl; + +/*! + * \brief External memory data source for ELLPACK format. + * + * This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid + * including CUDA-specific implementation details in the header. + */ +class EllpackPageSource : public DataSource { + public: + /*! + * \brief Create source from cache files the cache_prefix. + * \param cache_prefix The prefix of cache we want to solve. + */ + explicit EllpackPageSource(DMatrix* dmat, + const std::string& cache_info, + const BatchParam& param) noexcept(false); + + /*! \brief destructor */ + ~EllpackPageSource() override = default; + + void BeforeFirst() override; + bool Next() override; + EllpackPage& Value(); + const EllpackPage& Value() const override; + + const EllpackPageSourceImpl* Impl() const { return impl_.get(); } + EllpackPageSourceImpl* Impl() { return impl_.get(); } + + private: + std::shared_ptr impl_; +}; + +} // namespace data +} // namespace xgboost + +#endif // XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 65f639a13..c688f1f18 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -62,10 +62,12 @@ BatchSet SimpleDMatrix::GetSortedColumnBatches() { return BatchSet(begin_iter); } -BatchSet SimpleDMatrix::GetEllpackBatches() { +BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) { + CHECK_GE(param.gpu_id, 0); + CHECK_GE(param.max_bin, 2); // ELLPACK page doesn't exist, generate it if (!ellpack_page_) { - ellpack_page_.reset(new EllpackPage(this)); + ellpack_page_.reset(new EllpackPage(this, param)); } auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ellpack_page_.get())); diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 2c740924d..0479a3577 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -38,7 +38,7 @@ class SimpleDMatrix : public DMatrix { BatchSet GetRowBatches() override; BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; - BatchSet GetEllpackBatches() override; + BatchSet GetEllpackBatches(const BatchParam& param) override; // source data pointer. std::unique_ptr> source_; diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 7d07f7426..909e82f45 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -23,10 +23,10 @@ const MetaInfo& SparsePageDMatrix::Info() const { return row_source_->info; } -template +template class SparseBatchIteratorImpl : public BatchIteratorImpl { public: - explicit SparseBatchIteratorImpl(SparsePageSource* source) : source_(source) { + explicit SparseBatchIteratorImpl(S* source) : source_(source) { CHECK(source_ != nullptr); } T& operator*() override { return source_->Value(); } @@ -35,7 +35,7 @@ class SparseBatchIteratorImpl : public BatchIteratorImpl { bool AtEnd() const override { return at_end_; } private: - SparsePageSource* source_{nullptr}; + S* source_{nullptr}; bool at_end_{ false }; }; @@ -43,7 +43,8 @@ BatchSet SparsePageDMatrix::GetRowBatches() { auto cast = dynamic_cast*>(row_source_.get()); cast->BeforeFirst(); cast->Next(); - auto begin_iter = BatchIterator(new SparseBatchIteratorImpl(cast)); + auto begin_iter = BatchIterator( + new SparseBatchIteratorImpl, SparsePage>(cast)); return BatchSet(begin_iter); } @@ -55,8 +56,8 @@ BatchSet SparsePageDMatrix::GetColumnBatches() { } column_source_->BeforeFirst(); column_source_->Next(); - auto begin_iter = - BatchIterator(new SparseBatchIteratorImpl(column_source_.get())); + auto begin_iter = BatchIterator( + new SparseBatchIteratorImpl, CSCPage>(column_source_.get())); return BatchSet(begin_iter); } @@ -70,17 +71,26 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { sorted_column_source_->BeforeFirst(); sorted_column_source_->Next(); auto begin_iter = BatchIterator( - new SparseBatchIteratorImpl(sorted_column_source_.get())); + new SparseBatchIteratorImpl, SortedCSCPage>( + sorted_column_source_.get())); return BatchSet(begin_iter); } -BatchSet SparsePageDMatrix::GetEllpackBatches() { - // ELLPACK page doesn't exist, generate it - if (!ellpack_page_) { - ellpack_page_.reset(new EllpackPage(this)); +BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) { + CHECK_GE(param.gpu_id, 0); + CHECK_GE(param.max_bin, 2); + // Lazily instantiate + if (!ellpack_source_ || + batch_param_.gpu_id != param.gpu_id || + batch_param_.max_bin != param.max_bin || + batch_param_.gpu_batch_nrows != param.gpu_batch_nrows) { + ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param)); + batch_param_ = param; } - auto begin_iter = - BatchIterator(new SimpleBatchIteratorImpl(ellpack_page_.get())); + ellpack_source_->BeforeFirst(); + ellpack_source_->Next(); + auto begin_iter = BatchIterator( + new SparseBatchIteratorImpl(ellpack_source_.get())); return BatchSet(begin_iter); } diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index b8921ba95..eb1634a4b 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -14,6 +14,7 @@ #include #include +#include "ellpack_page_source.h" #include "sparse_page_source.h" namespace xgboost { @@ -38,13 +39,15 @@ class SparsePageDMatrix : public DMatrix { BatchSet GetRowBatches() override; BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; - BatchSet GetEllpackBatches() override; + BatchSet GetEllpackBatches(const BatchParam& param) override; // source data pointers. std::unique_ptr> row_source_; std::unique_ptr> column_source_; std::unique_ptr> sorted_column_source_; - std::unique_ptr ellpack_page_; + std::unique_ptr ellpack_source_; + // saved batch param + BatchParam batch_param_; // the cache prefix std::string cache_info_; // Store column densities to avoid recalculating diff --git a/src/data/sparse_page_raw_format.cc b/src/data/sparse_page_raw_format.cc index 0cb3b6ebe..b9a82bbdf 100644 --- a/src/data/sparse_page_raw_format.cc +++ b/src/data/sparse_page_raw_format.cc @@ -12,9 +12,10 @@ namespace data { DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format); -class SparsePageRawFormat : public SparsePageFormat { +template +class SparsePageRawFormat : public SparsePageFormat { public: - bool Read(SparsePage* page, dmlc::SeekStream* fi) override { + bool Read(T* page, dmlc::SeekStream* fi) override { auto& offset_vec = page->offset.HostVector(); if (!fi->Read(&offset_vec)) return false; auto& data_vec = page->data.HostVector(); @@ -29,7 +30,7 @@ class SparsePageRawFormat : public SparsePageFormat { return true; } - bool Read(SparsePage* page, + bool Read(T* page, dmlc::SeekStream* fi, const std::vector& sorted_index_set) override { if (!fi->Read(&disk_offset_)) return false; @@ -79,7 +80,7 @@ class SparsePageRawFormat : public SparsePageFormat { return true; } - void Write(const SparsePage& page, dmlc::Stream* fo) override { + void Write(const T& page, dmlc::Stream* fo) override { const auto& offset_vec = page.offset.HostVector(); const auto& data_vec = page.data.HostVector(); CHECK(page.offset.Size() != 0 && offset_vec[0] == 0); @@ -98,7 +99,20 @@ class SparsePageRawFormat : public SparsePageFormat { XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw) .describe("Raw binary data format.") .set_body([]() { - return new SparsePageRawFormat(); + return new SparsePageRawFormat(); }); + +XGBOOST_REGISTER_CSC_PAGE_FORMAT(raw) +.describe("Raw binary data format.") +.set_body([]() { + return new SparsePageRawFormat(); + }); + +XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(raw) +.describe("Raw binary data format.") +.set_body([]() { + return new SparsePageRawFormat(); + }); + } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 5112972bf..32eb687b5 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -46,6 +46,47 @@ GetCacheShards(const std::string& cache_info) { namespace xgboost { namespace data { + +/*! + * \brief decide the format from cache prefix. + * \return pair of row format, column format type of the cache prefix. + */ +inline std::pair DecideFormat(const std::string& cache_prefix) { + size_t pos = cache_prefix.rfind(".fmt-"); + + if (pos != std::string::npos) { + std::string fmt = cache_prefix.substr(pos + 5, cache_prefix.length()); + size_t cpos = fmt.rfind('-'); + if (cpos != std::string::npos) { + return std::make_pair(fmt.substr(0, cpos), fmt.substr(cpos + 1, fmt.length())); + } else { + return std::make_pair(fmt, fmt); + } + } else { + std::string raw = "raw"; + return std::make_pair(raw, raw); + } +} + +struct CacheInfo { + std::string name_info; + std::vector format_shards; + std::vector name_shards; +}; + +inline CacheInfo ParseCacheInfo(const std::string& cache_info, const std::string& page_type) { + CacheInfo info; + std::vector cache_shards = GetCacheShards(cache_info); + CHECK_NE(cache_shards.size(), 0U); + // read in the info files. + info.name_info = cache_shards[0]; + for (const std::string& prefix : cache_shards) { + info.name_shards.push_back(prefix + page_type); + info.format_shards.push_back(DecideFormat(prefix).first); + } + return info; +} + /*! * \brief External memory data source. * \code @@ -72,6 +113,7 @@ class SparsePageSource : public DataSource { std::unique_ptr finfo(dmlc::Stream::Create(name_info.c_str(), "r")); int tmagic; CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic)); + CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch"; this->info.LoadBinary(finfo.get()); } files_.resize(cache_shards.size()); @@ -85,8 +127,8 @@ class SparsePageSource : public DataSource { std::unique_ptr& fi = files_[i]; std::string format; CHECK(fi->Read(&format)) << "Invalid page format"; - formats_[i].reset(SparsePageFormat::Create(format)); - std::unique_ptr& fmt = formats_[i]; + formats_[i].reset(CreatePageFormat(format)); + std::unique_ptr>& fmt = formats_[i]; size_t fbegin = fi->Tell(); prefetchers_[i].reset(new dmlc::ThreadedIter(4)); prefetchers_[i]->Init([&fi, &fmt] (T** dptr) { @@ -111,7 +153,7 @@ class SparsePageSource : public DataSource { prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_); } if (prefetchers_[clock_ptr_]->Next(&page_)) { - page_->base_rowid = base_rowid_; + page_->SetBaseRowId(base_rowid_); base_rowid_ += page_->Size(); // advance clock clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size(); @@ -149,17 +191,9 @@ class SparsePageSource : public DataSource { const std::string& cache_info, const size_t page_size = DMatrix::kPageSize) { const std::string page_type = ".row.page"; - std::vector cache_shards = GetCacheShards(cache_info); - CHECK_NE(cache_shards.size(), 0U); - // read in the info files. - std::string name_info = cache_shards[0]; - std::vector name_shards, format_shards; - for (const std::string& prefix : cache_shards) { - name_shards.push_back(prefix + page_type); - format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first); - } + auto cinfo = ParseCacheInfo(cache_info, page_type); { - SparsePageWriter writer(name_shards, format_shards, 6); + SparsePageWriter writer(cinfo.name_shards, cinfo.format_shards, 6); std::shared_ptr page; writer.Alloc(&page); page->Clear(); @@ -230,30 +264,19 @@ class SparsePageSource : public DataSource { writer.PushWrite(std::move(page)); } - std::unique_ptr fo( - dmlc::Stream::Create(name_info.c_str(), "w")); + std::unique_ptr fo(dmlc::Stream::Create(cinfo.name_info.c_str(), "w")); int tmagic = kMagic; fo->Write(&tmagic, sizeof(tmagic)); // Either every row has query ID or none at all CHECK(qids.empty() || qids.size() == info.num_row_); info.SaveBinary(fo.get()); } - LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to " - << name_info; + LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to " << cinfo.name_info; } /*! * \brief Create source cache by copy content from DMatrix. - * \param cache_info The cache_info of cache file location. - */ - static void CreateRowPage(DMatrix* src, - const std::string& cache_info) { - const std::string page_type = ".row.page"; - CreatePageFromDMatrix(src, cache_info, page_type); - } - - /*! - * \brief Create source cache by copy content from DMatrix. Creates transposed column page, may be sorted or not. + * Creates transposed column page, may be sorted or not. * \param cache_info The cache_info of cache file location. * \param sorted Whether columns should be pre-sorted */ @@ -293,17 +316,9 @@ class SparsePageSource : public DataSource { static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info, const std::string& page_type, const size_t page_size = DMatrix::kPageSize) { - std::vector cache_shards = GetCacheShards(cache_info); - CHECK_NE(cache_shards.size(), 0U); - // read in the info files. - std::string name_info = cache_shards[0]; - std::vector name_shards, format_shards; - for (const std::string& prefix : cache_shards) { - name_shards.push_back(prefix + page_type); - format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first); - } + auto cinfo = ParseCacheInfo(cache_info, page_type); { - SparsePageWriter writer(name_shards, format_shards, 6); + SparsePageWriter writer(cinfo.name_shards, cinfo.format_shards, 6); std::shared_ptr page; writer.Alloc(&page); page->Clear(); @@ -312,9 +327,7 @@ class SparsePageSource : public DataSource { size_t bytes_write = 0; double tstart = dmlc::GetTime(); for (auto& batch : src->GetBatches()) { - if (page_type == ".row.page") { - page->Push(batch); - } else if (page_type == ".col.page") { + if (page_type == ".col.page") { page->PushCSC(batch.GetTranspose(src->Info().num_col_)); } else if (page_type == ".sorted.col.page") { SparsePage tmp = batch.GetTranspose(src->Info().num_col_); @@ -338,28 +351,22 @@ class SparsePageSource : public DataSource { if (page->data.Size() != 0) { writer.PushWrite(std::move(page)); } - - std::unique_ptr fo( - dmlc::Stream::Create(name_info.c_str(), "w")); - int tmagic = kMagic; - fo->Write(&tmagic, sizeof(tmagic)); - info.SaveBinary(fo.get()); } - LOG(INFO) << "SparsePageSource: Finished writing to " << name_info; + LOG(INFO) << "SparsePageSource: Finished writing to " << cinfo.name_info; } /*! \brief number of rows */ size_t base_rowid_; /*! \brief page currently on hold. */ - T *page_; + T* page_; /*! \brief internal clock ptr */ size_t clock_ptr_; /*! \brief file pointer to the row blob file. */ - std::vector > files_; + std::vector> files_; /*! \brief Sparse page format file. */ - std::vector > formats_; + std::vector>> formats_; /*! \brief internal prefetcher. */ - std::vector > > prefetchers_; + std::vector>> prefetchers_; }; } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_writer.cc b/src/data/sparse_page_writer.cc deleted file mode 100644 index a78ced8a5..000000000 --- a/src/data/sparse_page_writer.cc +++ /dev/null @@ -1,75 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file sparse_batch_writer.cc - * \param Writer class sparse page. - */ -#include -#include -#include "./sparse_page_writer.h" - -#if DMLC_ENABLE_STD_THREAD -namespace xgboost { -namespace data { - -SparsePageWriter::SparsePageWriter( - const std::vector& name_shards, - const std::vector& format_shards, - size_t extra_buffer_capacity) - : num_free_buffer_(extra_buffer_capacity + name_shards.size()), - clock_ptr_(0), - workers_(name_shards.size()), - qworkers_(name_shards.size()) { - CHECK_EQ(name_shards.size(), format_shards.size()); - // start writer threads - for (size_t i = 0; i < name_shards.size(); ++i) { - std::string name_shard = name_shards[i]; - std::string format_shard = format_shards[i]; - auto* wqueue = &qworkers_[i]; - workers_[i].reset(new std::thread( - [this, name_shard, format_shard, wqueue] () { - std::unique_ptr fo( - dmlc::Stream::Create(name_shard.c_str(), "w")); - std::unique_ptr fmt( - SparsePageFormat::Create(format_shard)); - fo->Write(format_shard); - std::shared_ptr page; - while (wqueue->Pop(&page)) { - if (page == nullptr) break; - fmt->Write(*page, fo.get()); - qrecycle_.Push(std::move(page)); - } - fo.reset(nullptr); - LOG(INFO) << "SparsePage::Writer Finished writing to " << name_shard; - })); - } -} - -SparsePageWriter::~SparsePageWriter() { - for (auto& queue : qworkers_) { - // use nullptr to signal termination. - std::shared_ptr sig(nullptr); - queue.Push(std::move(sig)); - } - for (auto& thread : workers_) { - thread->join(); - } -} - -void SparsePageWriter::PushWrite(std::shared_ptr&& page) { - qworkers_[clock_ptr_].Push(std::move(page)); - clock_ptr_ = (clock_ptr_ + 1) % workers_.size(); -} - -void SparsePageWriter::Alloc(std::shared_ptr* out_page) { - CHECK(*out_page == nullptr); - if (num_free_buffer_ != 0) { - out_page->reset(new SparsePage()); - --num_free_buffer_; - } else { - CHECK(qrecycle_.Pop(out_page)); - } -} -} // namespace data -} // namespace xgboost - -#endif // DMLC_ENABLE_STD_THREAD diff --git a/src/data/sparse_page_writer.h b/src/data/sparse_page_writer.h index 835663e21..6a6ff4217 100644 --- a/src/data/sparse_page_writer.h +++ b/src/data/sparse_page_writer.h @@ -23,9 +23,14 @@ namespace xgboost { namespace data { + +template +struct SparsePageFormatReg; + /*! * \brief Format specification of SparsePage. */ +template class SparsePageFormat { public: /*! \brief virtual destructor */ @@ -36,7 +41,8 @@ class SparsePageFormat { * \param fi the input stream of the file * \return true of the loading as successful, false if end of file was reached */ - virtual bool Read(SparsePage* page, dmlc::SeekStream* fi) = 0; + virtual bool Read(T* page, dmlc::SeekStream* fi) = 0; + /*! * \brief read only the segments we are interested in, advance fi to end of the block. * \param page The page to load the data into. @@ -44,30 +50,35 @@ class SparsePageFormat { * \param sorted_index_set sorted index of segments we are interested in * \return true of the loading as successful, false if end of file was reached */ - virtual bool Read(SparsePage* page, + virtual bool Read(T* page, dmlc::SeekStream* fi, const std::vector& sorted_index_set) = 0; /*! * \brief save the data to fo, when a page was written. * \param fo output stream */ - virtual void Write(const SparsePage& page, dmlc::Stream* fo) = 0; - /*! - * \brief Create sparse page of format. - * \return The created format functors. - */ - static SparsePageFormat* Create(const std::string& name); - /*! - * \brief decide the format from cache prefix. - * \return pair of row format, column format type of the cache prefix. - */ - static std::pair DecideFormat(const std::string& cache_prefix); + virtual void Write(const T& page, dmlc::Stream* fo) = 0; }; +/*! + * \brief Create sparse page of format. + * \return The created format functors. + */ +template +inline SparsePageFormat* CreatePageFormat(const std::string& name) { + auto *e = ::dmlc::Registry>::Get()->Find(name); + if (e == nullptr) { + LOG(FATAL) << "Unknown format type " << name; + } + return (e->body)(); +} + #if DMLC_ENABLE_STD_THREAD /*! * \brief A threaded writer to write sparse batch page to sharded files. + * @tparam T Type of the page. */ +template class SparsePageWriter { public: /*! @@ -76,26 +87,74 @@ class SparsePageWriter { * \param format_shards format of each shard. * \param extra_buffer_capacity Extra buffer capacity before block. */ - explicit SparsePageWriter( - const std::vector& name_shards, - const std::vector& format_shards, - size_t extra_buffer_capacity); + explicit SparsePageWriter(const std::vector& name_shards, + const std::vector& format_shards, + size_t extra_buffer_capacity) + : num_free_buffer_(extra_buffer_capacity + name_shards.size()), + clock_ptr_(0), + workers_(name_shards.size()), + qworkers_(name_shards.size()) { + CHECK_EQ(name_shards.size(), format_shards.size()); + // start writer threads + for (size_t i = 0; i < name_shards.size(); ++i) { + std::string name_shard = name_shards[i]; + std::string format_shard = format_shards[i]; + auto* wqueue = &qworkers_[i]; + workers_[i].reset(new std::thread( + [this, name_shard, format_shard, wqueue]() { + std::unique_ptr fo(dmlc::Stream::Create(name_shard.c_str(), "w")); + std::unique_ptr> fmt(CreatePageFormat(format_shard)); + fo->Write(format_shard); + std::shared_ptr page; + while (wqueue->Pop(&page)) { + if (page == nullptr) break; + fmt->Write(*page, fo.get()); + qrecycle_.Push(std::move(page)); + } + fo.reset(nullptr); + LOG(INFO) << "SparsePageWriter Finished writing to " << name_shard; + })); + } + } + /*! \brief destructor, will close the files automatically */ - ~SparsePageWriter(); + ~SparsePageWriter() { + for (auto& queue : qworkers_) { + // use nullptr to signal termination. + std::shared_ptr sig(nullptr); + queue.Push(std::move(sig)); + } + for (auto& thread : workers_) { + thread->join(); + } + } + /*! * \brief Push a write job to the writer. * This function won't block, * writing is done by another thread inside writer. * \param page The page to be written */ - void PushWrite(std::shared_ptr&& page); + void PushWrite(std::shared_ptr&& page) { + qworkers_[clock_ptr_].Push(std::move(page)); + clock_ptr_ = (clock_ptr_ + 1) % workers_.size(); + } + /*! * \brief Allocate a page to store results. * This function can block when the writer is too slow and buffer pages * have not yet been recycled. * \param out_page Used to store the allocated pages. */ - void Alloc(std::shared_ptr* out_page); + void Alloc(std::shared_ptr* out_page) { + CHECK(*out_page == nullptr); + if (num_free_buffer_ != 0) { + out_page->reset(new T()); + --num_free_buffer_; + } else { + CHECK(qrecycle_.Pop(out_page)); + } + } private: /*! \brief number of allocated pages */ @@ -103,20 +162,21 @@ class SparsePageWriter { /*! \brief clock_pointer */ size_t clock_ptr_; /*! \brief writer threads */ - std::vector > workers_; + std::vector> workers_; /*! \brief recycler queue */ - dmlc::ConcurrentBlockingQueue > qrecycle_; + dmlc::ConcurrentBlockingQueue> qrecycle_; /*! \brief worker threads */ - std::vector > > qworkers_; + std::vector>> qworkers_; }; #endif // DMLC_ENABLE_STD_THREAD /*! * \brief Registry entry for sparse page format. */ +template struct SparsePageFormatReg - : public dmlc::FunctionRegEntryBase > { + : public dmlc::FunctionRegEntryBase, + std::function* ()>> { }; /*! @@ -131,8 +191,21 @@ struct SparsePageFormatReg * }); * \endcode */ +#define SparsePageFmt SparsePageFormat #define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \ - DMLC_REGISTRY_REGISTER(::xgboost::data::SparsePageFormatReg, SparsePageFormat, Name) + DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SparsePageFmt, Name) + +#define CSCPageFmt SparsePageFormat +#define XGBOOST_REGISTER_CSC_PAGE_FORMAT(Name) \ + DMLC_REGISTRY_REGISTER(SparsePageFormatReg, CSCPageFmt, Name) + +#define SortedCSCPageFmt SparsePageFormat +#define XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(Name) \ + DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SortedCSCPageFmt, Name) + +#define EllpackPageFmt SparsePageFormat +#define XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(Name) \ + DMLC_REGISTRY_REGISTER(SparsePageFormatReg, EllpackPageFm, Name) } // namespace data } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index a65d05ee4..4c6650ba4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -174,16 +174,15 @@ template __device__ void EvaluateFeature( int fidx, common::Span node_histogram, - const xgboost::ELLPackMatrix& matrix, + const xgboost::EllpackMatrix& matrix, DeviceSplitCandidate* best_split, // shared memory storing best split const DeviceNodeStats& node, const GPUTrainingParam& param, TempStorageT* temp_storage, // temp memory for cub operations int constraint, // monotonic_constraints const ValueConstraint& value_constraint) { // Use pointer from cut to indicate begin and end of bins for each feature. - uint32_t gidx_begin = matrix.feature_segments[fidx]; // begining bin - uint32_t gidx_end = - matrix.feature_segments[fidx + 1]; // end bin for i^th feature + uint32_t gidx_begin = matrix.info.feature_segments[fidx]; // begining bin + uint32_t gidx_end = matrix.info.feature_segments[fidx + 1]; // end bin for i^th feature // Sum histogram bins for current feature GradientSumT const feature_sum = ReduceFeature( @@ -231,9 +230,9 @@ __device__ void EvaluateFeature( int split_gidx = (scan_begin + threadIdx.x) - 1; float fvalue; if (split_gidx < static_cast(gidx_begin)) { - fvalue = matrix.min_fvalue[fidx]; + fvalue = matrix.info.min_fvalue[fidx]; } else { - fvalue = matrix.gidx_fvalue_map[split_gidx]; + fvalue = matrix.info.gidx_fvalue_map[split_gidx]; } GradientSumT left = missing_left ? bin + missing : bin; GradientSumT right = parent_sum - left; @@ -249,7 +248,7 @@ __global__ void EvaluateSplitKernel( common::Span node_histogram, // histogram for gradients common::Span feature_set, // Selected features DeviceNodeStats node, - xgboost::ELLPackMatrix matrix, + xgboost::EllpackMatrix matrix, GPUTrainingParam gpu_param, common::Span split_candidates, // resulting split ValueConstraint value_constraint, @@ -401,7 +400,7 @@ struct CalcWeightTrainParam { }; template -__global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix, +__global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix, common::Span d_ridx, GradientSumT* d_node_hist, const GradientPair* d_gpair, size_t n_elements, @@ -413,10 +412,10 @@ __global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix, __syncthreads(); } for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / matrix.row_stride ]; + int ridx = d_ridx[idx / matrix.info.row_stride ]; int gidx = - matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride]; - if (gidx != matrix.null_gidx_value) { + matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride]; + if (gidx != matrix.info.n_bins) { // If we are not using shared memory, accumulate the values directly into // global memory GradientSumT* atomic_add_ptr = @@ -606,7 +605,7 @@ struct GPUHistMakerDevice { int constexpr kBlockThreads = 256; EvaluateSplitKernel <<>>( - hist.GetNodeHistogram(nidx), d_feature_set, node, page->ellpack_matrix, + hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix, gpu_param, d_split_candidates, node_value_constraints[nidx], monotone_constraints); @@ -632,11 +631,11 @@ struct GPUHistMakerDevice { auto d_ridx = row_partitioner->GetRows(nidx); auto d_gpair = gpair.data(); - auto n_elements = d_ridx.size() * page->ellpack_matrix.row_stride; + auto n_elements = d_ridx.size() * page->matrix.info.row_stride; const size_t smem_size = use_shared_memory_histograms - ? sizeof(GradientSumT) * page->ellpack_matrix.BinCount() + ? sizeof(GradientSumT) * page->matrix.BinCount() : 0; const int items_per_thread = 8; const int block_threads = 256; @@ -646,7 +645,7 @@ struct GPUHistMakerDevice { return; } SharedMemHistKernel<<>>( - page->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements, + page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements, use_shared_memory_histograms); } @@ -656,7 +655,7 @@ struct GPUHistMakerDevice { auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); - dh::LaunchN(device_id, page->n_bins, [=] __device__(size_t idx) { + dh::LaunchN(device_id, page->matrix.info.n_bins, [=] __device__(size_t idx) { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); @@ -671,7 +670,7 @@ struct GPUHistMakerDevice { } void UpdatePosition(int nidx, RegTree::Node split_node) { - auto d_matrix = page->ellpack_matrix; + auto d_matrix = page->matrix; row_partitioner->UpdatePosition( nidx, split_node.LeftChild(), split_node.RightChild(), @@ -703,7 +702,7 @@ struct GPUHistMakerDevice { dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(), d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); - auto d_matrix = page->ellpack_matrix; + auto d_matrix = page->matrix; row_partitioner->FinalisePosition( [=] __device__(bst_uint ridx, int position) { auto node = d_nodes[position]; @@ -766,8 +765,7 @@ struct GPUHistMakerDevice { reducer->AllReduceSum( reinterpret_cast(d_node_hist), reinterpret_cast(d_node_hist), - page->ellpack_matrix.BinCount() * - (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); + page->matrix.BinCount() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); reducer->Synchronize(); monitor.StopCuda("AllReduce"); @@ -956,14 +954,14 @@ inline void GPUHistMakerDevice::InitHistogram() { // check if we can use shared memory for building histograms // (assuming atleast we need 2 CTAs per SM to maintain decent latency // hiding) - auto histogram_size = sizeof(GradientSumT) * page->n_bins; + auto histogram_size = sizeof(GradientSumT) * page->matrix.info.n_bins; auto max_smem = dh::MaxSharedMemory(device_id); if (histogram_size <= max_smem) { use_shared_memory_histograms = true; } // Init histogram - hist.Init(device_id, page->n_bins); + hist.Init(device_id, page->matrix.info.n_bins); } template @@ -1017,22 +1015,23 @@ class GPUHistMakerSpecialised { // TODO(rongou): support multiple Ellpack pages. EllpackPageImpl* page{}; - for (auto& batch : dmat->GetBatches()) { + for (auto& batch : dmat->GetBatches({device_, + param_.max_bin, + hist_maker_param_.gpu_batch_nrows})) { page = batch.Impl(); - page->Init(device_, param_.max_bin, hist_maker_param_.gpu_batch_nrows); } dh::safe_cuda(cudaSetDevice(device_)); - maker_.reset(new GPUHistMakerDevice(device_, - page, - info_->num_row_, - param_, - column_sampling_seed, - info_->num_col_)); + maker.reset(new GPUHistMakerDevice(device_, + page, + info_->num_row_, + param_, + column_sampling_seed, + info_->num_col_)); monitor_.StartCuda("InitHistogram"); dh::safe_cuda(cudaSetDevice(device_)); - maker_->InitHistogram(); + maker->InitHistogram(); monitor_.StopCuda("InitHistogram"); p_last_fmat_ = dmat; @@ -1071,17 +1070,17 @@ class GPUHistMakerSpecialised { monitor_.StopCuda("InitData"); gpair->SetDevice(device_); - maker_->UpdateTree(gpair, p_fmat, p_tree, &reducer_); + maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_); } bool UpdatePredictionCache( const DMatrix* data, HostDeviceVector* p_out_preds) { - if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { + if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } monitor_.StartCuda("UpdatePredictionCache"); p_out_preds->SetDevice(device_); - maker_->UpdatePredictionCache(p_out_preds->DevicePointer()); + maker->UpdatePredictionCache(p_out_preds->DevicePointer()); monitor_.StopCuda("UpdatePredictionCache"); return true; } @@ -1089,7 +1088,7 @@ class GPUHistMakerSpecialised { TrainParam param_; // NOLINT MetaInfo* info_{}; // NOLINT - std::unique_ptr> maker_; // NOLINT + std::unique_ptr> maker; // NOLINT private: bool initialised_; diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 30744e22a..6dd97ef7c 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -17,15 +17,13 @@ TEST(EllpackPage, EmptyDMatrix) { constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256, kGpuBatchNRows = 64; constexpr float kSparsity = 0; auto dmat = *CreateDMatrix(kNRows, kNCols, kSparsity); - auto& page = *dmat->GetBatches().begin(); + auto& page = *dmat->GetBatches({0, kMaxBin, kGpuBatchNRows}).begin(); auto impl = page.Impl(); - impl->Init(0, kMaxBin, kGpuBatchNRows); - ASSERT_EQ(impl->ellpack_matrix.feature_segments.size(), 1); - ASSERT_EQ(impl->ellpack_matrix.min_fvalue.size(), 0); - ASSERT_EQ(impl->ellpack_matrix.gidx_fvalue_map.size(), 0); - ASSERT_EQ(impl->ellpack_matrix.row_stride, 0); - ASSERT_EQ(impl->ellpack_matrix.null_gidx_value, 0); - ASSERT_EQ(impl->n_bins, 0); + ASSERT_EQ(impl->matrix.info.feature_segments.size(), 1); + ASSERT_EQ(impl->matrix.info.min_fvalue.size(), 0); + ASSERT_EQ(impl->matrix.info.gidx_fvalue_map.size(), 0); + ASSERT_EQ(impl->matrix.info.row_stride, 0); + ASSERT_EQ(impl->matrix.info.n_bins, 0); ASSERT_EQ(impl->gidx_buffer.size(), 4); } @@ -37,7 +35,7 @@ TEST(EllpackPage, BuildGidxDense) { dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer); common::CompressedIterator gidx(h_gidx_buffer.data(), 25); - ASSERT_EQ(page->ellpack_matrix.row_stride, kNCols); + ASSERT_EQ(page->matrix.info.row_stride, kNCols); std::vector solution = { 0, 3, 8, 9, 14, 17, 20, 21, @@ -70,7 +68,7 @@ TEST(EllpackPage, BuildGidxSparse) { dh::CopyDeviceSpanToVector(&h_gidx_buffer, page->gidx_buffer); common::CompressedIterator gidx(h_gidx_buffer.data(), 25); - ASSERT_LE(page->ellpack_matrix.row_stride, 3); + ASSERT_LE(page->matrix.info.row_stride, 3); // row_stride = 3, 16 rows, 48 entries for ELLPack std::vector solution = { @@ -78,7 +76,7 @@ TEST(EllpackPage, BuildGidxSparse) { 24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24, 24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24 }; - for (size_t i = 0; i < kNRows * page->ellpack_matrix.row_stride; ++i) { + for (size_t i = 0; i < kNRows * page->matrix.info.row_stride; ++i) { ASSERT_EQ(solution[i], gidx[i]); } } diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu new file mode 100644 index 000000000..c95d86817 --- /dev/null +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -0,0 +1,26 @@ +// Copyright by Contributors + +#include +#include "../helpers.h" + +namespace xgboost { + +TEST(GPUSparsePageDMatrix, EllpackPage) { + dmlc::TemporaryDirectory tempdir; + const std::string tmp_file = tempdir.path + "/simple.libsvm"; + CreateSimpleTestData(tmp_file); + DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false); + + // Loop over the batches and assert the data is as expected + for (const auto& batch : dmat->GetBatches({0, 256, 64})) { + EXPECT_EQ(batch.Size(), dmat->Info().num_row_); + } + + EXPECT_TRUE(FileExists(tmp_file + ".cache")); + EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); + EXPECT_TRUE(FileExists(tmp_file + ".cache.ellpack.page")); + + delete dmat; +} + +} // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index d16ab4647..f6878339d 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -192,14 +192,14 @@ std::unique_ptr CreateSparsePageDMatrix( return dmat; } -std::unique_ptr CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols, - size_t page_size, bool deterministic) { +std::unique_ptr CreateSparsePageDMatrixWithRC( + size_t n_rows, size_t n_cols, size_t page_size, bool deterministic, + const dmlc::TemporaryDirectory& tempdir) { if (!n_rows || !n_cols) { return nullptr; } // Create the svm file in a temp dir - dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/big.libsvm"; std::ofstream fo(tmp_file.c_str()); diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index a3ed85cff..489b45583 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -14,6 +14,7 @@ #include +#include #include #include #include @@ -199,8 +200,9 @@ std::unique_ptr CreateSparsePageDMatrix( * * \return The new dmatrix. */ -std::unique_ptr CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols, - size_t page_size, bool deterministic); +std::unique_ptr CreateSparsePageDMatrixWithRC( + size_t n_rows, size_t n_cols, size_t page_size, bool deterministic, + const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory()); gbm::GBTreeModel CreateTestModel(); @@ -247,16 +249,15 @@ inline std::unique_ptr BuildEllpackPage( 0.26f, 0.71f, 1.83f}); cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); - auto is_dense = (*dmat)->Info().num_nonzero_ == - (*dmat)->Info().num_row_ * (*dmat)->Info().num_col_; size_t row_stride = 0; const auto &offset_vec = batch.offset.ConstHostVector(); for (size_t i = 1; i < offset_vec.size(); ++i) { row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]); } - auto page = std::unique_ptr(new EllpackPageImpl(dmat->get())); - page->InitCompressedData(0, cmat, row_stride, is_dense); + auto page = std::unique_ptr(new EllpackPageImpl(dmat->get(), {0, 256, 0})); + page->InitInfo(0, (*dmat)->IsDense(), row_stride, cmat); + page->InitCompressedData(0, n_rows); page->CreateHistIndices(0, batch, RowStateOnDevice(batch.Size(), batch.Size())); delete dmat; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 261a9b898..1b234d350 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -2,6 +2,7 @@ * Copyright 2017-2019 XGBoost contributors */ #include +#include #include #include #include @@ -207,14 +208,14 @@ TEST(GpuHist, EvaluateSplits) { // Copy cut matrix to device. maker.ba.Allocate(0, - &(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(), - &(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(), - &(page->ellpack_matrix.gidx_fvalue_map), 24, + &(page->matrix.info.feature_segments), cmat.Ptrs().size(), + &(page->matrix.info.min_fvalue), cmat.MinValues().size(), + &(page->matrix.info.gidx_fvalue_map), 24, &(maker.monotone_constraints), kNCols); - dh::CopyVectorToDeviceSpan(page->ellpack_matrix.feature_segments, cmat.Ptrs()); - dh::CopyVectorToDeviceSpan(page->ellpack_matrix.gidx_fvalue_map, cmat.Values()); + dh::CopyVectorToDeviceSpan(page->matrix.info.feature_segments, cmat.Ptrs()); + dh::CopyVectorToDeviceSpan(page->matrix.info.gidx_fvalue_map, cmat.Values()); dh::CopyVectorToDeviceSpan(maker.monotone_constraints, param.monotone_constraints); - dh::CopyVectorToDeviceSpan(page->ellpack_matrix.min_fvalue, cmat.MinValues()); + dh::CopyVectorToDeviceSpan(page->matrix.info.min_fvalue, cmat.MinValues()); // Initialize GPUHistMakerDevice::hist maker.hist.Init(0, (max_bins - 1) * kNCols); @@ -265,8 +266,10 @@ void TestHistogramIndexImpl() { tree::GPUHistMakerSpecialised hist_maker, hist_maker_ext; std::unique_ptr hist_maker_dmat( CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); + + dmlc::TemporaryDirectory tempdir; std::unique_ptr hist_maker_ext_dmat( - CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true)); + CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true, tempdir)); std::vector> training_params = { {"max_depth", "10"}, @@ -275,22 +278,21 @@ void TestHistogramIndexImpl() { GenericParameter generic_param(CreateEmptyGenericParam(0)); hist_maker.Configure(training_params, &generic_param); - hist_maker.InitDataOnce(hist_maker_dmat.get()); hist_maker_ext.Configure(training_params, &generic_param); hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get()); // Extract the device maker from the histogram makers and from that its compressed // histogram index - const auto &maker = hist_maker.maker_; + const auto &maker = hist_maker.maker; std::vector h_gidx_buffer(maker->page->gidx_buffer.size()); dh::CopyDeviceSpanToVector(&h_gidx_buffer, maker->page->gidx_buffer); - const auto &maker_ext = hist_maker_ext.maker_; + const auto &maker_ext = hist_maker_ext.maker; std::vector h_gidx_buffer_ext(maker_ext->page->gidx_buffer.size()); dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, maker_ext->page->gidx_buffer); - ASSERT_EQ(maker->page->n_bins, maker_ext->page->n_bins); + ASSERT_EQ(maker->page->matrix.info.n_bins, maker_ext->page->matrix.info.n_bins); ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size()); ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);