[TREE] Experimental version of monotone constraint (#1516)
* [TREE] Experimental version of monotone constraint * Allow default detection of montone option * loose the condition of strict check * Update gbtree.cc
This commit is contained in:
parent
8cac37b2b4
commit
c93c9b7ed6
@ -15,6 +15,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <unordered_map>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
|
|
||||||
|
|||||||
177
src/tree/param.h
177
src/tree/param.h
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -55,6 +56,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
bool cache_opt;
|
bool cache_opt;
|
||||||
// whether to not print info during training.
|
// whether to not print info during training.
|
||||||
bool silent;
|
bool silent;
|
||||||
|
// auxiliary data structure
|
||||||
|
std::vector<int> monotone_constraints;
|
||||||
// declare the parameters
|
// declare the parameters
|
||||||
DMLC_DECLARE_PARAMETER(TrainParam) {
|
DMLC_DECLARE_PARAMETER(TrainParam) {
|
||||||
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
|
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
|
||||||
@ -97,13 +100,20 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
.describe("EXP Param: Cache aware optimization.");
|
.describe("EXP Param: Cache aware optimization.");
|
||||||
DMLC_DECLARE_FIELD(silent).set_default(false)
|
DMLC_DECLARE_FIELD(silent).set_default(false)
|
||||||
.describe("Do not print information during trainig.");
|
.describe("Do not print information during trainig.");
|
||||||
|
DMLC_DECLARE_FIELD(monotone_constraints).set_default(std::vector<int>())
|
||||||
|
.describe("Constraint of variable monotinicity");
|
||||||
// add alias of parameters
|
// add alias of parameters
|
||||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||||
DMLC_DECLARE_ALIAS(min_split_loss, gamma);
|
DMLC_DECLARE_ALIAS(min_split_loss, gamma);
|
||||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||||
}
|
}
|
||||||
|
// calculate the cost of loss function
|
||||||
|
inline double CalcGainGivenWeight(double sum_grad,
|
||||||
|
double sum_hess,
|
||||||
|
double w) const {
|
||||||
|
return -(2.0 * sum_grad * w + (sum_hess + reg_lambda) * Sqr(w));
|
||||||
|
}
|
||||||
// calculate the cost of loss function
|
// calculate the cost of loss function
|
||||||
inline double CalcGain(double sum_grad, double sum_hess) const {
|
inline double CalcGain(double sum_grad, double sum_hess) const {
|
||||||
if (sum_hess < min_child_weight) return 0.0;
|
if (sum_hess < min_child_weight) return 0.0;
|
||||||
@ -262,6 +272,102 @@ struct GradStats {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct NoConstraint {
|
||||||
|
inline static void Init(TrainParam* param, unsigned num_feature) {
|
||||||
|
}
|
||||||
|
inline double CalcSplitGain(
|
||||||
|
const TrainParam& param, bst_uint split_index,
|
||||||
|
GradStats left, GradStats right) const {
|
||||||
|
return left.CalcGain(param) + right.CalcGain(param);
|
||||||
|
}
|
||||||
|
inline double CalcWeight(
|
||||||
|
const TrainParam& param,
|
||||||
|
GradStats stats) const {
|
||||||
|
return stats.CalcWeight(param);
|
||||||
|
}
|
||||||
|
inline double CalcGain(const TrainParam& param,
|
||||||
|
GradStats stats) const {
|
||||||
|
return stats.CalcGain(param);
|
||||||
|
}
|
||||||
|
inline void SetChild(
|
||||||
|
const TrainParam& param, bst_uint split_index,
|
||||||
|
GradStats left, GradStats right,
|
||||||
|
NoConstraint* cleft, NoConstraint* cright) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ValueConstraint {
|
||||||
|
double lower_bound;
|
||||||
|
double upper_bound;
|
||||||
|
ValueConstraint() :
|
||||||
|
lower_bound(-std::numeric_limits<double>::max()),
|
||||||
|
upper_bound(std::numeric_limits<double>::max()) {
|
||||||
|
}
|
||||||
|
inline static void Init(TrainParam* param, unsigned num_feature) {
|
||||||
|
param->monotone_constraints.resize(num_feature, 1);
|
||||||
|
}
|
||||||
|
inline double CalcWeight(
|
||||||
|
const TrainParam& param,
|
||||||
|
GradStats stats) const {
|
||||||
|
double w = stats.CalcWeight(param);
|
||||||
|
if (w < lower_bound) {
|
||||||
|
return lower_bound;
|
||||||
|
}
|
||||||
|
if (w > upper_bound) {
|
||||||
|
return upper_bound;
|
||||||
|
}
|
||||||
|
return w;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline double CalcGain(const TrainParam& param,
|
||||||
|
GradStats stats) const {
|
||||||
|
return param.CalcGainGivenWeight(
|
||||||
|
stats.sum_grad, stats.sum_hess,
|
||||||
|
CalcWeight(param, stats));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline double CalcSplitGain(
|
||||||
|
const TrainParam& param,
|
||||||
|
bst_uint split_index,
|
||||||
|
GradStats left, GradStats right) const {
|
||||||
|
double wleft = CalcWeight(param, left);
|
||||||
|
double wright = CalcWeight(param, right);
|
||||||
|
int c = param.monotone_constraints[split_index];
|
||||||
|
double gain =
|
||||||
|
param.CalcGainGivenWeight(left.sum_grad, left.sum_hess, wleft) +
|
||||||
|
param.CalcGainGivenWeight(right.sum_grad, right.sum_hess, wright);
|
||||||
|
if (c == 0) {
|
||||||
|
return gain;
|
||||||
|
} else if (c > 0) {
|
||||||
|
return wleft < wright ? gain : 0.0;
|
||||||
|
} else {
|
||||||
|
return wleft > wright ? gain : 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetChild(
|
||||||
|
const TrainParam& param,
|
||||||
|
bst_uint split_index,
|
||||||
|
GradStats left, GradStats right,
|
||||||
|
ValueConstraint* cleft, ValueConstraint *cright) {
|
||||||
|
int c = param.monotone_constraints.at(split_index);
|
||||||
|
*cleft = *this;
|
||||||
|
*cright = *this;
|
||||||
|
if (c == 0) return;
|
||||||
|
double wleft = CalcWeight(param, left);
|
||||||
|
double wright = CalcWeight(param, right);
|
||||||
|
double mid = (wleft + wright) / 2;
|
||||||
|
CHECK(!std::isnan(mid));
|
||||||
|
if (c < 0) {
|
||||||
|
cleft->lower_bound = mid;
|
||||||
|
cright->upper_bound = mid;
|
||||||
|
} else {
|
||||||
|
cleft->upper_bound = mid;
|
||||||
|
cright->lower_bound = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief statistics that is helpful to store
|
* \brief statistics that is helpful to store
|
||||||
* and represent a split solution for the tree
|
* and represent a split solution for the tree
|
||||||
@ -340,4 +446,73 @@ struct SplitEntry {
|
|||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
// define string serializer for vector, to get the arguments
|
||||||
|
namespace std {
|
||||||
|
inline std::ostream &operator<<(std::ostream &os, const std::vector<int> &t) {
|
||||||
|
os << '(';
|
||||||
|
for (std::vector<int>::const_iterator
|
||||||
|
it = t.begin(); it != t.end(); ++it) {
|
||||||
|
if (it != t.begin()) os << ',';
|
||||||
|
os << *it;
|
||||||
|
}
|
||||||
|
// python style tuple
|
||||||
|
if (t.size() == 1) os << ',';
|
||||||
|
os << ')';
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
|
||||||
|
// get (
|
||||||
|
while (true) {
|
||||||
|
char ch = is.peek();
|
||||||
|
if (isdigit(ch)) {
|
||||||
|
int idx;
|
||||||
|
if (is >> idx) {
|
||||||
|
t.assign(&idx, &idx + 1);
|
||||||
|
}
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
is.get();
|
||||||
|
if (ch == '(') break;
|
||||||
|
if (!isspace(ch)) {
|
||||||
|
is.setstate(std::ios::failbit);
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int idx;
|
||||||
|
std::vector<int> tmp;
|
||||||
|
while (is >> idx) {
|
||||||
|
tmp.push_back(idx);
|
||||||
|
char ch;
|
||||||
|
do {
|
||||||
|
ch = is.get();
|
||||||
|
} while (isspace(ch));
|
||||||
|
if (ch == 'L') {
|
||||||
|
ch = is.get();
|
||||||
|
}
|
||||||
|
if (ch == ',') {
|
||||||
|
while (true) {
|
||||||
|
ch = is.peek();
|
||||||
|
if (isspace(ch)) {
|
||||||
|
is.get(); continue;
|
||||||
|
}
|
||||||
|
if (ch == ')') {
|
||||||
|
is.get(); break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (ch == ')') break;
|
||||||
|
} else if (ch == ')') {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
is.setstate(std::ios::failbit);
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.assign(tmp.begin(), tmp.end());
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
} // namespace std
|
||||||
|
|
||||||
#endif // XGBOOST_TREE_PARAM_H_
|
#endif // XGBOOST_TREE_PARAM_H_
|
||||||
|
|||||||
@ -19,7 +19,7 @@ namespace tree {
|
|||||||
DMLC_REGISTRY_FILE_TAG(updater_colmaker);
|
DMLC_REGISTRY_FILE_TAG(updater_colmaker);
|
||||||
|
|
||||||
/*! \brief column-wise update to construct a tree */
|
/*! \brief column-wise update to construct a tree */
|
||||||
template<typename TStats>
|
template<typename TStats, typename TConstraint>
|
||||||
class ColMaker: public TreeUpdater {
|
class ColMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
@ -33,6 +33,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param.learning_rate;
|
float lr = param.learning_rate;
|
||||||
param.learning_rate = lr / trees.size();
|
param.learning_rate = lr / trees.size();
|
||||||
|
TConstraint::Init(¶m, dmat->info().num_col);
|
||||||
// build tree
|
// build tree
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
Builder builder(param);
|
Builder builder(param);
|
||||||
@ -199,6 +200,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
stemp[i].resize(tree.param.num_nodes, ThreadEntry(param));
|
stemp[i].resize(tree.param.num_nodes, ThreadEntry(param));
|
||||||
}
|
}
|
||||||
snode.resize(tree.param.num_nodes, NodeEntry(param));
|
snode.resize(tree.param.num_nodes, NodeEntry(param));
|
||||||
|
constraints_.resize(tree.param.num_nodes);
|
||||||
}
|
}
|
||||||
const RowSet &rowset = fmat.buffered_rowset();
|
const RowSet &rowset = fmat.buffered_rowset();
|
||||||
const MetaInfo& info = fmat.info();
|
const MetaInfo& info = fmat.info();
|
||||||
@ -220,8 +222,25 @@ class ColMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
// update node statistics
|
// update node statistics
|
||||||
snode[nid].stats = stats;
|
snode[nid].stats = stats;
|
||||||
snode[nid].root_gain = static_cast<float>(stats.CalcGain(param));
|
}
|
||||||
snode[nid].weight = static_cast<float>(stats.CalcWeight(param));
|
// setup constraints before calculating the weight
|
||||||
|
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||||
|
const int nid = qexpand[j];
|
||||||
|
if (tree[nid].is_root()) continue;
|
||||||
|
const int pid = tree[nid].parent();
|
||||||
|
constraints_[pid].SetChild(param, tree[pid].split_index(),
|
||||||
|
snode[tree[pid].cleft()].stats,
|
||||||
|
snode[tree[pid].cright()].stats,
|
||||||
|
&constraints_[tree[pid].cleft()],
|
||||||
|
&constraints_[tree[pid].cright()]);
|
||||||
|
}
|
||||||
|
// calculating the weights
|
||||||
|
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||||
|
const int nid = qexpand[j];
|
||||||
|
snode[nid].root_gain = static_cast<float>(
|
||||||
|
constraints_[nid].CalcGain(param, snode[nid].stats));
|
||||||
|
snode[nid].weight = static_cast<float>(
|
||||||
|
constraints_[nid].CalcWeight(param, snode[nid].stats));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*! \brief update queue expand add in new leaves */
|
/*! \brief update queue expand add in new leaves */
|
||||||
@ -244,6 +263,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
bst_uint fid,
|
bst_uint fid,
|
||||||
const DMatrix &fmat,
|
const DMatrix &fmat,
|
||||||
const std::vector<bst_gpair> &gpair) {
|
const std::vector<bst_gpair> &gpair) {
|
||||||
|
// TODO(tqchen): double check stats order.
|
||||||
const MetaInfo& info = fmat.info();
|
const MetaInfo& info = fmat.info();
|
||||||
const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue;
|
const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue;
|
||||||
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind);
|
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind);
|
||||||
@ -303,8 +323,8 @@ class ColMaker: public TreeUpdater {
|
|||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight &&
|
if (c.sum_hess >= param.min_child_weight &&
|
||||||
e.stats.sum_hess >= param.min_child_weight) {
|
e.stats.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) +
|
bst_float loss_chg = static_cast<bst_float>(
|
||||||
c.CalcGain(param) - snode[nid].root_gain);
|
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, fsplit, false);
|
e.best.Update(loss_chg, fid, fsplit, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,8 +333,8 @@ class ColMaker: public TreeUpdater {
|
|||||||
c.SetSubstract(snode[nid].stats, tmp);
|
c.SetSubstract(snode[nid].stats, tmp);
|
||||||
if (c.sum_hess >= param.min_child_weight &&
|
if (c.sum_hess >= param.min_child_weight &&
|
||||||
tmp.sum_hess >= param.min_child_weight) {
|
tmp.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(tmp.CalcGain(param) +
|
bst_float loss_chg = static_cast<bst_float>(
|
||||||
c.CalcGain(param) - snode[nid].root_gain);
|
constraints_[nid].CalcSplitGain(param, fid, tmp, c) - snode[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, fsplit, true);
|
e.best.Update(loss_chg, fid, fsplit, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -325,8 +345,8 @@ class ColMaker: public TreeUpdater {
|
|||||||
c.SetSubstract(snode[nid].stats, tmp);
|
c.SetSubstract(snode[nid].stats, tmp);
|
||||||
if (c.sum_hess >= param.min_child_weight &&
|
if (c.sum_hess >= param.min_child_weight &&
|
||||||
tmp.sum_hess >= param.min_child_weight) {
|
tmp.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(tmp.CalcGain(param) +
|
bst_float loss_chg = static_cast<bst_float>(
|
||||||
c.CalcGain(param) - snode[nid].root_gain);
|
constraints_[nid].CalcSplitGain(param, fid, tmp, c) - snode[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + rt_eps, true);
|
e.best.Update(loss_chg, fid, e.last_fvalue + rt_eps, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -357,9 +377,9 @@ class ColMaker: public TreeUpdater {
|
|||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight &&
|
if (c.sum_hess >= param.min_child_weight &&
|
||||||
e.stats.sum_hess >= param.min_child_weight) {
|
e.stats.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) +
|
bst_float loss_chg = static_cast<bst_float>(
|
||||||
c.CalcGain(param) -
|
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) -
|
||||||
snode[nid].root_gain);
|
snode[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, false);
|
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -368,9 +388,9 @@ class ColMaker: public TreeUpdater {
|
|||||||
c.SetSubstract(snode[nid].stats, cright);
|
c.SetSubstract(snode[nid].stats, cright);
|
||||||
if (c.sum_hess >= param.min_child_weight &&
|
if (c.sum_hess >= param.min_child_weight &&
|
||||||
cright.sum_hess >= param.min_child_weight) {
|
cright.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(cright.CalcGain(param) +
|
bst_float loss_chg = static_cast<bst_float>(
|
||||||
c.CalcGain(param) -
|
constraints_[nid].CalcSplitGain(param, fid, c, cright) -
|
||||||
snode[nid].root_gain);
|
snode[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true);
|
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -397,8 +417,14 @@ class ColMaker: public TreeUpdater {
|
|||||||
e.stats.sum_hess >= param.min_child_weight) {
|
e.stats.sum_hess >= param.min_child_weight) {
|
||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight) {
|
if (c.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) +
|
bst_float loss_chg;
|
||||||
c.CalcGain(param) - snode[nid].root_gain);
|
if (d_step == -1) {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, c, e.stats) - snode[nid].root_gain);
|
||||||
|
} else {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain);
|
||||||
|
}
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -467,9 +493,16 @@ class ColMaker: public TreeUpdater {
|
|||||||
const int nid = qexpand[i];
|
const int nid = qexpand[i];
|
||||||
ThreadEntry &e = temp[nid];
|
ThreadEntry &e = temp[nid];
|
||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
if (e.stats.sum_hess >= param.min_child_weight &&
|
||||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) +
|
c.sum_hess >= param.min_child_weight) {
|
||||||
c.CalcGain(param) - snode[nid].root_gain);
|
bst_float loss_chg;
|
||||||
|
if (d_step == -1) {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, c, e.stats) - snode[nid].root_gain);
|
||||||
|
} else {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain);
|
||||||
|
}
|
||||||
const float gap = std::abs(e.last_fvalue) + rt_eps;
|
const float gap = std::abs(e.last_fvalue) + rt_eps;
|
||||||
const float delta = d_step == +1 ? gap: -gap;
|
const float delta = d_step == +1 ? gap: -gap;
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
|
||||||
@ -515,8 +548,16 @@ class ColMaker: public TreeUpdater {
|
|||||||
e.stats.sum_hess >= param.min_child_weight) {
|
e.stats.sum_hess >= param.min_child_weight) {
|
||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight) {
|
if (c.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) +
|
bst_float loss_chg;
|
||||||
c.CalcGain(param) - snode[nid].root_gain);
|
if (d_step == -1) {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, c, e.stats) -
|
||||||
|
snode[nid].root_gain);
|
||||||
|
} else {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) -
|
||||||
|
snode[nid].root_gain);
|
||||||
|
}
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -531,8 +572,14 @@ class ColMaker: public TreeUpdater {
|
|||||||
ThreadEntry &e = temp[nid];
|
ThreadEntry &e = temp[nid];
|
||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
||||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) +
|
bst_float loss_chg;
|
||||||
c.CalcGain(param) - snode[nid].root_gain);
|
if (d_step == -1) {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, c, e.stats) - snode[nid].root_gain);
|
||||||
|
} else {
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain);
|
||||||
|
}
|
||||||
const float gap = std::abs(e.last_fvalue) + rt_eps;
|
const float gap = std::abs(e.last_fvalue) + rt_eps;
|
||||||
const float delta = d_step == +1 ? gap: -gap;
|
const float delta = d_step == +1 ? gap: -gap;
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
|
||||||
@ -724,12 +771,14 @@ class ColMaker: public TreeUpdater {
|
|||||||
std::vector<NodeEntry> snode;
|
std::vector<NodeEntry> snode;
|
||||||
/*! \brief queue of nodes to be expanded */
|
/*! \brief queue of nodes to be expanded */
|
||||||
std::vector<int> qexpand_;
|
std::vector<int> qexpand_;
|
||||||
|
// constraint value
|
||||||
|
std::vector<TConstraint> constraints_;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// distributed column maker
|
// distributed column maker
|
||||||
template<typename TStats>
|
template<typename TStats, typename TConstraint>
|
||||||
class DistColMaker : public ColMaker<TStats> {
|
class DistColMaker : public ColMaker<TStats, TConstraint> {
|
||||||
public:
|
public:
|
||||||
DistColMaker() : builder(param) {
|
DistColMaker() : builder(param) {
|
||||||
pruner.reset(TreeUpdater::Create("prune"));
|
pruner.reset(TreeUpdater::Create("prune"));
|
||||||
@ -755,10 +804,10 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct Builder : public ColMaker<TStats>::Builder {
|
struct Builder : public ColMaker<TStats, TConstraint>::Builder {
|
||||||
public:
|
public:
|
||||||
explicit Builder(const TrainParam ¶m)
|
explicit Builder(const TrainParam ¶m)
|
||||||
: ColMaker<TStats>::Builder(param) {
|
: ColMaker<TStats, TConstraint>::Builder(param) {
|
||||||
}
|
}
|
||||||
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
|
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
|
||||||
const RowSet &rowset = p_fmat->buffered_rowset();
|
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||||
@ -881,16 +930,56 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
Builder builder;
|
Builder builder;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// simple switch to defer implementation.
|
||||||
|
class TreeUpdaterSwitch : public TreeUpdater {
|
||||||
|
public:
|
||||||
|
TreeUpdaterSwitch() : monotone_(false) {}
|
||||||
|
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
|
for (auto &kv : args) {
|
||||||
|
if (kv.first == "monotone_constraints" && kv.second.length() != 0) {
|
||||||
|
monotone_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (inner_.get() == nullptr) {
|
||||||
|
if (monotone_) {
|
||||||
|
inner_.reset(new ColMaker<GradStats, ValueConstraint>());
|
||||||
|
} else {
|
||||||
|
inner_.reset(new ColMaker<GradStats, NoConstraint>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inner_->Init(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Update(const std::vector<bst_gpair>& gpair,
|
||||||
|
DMatrix* data,
|
||||||
|
const std::vector<RegTree*>& trees) override {
|
||||||
|
CHECK(inner_ != nullptr);
|
||||||
|
inner_->Update(gpair, data, trees);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int* GetLeafPosition() const override {
|
||||||
|
CHECK(inner_ != nullptr);
|
||||||
|
return inner_->GetLeafPosition();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// monotone constraints
|
||||||
|
bool monotone_;
|
||||||
|
// internal implementation
|
||||||
|
std::unique_ptr<TreeUpdater> inner_;
|
||||||
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||||
.describe("Grow tree with parallelization over columns.")
|
.describe("Grow tree with parallelization over columns.")
|
||||||
.set_body([]() {
|
.set_body([]() {
|
||||||
return new ColMaker<GradStats>();
|
return new TreeUpdaterSwitch();
|
||||||
});
|
});
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(DistColMaker, "distcol")
|
XGBOOST_REGISTER_TREE_UPDATER(DistColMaker, "distcol")
|
||||||
.describe("Distributed column split version of tree maker.")
|
.describe("Distributed column split version of tree maker.")
|
||||||
.set_body([]() {
|
.set_body([]() {
|
||||||
return new DistColMaker<GradStats>();
|
return new DistColMaker<GradStats, NoConstraint>();
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user