* Allow multiple split constraints

* Replace RidgePenalty with ElasticNet

* Add test for checking Ridge, LASSO, and Elastic Net are implemented
This commit is contained in:
Henry Gouk 2018-07-15 16:38:26 +12:00 committed by Rory Mitchell
parent 2f8764955c
commit a13e29ece1
6 changed files with 188 additions and 85 deletions

View File

@ -187,7 +187,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.set_default(1) .set_default(1)
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs"); .describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
DMLC_DECLARE_FIELD(split_evaluator) DMLC_DECLARE_FIELD(split_evaluator)
.set_default("monotonic") .set_default("elastic_net,monotonic")
.describe("The criteria to use for ranking splits"); .describe("The criteria to use for ranking splits");
// add alias of parameters // add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_lambda, lambda);

View File

@ -8,13 +8,12 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <string> #include <string>
#include <sstream>
#include <utility> #include <utility>
#include "param.h" #include "param.h"
#include "../common/common.h" #include "../common/common.h"
#include "../common/host_device_vector.h" #include "../common/host_device_vector.h"
#define ROOT_PARENT_ID (-1 & ((1U << 31) - 1))
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::tree::SplitEvaluatorReg); DMLC_REGISTRY_ENABLE(::xgboost::tree::SplitEvaluatorReg);
} // namespace dmlc } // namespace dmlc
@ -23,12 +22,19 @@ namespace xgboost {
namespace tree { namespace tree {
SplitEvaluator* SplitEvaluator::Create(const std::string& name) { SplitEvaluator* SplitEvaluator::Create(const std::string& name) {
auto* e = ::dmlc::Registry< ::xgboost::tree::SplitEvaluatorReg> std::stringstream ss(name);
::Get()->Find(name); std::string item;
if (e == nullptr) { SplitEvaluator* eval = nullptr;
LOG(FATAL) << "Unknown SplitEvaluator " << name; // Construct a chain of SplitEvaluators. This allows one to specify multiple constraints.
while (std::getline(ss, item, ',')) {
auto* e = ::dmlc::Registry< ::xgboost::tree::SplitEvaluatorReg>
::Get()->Find(item);
if (e == nullptr) {
LOG(FATAL) << "Unknown SplitEvaluator " << name;
}
eval = (e->body)(std::unique_ptr<SplitEvaluator>(eval));
} }
return (e->body)(); return eval;
} }
// Default implementations of some virtual methods that aren't always needed // Default implementations of some virtual methods that aren't always needed
@ -41,38 +47,57 @@ void SplitEvaluator::AddSplit(bst_uint nodeid,
bst_uint featureid, bst_uint featureid,
bst_float leftweight, bst_float leftweight,
bst_float rightweight) {} bst_float rightweight) {}
bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
bst_uint featureid,
const GradStats& left_stats,
const GradStats& right_stats) const {
bst_float left_weight = ComputeWeight(nodeid, left_stats);
bst_float right_weight = ComputeWeight(nodeid, right_stats);
return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
}
//! \brief Encapsulates the parameters for by the RidgePenalty //! \brief Encapsulates the parameters for ElasticNet
struct RidgePenaltyParams : public dmlc::Parameter<RidgePenaltyParams> { struct ElasticNetParams : public dmlc::Parameter<ElasticNetParams> {
float reg_lambda; bst_float reg_lambda;
float reg_gamma; bst_float reg_alpha;
bst_float reg_gamma;
DMLC_DECLARE_PARAMETER(RidgePenaltyParams) { DMLC_DECLARE_PARAMETER(ElasticNetParams) {
DMLC_DECLARE_FIELD(reg_lambda) DMLC_DECLARE_FIELD(reg_lambda)
.set_lower_bound(0.0) .set_lower_bound(0.0)
.set_default(1.0) .set_default(1.0)
.describe("L2 regularization on leaf weight"); .describe("L2 regularization on leaf weight");
DMLC_DECLARE_FIELD(reg_alpha)
.set_lower_bound(0.0)
.set_default(0.0)
.describe("L1 regularization on leaf weight");
DMLC_DECLARE_FIELD(reg_gamma) DMLC_DECLARE_FIELD(reg_gamma)
.set_lower_bound(0.0f) .set_lower_bound(0.0)
.set_default(0.0f) .set_default(0.0)
.describe("Cost incurred by adding a new leaf node to the tree"); .describe("Cost incurred by adding a new leaf node to the tree");
DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
DMLC_DECLARE_ALIAS(reg_gamma, gamma); DMLC_DECLARE_ALIAS(reg_gamma, gamma);
} }
}; };
DMLC_REGISTER_PARAMETER(RidgePenaltyParams); DMLC_REGISTER_PARAMETER(ElasticNetParams);
/*! \brief Applies an L2 penalty and per-leaf penalty. */ /*! \brief Applies an elastic net penalty and per-leaf penalty. */
class RidgePenalty final : public SplitEvaluator { class ElasticNet final : public SplitEvaluator {
public: public:
explicit ElasticNet(std::unique_ptr<SplitEvaluator> inner) {
if (inner) {
LOG(FATAL) << "ElasticNet does not accept an inner SplitEvaluator";
}
}
void Init( void Init(
const std::vector<std::pair<std::string, std::string> >& args) override { const std::vector<std::pair<std::string, std::string> >& args) override {
params_.InitAllowUnknown(args); params_.InitAllowUnknown(args);
} }
SplitEvaluator* GetHostClone() const override { SplitEvaluator* GetHostClone() const override {
auto r = new RidgePenalty(); auto r = new ElasticNet(nullptr);
r->params_ = this->params_; r->params_ = this->params_;
return r; return r;
@ -80,31 +105,55 @@ class RidgePenalty final : public SplitEvaluator {
bst_float ComputeSplitScore(bst_uint nodeid, bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid, bst_uint featureid,
const GradStats& left, const GradStats& left_stats,
const GradStats& right) const override { const GradStats& right_stats,
// parentID is not needed for this split evaluator. Just use 0. bst_float left_weight,
return ComputeScore(0, left) + ComputeScore(0, right); bst_float right_weight) const override {
return ComputeScore(nodeid, left_stats, left_weight) +
ComputeScore(nodeid, right_stats, right_weight);
} }
bst_float ComputeScore(bst_uint parentID, const GradStats& stats) bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid,
const GradStats& left_stats,
const GradStats& right_stats) const override {
return ComputeScore(nodeid, left_stats) + ComputeScore(nodeid, right_stats);
}
bst_float ComputeScore(bst_uint parentID, const GradStats &stats, bst_float weight)
const override { const override {
return (stats.sum_grad * stats.sum_grad) auto loss = weight * (2.0 * stats.sum_grad + stats.sum_hess * weight
/ (stats.sum_hess + params_.reg_lambda) - params_.reg_gamma; + params_.reg_lambda * weight)
+ params_.reg_alpha * std::abs(weight);
return -loss;
}
bst_float ComputeScore(bst_uint parentID, const GradStats &stats) const {
return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_.reg_lambda);
} }
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats) bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
const override { const override {
return -stats.sum_grad / (stats.sum_hess + params_.reg_lambda); return -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_.reg_lambda);
} }
private: private:
RidgePenaltyParams params_; ElasticNetParams params_;
inline double ThresholdL1(double g) const {
if (g > params_.reg_alpha) {
g = g - params_.reg_alpha;
} else if (g < -params_.reg_alpha) {
g = g + params_.reg_alpha;
}
return g;
}
}; };
XGBOOST_REGISTER_SPLIT_EVALUATOR(RidgePenalty, "ridge") XGBOOST_REGISTER_SPLIT_EVALUATOR(ElasticNet, "elastic_net")
.describe("Use an L2 penalty term for the weights and a cost per leaf node") .describe("Use an elastic net regulariser and a cost per leaf node")
.set_body([]() { .set_body([](std::unique_ptr<SplitEvaluator> inner) {
return new RidgePenalty(); return new ElasticNet(std::move(inner));
}); });
/*! \brief Encapsulates the parameters required by the MonotonicConstraint /*! \brief Encapsulates the parameters required by the MonotonicConstraint
@ -113,23 +162,11 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(RidgePenalty, "ridge")
struct MonotonicConstraintParams struct MonotonicConstraintParams
: public dmlc::Parameter<MonotonicConstraintParams> { : public dmlc::Parameter<MonotonicConstraintParams> {
std::vector<bst_int> monotone_constraints; std::vector<bst_int> monotone_constraints;
float reg_lambda;
float reg_gamma;
DMLC_DECLARE_PARAMETER(MonotonicConstraintParams) { DMLC_DECLARE_PARAMETER(MonotonicConstraintParams) {
DMLC_DECLARE_FIELD(reg_lambda)
.set_lower_bound(0.0)
.set_default(1.0)
.describe("L2 regularization on leaf weight");
DMLC_DECLARE_FIELD(reg_gamma)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("Cost incurred by adding a new leaf node to the tree");
DMLC_DECLARE_FIELD(monotone_constraints) DMLC_DECLARE_FIELD(monotone_constraints)
.set_default(std::vector<bst_int>()) .set_default(std::vector<bst_int>())
.describe("Constraint of variable monotonicity"); .describe("Constraint of variable monotonicity");
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_gamma, gamma);
} }
}; };
@ -140,8 +177,16 @@ DMLC_REGISTER_PARAMETER(MonotonicConstraintParams);
*/ */
class MonotonicConstraint final : public SplitEvaluator { class MonotonicConstraint final : public SplitEvaluator {
public: public:
explicit MonotonicConstraint(std::unique_ptr<SplitEvaluator> inner) {
if (!inner) {
LOG(FATAL) << "MonotonicConstraint must be given an inner evaluator";
}
inner_ = std::move(inner);
}
void Init(const std::vector<std::pair<std::string, std::string> >& args) void Init(const std::vector<std::pair<std::string, std::string> >& args)
override { override {
inner_->Init(args);
params_.InitAllowUnknown(args); params_.InitAllowUnknown(args);
Reset(); Reset();
} }
@ -153,22 +198,11 @@ class MonotonicConstraint final : public SplitEvaluator {
SplitEvaluator* GetHostClone() const override { SplitEvaluator* GetHostClone() const override {
if (params_.monotone_constraints.size() == 0) { if (params_.monotone_constraints.size() == 0) {
// No monotone constraints specified, make a RidgePenalty evaluator // No monotone constraints specified, just return a clone of inner to speed things up
using std::pair; return inner_->GetHostClone();
using std::string;
using std::to_string;
using std::vector;
auto c = new RidgePenalty();
vector<pair<string, string> > args;
args.emplace_back(
pair<string, string>("reg_lambda", to_string(params_.reg_lambda)));
args.emplace_back(
pair<string, string>("reg_gamma", to_string(params_.reg_gamma)));
c->Init(args);
c->Reset();
return c;
} else { } else {
auto c = new MonotonicConstraint(); auto c = new MonotonicConstraint(
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
c->params_ = this->params_; c->params_ = this->params_;
c->Reset(); c->Reset();
return c; return c;
@ -177,35 +211,32 @@ class MonotonicConstraint final : public SplitEvaluator {
bst_float ComputeSplitScore(bst_uint nodeid, bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid, bst_uint featureid,
const GradStats& left, const GradStats& left_stats,
const GradStats& right) const override { const GradStats& right_stats,
bst_float left_weight,
bst_float right_weight) const override {
bst_float infinity = std::numeric_limits<bst_float>::infinity(); bst_float infinity = std::numeric_limits<bst_float>::infinity();
bst_int constraint = GetConstraint(featureid); bst_int constraint = GetConstraint(featureid);
bst_float score = inner_->ComputeSplitScore(
bst_float score = ComputeScore(nodeid, left) + ComputeScore(nodeid, right); nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
bst_float leftweight = ComputeWeight(nodeid, left);
bst_float rightweight = ComputeWeight(nodeid, right);
if (constraint == 0) { if (constraint == 0) {
return score; return score;
} else if (constraint > 0) { } else if (constraint > 0) {
return leftweight <= rightweight ? score : -infinity; return left_weight <= right_weight ? score : -infinity;
} else { } else {
return leftweight >= rightweight ? score : -infinity; return left_weight >= right_weight ? score : -infinity;
} }
} }
bst_float ComputeScore(bst_uint parentID, const GradStats& stats) bst_float ComputeScore(bst_uint parentID, const GradStats& stats, bst_float weight)
const override { const override {
bst_float w = ComputeWeight(parentID, stats); return inner_->ComputeScore(parentID, stats, weight);
return -(2.0 * stats.sum_grad * w + (stats.sum_hess + params_.reg_lambda)
* w * w);
} }
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats) bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
const override { const override {
bst_float weight = -stats.sum_grad / (stats.sum_hess + params_.reg_lambda); bst_float weight = inner_->ComputeWeight(parentID, stats);
if (parentID == ROOT_PARENT_ID) { if (parentID == ROOT_PARENT_ID) {
// This is the root node // This is the root node
@ -225,6 +256,7 @@ class MonotonicConstraint final : public SplitEvaluator {
bst_uint featureid, bst_uint featureid,
bst_float leftweight, bst_float leftweight,
bst_float rightweight) override { bst_float rightweight) override {
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
bst_uint newsize = std::max(leftid, rightid) + 1; bst_uint newsize = std::max(leftid, rightid) + 1;
lower_.resize(newsize); lower_.resize(newsize);
upper_.resize(newsize); upper_.resize(newsize);
@ -250,6 +282,7 @@ class MonotonicConstraint final : public SplitEvaluator {
private: private:
MonotonicConstraintParams params_; MonotonicConstraintParams params_;
std::unique_ptr<SplitEvaluator> inner_;
std::vector<bst_float> lower_; std::vector<bst_float> lower_;
std::vector<bst_float> upper_; std::vector<bst_float> upper_;
@ -265,8 +298,8 @@ class MonotonicConstraint final : public SplitEvaluator {
XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic") XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
.describe("Enforces that the tree is monotonically increasing/decreasing " .describe("Enforces that the tree is monotonically increasing/decreasing "
"w.r.t. specified features") "w.r.t. specified features")
.set_body([]() { .set_body([](std::unique_ptr<SplitEvaluator> inner) {
return new MonotonicConstraint(); return new MonotonicConstraint(std::move(inner));
}); });
} // namespace tree } // namespace tree

View File

@ -15,6 +15,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#define ROOT_PARENT_ID (-1 & ((1U << 31) - 1))
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -40,13 +42,21 @@ class SplitEvaluator {
// Computes the score (negative loss) resulting from performing this split // Computes the score (negative loss) resulting from performing this split
virtual bst_float ComputeSplitScore(bst_uint nodeid, virtual bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid, bst_uint featureid,
const GradStats& left, const GradStats& left_stats,
const GradStats& right) const = 0; const GradStats& right_stats,
bst_float left_weight,
bst_float right_weight) const = 0;
virtual bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid,
const GradStats& left_stats,
const GradStats& right_stats) const;
// Compute the Score for a node with the given stats // Compute the Score for a node with the given stats
virtual bst_float ComputeScore(bst_uint parentid, const GradStats& stats) virtual bst_float ComputeScore(bst_uint parentid,
const = 0; const GradStats &stats,
bst_float weight) const = 0;
// Compute the weight for a node with the given stats // Compute the weight for a node with the given stats
virtual bst_float ComputeWeight(bst_uint parentid, const GradStats& stats) virtual bst_float ComputeWeight(bst_uint parentid, const GradStats& stats)
@ -62,7 +72,7 @@ class SplitEvaluator {
struct SplitEvaluatorReg struct SplitEvaluatorReg
: public dmlc::FunctionRegEntryBase<SplitEvaluatorReg, : public dmlc::FunctionRegEntryBase<SplitEvaluatorReg,
std::function<SplitEvaluator* ()> > {}; std::function<SplitEvaluator* (std::unique_ptr<SplitEvaluator>)> > {};
/*! /*!
* \brief Macro to register tree split evaluator. * \brief Macro to register tree split evaluator.

View File

@ -243,10 +243,10 @@ class ColMaker: public TreeUpdater {
// calculating the weights // calculating the weights
for (int nid : qexpand) { for (int nid : qexpand) {
bst_uint parentid = tree[nid].Parent(); bst_uint parentid = tree[nid].Parent();
snode_[nid].root_gain = static_cast<float>(
spliteval_->ComputeScore(parentid, snode_[nid].stats));
snode_[nid].weight = static_cast<float>( snode_[nid].weight = static_cast<float>(
spliteval_->ComputeWeight(parentid, snode_[nid].stats)); spliteval_->ComputeWeight(parentid, snode_[nid].stats));
snode_[nid].root_gain = static_cast<float>(
spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight));
} }
} }
/*! \brief update queue expand add in new leaves */ /*! \brief update queue expand add in new leaves */

View File

@ -752,10 +752,10 @@ class FastHistMaker: public TreeUpdater {
// calculating the weights // calculating the weights
{ {
bst_uint parentid = tree[nid].Parent(); bst_uint parentid = tree[nid].Parent();
snode_[nid].root_gain = static_cast<float>(
spliteval_->ComputeScore(parentid, snode_[nid].stats));
snode_[nid].weight = static_cast<float>( snode_[nid].weight = static_cast<float>(
spliteval_->ComputeWeight(parentid, snode_[nid].stats)); spliteval_->ComputeWeight(parentid, snode_[nid].stats));
snode_[nid].root_gain = static_cast<float>(
spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight));
} }
} }

View File

@ -0,0 +1,60 @@
import numpy as np
import unittest
import xgboost as xgb
from numpy.testing import assert_approx_equal
train_data = xgb.DMatrix(np.array([[1]]), label=np.array([1]))
class TestTreeRegularization(unittest.TestCase):
def test_alpha(self):
params = {
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
'eta': 1,
'lambda': 0,
'alpha': 0.1
}
model = xgb.train(params, train_data, 1)
preds = model.predict(train_data)
# Default prediction (with no trees) is 0.5
# sum_grad = (0.5 - 1.0)
# sum_hess = 1.0
# 0.9 = 0.5 - (sum_grad - alpha * sgn(sum_grad)) / sum_hess
assert_approx_equal(preds[0], 0.9)
def test_lambda(self):
params = {
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
'eta': 1,
'lambda': 1,
'alpha': 0
}
model = xgb.train(params, train_data, 1)
preds = model.predict(train_data)
# Default prediction (with no trees) is 0.5
# sum_grad = (0.5 - 1.0)
# sum_hess = 1.0
# 0.75 = 0.5 - sum_grad / (sum_hess + lambda)
assert_approx_equal(preds[0], 0.75)
def test_alpha_and_lambda(self):
params = {
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
'eta': 1,
'lambda': 1,
'alpha': 0.1
}
model = xgb.train(params, train_data, 1)
preds = model.predict(train_data)
# Default prediction (with no trees) is 0.5
# sum_grad = (0.5 - 1.0)
# sum_hess = 1.0
# 0.7 = 0.5 - (sum_grad - alpha * sgn(sum_grad)) / (sum_hess + lambda)
assert_approx_equal(preds[0], 0.7)