Refactor linear modelling and add new coordinate descent updater (#3103)

* Refactor linear modelling and add new coordinate descent updater

* Allow unsorted column iterator

* Add prediction cacheing to gblinear
This commit is contained in:
Rory Mitchell
2018-02-17 09:17:01 +13:00
committed by GitHub
parent 9ffe8596f2
commit 10eb05a63a
23 changed files with 1252 additions and 271 deletions

View File

@@ -9,92 +9,66 @@
#include <dmlc/parameter.h>
#include <xgboost/gbm.h>
#include <xgboost/logging.h>
#include <xgboost/linear_updater.h>
#include <vector>
#include <string>
#include <sstream>
#include <cstring>
#include <algorithm>
#include "../common/timer.h"
namespace xgboost {
namespace gbm {
DMLC_REGISTRY_FILE_TAG(gblinear);
// model parameter
struct GBLinearModelParam :public dmlc::Parameter<GBLinearModelParam> {
// number of feature dimension
unsigned num_feature;
// number of output group
int num_output_group;
// reserved field
int reserved[32];
// constructor
GBLinearModelParam() {
std::memset(this, 0, sizeof(GBLinearModelParam));
}
DMLC_DECLARE_PARAMETER(GBLinearModelParam) {
DMLC_DECLARE_FIELD(num_feature).set_lower_bound(0)
.describe("Number of features used in classification.");
DMLC_DECLARE_FIELD(num_output_group).set_lower_bound(1).set_default(1)
.describe("Number of output groups in the setting.");
}
};
// training parameter
struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
/*! \brief learning_rate */
float learning_rate;
/*! \brief regularization weight for L2 norm */
float reg_lambda;
/*! \brief regularization weight for L1 norm */
float reg_alpha;
/*! \brief regularization weight for L2 norm in bias */
float reg_lambda_bias;
std::string updater;
// flag to print out detailed breakdown of runtime
int debug_verbose;
float tolerance;
// declare parameters
DMLC_DECLARE_PARAMETER(GBLinearTrainParam) {
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(1.0f)
.describe("Learning rate of each update.");
DMLC_DECLARE_FIELD(reg_lambda).set_lower_bound(0.0f).set_default(0.0f)
.describe("L2 regularization on weights.");
DMLC_DECLARE_FIELD(reg_alpha).set_lower_bound(0.0f).set_default(0.0f)
.describe("L1 regularization on weights.");
DMLC_DECLARE_FIELD(reg_lambda_bias).set_lower_bound(0.0f).set_default(0.0f)
.describe("L2 regularization on bias.");
// alias of parameters
DMLC_DECLARE_ALIAS(learning_rate, eta);
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
DMLC_DECLARE_ALIAS(reg_lambda_bias, lambda_bias);
}
// given original weight calculate delta
inline double CalcDelta(double sum_grad, double sum_hess, double w) const {
if (sum_hess < 1e-5f) return 0.0f;
double tmp = w - (sum_grad + reg_lambda * w) / (sum_hess + reg_lambda);
if (tmp >=0) {
return std::max(-(sum_grad + reg_lambda * w + reg_alpha) / (sum_hess + reg_lambda), -w);
} else {
return std::min(-(sum_grad + reg_lambda * w - reg_alpha) / (sum_hess + reg_lambda), -w);
}
}
// given original weight calculate delta bias
inline double CalcDeltaBias(double sum_grad, double sum_hess, double w) const {
return - (sum_grad + reg_lambda_bias * w) / (sum_hess + reg_lambda_bias);
DMLC_DECLARE_FIELD(updater)
.set_default("shotgun")
.describe("Update algorithm for linear model. One of shotgun/coord_descent");
DMLC_DECLARE_FIELD(tolerance)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("Stop if largest weight update is smaller than this number.");
DMLC_DECLARE_FIELD(debug_verbose)
.set_lower_bound(0)
.set_default(0)
.describe("flag to print out detailed breakdown of runtime");
}
};
/*!
* \brief gradient boosted linear model
*/
class GBLinear : public GradientBooster {
public:
explicit GBLinear(bst_float base_margin)
: base_margin_(base_margin) {
explicit GBLinear(const std::vector<std::shared_ptr<DMatrix> > &cache,
bst_float base_margin)
: base_margin_(base_margin),
sum_instance_weight(0),
sum_weight_complete(false),
is_converged(false) {
// Add matrices to the prediction cache
for (auto &d : cache) {
PredictionCacheEntry e;
e.data = d;
cache_[d.get()] = std::move(e);
}
}
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
if (model.weight.size() == 0) {
model.param.InitAllowUnknown(cfg);
}
param.InitAllowUnknown(cfg);
updater.reset(LinearUpdater::Create(param.updater));
updater->Init(cfg);
monitor.Init("GBLinear ", param.debug_verbose);
}
void Load(dmlc::Stream* fi) override {
model.Load(fi);
@@ -102,108 +76,44 @@ class GBLinear : public GradientBooster {
void Save(dmlc::Stream* fo) const override {
model.Save(fo);
}
void DoBoost(DMatrix *p_fmat,
std::vector<bst_gpair> *in_gpair,
ObjFunction* obj) override {
// lazily initialize the model when not ready.
if (model.weight.size() == 0) {
model.InitModel();
void DoBoost(DMatrix *p_fmat, std::vector<bst_gpair> *in_gpair,
ObjFunction *obj) override {
monitor.Start("DoBoost");
if (!p_fmat->HaveColAccess(false)) {
std::vector<bool> enabled(p_fmat->info().num_col, true);
p_fmat->InitColAccess(enabled, 1.0f, std::numeric_limits<size_t>::max(),
false);
}
std::vector<bst_gpair> &gpair = *in_gpair;
const int ngroup = model.param.num_output_group;
const RowSet &rowset = p_fmat->buffered_rowset();
// for all the output group
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess)
for (bst_omp_uint i = 0; i < ndata; ++i) {
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
if (p.GetHess() >= 0.0f) {
sum_grad += p.GetGrad();
sum_hess += p.GetHess();
}
}
// remove bias effect
bst_float dw = static_cast<bst_float>(
param.learning_rate * param.CalcDeltaBias(sum_grad, sum_hess, model.bias()[gid]));
model.bias()[gid] += dw;
// update grad value
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
if (p.GetHess() >= 0.0f) {
p += bst_gpair(p.GetHess() * dw, 0);
}
}
}
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
while (iter->Next()) {
// number of features
const ColBatch &batch = iter->Value();
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const bst_uint fid = batch.col_index[i];
ColBatch::Inst col = batch[i];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (bst_uint j = 0; j < col.length; ++j) {
const bst_float v = col[j].fvalue;
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.GetHess() < 0.0f) continue;
sum_grad += p.GetGrad() * v;
sum_hess += p.GetHess() * v * v;
}
bst_float &w = model[fid][gid];
bst_float dw = static_cast<bst_float>(param.learning_rate *
param.CalcDelta(sum_grad, sum_hess, w));
w += dw;
// update grad value
for (bst_uint j = 0; j < col.length; ++j) {
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.GetHess() < 0.0f) continue;
p += bst_gpair(p.GetHess() * col[j].fvalue * dw, 0);
}
}
}
model.LazyInitModel();
this->LazySumWeights(p_fmat);
if (!this->CheckConvergence()) {
updater->Update(in_gpair, p_fmat, &model, sum_instance_weight);
}
this->UpdatePredictionCache();
monitor.Stop("DoBoost");
}
void PredictBatch(DMatrix *p_fmat,
std::vector<bst_float> *out_preds,
unsigned ntree_limit) override {
if (model.weight.size() == 0) {
model.InitModel();
}
void PredictBatch(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
unsigned ntree_limit) override {
monitor.Start("PredictBatch");
CHECK_EQ(ntree_limit, 0U)
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
std::vector<bst_float> &preds = *out_preds;
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
preds.resize(0);
// start collecting the prediction
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
const int ngroup = model.param.num_output_group;
while (iter->Next()) {
const RowBatch &batch = iter->Value();
CHECK_EQ(batch.base_rowid * ngroup, preds.size());
// output convention: nrow * k, where nrow is number of rows
// k is number of group
preds.resize(preds.size() + batch.size * ngroup);
// parallel over local batch
const omp_ulong nsize = static_cast<omp_ulong>(batch.size);
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < nsize; ++i) {
const size_t ridx = batch.base_rowid + i;
// loop over output groups
for (int gid = 0; gid < ngroup; ++gid) {
bst_float margin = (base_margin.size() != 0) ?
base_margin[ridx * ngroup + gid] : base_margin_;
this->Pred(batch[i], &preds[ridx * ngroup], gid, margin);
}
}
// Try to predict from cache
auto it = cache_.find(p_fmat);
if (it != cache_.end() && it->second.predictions.size() != 0) {
std::vector<bst_float> &y = it->second.predictions;
out_preds->resize(y.size());
std::copy(y.begin(), y.end(), out_preds->begin());
} else {
this->PredictBatchInternal(p_fmat, out_preds);
}
monitor.Stop("PredictBatch");
}
// add base margin
void PredictInstance(const SparseBatch::Inst &inst,
@@ -226,9 +136,7 @@ class GBLinear : public GradientBooster {
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int condition = 0,
unsigned condition_feature = 0) override {
if (model.weight.size() == 0) {
model.InitModel();
}
model.LazyInitModel();
CHECK_EQ(ntree_limit, 0U)
<< "GBLinear::PredictContribution: ntrees is only valid for gbtree predictor";
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
@@ -317,7 +225,74 @@ class GBLinear : public GradientBooster {
}
protected:
inline void Pred(const RowBatch::Inst &inst, bst_float *preds, int gid, bst_float base) {
void PredictBatchInternal(DMatrix *p_fmat,
std::vector<bst_float> *out_preds) {
monitor.Start("PredictBatchInternal");
model.LazyInitModel();
std::vector<bst_float> &preds = *out_preds;
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
// start collecting the prediction
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
const int ngroup = model.param.num_output_group;
preds.resize(p_fmat->info().num_row * ngroup);
while (iter->Next()) {
const RowBatch &batch = iter->Value();
// output convention: nrow * k, where nrow is number of rows
// k is number of group
// parallel over local batch
const omp_ulong nsize = static_cast<omp_ulong>(batch.size);
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < nsize; ++i) {
const size_t ridx = batch.base_rowid + i;
// loop over output groups
for (int gid = 0; gid < ngroup; ++gid) {
bst_float margin = (base_margin.size() != 0) ?
base_margin[ridx * ngroup + gid] : base_margin_;
this->Pred(batch[i], &preds[ridx * ngroup], gid, margin);
}
}
}
monitor.Stop("PredictBatchInternal");
}
void UpdatePredictionCache() {
// update cache entry
for (auto &kv : cache_) {
PredictionCacheEntry &e = kv.second;
if (e.predictions.size() == 0) {
size_t n = model.param.num_output_group * e.data->info().num_row;
e.predictions.resize(n);
}
this->PredictBatchInternal(e.data.get(), &e.predictions);
}
}
bool CheckConvergence() {
if (param.tolerance == 0.0f) return false;
if (is_converged) return true;
if (previous_model.weight.size() != model.weight.size()) return false;
float largest_dw = 0.0;
for (auto i = 0; i < model.weight.size(); i++) {
largest_dw = std::max(
largest_dw, std::abs(model.weight[i] - previous_model.weight[i]));
}
previous_model = model;
is_converged = largest_dw <= param.tolerance;
return is_converged;
}
void LazySumWeights(DMatrix *p_fmat) {
if (!sum_weight_complete) {
auto &info = p_fmat->info();
for (int i = 0; i < info.num_row; i++) {
sum_instance_weight += info.GetWeight(i);
}
sum_weight_complete = true;
}
}
inline void Pred(const RowBatch::Inst &inst, bst_float *preds, int gid,
bst_float base) {
bst_float psum = model.bias()[gid] + base;
for (bst_uint i = 0; i < inst.length; ++i) {
if (inst[i].index >= model.param.num_feature) continue;
@@ -325,52 +300,33 @@ class GBLinear : public GradientBooster {
}
preds[gid] = psum;
}
// model for linear booster
class Model {
public:
// parameter
GBLinearModelParam param;
// weight for each of feature, bias is the last one
std::vector<bst_float> weight;
// initialize the model parameter
inline void InitModel(void) {
// bias is the last weight
weight.resize((param.num_feature + 1) * param.num_output_group);
std::fill(weight.begin(), weight.end(), 0.0f);
}
// save the model to file
inline void Save(dmlc::Stream* fo) const {
fo->Write(&param, sizeof(param));
fo->Write(weight);
}
// load model from file
inline void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param));
fi->Read(&weight);
}
// model bias
inline bst_float* bias() {
return &weight[param.num_feature * param.num_output_group];
}
inline const bst_float* bias() const {
return &weight[param.num_feature * param.num_output_group];
}
// get i-th weight
inline bst_float* operator[](size_t i) {
return &weight[i * param.num_output_group];
}
inline const bst_float* operator[](size_t i) const {
return &weight[i * param.num_output_group];
}
};
// biase margin score
bst_float base_margin_;
// model field
Model model;
// training parameter
GBLinearModel model;
GBLinearModel previous_model;
GBLinearTrainParam param;
// Per feature: shuffle index of each feature index
std::vector<bst_uint> feat_index;
std::unique_ptr<LinearUpdater> updater;
double sum_instance_weight;
bool sum_weight_complete;
common::Monitor monitor;
bool is_converged;
/**
* \struct PredictionCacheEntry
*
* \brief Contains pointer to input matrix and associated cached predictions.
*/
struct PredictionCacheEntry {
std::shared_ptr<DMatrix> data;
std::vector<bst_float> predictions;
};
/**
* \brief Map of matrices and associated cached predictions to facilitate
* storing and looking up predictions.
*/
std::unordered_map<DMatrix*, PredictionCacheEntry> cache_;
};
// register the objective functions
@@ -378,9 +334,10 @@ DMLC_REGISTER_PARAMETER(GBLinearModelParam);
DMLC_REGISTER_PARAMETER(GBLinearTrainParam);
XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
.describe("Linear booster, implement generalized linear model.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> >&cache, bst_float base_margin) {
return new GBLinear(base_margin);
});
.describe("Linear booster, implement generalized linear model.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> > &cache,
bst_float base_margin) {
return new GBLinear(cache, base_margin);
});
} // namespace gbm
} // namespace xgboost

