ok
This commit is contained in:
@@ -15,7 +15,7 @@ IUpdater* CreateUpdater(const char *name) {
|
||||
if (!strcmp(name, "prune")) return new TreePruner();
|
||||
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_histmaker")) return new HistMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_histmaker")) return new QuantileHistMaker<GradStats>();
|
||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
|
||||
if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >();
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/quantile.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -140,7 +141,13 @@ class HistMaker: public IUpdater {
|
||||
}
|
||||
return n.cdefault();
|
||||
}
|
||||
|
||||
|
||||
// this function does two jobs
|
||||
// (1) reset the position in array position, to be the latest leaf id
|
||||
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
|
||||
virtual void ResetPosAndPropose(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const RegTree &tree) = 0;
|
||||
private:
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
@@ -160,7 +167,8 @@ class HistMaker: public IUpdater {
|
||||
inline void InitData(const std::vector<bst_gpair> &gpair,
|
||||
const IFMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index, const RegTree &tree) {
|
||||
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "HistMaker: can only grow new tree");
|
||||
utils::Assert(tree.param.num_nodes == tree.param.num_roots,
|
||||
"HistMaker: can only grow new tree");
|
||||
{// setup position
|
||||
position.resize(gpair.size());
|
||||
if (root_index.size() == 0) {
|
||||
@@ -212,15 +220,6 @@ class HistMaker: public IUpdater {
|
||||
node2workindex[qexpand[i]] = static_cast<int>(i);
|
||||
}
|
||||
}
|
||||
// this function does two jobs
|
||||
// (1) reset the position in array position, to be the latest leaf id
|
||||
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
|
||||
virtual void ResetPosAndPropose(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const RegTree &tree) {
|
||||
|
||||
}
|
||||
// create histogram for a setup histset
|
||||
inline void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
@@ -250,7 +249,7 @@ class HistMaker: public IUpdater {
|
||||
const int nid = position[ridx];
|
||||
if (nid >= 0) {
|
||||
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
|
||||
const int wid = node2workindex[nid];
|
||||
const int wid = node2workindex[nid];
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
|
||||
// feature histogram
|
||||
@@ -312,7 +311,8 @@ class HistMaker: public IUpdater {
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) {
|
||||
const int nid = qexpand[wid];
|
||||
utils::Assert(node2workindex[nid] == static_cast<int>(wid), "node2workindex inconsistent");
|
||||
utils::Assert(node2workindex[nid] == static_cast<int>(wid),
|
||||
"node2workindex inconsistent");
|
||||
SplitEntry &best = sol[wid];
|
||||
TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
for (bst_uint fid = 0; fid < num_feature; ++ fid) {
|
||||
@@ -345,6 +345,36 @@ class HistMaker: public IUpdater {
|
||||
}
|
||||
};
|
||||
|
||||
// hist maker that propose using quantile sketch
|
||||
template<typename TStats>
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
protected:
|
||||
virtual void ResetPosAndPropose(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const RegTree &tree) {
|
||||
// start accumulating statistics
|
||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
RowBatch::Inst inst = batch[i];
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
int nid = this->position[ridx];
|
||||
if (nid >= 0) {
|
||||
if (tree[nid].is_leaf()) {
|
||||
this->position[ridx] = ~nid;
|
||||
} else {
|
||||
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
||||
// todo add the cut point setup
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user