fix sklearner

This commit is contained in:
tqchen 2015-02-11 11:37:14 -08:00
parent c639efc71b
commit c40afa2023
2 changed files with 16 additions and 13 deletions

View File

@ -10,6 +10,7 @@
#include "./updater_sync-inl.hpp"
#include "./updater_distcol-inl.hpp"
#include "./updater_histmaker-inl.hpp"
#include "./updater_skmaker-inl.hpp"
#endif
namespace xgboost {
@ -22,6 +23,7 @@ IUpdater* CreateUpdater(const char *name) {
#ifndef XGBOOST_STRICT_CXX98_
if (!strcmp(name, "sync")) return new TreeSyncher();
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
#endif
utils::Error("unknown updater:%s", name);

View File

@ -8,7 +8,7 @@
*/
#include <vector>
#include <algorithm>
#include <rabit.h>
#include "../sync/sync.h"
#include "../utils/quantile.h"
#include "./updater_basemaker-inl.hpp"
@ -123,8 +123,8 @@ class SketchMaker: public BaseMaker {
sum_hess += b.sum_hess;
}
/*! \brief same as add, reduce is used in All Reduce */
inline void Reduce(const SKStats &b) {
this->Add(b);
inline static void Reduce(SKStats &a, const SKStats &b) {
a.Add(b);
}
/*! \brief set leaf vector value based on statistics */
inline void SetLeafVec(const TrainParam &param, bst_float *vec) const {
@ -156,18 +156,19 @@ class SketchMaker: public BaseMaker {
batch[i].length == nrows,
&thread_sketch[omp_get_thread_num()]);
}
}
}
// setup maximum size
unsigned max_size = param.max_sketch_size();
// synchronize sketch
summary_array.Init(sketchs.size(), max_size);
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array.Set(i, out);
summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size);
}
size_t nbytes = summary_array.MemSize();;
sketch_reducer.Allreduce(&summary_array, nbytes);
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sketch_reducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
}
// update sketch information in column fid
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
@ -186,7 +187,7 @@ class SketchMaker: public BaseMaker {
const unsigned wid = this->node2workindex[nid];
for (int k = 0; k < 3; ++k) {
sbuilder[3 * nid + k].sum_total = 0.0f;
sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k];
sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k];
}
}
if (!col_full) {
@ -367,7 +368,7 @@ class SketchMaker: public BaseMaker {
c.sum_hess >= param.min_child_weight) {
bst_float cpt = fsplits.back();
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, true);
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, false);
}
}
}
@ -380,11 +381,11 @@ class SketchMaker: public BaseMaker {
// node statistics
std::vector<SKStats> node_stats;
// summary array
WXQSketch::SummaryArray summary_array;
std::vector<WXQSketch::SummaryContainer> summary_array;
// reducer for summary
rabit::Reducer<SKStats> stats_reducer;
rabit::Reducer<SKStats, SKStats::Reduce> stats_reducer;
// reducer for summary
rabit::SerializeReducer<WXQSketch::SummaryArray> sketch_reducer;
rabit::SerializeReducer<WXQSketch::SummaryContainer> sketch_reducer;
// per node, per feature sketch
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
};