Additional improvements for gblinear (#3134)

* fix rebase conflict

* [core] additional gblinear improvements

* [R] callback for gblinear coefficients history

* force eta=1 for gblinear python tests

* add top_k to GreedyFeatureSelector

* set eta=1 in shotgun test

* [core] fix SparsePage processing in gblinear; col-wise multithreading in greedy updater

* set sorted flag within TryInitColData

* gblinear tests: use scale, add external memory test

* fix multiclass for greedy updater

* fix whitespace

* fix typo
This commit is contained in:
Vadim Khotilovich 2018-03-13 01:27:13 -05:00 committed by GitHub
parent a1b48afa41
commit 706be4e5d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 750 additions and 260 deletions

View File

@ -1,8 +1,8 @@
Package: xgboost
Type: Package
Title: Extreme Gradient Boosting
Version: 0.7.0
Date: 2018-01-22
Version: 0.7.0.1
Date: 2018-02-25
Author: Tianqi Chen <tianqi.tchen@gmail.com>, Tong He <hetong007@gmail.com>,
Michael Benesty <michael@benesty.fr>, Vadim Khotilovich <khotilovich@gmail.com>,
Yuan Tang <terrytangyuan@gmail.com>

View File

@ -18,6 +18,7 @@ export("xgb.parameters<-")
export(cb.cv.predict)
export(cb.early.stop)
export(cb.evaluation.log)
export(cb.gblinear.history)
export(cb.print.evaluation)
export(cb.reset.parameters)
export(cb.save.model)
@ -32,6 +33,7 @@ export(xgb.attributes)
export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
export(xgb.gblinear.history)
export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
export(xgb.importance)
@ -52,7 +54,9 @@ importClassesFrom(Matrix,dgeMatrix)
importFrom(Matrix,cBind)
importFrom(Matrix,colSums)
importFrom(Matrix,sparse.model.matrix)
importFrom(Matrix,sparseMatrix)
importFrom(Matrix,sparseVector)
importFrom(Matrix,t)
importFrom(data.table,":=")
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)

View File

@ -524,6 +524,225 @@ cb.cv.predict <- function(save_models = FALSE) {
}
#' Callback closure for collecting the model coefficients history of a gblinear booster
#' during its training.
#'
#' @param sparse when set to FALSE/TURE, a dense/sparse matrix is used to store the result.
#' Sparse format is useful when one expects only a subset of coefficients to be non-zero,
#' when using the "thrifty" feature selector with fairly small number of top features
#' selected per iteration.
#'
#' @details
#' To keep things fast and simple, gblinear booster does not internally store the history of linear
#' model coefficients at each boosting iteration. This callback provides a workaround for storing
#' the coefficients' path, by extracting them after each training iteration.
#'
#' Callback function expects the following values to be set in its calling frame:
#' \code{bst} (or \code{bst_folds}).
#'
#' @return
#' Results are stored in the \code{coefs} element of the closure.
#' The \code{\link{xgb.gblinear.history}} convenience function provides an easy way to access it.
#' With \code{xgb.train}, it is either a dense of a sparse matrix.
#' While with \code{xgb.cv}, it is a list (an element per each fold) of such matrices.
#'
#' @seealso
#' \code{\link{callbacks}}, \code{\link{xgb.gblinear.history}}.
#'
#' @examples
#' #### Binary classification:
#' #
#' # In the iris dataset, it is hard to linearly separate Versicolor class from the rest
#' # without considering the 2nd order interactions:
#' x <- model.matrix(Species ~ .^2, iris)[,-1]
#' colnames(x)
#' dtrain <- xgb.DMatrix(scale(x), label = 1*(iris$Species == "versicolor"))
#' param <- list(booster = "gblinear", objective = "reg:logistic", eval_metric = "auc",
#' lambda = 0.0003, alpha = 0.0003, nthread = 2)
#' # For 'shotgun', which is a default linear updater, using high eta values may result in
#' # unstable behaviour in some datasets. With this simple dataset, however, the high learning
#' # rate does not break the convergence, but allows us to illustrate the typical pattern of
#' # "stochastic explosion" behaviour of this lock-free algorithm at early boosting iterations.
#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 1.,
#' callbacks = list(cb.gblinear.history()))
#' # Extract the coefficients' path and plot them vs boosting iteration number:
#' coef_path <- xgb.gblinear.history(bst)
#' matplot(coef_path, type = 'l')
#'
#' # With the deterministic coordinate descent updater, it is safer to use higher learning rates.
#' # Will try the classical componentwise boosting which selects a single best feature per round:
#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 0.8,
#' updater = 'coord_descent', feature_selector = 'thrifty', top_k = 1,
#' callbacks = list(cb.gblinear.history()))
#' xgb.gblinear.history(bst) %>% matplot(type = 'l')
#' # Componentwise boosting is known to have similar effect to Lasso regularization.
#' # Try experimenting with various values of top_k, eta, nrounds,
#' # as well as different feature_selectors.
#'
#' # For xgb.cv:
#' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 100, eta = 0.8,
#' callbacks = list(cb.gblinear.history()))
#' # coefficients in the CV fold #3
#' xgb.gblinear.history(bst)[[3]] %>% matplot(type = 'l')
#'
#'
#' #### Multiclass classification:
#' #
#' dtrain <- xgb.DMatrix(scale(x), label = as.numeric(iris$Species) - 1)
#' param <- list(booster = "gblinear", objective = "multi:softprob", num_class = 3,
#' lambda = 0.0003, alpha = 0.0003, nthread = 2)
#' # For the default linear updater 'shotgun' it sometimes is helpful
#' # to use smaller eta to reduce instability
#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 70, eta = 0.5,
#' callbacks = list(cb.gblinear.history()))
#' # Will plot the coefficient paths separately for each class:
#' xgb.gblinear.history(bst, class_index = 0) %>% matplot(type = 'l')
#' xgb.gblinear.history(bst, class_index = 1) %>% matplot(type = 'l')
#' xgb.gblinear.history(bst, class_index = 2) %>% matplot(type = 'l')
#'
#' # CV:
#' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 70, eta = 0.5,
#' callbacks = list(cb.gblinear.history(F)))
#' # 1st forld of 1st class
#' xgb.gblinear.history(bst, class_index = 0)[[1]] %>% matplot(type = 'l')
#'
#' @export
cb.gblinear.history <- function(sparse=FALSE) {
coefs <- NULL
init <- function(env) {
if (!is.null(env$bst)) { # xgb.train:
coef_path <- list()
} else if (!is.null(env$bst_folds)) { # xgb.cv:
coef_path <- rep(list(), length(env$bst_folds))
} else stop("Parent frame has neither 'bst' nor 'bst_folds'")
}
# convert from list to (sparse) matrix
list2mat <- function(coef_list) {
if (sparse) {
coef_mat <- sparseMatrix(x = unlist(lapply(coef_list, slot, "x")),
i = unlist(lapply(coef_list, slot, "i")),
p = c(0, cumsum(sapply(coef_list, function(x) length(x@x)))),
dims = c(length(coef_list[[1]]), length(coef_list)))
return(t(coef_mat))
} else {
return(do.call(rbind, coef_list))
}
}
finalizer <- function(env) {
if (length(coefs) == 0)
return()
if (!is.null(env$bst)) { # # xgb.train:
coefs <<- list2mat(coefs)
} else { # xgb.cv:
# first lapply transposes the list
coefs <<- lapply(seq_along(coefs[[1]]), function(i) lapply(coefs, "[[", i)) %>%
lapply(function(x) list2mat(x))
}
}
extract.coef <- function(env) {
if (!is.null(env$bst)) { # # xgb.train:
cf <- as.numeric(grep('(booster|bias|weigh)', xgb.dump(env$bst), invert = TRUE, value = TRUE))
if (sparse) cf <- as(cf, "sparseVector")
} else { # xgb.cv:
cf <- vector("list", length(env$bst_folds))
for (i in seq_along(env$bst_folds)) {
dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst))
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
}
}
cf
}
callback <- function(env = parent.frame(), finalize = FALSE) {
if (is.null(coefs)) init(env)
if (finalize) return(finalizer(env))
cf <- extract.coef(env)
coefs <<- c(coefs, list(cf))
}
attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.gblinear.history'
callback
}
#' Extract gblinear coefficients history.
#'
#' A helper function to extract the matrix of linear coefficients' history
#' from a gblinear model created while using the \code{cb.gblinear.history()}
#' callback.
#'
#' @param model either an \code{xgb.Booster} or a result of \code{xgb.cv()}, trained
#' using the \code{cb.gblinear.history()} callback.
#' @param class_index zero-based class index to extract the coefficients for only that
#' specific class in a multinomial multiclass model. When it is NULL, all the
#' coeffients are returned. Has no effect in non-multiclass models.
#'
#' @return
#' For an \code{xgb.train} result, a matrix (either dense or sparse) with the columns
#' corresponding to iteration's coefficients (in the order as \code{xgb.dump()} would
#' return) and the rows corresponding to boosting iterations.
#'
#' For an \code{xgb.cv} result, a list of such matrices is returned with the elements
#' corresponding to CV folds.
#'
#' @examples
#' See \code{\link{cv.gblinear.history}}
#'
#' @export
xgb.gblinear.history <- function(model, class_index = NULL) {
if (!(inherits(model, "xgb.Booster") ||
inherits(model, "xgb.cv.synchronous")))
stop("model must be an object of either xgb.Booster or xgb.cv.synchronous class")
is_cv <- inherits(model, "xgb.cv.synchronous")
if (is.null(model[["callbacks"]]) || is.null(model$callbacks[["cb.gblinear.history"]]))
stop("model must be trained while using the cb.gblinear.history() callback")
if (!is_cv) {
# extract num_class & num_feat from the internal model
dmp <- xgb.dump(model)
if(length(dmp) < 2 || dmp[2] != "bias:")
stop("It does not appear to be a gblinear model")
dmp <- dmp[-c(1,2)]
n <- which(dmp == 'weight:')
if(length(n) != 1)
stop("It does not appear to be a gblinear model")
num_class <- n - 1
num_feat <- (length(dmp) - 4) / num_class
} else {
# in case of CV, the object is expected to have this info
if (model$params$booster != "gblinear")
stop("It does not appear to be a gblinear model")
num_class <- NVL(model$params$num_class, 1)
num_feat <- model$nfeatures
if (is.null(num_feat))
stop("This xgb.cv result does not have nfeatures info")
}
if (!is.null(class_index) &&
num_class > 1 &&
(class_index[1] < 0 || class_index[1] >= num_class))
stop("class_index has to be within [0,", num_class - 1, "]")
coef_path <- environment(model$callbacks$cb.gblinear.history)[["coefs"]]
if (!is.null(class_index) && num_class > 1) {
coef_path <- if (is.list(coef_path)) {
lapply(coef_path,
function(x) x[, seq(1 + class_index, by=num_class, length.out=num_feat)])
} else {
coef_path <- coef_path[, seq(1 + class_index, by=num_class, length.out=num_feat)]
}
}
coef_path
}
#
# Internal utility functions for callbacks ------------------------------------
#

