first ver
This commit is contained in:
parent
698c010247
commit
c1f1bb9206
@ -36,6 +36,8 @@ struct TrainParam{
|
||||
float colsample_bytree;
|
||||
// speed optimization for dense column
|
||||
float opt_dense_col;
|
||||
// accuracy of sketch
|
||||
float sketch_eps;
|
||||
// leaf vector size
|
||||
int size_leaf_vector;
|
||||
// option for parallelization
|
||||
@ -58,6 +60,7 @@ struct TrainParam{
|
||||
nthread = 0;
|
||||
size_leaf_vector = 0;
|
||||
parallel_option = 2;
|
||||
sketch_eps = 0.1f;
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
@ -79,6 +82,7 @@ struct TrainParam{
|
||||
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "sketch_eps")) sketch_eps = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val);
|
||||
if (!strcmp(name, "max_depth")) max_depth = atoi(val);
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/quantile.h"
|
||||
#include "../utils/group_data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -145,7 +146,8 @@ class HistMaker: public IUpdater {
|
||||
// this function does two jobs
|
||||
// (1) reset the position in array position, to be the latest leaf id
|
||||
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
|
||||
virtual void ResetPosAndPropose(IFMatrix *p_fmat,
|
||||
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const RegTree &tree) = 0;
|
||||
private:
|
||||
@ -249,7 +251,7 @@ class HistMaker: public IUpdater {
|
||||
const int nid = position[ridx];
|
||||
if (nid >= 0) {
|
||||
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
|
||||
const int wid = node2workindex[nid];
|
||||
const int wid = this->node2workindex[nid];
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
|
||||
// feature histogram
|
||||
@ -302,7 +304,7 @@ class HistMaker: public IUpdater {
|
||||
RegTree *p_tree) {
|
||||
const bst_uint num_feature = p_tree->param.num_feature;
|
||||
// reset and propose candidate split
|
||||
this->ResetPosAndPropose(p_fmat, info, *p_tree);
|
||||
this->ResetPosAndPropose(gpair, p_fmat, info, *p_tree);
|
||||
// create histogram
|
||||
this->CreateHist(gpair, p_fmat, info, *p_tree);
|
||||
// get the best split condition for each node
|
||||
@ -347,17 +349,32 @@ class HistMaker: public IUpdater {
|
||||
|
||||
// hist maker that propose using quantile sketch
|
||||
template<typename TStats>
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
protected:
|
||||
virtual void ResetPosAndPropose(IFMatrix *p_fmat,
|
||||
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const RegTree &tree) {
|
||||
// initialize the data structure
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||
}
|
||||
// start accumulating statistics
|
||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
|
||||
// parallel convert to column major format
|
||||
utils::ParallelGroupBuilder<SparseBatch::Entry> builder(&col_ptr, &col_data, &thread_col_ptr);
|
||||
builder.InitBudget(tree.param.num_feature, nthread);
|
||||
|
||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
RowBatch::Inst inst = batch[i];
|
||||
@ -367,13 +384,54 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
if (tree[nid].is_leaf()) {
|
||||
this->position[ridx] = ~nid;
|
||||
} else {
|
||||
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
||||
// todo add the cut point setup
|
||||
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
builder.AddBudget(inst[j].index, omp_get_thread_num());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
RowBatch::Inst inst = batch[i];
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
const int nid = this->position[ridx];
|
||||
if (nid >= 0) {
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
builder.Push(inst[j].index,
|
||||
SparseBatch::Entry(nid, inst[j].fvalue),
|
||||
omp_get_thread_num());
|
||||
}
|
||||
}
|
||||
}
|
||||
// start putting things into sketch
|
||||
const bst_omp_uint nfeat = tree.param.num_feature;
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint k = 0; k < nfeat; ++k) {
|
||||
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
|
||||
const SparseBatch::Entry &e = col_data[i];
|
||||
const int wid = this->node2workindex[e.index];
|
||||
sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].hess);
|
||||
}
|
||||
}
|
||||
}
|
||||
// synchronize sketch
|
||||
|
||||
|
||||
// now we have all the results in the sketchs, try to setup the cut point
|
||||
}
|
||||
|
||||
private:
|
||||
//
|
||||
|
||||
// local temp column data structure
|
||||
std::vector<size_t> col_ptr;
|
||||
// local storage of column data
|
||||
std::vector<SparseBatch::Entry> col_data;
|
||||
std::vector< std::vector<size_t> > thread_col_ptr;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user