73
src/gbm/gblinear_model.h Normal file
View File

@@ -0,0 +1,73 @@
/*!
* Copyright by Contributors 2018
*/
#pragma once
#include <dmlc/io.h>
#include <dmlc/parameter.h>
#include <vector>
#include <cstring>
namespace xgboost {
namespace gbm {
// model parameter
struct GBLinearModelParam : public dmlc::Parameter<GBLinearModelParam> {
// number of feature dimension
unsigned num_feature;
// number of output group
int num_output_group;
// reserved field
int reserved[32];
// constructor
GBLinearModelParam() { std::memset(this, 0, sizeof(GBLinearModelParam)); }
DMLC_DECLARE_PARAMETER(GBLinearModelParam) {
DMLC_DECLARE_FIELD(num_feature)
.set_lower_bound(0)
.describe("Number of features used in classification.");
DMLC_DECLARE_FIELD(num_output_group)
.set_lower_bound(1)
.set_default(1)
.describe("Number of output groups in the setting.");
}
};
// model for linear booster
class GBLinearModel {
public:
// parameter
GBLinearModelParam param;
// weight for each of feature, bias is the last one
std::vector<bst_float> weight;
// initialize the model parameter
inline void LazyInitModel(void) {
if (!weight.empty()) return;
// bias is the last weight
weight.resize((param.num_feature + 1) * param.num_output_group);
std::fill(weight.begin(), weight.end(), 0.0f);
}
// save the model to file
inline void Save(dmlc::Stream* fo) const {
fo->Write(&param, sizeof(param));
fo->Write(weight);
}
// load model from file
inline void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param));
fi->Read(&weight);
}
// model bias
inline bst_float* bias() {
return &weight[param.num_feature * param.num_output_group];
}
inline const bst_float* bias() const {
return &weight[param.num_feature * param.num_output_group];
}
// get i-th weight
inline bst_float* operator[](size_t i) {
return &weight[i * param.num_output_group];
}
inline const bst_float* operator[](size_t i) const {
return &weight[i * param.num_output_group];
}
};
} // namespace gbm
} // namespace xgboost