Write ELLPACK pages to disk (#4879)

* add ellpack source
* add batch param
* extract function to parse cache info
* construct ellpack info separately
* push batch to ellpack page
* write ellpack page.
* make sparse page source reusable
This commit is contained in:
Rong Ou
2019-10-22 20:44:32 -07:00
committed by Jiaming Yuan
parent 310fe60b35
commit 5b1715d97c
25 changed files with 935 additions and 408 deletions

View File

@@ -156,6 +156,18 @@ struct Entry {
}
};
/*!
* \brief Parameters for constructing batches.
*/
struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id;
/*! \brief Maximum number of bins per feature for histograms. */
int max_bin;
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
int gpu_batch_nrows;
};
/*!
* \brief In-memory storage unit of sparse batch, stored in CSR format.
*/
@@ -191,14 +203,17 @@ class SparsePage {
SparsePage() {
this->Clear();
}
/*! \return number of instance in the page */
/*! \return Number of instances in the page. */
inline size_t Size() const {
return offset.Size() - 1;
}
/*! \return estimation of memory cost of this page */
inline size_t MemCostBytes() const {
return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry);
}
/*! \brief clear the page */
inline void Clear() {
base_rowid = 0;
@@ -208,6 +223,11 @@ class SparsePage {
data.HostVector().clear();
}
/*! \brief Set the base row id for this page. */
inline void SetBaseRowId(size_t row_id) {
base_rowid = row_id;
}
SparsePage GetTranspose(int num_columns) const;
void SortRows() {
@@ -238,13 +258,6 @@ class SparsePage {
* \param batch The row batch to be pushed
*/
void PushCSC(const SparsePage& batch);
/*!
* \brief Push one instance into page
* \param inst an instance row
*/
void Push(const Inst &inst);
size_t Size() { return offset.Size() - 1; }
};
class CSCPage: public SparsePage {
@@ -268,9 +281,31 @@ class EllpackPageImpl;
*/
class EllpackPage {
public:
explicit EllpackPage(DMatrix* dmat);
/*!
* \brief Default constructor.
*
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
* set later by the reader.
*/
EllpackPage();
/*!
* \brief Constructor from an existing DMatrix.
*
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
/*! \brief Destructor. */
~EllpackPage();
/*! \return Number of instances in the page. */
size_t Size() const;
/*! \brief Set the base row id for this page. */
void SetBaseRowId(size_t row_id);
const EllpackPageImpl* Impl() const { return impl_.get(); }
EllpackPageImpl* Impl() { return impl_.get(); }
@@ -356,7 +391,8 @@ class DataSource : public dmlc::DataIter<T> {
* There are two ways to create a customized DMatrix that reads in user defined-format.
*
* - Provide a dmlc::Parser and pass into the DMatrix::Create
* - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by DMLC_REGISTER_DATA_PARSER;
* - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by
* DMLC_REGISTER_DATA_PARSER;
* - This works best for user defined data input source, such as data-base, filesystem.
* - Provide a DataSource, that can be passed to DMatrix::Create
* This can be used to re-use inmemory data structure into DMatrix.
@@ -373,7 +409,7 @@ class DMatrix {
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
*/
template<typename T>
BatchSet<T> GetBatches();
BatchSet<T> GetBatches(const BatchParam& param = {});
// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
@@ -389,6 +425,12 @@ class DMatrix {
* \return The created DMatrix.
*/
virtual void SaveToLocalFile(const std::string& fname);
/*! \brief Whether the matrix is dense. */
bool IsDense() const {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}
/*!
* \brief Load DMatrix from URI.
* \param uri The URI of input.
@@ -438,27 +480,27 @@ class DMatrix {
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
};
template<>
inline BatchSet<SparsePage> DMatrix::GetBatches() {
inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
return GetRowBatches();
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches() {
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
return GetColumnBatches();
}
template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(const BatchParam&) {
return GetSortedColumnBatches();
}
template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches() {
return GetEllpackBatches();
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param);
}
} // namespace xgboost