Refactor DMatrix to return batches of different page types (#4686)
* Use explicit template parameter for specifying page type.
This commit is contained in:
@@ -53,7 +53,7 @@ class CPUPredictor : public Predictor {
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
@@ -230,7 +230,7 @@ class CPUPredictor : public Predictor {
|
||||
std::vector<bst_float>& preds = *out_preds;
|
||||
preds.resize(info.num_row_ * ntree_limit);
|
||||
// start collecting the prediction
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
@@ -276,7 +276,7 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
|
||||
// start collecting the contributions
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
|
||||
Reference in New Issue
Block a user