compile
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#include "./updater_refresh-inl.hpp"
|
||||
#include "./updater_colmaker-inl.hpp"
|
||||
#include "./updater_distcol-inl.hpp"
|
||||
#include "./updater_skmaker-inl.hpp"
|
||||
//#include "./updater_skmaker-inl.hpp"
|
||||
#include "./updater_histmaker-inl.hpp"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -18,8 +18,8 @@ IUpdater* CreateUpdater(const char *name) {
|
||||
if (!strcmp(name, "sync")) return new TreeSyncher();
|
||||
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
|
||||
//if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
|
||||
//if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
|
||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||
|
||||
utils::Error("unknown updater:%s", name);
|
||||
|
||||
@@ -306,6 +306,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
hist.data[istart].Add(gpair, info, ridx);
|
||||
}
|
||||
};
|
||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
virtual void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
@@ -371,21 +372,22 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// setup maximum size
|
||||
unsigned max_size = this->param.max_sketch_size();
|
||||
// synchronize sketch
|
||||
summary_array.Init(sketchs.size(), max_size);
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array.Set(i, out);
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
||||
sreducer.AllReduce(&summary_array, n4bytes);
|
||||
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
||||
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
this->wspace.rptr.clear();
|
||||
this->wspace.rptr.push_back(0);
|
||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
|
||||
const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
|
||||
for (size_t i = 1; i < a.size; ++i) {
|
||||
bst_float cpt = a.data[i].value - rt_eps;
|
||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||
@@ -407,7 +409,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
utils::Assert(this->wspace.rptr.size() ==
|
||||
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
||||
"cut space inconsistent");
|
||||
"cut space inconsistent");
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -496,7 +498,6 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
}
|
||||
|
||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
// thread temp data
|
||||
std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
||||
// used to hold statistics
|
||||
@@ -506,9 +507,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// node statistics
|
||||
std::vector<TStats> node_stats;
|
||||
// summary array
|
||||
WXQSketch::SummaryArray summary_array;
|
||||
std::vector< WXQSketch::SummaryContainer> summary_array;
|
||||
// reducer for summary
|
||||
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
|
||||
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
@@ -580,23 +581,24 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
}
|
||||
// setup maximum size
|
||||
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
|
||||
unsigned max_size = this->param.max_sketch_size();
|
||||
// synchronize sketch
|
||||
summary_array.Init(sketchs.size(), max_size);
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array.Set(i, out);
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
||||
sreducer.AllReduce(&summary_array, n4bytes);
|
||||
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
||||
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
this->wspace.rptr.clear();
|
||||
this->wspace.rptr.push_back(0);
|
||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
|
||||
const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
|
||||
for (size_t i = 1; i < a.size; ++i) {
|
||||
bst_float cpt = a.data[i].value - rt_eps;
|
||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||
@@ -624,9 +626,9 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
private:
|
||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
// summary array
|
||||
WXQSketch::SummaryArray summary_array;
|
||||
std::vector< WXQSketch::SummaryContainer> summary_array;
|
||||
// reducer for summary
|
||||
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
|
||||
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
// local temp column data structure
|
||||
std::vector<size_t> col_ptr;
|
||||
// local storage of column data
|
||||
|
||||
Reference in New Issue
Block a user