Additional improvements for gblinear (#3134)
* fix rebase conflict * [core] additional gblinear improvements * [R] callback for gblinear coefficients history * force eta=1 for gblinear python tests * add top_k to GreedyFeatureSelector * set eta=1 in shotgun test * [core] fix SparsePage processing in gblinear; col-wise multithreading in greedy updater * set sorted flag within TryInitColData * gblinear tests: use scale, add external memory test * fix multiclass for greedy updater * fix whitespace * fix typo
This commit is contained in:
parent
a1b48afa41
commit
706be4e5d4
@ -1,8 +1,8 @@
|
||||
Package: xgboost
|
||||
Type: Package
|
||||
Title: Extreme Gradient Boosting
|
||||
Version: 0.7.0
|
||||
Date: 2018-01-22
|
||||
Version: 0.7.0.1
|
||||
Date: 2018-02-25
|
||||
Author: Tianqi Chen <tianqi.tchen@gmail.com>, Tong He <hetong007@gmail.com>,
|
||||
Michael Benesty <michael@benesty.fr>, Vadim Khotilovich <khotilovich@gmail.com>,
|
||||
Yuan Tang <terrytangyuan@gmail.com>
|
||||
|
||||
@ -18,6 +18,7 @@ export("xgb.parameters<-")
|
||||
export(cb.cv.predict)
|
||||
export(cb.early.stop)
|
||||
export(cb.evaluation.log)
|
||||
export(cb.gblinear.history)
|
||||
export(cb.print.evaluation)
|
||||
export(cb.reset.parameters)
|
||||
export(cb.save.model)
|
||||
@ -32,6 +33,7 @@ export(xgb.attributes)
|
||||
export(xgb.create.features)
|
||||
export(xgb.cv)
|
||||
export(xgb.dump)
|
||||
export(xgb.gblinear.history)
|
||||
export(xgb.ggplot.deepness)
|
||||
export(xgb.ggplot.importance)
|
||||
export(xgb.importance)
|
||||
@ -52,7 +54,9 @@ importClassesFrom(Matrix,dgeMatrix)
|
||||
importFrom(Matrix,cBind)
|
||||
importFrom(Matrix,colSums)
|
||||
importFrom(Matrix,sparse.model.matrix)
|
||||
importFrom(Matrix,sparseMatrix)
|
||||
importFrom(Matrix,sparseVector)
|
||||
importFrom(Matrix,t)
|
||||
importFrom(data.table,":=")
|
||||
importFrom(data.table,as.data.table)
|
||||
importFrom(data.table,data.table)
|
||||
|
||||
@ -524,6 +524,225 @@ cb.cv.predict <- function(save_models = FALSE) {
|
||||
}
|
||||
|
||||
|
||||
#' Callback closure for collecting the model coefficients history of a gblinear booster
|
||||
#' during its training.
|
||||
#'
|
||||
#' @param sparse when set to FALSE/TURE, a dense/sparse matrix is used to store the result.
|
||||
#' Sparse format is useful when one expects only a subset of coefficients to be non-zero,
|
||||
#' when using the "thrifty" feature selector with fairly small number of top features
|
||||
#' selected per iteration.
|
||||
#'
|
||||
#' @details
|
||||
#' To keep things fast and simple, gblinear booster does not internally store the history of linear
|
||||
#' model coefficients at each boosting iteration. This callback provides a workaround for storing
|
||||
#' the coefficients' path, by extracting them after each training iteration.
|
||||
#'
|
||||
#' Callback function expects the following values to be set in its calling frame:
|
||||
#' \code{bst} (or \code{bst_folds}).
|
||||
#'
|
||||
#' @return
|
||||
#' Results are stored in the \code{coefs} element of the closure.
|
||||
#' The \code{\link{xgb.gblinear.history}} convenience function provides an easy way to access it.
|
||||
#' With \code{xgb.train}, it is either a dense of a sparse matrix.
|
||||
#' While with \code{xgb.cv}, it is a list (an element per each fold) of such matrices.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{callbacks}}, \code{\link{xgb.gblinear.history}}.
|
||||
#'
|
||||
#' @examples
|
||||
#' #### Binary classification:
|
||||
#' #
|
||||
#' # In the iris dataset, it is hard to linearly separate Versicolor class from the rest
|
||||
#' # without considering the 2nd order interactions:
|
||||
#' x <- model.matrix(Species ~ .^2, iris)[,-1]
|
||||
#' colnames(x)
|
||||
#' dtrain <- xgb.DMatrix(scale(x), label = 1*(iris$Species == "versicolor"))
|
||||
#' param <- list(booster = "gblinear", objective = "reg:logistic", eval_metric = "auc",
|
||||
#' lambda = 0.0003, alpha = 0.0003, nthread = 2)
|
||||
#' # For 'shotgun', which is a default linear updater, using high eta values may result in
|
||||
#' # unstable behaviour in some datasets. With this simple dataset, however, the high learning
|
||||
#' # rate does not break the convergence, but allows us to illustrate the typical pattern of
|
||||
#' # "stochastic explosion" behaviour of this lock-free algorithm at early boosting iterations.
|
||||
#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 1.,
|
||||
#' callbacks = list(cb.gblinear.history()))
|
||||
#' # Extract the coefficients' path and plot them vs boosting iteration number:
|
||||
#' coef_path <- xgb.gblinear.history(bst)
|
||||
#' matplot(coef_path, type = 'l')
|
||||
#'
|
||||
#' # With the deterministic coordinate descent updater, it is safer to use higher learning rates.
|
||||
#' # Will try the classical componentwise boosting which selects a single best feature per round:
|
||||
#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 0.8,
|
||||
#' updater = 'coord_descent', feature_selector = 'thrifty', top_k = 1,
|
||||
#' callbacks = list(cb.gblinear.history()))
|
||||
#' xgb.gblinear.history(bst) %>% matplot(type = 'l')
|
||||
#' # Componentwise boosting is known to have similar effect to Lasso regularization.
|
||||
#' # Try experimenting with various values of top_k, eta, nrounds,
|
||||
#' # as well as different feature_selectors.
|
||||
#'
|
||||
#' # For xgb.cv:
|
||||
#' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 100, eta = 0.8,
|
||||
#' callbacks = list(cb.gblinear.history()))
|
||||
#' # coefficients in the CV fold #3
|
||||
#' xgb.gblinear.history(bst)[[3]] %>% matplot(type = 'l')
|
||||
#'
|
||||
#'
|
||||
#' #### Multiclass classification:
|
||||
#' #
|
||||
#' dtrain <- xgb.DMatrix(scale(x), label = as.numeric(iris$Species) - 1)
|
||||
#' param <- list(booster = "gblinear", objective = "multi:softprob", num_class = 3,
|
||||
#' lambda = 0.0003, alpha = 0.0003, nthread = 2)
|
||||
#' # For the default linear updater 'shotgun' it sometimes is helpful
|
||||
#' # to use smaller eta to reduce instability
|
||||
#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 70, eta = 0.5,
|
||||
#' callbacks = list(cb.gblinear.history()))
|
||||
#' # Will plot the coefficient paths separately for each class:
|
||||
#' xgb.gblinear.history(bst, class_index = 0) %>% matplot(type = 'l')
|
||||
#' xgb.gblinear.history(bst, class_index = 1) %>% matplot(type = 'l')
|
||||
#' xgb.gblinear.history(bst, class_index = 2) %>% matplot(type = 'l')
|
||||
#'
|
||||
#' # CV:
|
||||
#' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 70, eta = 0.5,
|
||||
#' callbacks = list(cb.gblinear.history(F)))
|
||||
#' # 1st forld of 1st class
|
||||
#' xgb.gblinear.history(bst, class_index = 0)[[1]] %>% matplot(type = 'l')
|
||||
#'
|
||||
#' @export
|
||||
cb.gblinear.history <- function(sparse=FALSE) {
|
||||
coefs <- NULL
|
||||
|
||||
init <- function(env) {
|
||||
if (!is.null(env$bst)) { # xgb.train:
|
||||
coef_path <- list()
|
||||
} else if (!is.null(env$bst_folds)) { # xgb.cv:
|
||||
coef_path <- rep(list(), length(env$bst_folds))
|
||||
} else stop("Parent frame has neither 'bst' nor 'bst_folds'")
|
||||
}
|
||||
|
||||
# convert from list to (sparse) matrix
|
||||
list2mat <- function(coef_list) {
|
||||
if (sparse) {
|
||||
coef_mat <- sparseMatrix(x = unlist(lapply(coef_list, slot, "x")),
|
||||
i = unlist(lapply(coef_list, slot, "i")),
|
||||
p = c(0, cumsum(sapply(coef_list, function(x) length(x@x)))),
|
||||
dims = c(length(coef_list[[1]]), length(coef_list)))
|
||||
return(t(coef_mat))
|
||||
} else {
|
||||
return(do.call(rbind, coef_list))
|
||||
}
|
||||
}
|
||||
|
||||
finalizer <- function(env) {
|
||||
if (length(coefs) == 0)
|
||||
return()
|
||||
if (!is.null(env$bst)) { # # xgb.train:
|
||||
coefs <<- list2mat(coefs)
|
||||
} else { # xgb.cv:
|
||||
# first lapply transposes the list
|
||||
coefs <<- lapply(seq_along(coefs[[1]]), function(i) lapply(coefs, "[[", i)) %>%
|
||||
lapply(function(x) list2mat(x))
|
||||
}
|
||||
}
|
||||
|
||||
extract.coef <- function(env) {
|
||||
if (!is.null(env$bst)) { # # xgb.train:
|
||||
cf <- as.numeric(grep('(booster|bias|weigh)', xgb.dump(env$bst), invert = TRUE, value = TRUE))
|
||||
if (sparse) cf <- as(cf, "sparseVector")
|
||||
} else { # xgb.cv:
|
||||
cf <- vector("list", length(env$bst_folds))
|
||||
for (i in seq_along(env$bst_folds)) {
|
||||
dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst))
|
||||
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
|
||||
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
|
||||
}
|
||||
}
|
||||
cf
|
||||
}
|
||||
|
||||
callback <- function(env = parent.frame(), finalize = FALSE) {
|
||||
if (is.null(coefs)) init(env)
|
||||
if (finalize) return(finalizer(env))
|
||||
cf <- extract.coef(env)
|
||||
coefs <<- c(coefs, list(cf))
|
||||
}
|
||||
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.gblinear.history'
|
||||
callback
|
||||
}
|
||||
|
||||
#' Extract gblinear coefficients history.
|
||||
#'
|
||||
#' A helper function to extract the matrix of linear coefficients' history
|
||||
#' from a gblinear model created while using the \code{cb.gblinear.history()}
|
||||
#' callback.
|
||||
#'
|
||||
#' @param model either an \code{xgb.Booster} or a result of \code{xgb.cv()}, trained
|
||||
#' using the \code{cb.gblinear.history()} callback.
|
||||
#' @param class_index zero-based class index to extract the coefficients for only that
|
||||
#' specific class in a multinomial multiclass model. When it is NULL, all the
|
||||
#' coeffients are returned. Has no effect in non-multiclass models.
|
||||
#'
|
||||
#' @return
|
||||
#' For an \code{xgb.train} result, a matrix (either dense or sparse) with the columns
|
||||
#' corresponding to iteration's coefficients (in the order as \code{xgb.dump()} would
|
||||
#' return) and the rows corresponding to boosting iterations.
|
||||
#'
|
||||
#' For an \code{xgb.cv} result, a list of such matrices is returned with the elements
|
||||
#' corresponding to CV folds.
|
||||
#'
|
||||
#' @examples
|
||||
#' See \code{\link{cv.gblinear.history}}
|
||||
#'
|
||||
#' @export
|
||||
xgb.gblinear.history <- function(model, class_index = NULL) {
|
||||
|
||||
if (!(inherits(model, "xgb.Booster") ||
|
||||
inherits(model, "xgb.cv.synchronous")))
|
||||
stop("model must be an object of either xgb.Booster or xgb.cv.synchronous class")
|
||||
is_cv <- inherits(model, "xgb.cv.synchronous")
|
||||
|
||||
if (is.null(model[["callbacks"]]) || is.null(model$callbacks[["cb.gblinear.history"]]))
|
||||
stop("model must be trained while using the cb.gblinear.history() callback")
|
||||
|
||||
if (!is_cv) {
|
||||
# extract num_class & num_feat from the internal model
|
||||
dmp <- xgb.dump(model)
|
||||
if(length(dmp) < 2 || dmp[2] != "bias:")
|
||||
stop("It does not appear to be a gblinear model")
|
||||
dmp <- dmp[-c(1,2)]
|
||||
n <- which(dmp == 'weight:')
|
||||
if(length(n) != 1)
|
||||
stop("It does not appear to be a gblinear model")
|
||||
num_class <- n - 1
|
||||
num_feat <- (length(dmp) - 4) / num_class
|
||||
} else {
|
||||
# in case of CV, the object is expected to have this info
|
||||
if (model$params$booster != "gblinear")
|
||||
stop("It does not appear to be a gblinear model")
|
||||
num_class <- NVL(model$params$num_class, 1)
|
||||
num_feat <- model$nfeatures
|
||||
if (is.null(num_feat))
|
||||
stop("This xgb.cv result does not have nfeatures info")
|
||||
}
|
||||
|
||||
if (!is.null(class_index) &&
|
||||
num_class > 1 &&
|
||||
(class_index[1] < 0 || class_index[1] >= num_class))
|
||||
stop("class_index has to be within [0,", num_class - 1, "]")
|
||||
|
||||
coef_path <- environment(model$callbacks$cb.gblinear.history)[["coefs"]]
|
||||
if (!is.null(class_index) && num_class > 1) {
|
||||
coef_path <- if (is.list(coef_path)) {
|
||||
lapply(coef_path,
|
||||
function(x) x[, seq(1 + class_index, by=num_class, length.out=num_feat)])
|
||||
} else {
|
||||
coef_path <- coef_path[, seq(1 + class_index, by=num_class, length.out=num_feat)]
|
||||
}
|
||||
}
|
||||
coef_path
|
||||
}
|
||||
|
||||
|
||||
#
|
||||
# Internal utility functions for callbacks ------------------------------------
|
||||
#
|
||||
|
||||
@ -88,6 +88,7 @@
|
||||
#' CV-based evaluation means and standard deviations for the training and test CV-sets.
|
||||
#' It is created by the \code{\link{cb.evaluation.log}} callback.
|
||||
#' \item \code{niter} number of boosting iterations.
|
||||
#' \item \code{nfeatures} number of features in training data.
|
||||
#' \item \code{folds} the list of CV folds' indices - either those passed through the \code{folds}
|
||||
#' parameter or randomly generated.
|
||||
#' \item \code{best_iteration} iteration number with the best evaluation metric value
|
||||
@ -184,6 +185,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
handle <- xgb.Booster.handle(params, list(dtrain, dtest))
|
||||
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test=dtest), index = folds[[k]])
|
||||
})
|
||||
rm(dall)
|
||||
# a "basket" to collect some results from callbacks
|
||||
basket <- list()
|
||||
|
||||
@ -221,6 +223,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
callbacks = callbacks,
|
||||
evaluation_log = evaluation_log,
|
||||
niter = end_iteration,
|
||||
nfeatures = ncol(data),
|
||||
folds = folds
|
||||
)
|
||||
ret <- c(ret, basket)
|
||||
|
||||
@ -162,6 +162,7 @@
|
||||
#' (only available with early stopping).
|
||||
#' \item \code{feature_names} names of the training dataset features
|
||||
#' (only when comun names were defined in training data).
|
||||
#' \item \code{nfeatures} number of features in training data.
|
||||
#' }
|
||||
#'
|
||||
#' @seealso
|
||||
@ -363,6 +364,7 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
bst$callbacks <- callbacks
|
||||
if (!is.null(colnames(dtrain)))
|
||||
bst$feature_names <- colnames(dtrain)
|
||||
bst$nfeatures <- ncol(dtrain)
|
||||
|
||||
return(bst)
|
||||
}
|
||||
|
||||
@ -81,6 +81,8 @@ NULL
|
||||
#' @importFrom Matrix colSums
|
||||
#' @importFrom Matrix sparse.model.matrix
|
||||
#' @importFrom Matrix sparseVector
|
||||
#' @importFrom Matrix sparseMatrix
|
||||
#' @importFrom Matrix t
|
||||
#' @importFrom data.table data.table
|
||||
#' @importFrom data.table is.data.table
|
||||
#' @importFrom data.table as.data.table
|
||||
|
||||
@ -104,6 +104,7 @@ An object of class \code{xgb.cv.synchronous} with the following elements:
|
||||
CV-based evaluation means and standard deviations for the training and test CV-sets.
|
||||
It is created by the \code{\link{cb.evaluation.log}} callback.
|
||||
\item \code{niter} number of boosting iterations.
|
||||
\item \code{nfeatures} number of features in training data.
|
||||
\item \code{folds} the list of CV folds' indices - either those passed through the \code{folds}
|
||||
parameter or randomly generated.
|
||||
\item \code{best_iteration} iteration number with the best evaluation metric value
|
||||
|
||||
@ -155,6 +155,7 @@ An object of class \code{xgb.Booster} with the following elements:
|
||||
(only available with early stopping).
|
||||
\item \code{feature_names} names of the training dataset features
|
||||
(only when comun names were defined in training data).
|
||||
\item \code{nfeatures} number of features in training data.
|
||||
}
|
||||
}
|
||||
\description{
|
||||
|
||||
@ -2,18 +2,47 @@ context('Test generalized linear models')
|
||||
|
||||
require(xgboost)
|
||||
|
||||
test_that("glm works", {
|
||||
test_that("gblinear works", {
|
||||
data(agaricus.train, package='xgboost')
|
||||
data(agaricus.test, package='xgboost')
|
||||
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
|
||||
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||
expect_equal(class(dtrain), "xgb.DMatrix")
|
||||
expect_equal(class(dtest), "xgb.DMatrix")
|
||||
|
||||
param <- list(objective = "binary:logistic", booster = "gblinear",
|
||||
nthread = 2, alpha = 0.0001, lambda = 1)
|
||||
nthread = 2, eta = 0.8, alpha = 0.0001, lambda = 0.0001)
|
||||
watchlist <- list(eval = dtest, train = dtrain)
|
||||
num_round <- 2
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||
|
||||
n <- 5 # iterations
|
||||
ERR_UL <- 0.005 # upper limit for the test set error
|
||||
VERB <- 0 # chatterbox switch
|
||||
|
||||
param$updater = 'shotgun'
|
||||
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle')
|
||||
ypred <- predict(bst, dtest)
|
||||
expect_equal(length(getinfo(dtest, 'label')), 1611)
|
||||
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
|
||||
|
||||
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic',
|
||||
callbacks = list(cb.gblinear.history()))
|
||||
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
|
||||
h <- xgb.gblinear.history(bst)
|
||||
expect_equal(dim(h), c(n, ncol(dtrain) + 1))
|
||||
expect_is(h, "matrix")
|
||||
|
||||
param$updater = 'coord_descent'
|
||||
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic')
|
||||
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
|
||||
|
||||
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle')
|
||||
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
|
||||
|
||||
bst <- xgb.train(param, dtrain, 2, watchlist, verbose = VERB, feature_selector = 'greedy')
|
||||
expect_lt(bst$evaluation_log$eval_error[2], ERR_UL)
|
||||
|
||||
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty',
|
||||
top_n = 50, callbacks = list(cb.gblinear.history(sparse = TRUE)))
|
||||
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
|
||||
h <- xgb.gblinear.history(bst)
|
||||
expect_equal(dim(h), c(n, ncol(dtrain) + 1))
|
||||
expect_s4_class(h, "dgCMatrix")
|
||||
})
|
||||
|
||||
@ -119,7 +119,7 @@ ColIterator(const std::vector<bst_uint>& fset) {
|
||||
}
|
||||
|
||||
|
||||
bool SparsePageDMatrix::TryInitColData() {
|
||||
bool SparsePageDMatrix::TryInitColData(bool sorted) {
|
||||
// load meta data.
|
||||
std::vector<std::string> cache_shards = common::Split(cache_info_, ':');
|
||||
{
|
||||
@ -140,6 +140,8 @@ bool SparsePageDMatrix::TryInitColData() {
|
||||
files.push_back(std::move(fdata));
|
||||
}
|
||||
col_iter_.reset(new ColPageIter(std::move(files)));
|
||||
// warning: no attempt to check here whether the cached data was sorted
|
||||
col_iter_->sorted = sorted;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -147,7 +149,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch, bool sorted) {
|
||||
if (HaveColAccess(sorted)) return;
|
||||
if (TryInitColData()) return;
|
||||
if (TryInitColData(sorted)) return;
|
||||
const MetaInfo& info = this->info();
|
||||
if (max_row_perbatch == std::numeric_limits<size_t>::max()) {
|
||||
max_row_perbatch = kMaxRowPerBatch;
|
||||
@ -291,8 +293,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
||||
fo.reset(nullptr);
|
||||
}
|
||||
// initialize column data
|
||||
CHECK(TryInitColData());
|
||||
col_iter_->sorted = sorted;
|
||||
CHECK(TryInitColData(sorted));
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
|
||||
@ -116,7 +116,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
* \brief Try to initialize column data.
|
||||
* \return true if data already exists, false if they do not.
|
||||
*/
|
||||
bool TryInitColData();
|
||||
bool TryInitColData(bool sorted);
|
||||
// source data pointer.
|
||||
std::unique_ptr<DataSource> source_;
|
||||
// the cache prefix
|
||||
|
||||
@ -21,14 +21,12 @@ namespace gbm {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gblinear);
|
||||
|
||||
// training parameter
|
||||
// training parameters
|
||||
struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
|
||||
/*! \brief learning_rate */
|
||||
std::string updater;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
float tolerance;
|
||||
// declare parameters
|
||||
size_t max_row_perbatch;
|
||||
int debug_verbose;
|
||||
DMLC_DECLARE_PARAMETER(GBLinearTrainParam) {
|
||||
DMLC_DECLARE_FIELD(updater)
|
||||
.set_default("shotgun")
|
||||
@ -37,6 +35,9 @@ struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("Stop if largest weight update is smaller than this number.");
|
||||
DMLC_DECLARE_FIELD(max_row_perbatch)
|
||||
.set_default(std::numeric_limits<size_t>::max())
|
||||
.describe("Maximum rows per batch.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
@ -84,12 +85,10 @@ class GBLinear : public GradientBooster {
|
||||
|
||||
if (!p_fmat->HaveColAccess(false)) {
|
||||
std::vector<bool> enabled(p_fmat->info().num_col, true);
|
||||
p_fmat->InitColAccess(enabled, 1.0f, std::numeric_limits<size_t>::max(),
|
||||
false);
|
||||
p_fmat->InitColAccess(enabled, 1.0f, param.max_row_perbatch, false);
|
||||
}
|
||||
|
||||
model.LazyInitModel();
|
||||
|
||||
this->LazySumWeights(p_fmat);
|
||||
|
||||
if (!this->CheckConvergence()) {
|
||||
@ -191,40 +190,7 @@ class GBLinear : public GradientBooster {
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const override {
|
||||
const int ngroup = model.param.num_output_group;
|
||||
const unsigned nfeature = model.param.num_feature;
|
||||
|
||||
std::stringstream fo("");
|
||||
if (format == "json") {
|
||||
fo << " { \"bias\": [" << std::endl;
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
if (gid != 0) fo << "," << std::endl;
|
||||
fo << " " << model.bias()[gid];
|
||||
}
|
||||
fo << std::endl << " ]," << std::endl
|
||||
<< " \"weight\": [" << std::endl;
|
||||
for (unsigned i = 0; i < nfeature; ++i) {
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
if (i != 0 || gid != 0) fo << "," << std::endl;
|
||||
fo << " " << model[i][gid];
|
||||
}
|
||||
}
|
||||
fo << std::endl << " ]" << std::endl << " }";
|
||||
} else {
|
||||
fo << "bias:\n";
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
fo << model.bias()[gid] << std::endl;
|
||||
}
|
||||
fo << "weight:\n";
|
||||
for (unsigned i = 0; i < nfeature; ++i) {
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
fo << model[i][gid] << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<std::string> v;
|
||||
v.push_back(fo.str());
|
||||
return v;
|
||||
return model.DumpModel(fmap, with_stats, format);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -272,9 +238,12 @@ class GBLinear : public GradientBooster {
|
||||
bool CheckConvergence() {
|
||||
if (param.tolerance == 0.0f) return false;
|
||||
if (is_converged) return true;
|
||||
if (previous_model.weight.size() != model.weight.size()) return false;
|
||||
if (previous_model.weight.size() != model.weight.size()) {
|
||||
previous_model = model;
|
||||
return false;
|
||||
}
|
||||
float largest_dw = 0.0;
|
||||
for (auto i = 0; i < model.weight.size(); i++) {
|
||||
for (size_t i = 0; i < model.weight.size(); i++) {
|
||||
largest_dw = std::max(
|
||||
largest_dw, std::abs(model.weight[i] - previous_model.weight[i]));
|
||||
}
|
||||
@ -287,7 +256,7 @@ class GBLinear : public GradientBooster {
|
||||
void LazySumWeights(DMatrix *p_fmat) {
|
||||
if (!sum_weight_complete) {
|
||||
auto &info = p_fmat->info();
|
||||
for (int i = 0; i < info.num_row; i++) {
|
||||
for (size_t i = 0; i < info.num_row; i++) {
|
||||
sum_instance_weight += info.GetWeight(i);
|
||||
}
|
||||
sum_weight_complete = true;
|
||||
|
||||
@ -4,7 +4,9 @@
|
||||
#pragma once
|
||||
#include <dmlc/io.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
|
||||
namespace xgboost {
|
||||
@ -68,6 +70,44 @@ class GBLinearModel {
|
||||
inline const bst_float* operator[](size_t i) const {
|
||||
return &weight[i * param.num_output_group];
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
std::string format) const {
|
||||
const int ngroup = param.num_output_group;
|
||||
const unsigned nfeature = param.num_feature;
|
||||
|
||||
std::stringstream fo("");
|
||||
if (format == "json") {
|
||||
fo << " { \"bias\": [" << std::endl;
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
if (gid != 0) fo << "," << std::endl;
|
||||
fo << " " << this->bias()[gid];
|
||||
}
|
||||
fo << std::endl << " ]," << std::endl
|
||||
<< " \"weight\": [" << std::endl;
|
||||
for (unsigned i = 0; i < nfeature; ++i) {
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
if (i != 0 || gid != 0) fo << "," << std::endl;
|
||||
fo << " " << (*this)[i][gid];
|
||||
}
|
||||
}
|
||||
fo << std::endl << " ]" << std::endl << " }";
|
||||
} else {
|
||||
fo << "bias:\n";
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
fo << this->bias()[gid] << std::endl;
|
||||
}
|
||||
fo << "weight:\n";
|
||||
for (unsigned i = 0; i < nfeature; ++i) {
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
fo << (*this)[i][gid] << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<std::string> v;
|
||||
v.push_back(fo.str());
|
||||
return v;
|
||||
}
|
||||
};
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include "../common/random.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -19,26 +20,21 @@ namespace linear {
|
||||
* \param sum_grad The sum gradient.
|
||||
* \param sum_hess The sum hess.
|
||||
* \param w The weight.
|
||||
* \param reg_lambda Unnormalised L2 penalty.
|
||||
* \param reg_alpha Unnormalised L1 penalty.
|
||||
* \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty.
|
||||
* \param reg_lambda Unnormalised L2 penalty.
|
||||
*
|
||||
* \return The weight update.
|
||||
*/
|
||||
|
||||
inline double CoordinateDelta(double sum_grad, double sum_hess, double w,
|
||||
double reg_lambda, double reg_alpha,
|
||||
double sum_instance_weight) {
|
||||
reg_alpha *= sum_instance_weight;
|
||||
reg_lambda *= sum_instance_weight;
|
||||
double reg_alpha, double reg_lambda) {
|
||||
if (sum_hess < 1e-5f) return 0.0f;
|
||||
double tmp = w - (sum_grad + reg_lambda * w) / (sum_hess + reg_lambda);
|
||||
const double sum_grad_l2 = sum_grad + reg_lambda * w;
|
||||
const double sum_hess_l2 = sum_hess + reg_lambda;
|
||||
const double tmp = w - sum_grad_l2 / sum_hess_l2;
|
||||
if (tmp >= 0) {
|
||||
return std::max(
|
||||
-(sum_grad + reg_lambda * w + reg_alpha) / (sum_hess + reg_lambda), -w);
|
||||
return std::max(-(sum_grad_l2 + reg_alpha) / sum_hess_l2, -w);
|
||||
} else {
|
||||
return std::min(
|
||||
-(sum_grad + reg_lambda * w - reg_alpha) / (sum_hess + reg_lambda), -w);
|
||||
return std::min(-(sum_grad_l2 - reg_alpha) / sum_hess_l2, -w);
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +46,6 @@ inline double CoordinateDelta(double sum_grad, double sum_hess, double w,
|
||||
*
|
||||
* \return The weight update.
|
||||
*/
|
||||
|
||||
inline double CoordinateDeltaBias(double sum_grad, double sum_hess) {
|
||||
return -sum_grad / sum_hess;
|
||||
}
|
||||
@ -66,15 +61,14 @@ inline double CoordinateDeltaBias(double sum_grad, double sum_hess) {
|
||||
*
|
||||
* \return The gradient and diagonal Hessian entry for a given feature.
|
||||
*/
|
||||
|
||||
inline std::pair<double, double> GetGradient(
|
||||
int group_idx, int num_group, int fidx, const std::vector<bst_gpair> &gpair,
|
||||
inline std::pair<double, double> GetGradient(int group_idx, int num_group, int fidx,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
ColBatch::Inst col = batch[fidx];
|
||||
ColBatch::Inst col = batch[0];
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
const bst_float v = col[j].fvalue;
|
||||
@ -88,7 +82,7 @@ inline std::pair<double, double> GetGradient(
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get the gradient with respect to a single feature. Multithreaded.
|
||||
* \brief Get the gradient with respect to a single feature. Row-wise multithreaded.
|
||||
*
|
||||
* \param group_idx Zero-based index of the group.
|
||||
* \param num_group Number of groups.
|
||||
@ -98,16 +92,14 @@ inline std::pair<double, double> GetGradient(
|
||||
*
|
||||
* \return The gradient and diagonal Hessian entry for a given feature.
|
||||
*/
|
||||
|
||||
inline std::pair<double, double> GetGradientParallel(
|
||||
int group_idx, int num_group, int fidx,
|
||||
|
||||
const std::vector<bst_gpair> &gpair, DMatrix *p_fmat) {
|
||||
inline std::pair<double, double> GetGradientParallel(int group_idx, int num_group, int fidx,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
ColBatch::Inst col = batch[fidx];
|
||||
ColBatch::Inst col = batch[0];
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
|
||||
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
@ -122,7 +114,7 @@ inline std::pair<double, double> GetGradientParallel(
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get the gradient with respect to the bias. Multithreaded.
|
||||
* \brief Get the gradient with respect to the bias. Row-wise multithreaded.
|
||||
*
|
||||
* \param group_idx Zero-based index of the group.
|
||||
* \param num_group Number of groups.
|
||||
@ -131,9 +123,8 @@ inline std::pair<double, double> GetGradientParallel(
|
||||
*
|
||||
* \return The gradient and diagonal Hessian entry for the bias.
|
||||
*/
|
||||
|
||||
inline std::pair<double, double> GetBiasGradientParallel(
|
||||
int group_idx, int num_group, const std::vector<bst_gpair> &gpair,
|
||||
inline std::pair<double, double> GetBiasGradientParallel(int group_idx, int num_group,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat) {
|
||||
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
@ -159,15 +150,14 @@ inline std::pair<double, double> GetBiasGradientParallel(
|
||||
* \param in_gpair The gradient vector to be updated.
|
||||
* \param p_fmat The input feature matrix.
|
||||
*/
|
||||
|
||||
inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
||||
float dw, std::vector<bst_gpair> *in_gpair,
|
||||
DMatrix *p_fmat) {
|
||||
if (dw == 0.0f) return;
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
ColBatch::Inst col = batch[fidx];
|
||||
ColBatch::Inst col = batch[0];
|
||||
// update grad value
|
||||
const bst_omp_uint num_row = static_cast<bst_omp_uint>(col.length);
|
||||
#pragma omp parallel for schedule(static)
|
||||
@ -188,9 +178,7 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
||||
* \param in_gpair The gradient vector to be updated.
|
||||
* \param p_fmat The input feature matrix.
|
||||
*/
|
||||
|
||||
inline void UpdateBiasResidualParallel(int group_idx, int num_group,
|
||||
float dbias,
|
||||
inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias,
|
||||
std::vector<bst_gpair> *in_gpair,
|
||||
DMatrix *p_fmat) {
|
||||
if (dbias == 0.0f) return;
|
||||
@ -205,114 +193,292 @@ inline void UpdateBiasResidualParallel(int group_idx, int num_group,
|
||||
}
|
||||
|
||||
/**
|
||||
* \class FeatureSelector
|
||||
*
|
||||
* \brief Abstract class for stateful feature selection in coordinate descent
|
||||
* algorithms.
|
||||
* \brief Abstract class for stateful feature selection or ordering
|
||||
* in coordinate descent algorithms.
|
||||
*/
|
||||
|
||||
class FeatureSelector {
|
||||
public:
|
||||
static FeatureSelector *Create(std::string name);
|
||||
/*! \brief factory method */
|
||||
static FeatureSelector *Create(int choice);
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~FeatureSelector() {}
|
||||
|
||||
/**
|
||||
* \brief Setting up the selector state prior to looping through features.
|
||||
*
|
||||
* \param model The model.
|
||||
* \param gpair The gpair.
|
||||
* \param p_fmat The feature matrix.
|
||||
* \param alpha Regularisation alpha.
|
||||
* \param lambda Regularisation lambda.
|
||||
* \param param A parameter with algorithm-dependent use.
|
||||
*/
|
||||
virtual void Setup(const gbm::GBLinearModel &model,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
float alpha, float lambda, int param) {}
|
||||
/**
|
||||
* \brief Select next coordinate to update.
|
||||
*
|
||||
* \param iteration The iteration.
|
||||
* \param iteration The iteration in a loop through features
|
||||
* \param model The model.
|
||||
* \param group_idx Zero-based index of the group.
|
||||
* \param gpair The gpair.
|
||||
* \param p_fmat The feature matrix.
|
||||
* \param alpha Regularisation alpha.
|
||||
* \param lambda Regularisation lambda.
|
||||
* \param sum_instance_weight The sum instance weight.
|
||||
*
|
||||
* \return The index of the selected feature. -1 indicates the bias term.
|
||||
* \return The index of the selected feature. -1 indicates none selected.
|
||||
*/
|
||||
|
||||
virtual int SelectNextFeature(int iteration,
|
||||
virtual int NextFeature(int iteration,
|
||||
const gbm::GBLinearModel &model,
|
||||
int group_idx,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda,
|
||||
double sum_instance_weight) = 0;
|
||||
DMatrix *p_fmat, float alpha, float lambda) = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* \class CyclicFeatureSelector
|
||||
*
|
||||
* \brief Deterministic selection by cycling through coordinates one at a time.
|
||||
* \brief Deterministic selection by cycling through features one at a time.
|
||||
*/
|
||||
|
||||
class CyclicFeatureSelector : public FeatureSelector {
|
||||
public:
|
||||
int SelectNextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
int group_idx, const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda,
|
||||
double sum_instance_weight) override {
|
||||
DMatrix *p_fmat, float alpha, float lambda) override {
|
||||
return iteration % model.param.num_feature;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \class RandomFeatureSelector
|
||||
*
|
||||
* \brief A random coordinate selector.
|
||||
* \brief Similar to Cyclyc but with random feature shuffling prior to each update.
|
||||
* \note Its randomness is controllable by setting a random seed.
|
||||
*/
|
||||
class ShuffleFeatureSelector : public FeatureSelector {
|
||||
public:
|
||||
void Setup(const gbm::GBLinearModel &model,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda, int param) override {
|
||||
if (feat_index.size() == 0) {
|
||||
feat_index.resize(model.param.num_feature);
|
||||
std::iota(feat_index.begin(), feat_index.end(), 0);
|
||||
}
|
||||
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
|
||||
}
|
||||
|
||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
int group_idx, const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda) override {
|
||||
return feat_index[iteration % model.param.num_feature];
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<bst_uint> feat_index;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief A random (with replacement) coordinate selector.
|
||||
* \note Its randomness is controllable by setting a random seed.
|
||||
*/
|
||||
class RandomFeatureSelector : public FeatureSelector {
|
||||
public:
|
||||
int SelectNextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
int group_idx, const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda,
|
||||
double sum_instance_weight) override {
|
||||
DMatrix *p_fmat, float alpha, float lambda) override {
|
||||
return common::GlobalRandom()() % model.param.num_feature;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \class GreedyFeatureSelector
|
||||
*
|
||||
* \brief Select coordinate with the greatest gradient magnitude.
|
||||
* \note It has O(num_feature^2) complexity. It is fully deterministic.
|
||||
*
|
||||
* \note It allows restricting the selection to top_k features per group with
|
||||
* the largest magnitude of univariate weight change, by passing the top_k value
|
||||
* through the `param` argument of Setup(). That would reduce the complexity to
|
||||
* O(num_feature*top_k).
|
||||
*/
|
||||
|
||||
class GreedyFeatureSelector : public FeatureSelector {
|
||||
public:
|
||||
int SelectNextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
void Setup(const gbm::GBLinearModel &model,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda, int param) override {
|
||||
top_k = static_cast<bst_uint>(param);
|
||||
const bst_uint ngroup = model.param.num_output_group;
|
||||
if (param <= 0) top_k = std::numeric_limits<bst_uint>::max();
|
||||
if (counter.size() == 0) {
|
||||
counter.resize(ngroup);
|
||||
gpair_sums.resize(model.param.num_feature * ngroup);
|
||||
}
|
||||
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
||||
counter[gid] = 0u;
|
||||
}
|
||||
}
|
||||
|
||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
int group_idx, const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda,
|
||||
double sum_instance_weight) override {
|
||||
// Find best
|
||||
DMatrix *p_fmat, float alpha, float lambda) override {
|
||||
// k-th selected feature for a group
|
||||
auto k = counter[group_idx]++;
|
||||
// stop after either reaching top-K or going through all the features in a group
|
||||
if (k >= top_k || counter[group_idx] == model.param.num_feature) return -1;
|
||||
|
||||
const int ngroup = model.param.num_output_group;
|
||||
const bst_omp_uint nfeat = model.param.num_feature;
|
||||
// Calculate univariate gradient sums
|
||||
std::fill(gpair_sums.begin(), gpair_sums.end(), std::make_pair(0., 0.));
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
const ColBatch::Inst col = batch[i];
|
||||
const bst_uint ndata = col.length;
|
||||
auto &sums = gpair_sums[group_idx * nfeat + i];
|
||||
for (bst_uint j = 0u; j < ndata; ++j) {
|
||||
const bst_float v = col[j].fvalue;
|
||||
auto &p = gpair[col[j].index * ngroup + group_idx];
|
||||
if (p.GetHess() < 0.f) continue;
|
||||
sums.first += p.GetGrad() * v;
|
||||
sums.second += p.GetHess() * v * v;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Find a feature with the largest magnitude of weight change
|
||||
int best_fidx = 0;
|
||||
double best_weight_update = 0.0f;
|
||||
|
||||
for (auto fidx = 0U; fidx < model.param.num_feature; fidx++) {
|
||||
const float w = model[fidx][group_idx];
|
||||
auto gradient = GetGradientParallel(
|
||||
group_idx, model.param.num_output_group, fidx, gpair, p_fmat);
|
||||
float dw = static_cast<float>(
|
||||
CoordinateDelta(gradient.first, gradient.second, w, lambda, alpha,
|
||||
sum_instance_weight));
|
||||
if (std::abs(dw) > std::abs(best_weight_update)) {
|
||||
for (bst_omp_uint fidx = 0; fidx < nfeat; ++fidx) {
|
||||
auto &s = gpair_sums[group_idx * nfeat + fidx];
|
||||
float dw = std::abs(static_cast<bst_float>(
|
||||
CoordinateDelta(s.first, s.second, model[fidx][group_idx], alpha, lambda)));
|
||||
if (dw > best_weight_update) {
|
||||
best_weight_update = dw;
|
||||
best_fidx = fidx;
|
||||
}
|
||||
}
|
||||
return best_fidx;
|
||||
}
|
||||
|
||||
protected:
|
||||
bst_uint top_k;
|
||||
std::vector<bst_uint> counter;
|
||||
std::vector<std::pair<double, double>> gpair_sums;
|
||||
};
|
||||
|
||||
inline FeatureSelector *FeatureSelector::Create(std::string name) {
|
||||
if (name == "cyclic") {
|
||||
/**
|
||||
* \brief Thrifty, approximately-greedy feature selector.
|
||||
*
|
||||
* \note Prior to cyclic updates, reorders features in descending magnitude of
|
||||
* their univariate weight changes. This operation is multithreaded and is a
|
||||
* linear complexity approximation of the quadratic greedy selection.
|
||||
*
|
||||
* \note It allows restricting the selection to top_k features per group with
|
||||
* the largest magnitude of univariate weight change, by passing the top_k value
|
||||
* through the `param` argument of Setup().
|
||||
*/
|
||||
class ThriftyFeatureSelector : public FeatureSelector {
|
||||
public:
|
||||
void Setup(const gbm::GBLinearModel &model,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda, int param) override {
|
||||
top_k = static_cast<bst_uint>(param);
|
||||
if (param <= 0) top_k = std::numeric_limits<bst_uint>::max();
|
||||
const bst_uint ngroup = model.param.num_output_group;
|
||||
const bst_omp_uint nfeat = model.param.num_feature;
|
||||
|
||||
if (deltaw.size() == 0) {
|
||||
deltaw.resize(nfeat * ngroup);
|
||||
sorted_idx.resize(nfeat * ngroup);
|
||||
counter.resize(ngroup);
|
||||
gpair_sums.resize(nfeat * ngroup);
|
||||
}
|
||||
// Calculate univariate gradient sums
|
||||
std::fill(gpair_sums.begin(), gpair_sums.end(), std::make_pair(0., 0.));
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
// column-parallel is usually faster than row-parallel
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
const ColBatch::Inst col = batch[i];
|
||||
const bst_uint ndata = col.length;
|
||||
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
||||
auto &sums = gpair_sums[gid * nfeat + i];
|
||||
for (bst_uint j = 0u; j < ndata; ++j) {
|
||||
const bst_float v = col[j].fvalue;
|
||||
auto &p = gpair[col[j].index * ngroup + gid];
|
||||
if (p.GetHess() < 0.f) continue;
|
||||
sums.first += p.GetGrad() * v;
|
||||
sums.second += p.GetHess() * v * v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// rank by descending weight magnitude within the groups
|
||||
std::fill(deltaw.begin(), deltaw.end(), 0.f);
|
||||
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||
bst_float *pdeltaw = &deltaw[0];
|
||||
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
||||
// Calculate univariate weight changes
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
auto ii = gid * nfeat + i;
|
||||
auto &s = gpair_sums[ii];
|
||||
deltaw[ii] = static_cast<bst_float>(CoordinateDelta(
|
||||
s.first, s.second, model[i][gid], alpha, lambda));
|
||||
}
|
||||
// sort in descending order of deltaw abs values
|
||||
auto start = sorted_idx.begin() + gid * nfeat;
|
||||
std::sort(start, start + nfeat,
|
||||
[pdeltaw](size_t i, size_t j) {
|
||||
return std::abs(*(pdeltaw + i)) > std::abs(*(pdeltaw + j));
|
||||
});
|
||||
counter[gid] = 0u;
|
||||
}
|
||||
}
|
||||
|
||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||
int group_idx, const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat, float alpha, float lambda) override {
|
||||
// k-th selected feature for a group
|
||||
auto k = counter[group_idx]++;
|
||||
// stop after either reaching top-N or going through all the features in a group
|
||||
if (k >= top_k || counter[group_idx] == model.param.num_feature) return -1;
|
||||
// note that sorted_idx stores the "long" indices
|
||||
const size_t grp_offset = group_idx * model.param.num_feature;
|
||||
return static_cast<int>(sorted_idx[grp_offset + k] - grp_offset);
|
||||
}
|
||||
|
||||
protected:
|
||||
bst_uint top_k;
|
||||
std::vector<bst_float> deltaw;
|
||||
std::vector<size_t> sorted_idx;
|
||||
std::vector<bst_uint> counter;
|
||||
std::vector<std::pair<double, double>> gpair_sums;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief A set of available FeatureSelector's
|
||||
*/
|
||||
enum FeatureSelectorEnum {
|
||||
kCyclic = 0,
|
||||
kShuffle,
|
||||
kThrifty,
|
||||
kGreedy,
|
||||
kRandom
|
||||
};
|
||||
|
||||
inline FeatureSelector *FeatureSelector::Create(int choice) {
|
||||
switch (choice) {
|
||||
case kCyclic:
|
||||
return new CyclicFeatureSelector();
|
||||
} else if (name == "random") {
|
||||
return new RandomFeatureSelector();
|
||||
} else if (name == "greedy") {
|
||||
case kShuffle:
|
||||
return new ShuffleFeatureSelector();
|
||||
case kThrifty:
|
||||
return new ThriftyFeatureSelector();
|
||||
case kGreedy:
|
||||
return new GreedyFeatureSelector();
|
||||
} else {
|
||||
LOG(FATAL) << name << ": unknown coordinate selector";
|
||||
case kRandom:
|
||||
return new RandomFeatureSelector();
|
||||
default:
|
||||
LOG(FATAL) << "unknown coordinate selector: " << choice;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -20,8 +20,8 @@ struct CoordinateTrainParam : public dmlc::Parameter<CoordinateTrainParam> {
|
||||
float reg_lambda;
|
||||
/*! \brief regularization weight for L1 norm */
|
||||
float reg_alpha;
|
||||
std::string feature_selector;
|
||||
float maximum_weight;
|
||||
int feature_selector;
|
||||
int top_k;
|
||||
int debug_verbose;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(CoordinateTrainParam) {
|
||||
@ -38,17 +38,35 @@ struct CoordinateTrainParam : public dmlc::Parameter<CoordinateTrainParam> {
|
||||
.set_default(0.0f)
|
||||
.describe("L1 regularization on weights.");
|
||||
DMLC_DECLARE_FIELD(feature_selector)
|
||||
.set_default("cyclic")
|
||||
.describe(
|
||||
"Feature selection algorithm, one of cyclic/random/greedy");
|
||||
.set_default(kCyclic)
|
||||
.add_enum("cyclic", kCyclic)
|
||||
.add_enum("shuffle", kShuffle)
|
||||
.add_enum("thrifty", kThrifty)
|
||||
.add_enum("greedy", kGreedy)
|
||||
.add_enum("random", kRandom)
|
||||
.describe("Feature selection or ordering method.");
|
||||
DMLC_DECLARE_FIELD(top_k)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("The number of top features to select in 'thrifty' feature_selector. "
|
||||
"The value of zero means using all the features.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
// alias of parameters
|
||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||
}
|
||||
/*! \brief Denormalizes the regularization penalties - to be called at each update */
|
||||
void DenormalizePenalties(double sum_instance_weight) {
|
||||
reg_lambda_denorm = reg_lambda * sum_instance_weight;
|
||||
reg_alpha_denorm = reg_alpha * sum_instance_weight;
|
||||
}
|
||||
// denormalizated regularization penalties
|
||||
float reg_lambda_denorm;
|
||||
float reg_alpha_denorm;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -66,47 +84,47 @@ class CoordinateUpdater : public LinearUpdater {
|
||||
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||
monitor.Init("CoordinateUpdater", param.debug_verbose);
|
||||
}
|
||||
|
||||
void Update(std::vector<bst_gpair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
// Calculate bias
|
||||
for (int group_idx = 0; group_idx < model->param.num_output_group;
|
||||
++group_idx) {
|
||||
auto grad = GetBiasGradientParallel(
|
||||
group_idx, model->param.num_output_group, *in_gpair, p_fmat);
|
||||
auto dbias = static_cast<float>(
|
||||
param.learning_rate * CoordinateDeltaBias(grad.first, grad.second));
|
||||
param.DenormalizePenalties(sum_instance_weight);
|
||||
const int ngroup = model->param.num_output_group;
|
||||
// update bias
|
||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||
auto grad = GetBiasGradientParallel(group_idx, ngroup, *in_gpair, p_fmat);
|
||||
auto dbias = static_cast<float>(param.learning_rate *
|
||||
CoordinateDeltaBias(grad.first, grad.second));
|
||||
model->bias()[group_idx] += dbias;
|
||||
UpdateBiasResidualParallel(group_idx, model->param.num_output_group,
|
||||
dbias, in_gpair, p_fmat);
|
||||
UpdateBiasResidualParallel(group_idx, ngroup, dbias, in_gpair, p_fmat);
|
||||
}
|
||||
for (int group_idx = 0; group_idx < model->param.num_output_group;
|
||||
++group_idx) {
|
||||
for (auto i = 0U; i < model->param.num_feature; i++) {
|
||||
int fidx = selector->SelectNextFeature(
|
||||
i, *model, group_idx, *in_gpair, p_fmat, param.reg_alpha,
|
||||
param.reg_lambda, sum_instance_weight);
|
||||
this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model,
|
||||
sum_instance_weight);
|
||||
// prepare for updating the weights
|
||||
selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm,
|
||||
param.reg_lambda_denorm, param.top_k);
|
||||
// update weights
|
||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||
for (unsigned i = 0U; i < model->param.num_feature; i++) {
|
||||
int fidx = selector->NextFeature(i, *model, group_idx, *in_gpair, p_fmat,
|
||||
param.reg_alpha_denorm, param.reg_lambda_denorm);
|
||||
if (fidx < 0) break;
|
||||
this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateFeature(int fidx, int group_idx, std::vector<bst_gpair> *in_gpair,
|
||||
DMatrix *p_fmat, gbm::GBLinearModel *model,
|
||||
double sum_instance_weight) {
|
||||
inline void UpdateFeature(int fidx, int group_idx, std::vector<bst_gpair> *in_gpair,
|
||||
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||
const int ngroup = model->param.num_output_group;
|
||||
bst_float &w = (*model)[fidx][group_idx];
|
||||
monitor.Start("GetGradientParallel");
|
||||
auto gradient = GetGradientParallel(
|
||||
group_idx, model->param.num_output_group, fidx, *in_gpair, p_fmat);
|
||||
auto gradient = GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
|
||||
monitor.Stop("GetGradientParallel");
|
||||
auto dw = static_cast<float>(
|
||||
param.learning_rate *
|
||||
CoordinateDelta(gradient.first, gradient.second, w, param.reg_lambda,
|
||||
param.reg_alpha, sum_instance_weight));
|
||||
CoordinateDelta(gradient.first, gradient.second, w, param.reg_alpha_denorm,
|
||||
param.reg_lambda_denorm));
|
||||
w += dw;
|
||||
monitor.Start("UpdateResidualParallel");
|
||||
UpdateResidualParallel(fidx, group_idx, model->param.num_output_group, dw,
|
||||
in_gpair, p_fmat);
|
||||
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat);
|
||||
monitor.Stop("UpdateResidualParallel");
|
||||
}
|
||||
|
||||
|
||||
@ -19,11 +19,12 @@ struct ShotgunTrainParam : public dmlc::Parameter<ShotgunTrainParam> {
|
||||
float reg_lambda;
|
||||
/*! \brief regularization weight for L1 norm */
|
||||
float reg_alpha;
|
||||
int feature_selector;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(ShotgunTrainParam) {
|
||||
DMLC_DECLARE_FIELD(learning_rate)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(1.0f)
|
||||
.set_default(0.5f)
|
||||
.describe("Learning rate of each update.");
|
||||
DMLC_DECLARE_FIELD(reg_lambda)
|
||||
.set_lower_bound(0.0f)
|
||||
@ -33,75 +34,79 @@ struct ShotgunTrainParam : public dmlc::Parameter<ShotgunTrainParam> {
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("L1 regularization on weights.");
|
||||
DMLC_DECLARE_FIELD(feature_selector)
|
||||
.set_default(kCyclic)
|
||||
.add_enum("cyclic", kCyclic)
|
||||
.add_enum("shuffle", kShuffle)
|
||||
.describe("Feature selection or ordering method.");
|
||||
// alias of parameters
|
||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||
}
|
||||
/*! \brief Denormalizes the regularization penalties - to be called at each update */
|
||||
void DenormalizePenalties(double sum_instance_weight) {
|
||||
reg_lambda_denorm = reg_lambda * sum_instance_weight;
|
||||
reg_alpha_denorm = reg_alpha * sum_instance_weight;
|
||||
}
|
||||
// denormalizated regularization penalties
|
||||
float reg_lambda_denorm;
|
||||
float reg_alpha_denorm;
|
||||
};
|
||||
|
||||
class ShotgunUpdater : public LinearUpdater {
|
||||
public:
|
||||
// set training parameter
|
||||
void Init(
|
||||
const std::vector<std::pair<std::string, std::string> > &args) override {
|
||||
void Init(const std::vector<std::pair<std::string, std::string> > &args) override {
|
||||
param.InitAllowUnknown(args);
|
||||
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||
}
|
||||
|
||||
void Update(std::vector<bst_gpair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
param.DenormalizePenalties(sum_instance_weight);
|
||||
std::vector<bst_gpair> &gpair = *in_gpair;
|
||||
const int ngroup = model->param.num_output_group;
|
||||
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||
// for all the output group
|
||||
|
||||
// update bias
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
||||
if (p.GetHess() >= 0.0f) {
|
||||
sum_grad += p.GetGrad();
|
||||
sum_hess += p.GetHess();
|
||||
}
|
||||
}
|
||||
// remove bias effect
|
||||
bst_float dw = static_cast<bst_float>(
|
||||
param.learning_rate * CoordinateDeltaBias(sum_grad, sum_hess));
|
||||
model->bias()[gid] += dw;
|
||||
// update grad value
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
||||
if (p.GetHess() >= 0.0f) {
|
||||
p += bst_gpair(p.GetHess() * dw, 0);
|
||||
}
|
||||
}
|
||||
auto grad = GetBiasGradientParallel(gid, ngroup, *in_gpair, p_fmat);
|
||||
auto dbias = static_cast<bst_float>(param.learning_rate *
|
||||
CoordinateDeltaBias(grad.first, grad.second));
|
||||
model->bias()[gid] += dbias;
|
||||
UpdateBiasResidualParallel(gid, ngroup, dbias, in_gpair, p_fmat);
|
||||
}
|
||||
|
||||
// lock-free parallel updates of weights
|
||||
selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm, 0);
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
while (iter->Next()) {
|
||||
// number of features
|
||||
const ColBatch &batch = iter->Value();
|
||||
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
const bst_uint fid = batch.col_index[i];
|
||||
ColBatch::Inst col = batch[i];
|
||||
int ii = selector->NextFeature(i, *model, 0, *in_gpair, p_fmat,
|
||||
param.reg_alpha_denorm, param.reg_lambda_denorm);
|
||||
if (ii < 0) continue;
|
||||
const bst_uint fid = batch.col_index[ii];
|
||||
ColBatch::Inst col = batch[ii];
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
for (bst_uint j = 0; j < col.length; ++j) {
|
||||
const bst_float v = col[j].fvalue;
|
||||
bst_gpair &p = gpair[col[j].index * ngroup + gid];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
const bst_float v = col[j].fvalue;
|
||||
sum_grad += p.GetGrad() * v;
|
||||
sum_hess += p.GetHess() * v * v;
|
||||
}
|
||||
bst_float &w = (*model)[fid][gid];
|
||||
bst_float dw = static_cast<bst_float>(
|
||||
param.learning_rate *
|
||||
CoordinateDelta(sum_grad, sum_hess, w, param.reg_lambda,
|
||||
param.reg_alpha, sum_instance_weight));
|
||||
CoordinateDelta(sum_grad, sum_hess, w, param.reg_alpha_denorm,
|
||||
param.reg_lambda_denorm));
|
||||
if (dw == 0.f) continue;
|
||||
w += dw;
|
||||
// update grad value
|
||||
// update grad values
|
||||
for (bst_uint j = 0; j < col.length; ++j) {
|
||||
bst_gpair &p = gpair[col[j].index * ngroup + gid];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
@ -112,8 +117,11 @@ class ShotgunUpdater : public LinearUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
// training parameter
|
||||
protected:
|
||||
// training parameters
|
||||
ShotgunTrainParam param;
|
||||
|
||||
std::unique_ptr<FeatureSelector> selector;
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(ShotgunTrainParam);
|
||||
|
||||
@ -12,7 +12,7 @@ TEST(Linear, shotgun) {
|
||||
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
|
||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||
xgboost::LinearUpdater::Create("shotgun"));
|
||||
updater->Init({});
|
||||
updater->Init({{"eta", "1."}});
|
||||
std::vector<xgboost::bst_gpair> gpair(mat->info().num_row,
|
||||
xgboost::bst_gpair(-5, 1.0));
|
||||
xgboost::gbm::GBLinearModel model;
|
||||
|
||||
@ -3,9 +3,17 @@ from __future__ import print_function
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import testing as tm
|
||||
import unittest
|
||||
import xgboost as xgb
|
||||
try:
|
||||
from sklearn import metrics, datasets
|
||||
from sklearn.linear_model import ElasticNet
|
||||
from sklearn.preprocessing import scale
|
||||
except ImportError:
|
||||
None
|
||||
|
||||
rng = np.random.RandomState(199)
|
||||
|
||||
@ -21,39 +29,35 @@ def is_float(s):
|
||||
|
||||
|
||||
def xgb_get_weights(bst):
|
||||
return [float(s) for s in bst.get_dump()[0].split() if is_float(s)]
|
||||
return np.array([float(s) for s in bst.get_dump()[0].split() if is_float(s)])
|
||||
|
||||
|
||||
# Check gradient/subgradient = 0
|
||||
def check_least_squares_solution(X, y, pred, tol, reg_alpha, reg_lambda, weights):
|
||||
reg_alpha = reg_alpha * len(y)
|
||||
reg_lambda = reg_lambda * len(y)
|
||||
r = np.subtract(y, pred)
|
||||
g = X.T.dot(r)
|
||||
g = np.subtract(g, np.multiply(reg_lambda, weights))
|
||||
for i in range(0, len(weights)):
|
||||
if weights[i] == 0.0:
|
||||
assert abs(g[i]) <= reg_alpha
|
||||
else:
|
||||
assert np.isclose(g[i], np.sign(weights[i]) * reg_alpha, rtol=tol, atol=tol)
|
||||
def check_ElasticNet(X, y, pred, tol, reg_alpha, reg_lambda, weights):
|
||||
enet = ElasticNet(alpha=reg_alpha + reg_lambda,
|
||||
l1_ratio=reg_alpha / (reg_alpha + reg_lambda))
|
||||
enet.fit(X, y)
|
||||
enet_pred = enet.predict(X)
|
||||
assert np.isclose(weights, enet.coef_, rtol=tol, atol=tol).all()
|
||||
assert np.isclose(enet_pred, pred, rtol=tol, atol=tol).all()
|
||||
|
||||
|
||||
def train_diabetes(param_in):
|
||||
from sklearn import datasets
|
||||
data = datasets.load_diabetes()
|
||||
dtrain = xgb.DMatrix(data.data, label=data.target)
|
||||
X = scale(data.data)
|
||||
dtrain = xgb.DMatrix(X, label=data.target)
|
||||
param = {}
|
||||
param.update(param_in)
|
||||
bst = xgb.train(param, dtrain, num_rounds)
|
||||
xgb_pred = bst.predict(dtrain)
|
||||
check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'],
|
||||
check_ElasticNet(X, data.target, xgb_pred, 1e-2,
|
||||
param['alpha'], param['lambda'],
|
||||
xgb_get_weights(bst)[1:])
|
||||
|
||||
|
||||
def train_breast_cancer(param_in):
|
||||
from sklearn import metrics, datasets
|
||||
data = datasets.load_breast_cancer()
|
||||
dtrain = xgb.DMatrix(data.data, label=data.target)
|
||||
X = scale(data.data)
|
||||
dtrain = xgb.DMatrix(X, label=data.target)
|
||||
param = {'objective': 'binary:logistic'}
|
||||
param.update(param_in)
|
||||
bst = xgb.train(param, dtrain, num_rounds)
|
||||
@ -63,9 +67,8 @@ def train_breast_cancer(param_in):
|
||||
|
||||
|
||||
def train_classification(param_in):
|
||||
from sklearn import metrics, datasets
|
||||
X, y = datasets.make_classification(random_state=rng,
|
||||
scale=100) # Scale is necessary otherwise regularisation parameters will force all coefficients to 0
|
||||
X, y = datasets.make_classification(random_state=rng)
|
||||
X = scale(X)
|
||||
dtrain = xgb.DMatrix(X, label=y)
|
||||
param = {'objective': 'binary:logistic'}
|
||||
param.update(param_in)
|
||||
@ -76,10 +79,11 @@ def train_classification(param_in):
|
||||
|
||||
|
||||
def train_classification_multi(param_in):
|
||||
from sklearn import metrics, datasets
|
||||
num_class = 3
|
||||
X, y = datasets.make_classification(n_samples=10, random_state=rng, scale=100, n_classes=num_class, n_informative=4,
|
||||
X, y = datasets.make_classification(n_samples=100, random_state=rng,
|
||||
n_classes=num_class, n_informative=4,
|
||||
n_features=4, n_redundant=0)
|
||||
X = scale(X)
|
||||
dtrain = xgb.DMatrix(X, label=y)
|
||||
param = {'objective': 'multi:softmax', 'num_class': num_class}
|
||||
param.update(param_in)
|
||||
@ -90,20 +94,42 @@ def train_classification_multi(param_in):
|
||||
|
||||
|
||||
def train_boston(param_in):
|
||||
from sklearn import datasets
|
||||
data = datasets.load_boston()
|
||||
dtrain = xgb.DMatrix(data.data, label=data.target)
|
||||
X = scale(data.data)
|
||||
dtrain = xgb.DMatrix(X, label=data.target)
|
||||
param = {}
|
||||
param.update(param_in)
|
||||
bst = xgb.train(param, dtrain, num_rounds)
|
||||
xgb_pred = bst.predict(dtrain)
|
||||
check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'],
|
||||
check_ElasticNet(X, data.target, xgb_pred, 1e-2,
|
||||
param['alpha'], param['lambda'],
|
||||
xgb_get_weights(bst)[1:])
|
||||
|
||||
|
||||
def train_external_mem(param_in):
|
||||
data = datasets.load_boston()
|
||||
X = scale(data.data)
|
||||
y = data.target
|
||||
param = {}
|
||||
param.update(param_in)
|
||||
dtrain = xgb.DMatrix(X, label=y)
|
||||
bst = xgb.train(param, dtrain, num_rounds)
|
||||
xgb_pred = bst.predict(dtrain)
|
||||
np.savetxt('tmptmp_1234.csv', np.hstack((y.reshape(len(y), 1), X)),
|
||||
delimiter=',', fmt='%10.9f')
|
||||
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_')
|
||||
bst = xgb.train(param, dtrain, num_rounds)
|
||||
xgb_pred_ext = bst.predict(dtrain)
|
||||
assert np.abs(xgb_pred_ext - xgb_pred).max() < 1e-3
|
||||
del dtrain, bst
|
||||
for f in glob.glob("tmptmp_*"):
|
||||
os.remove(f)
|
||||
|
||||
|
||||
# Enumerates all permutations of variable parameters
|
||||
def assert_updater_accuracy(linear_updater, variable_param):
|
||||
param = {'booster': 'gblinear', 'updater': linear_updater, 'tolerance': 1e-8}
|
||||
param = {'booster': 'gblinear', 'updater': linear_updater, 'eta': 1.,
|
||||
'top_k': 10, 'tolerance': 1e-5, 'nthread': 2}
|
||||
names = sorted(variable_param)
|
||||
combinations = it.product(*(variable_param[Name] for Name in names))
|
||||
|
||||
@ -118,16 +144,17 @@ def assert_updater_accuracy(linear_updater, variable_param):
|
||||
train_classification(param_tmp)
|
||||
train_classification_multi(param_tmp)
|
||||
train_breast_cancer(param_tmp)
|
||||
train_external_mem(param_tmp)
|
||||
|
||||
|
||||
class TestLinear(unittest.TestCase):
|
||||
def test_coordinate(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0],
|
||||
'coordinate_selection': ['cyclic', 'random', 'greedy']}
|
||||
variable_param = {'alpha': [.005, .1], 'lambda': [.005],
|
||||
'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']}
|
||||
assert_updater_accuracy('coord_descent', variable_param)
|
||||
|
||||
def test_shotgun(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0]}
|
||||
variable_param = {'alpha': [.005, .1], 'lambda': [.005, .1]}
|
||||
assert_updater_accuracy('shotgun', variable_param)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user