fix sklearner

This commit is contained in:
tqchen 2015-02-11 11:37:14 -08:00
parent c639efc71b
commit c40afa2023
2 changed files with 16 additions and 13 deletions

View File

@ -10,6 +10,7 @@
#include "./updater_sync-inl.hpp" #include "./updater_sync-inl.hpp"
#include "./updater_distcol-inl.hpp" #include "./updater_distcol-inl.hpp"
#include "./updater_histmaker-inl.hpp" #include "./updater_histmaker-inl.hpp"
#include "./updater_skmaker-inl.hpp"
#endif #endif
namespace xgboost { namespace xgboost {
@ -22,6 +23,7 @@ IUpdater* CreateUpdater(const char *name) {
#ifndef XGBOOST_STRICT_CXX98_ #ifndef XGBOOST_STRICT_CXX98_
if (!strcmp(name, "sync")) return new TreeSyncher(); if (!strcmp(name, "sync")) return new TreeSyncher();
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>(); if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>(); if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
#endif #endif
utils::Error("unknown updater:%s", name); utils::Error("unknown updater:%s", name);

View File

@ -8,7 +8,7 @@
*/ */
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <rabit.h> #include "../sync/sync.h"
#include "../utils/quantile.h" #include "../utils/quantile.h"
#include "./updater_basemaker-inl.hpp" #include "./updater_basemaker-inl.hpp"
@ -123,8 +123,8 @@ class SketchMaker: public BaseMaker {
sum_hess += b.sum_hess; sum_hess += b.sum_hess;
} }
/*! \brief same as add, reduce is used in All Reduce */ /*! \brief same as add, reduce is used in All Reduce */
inline void Reduce(const SKStats &b) { inline static void Reduce(SKStats &a, const SKStats &b) {
this->Add(b); a.Add(b);
} }
/*! \brief set leaf vector value based on statistics */ /*! \brief set leaf vector value based on statistics */
inline void SetLeafVec(const TrainParam &param, bst_float *vec) const { inline void SetLeafVec(const TrainParam &param, bst_float *vec) const {
@ -156,18 +156,19 @@ class SketchMaker: public BaseMaker {
batch[i].length == nrows, batch[i].length == nrows,
&thread_sketch[omp_get_thread_num()]); &thread_sketch[omp_get_thread_num()]);
} }
} }
// setup maximum size // setup maximum size
unsigned max_size = param.max_sketch_size(); unsigned max_size = param.max_sketch_size();
// synchronize sketch // synchronize sketch
summary_array.Init(sketchs.size(), max_size); summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out; utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out); sketchs[i].GetSummary(&out);
summary_array.Set(i, out); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size);
} }
size_t nbytes = summary_array.MemSize();; size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sketch_reducer.Allreduce(&summary_array, nbytes); sketch_reducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
} }
// update sketch information in column fid // update sketch information in column fid
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair, inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
@ -186,7 +187,7 @@ class SketchMaker: public BaseMaker {
const unsigned wid = this->node2workindex[nid]; const unsigned wid = this->node2workindex[nid];
for (int k = 0; k < 3; ++k) { for (int k = 0; k < 3; ++k) {
sbuilder[3 * nid + k].sum_total = 0.0f; sbuilder[3 * nid + k].sum_total = 0.0f;
sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k]; sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k];
} }
} }
if (!col_full) { if (!col_full) {
@ -367,7 +368,7 @@ class SketchMaker: public BaseMaker {
c.sum_hess >= param.min_child_weight) { c.sum_hess >= param.min_child_weight) {
bst_float cpt = fsplits.back(); bst_float cpt = fsplits.back();
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, true); best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, false);
} }
} }
} }
@ -380,11 +381,11 @@ class SketchMaker: public BaseMaker {
// node statistics // node statistics
std::vector<SKStats> node_stats; std::vector<SKStats> node_stats;
// summary array // summary array
WXQSketch::SummaryArray summary_array; std::vector<WXQSketch::SummaryContainer> summary_array;
// reducer for summary // reducer for summary
rabit::Reducer<SKStats> stats_reducer; rabit::Reducer<SKStats, SKStats::Reduce> stats_reducer;
// reducer for summary // reducer for summary
rabit::SerializeReducer<WXQSketch::SummaryArray> sketch_reducer; rabit::SerializeReducer<WXQSketch::SummaryContainer> sketch_reducer;
// per node, per feature sketch // per node, per feature sketch
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs; std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
}; };