Add LASSO (#3429)
* Allow multiple split constraints * Replace RidgePenalty with ElasticNet * Add test for checking Ridge, LASSO, and Elastic Net are implemented
This commit is contained in:
parent
2f8764955c
commit
a13e29ece1
@ -187,7 +187,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
.set_default(1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
|
||||
DMLC_DECLARE_FIELD(split_evaluator)
|
||||
.set_default("monotonic")
|
||||
.set_default("elastic_net,monotonic")
|
||||
.describe("The criteria to use for ranking splits");
|
||||
// add alias of parameters
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
|
||||
@ -8,13 +8,12 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/host_device_vector.h"
|
||||
|
||||
#define ROOT_PARENT_ID (-1 & ((1U << 31) - 1))
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::tree::SplitEvaluatorReg);
|
||||
} // namespace dmlc
|
||||
@ -23,12 +22,19 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
SplitEvaluator* SplitEvaluator::Create(const std::string& name) {
|
||||
auto* e = ::dmlc::Registry< ::xgboost::tree::SplitEvaluatorReg>
|
||||
::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown SplitEvaluator " << name;
|
||||
std::stringstream ss(name);
|
||||
std::string item;
|
||||
SplitEvaluator* eval = nullptr;
|
||||
// Construct a chain of SplitEvaluators. This allows one to specify multiple constraints.
|
||||
while (std::getline(ss, item, ',')) {
|
||||
auto* e = ::dmlc::Registry< ::xgboost::tree::SplitEvaluatorReg>
|
||||
::Get()->Find(item);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown SplitEvaluator " << name;
|
||||
}
|
||||
eval = (e->body)(std::unique_ptr<SplitEvaluator>(eval));
|
||||
}
|
||||
return (e->body)();
|
||||
return eval;
|
||||
}
|
||||
|
||||
// Default implementations of some virtual methods that aren't always needed
|
||||
@ -41,38 +47,57 @@ void SplitEvaluator::AddSplit(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
bst_float leftweight,
|
||||
bst_float rightweight) {}
|
||||
bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left_stats,
|
||||
const GradStats& right_stats) const {
|
||||
bst_float left_weight = ComputeWeight(nodeid, left_stats);
|
||||
bst_float right_weight = ComputeWeight(nodeid, right_stats);
|
||||
return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
|
||||
}
|
||||
|
||||
//! \brief Encapsulates the parameters for by the RidgePenalty
|
||||
struct RidgePenaltyParams : public dmlc::Parameter<RidgePenaltyParams> {
|
||||
float reg_lambda;
|
||||
float reg_gamma;
|
||||
//! \brief Encapsulates the parameters for ElasticNet
|
||||
struct ElasticNetParams : public dmlc::Parameter<ElasticNetParams> {
|
||||
bst_float reg_lambda;
|
||||
bst_float reg_alpha;
|
||||
bst_float reg_gamma;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(RidgePenaltyParams) {
|
||||
DMLC_DECLARE_PARAMETER(ElasticNetParams) {
|
||||
DMLC_DECLARE_FIELD(reg_lambda)
|
||||
.set_lower_bound(0.0)
|
||||
.set_default(1.0)
|
||||
.describe("L2 regularization on leaf weight");
|
||||
DMLC_DECLARE_FIELD(reg_alpha)
|
||||
.set_lower_bound(0.0)
|
||||
.set_default(0.0)
|
||||
.describe("L1 regularization on leaf weight");
|
||||
DMLC_DECLARE_FIELD(reg_gamma)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.set_lower_bound(0.0)
|
||||
.set_default(0.0)
|
||||
.describe("Cost incurred by adding a new leaf node to the tree");
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||
DMLC_DECLARE_ALIAS(reg_gamma, gamma);
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(RidgePenaltyParams);
|
||||
DMLC_REGISTER_PARAMETER(ElasticNetParams);
|
||||
|
||||
/*! \brief Applies an L2 penalty and per-leaf penalty. */
|
||||
class RidgePenalty final : public SplitEvaluator {
|
||||
/*! \brief Applies an elastic net penalty and per-leaf penalty. */
|
||||
class ElasticNet final : public SplitEvaluator {
|
||||
public:
|
||||
explicit ElasticNet(std::unique_ptr<SplitEvaluator> inner) {
|
||||
if (inner) {
|
||||
LOG(FATAL) << "ElasticNet does not accept an inner SplitEvaluator";
|
||||
}
|
||||
}
|
||||
void Init(
|
||||
const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
params_.InitAllowUnknown(args);
|
||||
}
|
||||
|
||||
SplitEvaluator* GetHostClone() const override {
|
||||
auto r = new RidgePenalty();
|
||||
auto r = new ElasticNet(nullptr);
|
||||
r->params_ = this->params_;
|
||||
|
||||
return r;
|
||||
@ -80,31 +105,55 @@ class RidgePenalty final : public SplitEvaluator {
|
||||
|
||||
bst_float ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left,
|
||||
const GradStats& right) const override {
|
||||
// parentID is not needed for this split evaluator. Just use 0.
|
||||
return ComputeScore(0, left) + ComputeScore(0, right);
|
||||
const GradStats& left_stats,
|
||||
const GradStats& right_stats,
|
||||
bst_float left_weight,
|
||||
bst_float right_weight) const override {
|
||||
return ComputeScore(nodeid, left_stats, left_weight) +
|
||||
ComputeScore(nodeid, right_stats, right_weight);
|
||||
}
|
||||
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats& stats)
|
||||
bst_float ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left_stats,
|
||||
const GradStats& right_stats) const override {
|
||||
return ComputeScore(nodeid, left_stats) + ComputeScore(nodeid, right_stats);
|
||||
}
|
||||
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats &stats, bst_float weight)
|
||||
const override {
|
||||
return (stats.sum_grad * stats.sum_grad)
|
||||
/ (stats.sum_hess + params_.reg_lambda) - params_.reg_gamma;
|
||||
auto loss = weight * (2.0 * stats.sum_grad + stats.sum_hess * weight
|
||||
+ params_.reg_lambda * weight)
|
||||
+ params_.reg_alpha * std::abs(weight);
|
||||
return -loss;
|
||||
}
|
||||
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats &stats) const {
|
||||
return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_.reg_lambda);
|
||||
}
|
||||
|
||||
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
|
||||
const override {
|
||||
return -stats.sum_grad / (stats.sum_hess + params_.reg_lambda);
|
||||
return -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_.reg_lambda);
|
||||
}
|
||||
|
||||
private:
|
||||
RidgePenaltyParams params_;
|
||||
ElasticNetParams params_;
|
||||
|
||||
inline double ThresholdL1(double g) const {
|
||||
if (g > params_.reg_alpha) {
|
||||
g = g - params_.reg_alpha;
|
||||
} else if (g < -params_.reg_alpha) {
|
||||
g = g + params_.reg_alpha;
|
||||
}
|
||||
return g;
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_SPLIT_EVALUATOR(RidgePenalty, "ridge")
|
||||
.describe("Use an L2 penalty term for the weights and a cost per leaf node")
|
||||
.set_body([]() {
|
||||
return new RidgePenalty();
|
||||
XGBOOST_REGISTER_SPLIT_EVALUATOR(ElasticNet, "elastic_net")
|
||||
.describe("Use an elastic net regulariser and a cost per leaf node")
|
||||
.set_body([](std::unique_ptr<SplitEvaluator> inner) {
|
||||
return new ElasticNet(std::move(inner));
|
||||
});
|
||||
|
||||
/*! \brief Encapsulates the parameters required by the MonotonicConstraint
|
||||
@ -113,23 +162,11 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(RidgePenalty, "ridge")
|
||||
struct MonotonicConstraintParams
|
||||
: public dmlc::Parameter<MonotonicConstraintParams> {
|
||||
std::vector<bst_int> monotone_constraints;
|
||||
float reg_lambda;
|
||||
float reg_gamma;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(MonotonicConstraintParams) {
|
||||
DMLC_DECLARE_FIELD(reg_lambda)
|
||||
.set_lower_bound(0.0)
|
||||
.set_default(1.0)
|
||||
.describe("L2 regularization on leaf weight");
|
||||
DMLC_DECLARE_FIELD(reg_gamma)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("Cost incurred by adding a new leaf node to the tree");
|
||||
DMLC_DECLARE_FIELD(monotone_constraints)
|
||||
.set_default(std::vector<bst_int>())
|
||||
.describe("Constraint of variable monotonicity");
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
DMLC_DECLARE_ALIAS(reg_gamma, gamma);
|
||||
}
|
||||
};
|
||||
|
||||
@ -140,8 +177,16 @@ DMLC_REGISTER_PARAMETER(MonotonicConstraintParams);
|
||||
*/
|
||||
class MonotonicConstraint final : public SplitEvaluator {
|
||||
public:
|
||||
explicit MonotonicConstraint(std::unique_ptr<SplitEvaluator> inner) {
|
||||
if (!inner) {
|
||||
LOG(FATAL) << "MonotonicConstraint must be given an inner evaluator";
|
||||
}
|
||||
inner_ = std::move(inner);
|
||||
}
|
||||
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args)
|
||||
override {
|
||||
inner_->Init(args);
|
||||
params_.InitAllowUnknown(args);
|
||||
Reset();
|
||||
}
|
||||
@ -153,22 +198,11 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
|
||||
SplitEvaluator* GetHostClone() const override {
|
||||
if (params_.monotone_constraints.size() == 0) {
|
||||
// No monotone constraints specified, make a RidgePenalty evaluator
|
||||
using std::pair;
|
||||
using std::string;
|
||||
using std::to_string;
|
||||
using std::vector;
|
||||
auto c = new RidgePenalty();
|
||||
vector<pair<string, string> > args;
|
||||
args.emplace_back(
|
||||
pair<string, string>("reg_lambda", to_string(params_.reg_lambda)));
|
||||
args.emplace_back(
|
||||
pair<string, string>("reg_gamma", to_string(params_.reg_gamma)));
|
||||
c->Init(args);
|
||||
c->Reset();
|
||||
return c;
|
||||
// No monotone constraints specified, just return a clone of inner to speed things up
|
||||
return inner_->GetHostClone();
|
||||
} else {
|
||||
auto c = new MonotonicConstraint();
|
||||
auto c = new MonotonicConstraint(
|
||||
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
|
||||
c->params_ = this->params_;
|
||||
c->Reset();
|
||||
return c;
|
||||
@ -177,35 +211,32 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
|
||||
bst_float ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left,
|
||||
const GradStats& right) const override {
|
||||
const GradStats& left_stats,
|
||||
const GradStats& right_stats,
|
||||
bst_float left_weight,
|
||||
bst_float right_weight) const override {
|
||||
bst_float infinity = std::numeric_limits<bst_float>::infinity();
|
||||
bst_int constraint = GetConstraint(featureid);
|
||||
|
||||
bst_float score = ComputeScore(nodeid, left) + ComputeScore(nodeid, right);
|
||||
bst_float leftweight = ComputeWeight(nodeid, left);
|
||||
bst_float rightweight = ComputeWeight(nodeid, right);
|
||||
bst_float score = inner_->ComputeSplitScore(
|
||||
nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
|
||||
|
||||
if (constraint == 0) {
|
||||
return score;
|
||||
} else if (constraint > 0) {
|
||||
return leftweight <= rightweight ? score : -infinity;
|
||||
return left_weight <= right_weight ? score : -infinity;
|
||||
} else {
|
||||
return leftweight >= rightweight ? score : -infinity;
|
||||
return left_weight >= right_weight ? score : -infinity;
|
||||
}
|
||||
}
|
||||
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats& stats)
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats& stats, bst_float weight)
|
||||
const override {
|
||||
bst_float w = ComputeWeight(parentID, stats);
|
||||
|
||||
return -(2.0 * stats.sum_grad * w + (stats.sum_hess + params_.reg_lambda)
|
||||
* w * w);
|
||||
return inner_->ComputeScore(parentID, stats, weight);
|
||||
}
|
||||
|
||||
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
|
||||
const override {
|
||||
bst_float weight = -stats.sum_grad / (stats.sum_hess + params_.reg_lambda);
|
||||
bst_float weight = inner_->ComputeWeight(parentID, stats);
|
||||
|
||||
if (parentID == ROOT_PARENT_ID) {
|
||||
// This is the root node
|
||||
@ -225,6 +256,7 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
bst_uint featureid,
|
||||
bst_float leftweight,
|
||||
bst_float rightweight) override {
|
||||
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
|
||||
bst_uint newsize = std::max(leftid, rightid) + 1;
|
||||
lower_.resize(newsize);
|
||||
upper_.resize(newsize);
|
||||
@ -250,6 +282,7 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
|
||||
private:
|
||||
MonotonicConstraintParams params_;
|
||||
std::unique_ptr<SplitEvaluator> inner_;
|
||||
std::vector<bst_float> lower_;
|
||||
std::vector<bst_float> upper_;
|
||||
|
||||
@ -265,8 +298,8 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
|
||||
.describe("Enforces that the tree is monotonically increasing/decreasing "
|
||||
"w.r.t. specified features")
|
||||
.set_body([]() {
|
||||
return new MonotonicConstraint();
|
||||
.set_body([](std::unique_ptr<SplitEvaluator> inner) {
|
||||
return new MonotonicConstraint(std::move(inner));
|
||||
});
|
||||
|
||||
} // namespace tree
|
||||
|
||||
@ -15,6 +15,8 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#define ROOT_PARENT_ID (-1 & ((1U << 31) - 1))
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
@ -40,13 +42,21 @@ class SplitEvaluator {
|
||||
|
||||
// Computes the score (negative loss) resulting from performing this split
|
||||
virtual bst_float ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left,
|
||||
const GradStats& right) const = 0;
|
||||
bst_uint featureid,
|
||||
const GradStats& left_stats,
|
||||
const GradStats& right_stats,
|
||||
bst_float left_weight,
|
||||
bst_float right_weight) const = 0;
|
||||
|
||||
virtual bst_float ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left_stats,
|
||||
const GradStats& right_stats) const;
|
||||
|
||||
// Compute the Score for a node with the given stats
|
||||
virtual bst_float ComputeScore(bst_uint parentid, const GradStats& stats)
|
||||
const = 0;
|
||||
virtual bst_float ComputeScore(bst_uint parentid,
|
||||
const GradStats &stats,
|
||||
bst_float weight) const = 0;
|
||||
|
||||
// Compute the weight for a node with the given stats
|
||||
virtual bst_float ComputeWeight(bst_uint parentid, const GradStats& stats)
|
||||
@ -62,7 +72,7 @@ class SplitEvaluator {
|
||||
|
||||
struct SplitEvaluatorReg
|
||||
: public dmlc::FunctionRegEntryBase<SplitEvaluatorReg,
|
||||
std::function<SplitEvaluator* ()> > {};
|
||||
std::function<SplitEvaluator* (std::unique_ptr<SplitEvaluator>)> > {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register tree split evaluator.
|
||||
|
||||
@ -243,10 +243,10 @@ class ColMaker: public TreeUpdater {
|
||||
// calculating the weights
|
||||
for (int nid : qexpand) {
|
||||
bst_uint parentid = tree[nid].Parent();
|
||||
snode_[nid].root_gain = static_cast<float>(
|
||||
spliteval_->ComputeScore(parentid, snode_[nid].stats));
|
||||
snode_[nid].weight = static_cast<float>(
|
||||
spliteval_->ComputeWeight(parentid, snode_[nid].stats));
|
||||
snode_[nid].root_gain = static_cast<float>(
|
||||
spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight));
|
||||
}
|
||||
}
|
||||
/*! \brief update queue expand add in new leaves */
|
||||
|
||||
@ -752,10 +752,10 @@ class FastHistMaker: public TreeUpdater {
|
||||
// calculating the weights
|
||||
{
|
||||
bst_uint parentid = tree[nid].Parent();
|
||||
snode_[nid].root_gain = static_cast<float>(
|
||||
spliteval_->ComputeScore(parentid, snode_[nid].stats));
|
||||
snode_[nid].weight = static_cast<float>(
|
||||
spliteval_->ComputeWeight(parentid, snode_[nid].stats));
|
||||
snode_[nid].root_gain = static_cast<float>(
|
||||
spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
60
tests/python/test_tree_regularization.py
Normal file
60
tests/python/test_tree_regularization.py
Normal file
@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
import xgboost as xgb
|
||||
|
||||
from numpy.testing import assert_approx_equal
|
||||
|
||||
train_data = xgb.DMatrix(np.array([[1]]), label=np.array([1]))
|
||||
|
||||
|
||||
class TestTreeRegularization(unittest.TestCase):
|
||||
def test_alpha(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
|
||||
'eta': 1,
|
||||
'lambda': 0,
|
||||
'alpha': 0.1
|
||||
}
|
||||
|
||||
model = xgb.train(params, train_data, 1)
|
||||
preds = model.predict(train_data)
|
||||
|
||||
# Default prediction (with no trees) is 0.5
|
||||
# sum_grad = (0.5 - 1.0)
|
||||
# sum_hess = 1.0
|
||||
# 0.9 = 0.5 - (sum_grad - alpha * sgn(sum_grad)) / sum_hess
|
||||
assert_approx_equal(preds[0], 0.9)
|
||||
|
||||
def test_lambda(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
|
||||
'eta': 1,
|
||||
'lambda': 1,
|
||||
'alpha': 0
|
||||
}
|
||||
|
||||
model = xgb.train(params, train_data, 1)
|
||||
preds = model.predict(train_data)
|
||||
|
||||
# Default prediction (with no trees) is 0.5
|
||||
# sum_grad = (0.5 - 1.0)
|
||||
# sum_hess = 1.0
|
||||
# 0.75 = 0.5 - sum_grad / (sum_hess + lambda)
|
||||
assert_approx_equal(preds[0], 0.75)
|
||||
|
||||
def test_alpha_and_lambda(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
|
||||
'eta': 1,
|
||||
'lambda': 1,
|
||||
'alpha': 0.1
|
||||
}
|
||||
|
||||
model = xgb.train(params, train_data, 1)
|
||||
preds = model.predict(train_data)
|
||||
|
||||
# Default prediction (with no trees) is 0.5
|
||||
# sum_grad = (0.5 - 1.0)
|
||||
# sum_hess = 1.0
|
||||
# 0.7 = 0.5 - (sum_grad - alpha * sgn(sum_grad)) / (sum_hess + lambda)
|
||||
assert_approx_equal(preds[0], 0.7)
|
||||
Loading…
x
Reference in New Issue
Block a user