View File

@ -88,6 +88,7 @@
#' CV-based evaluation means and standard deviations for the training and test CV-sets.
#' It is created by the \code{\link{cb.evaluation.log}} callback.
#' \item \code{niter} number of boosting iterations.
#' \item \code{nfeatures} number of features in training data.
#' \item \code{folds} the list of CV folds' indices - either those passed through the \code{folds}
#' parameter or randomly generated.
#' \item \code{best_iteration} iteration number with the best evaluation metric value
@ -184,6 +185,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
handle <- xgb.Booster.handle(params, list(dtrain, dtest))
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test=dtest), index = folds[[k]])
})
rm(dall)
# a "basket" to collect some results from callbacks
basket <- list()
@ -221,6 +223,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
callbacks = callbacks,
evaluation_log = evaluation_log,
niter = end_iteration,
nfeatures = ncol(data),
folds = folds
)
ret <- c(ret, basket)

View File

@ -162,6 +162,7 @@
#' (only available with early stopping).
#' \item \code{feature_names} names of the training dataset features
#' (only when comun names were defined in training data).
#' \item \code{nfeatures} number of features in training data.
#' }
#'
#' @seealso
@ -363,6 +364,7 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
bst$callbacks <- callbacks
if (!is.null(colnames(dtrain)))
bst$feature_names <- colnames(dtrain)
bst$nfeatures <- ncol(dtrain)
return(bst)
}

View File

@ -81,6 +81,8 @@ NULL
#' @importFrom Matrix colSums
#' @importFrom Matrix sparse.model.matrix
#' @importFrom Matrix sparseVector
#' @importFrom Matrix sparseMatrix
#' @importFrom Matrix t
#' @importFrom data.table data.table
#' @importFrom data.table is.data.table
#' @importFrom data.table as.data.table

View File

@ -104,6 +104,7 @@ An object of class \code{xgb.cv.synchronous} with the following elements:
CV-based evaluation means and standard deviations for the training and test CV-sets.
It is created by the \code{\link{cb.evaluation.log}} callback.
\item \code{niter} number of boosting iterations.
\item \code{nfeatures} number of features in training data.
\item \code{folds} the list of CV folds' indices - either those passed through the \code{folds}
parameter or randomly generated.
\item \code{best_iteration} iteration number with the best evaluation metric value

View File

