first ver

This commit is contained in:
tqchen 2014-11-15 09:46:30 -08:00
parent 698c010247
commit c1f1bb9206
2 changed files with 72 additions and 10 deletions

View File

@ -36,6 +36,8 @@ struct TrainParam{
float colsample_bytree;
// speed optimization for dense column
float opt_dense_col;
// accuracy of sketch
float sketch_eps;
// leaf vector size
int size_leaf_vector;
// option for parallelization
@ -58,6 +60,7 @@ struct TrainParam{
nthread = 0;
size_leaf_vector = 0;
parallel_option = 2;
sketch_eps = 0.1f;
}
/*!
* \brief set parameters from outside
@ -79,6 +82,7 @@ struct TrainParam{
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
if (!strcmp(name, "sketch_eps")) sketch_eps = static_cast<float>(atof(val));
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val);
if (!strcmp(name, "max_depth")) max_depth = atoi(val);

View File

@ -9,6 +9,7 @@
#include <algorithm>
#include "../sync/sync.h"
#include "../utils/quantile.h"
#include "../utils/group_data.h"
namespace xgboost {
namespace tree {
@ -145,7 +146,8 @@ class HistMaker: public IUpdater {
// this function does two jobs
// (1) reset the position in array position, to be the latest leaf id
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
virtual void ResetPosAndPropose(IFMatrix *p_fmat,
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info,
const RegTree &tree) = 0;
private:
@ -249,7 +251,7 @@ class HistMaker: public IUpdater {
const int nid = position[ridx];
if (nid >= 0) {
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
const int wid = node2workindex[nid];
const int wid = this->node2workindex[nid];
for (bst_uint i = 0; i < inst.length; ++i) {
utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
// feature histogram
@ -302,7 +304,7 @@ class HistMaker: public IUpdater {
RegTree *p_tree) {
const bst_uint num_feature = p_tree->param.num_feature;
// reset and propose candidate split
this->ResetPosAndPropose(p_fmat, info, *p_tree);
this->ResetPosAndPropose(gpair, p_fmat, info, *p_tree);
// create histogram
this->CreateHist(gpair, p_fmat, info, *p_tree);
// get the best split condition for each node
@ -347,17 +349,32 @@ class HistMaker: public IUpdater {
// hist maker that propose using quantile sketch
template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> {
class QuantileHistMaker: public HistMaker<TStats> {
protected:
virtual void ResetPosAndPropose(IFMatrix *p_fmat,
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info,
const RegTree &tree) {
// initialize the data structure
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
for (size_t i = 0; i < sketchs.size(); ++i) {
sketchs[i].Init(info.num_row, this->param.sketch_eps);
}
// start accumulating statistics
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
// parallel convert to column major format
utils::ParallelGroupBuilder<SparseBatch::Entry> builder(&col_ptr, &col_data, &thread_col_ptr);
builder.InitBudget(tree.param.num_feature, nthread);
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
RowBatch::Inst inst = batch[i];
@ -367,13 +384,54 @@ class QuantileHistMaker: public HistMaker<TStats> {
if (tree[nid].is_leaf()) {
this->position[ridx] = ~nid;
} else {
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
// todo add the cut point setup
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].index, omp_get_thread_num());
}
}
}
}
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
RowBatch::Inst inst = batch[i];
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
const int nid = this->position[ridx];
if (nid >= 0) {
for (bst_uint j = 0; j < inst.length; ++j) {
builder.Push(inst[j].index,
SparseBatch::Entry(nid, inst[j].fvalue),
omp_get_thread_num());
}
}
}
// start putting things into sketch
const bst_omp_uint nfeat = tree.param.num_feature;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint k = 0; k < nfeat; ++k) {
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
const SparseBatch::Entry &e = col_data[i];
const int wid = this->node2workindex[e.index];
sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].hess);
}
}
}
// synchronize sketch
// now we have all the results in the sketchs, try to setup the cut point
}
private:
//
// local temp column data structure
std::vector<size_t> col_ptr;
// local storage of column data
std::vector<SparseBatch::Entry> col_data;
std::vector< std::vector<size_t> > thread_col_ptr;
// per node, per feature sketch
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
};
} // namespace tree