Dmatrix refactor stage 1 (#3301)
* Use sparse page as singular CSR matrix representation * Simplify dmatrix methods * Reduce statefullness of batch iterators * BREAKING CHANGE: Remove prob_buffer_row parameter. Users are instead recommended to sample their dataset as a preprocessing step before using XGBoost.
This commit is contained in:
@@ -86,7 +86,7 @@ class GBLinear : public GradientBooster {
|
||||
if (!p_fmat->HaveColAccess(false)) {
|
||||
monitor_.Start("InitColAccess");
|
||||
std::vector<bool> enabled(p_fmat->Info().num_col_, true);
|
||||
p_fmat->InitColAccess(enabled, 1.0f, param_.max_row_perbatch, false);
|
||||
p_fmat->InitColAccess(param_.max_row_perbatch, false);
|
||||
monitor_.Stop("InitColAccess");
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ class GBLinear : public GradientBooster {
|
||||
monitor_.Stop("PredictBatch");
|
||||
}
|
||||
// add base margin
|
||||
void PredictInstance(const SparseBatch::Inst &inst,
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit,
|
||||
unsigned root_index) override {
|
||||
@@ -152,15 +152,15 @@ class GBLinear : public GradientBooster {
|
||||
// make sure contributions is zeroed, we could be reusing a previously allocated one
|
||||
std::fill(contribs.begin(), contribs.end(), 0);
|
||||
// start collecting the contributions
|
||||
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
|
||||
auto iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch& batch = iter->Value();
|
||||
auto batch = iter->Value();
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const RowBatch::Inst &inst = batch[i];
|
||||
auto inst = batch[i];
|
||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
// loop over output groups
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
@@ -203,15 +203,15 @@ class GBLinear : public GradientBooster {
|
||||
std::vector<bst_float> &preds = *out_preds;
|
||||
const std::vector<bst_float>& base_margin = p_fmat->Info().base_margin_;
|
||||
// start collecting the prediction
|
||||
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
|
||||
auto iter = p_fmat->RowIterator();
|
||||
const int ngroup = model_.param.num_output_group;
|
||||
preds.resize(p_fmat->Info().num_row_ * ngroup);
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
auto batch = iter->Value();
|
||||
// output convention: nrow * k, where nrow is number of rows
|
||||
// k is number of group
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<omp_ulong>(batch.size);
|
||||
const auto nsize = static_cast<omp_ulong>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < nsize; ++i) {
|
||||
const size_t ridx = batch.base_rowid + i;
|
||||
@@ -265,7 +265,7 @@ class GBLinear : public GradientBooster {
|
||||
}
|
||||
}
|
||||
|
||||
inline void Pred(const RowBatch::Inst &inst, bst_float *preds, int gid,
|
||||
inline void Pred(const SparsePage::Inst &inst, bst_float *preds, int gid,
|
||||
bst_float base) {
|
||||
bst_float psum = model_.bias()[gid] + base;
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
|
||||
@@ -221,7 +221,7 @@ class GBTree : public GradientBooster {
|
||||
predictor_->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparseBatch::Inst& inst,
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit,
|
||||
unsigned root_index) override {
|
||||
@@ -361,7 +361,7 @@ class Dart : public GBTree {
|
||||
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparseBatch::Inst& inst,
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit,
|
||||
unsigned root_index) override {
|
||||
@@ -437,21 +437,21 @@ class Dart : public GBTree {
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
|
||||
auto iter = p_fmat->RowIterator();
|
||||
auto* self = static_cast<Derived*>(this);
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
auto batch = iter->Value();
|
||||
// parallel over local batch
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp_[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
RowBatch::Inst inst[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
@@ -470,7 +470,7 @@ class Dart : public GBTree {
|
||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||
RegTree::FVec& feats = thread_temp_[0];
|
||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
const RowBatch::Inst inst = batch[i];
|
||||
const SparsePage::Inst inst = batch[i];
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] +=
|
||||
@@ -497,7 +497,7 @@ class Dart : public GBTree {
|
||||
}
|
||||
|
||||
// predict the leaf scores without dropped trees
|
||||
inline bst_float PredValue(const RowBatch::Inst &inst,
|
||||
inline bst_float PredValue(const SparsePage::Inst &inst,
|
||||
int bst_group,
|
||||
unsigned root_index,
|
||||
RegTree::FVec *p_feats,
|
||||
|
||||
Reference in New Issue
Block a user