Refactor DMatrix to return batches of different page types (#4686)
* Use explicit template parameter for specifying page type.
This commit is contained in:
@@ -140,7 +140,7 @@ class GBLinear : public GradientBooster {
|
||||
// make sure contributions is zeroed, we could be reusing a previously allocated one
|
||||
std::fill(contribs.begin(), contribs.end(), 0);
|
||||
// start collecting the contributions
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
@@ -198,7 +198,7 @@ class GBLinear : public GradientBooster {
|
||||
// start collecting the prediction
|
||||
const int ngroup = model_.param.num_output_group;
|
||||
preds.resize(p_fmat->Info().num_row_ * ngroup);
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// output convention: nrow * k, where nrow is number of rows
|
||||
// k is number of group
|
||||
// parallel over local batch
|
||||
|
||||
@@ -371,7 +371,7 @@ class Dart : public GBTree {
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
auto* self = static_cast<Derived*>(this);
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
|
||||
Reference in New Issue
Block a user