From 10eb05a63a8c0e2dd2d8b8693dfff43481477d9f Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sat, 17 Feb 2018 09:17:01 +1300 Subject: [PATCH] Refactor linear modelling and add new coordinate descent updater (#3103) * Refactor linear modelling and add new coordinate descent updater * Allow unsorted column iterator * Add prediction cacheing to gblinear --- NEWS.md | 2 + amalgamation/xgboost-all0.cc | 5 + doc/parameter.md | 11 +- include/xgboost/data.h | 6 +- include/xgboost/linear_updater.h | 66 ++++ src/common/timer.h | 4 +- src/data/simple_dmatrix.cc | 58 +-- src/data/simple_dmatrix.h | 16 +- src/data/sparse_page_dmatrix.cc | 22 +- src/data/sparse_page_dmatrix.h | 8 +- src/gbm/gblinear.cc | 361 ++++++++---------- src/gbm/gblinear_model.h | 73 ++++ src/learner.cc | 8 +- src/linear/coordinate_common.h | 321 ++++++++++++++++ src/linear/linear_updater.cc | 29 ++ src/linear/updater_coordinate.cc | 124 ++++++ src/linear/updater_shotgun.cc | 127 ++++++ tests/benchmark/benchmark_linear.py | 69 ++++ .../{benchmark.py => benchmark_tree.py} | 0 tests/cpp/data/test_simple_dmatrix.cc | 30 +- tests/cpp/data/test_sparse_page_dmatrix.cc | 6 +- tests/cpp/linear/test_linear.cc | 44 +++ tests/python/test_linear.py | 133 +++++++ 23 files changed, 1252 insertions(+), 271 deletions(-) create mode 100644 include/xgboost/linear_updater.h create mode 100644 src/gbm/gblinear_model.h create mode 100644 src/linear/coordinate_common.h create mode 100644 src/linear/linear_updater.cc create mode 100644 src/linear/updater_coordinate.cc create mode 100644 src/linear/updater_shotgun.cc create mode 100644 tests/benchmark/benchmark_linear.py rename tests/benchmark/{benchmark.py => benchmark_tree.py} (100%) create mode 100644 tests/cpp/linear/test_linear.cc create mode 100644 tests/python/test_linear.py diff --git a/NEWS.md b/NEWS.md index 97b8fe375..122ddaf24 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,8 @@ XGBoost Change Log This file records the changes in xgboost library in reverse chronological order. +* BREAKING CHANGES: Updated linear modelling algorithms. In particular L1/L2 regularisation penalties are now normalised to number of training examples. This makes the implementation consistent with sklearn/glmnet. L2 regularisation has also been removed from the intercept. To produce linear models with the old regularisation behaviour, the alpha/lambda regularisation parameters can be manually scaled by dividing them by the number of training examples. + ## v0.7 (2017.12.30) * **This version represents a major change from the last release (v0.6), which was released one year and half ago.** * Updated Sklearn API diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 4ad5fe96a..d6082e03a 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -53,6 +53,11 @@ #include "../src/tree/updater_histmaker.cc" #include "../src/tree/updater_skmaker.cc" +// linear +#include "../src/linear/linear_updater.cc" +#include "../src/linear/updater_coordinate.cc" +#include "../src/linear/updater_shotgun.cc" + // global #include "../src/learner.cc" #include "../src/logging.cc" diff --git a/doc/parameter.md b/doc/parameter.md index 784f1209f..1aa2176a6 100644 --- a/doc/parameter.md +++ b/doc/parameter.md @@ -142,11 +142,14 @@ Additional parameters for Dart Booster Parameters for Linear Booster ----------------------------- * lambda [default=0, alias: reg_lambda] - - L2 regularization term on weights, increase this value will make model more conservative. + - L2 regularization term on weights, increase this value will make model more conservative. Normalised to number of training examples. * alpha [default=0, alias: reg_alpha] - - L1 regularization term on weights, increase this value will make model more conservative. -* lambda_bias [default=0, alias: reg_lambda_bias] - - L2 regularization term on bias (no L1 reg on bias because it is not important) + - L1 regularization term on weights, increase this value will make model more conservative. Normalised to number of training examples. +* updater [default='shotgun'] + - Linear model algorithm + - 'shotgun': Parallel coordinate descent algorithm based on shotgun algorithm. Uses 'hogwild' parallelism and therefore produces a nondeterministic solution on each run. + - 'coord_descent': Ordinary coordinate descent algorithm. Also multithreaded but still produces a deterministic solution. + Parameters for Tweedie Regression --------------------------------- diff --git a/include/xgboost/data.h b/include/xgboost/data.h index f7d9812ab..24a3a1f3f 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -274,14 +274,16 @@ class DMatrix { * \param subsample subsample ratio when generating column access. * \param max_row_perbatch auxiliary information, maximum row used in each column batch. * this is a hint information that can be ignored by the implementation. + * \param sorted If column features should be in sorted order * \return Number of column blocks in the column access. */ + virtual void InitColAccess(const std::vector& enabled, float subsample, - size_t max_row_perbatch) = 0; + size_t max_row_perbatch, bool sorted) = 0; // the following are column meta data, should be able to answer them fast. /*! \return whether column access is enabled */ - virtual bool HaveColAccess() const = 0; + virtual bool HaveColAccess(bool sorted) const = 0; /*! \return Whether the data columns single column block. */ virtual bool SingleColBlock() const = 0; /*! \brief get number of non-missing entries in column */ diff --git a/include/xgboost/linear_updater.h b/include/xgboost/linear_updater.h new file mode 100644 index 000000000..b91d598ee --- /dev/null +++ b/include/xgboost/linear_updater.h @@ -0,0 +1,66 @@ +/* + * Copyright 2018 by Contributors + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "../../src/gbm/gblinear_model.h" + +namespace xgboost { +/*! + * \brief interface of linear updater + */ +class LinearUpdater { + public: + /*! \brief virtual destructor */ + virtual ~LinearUpdater() {} + /*! + * \brief Initialize the updater with given arguments. + * \param args arguments to the objective function. + */ + virtual void Init( + const std::vector >& args) = 0; + + /** + * \brief Updates linear model given gradients. + * + * \param in_gpair The gradient pair statistics of the data. + * \param data Input data matrix. + * \param model Model to be updated. + * \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty. + */ + + virtual void Update(std::vector* in_gpair, DMatrix* data, + gbm::GBLinearModel* model, + double sum_instance_weight) = 0; + + /*! + * \brief Create a linear updater given name + * \param name Name of the linear updater. + */ + static LinearUpdater* Create(const std::string& name); +}; + +/*! + * \brief Registry entry for linear updater. + */ +struct LinearUpdaterReg + : public dmlc::FunctionRegEntryBase > {}; + +/*! + * \brief Macro to register linear updater. + */ +#define XGBOOST_REGISTER_LINEAR_UPDATER(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::LinearUpdaterReg& \ + __make_##LinearUpdaterReg##_##UniqueId##__ = \ + ::dmlc::Registry< ::xgboost::LinearUpdaterReg>::Get()->__REGISTER__( \ + Name) + +} // namespace xgboost diff --git a/src/common/timer.h b/src/common/timer.h index 418119379..d1cd53ac7 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -2,6 +2,7 @@ * Copyright by Contributors 2017 */ #pragma once +#include #include #include #include @@ -28,7 +29,8 @@ struct Timer { double ElapsedSeconds() const { return SecondsT(elapsed).count(); } void PrintElapsed(std::string label) { char buffer[255]; - snprintf(buffer, sizeof(buffer), "%s:\t %fs", label.c_str(), SecondsT(elapsed).count()); + snprintf(buffer, sizeof(buffer), "%s:\t %fs", label.c_str(), + SecondsT(elapsed).count()); LOG(CONSOLE) << buffer; Reset(); } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index c7c0b3d1b..42c836545 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -54,16 +54,16 @@ dmlc::DataIter* SimpleDMatrix::ColIterator(const std::vector void SimpleDMatrix::InitColAccess(const std::vector &enabled, float pkeep, - size_t max_row_perbatch) { - if (this->HaveColAccess()) return; - + size_t max_row_perbatch, bool sorted) { + if (this->HaveColAccess(sorted)) return; + col_iter_.sorted = sorted; col_iter_.cpages_.clear(); if (info().num_row < max_row_perbatch) { std::unique_ptr page(new SparsePage()); - this->MakeOneBatch(enabled, pkeep, page.get()); + this->MakeOneBatch(enabled, pkeep, page.get(), sorted); col_iter_.cpages_.push_back(std::move(page)); } else { - this->MakeManyBatch(enabled, pkeep, max_row_perbatch); + this->MakeManyBatch(enabled, pkeep, max_row_perbatch, sorted); } // setup col-size col_size_.resize(info().num_col); @@ -77,9 +77,8 @@ void SimpleDMatrix::InitColAccess(const std::vector &enabled, } // internal function to make one batch from row iter. -void SimpleDMatrix::MakeOneBatch(const std::vector& enabled, - float pkeep, - SparsePage *pcol) { +void SimpleDMatrix::MakeOneBatch(const std::vector& enabled, float pkeep, + SparsePage* pcol, bool sorted) { // clear rowset buffered_rowset_.clear(); // bit map @@ -144,21 +143,24 @@ void SimpleDMatrix::MakeOneBatch(const std::vector& enabled, } CHECK_EQ(pcol->Size(), info().num_col); - // sort columns - bst_omp_uint ncol = static_cast(pcol->Size()); - #pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) - for (bst_omp_uint i = 0; i < ncol; ++i) { - if (pcol->offset[i] < pcol->offset[i + 1]) { - std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], - dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], - SparseBatch::Entry::CmpValue); + + if (sorted) { + // sort columns + bst_omp_uint ncol = static_cast(pcol->Size()); +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) + for (bst_omp_uint i = 0; i < ncol; ++i) { + if (pcol->offset[i] < pcol->offset[i + 1]) { + std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], + dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], + SparseBatch::Entry::CmpValue); + } } } } void SimpleDMatrix::MakeManyBatch(const std::vector& enabled, float pkeep, - size_t max_row_perbatch) { + size_t max_row_perbatch, bool sorted) { size_t btop = 0; std::bernoulli_distribution coin_flip(pkeep); auto& rnd = common::GlobalRandom(); @@ -179,7 +181,7 @@ void SimpleDMatrix::MakeManyBatch(const std::vector& enabled, } if (tmp.Size() >= max_row_perbatch) { std::unique_ptr page(new SparsePage()); - this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get()); + this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted); col_iter_.cpages_.push_back(std::move(page)); btop = buffered_rowset_.size(); tmp.Clear(); @@ -189,7 +191,7 @@ void SimpleDMatrix::MakeManyBatch(const std::vector& enabled, if (tmp.Size() != 0) { std::unique_ptr page(new SparsePage()); - this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get()); + this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted); col_iter_.cpages_.push_back(std::move(page)); } } @@ -198,7 +200,7 @@ void SimpleDMatrix::MakeManyBatch(const std::vector& enabled, void SimpleDMatrix::MakeColPage(const RowBatch& batch, size_t buffer_begin, const std::vector& enabled, - SparsePage* pcol) { + SparsePage* pcol, bool sorted) { const int nthread = std::min(omp_get_max_threads(), std::max(omp_get_num_procs() / 2 - 2, 1)); pcol->Clear(); common::ParallelGroupBuilder @@ -231,13 +233,15 @@ void SimpleDMatrix::MakeColPage(const RowBatch& batch, } CHECK_EQ(pcol->Size(), info().num_col); // sort columns - bst_omp_uint ncol = static_cast(pcol->Size()); - #pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) - for (bst_omp_uint i = 0; i < ncol; ++i) { - if (pcol->offset[i] < pcol->offset[i + 1]) { - std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], - dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], - SparseBatch::Entry::CmpValue); + if (sorted) { + bst_omp_uint ncol = static_cast(pcol->Size()); +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) + for (bst_omp_uint i = 0; i < ncol; ++i) { + if (pcol->offset[i] < pcol->offset[i + 1]) { + std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], + dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], + SparseBatch::Entry::CmpValue); + } } } } diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 81454dc7f..58d60c444 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -36,8 +36,8 @@ class SimpleDMatrix : public DMatrix { return iter; } - bool HaveColAccess() const override { - return col_size_.size() != 0; + bool HaveColAccess(bool sorted) const override { + return col_size_.size() != 0 && col_iter_.sorted == sorted; } const RowSet& buffered_rowset() const override { @@ -59,7 +59,7 @@ class SimpleDMatrix : public DMatrix { void InitColAccess(const std::vector& enabled, float subsample, - size_t max_row_perbatch) override; + size_t max_row_perbatch, bool sorted) override; bool SingleColBlock() const override; @@ -67,7 +67,7 @@ class SimpleDMatrix : public DMatrix { // in-memory column batch iterator. struct ColBatchIter: dmlc::DataIter { public: - ColBatchIter() : data_ptr_(0) {} + ColBatchIter() : data_ptr_(0), sorted(false) {} void BeforeFirst() override { data_ptr_ = 0; } @@ -89,6 +89,8 @@ class SimpleDMatrix : public DMatrix { size_t data_ptr_; // temporal space for batch ColBatch batch_; + // Is column sorted? + bool sorted; }; // source data pointer. @@ -103,16 +105,16 @@ class SimpleDMatrix : public DMatrix { // internal function to make one batch from row iter. void MakeOneBatch(const std::vector& enabled, float pkeep, - SparsePage *pcol); + SparsePage *pcol, bool sorted); void MakeManyBatch(const std::vector& enabled, float pkeep, - size_t max_row_perbatch); + size_t max_row_perbatch, bool sorted); void MakeColPage(const RowBatch& batch, size_t buffer_begin, const std::vector& enabled, - SparsePage* pcol); + SparsePage* pcol, bool sorted); }; } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 676d0dcd5..8ad2edc6a 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -145,10 +145,9 @@ bool SparsePageDMatrix::TryInitColData() { void SparsePageDMatrix::InitColAccess(const std::vector& enabled, float pkeep, - size_t max_row_perbatch) { - if (HaveColAccess()) return; + size_t max_row_perbatch, bool sorted) { + if (HaveColAccess(sorted)) return; if (TryInitColData()) return; - const MetaInfo& info = this->info(); if (max_row_perbatch == std::numeric_limits::max()) { max_row_perbatch = kMaxRowPerBatch; @@ -197,13 +196,15 @@ void SparsePageDMatrix::InitColAccess(const std::vector& enabled, } CHECK_EQ(pcol->Size(), info.num_col); // sort columns - bst_omp_uint ncol = static_cast(pcol->Size()); - #pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) - for (bst_omp_uint i = 0; i < ncol; ++i) { - if (pcol->offset[i] < pcol->offset[i + 1]) { - std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], - dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], - SparseBatch::Entry::CmpValue); + if (sorted) { + bst_omp_uint ncol = static_cast(pcol->Size()); +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) + for (bst_omp_uint i = 0; i < ncol; ++i) { + if (pcol->offset[i] < pcol->offset[i + 1]) { + std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], + dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], + SparseBatch::Entry::CmpValue); + } } } }; @@ -291,6 +292,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector& enabled, } // initialize column data CHECK(TryInitColData()); + col_iter_->sorted = sorted; } } // namespace data diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 4bd750a20..4c99e72cc 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -40,8 +40,8 @@ class SparsePageDMatrix : public DMatrix { return iter; } - bool HaveColAccess() const override { - return col_iter_.get() != nullptr; + bool HaveColAccess(bool sorted) const override { + return col_iter_.get() != nullptr && col_iter_->sorted == sorted; } const RowSet& buffered_rowset() const override { @@ -67,7 +67,7 @@ class SparsePageDMatrix : public DMatrix { void InitColAccess(const std::vector& enabled, float subsample, - size_t max_row_perbatch) override; + size_t max_row_perbatch, bool sorted) override; /*! \brief page size 256 MB */ static const size_t kPageSize = 256UL << 20UL; @@ -87,6 +87,8 @@ class SparsePageDMatrix : public DMatrix { bool Next() override; // initialize the column iterator with the specified index set. void Init(const std::vector& index_set, bool load_all); + // If the column features are sorted + bool sorted; private: // the temp page. diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 4fa6ad230..adccf6239 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -9,92 +9,66 @@ #include #include #include +#include #include #include #include -#include #include +#include "../common/timer.h" namespace xgboost { namespace gbm { DMLC_REGISTRY_FILE_TAG(gblinear); -// model parameter -struct GBLinearModelParam :public dmlc::Parameter { - // number of feature dimension - unsigned num_feature; - // number of output group - int num_output_group; - // reserved field - int reserved[32]; - // constructor - GBLinearModelParam() { - std::memset(this, 0, sizeof(GBLinearModelParam)); - } - DMLC_DECLARE_PARAMETER(GBLinearModelParam) { - DMLC_DECLARE_FIELD(num_feature).set_lower_bound(0) - .describe("Number of features used in classification."); - DMLC_DECLARE_FIELD(num_output_group).set_lower_bound(1).set_default(1) - .describe("Number of output groups in the setting."); - } -}; - // training parameter struct GBLinearTrainParam : public dmlc::Parameter { /*! \brief learning_rate */ - float learning_rate; - /*! \brief regularization weight for L2 norm */ - float reg_lambda; - /*! \brief regularization weight for L1 norm */ - float reg_alpha; - /*! \brief regularization weight for L2 norm in bias */ - float reg_lambda_bias; + std::string updater; + // flag to print out detailed breakdown of runtime + int debug_verbose; + float tolerance; // declare parameters DMLC_DECLARE_PARAMETER(GBLinearTrainParam) { - DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(1.0f) - .describe("Learning rate of each update."); - DMLC_DECLARE_FIELD(reg_lambda).set_lower_bound(0.0f).set_default(0.0f) - .describe("L2 regularization on weights."); - DMLC_DECLARE_FIELD(reg_alpha).set_lower_bound(0.0f).set_default(0.0f) - .describe("L1 regularization on weights."); - DMLC_DECLARE_FIELD(reg_lambda_bias).set_lower_bound(0.0f).set_default(0.0f) - .describe("L2 regularization on bias."); - // alias of parameters - DMLC_DECLARE_ALIAS(learning_rate, eta); - DMLC_DECLARE_ALIAS(reg_lambda, lambda); - DMLC_DECLARE_ALIAS(reg_alpha, alpha); - DMLC_DECLARE_ALIAS(reg_lambda_bias, lambda_bias); - } - // given original weight calculate delta - inline double CalcDelta(double sum_grad, double sum_hess, double w) const { - if (sum_hess < 1e-5f) return 0.0f; - double tmp = w - (sum_grad + reg_lambda * w) / (sum_hess + reg_lambda); - if (tmp >=0) { - return std::max(-(sum_grad + reg_lambda * w + reg_alpha) / (sum_hess + reg_lambda), -w); - } else { - return std::min(-(sum_grad + reg_lambda * w - reg_alpha) / (sum_hess + reg_lambda), -w); - } - } - // given original weight calculate delta bias - inline double CalcDeltaBias(double sum_grad, double sum_hess, double w) const { - return - (sum_grad + reg_lambda_bias * w) / (sum_hess + reg_lambda_bias); + DMLC_DECLARE_FIELD(updater) + .set_default("shotgun") + .describe("Update algorithm for linear model. One of shotgun/coord_descent"); + DMLC_DECLARE_FIELD(tolerance) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe("Stop if largest weight update is smaller than this number."); + DMLC_DECLARE_FIELD(debug_verbose) + .set_lower_bound(0) + .set_default(0) + .describe("flag to print out detailed breakdown of runtime"); } }; - /*! * \brief gradient boosted linear model */ class GBLinear : public GradientBooster { public: - explicit GBLinear(bst_float base_margin) - : base_margin_(base_margin) { + explicit GBLinear(const std::vector > &cache, + bst_float base_margin) + : base_margin_(base_margin), + sum_instance_weight(0), + sum_weight_complete(false), + is_converged(false) { + // Add matrices to the prediction cache + for (auto &d : cache) { + PredictionCacheEntry e; + e.data = d; + cache_[d.get()] = std::move(e); + } } void Configure(const std::vector >& cfg) override { if (model.weight.size() == 0) { model.param.InitAllowUnknown(cfg); } param.InitAllowUnknown(cfg); + updater.reset(LinearUpdater::Create(param.updater)); + updater->Init(cfg); + monitor.Init("GBLinear ", param.debug_verbose); } void Load(dmlc::Stream* fi) override { model.Load(fi); @@ -102,108 +76,44 @@ class GBLinear : public GradientBooster { void Save(dmlc::Stream* fo) const override { model.Save(fo); } - void DoBoost(DMatrix *p_fmat, - std::vector *in_gpair, - ObjFunction* obj) override { - // lazily initialize the model when not ready. - if (model.weight.size() == 0) { - model.InitModel(); + void DoBoost(DMatrix *p_fmat, std::vector *in_gpair, + ObjFunction *obj) override { + monitor.Start("DoBoost"); + + if (!p_fmat->HaveColAccess(false)) { + std::vector enabled(p_fmat->info().num_col, true); + p_fmat->InitColAccess(enabled, 1.0f, std::numeric_limits::max(), + false); } - std::vector &gpair = *in_gpair; - const int ngroup = model.param.num_output_group; - const RowSet &rowset = p_fmat->buffered_rowset(); - // for all the output group - for (int gid = 0; gid < ngroup; ++gid) { - double sum_grad = 0.0, sum_hess = 0.0; - const bst_omp_uint ndata = static_cast(rowset.size()); - #pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess) - for (bst_omp_uint i = 0; i < ndata; ++i) { - bst_gpair &p = gpair[rowset[i] * ngroup + gid]; - if (p.GetHess() >= 0.0f) { - sum_grad += p.GetGrad(); - sum_hess += p.GetHess(); - } - } - // remove bias effect - bst_float dw = static_cast( - param.learning_rate * param.CalcDeltaBias(sum_grad, sum_hess, model.bias()[gid])); - model.bias()[gid] += dw; - // update grad value - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < ndata; ++i) { - bst_gpair &p = gpair[rowset[i] * ngroup + gid]; - if (p.GetHess() >= 0.0f) { - p += bst_gpair(p.GetHess() * dw, 0); - } - } - } - dmlc::DataIter *iter = p_fmat->ColIterator(); - while (iter->Next()) { - // number of features - const ColBatch &batch = iter->Value(); - const bst_omp_uint nfeat = static_cast(batch.size); - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nfeat; ++i) { - const bst_uint fid = batch.col_index[i]; - ColBatch::Inst col = batch[i]; - for (int gid = 0; gid < ngroup; ++gid) { - double sum_grad = 0.0, sum_hess = 0.0; - for (bst_uint j = 0; j < col.length; ++j) { - const bst_float v = col[j].fvalue; - bst_gpair &p = gpair[col[j].index * ngroup + gid]; - if (p.GetHess() < 0.0f) continue; - sum_grad += p.GetGrad() * v; - sum_hess += p.GetHess() * v * v; - } - bst_float &w = model[fid][gid]; - bst_float dw = static_cast(param.learning_rate * - param.CalcDelta(sum_grad, sum_hess, w)); - w += dw; - // update grad value - for (bst_uint j = 0; j < col.length; ++j) { - bst_gpair &p = gpair[col[j].index * ngroup + gid]; - if (p.GetHess() < 0.0f) continue; - p += bst_gpair(p.GetHess() * col[j].fvalue * dw, 0); - } - } - } + model.LazyInitModel(); + + this->LazySumWeights(p_fmat); + + if (!this->CheckConvergence()) { + updater->Update(in_gpair, p_fmat, &model, sum_instance_weight); } + this->UpdatePredictionCache(); + + monitor.Stop("DoBoost"); } - void PredictBatch(DMatrix *p_fmat, - std::vector *out_preds, - unsigned ntree_limit) override { - if (model.weight.size() == 0) { - model.InitModel(); - } + void PredictBatch(DMatrix *p_fmat, std::vector *out_preds, + unsigned ntree_limit) override { + monitor.Start("PredictBatch"); CHECK_EQ(ntree_limit, 0U) << "GBLinear::Predict ntrees is only valid for gbtree predictor"; - std::vector &preds = *out_preds; - const std::vector& base_margin = p_fmat->info().base_margin; - preds.resize(0); - // start collecting the prediction - dmlc::DataIter *iter = p_fmat->RowIterator(); - const int ngroup = model.param.num_output_group; - while (iter->Next()) { - const RowBatch &batch = iter->Value(); - CHECK_EQ(batch.base_rowid * ngroup, preds.size()); - // output convention: nrow * k, where nrow is number of rows - // k is number of group - preds.resize(preds.size() + batch.size * ngroup); - // parallel over local batch - const omp_ulong nsize = static_cast(batch.size); - #pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < nsize; ++i) { - const size_t ridx = batch.base_rowid + i; - // loop over output groups - for (int gid = 0; gid < ngroup; ++gid) { - bst_float margin = (base_margin.size() != 0) ? - base_margin[ridx * ngroup + gid] : base_margin_; - this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); - } - } + + // Try to predict from cache + auto it = cache_.find(p_fmat); + if (it != cache_.end() && it->second.predictions.size() != 0) { + std::vector &y = it->second.predictions; + out_preds->resize(y.size()); + std::copy(y.begin(), y.end(), out_preds->begin()); + } else { + this->PredictBatchInternal(p_fmat, out_preds); } + monitor.Stop("PredictBatch"); } // add base margin void PredictInstance(const SparseBatch::Inst &inst, @@ -226,9 +136,7 @@ class GBLinear : public GradientBooster { std::vector* out_contribs, unsigned ntree_limit, bool approximate, int condition = 0, unsigned condition_feature = 0) override { - if (model.weight.size() == 0) { - model.InitModel(); - } + model.LazyInitModel(); CHECK_EQ(ntree_limit, 0U) << "GBLinear::PredictContribution: ntrees is only valid for gbtree predictor"; const std::vector& base_margin = p_fmat->info().base_margin; @@ -317,7 +225,74 @@ class GBLinear : public GradientBooster { } protected: - inline void Pred(const RowBatch::Inst &inst, bst_float *preds, int gid, bst_float base) { + void PredictBatchInternal(DMatrix *p_fmat, + std::vector *out_preds) { + monitor.Start("PredictBatchInternal"); + model.LazyInitModel(); + std::vector &preds = *out_preds; + const std::vector& base_margin = p_fmat->info().base_margin; + // start collecting the prediction + dmlc::DataIter *iter = p_fmat->RowIterator(); + const int ngroup = model.param.num_output_group; + preds.resize(p_fmat->info().num_row * ngroup); + while (iter->Next()) { + const RowBatch &batch = iter->Value(); + // output convention: nrow * k, where nrow is number of rows + // k is number of group + // parallel over local batch + const omp_ulong nsize = static_cast(batch.size); + #pragma omp parallel for schedule(static) + for (omp_ulong i = 0; i < nsize; ++i) { + const size_t ridx = batch.base_rowid + i; + // loop over output groups + for (int gid = 0; gid < ngroup; ++gid) { + bst_float margin = (base_margin.size() != 0) ? + base_margin[ridx * ngroup + gid] : base_margin_; + this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); + } + } + } + monitor.Stop("PredictBatchInternal"); + } + void UpdatePredictionCache() { + // update cache entry + for (auto &kv : cache_) { + PredictionCacheEntry &e = kv.second; + if (e.predictions.size() == 0) { + size_t n = model.param.num_output_group * e.data->info().num_row; + e.predictions.resize(n); + } + this->PredictBatchInternal(e.data.get(), &e.predictions); + } + } + + bool CheckConvergence() { + if (param.tolerance == 0.0f) return false; + if (is_converged) return true; + if (previous_model.weight.size() != model.weight.size()) return false; + float largest_dw = 0.0; + for (auto i = 0; i < model.weight.size(); i++) { + largest_dw = std::max( + largest_dw, std::abs(model.weight[i] - previous_model.weight[i])); + } + previous_model = model; + + is_converged = largest_dw <= param.tolerance; + return is_converged; + } + + void LazySumWeights(DMatrix *p_fmat) { + if (!sum_weight_complete) { + auto &info = p_fmat->info(); + for (int i = 0; i < info.num_row; i++) { + sum_instance_weight += info.GetWeight(i); + } + sum_weight_complete = true; + } + } + + inline void Pred(const RowBatch::Inst &inst, bst_float *preds, int gid, + bst_float base) { bst_float psum = model.bias()[gid] + base; for (bst_uint i = 0; i < inst.length; ++i) { if (inst[i].index >= model.param.num_feature) continue; @@ -325,52 +300,33 @@ class GBLinear : public GradientBooster { } preds[gid] = psum; } - // model for linear booster - class Model { - public: - // parameter - GBLinearModelParam param; - // weight for each of feature, bias is the last one - std::vector weight; - // initialize the model parameter - inline void InitModel(void) { - // bias is the last weight - weight.resize((param.num_feature + 1) * param.num_output_group); - std::fill(weight.begin(), weight.end(), 0.0f); - } - // save the model to file - inline void Save(dmlc::Stream* fo) const { - fo->Write(¶m, sizeof(param)); - fo->Write(weight); - } - // load model from file - inline void Load(dmlc::Stream* fi) { - CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param)); - fi->Read(&weight); - } - // model bias - inline bst_float* bias() { - return &weight[param.num_feature * param.num_output_group]; - } - inline const bst_float* bias() const { - return &weight[param.num_feature * param.num_output_group]; - } - // get i-th weight - inline bst_float* operator[](size_t i) { - return &weight[i * param.num_output_group]; - } - inline const bst_float* operator[](size_t i) const { - return &weight[i * param.num_output_group]; - } - }; // biase margin score bst_float base_margin_; // model field - Model model; - // training parameter + GBLinearModel model; + GBLinearModel previous_model; GBLinearTrainParam param; - // Per feature: shuffle index of each feature index - std::vector feat_index; + std::unique_ptr updater; + double sum_instance_weight; + bool sum_weight_complete; + common::Monitor monitor; + bool is_converged; + + /** + * \struct PredictionCacheEntry + * + * \brief Contains pointer to input matrix and associated cached predictions. + */ + struct PredictionCacheEntry { + std::shared_ptr data; + std::vector predictions; + }; + + /** + * \brief Map of matrices and associated cached predictions to facilitate + * storing and looking up predictions. + */ + std::unordered_map cache_; }; // register the objective functions @@ -378,9 +334,10 @@ DMLC_REGISTER_PARAMETER(GBLinearModelParam); DMLC_REGISTER_PARAMETER(GBLinearTrainParam); XGBOOST_REGISTER_GBM(GBLinear, "gblinear") -.describe("Linear booster, implement generalized linear model.") -.set_body([](const std::vector >&cache, bst_float base_margin) { - return new GBLinear(base_margin); - }); + .describe("Linear booster, implement generalized linear model.") + .set_body([](const std::vector > &cache, + bst_float base_margin) { + return new GBLinear(cache, base_margin); + }); } // namespace gbm } // namespace xgboost diff --git a/src/gbm/gblinear_model.h b/src/gbm/gblinear_model.h new file mode 100644 index 000000000..72fcedb80 --- /dev/null +++ b/src/gbm/gblinear_model.h @@ -0,0 +1,73 @@ +/*! + * Copyright by Contributors 2018 + */ +#pragma once +#include +#include +#include +#include + +namespace xgboost { +namespace gbm { +// model parameter +struct GBLinearModelParam : public dmlc::Parameter { + // number of feature dimension + unsigned num_feature; + // number of output group + int num_output_group; + // reserved field + int reserved[32]; + // constructor + GBLinearModelParam() { std::memset(this, 0, sizeof(GBLinearModelParam)); } + DMLC_DECLARE_PARAMETER(GBLinearModelParam) { + DMLC_DECLARE_FIELD(num_feature) + .set_lower_bound(0) + .describe("Number of features used in classification."); + DMLC_DECLARE_FIELD(num_output_group) + .set_lower_bound(1) + .set_default(1) + .describe("Number of output groups in the setting."); + } +}; + +// model for linear booster +class GBLinearModel { + public: + // parameter + GBLinearModelParam param; + // weight for each of feature, bias is the last one + std::vector weight; + // initialize the model parameter + inline void LazyInitModel(void) { + if (!weight.empty()) return; + // bias is the last weight + weight.resize((param.num_feature + 1) * param.num_output_group); + std::fill(weight.begin(), weight.end(), 0.0f); + } + // save the model to file + inline void Save(dmlc::Stream* fo) const { + fo->Write(¶m, sizeof(param)); + fo->Write(weight); + } + // load model from file + inline void Load(dmlc::Stream* fi) { + CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param)); + fi->Read(&weight); + } + // model bias + inline bst_float* bias() { + return &weight[param.num_feature * param.num_output_group]; + } + inline const bst_float* bias() const { + return &weight[param.num_feature * param.num_output_group]; + } + // get i-th weight + inline bst_float* operator[](size_t i) { + return &weight[i * param.num_output_group]; + } + inline const bst_float* operator[](size_t i) const { + return &weight[i * param.num_output_group]; + } +}; +} // namespace gbm +} // namespace xgboost diff --git a/src/learner.cc b/src/learner.cc index cc93c34ff..3f13ffba8 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -464,18 +464,18 @@ class LearnerImpl : public Learner { // if not, initialize the column access. inline void LazyInitDMatrix(DMatrix* p_train) { if (tparam.tree_method == 3 || tparam.tree_method == 4 || - tparam.tree_method == 5) { + tparam.tree_method == 5 || name_gbm_ == "gblinear") { return; } monitor.Start("LazyInitDMatrix"); - if (!p_train->HaveColAccess()) { + if (!p_train->HaveColAccess(true)) { int ncol = static_cast(p_train->info().num_col); std::vector enabled(ncol, true); // set max row per batch to limited value // in distributed mode, use safe choice otherwise size_t max_row_perbatch = tparam.max_row_perbatch; - const size_t safe_max_row = static_cast(32UL << 10UL); + const size_t safe_max_row = static_cast(32ul << 10ul); if (tparam.tree_method == 0 && p_train->info().num_row >= (4UL << 20UL)) { LOG(CONSOLE) @@ -495,7 +495,7 @@ class LearnerImpl : public Learner { max_row_perbatch = std::min(max_row_perbatch, safe_max_row); } // initialize column access - p_train->InitColAccess(enabled, tparam.prob_buffer_row, max_row_perbatch); + p_train->InitColAccess(enabled, tparam.prob_buffer_row, max_row_perbatch, true); } if (!p_train->SingleColBlock() && cfg_.count("updater") == 0) { diff --git a/src/linear/coordinate_common.h b/src/linear/coordinate_common.h new file mode 100644 index 000000000..41955e4c7 --- /dev/null +++ b/src/linear/coordinate_common.h @@ -0,0 +1,321 @@ +/*! + * Copyright 2018 by Contributors + * \author Rory Mitchell + */ +#pragma once +#include +#include +#include +#include +#include "../common/random.h" + +namespace xgboost { +namespace linear { + +/** + * \brief Calculate change in weight for a given feature. Applies l1/l2 penalty normalised by the + * number of training instances. + * + * \param sum_grad The sum gradient. + * \param sum_hess The sum hess. + * \param w The weight. + * \param reg_lambda Unnormalised L2 penalty. + * \param reg_alpha Unnormalised L1 penalty. + * \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty. + * + * \return The weight update. + */ + +inline double CoordinateDelta(double sum_grad, double sum_hess, double w, + double reg_lambda, double reg_alpha, + double sum_instance_weight) { + reg_alpha *= sum_instance_weight; + reg_lambda *= sum_instance_weight; + if (sum_hess < 1e-5f) return 0.0f; + double tmp = w - (sum_grad + reg_lambda * w) / (sum_hess + reg_lambda); + if (tmp >= 0) { + return std::max( + -(sum_grad + reg_lambda * w + reg_alpha) / (sum_hess + reg_lambda), -w); + } else { + return std::min( + -(sum_grad + reg_lambda * w - reg_alpha) / (sum_hess + reg_lambda), -w); + } +} + +/** + * \brief Calculate update to bias. + * + * \param sum_grad The sum gradient. + * \param sum_hess The sum hess. + * + * \return The weight update. + */ + +inline double CoordinateDeltaBias(double sum_grad, double sum_hess) { + return -sum_grad / sum_hess; +} + +/** + * \brief Get the gradient with respect to a single feature. + * + * \param group_idx Zero-based index of the group. + * \param num_group Number of groups. + * \param fidx The target feature. + * \param gpair Gradients. + * \param p_fmat The feature matrix. + * + * \return The gradient and diagonal Hessian entry for a given feature. + */ + +inline std::pair GetGradient( + int group_idx, int num_group, int fidx, const std::vector &gpair, + DMatrix *p_fmat) { + double sum_grad = 0.0, sum_hess = 0.0; + dmlc::DataIter *iter = p_fmat->ColIterator(); + while (iter->Next()) { + const ColBatch &batch = iter->Value(); + ColBatch::Inst col = batch[fidx]; + const bst_omp_uint ndata = static_cast(col.length); + for (bst_omp_uint j = 0; j < ndata; ++j) { + const bst_float v = col[j].fvalue; + auto &p = gpair[col[j].index * num_group + group_idx]; + if (p.GetHess() < 0.0f) continue; + sum_grad += p.GetGrad() * v; + sum_hess += p.GetHess() * v * v; + } + } + return std::make_pair(sum_grad, sum_hess); +} + +/** + * \brief Get the gradient with respect to a single feature. Multithreaded. + * + * \param group_idx Zero-based index of the group. + * \param num_group Number of groups. + * \param fidx The target feature. + * \param gpair Gradients. + * \param p_fmat The feature matrix. + * + * \return The gradient and diagonal Hessian entry for a given feature. + */ + +inline std::pair GetGradientParallel( + int group_idx, int num_group, int fidx, + + const std::vector &gpair, DMatrix *p_fmat) { + double sum_grad = 0.0, sum_hess = 0.0; + dmlc::DataIter *iter = p_fmat->ColIterator(); + while (iter->Next()) { + const ColBatch &batch = iter->Value(); + ColBatch::Inst col = batch[fidx]; + const bst_omp_uint ndata = static_cast(col.length); +#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess) + for (bst_omp_uint j = 0; j < ndata; ++j) { + const bst_float v = col[j].fvalue; + auto &p = gpair[col[j].index * num_group + group_idx]; + if (p.GetHess() < 0.0f) continue; + sum_grad += p.GetGrad() * v; + sum_hess += p.GetHess() * v * v; + } + } + return std::make_pair(sum_grad, sum_hess); +} + +/** + * \brief Get the gradient with respect to the bias. Multithreaded. + * + * \param group_idx Zero-based index of the group. + * \param num_group Number of groups. + * \param gpair Gradients. + * \param p_fmat The feature matrix. + * + * \return The gradient and diagonal Hessian entry for the bias. + */ + +inline std::pair GetBiasGradientParallel( + int group_idx, int num_group, const std::vector &gpair, + DMatrix *p_fmat) { + const RowSet &rowset = p_fmat->buffered_rowset(); + double sum_grad = 0.0, sum_hess = 0.0; + const bst_omp_uint ndata = static_cast(rowset.size()); +#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess) + for (bst_omp_uint i = 0; i < ndata; ++i) { + auto &p = gpair[rowset[i] * num_group + group_idx]; + if (p.GetHess() >= 0.0f) { + sum_grad += p.GetGrad(); + sum_hess += p.GetHess(); + } + } + return std::make_pair(sum_grad, sum_hess); +} + +/** + * \brief Updates the gradient vector with respect to a change in weight. + * + * \param fidx The feature index. + * \param group_idx Zero-based index of the group. + * \param num_group Number of groups. + * \param dw The change in weight. + * \param in_gpair The gradient vector to be updated. + * \param p_fmat The input feature matrix. + */ + +inline void UpdateResidualParallel(int fidx, int group_idx, int num_group, + float dw, std::vector *in_gpair, + DMatrix *p_fmat) { + if (dw == 0.0f) return; + dmlc::DataIter *iter = p_fmat->ColIterator(); + while (iter->Next()) { + const ColBatch &batch = iter->Value(); + ColBatch::Inst col = batch[fidx]; + // update grad value + const bst_omp_uint num_row = static_cast(col.length); +#pragma omp parallel for schedule(static) + for (bst_omp_uint j = 0; j < num_row; ++j) { + bst_gpair &p = (*in_gpair)[col[j].index * num_group + group_idx]; + if (p.GetHess() < 0.0f) continue; + p += bst_gpair(p.GetHess() * col[j].fvalue * dw, 0); + } + } +} + +/** + * \brief Updates the gradient vector based on a change in the bias. + * + * \param group_idx Zero-based index of the group. + * \param num_group Number of groups. + * \param dbias The change in bias. + * \param in_gpair The gradient vector to be updated. + * \param p_fmat The input feature matrix. + */ + +inline void UpdateBiasResidualParallel(int group_idx, int num_group, + float dbias, + std::vector *in_gpair, + DMatrix *p_fmat) { + if (dbias == 0.0f) return; + const RowSet &rowset = p_fmat->buffered_rowset(); + const bst_omp_uint ndata = static_cast(p_fmat->info().num_row); +#pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < ndata; ++i) { + bst_gpair &g = (*in_gpair)[rowset[i] * num_group + group_idx]; + if (g.GetHess() < 0.0f) continue; + g += bst_gpair(g.GetHess() * dbias, 0); + } +} + +/** + * \class FeatureSelector + * + * \brief Abstract class for stateful feature selection in coordinate descent + * algorithms. + */ + +class FeatureSelector { + public: + static FeatureSelector *Create(std::string name); + /*! \brief virtual destructor */ + virtual ~FeatureSelector() {} + + /** + * \brief Select next coordinate to update. + * + * \param iteration The iteration. + * \param model The model. + * \param group_idx Zero-based index of the group. + * \param gpair The gpair. + * \param p_fmat The feature matrix. + * \param alpha Regularisation alpha. + * \param lambda Regularisation lambda. + * \param sum_instance_weight The sum instance weight. + * + * \return The index of the selected feature. -1 indicates the bias term. + */ + + virtual int SelectNextFeature(int iteration, + const gbm::GBLinearModel &model, + int group_idx, + const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda, + double sum_instance_weight) = 0; +}; + +/** + * \class CyclicFeatureSelector + * + * \brief Deterministic selection by cycling through coordinates one at a time. + */ + +class CyclicFeatureSelector : public FeatureSelector { + public: + int SelectNextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda, + double sum_instance_weight) override { + return iteration % model.param.num_feature; + } +}; + +/** + * \class RandomFeatureSelector + * + * \brief A random coordinate selector. + */ + +class RandomFeatureSelector : public FeatureSelector { + public: + int SelectNextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda, + double sum_instance_weight) override { + return common::GlobalRandom()() % model.param.num_feature; + } +}; + +/** + * \class GreedyFeatureSelector + * + * \brief Select coordinate with the greatest gradient magnitude. + */ + +class GreedyFeatureSelector : public FeatureSelector { + public: + int SelectNextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda, + double sum_instance_weight) override { + // Find best + int best_fidx = 0; + double best_weight_update = 0.0f; + + for (auto fidx = 0U; fidx < model.param.num_feature; fidx++) { + const float w = model[fidx][group_idx]; + auto gradient = GetGradientParallel( + group_idx, model.param.num_output_group, fidx, gpair, p_fmat); + float dw = static_cast( + CoordinateDelta(gradient.first, gradient.second, w, lambda, alpha, + sum_instance_weight)); + if (std::abs(dw) > std::abs(best_weight_update)) { + best_weight_update = dw; + best_fidx = fidx; + } + } + return best_fidx; + } +}; + +inline FeatureSelector *FeatureSelector::Create(std::string name) { + if (name == "cyclic") { + return new CyclicFeatureSelector(); + } else if (name == "random") { + return new RandomFeatureSelector(); + } else if (name == "greedy") { + return new GreedyFeatureSelector(); + } else { + LOG(FATAL) << name << ": unknown coordinate selector"; + } + return nullptr; +} + +} // namespace linear +} // namespace xgboost diff --git a/src/linear/linear_updater.cc b/src/linear/linear_updater.cc new file mode 100644 index 000000000..9041a57f3 --- /dev/null +++ b/src/linear/linear_updater.cc @@ -0,0 +1,29 @@ +/*! + * Copyright 2018 + */ +#include +#include + +namespace dmlc { +DMLC_REGISTRY_ENABLE(::xgboost::LinearUpdaterReg); +} // namespace dmlc + +namespace xgboost { + +LinearUpdater* LinearUpdater::Create(const std::string& name) { + auto *e = ::dmlc::Registry< ::xgboost::LinearUpdaterReg>::Get()->Find(name); + if (e == nullptr) { + LOG(FATAL) << "Unknown linear updater " << name; + } + return (e->body)(); +} + +} // namespace xgboost + +namespace xgboost { +namespace linear { +// List of files that will be force linked in static links. +DMLC_REGISTRY_LINK_TAG(updater_shotgun); +DMLC_REGISTRY_LINK_TAG(updater_coordinate); +} // namespace linear +} // namespace xgboost diff --git a/src/linear/updater_coordinate.cc b/src/linear/updater_coordinate.cc new file mode 100644 index 000000000..4f8a58b55 --- /dev/null +++ b/src/linear/updater_coordinate.cc @@ -0,0 +1,124 @@ +/*! + * Copyright 2018 by Contributors + * \author Rory Mitchell + */ + +#include +#include "../common/timer.h" +#include "coordinate_common.h" + +namespace xgboost { +namespace linear { + +DMLC_REGISTRY_FILE_TAG(updater_coordinate); + +// training parameter +struct CoordinateTrainParam : public dmlc::Parameter { + /*! \brief learning_rate */ + float learning_rate; + /*! \brief regularization weight for L2 norm */ + float reg_lambda; + /*! \brief regularization weight for L1 norm */ + float reg_alpha; + std::string feature_selector; + float maximum_weight; + int debug_verbose; + // declare parameters + DMLC_DECLARE_PARAMETER(CoordinateTrainParam) { + DMLC_DECLARE_FIELD(learning_rate) + .set_lower_bound(0.0f) + .set_default(1.0f) + .describe("Learning rate of each update."); + DMLC_DECLARE_FIELD(reg_lambda) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe("L2 regularization on weights."); + DMLC_DECLARE_FIELD(reg_alpha) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe("L1 regularization on weights."); + DMLC_DECLARE_FIELD(feature_selector) + .set_default("cyclic") + .describe( + "Feature selection algorithm, one of cyclic/random/greedy"); + DMLC_DECLARE_FIELD(debug_verbose) + .set_lower_bound(0) + .set_default(0) + .describe("flag to print out detailed breakdown of runtime"); + // alias of parameters + DMLC_DECLARE_ALIAS(reg_lambda, lambda); + DMLC_DECLARE_ALIAS(reg_alpha, alpha); + } +}; + +/** + * \class CoordinateUpdater + * + * \brief Coordinate descent algorithm that updates one feature per iteration + */ + +class CoordinateUpdater : public LinearUpdater { + public: + // set training parameter + void Init( + const std::vector > &args) override { + param.InitAllowUnknown(args); + selector.reset(FeatureSelector::Create(param.feature_selector)); + monitor.Init("CoordinateUpdater", param.debug_verbose); + } + void Update(std::vector *in_gpair, DMatrix *p_fmat, + gbm::GBLinearModel *model, double sum_instance_weight) override { + // Calculate bias + for (int group_idx = 0; group_idx < model->param.num_output_group; + ++group_idx) { + auto grad = GetBiasGradientParallel( + group_idx, model->param.num_output_group, *in_gpair, p_fmat); + auto dbias = static_cast( + param.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); + model->bias()[group_idx] += dbias; + UpdateBiasResidualParallel(group_idx, model->param.num_output_group, + dbias, in_gpair, p_fmat); + } + for (int group_idx = 0; group_idx < model->param.num_output_group; + ++group_idx) { + for (auto i = 0U; i < model->param.num_feature; i++) { + int fidx = selector->SelectNextFeature( + i, *model, group_idx, *in_gpair, p_fmat, param.reg_alpha, + param.reg_lambda, sum_instance_weight); + this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model, + sum_instance_weight); + } + } + } + + void UpdateFeature(int fidx, int group_idx, std::vector *in_gpair, + DMatrix *p_fmat, gbm::GBLinearModel *model, + double sum_instance_weight) { + bst_float &w = (*model)[fidx][group_idx]; + monitor.Start("GetGradientParallel"); + auto gradient = GetGradientParallel( + group_idx, model->param.num_output_group, fidx, *in_gpair, p_fmat); + monitor.Stop("GetGradientParallel"); + auto dw = static_cast( + param.learning_rate * + CoordinateDelta(gradient.first, gradient.second, w, param.reg_lambda, + param.reg_alpha, sum_instance_weight)); + w += dw; + monitor.Start("UpdateResidualParallel"); + UpdateResidualParallel(fidx, group_idx, model->param.num_output_group, dw, + in_gpair, p_fmat); + monitor.Stop("UpdateResidualParallel"); + } + + // training parameter + CoordinateTrainParam param; + std::unique_ptr selector; + common::Monitor monitor; +}; + +DMLC_REGISTER_PARAMETER(CoordinateTrainParam); +XGBOOST_REGISTER_LINEAR_UPDATER(CoordinateUpdater, "coord_descent") + .describe("Update linear model according to coordinate descent algorithm.") + .set_body([]() { return new CoordinateUpdater(); }); +} // namespace linear +} // namespace xgboost diff --git a/src/linear/updater_shotgun.cc b/src/linear/updater_shotgun.cc new file mode 100644 index 000000000..02d740031 --- /dev/null +++ b/src/linear/updater_shotgun.cc @@ -0,0 +1,127 @@ +/*! + * Copyright 2018 by Contributors + * \author Tianqi Chen, Rory Mitchell + */ + +#include +#include "coordinate_common.h" + +namespace xgboost { +namespace linear { + +DMLC_REGISTRY_FILE_TAG(updater_shotgun); + +// training parameter +struct ShotgunTrainParam : public dmlc::Parameter { + /*! \brief learning_rate */ + float learning_rate; + /*! \brief regularization weight for L2 norm */ + float reg_lambda; + /*! \brief regularization weight for L1 norm */ + float reg_alpha; + // declare parameters + DMLC_DECLARE_PARAMETER(ShotgunTrainParam) { + DMLC_DECLARE_FIELD(learning_rate) + .set_lower_bound(0.0f) + .set_default(1.0f) + .describe("Learning rate of each update."); + DMLC_DECLARE_FIELD(reg_lambda) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe("L2 regularization on weights."); + DMLC_DECLARE_FIELD(reg_alpha) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe("L1 regularization on weights."); + // alias of parameters + DMLC_DECLARE_ALIAS(learning_rate, eta); + DMLC_DECLARE_ALIAS(reg_lambda, lambda); + DMLC_DECLARE_ALIAS(reg_alpha, alpha); + } +}; + +class ShotgunUpdater : public LinearUpdater { + public: + // set training parameter + void Init( + const std::vector > &args) override { + param.InitAllowUnknown(args); + } + void Update(std::vector *in_gpair, DMatrix *p_fmat, + gbm::GBLinearModel *model, double sum_instance_weight) override { + std::vector &gpair = *in_gpair; + const int ngroup = model->param.num_output_group; + const RowSet &rowset = p_fmat->buffered_rowset(); + // for all the output group + for (int gid = 0; gid < ngroup; ++gid) { + double sum_grad = 0.0, sum_hess = 0.0; + const bst_omp_uint ndata = static_cast(rowset.size()); +#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess) + for (bst_omp_uint i = 0; i < ndata; ++i) { + bst_gpair &p = gpair[rowset[i] * ngroup + gid]; + if (p.GetHess() >= 0.0f) { + sum_grad += p.GetGrad(); + sum_hess += p.GetHess(); + } + } + // remove bias effect + bst_float dw = static_cast( + param.learning_rate * CoordinateDeltaBias(sum_grad, sum_hess)); + model->bias()[gid] += dw; +// update grad value +#pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < ndata; ++i) { + bst_gpair &p = gpair[rowset[i] * ngroup + gid]; + if (p.GetHess() >= 0.0f) { + p += bst_gpair(p.GetHess() * dw, 0); + } + } + } + dmlc::DataIter *iter = p_fmat->ColIterator(); + while (iter->Next()) { + // number of features + const ColBatch &batch = iter->Value(); + const bst_omp_uint nfeat = static_cast(batch.size); +#pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nfeat; ++i) { + const bst_uint fid = batch.col_index[i]; + ColBatch::Inst col = batch[i]; + for (int gid = 0; gid < ngroup; ++gid) { + double sum_grad = 0.0, sum_hess = 0.0; + for (bst_uint j = 0; j < col.length; ++j) { + const bst_float v = col[j].fvalue; + bst_gpair &p = gpair[col[j].index * ngroup + gid]; + if (p.GetHess() < 0.0f) continue; + sum_grad += p.GetGrad() * v; + sum_hess += p.GetHess() * v * v; + } + bst_float &w = (*model)[fid][gid]; + bst_float dw = static_cast( + param.learning_rate * + CoordinateDelta(sum_grad, sum_hess, w, param.reg_lambda, + param.reg_alpha, sum_instance_weight)); + w += dw; + // update grad value + for (bst_uint j = 0; j < col.length; ++j) { + bst_gpair &p = gpair[col[j].index * ngroup + gid]; + if (p.GetHess() < 0.0f) continue; + p += bst_gpair(p.GetHess() * col[j].fvalue * dw, 0); + } + } + } + } + } + + // training parameter + ShotgunTrainParam param; +}; + +DMLC_REGISTER_PARAMETER(ShotgunTrainParam); + +XGBOOST_REGISTER_LINEAR_UPDATER(ShotgunUpdater, "shotgun") + .describe( + "Update linear model according to shotgun coordinate descent " + "algorithm.") + .set_body([]() { return new ShotgunUpdater(); }); +} // namespace linear +} // namespace xgboost diff --git a/tests/benchmark/benchmark_linear.py b/tests/benchmark/benchmark_linear.py new file mode 100644 index 000000000..561a531d8 --- /dev/null +++ b/tests/benchmark/benchmark_linear.py @@ -0,0 +1,69 @@ +#pylint: skip-file +import sys, argparse +import xgboost as xgb +import numpy as np +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +import time +import ast + +rng = np.random.RandomState(1994) + + +def run_benchmark(args): + + try: + dtest = xgb.DMatrix('dtest.dm') + dtrain = xgb.DMatrix('dtrain.dm') + + if not (dtest.num_col() == args.columns \ + and dtrain.num_col() == args.columns): + raise ValueError("Wrong cols") + if not (dtest.num_row() == args.rows * args.test_size \ + and dtrain.num_row() == args.rows * (1-args.test_size)): + raise ValueError("Wrong rows") + except: + + print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns)) + print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size)) + tmp = time.time() + X, y = make_classification(args.rows, n_features=args.columns, n_redundant=0, n_informative=args.columns, n_repeated=0, random_state=7) + if args.sparsity < 1.0: + X = np.array([[np.nan if rng.uniform(0, 1) < args.sparsity else x for x in x_row] for x_row in X]) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7) + print ("Generate Time: %s seconds" % (str(time.time() - tmp))) + tmp = time.time() + print ("DMatrix Start") + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test, nthread=-1) + print ("DMatrix Time: %s seconds" % (str(time.time() - tmp))) + + dtest.save_binary('dtest.dm') + dtrain.save_binary('dtrain.dm') + + param = {'objective': 'binary:logistic','booster':'gblinear'} + if args.params is not '': + param.update(ast.literal_eval(args.params)) + + param['updater'] = args.updater + print("Training with '%s'" % param['updater']) + tmp = time.time() + xgb.train(param, dtrain, args.iterations, evals=[(dtrain,"train")], early_stopping_rounds = args.columns) + print ("Train Time: %s seconds" % (str(time.time() - tmp))) + +parser = argparse.ArgumentParser() +parser.add_argument('--updater', default='coord_descent') +parser.add_argument('--sparsity', type=float, default=0.0) +parser.add_argument('--lambda', type=float, default=1.0) +parser.add_argument('--tol', type=float, default=1e-5) +parser.add_argument('--alpha', type=float, default=1.0) +parser.add_argument('--rows', type=int, default=1000000) +parser.add_argument('--iterations', type=int, default=10000) +parser.add_argument('--columns', type=int, default=50) +parser.add_argument('--test_size', type=float, default=0.25) +parser.add_argument('--standardise', type=bool, default=False) +parser.add_argument('--params', default='', help='Provide additional parameters as a Python dict string, e.g. --params \"{\'max_depth\':2}\"') +args = parser.parse_args() + +run_benchmark(args) diff --git a/tests/benchmark/benchmark.py b/tests/benchmark/benchmark_tree.py similarity index 100% rename from tests/benchmark/benchmark.py rename to tests/benchmark/benchmark_tree.py diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index be42019ea..f13d7b2f9 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -42,11 +42,18 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) { xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false); std::remove(tmp_file.c_str()); - EXPECT_EQ(dmat->HaveColAccess(), false); + // Unsorted column access const std::vector enable(dmat->info().num_col, true); - dmat->InitColAccess(enable, 1, dmat->info().num_row); - dmat->InitColAccess(enable, 0, 0); // Calling it again should not change it - ASSERT_EQ(dmat->HaveColAccess(), true); + EXPECT_EQ(dmat->HaveColAccess(false), false); + dmat->InitColAccess(enable, 1, dmat->info().num_row, false); + dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it + ASSERT_EQ(dmat->HaveColAccess(false), true); + + // Sorted column access + EXPECT_EQ(dmat->HaveColAccess(true), false); + dmat->InitColAccess(enable, 1, dmat->info().num_row, true); + dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it + ASSERT_EQ(dmat->HaveColAccess(true), true); EXPECT_EQ(dmat->GetColSize(0), 2); EXPECT_EQ(dmat->GetColSize(1), 1); @@ -86,11 +93,18 @@ TEST(SimpleDMatrix, ColAccessWithBatches) { xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false); std::remove(tmp_file.c_str()); - EXPECT_EQ(dmat->HaveColAccess(), false); + // Unsorted column access const std::vector enable(dmat->info().num_col, true); - dmat->InitColAccess(enable, 1, 1); // Max 1 row per patch - dmat->InitColAccess(enable, 0, 0); // Calling it again should not change it - ASSERT_EQ(dmat->HaveColAccess(), true); + EXPECT_EQ(dmat->HaveColAccess(false), false); + dmat->InitColAccess(enable, 1, 1, false); + dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it + ASSERT_EQ(dmat->HaveColAccess(false), true); + + // Sorted column access + EXPECT_EQ(dmat->HaveColAccess(true), false); + dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch + dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it + ASSERT_EQ(dmat->HaveColAccess(true), true); EXPECT_EQ(dmat->GetColSize(0), 2); EXPECT_EQ(dmat->GetColSize(1), 1); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 82957dcee..6d826f0e8 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -56,10 +56,10 @@ TEST(SparsePageDMatrix, ColAcess) { std::remove(tmp_file.c_str()); EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page")); - EXPECT_EQ(dmat->HaveColAccess(), false); + EXPECT_EQ(dmat->HaveColAccess(true), false); const std::vector enable(dmat->info().num_col, true); - dmat->InitColAccess(enable, 1, 1); // Max 1 row per patch - ASSERT_EQ(dmat->HaveColAccess(), true); + dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch + ASSERT_EQ(dmat->HaveColAccess(true), true); EXPECT_TRUE(FileExists(tmp_file + ".cache.col.page")); EXPECT_EQ(dmat->GetColSize(0), 2); diff --git a/tests/cpp/linear/test_linear.cc b/tests/cpp/linear/test_linear.cc new file mode 100644 index 000000000..a5f7756c0 --- /dev/null +++ b/tests/cpp/linear/test_linear.cc @@ -0,0 +1,44 @@ +// Copyright by Contributors +#include +#include "../helpers.h" +#include "xgboost/gbm.h" + +typedef std::pair arg; + +TEST(Linear, shotgun) { + typedef std::pair arg; + auto mat = CreateDMatrix(10, 10, 0); + std::vector enabled(mat->info().num_col, true); + mat->InitColAccess(enabled, 1.0f, 1 << 16, false); + auto updater = std::unique_ptr( + xgboost::LinearUpdater::Create("shotgun")); + updater->Init({}); + std::vector gpair(mat->info().num_row, + xgboost::bst_gpair(-5, 1.0)); + xgboost::gbm::GBLinearModel model; + model.param.num_feature = mat->info().num_col; + model.param.num_output_group = 1; + model.LazyInitModel(); + updater->Update(&gpair, mat.get(), &model, gpair.size()); + + ASSERT_EQ(model.bias()[0], 5.0f); +} + +TEST(Linear, coordinate) { + typedef std::pair arg; + auto mat = CreateDMatrix(10, 10, 0); + std::vector enabled(mat->info().num_col, true); + mat->InitColAccess(enabled, 1.0f, 1 << 16, false); + auto updater = std::unique_ptr( + xgboost::LinearUpdater::Create("coord_descent")); + updater->Init({}); + std::vector gpair(mat->info().num_row, + xgboost::bst_gpair(-5, 1.0)); + xgboost::gbm::GBLinearModel model; + model.param.num_feature = mat->info().num_col; + model.param.num_output_group = 1; + model.LazyInitModel(); + updater->Update(&gpair, mat.get(), &model, gpair.size()); + + ASSERT_EQ(model.bias()[0], 5.0f); +} \ No newline at end of file diff --git a/tests/python/test_linear.py b/tests/python/test_linear.py new file mode 100644 index 000000000..fd85441e4 --- /dev/null +++ b/tests/python/test_linear.py @@ -0,0 +1,133 @@ +from __future__ import print_function + +import itertools as it +import numpy as np +import sys +import testing as tm +import unittest +import xgboost as xgb + +rng = np.random.RandomState(199) + +num_rounds = 1000 + + +def is_float(s): + try: + float(s) + return 1 + except ValueError: + return 0 + + +def xgb_get_weights(bst): + return [float(s) for s in bst.get_dump()[0].split() if is_float(s)] + + +# Check gradient/subgradient = 0 +def check_least_squares_solution(X, y, pred, tol, reg_alpha, reg_lambda, weights): + reg_alpha = reg_alpha * len(y) + reg_lambda = reg_lambda * len(y) + r = np.subtract(y, pred) + g = X.T.dot(r) + g = np.subtract(g, np.multiply(reg_lambda, weights)) + for i in range(0, len(weights)): + if weights[i] == 0.0: + assert abs(g[i]) <= reg_alpha + else: + assert np.isclose(g[i], np.sign(weights[i]) * reg_alpha, rtol=tol, atol=tol) + + +def train_diabetes(param_in): + from sklearn import datasets + data = datasets.load_diabetes() + dtrain = xgb.DMatrix(data.data, label=data.target) + param = {} + param.update(param_in) + bst = xgb.train(param, dtrain, num_rounds) + xgb_pred = bst.predict(dtrain) + check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'], + xgb_get_weights(bst)[1:]) + + +def train_breast_cancer(param_in): + from sklearn import metrics, datasets + data = datasets.load_breast_cancer() + dtrain = xgb.DMatrix(data.data, label=data.target) + param = {'objective': 'binary:logistic'} + param.update(param_in) + bst = xgb.train(param, dtrain, num_rounds) + xgb_pred = bst.predict(dtrain) + xgb_score = metrics.accuracy_score(data.target, np.round(xgb_pred)) + assert xgb_score >= 0.8 + + +def train_classification(param_in): + from sklearn import metrics, datasets + X, y = datasets.make_classification(random_state=rng, + scale=100) # Scale is necessary otherwise regularisation parameters will force all coefficients to 0 + dtrain = xgb.DMatrix(X, label=y) + param = {'objective': 'binary:logistic'} + param.update(param_in) + bst = xgb.train(param, dtrain, num_rounds) + xgb_pred = bst.predict(dtrain) + xgb_score = metrics.accuracy_score(y, np.round(xgb_pred)) + assert xgb_score >= 0.8 + + +def train_classification_multi(param_in): + from sklearn import metrics, datasets + num_class = 3 + X, y = datasets.make_classification(n_samples=10, random_state=rng, scale=100, n_classes=num_class, n_informative=4, + n_features=4, n_redundant=0) + dtrain = xgb.DMatrix(X, label=y) + param = {'objective': 'multi:softmax', 'num_class': num_class} + param.update(param_in) + bst = xgb.train(param, dtrain, num_rounds) + xgb_pred = bst.predict(dtrain) + xgb_score = metrics.accuracy_score(y, np.round(xgb_pred)) + assert xgb_score >= 0.50 + + +def train_boston(param_in): + from sklearn import datasets + data = datasets.load_boston() + dtrain = xgb.DMatrix(data.data, label=data.target) + param = {} + param.update(param_in) + bst = xgb.train(param, dtrain, num_rounds) + xgb_pred = bst.predict(dtrain) + check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'], + xgb_get_weights(bst)[1:]) + + +# Enumerates all permutations of variable parameters +def assert_updater_accuracy(linear_updater, variable_param): + param = {'booster': 'gblinear', 'updater': linear_updater, 'tolerance': 1e-8} + names = sorted(variable_param) + combinations = it.product(*(variable_param[Name] for Name in names)) + + for set in combinations: + param_tmp = param.copy() + for i, name in enumerate(names): + param_tmp[name] = set[i] + + print(param_tmp, file=sys.stderr) + train_boston(param_tmp) + train_diabetes(param_tmp) + train_classification(param_tmp) + train_classification_multi(param_tmp) + train_breast_cancer(param_tmp) + + +class TestLinear(unittest.TestCase): + def test_coordinate(self): + tm._skip_if_no_sklearn() + variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0], + 'coordinate_selection': ['cyclic', 'random', 'greedy']} + assert_updater_accuracy('coord_descent', variable_param) + + def test_shotgun(self): + tm._skip_if_no_sklearn() + variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0]} + assert_updater_accuracy('shotgun', variable_param)