This commit is contained in:
tqchen
2014-11-19 15:28:09 -08:00
parent 55e62a7120
commit 7c3a392136
8 changed files with 136 additions and 133 deletions

View File

@@ -7,7 +7,7 @@
#include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp"
#include "./updater_distcol-inl.hpp"
#include "./updater_skmaker-inl.hpp"
//#include "./updater_skmaker-inl.hpp"
#include "./updater_histmaker-inl.hpp"
namespace xgboost {
@@ -18,8 +18,8 @@ IUpdater* CreateUpdater(const char *name) {
if (!strcmp(name, "sync")) return new TreeSyncher();
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
//if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
//if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
utils::Error("unknown updater:%s", name);

View File

@@ -306,6 +306,7 @@ class CQHistMaker: public HistMaker<TStats> {
hist.data[istart].Add(gpair, info, ridx);
}
};
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
virtual void CreateHist(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info,
@@ -371,21 +372,22 @@ class CQHistMaker: public HistMaker<TStats> {
// setup maximum size
unsigned max_size = this->param.max_sketch_size();
// synchronize sketch
summary_array.Init(sketchs.size(), max_size);
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array.Set(i, out);
summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size);
}
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
sreducer.AllReduce(&summary_array, n4bytes);
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
// now we get the final result of sketch, setup the cut
this->wspace.cut.clear();
this->wspace.rptr.clear();
this->wspace.rptr.push_back(0);
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = a.data[i].value - rt_eps;
if (i == 1 || cpt > this->wspace.cut.back()) {
@@ -407,7 +409,7 @@ class CQHistMaker: public HistMaker<TStats> {
}
utils::Assert(this->wspace.rptr.size() ==
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
"cut space inconsistent");
"cut space inconsistent");
}
private:
@@ -496,7 +498,6 @@ class CQHistMaker: public HistMaker<TStats> {
}
}
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
// thread temp data
std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch;
// used to hold statistics
@@ -506,9 +507,9 @@ class CQHistMaker: public HistMaker<TStats> {
// node statistics
std::vector<TStats> node_stats;
// summary array
WXQSketch::SummaryArray summary_array;
std::vector< WXQSketch::SummaryContainer> summary_array;
// reducer for summary
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
// per node, per feature sketch
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
};
@@ -580,23 +581,24 @@ class QuantileHistMaker: public HistMaker<TStats> {
}
}
// setup maximum size
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
unsigned max_size = this->param.max_sketch_size();
// synchronize sketch
summary_array.Init(sketchs.size(), max_size);
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array.Set(i, out);
summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size);
}
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
sreducer.AllReduce(&summary_array, n4bytes);
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
// now we get the final result of sketch, setup the cut
this->wspace.cut.clear();
this->wspace.rptr.clear();
this->wspace.rptr.push_back(0);
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = a.data[i].value - rt_eps;
if (i == 1 || cpt > this->wspace.cut.back()) {
@@ -624,9 +626,9 @@ class QuantileHistMaker: public HistMaker<TStats> {
private:
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
// summary array
WXQSketch::SummaryArray summary_array;
std::vector< WXQSketch::SummaryContainer> summary_array;
// reducer for summary
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
// local temp column data structure
std::vector<size_t> col_ptr;
// local storage of column data