xgboost/src/data/sparse_page_dmatrix.h
2021-07-01 00:44:49 +08:00

77 lines
2.4 KiB
C++

/*!
* Copyright 2015 by Contributors
* \file sparse_page_dmatrix.h
* \brief External-memory version of DMatrix.
* \author Tianqi Chen
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
#include <xgboost/data.h>
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ellpack_page_source.h"
#include "sparse_page_source.h"
namespace xgboost {
namespace data {
// Used for external memory.
class SparsePageDMatrix : public DMatrix {
public:
template <typename AdapterT>
explicit SparsePageDMatrix(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix,
size_t page_size = kPageSize)
: cache_info_(std::move(cache_prefix)) {
row_source_.reset(new data::SparsePageSource(adapter, missing, nthread,
cache_prefix, page_size));
}
~SparsePageDMatrix() override = default;
MetaInfo& Info() override;
const MetaInfo& Info() const override;
bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}
private:
BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
}
// source data pointers.
std::unique_ptr<SparsePageSource> row_source_;
std::unique_ptr<CSCPageSource> column_source_;
std::unique_ptr<SortedCSCPageSource> sorted_column_source_;
std::unique_ptr<EllpackPageSource> ellpack_source_;
// saved batch param
BatchParam batch_param_;
// the cache prefix
std::string cache_info_;
// Store column densities to avoid recalculating
std::vector<float> col_density_;
bool EllpackExists() const override {
return static_cast<bool>(ellpack_source_);
}
bool SparsePageExists() const override {
return static_cast<bool>(row_source_);
}
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_