a version that compile

This commit is contained in:
tqchen
2014-11-15 17:41:03 -08:00
parent c1f1bb9206
commit c86b83ea04
5 changed files with 204 additions and 9 deletions

View File

@@ -38,6 +38,8 @@ struct TrainParam{
float opt_dense_col;
// accuracy of sketch
float sketch_eps;
// accuracy of sketch
float sketch_ratio;
// leaf vector size
int size_leaf_vector;
// option for parallelization
@@ -61,6 +63,7 @@ struct TrainParam{
size_leaf_vector = 0;
parallel_option = 2;
sketch_eps = 0.1f;
sketch_ratio = 1.4f;
}
/*!
* \brief set parameters from outside
@@ -83,6 +86,7 @@ struct TrainParam{
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
if (!strcmp(name, "sketch_eps")) sketch_eps = static_cast<float>(atof(val));
if (!strcmp(name, "sketch_ratio")) sketch_ratio = static_cast<float>(atof(val));
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val);
if (!strcmp(name, "max_depth")) max_depth = atoi(val);

View File

@@ -124,8 +124,7 @@ class HistMaker: public IUpdater {
/*! \brief map active node to is working index offset in qexpand*/
std::vector<int> node2workindex;
// reducer for histogram
sync::Reducer<TStats> histred;
sync::Reducer<TStats> histred;
// helper function to get to next level of the tree
// must work on non-leaf node
inline static int NextLevel(const SparseBatch::Inst &inst, const RegTree &tree, int nid) {
@@ -142,7 +141,6 @@ class HistMaker: public IUpdater {
}
return n.cdefault();
}
// this function does two jobs
// (1) reset the position in array position, to be the latest leaf id
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
@@ -416,15 +414,38 @@ class QuantileHistMaker: public HistMaker<TStats> {
}
}
}
// setup maximum size
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
// synchronize sketch
// now we have all the results in the sketchs, try to setup the cut point
summary_array.Init(sketchs.size(), max_size);
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
sreducer.AllReduce(&summary_array, n4bytes);
// now we get the final result of sketch, setup the cut
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
for (size_t fid = 0; fid < tree.param.num_feature; ++fid) {
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
for (size_t i = 0; i < a.size; ++i) {
bst_float cpt = a.data[i].value + rt_eps;
if (i == 0 || cpt > this->wspace.cut.back()){
this->wspace.cut.push_back(cpt);
}
}
this->wspace.rptr.push_back(this->wspace.cut.size());
}
// reserve last value for global statistics
this->wspace.cut.push_back(0.0f);
this->wspace.rptr.push_back(this->wspace.cut.size());
}
utils::Assert(this->wspace.rptr.size() ==
(tree.param.num_feature + 1) * this->qexpand.size(), "cut space inconsistent");
}
private:
//
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
// summary array
WXQSketch::SummaryArray summary_array;
// reducer for summary
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
// local temp column data structure
std::vector<size_t> col_ptr;
// local storage of column data