'hist': Montonic Constraints (#3085)

* Extended monotonic constraints support to 'hist' tree method.

* Added monotonic constraints tests.

* Fix the signature of NoConstraint::CalcSplitGain()

* Document monotonic constraint support in 'hist'

* Update signature of Update to account for latest refactor
This commit is contained in:
redditur
2018-03-06 00:45:49 +00:00
committed by Philip Hyunsu Cho
parent 8937134015
commit d5f1b74ef5
4 changed files with 139 additions and 6 deletions

View File

@@ -376,7 +376,7 @@ struct NoConstraint {
inline static void Init(TrainParam *param, unsigned num_feature) {
param->monotone_constraints.resize(num_feature, 0);
}
inline double CalcSplitGain(const TrainParam &param, bst_uint split_index,
inline double CalcSplitGain(const TrainParam &param, int constraint,
GradStats left, GradStats right) const {
return left.CalcGain(param) + right.CalcGain(param);
}
@@ -421,6 +421,7 @@ template <typename param_t>
template <typename param_t>
XGBOOST_DEVICE inline double CalcSplitGain(const param_t &param, int constraint,
GradStats left, GradStats right) const {
const double negative_infinity = -std::numeric_limits<double>::infinity();
double wleft = CalcWeight(param, left);
double wright = CalcWeight(param, right);
double gain =
@@ -429,9 +430,9 @@ template <typename param_t>
if (constraint == 0) {
return gain;
} else if (constraint > 0) {
return wleft < wright ? gain : 0.0;
return wleft <= wright ? gain : negative_infinity;
} else {
return wleft > wright ? gain : 0.0;
return wleft >= wright ? gain : negative_infinity;
}
}

View File

@@ -870,13 +870,13 @@ class FastHistMaker: public TreeUpdater {
if (d_step > 0) {
// forward enumeration: split at right bound of each bin
loss_chg = static_cast<bst_float>(
constraint.CalcSplitGain(param, fid, e, c) -
constraint.CalcSplitGain(param, param.monotone_constraints[fid], e, c) -
snode.root_gain);
split_pt = cut_val[i];
} else {
// backward enumeration: split at left bound of each bin
loss_chg = static_cast<bst_float>(
constraint.CalcSplitGain(param, fid, c, e) -
constraint.CalcSplitGain(param, param.monotone_constraints[fid], c, e) -
snode.root_gain);
if (i == imin) {
// for leftmost bin, left bound is the smallest feature value
@@ -961,10 +961,45 @@ class FastHistMaker: public TreeUpdater {
std::unique_ptr<TreeUpdater> pruner_;
};
// simple switch to defer implementation.
class FastHistTreeUpdaterSwitch : public TreeUpdater {
public:
FastHistTreeUpdaterSwitch() : monotone_(false) {}
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
for (auto &kv : args) {
if (kv.first == "monotone_constraints" && kv.second.length() != 0) {
monotone_ = true;
}
}
if (inner_.get() == nullptr) {
if (monotone_) {
inner_.reset(new FastHistMaker<GradStats, ValueConstraint>());
} else {
inner_.reset(new FastHistMaker<GradStats, NoConstraint>());
}
}
inner_->Init(args);
}
void Update(HostDeviceVector<bst_gpair>* gpair,
DMatrix* data,
const std::vector<RegTree*>& trees) override {
CHECK(inner_ != nullptr);
inner_->Update(gpair, data, trees);
}
private:
// monotone constraints
bool monotone_;
// internal implementation
std::unique_ptr<TreeUpdater> inner_;
};
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body([]() {
return new FastHistMaker<GradStats, NoConstraint>();
return new FastHistTreeUpdaterSwitch();
});
} // namespace tree