@ -155,6 +155,7 @@ An object of class \code{xgb.Booster} with the following elements:
(only available with early stopping).
\item \code{feature_names} names of the training dataset features
(only when comun names were defined in training data).
\item \code{nfeatures} number of features in training data.
}
}
\description{

View File

@ -2,18 +2,47 @@ context('Test generalized linear models')
require(xgboost)
test_that("glm works", {
test_that("gblinear works", {
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
expect_equal(class(dtrain), "xgb.DMatrix")
expect_equal(class(dtest), "xgb.DMatrix")
param <- list(objective = "binary:logistic", booster = "gblinear",
nthread = 2, alpha = 0.0001, lambda = 1)
nthread = 2, eta = 0.8, alpha = 0.0001, lambda = 0.0001)
watchlist <- list(eval = dtest, train = dtrain)
num_round <- 2
bst <- xgb.train(param, dtrain, num_round, watchlist)
n <- 5 # iterations
ERR_UL <- 0.005 # upper limit for the test set error
VERB <- 0 # chatterbox switch
param$updater = 'shotgun'
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle')
ypred <- predict(bst, dtest)
expect_equal(length(getinfo(dtest, 'label')), 1611)
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic',
callbacks = list(cb.gblinear.history()))
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
h <- xgb.gblinear.history(bst)
expect_equal(dim(h), c(n, ncol(dtrain) + 1))
expect_is(h, "matrix")
param$updater = 'coord_descent'
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic')
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle')
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
bst <- xgb.train(param, dtrain, 2, watchlist, verbose = VERB, feature_selector = 'greedy')
expect_lt(bst$evaluation_log$eval_error[2], ERR_UL)
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty',
top_n = 50, callbacks = list(cb.gblinear.history(sparse = TRUE)))
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
h <- xgb.gblinear.history(bst)
expect_equal(dim(h), c(n, ncol(dtrain) + 1))
expect_s4_class(h, "dgCMatrix")
})

View File

@ -119,7 +119,7 @@ ColIterator(const std::vector<bst_uint>& fset) {
}
bool SparsePageDMatrix::TryInitColData() {
bool SparsePageDMatrix::TryInitColData(bool sorted) {
// load meta data.
std::vector<std::string> cache_shards = common::Split(cache_info_, ':');
{
@ -140,6 +140,8 @@ bool SparsePageDMatrix::TryInitColData() {
files.push_back(std::move(fdata));
}
col_iter_.reset(new ColPageIter(std::move(files)));
// warning: no attempt to check here whether the cached data was sorted
col_iter_->sorted = sorted;
return true;
}
@ -147,7 +149,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
float pkeep,
size_t max_row_perbatch, bool sorted) {
if (HaveColAccess(sorted)) return;
if (TryInitColData()) return;
if (TryInitColData(sorted)) return;
const MetaInfo& info = this->info();
if (max_row_perbatch == std::numeric_limits<size_t>::max()) {
max_row_perbatch = kMaxRowPerBatch;
@ -291,8 +293,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
fo.reset(nullptr);
}
// initialize column data
CHECK(TryInitColData());
col_iter_->sorted = sorted;
CHECK(TryInitColData(sorted));
}
} // namespace data

View File

