From c40afa202355b66daa02a6aa561cfda7aab1c630 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 11 Feb 2015 11:37:14 -0800 Subject: [PATCH] fix sklearner --- src/tree/updater.cpp | 2 ++ src/tree/updater_skmaker-inl.hpp | 27 ++++++++++++++------------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/tree/updater.cpp b/src/tree/updater.cpp index 53b3d6aa1..5d2e99820 100644 --- a/src/tree/updater.cpp +++ b/src/tree/updater.cpp @@ -10,6 +10,7 @@ #include "./updater_sync-inl.hpp" #include "./updater_distcol-inl.hpp" #include "./updater_histmaker-inl.hpp" +#include "./updater_skmaker-inl.hpp" #endif namespace xgboost { @@ -22,6 +23,7 @@ IUpdater* CreateUpdater(const char *name) { #ifndef XGBOOST_STRICT_CXX98_ if (!strcmp(name, "sync")) return new TreeSyncher(); if (!strcmp(name, "grow_histmaker")) return new CQHistMaker(); + if (!strcmp(name, "grow_skmaker")) return new SketchMaker(); if (!strcmp(name, "distcol")) return new DistColMaker(); #endif utils::Error("unknown updater:%s", name); diff --git a/src/tree/updater_skmaker-inl.hpp b/src/tree/updater_skmaker-inl.hpp index 45202273a..3dee0607c 100644 --- a/src/tree/updater_skmaker-inl.hpp +++ b/src/tree/updater_skmaker-inl.hpp @@ -8,7 +8,7 @@ */ #include #include -#include +#include "../sync/sync.h" #include "../utils/quantile.h" #include "./updater_basemaker-inl.hpp" @@ -123,8 +123,8 @@ class SketchMaker: public BaseMaker { sum_hess += b.sum_hess; } /*! \brief same as add, reduce is used in All Reduce */ - inline void Reduce(const SKStats &b) { - this->Add(b); + inline static void Reduce(SKStats &a, const SKStats &b) { + a.Add(b); } /*! \brief set leaf vector value based on statistics */ inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const { @@ -156,18 +156,19 @@ class SketchMaker: public BaseMaker { batch[i].length == nrows, &thread_sketch[omp_get_thread_num()]); } - } + } // setup maximum size unsigned max_size = param.max_sketch_size(); // synchronize sketch - summary_array.Init(sketchs.size(), max_size); + summary_array.resize(sketchs.size()); for (size_t i = 0; i < sketchs.size(); ++i) { utils::WXQuantileSketch::SummaryContainer out; sketchs[i].GetSummary(&out); - summary_array.Set(i, out); + summary_array[i].Reserve(max_size); + summary_array[i].SetPrune(out, max_size); } - size_t nbytes = summary_array.MemSize();; - sketch_reducer.Allreduce(&summary_array, nbytes); + size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); + sketch_reducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); } // update sketch information in column fid inline void UpdateSketchCol(const std::vector &gpair, @@ -186,7 +187,7 @@ class SketchMaker: public BaseMaker { const unsigned wid = this->node2workindex[nid]; for (int k = 0; k < 3; ++k) { sbuilder[3 * nid + k].sum_total = 0.0f; - sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k]; + sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k]; } } if (!col_full) { @@ -367,7 +368,7 @@ class SketchMaker: public BaseMaker { c.sum_hess >= param.min_child_weight) { bst_float cpt = fsplits.back(); double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; - best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, true); + best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, false); } } } @@ -380,11 +381,11 @@ class SketchMaker: public BaseMaker { // node statistics std::vector node_stats; // summary array - WXQSketch::SummaryArray summary_array; + std::vector summary_array; // reducer for summary - rabit::Reducer stats_reducer; + rabit::Reducer stats_reducer; // reducer for summary - rabit::SerializeReducer sketch_reducer; + rabit::SerializeReducer sketch_reducer; // per node, per feature sketch std::vector< utils::WXQuantileSketch > sketchs; };