[TREE] switch to two pass
This commit is contained in:
parent
523afcbcd2
commit
a500fbc9b0
@ -361,8 +361,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||||
const RegTree &tree) override {
|
const RegTree &tree) override {
|
||||||
// remove this reset and do two pass reset on ResetPosAndPropose
|
this->GetSplitSet(this->qexpand, tree, &fsplit_set);
|
||||||
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
|
||||||
}
|
}
|
||||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
@ -398,11 +397,20 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
// get smmary
|
// get smmary
|
||||||
thread_sketch.resize(this->get_nthread());
|
thread_sketch.resize(this->get_nthread());
|
||||||
|
|
||||||
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
|
this->SetDefaultPostion(p_fmat, tree);
|
||||||
|
work_set.insert(work_set.end(), fsplit_set.begin(), fsplit_set.end());
|
||||||
|
std::sort(work_set.begin(), work_set.end());
|
||||||
|
work_set.resize(std::unique(work_set.begin(), work_set.end()) - work_set.begin());
|
||||||
|
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(work_set);
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(work_set);
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
const ColBatch &batch = iter->Value();
|
const ColBatch &batch = iter->Value();
|
||||||
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
|
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set, tree);
|
||||||
|
|
||||||
// start enumeration
|
// start enumeration
|
||||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user