first ver

This commit is contained in:
tqchen 2014-11-15 09:46:30 -08:00
parent 698c010247
commit c1f1bb9206
2 changed files with 72 additions and 10 deletions

View File

@ -36,6 +36,8 @@ struct TrainParam{
float colsample_bytree; float colsample_bytree;
// speed optimization for dense column // speed optimization for dense column
float opt_dense_col; float opt_dense_col;
// accuracy of sketch
float sketch_eps;
// leaf vector size // leaf vector size
int size_leaf_vector; int size_leaf_vector;
// option for parallelization // option for parallelization
@ -58,6 +60,7 @@ struct TrainParam{
nthread = 0; nthread = 0;
size_leaf_vector = 0; size_leaf_vector = 0;
parallel_option = 2; parallel_option = 2;
sketch_eps = 0.1f;
} }
/*! /*!
* \brief set parameters from outside * \brief set parameters from outside
@ -79,6 +82,7 @@ struct TrainParam{
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val)); if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val)); if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val)); if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
if (!strcmp(name, "sketch_eps")) sketch_eps = static_cast<float>(atof(val));
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val)); if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val); if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val);
if (!strcmp(name, "max_depth")) max_depth = atoi(val); if (!strcmp(name, "max_depth")) max_depth = atoi(val);

View File

@ -9,6 +9,7 @@
#include <algorithm> #include <algorithm>
#include "../sync/sync.h" #include "../sync/sync.h"
#include "../utils/quantile.h" #include "../utils/quantile.h"
#include "../utils/group_data.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -145,7 +146,8 @@ class HistMaker: public IUpdater {
// this function does two jobs // this function does two jobs
// (1) reset the position in array position, to be the latest leaf id // (1) reset the position in array position, to be the latest leaf id
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly // (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
virtual void ResetPosAndPropose(IFMatrix *p_fmat, virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info, const BoosterInfo &info,
const RegTree &tree) = 0; const RegTree &tree) = 0;
private: private:
@ -249,7 +251,7 @@ class HistMaker: public IUpdater {
const int nid = position[ridx]; const int nid = position[ridx];
if (nid >= 0) { if (nid >= 0) {
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf"); utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
const int wid = node2workindex[nid]; const int wid = this->node2workindex[nid];
for (bst_uint i = 0; i < inst.length; ++i) { for (bst_uint i = 0; i < inst.length; ++i) {
utils::Assert(inst[i].index < num_feature, "feature index exceed bound"); utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
// feature histogram // feature histogram
@ -302,7 +304,7 @@ class HistMaker: public IUpdater {
RegTree *p_tree) { RegTree *p_tree) {
const bst_uint num_feature = p_tree->param.num_feature; const bst_uint num_feature = p_tree->param.num_feature;
// reset and propose candidate split // reset and propose candidate split
this->ResetPosAndPropose(p_fmat, info, *p_tree); this->ResetPosAndPropose(gpair, p_fmat, info, *p_tree);
// create histogram // create histogram
this->CreateHist(gpair, p_fmat, info, *p_tree); this->CreateHist(gpair, p_fmat, info, *p_tree);
// get the best split condition for each node // get the best split condition for each node
@ -349,14 +351,29 @@ class HistMaker: public IUpdater {
template<typename TStats> template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> { class QuantileHistMaker: public HistMaker<TStats> {
protected: protected:
virtual void ResetPosAndPropose(IFMatrix *p_fmat, virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info, const BoosterInfo &info,
const RegTree &tree) { const RegTree &tree) {
// initialize the data structure
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
for (size_t i = 0; i < sketchs.size(); ++i) {
sketchs[i].Init(info.num_row, this->param.sketch_eps);
}
// start accumulating statistics // start accumulating statistics
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator(); utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst(); iter->BeforeFirst();
while (iter->Next()) { while (iter->Next()) {
const RowBatch &batch = iter->Value(); const RowBatch &batch = iter->Value();
// parallel convert to column major format
utils::ParallelGroupBuilder<SparseBatch::Entry> builder(&col_ptr, &col_data, &thread_col_ptr);
builder.InitBudget(tree.param.num_feature, nthread);
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size); const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) { for (bst_omp_uint i = 0; i < nbatch; ++i) {
@ -368,12 +385,53 @@ class QuantileHistMaker: public HistMaker<TStats> {
this->position[ridx] = ~nid; this->position[ridx] = ~nid;
} else { } else {
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid); this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
// todo add the cut point setup for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].index, omp_get_thread_num());
}
} }
} }
} }
builder.InitStorage();
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
RowBatch::Inst inst = batch[i];
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
const int nid = this->position[ridx];
if (nid >= 0) {
for (bst_uint j = 0; j < inst.length; ++j) {
builder.Push(inst[j].index,
SparseBatch::Entry(nid, inst[j].fvalue),
omp_get_thread_num());
}
}
}
// start putting things into sketch
const bst_omp_uint nfeat = tree.param.num_feature;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint k = 0; k < nfeat; ++k) {
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
const SparseBatch::Entry &e = col_data[i];
const int wid = this->node2workindex[e.index];
sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].hess);
}
}
} }
// synchronize sketch
// now we have all the results in the sketchs, try to setup the cut point
} }
private:
//
// local temp column data structure
std::vector<size_t> col_ptr;
// local storage of column data
std::vector<SparseBatch::Entry> col_data;
std::vector< std::vector<size_t> > thread_col_ptr;
// per node, per feature sketch
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
}; };
} // namespace tree } // namespace tree