[TREE] finish move of updater
This commit is contained in:
@@ -55,7 +55,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
// number of threads to be used for tree construction,
|
||||
// if OpenMP is enabled, if equals 0, use system default
|
||||
int nthread;
|
||||
|
||||
// whether to not print info during training.
|
||||
bool silent;
|
||||
// declare the parameters
|
||||
DMLC_DECLARE_PARAMETER(TrainParam) {
|
||||
DMLC_DECLARE_FIELD(eta).set_lower_bound(0.0f).set_default(0.3f)
|
||||
@@ -98,6 +99,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
.describe("EXP Param: Cache aware optimization.");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0)
|
||||
.describe("Number of threads used for training.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false)
|
||||
.describe("Not print information during trainig.");
|
||||
}
|
||||
|
||||
// calculate the cost of loss function
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file updater_basemaker-inl.hpp
|
||||
* \file updater_basemaker-inl.h
|
||||
* \brief implement a common tree constructor
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_UPDATER_BASEMAKER_INL_HPP_
|
||||
#define XGBOOST_TREE_UPDATER_BASEMAKER_INL_HPP_
|
||||
#ifndef XGBOOST_TREE_UPDATER_BASEMAKER_INL_H_
|
||||
#define XGBOOST_TREE_UPDATER_BASEMAKER_INL_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/random.h"
|
||||
#include "../utils/quantile.h"
|
||||
#include <utility>
|
||||
#include "./param.h"
|
||||
#include "../common/sync.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/quantile.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -20,13 +26,10 @@ namespace tree {
|
||||
* \brief base tree maker class that defines common operation
|
||||
* needed in tree making
|
||||
*/
|
||||
class BaseMaker: public IUpdater {
|
||||
class BaseMaker: public TreeUpdater {
|
||||
public:
|
||||
// destructor
|
||||
virtual ~BaseMaker(void) {}
|
||||
// set training parameter
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
param.SetParam(name, val);
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param.Init(args);
|
||||
}
|
||||
|
||||
protected:
|
||||
@@ -34,31 +37,31 @@ class BaseMaker: public IUpdater {
|
||||
struct FMetaHelper {
|
||||
public:
|
||||
/*! \brief find type of each feature, use column format */
|
||||
inline void InitByCol(IFMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
inline void InitByCol(DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
fminmax.resize(tree.param.num_feature * 2);
|
||||
std::fill(fminmax.begin(), fminmax.end(),
|
||||
-std::numeric_limits<bst_float>::max());
|
||||
// start accumulating statistics
|
||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
||||
dmlc::DataIter<ColBatch>* iter = p_fmat->ColIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
const ColBatch& batch = iter->Value();
|
||||
for (bst_uint i = 0; i < batch.size; ++i) {
|
||||
const bst_uint fid = batch.col_index[i];
|
||||
const ColBatch::Inst &c = batch[i];
|
||||
const ColBatch::Inst& c = batch[i];
|
||||
if (c.length != 0) {
|
||||
fminmax[fid * 2 + 0] = std::max(-c[0].fvalue, fminmax[fid * 2 + 0]);
|
||||
fminmax[fid * 2 + 1] = std::max(c[c.length - 1].fvalue, fminmax[fid * 2 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
rabit::Allreduce<rabit::op::Max>(BeginPtr(fminmax), fminmax.size());
|
||||
rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
|
||||
}
|
||||
// get feature type, 0:empty 1:binary 2:real
|
||||
inline int Type(bst_uint fid) const {
|
||||
utils::Assert(fid * 2 + 1 < fminmax.size(),
|
||||
"FeatHelper fid exceed query bound ");
|
||||
CHECK_LT(fid * 2 + 1, fminmax.size())
|
||||
<< "FeatHelper fid exceed query bound ";
|
||||
bst_float a = fminmax[fid * 2];
|
||||
bst_float b = fminmax[fid * 2 + 1];
|
||||
if (a == -std::numeric_limits<bst_float>::max()) return 0;
|
||||
@@ -79,12 +82,12 @@ class BaseMaker: public IUpdater {
|
||||
if (this->Type(fid) != 0) findex.push_back(fid);
|
||||
}
|
||||
unsigned n = static_cast<unsigned>(p * findex.size());
|
||||
random::Shuffle(findex);
|
||||
std::shuffle(findex.begin(), findex.end(), common::GlobalRandom());
|
||||
findex.resize(n);
|
||||
// sync the findex if it is subsample
|
||||
std::string s_cache;
|
||||
utils::MemoryBufferStream fc(&s_cache);
|
||||
utils::IStream &fs = fc;
|
||||
common::MemoryBufferStream fc(&s_cache);
|
||||
dmlc::Stream& fs = fc;
|
||||
if (rabit::GetRank() == 0) {
|
||||
fs.Write(findex);
|
||||
}
|
||||
@@ -113,7 +116,7 @@ class BaseMaker: public IUpdater {
|
||||
return n.cdefault();
|
||||
}
|
||||
/*! \brief get number of omp thread in current context */
|
||||
inline static int get_nthread(void) {
|
||||
inline static int get_nthread() {
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
@@ -124,11 +127,11 @@ class BaseMaker: public IUpdater {
|
||||
// ------class member helpers---------
|
||||
/*! \brief initialize temp data structure */
|
||||
inline void InitData(const std::vector<bst_gpair> &gpair,
|
||||
const IFMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
const DMatrix &fmat,
|
||||
const RegTree &tree) {
|
||||
utils::Assert(tree.param.num_nodes == tree.param.num_roots,
|
||||
"TreeMaker: can only grow new tree");
|
||||
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
|
||||
<< "TreeMaker: can only grow new tree";
|
||||
const std::vector<unsigned> &root_index = fmat.info().root_index;
|
||||
{
|
||||
// setup position
|
||||
position.resize(gpair.size());
|
||||
@@ -137,8 +140,8 @@ class BaseMaker: public IUpdater {
|
||||
} else {
|
||||
for (size_t i = 0; i < position.size(); ++i) {
|
||||
position[i] = root_index[i];
|
||||
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots,
|
||||
"root index exceed setting");
|
||||
CHECK_LT(root_index[i], (unsigned)tree.param.num_roots)
|
||||
<< "root index exceed setting";
|
||||
}
|
||||
}
|
||||
// mark delete for the deleted datas
|
||||
@@ -147,9 +150,11 @@ class BaseMaker: public IUpdater {
|
||||
}
|
||||
// mark subsample
|
||||
if (param.subsample < 1.0f) {
|
||||
std::bernoulli_distribution coin_flip(param.subsample);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
for (size_t i = 0; i < position.size(); ++i) {
|
||||
if (gpair[i].hess < 0.0f) continue;
|
||||
if (random::SampleBinary(param.subsample) == 0) position[i] = ~position[i];
|
||||
if (!coin_flip(rnd)) position[i] = ~position[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,7 +202,8 @@ class BaseMaker: public IUpdater {
|
||||
* \param tree the regression tree structure
|
||||
*/
|
||||
inline void ResetPositionCol(const std::vector<int> &nodes,
|
||||
IFMatrix *p_fmat, const RegTree &tree) {
|
||||
DMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
// set the positions in the nondefault
|
||||
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
|
||||
// set rest of instances to default position
|
||||
@@ -234,7 +240,8 @@ class BaseMaker: public IUpdater {
|
||||
* \param tree the regression tree structure
|
||||
*/
|
||||
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
||||
IFMatrix *p_fmat, const RegTree &tree) {
|
||||
DMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
// step 1, classify the non-default data into right places
|
||||
std::vector<unsigned> fsplits;
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
@@ -246,7 +253,7 @@ class BaseMaker: public IUpdater {
|
||||
std::sort(fsplits.begin(), fsplits.end());
|
||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||
|
||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
@@ -273,12 +280,12 @@ class BaseMaker: public IUpdater {
|
||||
/*! \brief helper function to get statistics from a tree */
|
||||
template<typename TStats>
|
||||
inline void GetNodeStats(const std::vector<bst_gpair> &gpair,
|
||||
const IFMatrix &fmat,
|
||||
const DMatrix &fmat,
|
||||
const RegTree &tree,
|
||||
const BoosterInfo &info,
|
||||
std::vector< std::vector<TStats> > *p_thread_temp,
|
||||
std::vector<TStats> *p_node_stats) {
|
||||
std::vector< std::vector<TStats> > &thread_temp = *p_thread_temp;
|
||||
const MetaInfo &info = fmat.info();
|
||||
thread_temp.resize(this->get_nthread());
|
||||
p_node_stats->resize(tree.param.num_nodes);
|
||||
#pragma omp parallel
|
||||
@@ -323,7 +330,7 @@ class BaseMaker: public IUpdater {
|
||||
/*! \brief current size of sketch */
|
||||
double next_goal;
|
||||
// pointer to the sketch to put things in
|
||||
utils::WXQuantileSketch<bst_float, bst_float> *sketch;
|
||||
common::WXQuantileSketch<bst_float, bst_float> *sketch;
|
||||
// initialize the space
|
||||
inline void Init(unsigned max_size) {
|
||||
next_goal = -1.0f;
|
||||
@@ -351,13 +358,13 @@ class BaseMaker: public IUpdater {
|
||||
last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
|
||||
// push to sketch
|
||||
sketch->temp.data[sketch->temp.size] =
|
||||
utils::WXQuantileSketch<bst_float, bst_float>::
|
||||
common::WXQuantileSketch<bst_float, bst_float>::
|
||||
Entry(static_cast<bst_float>(rmin),
|
||||
static_cast<bst_float>(rmax),
|
||||
static_cast<bst_float>(wmin), last_fvalue);
|
||||
utils::Assert(sketch->temp.size < max_size,
|
||||
"invalid maximum size max_size=%u, stemp.size=%lu\n",
|
||||
max_size, sketch->temp.size);
|
||||
CHECK_LT(sketch->temp.size, max_size)
|
||||
<< "invalid maximum size max_size=" << max_size
|
||||
<< ", stemp.size" << sketch->temp.size;
|
||||
++sketch->temp.size;
|
||||
}
|
||||
if (sketch->temp.size == max_size) {
|
||||
@@ -382,12 +389,12 @@ class BaseMaker: public IUpdater {
|
||||
inline void Finalize(unsigned max_size) {
|
||||
double rmax = rmin + wmin;
|
||||
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
|
||||
utils::Assert(sketch->temp.size <= max_size,
|
||||
"Finalize: invalid maximum size, max_size=%u, stemp.size=%lu",
|
||||
sketch->temp.size, max_size);
|
||||
CHECK_LE(sketch->temp.size, max_size)
|
||||
<< "Finalize: invalid maximum size, max_size=" << max_size
|
||||
<< ", stemp.size=" << sketch->temp.size;
|
||||
// push to sketch
|
||||
sketch->temp.data[sketch->temp.size] =
|
||||
utils::WXQuantileSketch<bst_float, bst_float>::
|
||||
common::WXQuantileSketch<bst_float, bst_float>::
|
||||
Entry(static_cast<bst_float>(rmin),
|
||||
static_cast<bst_float>(rmax),
|
||||
static_cast<bst_float>(wmin), last_fvalue);
|
||||
@@ -424,4 +431,4 @@ class BaseMaker: public IUpdater {
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_BASEMAKER_INL_HPP_
|
||||
#endif // XGBOOST_TREE_UPDATER_BASEMAKER_INL_H_
|
||||
|
||||
@@ -747,7 +747,7 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
// update position after the tree is pruned
|
||||
builder.UpdatePosition(dmat, *trees[0]);
|
||||
}
|
||||
virtual const int* GetLeafPosition() const {
|
||||
const int* GetLeafPosition() const override {
|
||||
return builder.GetLeafPosition();
|
||||
}
|
||||
|
||||
@@ -771,7 +771,7 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
this->position[ridx] = nid;
|
||||
}
|
||||
}
|
||||
const int* GetLeafPosition() const override {
|
||||
inline const int* GetLeafPosition() const {
|
||||
return dmlc::BeginPtr(this->position);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,38 +1,35 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file updater_histmaker-inl.hpp
|
||||
* \file updater_histmaker.cc
|
||||
* \brief use histogram counting to construct a tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_UPDATER_HISTMAKER_INL_HPP_
|
||||
#define XGBOOST_TREE_UPDATER_HISTMAKER_INL_HPP_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/quantile.h"
|
||||
#include "../utils/group_data.h"
|
||||
#include "./updater_basemaker-inl.hpp"
|
||||
#include "../common/sync.h"
|
||||
#include "../common/quantile.h"
|
||||
#include "../common/group_data.h"
|
||||
#include "./updater_basemaker-inl.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
template<typename TStats>
|
||||
class HistMaker: public BaseMaker {
|
||||
public:
|
||||
virtual ~HistMaker(void) {}
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
TStats::CheckInfo(info);
|
||||
void Update(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
TStats::CheckInfo(p_fmat->info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
float lr = param.eta;
|
||||
param.eta = lr / trees.size();
|
||||
// build tree
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
this->Update(gpair, p_fmat, info, trees[i]);
|
||||
this->Update(gpair, p_fmat, trees[i]);
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
param.eta = lr;
|
||||
}
|
||||
|
||||
protected:
|
||||
@@ -45,19 +42,18 @@ class HistMaker: public BaseMaker {
|
||||
/*! \brief size of histogram */
|
||||
unsigned size;
|
||||
// default constructor
|
||||
HistUnit(void) {}
|
||||
HistUnit() {}
|
||||
// constructor
|
||||
HistUnit(const bst_float *cut, TStats *data, unsigned size)
|
||||
: cut(cut), data(data), size(size) {}
|
||||
/*! \brief add a histogram to data */
|
||||
inline void Add(bst_float fv,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
const MetaInfo &info,
|
||||
const bst_uint ridx) {
|
||||
unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
|
||||
utils::Assert(size != 0, "try insert into size=0");
|
||||
utils::Assert(i < size,
|
||||
"maximum value must be in cut, fv = %g, cutmax=%g", fv, cut[size-1]);
|
||||
CHECK_NE(size, 0) << "try insert into size=0";
|
||||
CHECK_LT(i, size);
|
||||
data[i].Add(gpair, info, ridx);
|
||||
}
|
||||
};
|
||||
@@ -92,13 +88,13 @@ class HistMaker: public BaseMaker {
|
||||
for (size_t i = 0; i < hset[tid].data.size(); ++i) {
|
||||
hset[tid].data[i].Clear();
|
||||
}
|
||||
hset[tid].rptr = BeginPtr(rptr);
|
||||
hset[tid].cut = BeginPtr(cut);
|
||||
hset[tid].rptr = dmlc::BeginPtr(rptr);
|
||||
hset[tid].cut = dmlc::BeginPtr(cut);
|
||||
hset[tid].data.resize(cut.size(), TStats(param));
|
||||
}
|
||||
}
|
||||
// aggregate all statistics to hset[0]
|
||||
inline void Aggregate(void) {
|
||||
inline void Aggregate() {
|
||||
bst_omp_uint nsize = static_cast<bst_omp_uint>(cut.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
@@ -108,11 +104,11 @@ class HistMaker: public BaseMaker {
|
||||
}
|
||||
}
|
||||
/*! \brief clear the workspace */
|
||||
inline void Clear(void) {
|
||||
inline void Clear() {
|
||||
cut.clear(); rptr.resize(1); rptr[0] = 0;
|
||||
}
|
||||
/*! \brief total size */
|
||||
inline size_t Size(void) const {
|
||||
inline size_t Size() const {
|
||||
return rptr.size() - 1;
|
||||
}
|
||||
};
|
||||
@@ -124,18 +120,17 @@ class HistMaker: public BaseMaker {
|
||||
std::vector<bst_uint> fwork_set;
|
||||
// update function implementation
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree) {
|
||||
this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitWorkSet(p_fmat, *p_tree, &fwork_set);
|
||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||
// reset and propose candidate split
|
||||
this->ResetPosAndPropose(gpair, p_fmat, info, fwork_set, *p_tree);
|
||||
this->ResetPosAndPropose(gpair, p_fmat, fwork_set, *p_tree);
|
||||
// create histogram
|
||||
this->CreateHist(gpair, p_fmat, info, fwork_set, *p_tree);
|
||||
this->CreateHist(gpair, p_fmat, fwork_set, *p_tree);
|
||||
// find split based on histogram statistics
|
||||
this->FindSplit(depth, gpair, p_fmat, info, fwork_set, p_tree);
|
||||
this->FindSplit(depth, gpair, p_fmat, fwork_set, p_tree);
|
||||
// reset position after split
|
||||
this->ResetPositionAfterSplit(p_fmat, *p_tree);
|
||||
this->UpdateQueueExpand(*p_tree);
|
||||
@@ -144,19 +139,18 @@ class HistMaker: public BaseMaker {
|
||||
}
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
|
||||
}
|
||||
}
|
||||
// this function does two jobs
|
||||
// (1) reset the position in array position, to be the latest leaf id
|
||||
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
|
||||
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector <bst_uint> &fset,
|
||||
const RegTree &tree) = 0;
|
||||
// initialize the current working set of features in this round
|
||||
virtual void InitWorkSet(IFMatrix *p_fmat,
|
||||
virtual void InitWorkSet(DMatrix *p_fmat,
|
||||
const RegTree &tree,
|
||||
std::vector<bst_uint> *p_fset) {
|
||||
p_fset->resize(tree.param.num_feature);
|
||||
@@ -165,12 +159,11 @@ class HistMaker: public BaseMaker {
|
||||
}
|
||||
}
|
||||
// reset position after split, this is not a must, depending on implementation
|
||||
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
||||
virtual void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
}
|
||||
virtual void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector <bst_uint> &fset,
|
||||
const RegTree &tree) = 0;
|
||||
|
||||
@@ -212,8 +205,7 @@ class HistMaker: public BaseMaker {
|
||||
}
|
||||
inline void FindSplit(int depth,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector <bst_uint> &fset,
|
||||
RegTree *p_tree) {
|
||||
const size_t num_feature = fset.size();
|
||||
@@ -224,8 +216,7 @@ class HistMaker: public BaseMaker {
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
|
||||
const int nid = qexpand[wid];
|
||||
utils::Assert(node2workindex[nid] == static_cast<int>(wid),
|
||||
"node2workindex inconsistent");
|
||||
CHECK_EQ(node2workindex[nid], static_cast<int>(wid));
|
||||
SplitEntry &best = sol[wid];
|
||||
TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
for (size_t i = 0; i < fset.size(); ++i) {
|
||||
@@ -255,7 +246,7 @@ class HistMaker: public BaseMaker {
|
||||
this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]);
|
||||
this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum);
|
||||
} else {
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -279,10 +270,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
*/
|
||||
inline void Add(bst_float fv,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
const MetaInfo &info,
|
||||
const bst_uint ridx) {
|
||||
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
|
||||
utils::Assert(istart != hist.size, "the bound variable must be max");
|
||||
CHECK_NE(istart, hist.size);
|
||||
hist.data[istart].Add(gpair, info, ridx);
|
||||
}
|
||||
/*!
|
||||
@@ -292,25 +283,25 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
inline void Add(bst_float fv,
|
||||
bst_gpair gstats) {
|
||||
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
|
||||
utils::Assert(istart != hist.size, "the bound variable must be max");
|
||||
CHECK_NE(istart, hist.size);
|
||||
hist.data[istart].Add(gstats);
|
||||
}
|
||||
};
|
||||
// sketch type used for this
|
||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
// initialize the work set of tree
|
||||
virtual void InitWorkSet(IFMatrix *p_fmat,
|
||||
const RegTree &tree,
|
||||
std::vector<bst_uint> *p_fset) {
|
||||
void InitWorkSet(DMatrix *p_fmat,
|
||||
const RegTree &tree,
|
||||
std::vector<bst_uint> *p_fset) override {
|
||||
feat_helper.InitByCol(p_fmat, tree);
|
||||
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
||||
}
|
||||
// code to create histogram
|
||||
virtual void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<bst_uint> &fset,
|
||||
const RegTree &tree) {
|
||||
void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<bst_uint> &fset,
|
||||
const RegTree &tree) override {
|
||||
const MetaInfo &info = p_fmat->info();
|
||||
// fill in reverse map
|
||||
feat2workindex.resize(tree.param.num_feature);
|
||||
std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
|
||||
@@ -327,7 +318,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
{
|
||||
thread_hist.resize(this->get_nthread());
|
||||
// start accumulating statistics
|
||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset);
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fset);
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
@@ -353,21 +344,22 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// sync the histogram
|
||||
// if it is C++11, use lazy evaluation for Allreduce
|
||||
#if __cplusplus >= 201103L
|
||||
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data),
|
||||
this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
|
||||
this->wspace.hset[0].data.size(), lazy_get_hist);
|
||||
#else
|
||||
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
||||
this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
|
||||
this->wspace.hset[0].data.size());
|
||||
#endif
|
||||
}
|
||||
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||
const RegTree &tree) override {
|
||||
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
||||
}
|
||||
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<bst_uint> &fset,
|
||||
const RegTree &tree) {
|
||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<bst_uint> &fset,
|
||||
const RegTree &tree) override {
|
||||
const MetaInfo &info = p_fmat->info();
|
||||
// fill in reverse map
|
||||
feat2workindex.resize(tree.param.num_feature);
|
||||
std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
|
||||
@@ -380,7 +372,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
feat2workindex[fset[i]] = -2;
|
||||
}
|
||||
}
|
||||
this->GetNodeStats(gpair, *p_fmat, tree, info,
|
||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||
&thread_stats, &node_stats);
|
||||
sketchs.resize(this->qexpand.size() * freal_set.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
@@ -403,7 +395,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// number of rows in
|
||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
||||
// start accumulating statistics
|
||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(freal_set);
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
@@ -422,18 +414,19 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
common::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
utils::Assert(summary_array.size() == sketchs.size(), "shape mismatch");
|
||||
CHECK_EQ(summary_array.size(), sketchs.size());
|
||||
};
|
||||
if (summary_array.size() != 0) {
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
#if __cplusplus >= 201103L
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size(), lazy_get_summary);
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array),
|
||||
nbytes, summary_array.size(), lazy_get_summary);
|
||||
#else
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
#endif
|
||||
}
|
||||
// now we get the final result of sketch, setup the cut
|
||||
@@ -460,7 +453,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size()));
|
||||
} else {
|
||||
utils::Assert(offset == -2, "BUG in mark");
|
||||
CHECK_EQ(offset, -2);
|
||||
bst_float cpt = feat_helper.MaxValue(fset[i]);
|
||||
this->wspace.cut.push_back(cpt + fabs(cpt) + rt_eps);
|
||||
this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size()));
|
||||
@@ -470,15 +463,14 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
this->wspace.cut.push_back(0.0f);
|
||||
this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size()));
|
||||
}
|
||||
utils::Assert(this->wspace.rptr.size() ==
|
||||
(fset.size() + 1) * this->qexpand.size() + 1,
|
||||
"cut space inconsistent");
|
||||
CHECK_EQ(this->wspace.rptr.size(),
|
||||
(fset.size() + 1) * this->qexpand.size() + 1);
|
||||
}
|
||||
|
||||
private:
|
||||
inline void UpdateHistCol(const std::vector<bst_gpair> &gpair,
|
||||
const ColBatch::Inst &c,
|
||||
const BoosterInfo &info,
|
||||
const MetaInfo &info,
|
||||
const RegTree &tree,
|
||||
const std::vector<bst_uint> &fset,
|
||||
bst_uint fid_offset,
|
||||
@@ -623,11 +615,11 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// set of index from fset that are real
|
||||
std::vector<bst_uint> freal_set;
|
||||
// thread temp data
|
||||
std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
||||
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
||||
// used to hold statistics
|
||||
std::vector< std::vector<TStats> > thread_stats;
|
||||
std::vector<std::vector<TStats> > thread_stats;
|
||||
// used to hold start pointer
|
||||
std::vector< std::vector<HistEntry> > thread_hist;
|
||||
std::vector<std::vector<HistEntry> > thread_hist;
|
||||
// node statistics
|
||||
std::vector<TStats> node_stats;
|
||||
// summary array
|
||||
@@ -635,18 +627,18 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// reducer for summary
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
template<typename TStats>
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
protected:
|
||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector <bst_uint> &fset,
|
||||
const RegTree &tree) {
|
||||
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector <bst_uint> &fset,
|
||||
const RegTree &tree) override {
|
||||
const MetaInfo &info = p_fmat->info();
|
||||
// initialize the data structure
|
||||
int nthread = BaseMaker::get_nthread();
|
||||
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
|
||||
@@ -654,12 +646,13 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||
}
|
||||
// start accumulating statistics
|
||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
// parallel convert to column major format
|
||||
utils::ParallelGroupBuilder<SparseBatch::Entry> builder(&col_ptr, &col_data, &thread_col_ptr);
|
||||
common::ParallelGroupBuilder<SparseBatch::Entry>
|
||||
builder(&col_ptr, &col_data, &thread_col_ptr);
|
||||
builder.InitBudget(tree.param.num_feature, nthread);
|
||||
|
||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
|
||||
@@ -711,14 +704,14 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
// synchronize sketch
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
common::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
this->wspace.rptr.clear();
|
||||
@@ -745,9 +738,8 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
this->wspace.cut.push_back(0.0f);
|
||||
this->wspace.rptr.push_back(this->wspace.cut.size());
|
||||
}
|
||||
utils::Assert(this->wspace.rptr.size() ==
|
||||
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
||||
"cut space inconsistent");
|
||||
CHECK_EQ(this->wspace.rptr.size(),
|
||||
(tree.param.num_feature + 1) * this->qexpand.size() + 1);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -759,11 +751,15 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
std::vector<size_t> col_ptr;
|
||||
// local storage of column data
|
||||
std::vector<SparseBatch::Entry> col_data;
|
||||
std::vector< std::vector<size_t> > thread_col_ptr;
|
||||
std::vector<std::vector<size_t> > thread_col_ptr;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
||||
.describe("Tree constructor that uses approximate histogram construction.")
|
||||
.set_body([]() {
|
||||
return new CQHistMaker<GradStats>();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_HISTMAKER_INL_HPP_
|
||||
|
||||
@@ -1,43 +1,42 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file updater_prune-inl.hpp
|
||||
* \file updater_prune.cc
|
||||
* \brief prune a tree given the statistics
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
|
||||
#define XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
|
||||
|
||||
#include <vector>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "./param.h"
|
||||
#include "./updater.h"
|
||||
#include "./updater_sync-inl.hpp"
|
||||
#include "../common/sync.h"
|
||||
#include "../common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*! \brief pruner that prunes a tree after growing finishes */
|
||||
class TreePruner: public IUpdater {
|
||||
class TreePruner: public TreeUpdater {
|
||||
public:
|
||||
virtual ~TreePruner(void) {}
|
||||
TreePruner() {
|
||||
syncher.reset(TreeUpdater::Create("sync"));
|
||||
}
|
||||
// set training parameter
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
using namespace std;
|
||||
param.SetParam(name, val);
|
||||
syncher.SetParam(name, val);
|
||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param.Init(args);
|
||||
syncher->Init(args);
|
||||
}
|
||||
// update the tree, do pruning
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
void Update(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
float lr = param.eta;
|
||||
param.eta = lr / trees.size();
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
this->DoPrune(*trees[i]);
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
syncher.Update(gpair, p_fmat, info, trees);
|
||||
param.eta = lr;
|
||||
syncher->Update(gpair, p_fmat, trees);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -49,9 +48,9 @@ class TreePruner: public IUpdater {
|
||||
++s.leaf_child_cnt;
|
||||
if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) {
|
||||
// need to be pruned
|
||||
tree.ChangeToLeaf(pid, param.learning_rate * s.base_weight);
|
||||
tree.ChangeToLeaf(pid, param.eta * s.base_weight);
|
||||
// tail recursion
|
||||
return this->TryPruneLeaf(tree, pid, depth - 1, npruned+2);
|
||||
return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2);
|
||||
} else {
|
||||
return npruned;
|
||||
}
|
||||
@@ -68,20 +67,24 @@ class TreePruner: public IUpdater {
|
||||
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
|
||||
}
|
||||
}
|
||||
if (silent == 0) {
|
||||
utils::Printf("tree pruning end, %d roots, %d extra nodes, %d pruned nodes, max_depth=%d\n",
|
||||
tree.param.num_roots, tree.num_extra_nodes(), npruned, tree.MaxDepth());
|
||||
if (!param.silent) {
|
||||
LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, "
|
||||
<< tree.num_extra_nodes() << " extra nodes, " << npruned
|
||||
<< " pruned nodes, max_depth=" << tree.MaxDepth();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// synchronizer
|
||||
TreeSyncher syncher;
|
||||
// shutup
|
||||
int silent;
|
||||
std::unique_ptr<TreeUpdater> syncher;
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||
.describe("Pruner that prune the tree according to statistics.")
|
||||
.set_body([]() {
|
||||
return new TreePruner();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
|
||||
|
||||
@@ -1,39 +1,34 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file updater_refresh-inl.hpp
|
||||
* \file updater_refresh.cc
|
||||
* \brief refresh the statistics and leaf value on the tree on the dataset
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_
|
||||
#define XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_
|
||||
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include "../sync/sync.h"
|
||||
#include "./param.h"
|
||||
#include "./updater.h"
|
||||
#include "../utils/omp.h"
|
||||
#include "../common/sync.h"
|
||||
#include "../common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*! \brief pruner that prunes a tree after growing finishs */
|
||||
template<typename TStats>
|
||||
class TreeRefresher: public IUpdater {
|
||||
class TreeRefresher: public TreeUpdater {
|
||||
public:
|
||||
virtual ~TreeRefresher(void) {}
|
||||
// set training parameter
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
param.SetParam(name, val);
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param.Init(args);
|
||||
}
|
||||
// update the tree, do pruning
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
void Update(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
if (trees.size() == 0) return;
|
||||
// number of threads
|
||||
// thread temporal space
|
||||
std::vector< std::vector<TStats> > stemp;
|
||||
std::vector<std::vector<TStats> > stemp;
|
||||
std::vector<RegTree::FVec> fvec_temp;
|
||||
// setup temp space for each thread
|
||||
int nthread;
|
||||
@@ -60,13 +55,13 @@ class TreeRefresher: public IUpdater {
|
||||
auto lazy_get_stats = [&]()
|
||||
#endif
|
||||
{
|
||||
const MetaInfo &info = p_fmat->info();
|
||||
// start accumulating statistics
|
||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
utils::Check(batch.size < std::numeric_limits<unsigned>::max(),
|
||||
"too large batch size ");
|
||||
CHECK_LT(batch.size, std::numeric_limits<unsigned>::max());
|
||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
@@ -78,7 +73,7 @@ class TreeRefresher: public IUpdater {
|
||||
int offset = 0;
|
||||
for (size_t j = 0; j < trees.size(); ++j) {
|
||||
AddStats(*trees[j], feats, gpair, info, ridx,
|
||||
BeginPtr(stemp[tid]) + offset);
|
||||
dmlc::BeginPtr(stemp[tid]) + offset);
|
||||
offset += trees[j]->param.num_nodes;
|
||||
}
|
||||
feats.Drop(inst);
|
||||
@@ -94,29 +89,29 @@ class TreeRefresher: public IUpdater {
|
||||
}
|
||||
};
|
||||
#if __cplusplus >= 201103L
|
||||
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
|
||||
reducer.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
|
||||
#else
|
||||
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
|
||||
reducer.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size());
|
||||
#endif
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
float lr = param.eta;
|
||||
param.eta = lr / trees.size();
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
|
||||
this->Refresh(BeginPtr(stemp[0]) + offset, rid, trees[i]);
|
||||
this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, rid, trees[i]);
|
||||
}
|
||||
offset += trees[i]->param.num_nodes;
|
||||
}
|
||||
// set learning rate back
|
||||
param.learning_rate = lr;
|
||||
param.eta = lr;
|
||||
}
|
||||
|
||||
private:
|
||||
inline static void AddStats(const RegTree &tree,
|
||||
const RegTree::FVec &feat,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
const MetaInfo &info,
|
||||
const bst_uint ridx,
|
||||
TStats *gstats) {
|
||||
// start from groups that belongs to current data
|
||||
@@ -136,7 +131,7 @@ class TreeRefresher: public IUpdater {
|
||||
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
|
||||
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
|
||||
if (tree[nid].is_leaf()) {
|
||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.eta);
|
||||
} else {
|
||||
tree.stat(nid).loss_chg = static_cast<float>(
|
||||
gstats[tree[nid].cleft()].CalcGain(param) +
|
||||
@@ -152,6 +147,10 @@ class TreeRefresher: public IUpdater {
|
||||
rabit::Reducer<TStats, TStats::Reduce> reducer;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||
.set_body([]() {
|
||||
return new TreeRefresher<GradStats>();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_
|
||||
|
||||
@@ -1,57 +1,56 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file updater_skmaker-inl.hpp
|
||||
* \file updater_skmaker.cc
|
||||
* \brief use approximation sketch to construct a tree,
|
||||
a refresh is needed to make the statistics exactly correct
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_UPDATER_SKMAKER_INL_HPP_
|
||||
#define XGBOOST_TREE_UPDATER_SKMAKER_INL_HPP_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/quantile.h"
|
||||
#include "./updater_basemaker-inl.hpp"
|
||||
#include "../common/sync.h"
|
||||
#include "../common/quantile.h"
|
||||
#include "../common/group_data.h"
|
||||
#include "./updater_basemaker-inl.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
class SketchMaker: public BaseMaker {
|
||||
public:
|
||||
virtual ~SketchMaker(void) {}
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
void Update(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
float lr = param.eta;
|
||||
param.eta = lr / trees.size();
|
||||
// build tree
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
this->Update(gpair, p_fmat, info, trees[i]);
|
||||
this->Update(gpair, p_fmat, trees[i]);
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
param.eta = lr;
|
||||
}
|
||||
|
||||
protected:
|
||||
inline void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
RegTree *p_tree) {
|
||||
this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree) {
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||
this->GetNodeStats(gpair, *p_fmat, *p_tree, info,
|
||||
this->GetNodeStats(gpair, *p_fmat, *p_tree,
|
||||
&thread_stats, &node_stats);
|
||||
this->BuildSketch(gpair, p_fmat, info, *p_tree);
|
||||
this->BuildSketch(gpair, p_fmat, *p_tree);
|
||||
this->SyncNodeStats();
|
||||
this->FindSplit(depth, gpair, p_fmat, info, p_tree);
|
||||
this->FindSplit(depth, gpair, p_fmat, p_tree);
|
||||
this->ResetPositionCol(qexpand, p_fmat, *p_tree);
|
||||
this->UpdateQueueExpand(*p_tree);
|
||||
// if nothing left to be expand, break
|
||||
if (qexpand.size() == 0) break;
|
||||
}
|
||||
if (qexpand.size() != 0) {
|
||||
this->GetNodeStats(gpair, *p_fmat, *p_tree, info,
|
||||
this->GetNodeStats(gpair, *p_fmat, *p_tree,
|
||||
&thread_stats, &node_stats);
|
||||
this->SyncNodeStats();
|
||||
}
|
||||
@@ -68,11 +67,11 @@ class SketchMaker: public BaseMaker {
|
||||
// set left leaves
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
|
||||
}
|
||||
}
|
||||
// define the sketch we want to use
|
||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
|
||||
private:
|
||||
// statistics needed in the gradient calculation
|
||||
@@ -94,7 +93,7 @@ class SketchMaker: public BaseMaker {
|
||||
}
|
||||
// accumulate statistics
|
||||
inline void Add(const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
const MetaInfo &info,
|
||||
bst_uint ridx) {
|
||||
const bst_gpair &b = gpair[ridx];
|
||||
if (b.grad >= 0.0f) {
|
||||
@@ -133,9 +132,9 @@ class SketchMaker: public BaseMaker {
|
||||
}
|
||||
};
|
||||
inline void BuildSketch(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
DMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
const MetaInfo& info = p_fmat->info();
|
||||
sketchs.resize(this->qexpand.size() * tree.param.num_feature * 3);
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||
@@ -144,7 +143,7 @@ class SketchMaker: public BaseMaker {
|
||||
// number of rows in
|
||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
||||
// start accumulating statistics
|
||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
@@ -164,13 +163,13 @@ class SketchMaker: public BaseMaker {
|
||||
// synchronize sketch
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
common::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sketch_reducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
sketch_reducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
}
|
||||
// update sketch information in column fid
|
||||
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
||||
@@ -256,20 +255,19 @@ class SketchMaker: public BaseMaker {
|
||||
}
|
||||
}
|
||||
inline void SyncNodeStats(void) {
|
||||
utils::Assert(qexpand.size() != 0, "qexpand must not be empty");
|
||||
CHECK_NE(qexpand.size(), 0);
|
||||
std::vector<SKStats> tmp(qexpand.size());
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
tmp[i] = node_stats[qexpand[i]];
|
||||
}
|
||||
stats_reducer.Allreduce(BeginPtr(tmp), tmp.size());
|
||||
stats_reducer.Allreduce(dmlc::BeginPtr(tmp), tmp.size());
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
node_stats[qexpand[i]] = tmp[i];
|
||||
}
|
||||
}
|
||||
inline void FindSplit(int depth,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree) {
|
||||
const bst_uint num_feature = p_tree->param.num_feature;
|
||||
// get the best split condition for each node
|
||||
@@ -278,8 +276,7 @@ class SketchMaker: public BaseMaker {
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
|
||||
const int nid = qexpand[wid];
|
||||
utils::Assert(node2workindex[nid] == static_cast<int>(wid),
|
||||
"node2workindex inconsistent");
|
||||
CHECK_EQ(node2workindex[nid], static_cast<int>(wid));
|
||||
SplitEntry &best = sol[wid];
|
||||
for (bst_uint fid = 0; fid < num_feature; ++fid) {
|
||||
unsigned base = (wid * p_tree->param.num_feature + fid) * 3;
|
||||
@@ -305,7 +302,7 @@ class SketchMaker: public BaseMaker {
|
||||
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
|
||||
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
|
||||
} else {
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
|
||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -380,9 +377,9 @@ class SketchMaker: public BaseMaker {
|
||||
|
||||
// thread temp data
|
||||
// used to hold temporal sketch
|
||||
std::vector< std::vector<SketchEntry> > thread_sketch;
|
||||
std::vector<std::vector<SketchEntry> > thread_sketch;
|
||||
// used to hold statistics
|
||||
std::vector< std::vector<SKStats> > thread_stats;
|
||||
std::vector<std::vector<SKStats> > thread_stats;
|
||||
// node statistics
|
||||
std::vector<SKStats> node_stats;
|
||||
// summary array
|
||||
@@ -392,8 +389,14 @@ class SketchMaker: public BaseMaker {
|
||||
// reducer for summary
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sketch_reducer;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(SketchMaker, "grow_skmaker")
|
||||
.describe("Approximate sketching maker.")
|
||||
.set_body([]() {
|
||||
return new SketchMaker();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_SKMAKER_INL_HPP_
|
||||
|
||||
|
||||
Reference in New Issue
Block a user