From c93c9b7ed6305c7ec35f41cc64a02954dfa6abd1 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 7 Sep 2016 21:28:43 -0700 Subject: [PATCH] [TREE] Experimental version of monotone constraint (#1516) * [TREE] Experimental version of monotone constraint * Allow default detection of montone option * loose the condition of strict check * Update gbtree.cc --- src/gbm/gbtree.cc | 1 + src/tree/param.h | 177 ++++++++++++++++++++++++++++++++++- src/tree/updater_colmaker.cc | 149 +++++++++++++++++++++++------ 3 files changed, 296 insertions(+), 31 deletions(-) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 06139adcf..6d7d79a94 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "../common/common.h" diff --git a/src/tree/param.h b/src/tree/param.h index 61ddffe33..d4254d84f 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -9,6 +9,7 @@ #include #include +#include #include namespace xgboost { @@ -55,6 +56,8 @@ struct TrainParam : public dmlc::Parameter { bool cache_opt; // whether to not print info during training. bool silent; + // auxiliary data structure + std::vector monotone_constraints; // declare the parameters DMLC_DECLARE_PARAMETER(TrainParam) { DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f) @@ -97,13 +100,20 @@ struct TrainParam : public dmlc::Parameter { .describe("EXP Param: Cache aware optimization."); DMLC_DECLARE_FIELD(silent).set_default(false) .describe("Do not print information during trainig."); + DMLC_DECLARE_FIELD(monotone_constraints).set_default(std::vector()) + .describe("Constraint of variable monotinicity"); // add alias of parameters DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_alpha, alpha); DMLC_DECLARE_ALIAS(min_split_loss, gamma); DMLC_DECLARE_ALIAS(learning_rate, eta); } - + // calculate the cost of loss function + inline double CalcGainGivenWeight(double sum_grad, + double sum_hess, + double w) const { + return -(2.0 * sum_grad * w + (sum_hess + reg_lambda) * Sqr(w)); + } // calculate the cost of loss function inline double CalcGain(double sum_grad, double sum_hess) const { if (sum_hess < min_child_weight) return 0.0; @@ -262,6 +272,102 @@ struct GradStats { } }; +struct NoConstraint { + inline static void Init(TrainParam* param, unsigned num_feature) { + } + inline double CalcSplitGain( + const TrainParam& param, bst_uint split_index, + GradStats left, GradStats right) const { + return left.CalcGain(param) + right.CalcGain(param); + } + inline double CalcWeight( + const TrainParam& param, + GradStats stats) const { + return stats.CalcWeight(param); + } + inline double CalcGain(const TrainParam& param, + GradStats stats) const { + return stats.CalcGain(param); + } + inline void SetChild( + const TrainParam& param, bst_uint split_index, + GradStats left, GradStats right, + NoConstraint* cleft, NoConstraint* cright) { + } +}; + +struct ValueConstraint { + double lower_bound; + double upper_bound; + ValueConstraint() : + lower_bound(-std::numeric_limits::max()), + upper_bound(std::numeric_limits::max()) { + } + inline static void Init(TrainParam* param, unsigned num_feature) { + param->monotone_constraints.resize(num_feature, 1); + } + inline double CalcWeight( + const TrainParam& param, + GradStats stats) const { + double w = stats.CalcWeight(param); + if (w < lower_bound) { + return lower_bound; + } + if (w > upper_bound) { + return upper_bound; + } + return w; + } + + inline double CalcGain(const TrainParam& param, + GradStats stats) const { + return param.CalcGainGivenWeight( + stats.sum_grad, stats.sum_hess, + CalcWeight(param, stats)); + } + + inline double CalcSplitGain( + const TrainParam& param, + bst_uint split_index, + GradStats left, GradStats right) const { + double wleft = CalcWeight(param, left); + double wright = CalcWeight(param, right); + int c = param.monotone_constraints[split_index]; + double gain = + param.CalcGainGivenWeight(left.sum_grad, left.sum_hess, wleft) + + param.CalcGainGivenWeight(right.sum_grad, right.sum_hess, wright); + if (c == 0) { + return gain; + } else if (c > 0) { + return wleft < wright ? gain : 0.0; + } else { + return wleft > wright ? gain : 0.0; + } + } + + inline void SetChild( + const TrainParam& param, + bst_uint split_index, + GradStats left, GradStats right, + ValueConstraint* cleft, ValueConstraint *cright) { + int c = param.monotone_constraints.at(split_index); + *cleft = *this; + *cright = *this; + if (c == 0) return; + double wleft = CalcWeight(param, left); + double wright = CalcWeight(param, right); + double mid = (wleft + wright) / 2; + CHECK(!std::isnan(mid)); + if (c < 0) { + cleft->lower_bound = mid; + cright->upper_bound = mid; + } else { + cleft->upper_bound = mid; + cright->lower_bound = mid; + } + } +}; + /*! * \brief statistics that is helpful to store * and represent a split solution for the tree @@ -340,4 +446,73 @@ struct SplitEntry { } // namespace tree } // namespace xgboost + +// define string serializer for vector, to get the arguments +namespace std { +inline std::ostream &operator<<(std::ostream &os, const std::vector &t) { + os << '('; + for (std::vector::const_iterator + it = t.begin(); it != t.end(); ++it) { + if (it != t.begin()) os << ','; + os << *it; + } + // python style tuple + if (t.size() == 1) os << ','; + os << ')'; + return os; +} + +inline std::istream &operator>>(std::istream &is, std::vector &t) { + // get ( + while (true) { + char ch = is.peek(); + if (isdigit(ch)) { + int idx; + if (is >> idx) { + t.assign(&idx, &idx + 1); + } + return is; + } + is.get(); + if (ch == '(') break; + if (!isspace(ch)) { + is.setstate(std::ios::failbit); + return is; + } + } + int idx; + std::vector tmp; + while (is >> idx) { + tmp.push_back(idx); + char ch; + do { + ch = is.get(); + } while (isspace(ch)); + if (ch == 'L') { + ch = is.get(); + } + if (ch == ',') { + while (true) { + ch = is.peek(); + if (isspace(ch)) { + is.get(); continue; + } + if (ch == ')') { + is.get(); break; + } + break; + } + if (ch == ')') break; + } else if (ch == ')') { + break; + } else { + is.setstate(std::ios::failbit); + return is; + } + } + t.assign(tmp.begin(), tmp.end()); + return is; +} +} // namespace std + #endif // XGBOOST_TREE_PARAM_H_ diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 79c013e29..0725651f0 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -19,7 +19,7 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_colmaker); /*! \brief column-wise update to construct a tree */ -template +template class ColMaker: public TreeUpdater { public: void Init(const std::vector >& args) override { @@ -33,6 +33,7 @@ class ColMaker: public TreeUpdater { // rescale learning rate according to size of trees float lr = param.learning_rate; param.learning_rate = lr / trees.size(); + TConstraint::Init(¶m, dmat->info().num_col); // build tree for (size_t i = 0; i < trees.size(); ++i) { Builder builder(param); @@ -199,6 +200,7 @@ class ColMaker: public TreeUpdater { stemp[i].resize(tree.param.num_nodes, ThreadEntry(param)); } snode.resize(tree.param.num_nodes, NodeEntry(param)); + constraints_.resize(tree.param.num_nodes); } const RowSet &rowset = fmat.buffered_rowset(); const MetaInfo& info = fmat.info(); @@ -220,8 +222,25 @@ class ColMaker: public TreeUpdater { } // update node statistics snode[nid].stats = stats; - snode[nid].root_gain = static_cast(stats.CalcGain(param)); - snode[nid].weight = static_cast(stats.CalcWeight(param)); + } + // setup constraints before calculating the weight + for (size_t j = 0; j < qexpand.size(); ++j) { + const int nid = qexpand[j]; + if (tree[nid].is_root()) continue; + const int pid = tree[nid].parent(); + constraints_[pid].SetChild(param, tree[pid].split_index(), + snode[tree[pid].cleft()].stats, + snode[tree[pid].cright()].stats, + &constraints_[tree[pid].cleft()], + &constraints_[tree[pid].cright()]); + } + // calculating the weights + for (size_t j = 0; j < qexpand.size(); ++j) { + const int nid = qexpand[j]; + snode[nid].root_gain = static_cast( + constraints_[nid].CalcGain(param, snode[nid].stats)); + snode[nid].weight = static_cast( + constraints_[nid].CalcWeight(param, snode[nid].stats)); } } /*! \brief update queue expand add in new leaves */ @@ -244,6 +263,7 @@ class ColMaker: public TreeUpdater { bst_uint fid, const DMatrix &fmat, const std::vector &gpair) { + // TODO(tqchen): double check stats order. const MetaInfo& info = fmat.info(); const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue; bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind); @@ -303,8 +323,8 @@ class ColMaker: public TreeUpdater { c.SetSubstract(snode[nid].stats, e.stats); if (c.sum_hess >= param.min_child_weight && e.stats.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(e.stats.CalcGain(param) + - c.CalcGain(param) - snode[nid].root_gain); + bst_float loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain); e.best.Update(loss_chg, fid, fsplit, false); } } @@ -313,8 +333,8 @@ class ColMaker: public TreeUpdater { c.SetSubstract(snode[nid].stats, tmp); if (c.sum_hess >= param.min_child_weight && tmp.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(tmp.CalcGain(param) + - c.CalcGain(param) - snode[nid].root_gain); + bst_float loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, tmp, c) - snode[nid].root_gain); e.best.Update(loss_chg, fid, fsplit, true); } } @@ -325,8 +345,8 @@ class ColMaker: public TreeUpdater { c.SetSubstract(snode[nid].stats, tmp); if (c.sum_hess >= param.min_child_weight && tmp.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(tmp.CalcGain(param) + - c.CalcGain(param) - snode[nid].root_gain); + bst_float loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, tmp, c) - snode[nid].root_gain); e.best.Update(loss_chg, fid, e.last_fvalue + rt_eps, true); } } @@ -357,9 +377,9 @@ class ColMaker: public TreeUpdater { c.SetSubstract(snode[nid].stats, e.stats); if (c.sum_hess >= param.min_child_weight && e.stats.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(e.stats.CalcGain(param) + - c.CalcGain(param) - - snode[nid].root_gain); + bst_float loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - + snode[nid].root_gain); e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, false); } } @@ -368,9 +388,9 @@ class ColMaker: public TreeUpdater { c.SetSubstract(snode[nid].stats, cright); if (c.sum_hess >= param.min_child_weight && cright.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(cright.CalcGain(param) + - c.CalcGain(param) - - snode[nid].root_gain); + bst_float loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, c, cright) - + snode[nid].root_gain); e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true); } } @@ -397,8 +417,14 @@ class ColMaker: public TreeUpdater { e.stats.sum_hess >= param.min_child_weight) { c.SetSubstract(snode[nid].stats, e.stats); if (c.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(e.stats.CalcGain(param) + - c.CalcGain(param) - snode[nid].root_gain); + bst_float loss_chg; + if (d_step == -1) { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, c, e.stats) - snode[nid].root_gain); + } else { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain); + } e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1); } } @@ -467,9 +493,16 @@ class ColMaker: public TreeUpdater { const int nid = qexpand[i]; ThreadEntry &e = temp[nid]; c.SetSubstract(snode[nid].stats, e.stats); - if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(e.stats.CalcGain(param) + - c.CalcGain(param) - snode[nid].root_gain); + if (e.stats.sum_hess >= param.min_child_weight && + c.sum_hess >= param.min_child_weight) { + bst_float loss_chg; + if (d_step == -1) { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, c, e.stats) - snode[nid].root_gain); + } else { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain); + } const float gap = std::abs(e.last_fvalue) + rt_eps; const float delta = d_step == +1 ? gap: -gap; e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1); @@ -515,8 +548,16 @@ class ColMaker: public TreeUpdater { e.stats.sum_hess >= param.min_child_weight) { c.SetSubstract(snode[nid].stats, e.stats); if (c.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(e.stats.CalcGain(param) + - c.CalcGain(param) - snode[nid].root_gain); + bst_float loss_chg; + if (d_step == -1) { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, c, e.stats) - + snode[nid].root_gain); + } else { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - + snode[nid].root_gain); + } e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1); } } @@ -531,8 +572,14 @@ class ColMaker: public TreeUpdater { ThreadEntry &e = temp[nid]; c.SetSubstract(snode[nid].stats, e.stats); if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) { - bst_float loss_chg = static_cast(e.stats.CalcGain(param) + - c.CalcGain(param) - snode[nid].root_gain); + bst_float loss_chg; + if (d_step == -1) { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, c, e.stats) - snode[nid].root_gain); + } else { + loss_chg = static_cast( + constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain); + } const float gap = std::abs(e.last_fvalue) + rt_eps; const float delta = d_step == +1 ? gap: -gap; e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1); @@ -724,12 +771,14 @@ class ColMaker: public TreeUpdater { std::vector snode; /*! \brief queue of nodes to be expanded */ std::vector qexpand_; + // constraint value + std::vector constraints_; }; }; // distributed column maker -template -class DistColMaker : public ColMaker { +template +class DistColMaker : public ColMaker { public: DistColMaker() : builder(param) { pruner.reset(TreeUpdater::Create("prune")); @@ -755,10 +804,10 @@ class DistColMaker : public ColMaker { } private: - struct Builder : public ColMaker::Builder { + struct Builder : public ColMaker::Builder { public: explicit Builder(const TrainParam ¶m) - : ColMaker::Builder(param) { + : ColMaker::Builder(param) { } inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) { const RowSet &rowset = p_fmat->buffered_rowset(); @@ -881,16 +930,56 @@ class DistColMaker : public ColMaker { Builder builder; }; +// simple switch to defer implementation. +class TreeUpdaterSwitch : public TreeUpdater { + public: + TreeUpdaterSwitch() : monotone_(false) {} + void Init(const std::vector >& args) override { + for (auto &kv : args) { + if (kv.first == "monotone_constraints" && kv.second.length() != 0) { + monotone_ = true; + } + } + if (inner_.get() == nullptr) { + if (monotone_) { + inner_.reset(new ColMaker()); + } else { + inner_.reset(new ColMaker()); + } + } + + inner_->Init(args); + } + + void Update(const std::vector& gpair, + DMatrix* data, + const std::vector& trees) override { + CHECK(inner_ != nullptr); + inner_->Update(gpair, data, trees); + } + + const int* GetLeafPosition() const override { + CHECK(inner_ != nullptr); + return inner_->GetLeafPosition(); + } + + private: + // monotone constraints + bool monotone_; + // internal implementation + std::unique_ptr inner_; +}; + XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") .describe("Grow tree with parallelization over columns.") .set_body([]() { - return new ColMaker(); + return new TreeUpdaterSwitch(); }); XGBOOST_REGISTER_TREE_UPDATER(DistColMaker, "distcol") .describe("Distributed column split version of tree maker.") .set_body([]() { - return new DistColMaker(); + return new DistColMaker(); }); } // namespace tree } // namespace xgboost