'hist': Montonic Constraints (#3085)
* Extended monotonic constraints support to 'hist' tree method. * Added monotonic constraints tests. * Fix the signature of NoConstraint::CalcSplitGain() * Document monotonic constraint support in 'hist' * Update signature of Update to account for latest refactor
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
8937134015
commit
d5f1b74ef5
@@ -376,7 +376,7 @@ struct NoConstraint {
|
||||
inline static void Init(TrainParam *param, unsigned num_feature) {
|
||||
param->monotone_constraints.resize(num_feature, 0);
|
||||
}
|
||||
inline double CalcSplitGain(const TrainParam ¶m, bst_uint split_index,
|
||||
inline double CalcSplitGain(const TrainParam ¶m, int constraint,
|
||||
GradStats left, GradStats right) const {
|
||||
return left.CalcGain(param) + right.CalcGain(param);
|
||||
}
|
||||
@@ -421,6 +421,7 @@ template <typename param_t>
|
||||
template <typename param_t>
|
||||
XGBOOST_DEVICE inline double CalcSplitGain(const param_t ¶m, int constraint,
|
||||
GradStats left, GradStats right) const {
|
||||
const double negative_infinity = -std::numeric_limits<double>::infinity();
|
||||
double wleft = CalcWeight(param, left);
|
||||
double wright = CalcWeight(param, right);
|
||||
double gain =
|
||||
@@ -429,9 +430,9 @@ template <typename param_t>
|
||||
if (constraint == 0) {
|
||||
return gain;
|
||||
} else if (constraint > 0) {
|
||||
return wleft < wright ? gain : 0.0;
|
||||
return wleft <= wright ? gain : negative_infinity;
|
||||
} else {
|
||||
return wleft > wright ? gain : 0.0;
|
||||
return wleft >= wright ? gain : negative_infinity;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -870,13 +870,13 @@ class FastHistMaker: public TreeUpdater {
|
||||
if (d_step > 0) {
|
||||
// forward enumeration: split at right bound of each bin
|
||||
loss_chg = static_cast<bst_float>(
|
||||
constraint.CalcSplitGain(param, fid, e, c) -
|
||||
constraint.CalcSplitGain(param, param.monotone_constraints[fid], e, c) -
|
||||
snode.root_gain);
|
||||
split_pt = cut_val[i];
|
||||
} else {
|
||||
// backward enumeration: split at left bound of each bin
|
||||
loss_chg = static_cast<bst_float>(
|
||||
constraint.CalcSplitGain(param, fid, c, e) -
|
||||
constraint.CalcSplitGain(param, param.monotone_constraints[fid], c, e) -
|
||||
snode.root_gain);
|
||||
if (i == imin) {
|
||||
// for leftmost bin, left bound is the smallest feature value
|
||||
@@ -961,10 +961,45 @@ class FastHistMaker: public TreeUpdater {
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
};
|
||||
|
||||
// simple switch to defer implementation.
|
||||
class FastHistTreeUpdaterSwitch : public TreeUpdater {
|
||||
public:
|
||||
FastHistTreeUpdaterSwitch() : monotone_(false) {}
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
for (auto &kv : args) {
|
||||
if (kv.first == "monotone_constraints" && kv.second.length() != 0) {
|
||||
monotone_ = true;
|
||||
}
|
||||
}
|
||||
if (inner_.get() == nullptr) {
|
||||
if (monotone_) {
|
||||
inner_.reset(new FastHistMaker<GradStats, ValueConstraint>());
|
||||
} else {
|
||||
inner_.reset(new FastHistMaker<GradStats, NoConstraint>());
|
||||
}
|
||||
}
|
||||
|
||||
inner_->Init(args);
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<bst_gpair>* gpair,
|
||||
DMatrix* data,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
CHECK(inner_ != nullptr);
|
||||
inner_->Update(gpair, data, trees);
|
||||
}
|
||||
|
||||
private:
|
||||
// monotone constraints
|
||||
bool monotone_;
|
||||
// internal implementation
|
||||
std::unique_ptr<TreeUpdater> inner_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
|
||||
.describe("Grow tree using quantized histogram.")
|
||||
.set_body([]() {
|
||||
return new FastHistMaker<GradStats, NoConstraint>();
|
||||
return new FastHistTreeUpdaterSwitch();
|
||||
});
|
||||
|
||||
} // namespace tree
|
||||
|
||||
Reference in New Issue
Block a user