fix sklearner
This commit is contained in:
parent
c639efc71b
commit
c40afa2023
@ -10,6 +10,7 @@
|
||||
#include "./updater_sync-inl.hpp"
|
||||
#include "./updater_distcol-inl.hpp"
|
||||
#include "./updater_histmaker-inl.hpp"
|
||||
#include "./updater_skmaker-inl.hpp"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
@ -22,6 +23,7 @@ IUpdater* CreateUpdater(const char *name) {
|
||||
#ifndef XGBOOST_STRICT_CXX98_
|
||||
if (!strcmp(name, "sync")) return new TreeSyncher();
|
||||
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
|
||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||
#endif
|
||||
utils::Error("unknown updater:%s", name);
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
*/
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <rabit.h>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/quantile.h"
|
||||
#include "./updater_basemaker-inl.hpp"
|
||||
|
||||
@ -123,8 +123,8 @@ class SketchMaker: public BaseMaker {
|
||||
sum_hess += b.sum_hess;
|
||||
}
|
||||
/*! \brief same as add, reduce is used in All Reduce */
|
||||
inline void Reduce(const SKStats &b) {
|
||||
this->Add(b);
|
||||
inline static void Reduce(SKStats &a, const SKStats &b) {
|
||||
a.Add(b);
|
||||
}
|
||||
/*! \brief set leaf vector value based on statistics */
|
||||
inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const {
|
||||
@ -156,18 +156,19 @@ class SketchMaker: public BaseMaker {
|
||||
batch[i].length == nrows,
|
||||
&thread_sketch[omp_get_thread_num()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// setup maximum size
|
||||
unsigned max_size = param.max_sketch_size();
|
||||
// synchronize sketch
|
||||
summary_array.Init(sketchs.size(), max_size);
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array.Set(i, out);
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
size_t nbytes = summary_array.MemSize();;
|
||||
sketch_reducer.Allreduce(&summary_array, nbytes);
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sketch_reducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
}
|
||||
// update sketch information in column fid
|
||||
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
||||
@ -186,7 +187,7 @@ class SketchMaker: public BaseMaker {
|
||||
const unsigned wid = this->node2workindex[nid];
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
sbuilder[3 * nid + k].sum_total = 0.0f;
|
||||
sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k];
|
||||
sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k];
|
||||
}
|
||||
}
|
||||
if (!col_full) {
|
||||
@ -367,7 +368,7 @@ class SketchMaker: public BaseMaker {
|
||||
c.sum_hess >= param.min_child_weight) {
|
||||
bst_float cpt = fsplits.back();
|
||||
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
|
||||
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, true);
|
||||
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -380,11 +381,11 @@ class SketchMaker: public BaseMaker {
|
||||
// node statistics
|
||||
std::vector<SKStats> node_stats;
|
||||
// summary array
|
||||
WXQSketch::SummaryArray summary_array;
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||
// reducer for summary
|
||||
rabit::Reducer<SKStats> stats_reducer;
|
||||
rabit::Reducer<SKStats, SKStats::Reduce> stats_reducer;
|
||||
// reducer for summary
|
||||
rabit::SerializeReducer<WXQSketch::SummaryArray> sketch_reducer;
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sketch_reducer;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user