Move ellpack page construction into DMatrix (#4833)

This commit is contained in:
Rong Ou
2019-09-16 20:50:55 -07:00
committed by Jiaming Yuan
parent 512f037e55
commit 125bcec62e
17 changed files with 761 additions and 513 deletions

View File

@@ -26,6 +26,8 @@
namespace xgboost {
// forward declare learner.
class LearnerImpl;
// forward declare dmatrix.
class DMatrix;
/*! \brief data type accepted by xgboost interface */
enum DataType {
@@ -86,7 +88,7 @@ class MetaInfo {
* \return The pre-defined root index of i-th instance.
*/
inline unsigned GetRoot(size_t i) const {
return root_index_.size() != 0 ? root_index_[i] : 0U;
return !root_index_.empty() ? root_index_[i] : 0U;
}
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
inline const std::vector<size_t>& LabelAbsSort() const {
@@ -166,7 +168,7 @@ class SparsePage {
/*! \brief the data of the segments */
HostDeviceVector<Entry> data;
size_t base_rowid;
size_t base_rowid{};
/*! \brief an instance of sparse vector in the batch */
using Inst = common::Span<Entry const>;
@@ -215,23 +217,23 @@ class SparsePage {
const int nthread = omp_get_max_threads();
builder.InitBudget(num_columns, nthread);
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = (*this)[i];
for (bst_uint j = 0; j < inst.size(); ++j) {
builder.AddBudget(inst[j].index, tid);
for (const auto& entry : inst) {
builder.AddBudget(entry.index, tid);
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static)
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = (*this)[i];
for (bst_uint j = 0; j < inst.size(); ++j) {
for (const auto& entry : inst) {
builder.Push(
inst[j].index,
Entry(static_cast<bst_uint>(this->base_rowid + i), inst[j].fvalue),
entry.index,
Entry(static_cast<bst_uint>(this->base_rowid + i), entry.fvalue),
tid);
}
}
@@ -240,7 +242,7 @@ class SparsePage {
void SortRows() {
auto ncol = static_cast<bst_omp_uint>(this->Size());
#pragma omp parallel for schedule(dynamic, 1)
#pragma omp parallel for default(none) shared(ncol) schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
std::sort(
@@ -287,10 +289,29 @@ class SortedCSCPage : public SparsePage {
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
};
class EllpackPageImpl;
/*!
* \brief A page stored in ELLPACK format.
*
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
* including CUDA-specific implementation details in the header.
*/
class EllpackPage {
public:
explicit EllpackPage(DMatrix* dmat);
~EllpackPage();
const EllpackPageImpl* Impl() const { return impl_.get(); }
EllpackPageImpl* Impl() { return impl_.get(); }
private:
std::unique_ptr<EllpackPageImpl> impl_;
};
template<typename T>
class BatchIteratorImpl {
public:
virtual ~BatchIteratorImpl() {}
virtual ~BatchIteratorImpl() = default;
virtual T& operator*() = 0;
virtual const T& operator*() const = 0;
virtual void operator++() = 0;
@@ -412,7 +433,7 @@ class DMatrix {
bool silent,
bool load_row_split,
const std::string& file_format = "auto",
const size_t page_size = kPageSize);
size_t page_size = kPageSize);
/*!
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
@@ -438,7 +459,7 @@ class DMatrix {
*/
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
const std::string& cache_prefix = "",
const size_t page_size = kPageSize);
size_t page_size = kPageSize);
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;
@@ -447,6 +468,7 @@ class DMatrix {
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches() = 0;
};
template<>
@@ -463,6 +485,11 @@ template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
}
template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches() {
return GetEllpackBatches();
}
} // namespace xgboost
namespace dmlc {