fix sklearner
This commit is contained in:
parent
c639efc71b
commit
c40afa2023
@ -10,6 +10,7 @@
|
|||||||
#include "./updater_sync-inl.hpp"
|
#include "./updater_sync-inl.hpp"
|
||||||
#include "./updater_distcol-inl.hpp"
|
#include "./updater_distcol-inl.hpp"
|
||||||
#include "./updater_histmaker-inl.hpp"
|
#include "./updater_histmaker-inl.hpp"
|
||||||
|
#include "./updater_skmaker-inl.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -22,6 +23,7 @@ IUpdater* CreateUpdater(const char *name) {
|
|||||||
#ifndef XGBOOST_STRICT_CXX98_
|
#ifndef XGBOOST_STRICT_CXX98_
|
||||||
if (!strcmp(name, "sync")) return new TreeSyncher();
|
if (!strcmp(name, "sync")) return new TreeSyncher();
|
||||||
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
|
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
|
||||||
|
if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
|
||||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||||
#endif
|
#endif
|
||||||
utils::Error("unknown updater:%s", name);
|
utils::Error("unknown updater:%s", name);
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <rabit.h>
|
#include "../sync/sync.h"
|
||||||
#include "../utils/quantile.h"
|
#include "../utils/quantile.h"
|
||||||
#include "./updater_basemaker-inl.hpp"
|
#include "./updater_basemaker-inl.hpp"
|
||||||
|
|
||||||
@ -123,8 +123,8 @@ class SketchMaker: public BaseMaker {
|
|||||||
sum_hess += b.sum_hess;
|
sum_hess += b.sum_hess;
|
||||||
}
|
}
|
||||||
/*! \brief same as add, reduce is used in All Reduce */
|
/*! \brief same as add, reduce is used in All Reduce */
|
||||||
inline void Reduce(const SKStats &b) {
|
inline static void Reduce(SKStats &a, const SKStats &b) {
|
||||||
this->Add(b);
|
a.Add(b);
|
||||||
}
|
}
|
||||||
/*! \brief set leaf vector value based on statistics */
|
/*! \brief set leaf vector value based on statistics */
|
||||||
inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const {
|
inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const {
|
||||||
@ -156,18 +156,19 @@ class SketchMaker: public BaseMaker {
|
|||||||
batch[i].length == nrows,
|
batch[i].length == nrows,
|
||||||
&thread_sketch[omp_get_thread_num()]);
|
&thread_sketch[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// setup maximum size
|
// setup maximum size
|
||||||
unsigned max_size = param.max_sketch_size();
|
unsigned max_size = param.max_sketch_size();
|
||||||
// synchronize sketch
|
// synchronize sketch
|
||||||
summary_array.Init(sketchs.size(), max_size);
|
summary_array.resize(sketchs.size());
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||||
sketchs[i].GetSummary(&out);
|
sketchs[i].GetSummary(&out);
|
||||||
summary_array.Set(i, out);
|
summary_array[i].Reserve(max_size);
|
||||||
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
size_t nbytes = summary_array.MemSize();;
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
sketch_reducer.Allreduce(&summary_array, nbytes);
|
sketch_reducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
}
|
}
|
||||||
// update sketch information in column fid
|
// update sketch information in column fid
|
||||||
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
||||||
@ -186,7 +187,7 @@ class SketchMaker: public BaseMaker {
|
|||||||
const unsigned wid = this->node2workindex[nid];
|
const unsigned wid = this->node2workindex[nid];
|
||||||
for (int k = 0; k < 3; ++k) {
|
for (int k = 0; k < 3; ++k) {
|
||||||
sbuilder[3 * nid + k].sum_total = 0.0f;
|
sbuilder[3 * nid + k].sum_total = 0.0f;
|
||||||
sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k];
|
sbuilder[3 * nid + k].sketch = &sketchs[(wid * tree.param.num_feature + fid) * 3 + k];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!col_full) {
|
if (!col_full) {
|
||||||
@ -367,7 +368,7 @@ class SketchMaker: public BaseMaker {
|
|||||||
c.sum_hess >= param.min_child_weight) {
|
c.sum_hess >= param.min_child_weight) {
|
||||||
bst_float cpt = fsplits.back();
|
bst_float cpt = fsplits.back();
|
||||||
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
|
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
|
||||||
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, true);
|
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -380,11 +381,11 @@ class SketchMaker: public BaseMaker {
|
|||||||
// node statistics
|
// node statistics
|
||||||
std::vector<SKStats> node_stats;
|
std::vector<SKStats> node_stats;
|
||||||
// summary array
|
// summary array
|
||||||
WXQSketch::SummaryArray summary_array;
|
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
rabit::Reducer<SKStats> stats_reducer;
|
rabit::Reducer<SKStats, SKStats::Reduce> stats_reducer;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
rabit::SerializeReducer<WXQSketch::SummaryArray> sketch_reducer;
|
rabit::SerializeReducer<WXQSketch::SummaryContainer> sketch_reducer;
|
||||||
// per node, per feature sketch
|
// per node, per feature sketch
|
||||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user