Move ellpack page construction into DMatrix (#4833)
This commit is contained in:
@@ -26,6 +26,8 @@
|
||||
namespace xgboost {
|
||||
// forward declare learner.
|
||||
class LearnerImpl;
|
||||
// forward declare dmatrix.
|
||||
class DMatrix;
|
||||
|
||||
/*! \brief data type accepted by xgboost interface */
|
||||
enum DataType {
|
||||
@@ -86,7 +88,7 @@ class MetaInfo {
|
||||
* \return The pre-defined root index of i-th instance.
|
||||
*/
|
||||
inline unsigned GetRoot(size_t i) const {
|
||||
return root_index_.size() != 0 ? root_index_[i] : 0U;
|
||||
return !root_index_.empty() ? root_index_[i] : 0U;
|
||||
}
|
||||
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
|
||||
inline const std::vector<size_t>& LabelAbsSort() const {
|
||||
@@ -166,7 +168,7 @@ class SparsePage {
|
||||
/*! \brief the data of the segments */
|
||||
HostDeviceVector<Entry> data;
|
||||
|
||||
size_t base_rowid;
|
||||
size_t base_rowid{};
|
||||
|
||||
/*! \brief an instance of sparse vector in the batch */
|
||||
using Inst = common::Span<Entry const>;
|
||||
@@ -215,23 +217,23 @@ class SparsePage {
|
||||
const int nthread = omp_get_max_threads();
|
||||
builder.InitBudget(num_columns, nthread);
|
||||
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = (*this)[i];
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
builder.AddBudget(inst[j].index, tid);
|
||||
for (const auto& entry : inst) {
|
||||
builder.AddBudget(entry.index, tid);
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = (*this)[i];
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
for (const auto& entry : inst) {
|
||||
builder.Push(
|
||||
inst[j].index,
|
||||
Entry(static_cast<bst_uint>(this->base_rowid + i), inst[j].fvalue),
|
||||
entry.index,
|
||||
Entry(static_cast<bst_uint>(this->base_rowid + i), entry.fvalue),
|
||||
tid);
|
||||
}
|
||||
}
|
||||
@@ -240,7 +242,7 @@ class SparsePage {
|
||||
|
||||
void SortRows() {
|
||||
auto ncol = static_cast<bst_omp_uint>(this->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
#pragma omp parallel for default(none) shared(ncol) schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
|
||||
std::sort(
|
||||
@@ -287,10 +289,29 @@ class SortedCSCPage : public SparsePage {
|
||||
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
||||
};
|
||||
|
||||
class EllpackPageImpl;
|
||||
/*!
|
||||
* \brief A page stored in ELLPACK format.
|
||||
*
|
||||
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
|
||||
* including CUDA-specific implementation details in the header.
|
||||
*/
|
||||
class EllpackPage {
|
||||
public:
|
||||
explicit EllpackPage(DMatrix* dmat);
|
||||
~EllpackPage();
|
||||
|
||||
const EllpackPageImpl* Impl() const { return impl_.get(); }
|
||||
EllpackPageImpl* Impl() { return impl_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<EllpackPageImpl> impl_;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class BatchIteratorImpl {
|
||||
public:
|
||||
virtual ~BatchIteratorImpl() {}
|
||||
virtual ~BatchIteratorImpl() = default;
|
||||
virtual T& operator*() = 0;
|
||||
virtual const T& operator*() const = 0;
|
||||
virtual void operator++() = 0;
|
||||
@@ -412,7 +433,7 @@ class DMatrix {
|
||||
bool silent,
|
||||
bool load_row_split,
|
||||
const std::string& file_format = "auto",
|
||||
const size_t page_size = kPageSize);
|
||||
size_t page_size = kPageSize);
|
||||
|
||||
/*!
|
||||
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
|
||||
@@ -438,7 +459,7 @@ class DMatrix {
|
||||
*/
|
||||
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
|
||||
const std::string& cache_prefix = "",
|
||||
const size_t page_size = kPageSize);
|
||||
size_t page_size = kPageSize);
|
||||
|
||||
/*! \brief page size 32 MB */
|
||||
static const size_t kPageSize = 32UL << 20UL;
|
||||
@@ -447,6 +468,7 @@ class DMatrix {
|
||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches() = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
@@ -463,6 +485,11 @@ template<>
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
||||
return GetSortedColumnBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches() {
|
||||
return GetEllpackBatches();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace dmlc {
|
||||
|
||||
Reference in New Issue
Block a user