@ -116,7 +116,7 @@ class SparsePageDMatrix : public DMatrix {
* \brief Try to initialize column data.
* \return true if data already exists, false if they do not.
*/
bool TryInitColData();
bool TryInitColData(bool sorted);
// source data pointer.
std::unique_ptr<DataSource> source_;
// the cache prefix

View File

@ -21,14 +21,12 @@ namespace gbm {
DMLC_REGISTRY_FILE_TAG(gblinear);
// training parameter
// training parameters
struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
/*! \brief learning_rate */
std::string updater;
// flag to print out detailed breakdown of runtime
int debug_verbose;
float tolerance;
// declare parameters
size_t max_row_perbatch;
int debug_verbose;
DMLC_DECLARE_PARAMETER(GBLinearTrainParam) {
DMLC_DECLARE_FIELD(updater)
.set_default("shotgun")
@ -37,6 +35,9 @@ struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("Stop if largest weight update is smaller than this number.");
DMLC_DECLARE_FIELD(max_row_perbatch)
.set_default(std::numeric_limits<size_t>::max())
.describe("Maximum rows per batch.");
DMLC_DECLARE_FIELD(debug_verbose)
.set_lower_bound(0)
.set_default(0)
@ -84,12 +85,10 @@ class GBLinear : public GradientBooster {
if (!p_fmat->HaveColAccess(false)) {
std::vector<bool> enabled(p_fmat->info().num_col, true);
p_fmat->InitColAccess(enabled, 1.0f, std::numeric_limits<size_t>::max(),
false);
p_fmat->InitColAccess(enabled, 1.0f, param.max_row_perbatch, false);
}
model.LazyInitModel();
this->LazySumWeights(p_fmat);
if (!this->CheckConvergence()) {
@ -191,40 +190,7 @@ class GBLinear : public GradientBooster {
std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const override {
const int ngroup = model.param.num_output_group;
const unsigned nfeature = model.param.num_feature;
std::stringstream fo("");
if (format == "json") {
fo << " { \"bias\": [" << std::endl;
for (int gid = 0; gid < ngroup; ++gid) {
if (gid != 0) fo << "," << std::endl;
fo << " " << model.bias()[gid];
}
fo << std::endl << " ]," << std::endl
<< " \"weight\": [" << std::endl;
for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) {
if (i != 0 || gid != 0) fo << "," << std::endl;
fo << " " << model[i][gid];
}
}
fo << std::endl << " ]" << std::endl << " }";
} else {
fo << "bias:\n";
for (int gid = 0; gid < ngroup; ++gid) {
fo << model.bias()[gid] << std::endl;
}
fo << "weight:\n";
for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) {
fo << model[i][gid] << std::endl;
}
}
}
std::vector<std::string> v;
v.push_back(fo.str());
return v;
return model.DumpModel(fmap, with_stats, format);
}
protected:
@ -272,9 +238,12 @@ class GBLinear : public GradientBooster {
bool CheckConvergence() {
if (param.tolerance == 0.0f) return false;
if (is_converged) return true;
if (previous_model.weight.size() != model.weight.size()) return false;
if (previous_model.weight.size() != model.weight.size()) {
previous_model = model;
return false;
}
float largest_dw = 0.0;
for (auto i = 0; i < model.weight.size(); i++) {
for (size_t i = 0; i < model.weight.size(); i++) {
largest_dw = std::max(
largest_dw, std::abs(model.weight[i] - previous_model.weight[i]));
}
@ -287,7 +256,7 @@ class GBLinear : public GradientBooster {
void LazySumWeights(DMatrix *p_fmat) {
if (!sum_weight_complete) {
auto &info = p_fmat->info();
for (int i = 0; i < info.num_row; i++) {
for (size_t i = 0; i < info.num_row; i++) {
sum_instance_weight += info.GetWeight(i);
}
sum_weight_complete = true;

View File

@ -4,7 +4,9 @@
#pragma once
#include <dmlc/io.h>
#include <dmlc/parameter.h>
#include <xgboost/feature_map.h>
#include <vector>
#include <string>
#include <cstring>
namespace xgboost {
@ -68,6 +70,44 @@ class GBLinearModel {
inline const bst_float* operator[](size_t i) const {
return &weight[i * param.num_output_group];
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
std::string format) const {
const int ngroup = param.num_output_group;
const unsigned nfeature = param.num_feature;
std::stringstream fo("");
if (format == "json") {
fo << " { \"bias\": [" << std::endl;
for (int gid = 0; gid < ngroup; ++gid) {
if (gid != 0) fo << "," << std::endl;
fo << " " << this->bias()[gid];
}
fo << std::endl << " ]," << std::endl
<< " \"weight\": [" << std::endl;
for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) {
if (i != 0 || gid != 0) fo << "," << std::endl;
fo << " " << (*this)[i][gid];
}
}
fo << std::endl << " ]" << std::endl << " }";
} else {
fo << "bias:\n";
for (int gid = 0; gid < ngroup; ++gid) {
fo << this->bias()[gid] << std::endl;
}
fo << "weight:\n";
for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) {
fo << (*this)[i][gid] << std::endl;
}
}
}
std::vector<std::string> v;
v.push_back(fo.str());
return v;
}
};
} // namespace gbm
} // namespace xgboost

View File

@ -7,6 +7,7 @@
#include <string>
#include <utility>
#include <vector>
#include <limits>
#include "../common/random.h"
namespace xgboost {
@ -19,26 +20,21 @@ namespace linear {
* \param sum_grad The sum gradient.
* \param sum_hess The sum hess.
* \param w The weight.
* \param reg_lambda Unnormalised L2 penalty.
* \param reg_alpha Unnormalised L1 penalty.
* \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty.
* \param reg_lambda Unnormalised L2 penalty.
*
* \return The weight update.
*/
inline double CoordinateDelta(double sum_grad, double sum_hess, double w,
double reg_lambda, double reg_alpha,
double sum_instance_weight) {
reg_alpha *= sum_instance_weight;
reg_lambda *= sum_instance_weight;
double reg_alpha, double reg_lambda) {
if (sum_hess < 1e-5f) return 0.0f;
double tmp = w - (sum_grad + reg_lambda * w) / (sum_hess + reg_lambda);
const double sum_grad_l2 = sum_grad + reg_lambda * w;
const double sum_hess_l2 = sum_hess + reg_lambda;
const double tmp = w - sum_grad_l2 / sum_hess_l2;
if (tmp >= 0) {
return std::max(
-(sum_grad + reg_lambda * w + reg_alpha) / (sum_hess + reg_lambda), -w);
return std::max(-(sum_grad_l2 + reg_alpha) / sum_hess_l2, -w);
} else {
return std::min(
-(sum_grad + reg_lambda * w - reg_alpha) / (sum_hess + reg_lambda), -w);
return std::min(-(sum_grad_l2 - reg_alpha) / sum_hess_l2, -w);
}
}
@ -50,7 +46,6 @@ inline double CoordinateDelta(double sum_grad, double sum_hess, double w,
*
* \return The weight update.
*/
inline double CoordinateDeltaBias(double sum_grad, double sum_hess) {
return -sum_grad / sum_hess;
}
@ -66,15 +61,14 @@ inline double CoordinateDeltaBias(double sum_grad, double sum_hess) {
*
* \return The gradient and diagonal Hessian entry for a given feature.
*/
inline std::pair<double, double> GetGradient(
int group_idx, int num_group, int fidx, const std::vector<bst_gpair> &gpair,
inline std::pair<double, double> GetGradient(int group_idx, int num_group, int fidx,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat) {
double sum_grad = 0.0, sum_hess = 0.0;
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
while (iter->Next()) {
const ColBatch &batch = iter->Value();
ColBatch::Inst col = batch[fidx];
ColBatch::Inst col = batch[0];
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_float v = col[j].fvalue;
@ -88,7 +82,7 @@ inline std::pair<double, double> GetGradient(
}
/**
* \brief Get the gradient with respect to a single feature. Multithreaded.
* \brief Get the gradient with respect to a single feature. Row-wise multithreaded.
*
* \param group_idx Zero-based index of the group.
* \param num_group Number of groups.
@ -98,16 +92,14 @@ inline std::pair<double, double> GetGradient(
*
* \return The gradient and diagonal Hessian entry for a given feature.
*/
inline std::pair<double, double> GetGradientParallel(
int group_idx, int num_group, int fidx,
const std::vector<bst_gpair> &gpair, DMatrix *p_fmat) {
inline std::pair<double, double> GetGradientParallel(int group_idx, int num_group, int fidx,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat) {
double sum_grad = 0.0, sum_hess = 0.0;
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
while (iter->Next()) {
const ColBatch &batch = iter->Value();
ColBatch::Inst col = batch[fidx];
ColBatch::Inst col = batch[0];
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
for (bst_omp_uint j = 0; j < ndata; ++j) {
@ -122,7 +114,7 @@ inline std::pair<double, double> GetGradientParallel(
}
/**
* \brief Get the gradient with respect to the bias. Multithreaded.
* \brief Get the gradient with respect to the bias. Row-wise multithreaded.
*
* \param group_idx Zero-based index of the group.
* \param num_group Number of groups.
@ -131,9 +123,8 @@ inline std::pair<double, double> GetGradientParallel(
*
* \return The gradient and diagonal Hessian entry for the bias.
*/
inline std::pair<double, double> GetBiasGradientParallel(
int group_idx, int num_group, const std::vector<bst_gpair> &gpair,
inline std::pair<double, double> GetBiasGradientParallel(int group_idx, int num_group,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat) {
const RowSet &rowset = p_fmat->buffered_rowset();
double sum_grad = 0.0, sum_hess = 0.0;
@ -159,15 +150,14 @@ inline std::pair<double, double> GetBiasGradientParallel(
* \param in_gpair The gradient vector to be updated.
* \param p_fmat The input feature matrix.
*/
inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
float dw, std::vector<bst_gpair> *in_gpair,
DMatrix *p_fmat) {
if (dw == 0.0f) return;
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
while (iter->Next()) {
const ColBatch &batch = iter->Value();
ColBatch::Inst col = batch[fidx];
ColBatch::Inst col = batch[0];
// update grad value
const bst_omp_uint num_row = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
@ -188,9 +178,7 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
* \param in_gpair The gradient vector to be updated.
* \param p_fmat The input feature matrix.
*/
inline void UpdateBiasResidualParallel(int group_idx, int num_group,
float dbias,
inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias,
std::vector<bst_gpair> *in_gpair,
DMatrix *p_fmat) {
if (dbias == 0.0f) return;
@ -205,114 +193,292 @@ inline void UpdateBiasResidualParallel(int group_idx, int num_group,
}
/**
* \class FeatureSelector
*
* \brief Abstract class for stateful feature selection in coordinate descent
* algorithms.
* \brief Abstract class for stateful feature selection or ordering
* in coordinate descent algorithms.
*/
class FeatureSelector {
public:
static FeatureSelector *Create(std::string name);
/*! \brief factory method */
static FeatureSelector *Create(int choice);
/*! \brief virtual destructor */
virtual ~FeatureSelector() {}
/**
* \brief Setting up the selector state prior to looping through features.
*
* \param model The model.
* \param gpair The gpair.
* \param p_fmat The feature matrix.
* \param alpha Regularisation alpha.
* \param lambda Regularisation lambda.
* \param param A parameter with algorithm-dependent use.
*/
virtual void Setup(const gbm::GBLinearModel &model,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
float alpha, float lambda, int param) {}
/**
* \brief Select next coordinate to update.
*
* \param iteration The iteration.
* \param iteration The iteration in a loop through features
* \param model The model.
* \param group_idx Zero-based index of the group.
* \param gpair The gpair.
* \param p_fmat The feature matrix.
* \param alpha Regularisation alpha.
* \param lambda Regularisation lambda.
* \param sum_instance_weight The sum instance weight.
*
* \return The index of the selected feature. -1 indicates the bias term.
* \return The index of the selected feature. -1 indicates none selected.
*/
virtual int SelectNextFeature(int iteration,
virtual int NextFeature(int iteration,
const gbm::GBLinearModel &model,
int group_idx,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda,
double sum_instance_weight) = 0;
DMatrix *p_fmat, float alpha, float lambda) = 0;
};
/**
* \class CyclicFeatureSelector
*
* \brief Deterministic selection by cycling through coordinates one at a time.
* \brief Deterministic selection by cycling through features one at a time.
*/
class CyclicFeatureSelector : public FeatureSelector {
public:
int SelectNextFeature(int iteration, const gbm::GBLinearModel &model,
int NextFeature(int iteration, const gbm::GBLinearModel &model,
int group_idx, const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda,
double sum_instance_weight) override {
DMatrix *p_fmat, float alpha, float lambda) override {
return iteration % model.param.num_feature;
}
};
/**
* \class RandomFeatureSelector
*
* \brief A random coordinate selector.
* \brief Similar to Cyclyc but with random feature shuffling prior to each update.
* \note Its randomness is controllable by setting a random seed.
*/
class ShuffleFeatureSelector : public FeatureSelector {
public:
void Setup(const gbm::GBLinearModel &model,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda, int param) override {
if (feat_index.size() == 0) {
feat_index.resize(model.param.num_feature);
std::iota(feat_index.begin(), feat_index.end(), 0);
}
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
}
int NextFeature(int iteration, const gbm::GBLinearModel &model,
int group_idx, const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda) override {
return feat_index[iteration % model.param.num_feature];
}
protected:
std::vector<bst_uint> feat_index;
};
/**
* \brief A random (with replacement) coordinate selector.
* \note Its randomness is controllable by setting a random seed.
*/
class RandomFeatureSelector : public FeatureSelector {
public:
int SelectNextFeature(int iteration, const gbm::GBLinearModel &model,
int NextFeature(int iteration, const gbm::GBLinearModel &model,
int group_idx, const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda,
double sum_instance_weight) override {
DMatrix *p_fmat, float alpha, float lambda) override {
return common::GlobalRandom()() % model.param.num_feature;
}
};
/**
* \class GreedyFeatureSelector
*
* \brief Select coordinate with the greatest gradient magnitude.
* \note It has O(num_feature^2) complexity. It is fully deterministic.
*
* \note It allows restricting the selection to top_k features per group with
* the largest magnitude of univariate weight change, by passing the top_k value
* through the `param` argument of Setup(). That would reduce the complexity to
* O(num_feature*top_k).
*/
class GreedyFeatureSelector : public FeatureSelector {
public:
int SelectNextFeature(int iteration, const gbm::GBLinearModel &model,
void Setup(const gbm::GBLinearModel &model,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda, int param) override {
top_k = static_cast<bst_uint>(param);
const bst_uint ngroup = model.param.num_output_group;
if (param <= 0) top_k = std::numeric_limits<bst_uint>::max();
if (counter.size() == 0) {
counter.resize(ngroup);
gpair_sums.resize(model.param.num_feature * ngroup);
}
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
counter[gid] = 0u;
}
}
int NextFeature(int iteration, const gbm::GBLinearModel &model,
int group_idx, const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda,
double sum_instance_weight) override {
// Find best
DMatrix *p_fmat, float alpha, float lambda) override {
// k-th selected feature for a group
auto k = counter[group_idx]++;
// stop after either reaching top-K or going through all the features in a group
if (k >= top_k || counter[group_idx] == model.param.num_feature) return -1;
const int ngroup = model.param.num_output_group;
const bst_omp_uint nfeat = model.param.num_feature;
// Calculate univariate gradient sums
std::fill(gpair_sums.begin(), gpair_sums.end(), std::make_pair(0., 0.));
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const ColBatch::Inst col = batch[i];
const bst_uint ndata = col.length;
auto &sums = gpair_sums[group_idx * nfeat + i];
for (bst_uint j = 0u; j < ndata; ++j) {
const bst_float v = col[j].fvalue;
auto &p = gpair[col[j].index * ngroup + group_idx];
if (p.GetHess() < 0.f) continue;
sums.first += p.GetGrad() * v;
sums.second += p.GetHess() * v * v;
}
}
}
// Find a feature with the largest magnitude of weight change
int best_fidx = 0;
double best_weight_update = 0.0f;
for (auto fidx = 0U; fidx < model.param.num_feature; fidx++) {
const float w = model[fidx][group_idx];
auto gradient = GetGradientParallel(
group_idx, model.param.num_output_group, fidx, gpair, p_fmat);
float dw = static_cast<float>(
CoordinateDelta(gradient.first, gradient.second, w, lambda, alpha,
sum_instance_weight));
if (std::abs(dw) > std::abs(best_weight_update)) {
for (bst_omp_uint fidx = 0; fidx < nfeat; ++fidx) {
auto &s = gpair_sums[group_idx * nfeat + fidx];
float dw = std::abs(static_cast<bst_float>(
CoordinateDelta(s.first, s.second, model[fidx][group_idx], alpha, lambda)));
if (dw > best_weight_update) {
best_weight_update = dw;
best_fidx = fidx;
}
}
return best_fidx;
}
protected:
bst_uint top_k;
std::vector<bst_uint> counter;
std::vector<std::pair<double, double>> gpair_sums;
};
inline FeatureSelector *FeatureSelector::Create(std::string name) {
if (name == "cyclic") {
/**
* \brief Thrifty, approximately-greedy feature selector.
*
* \note Prior to cyclic updates, reorders features in descending magnitude of
* their univariate weight changes. This operation is multithreaded and is a
* linear complexity approximation of the quadratic greedy selection.
*
* \note It allows restricting the selection to top_k features per group with
* the largest magnitude of univariate weight change, by passing the top_k value
* through the `param` argument of Setup().
*/
class ThriftyFeatureSelector : public FeatureSelector {
public:
void Setup(const gbm::GBLinearModel &model,
const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda, int param) override {
top_k = static_cast<bst_uint>(param);
if (param <= 0) top_k = std::numeric_limits<bst_uint>::max();
const bst_uint ngroup = model.param.num_output_group;
const bst_omp_uint nfeat = model.param.num_feature;
if (deltaw.size() == 0) {
deltaw.resize(nfeat * ngroup);
sorted_idx.resize(nfeat * ngroup);
counter.resize(ngroup);
gpair_sums.resize(nfeat * ngroup);
}
// Calculate univariate gradient sums
std::fill(gpair_sums.begin(), gpair_sums.end(), std::make_pair(0., 0.));
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// column-parallel is usually faster than row-parallel
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const ColBatch::Inst col = batch[i];
const bst_uint ndata = col.length;
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
auto &sums = gpair_sums[gid * nfeat + i];
for (bst_uint j = 0u; j < ndata; ++j) {
const bst_float v = col[j].fvalue;
auto &p = gpair[col[j].index * ngroup + gid];
if (p.GetHess() < 0.f) continue;
sums.first += p.GetGrad() * v;
sums.second += p.GetHess() * v * v;
}
}
}
}
// rank by descending weight magnitude within the groups
std::fill(deltaw.begin(), deltaw.end(), 0.f);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
bst_float *pdeltaw = &deltaw[0];
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
// Calculate univariate weight changes
for (bst_omp_uint i = 0; i < nfeat; ++i) {
auto ii = gid * nfeat + i;
auto &s = gpair_sums[ii];
deltaw[ii] = static_cast<bst_float>(CoordinateDelta(
s.first, s.second, model[i][gid], alpha, lambda));
}
// sort in descending order of deltaw abs values
auto start = sorted_idx.begin() + gid * nfeat;
std::sort(start, start + nfeat,
[pdeltaw](size_t i, size_t j) {
return std::abs(*(pdeltaw + i)) > std::abs(*(pdeltaw + j));
});
counter[gid] = 0u;
}
}
int NextFeature(int iteration, const gbm::GBLinearModel &model,
int group_idx, const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, float alpha, float lambda) override {
// k-th selected feature for a group
auto k = counter[group_idx]++;
// stop after either reaching top-N or going through all the features in a group
if (k >= top_k || counter[group_idx] == model.param.num_feature) return -1;
// note that sorted_idx stores the "long" indices
const size_t grp_offset = group_idx * model.param.num_feature;
return static_cast<int>(sorted_idx[grp_offset + k] - grp_offset);
}
protected:
bst_uint top_k;
std::vector<bst_float> deltaw;
std::vector<size_t> sorted_idx;
std::vector<bst_uint> counter;
std::vector<std::pair<double, double>> gpair_sums;
};
/**
* \brief A set of available FeatureSelector's
*/
enum FeatureSelectorEnum {
kCyclic = 0,
kShuffle,
kThrifty,
kGreedy,
kRandom
};
inline FeatureSelector *FeatureSelector::Create(int choice) {
switch (choice) {
case kCyclic:
return new CyclicFeatureSelector();
} else if (name == "random") {
return new RandomFeatureSelector();
} else if (name == "greedy") {
case kShuffle:
return new ShuffleFeatureSelector();
case kThrifty:
return new ThriftyFeatureSelector();
case kGreedy:
return new GreedyFeatureSelector();
} else {
LOG(FATAL) << name << ": unknown coordinate selector";
case kRandom:
return new RandomFeatureSelector();
default:
LOG(FATAL) << "unknown coordinate selector: " << choice;
}
return nullptr;
}

View File

@ -20,8 +20,8 @@ struct CoordinateTrainParam : public dmlc::Parameter<CoordinateTrainParam> {
float reg_lambda;
/*! \brief regularization weight for L1 norm */
float reg_alpha;
std::string feature_selector;
float maximum_weight;
int feature_selector;
int top_k;
int debug_verbose;
// declare parameters
DMLC_DECLARE_PARAMETER(CoordinateTrainParam) {
@ -38,17 +38,35 @@ struct CoordinateTrainParam : public dmlc::Parameter<CoordinateTrainParam> {
.set_default(0.0f)
.describe("L1 regularization on weights.");
DMLC_DECLARE_FIELD(feature_selector)
.set_default("cyclic")
.describe(
"Feature selection algorithm, one of cyclic/random/greedy");
.set_default(kCyclic)
.add_enum("cyclic", kCyclic)
.add_enum("shuffle", kShuffle)
.add_enum("thrifty", kThrifty)
.add_enum("greedy", kGreedy)
.add_enum("random", kRandom)
.describe("Feature selection or ordering method.");
DMLC_DECLARE_FIELD(top_k)
.set_lower_bound(0)
.set_default(0)
.describe("The number of top features to select in 'thrifty' feature_selector. "
"The value of zero means using all the features.");
DMLC_DECLARE_FIELD(debug_verbose)
.set_lower_bound(0)
.set_default(0)
.describe("flag to print out detailed breakdown of runtime");
// alias of parameters
DMLC_DECLARE_ALIAS(learning_rate, eta);
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
}
/*! \brief Denormalizes the regularization penalties - to be called at each update */
void DenormalizePenalties(double sum_instance_weight) {
reg_lambda_denorm = reg_lambda * sum_instance_weight;
reg_alpha_denorm = reg_alpha * sum_instance_weight;
}
// denormalizated regularization penalties
float reg_lambda_denorm;
float reg_alpha_denorm;
};
/**
@ -66,47 +84,47 @@ class CoordinateUpdater : public LinearUpdater {
selector.reset(FeatureSelector::Create(param.feature_selector));
monitor.Init("CoordinateUpdater", param.debug_verbose);
}
void Update(std::vector<bst_gpair> *in_gpair, DMatrix *p_fmat,
gbm::GBLinearModel *model, double sum_instance_weight) override {
// Calculate bias
for (int group_idx = 0; group_idx < model->param.num_output_group;
++group_idx) {
auto grad = GetBiasGradientParallel(
group_idx, model->param.num_output_group, *in_gpair, p_fmat);
auto dbias = static_cast<float>(
param.learning_rate * CoordinateDeltaBias(grad.first, grad.second));
param.DenormalizePenalties(sum_instance_weight);
const int ngroup = model->param.num_output_group;
// update bias
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
auto grad = GetBiasGradientParallel(group_idx, ngroup, *in_gpair, p_fmat);
auto dbias = static_cast<float>(param.learning_rate *
CoordinateDeltaBias(grad.first, grad.second));
model->bias()[group_idx] += dbias;
UpdateBiasResidualParallel(group_idx, model->param.num_output_group,
dbias, in_gpair, p_fmat);
UpdateBiasResidualParallel(group_idx, ngroup, dbias, in_gpair, p_fmat);
}
for (int group_idx = 0; group_idx < model->param.num_output_group;
++group_idx) {
for (auto i = 0U; i < model->param.num_feature; i++) {
int fidx = selector->SelectNextFeature(
i, *model, group_idx, *in_gpair, p_fmat, param.reg_alpha,
param.reg_lambda, sum_instance_weight);
this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model,
sum_instance_weight);
// prepare for updating the weights
selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm,
param.reg_lambda_denorm, param.top_k);
// update weights
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
for (unsigned i = 0U; i < model->param.num_feature; i++) {
int fidx = selector->NextFeature(i, *model, group_idx, *in_gpair, p_fmat,
param.reg_alpha_denorm, param.reg_lambda_denorm);
if (fidx < 0) break;
this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model);
}
}
}
void UpdateFeature(int fidx, int group_idx, std::vector<bst_gpair> *in_gpair,
DMatrix *p_fmat, gbm::GBLinearModel *model,
double sum_instance_weight) {
inline void UpdateFeature(int fidx, int group_idx, std::vector<bst_gpair> *in_gpair,
DMatrix *p_fmat, gbm::GBLinearModel *model) {
const int ngroup = model->param.num_output_group;
bst_float &w = (*model)[fidx][group_idx];
monitor.Start("GetGradientParallel");
auto gradient = GetGradientParallel(
group_idx, model->param.num_output_group, fidx, *in_gpair, p_fmat);
auto gradient = GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
monitor.Stop("GetGradientParallel");
auto dw = static_cast<float>(
param.learning_rate *
CoordinateDelta(gradient.first, gradient.second, w, param.reg_lambda,
param.reg_alpha, sum_instance_weight));
CoordinateDelta(gradient.first, gradient.second, w, param.reg_alpha_denorm,
param.reg_lambda_denorm));
w += dw;
monitor.Start("UpdateResidualParallel");
UpdateResidualParallel(fidx, group_idx, model->param.num_output_group, dw,
in_gpair, p_fmat);
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat);
monitor.Stop("UpdateResidualParallel");
}

View File

@ -19,11 +19,12 @@ struct ShotgunTrainParam : public dmlc::Parameter<ShotgunTrainParam> {
float reg_lambda;
/*! \brief regularization weight for L1 norm */
float reg_alpha;
int feature_selector;
// declare parameters
DMLC_DECLARE_PARAMETER(ShotgunTrainParam) {
DMLC_DECLARE_FIELD(learning_rate)
.set_lower_bound(0.0f)
.set_default(1.0f)
.set_default(0.5f)
.describe("Learning rate of each update.");
DMLC_DECLARE_FIELD(reg_lambda)
.set_lower_bound(0.0f)
@ -33,75 +34,79 @@ struct ShotgunTrainParam : public dmlc::Parameter<ShotgunTrainParam> {
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("L1 regularization on weights.");
DMLC_DECLARE_FIELD(feature_selector)
.set_default(kCyclic)
.add_enum("cyclic", kCyclic)
.add_enum("shuffle", kShuffle)
.describe("Feature selection or ordering method.");
// alias of parameters
DMLC_DECLARE_ALIAS(learning_rate, eta);
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
}
/*! \brief Denormalizes the regularization penalties - to be called at each update */
void DenormalizePenalties(double sum_instance_weight) {
reg_lambda_denorm = reg_lambda * sum_instance_weight;
reg_alpha_denorm = reg_alpha * sum_instance_weight;
}
// denormalizated regularization penalties
float reg_lambda_denorm;
float reg_alpha_denorm;
};
class ShotgunUpdater : public LinearUpdater {
public:
// set training parameter
void Init(
const std::vector<std::pair<std::string, std::string> > &args) override {
void Init(const std::vector<std::pair<std::string, std::string> > &args) override {
param.InitAllowUnknown(args);
selector.reset(FeatureSelector::Create(param.feature_selector));
}
void Update(std::vector<bst_gpair> *in_gpair, DMatrix *p_fmat,
gbm::GBLinearModel *model, double sum_instance_weight) override {
param.DenormalizePenalties(sum_instance_weight);
std::vector<bst_gpair> &gpair = *in_gpair;
const int ngroup = model->param.num_output_group;
const RowSet &rowset = p_fmat->buffered_rowset();
// for all the output group
// update bias
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
for (bst_omp_uint i = 0; i < ndata; ++i) {
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
if (p.GetHess() >= 0.0f) {
sum_grad += p.GetGrad();
sum_hess += p.GetHess();
}
}
// remove bias effect
bst_float dw = static_cast<bst_float>(
param.learning_rate * CoordinateDeltaBias(sum_grad, sum_hess));
model->bias()[gid] += dw;
// update grad value
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
if (p.GetHess() >= 0.0f) {
p += bst_gpair(p.GetHess() * dw, 0);
}
}
auto grad = GetBiasGradientParallel(gid, ngroup, *in_gpair, p_fmat);
auto dbias = static_cast<bst_float>(param.learning_rate *
CoordinateDeltaBias(grad.first, grad.second));
model->bias()[gid] += dbias;
UpdateBiasResidualParallel(gid, ngroup, dbias, in_gpair, p_fmat);
}
// lock-free parallel updates of weights
selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm, 0);
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
while (iter->Next()) {
// number of features
const ColBatch &batch = iter->Value();
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const bst_uint fid = batch.col_index[i];
ColBatch::Inst col = batch[i];
int ii = selector->NextFeature(i, *model, 0, *in_gpair, p_fmat,
param.reg_alpha_denorm, param.reg_lambda_denorm);
if (ii < 0) continue;
const bst_uint fid = batch.col_index[ii];
ColBatch::Inst col = batch[ii];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (bst_uint j = 0; j < col.length; ++j) {
const bst_float v = col[j].fvalue;
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.GetHess() < 0.0f) continue;
const bst_float v = col[j].fvalue;
sum_grad += p.GetGrad() * v;
sum_hess += p.GetHess() * v * v;
}
bst_float &w = (*model)[fid][gid];
bst_float dw = static_cast<bst_float>(
param.learning_rate *
CoordinateDelta(sum_grad, sum_hess, w, param.reg_lambda,
param.reg_alpha, sum_instance_weight));
CoordinateDelta(sum_grad, sum_hess, w, param.reg_alpha_denorm,
param.reg_lambda_denorm));
if (dw == 0.f) continue;
w += dw;
// update grad value
// update grad values
for (bst_uint j = 0; j < col.length; ++j) {
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.GetHess() < 0.0f) continue;
@ -112,8 +117,11 @@ class ShotgunUpdater : public LinearUpdater {
}
}
// training parameter
protected:
// training parameters
ShotgunTrainParam param;
std::unique_ptr<FeatureSelector> selector;
};
DMLC_REGISTER_PARAMETER(ShotgunTrainParam);

View File

@ -12,7 +12,7 @@ TEST(Linear, shotgun) {
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("shotgun"));
updater->Init({});
updater->Init({{"eta", "1."}});
std::vector<xgboost::bst_gpair> gpair(mat->info().num_row,
xgboost::bst_gpair(-5, 1.0));
xgboost::gbm::GBLinearModel model;

View File

@ -3,9 +3,17 @@ from __future__ import print_function
import itertools as it
import numpy as np
import sys
import os
import glob
import testing as tm
import unittest
import xgboost as xgb
try:
from sklearn import metrics, datasets
from sklearn.linear_model import ElasticNet
from sklearn.preprocessing import scale
except ImportError:
None
rng = np.random.RandomState(199)
@ -21,39 +29,35 @@ def is_float(s):
def xgb_get_weights(bst):
return [float(s) for s in bst.get_dump()[0].split() if is_float(s)]
return np.array([float(s) for s in bst.get_dump()[0].split() if is_float(s)])
# Check gradient/subgradient = 0
def check_least_squares_solution(X, y, pred, tol, reg_alpha, reg_lambda, weights):
reg_alpha = reg_alpha * len(y)
reg_lambda = reg_lambda * len(y)
r = np.subtract(y, pred)
g = X.T.dot(r)
g = np.subtract(g, np.multiply(reg_lambda, weights))
for i in range(0, len(weights)):
if weights[i] == 0.0:
assert abs(g[i]) <= reg_alpha
else:
assert np.isclose(g[i], np.sign(weights[i]) * reg_alpha, rtol=tol, atol=tol)
def check_ElasticNet(X, y, pred, tol, reg_alpha, reg_lambda, weights):
enet = ElasticNet(alpha=reg_alpha + reg_lambda,
l1_ratio=reg_alpha / (reg_alpha + reg_lambda))
enet.fit(X, y)
enet_pred = enet.predict(X)
assert np.isclose(weights, enet.coef_, rtol=tol, atol=tol).all()
assert np.isclose(enet_pred, pred, rtol=tol, atol=tol).all()
def train_diabetes(param_in):
from sklearn import datasets
data = datasets.load_diabetes()
dtrain = xgb.DMatrix(data.data, label=data.target)
X = scale(data.data)
dtrain = xgb.DMatrix(X, label=data.target)
param = {}
param.update(param_in)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'],
check_ElasticNet(X, data.target, xgb_pred, 1e-2,
param['alpha'], param['lambda'],
xgb_get_weights(bst)[1:])
def train_breast_cancer(param_in):
from sklearn import metrics, datasets
data = datasets.load_breast_cancer()
dtrain = xgb.DMatrix(data.data, label=data.target)
X = scale(data.data)
dtrain = xgb.DMatrix(X, label=data.target)
param = {'objective': 'binary:logistic'}
param.update(param_in)
bst = xgb.train(param, dtrain, num_rounds)
@ -63,9 +67,8 @@ def train_breast_cancer(param_in):
def train_classification(param_in):
from sklearn import metrics, datasets
X, y = datasets.make_classification(random_state=rng,
scale=100) # Scale is necessary otherwise regularisation parameters will force all coefficients to 0
X, y = datasets.make_classification(random_state=rng)
X = scale(X)
dtrain = xgb.DMatrix(X, label=y)
param = {'objective': 'binary:logistic'}
param.update(param_in)
@ -76,10 +79,11 @@ def train_classification(param_in):
def train_classification_multi(param_in):
from sklearn import metrics, datasets
num_class = 3
X, y = datasets.make_classification(n_samples=10, random_state=rng, scale=100, n_classes=num_class, n_informative=4,
X, y = datasets.make_classification(n_samples=100, random_state=rng,
n_classes=num_class, n_informative=4,
n_features=4, n_redundant=0)
X = scale(X)
dtrain = xgb.DMatrix(X, label=y)
param = {'objective': 'multi:softmax', 'num_class': num_class}
param.update(param_in)
@ -90,20 +94,42 @@ def train_classification_multi(param_in):
def train_boston(param_in):
from sklearn import datasets
data = datasets.load_boston()
dtrain = xgb.DMatrix(data.data, label=data.target)
X = scale(data.data)
dtrain = xgb.DMatrix(X, label=data.target)
param = {}
param.update(param_in)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'],
check_ElasticNet(X, data.target, xgb_pred, 1e-2,
param['alpha'], param['lambda'],
xgb_get_weights(bst)[1:])
def train_external_mem(param_in):
data = datasets.load_boston()
X = scale(data.data)
y = data.target
param = {}
param.update(param_in)
dtrain = xgb.DMatrix(X, label=y)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
np.savetxt('tmptmp_1234.csv', np.hstack((y.reshape(len(y), 1), X)),
delimiter=',', fmt='%10.9f')
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_')
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred_ext = bst.predict(dtrain)
assert np.abs(xgb_pred_ext - xgb_pred).max() < 1e-3
del dtrain, bst
for f in glob.glob("tmptmp_*"):
os.remove(f)
# Enumerates all permutations of variable parameters
def assert_updater_accuracy(linear_updater, variable_param):
param = {'booster': 'gblinear', 'updater': linear_updater, 'tolerance': 1e-8}
param = {'booster': 'gblinear', 'updater': linear_updater, 'eta': 1.,
'top_k': 10, 'tolerance': 1e-5, 'nthread': 2}
names = sorted(variable_param)
combinations = it.product(*(variable_param[Name] for Name in names))
@ -118,16 +144,17 @@ def assert_updater_accuracy(linear_updater, variable_param):
train_classification(param_tmp)
train_classification_multi(param_tmp)
train_breast_cancer(param_tmp)
train_external_mem(param_tmp)
class TestLinear(unittest.TestCase):
def test_coordinate(self):
tm._skip_if_no_sklearn()
variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0],
'coordinate_selection': ['cyclic', 'random', 'greedy']}
variable_param = {'alpha': [.005, .1], 'lambda': [.005],
'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']}
assert_updater_accuracy('coord_descent', variable_param)
def test_shotgun(self):
tm._skip_if_no_sklearn()
variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0]}
variable_param = {'alpha': [.005, .1], 'lambda': [.005, .1]}
assert_updater_accuracy('shotgun', variable_param)