diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 3fd72c27b..5db56d07d 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -1,8 +1,8 @@ Package: xgboost Type: Package Title: Extreme Gradient Boosting -Version: 0.7.0 -Date: 2018-01-22 +Version: 0.7.0.1 +Date: 2018-02-25 Author: Tianqi Chen , Tong He , Michael Benesty , Vadim Khotilovich , Yuan Tang diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index e0fe76a58..aff4864fc 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -18,6 +18,7 @@ export("xgb.parameters<-") export(cb.cv.predict) export(cb.early.stop) export(cb.evaluation.log) +export(cb.gblinear.history) export(cb.print.evaluation) export(cb.reset.parameters) export(cb.save.model) @@ -32,6 +33,7 @@ export(xgb.attributes) export(xgb.create.features) export(xgb.cv) export(xgb.dump) +export(xgb.gblinear.history) export(xgb.ggplot.deepness) export(xgb.ggplot.importance) export(xgb.importance) @@ -52,7 +54,9 @@ importClassesFrom(Matrix,dgeMatrix) importFrom(Matrix,cBind) importFrom(Matrix,colSums) importFrom(Matrix,sparse.model.matrix) +importFrom(Matrix,sparseMatrix) importFrom(Matrix,sparseVector) +importFrom(Matrix,t) importFrom(data.table,":=") importFrom(data.table,as.data.table) importFrom(data.table,data.table) diff --git a/R-package/R/callbacks.R b/R-package/R/callbacks.R index 80a6c66ca..48977ff5e 100644 --- a/R-package/R/callbacks.R +++ b/R-package/R/callbacks.R @@ -524,6 +524,225 @@ cb.cv.predict <- function(save_models = FALSE) { } +#' Callback closure for collecting the model coefficients history of a gblinear booster +#' during its training. +#' +#' @param sparse when set to FALSE/TURE, a dense/sparse matrix is used to store the result. +#' Sparse format is useful when one expects only a subset of coefficients to be non-zero, +#' when using the "thrifty" feature selector with fairly small number of top features +#' selected per iteration. +#' +#' @details +#' To keep things fast and simple, gblinear booster does not internally store the history of linear +#' model coefficients at each boosting iteration. This callback provides a workaround for storing +#' the coefficients' path, by extracting them after each training iteration. +#' +#' Callback function expects the following values to be set in its calling frame: +#' \code{bst} (or \code{bst_folds}). +#' +#' @return +#' Results are stored in the \code{coefs} element of the closure. +#' The \code{\link{xgb.gblinear.history}} convenience function provides an easy way to access it. +#' With \code{xgb.train}, it is either a dense of a sparse matrix. +#' While with \code{xgb.cv}, it is a list (an element per each fold) of such matrices. +#' +#' @seealso +#' \code{\link{callbacks}}, \code{\link{xgb.gblinear.history}}. +#' +#' @examples +#' #### Binary classification: +#' # +#' # In the iris dataset, it is hard to linearly separate Versicolor class from the rest +#' # without considering the 2nd order interactions: +#' x <- model.matrix(Species ~ .^2, iris)[,-1] +#' colnames(x) +#' dtrain <- xgb.DMatrix(scale(x), label = 1*(iris$Species == "versicolor")) +#' param <- list(booster = "gblinear", objective = "reg:logistic", eval_metric = "auc", +#' lambda = 0.0003, alpha = 0.0003, nthread = 2) +#' # For 'shotgun', which is a default linear updater, using high eta values may result in +#' # unstable behaviour in some datasets. With this simple dataset, however, the high learning +#' # rate does not break the convergence, but allows us to illustrate the typical pattern of +#' # "stochastic explosion" behaviour of this lock-free algorithm at early boosting iterations. +#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 1., +#' callbacks = list(cb.gblinear.history())) +#' # Extract the coefficients' path and plot them vs boosting iteration number: +#' coef_path <- xgb.gblinear.history(bst) +#' matplot(coef_path, type = 'l') +#' +#' # With the deterministic coordinate descent updater, it is safer to use higher learning rates. +#' # Will try the classical componentwise boosting which selects a single best feature per round: +#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 0.8, +#' updater = 'coord_descent', feature_selector = 'thrifty', top_k = 1, +#' callbacks = list(cb.gblinear.history())) +#' xgb.gblinear.history(bst) %>% matplot(type = 'l') +#' # Componentwise boosting is known to have similar effect to Lasso regularization. +#' # Try experimenting with various values of top_k, eta, nrounds, +#' # as well as different feature_selectors. +#' +#' # For xgb.cv: +#' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 100, eta = 0.8, +#' callbacks = list(cb.gblinear.history())) +#' # coefficients in the CV fold #3 +#' xgb.gblinear.history(bst)[[3]] %>% matplot(type = 'l') +#' +#' +#' #### Multiclass classification: +#' # +#' dtrain <- xgb.DMatrix(scale(x), label = as.numeric(iris$Species) - 1) +#' param <- list(booster = "gblinear", objective = "multi:softprob", num_class = 3, +#' lambda = 0.0003, alpha = 0.0003, nthread = 2) +#' # For the default linear updater 'shotgun' it sometimes is helpful +#' # to use smaller eta to reduce instability +#' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 70, eta = 0.5, +#' callbacks = list(cb.gblinear.history())) +#' # Will plot the coefficient paths separately for each class: +#' xgb.gblinear.history(bst, class_index = 0) %>% matplot(type = 'l') +#' xgb.gblinear.history(bst, class_index = 1) %>% matplot(type = 'l') +#' xgb.gblinear.history(bst, class_index = 2) %>% matplot(type = 'l') +#' +#' # CV: +#' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 70, eta = 0.5, +#' callbacks = list(cb.gblinear.history(F))) +#' # 1st forld of 1st class +#' xgb.gblinear.history(bst, class_index = 0)[[1]] %>% matplot(type = 'l') +#' +#' @export +cb.gblinear.history <- function(sparse=FALSE) { + coefs <- NULL + + init <- function(env) { + if (!is.null(env$bst)) { # xgb.train: + coef_path <- list() + } else if (!is.null(env$bst_folds)) { # xgb.cv: + coef_path <- rep(list(), length(env$bst_folds)) + } else stop("Parent frame has neither 'bst' nor 'bst_folds'") + } + + # convert from list to (sparse) matrix + list2mat <- function(coef_list) { + if (sparse) { + coef_mat <- sparseMatrix(x = unlist(lapply(coef_list, slot, "x")), + i = unlist(lapply(coef_list, slot, "i")), + p = c(0, cumsum(sapply(coef_list, function(x) length(x@x)))), + dims = c(length(coef_list[[1]]), length(coef_list))) + return(t(coef_mat)) + } else { + return(do.call(rbind, coef_list)) + } + } + + finalizer <- function(env) { + if (length(coefs) == 0) + return() + if (!is.null(env$bst)) { # # xgb.train: + coefs <<- list2mat(coefs) + } else { # xgb.cv: + # first lapply transposes the list + coefs <<- lapply(seq_along(coefs[[1]]), function(i) lapply(coefs, "[[", i)) %>% + lapply(function(x) list2mat(x)) + } + } + + extract.coef <- function(env) { + if (!is.null(env$bst)) { # # xgb.train: + cf <- as.numeric(grep('(booster|bias|weigh)', xgb.dump(env$bst), invert = TRUE, value = TRUE)) + if (sparse) cf <- as(cf, "sparseVector") + } else { # xgb.cv: + cf <- vector("list", length(env$bst_folds)) + for (i in seq_along(env$bst_folds)) { + dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst)) + cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE)) + if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector") + } + } + cf + } + + callback <- function(env = parent.frame(), finalize = FALSE) { + if (is.null(coefs)) init(env) + if (finalize) return(finalizer(env)) + cf <- extract.coef(env) + coefs <<- c(coefs, list(cf)) + } + + attr(callback, 'call') <- match.call() + attr(callback, 'name') <- 'cb.gblinear.history' + callback +} + +#' Extract gblinear coefficients history. +#' +#' A helper function to extract the matrix of linear coefficients' history +#' from a gblinear model created while using the \code{cb.gblinear.history()} +#' callback. +#' +#' @param model either an \code{xgb.Booster} or a result of \code{xgb.cv()}, trained +#' using the \code{cb.gblinear.history()} callback. +#' @param class_index zero-based class index to extract the coefficients for only that +#' specific class in a multinomial multiclass model. When it is NULL, all the +#' coeffients are returned. Has no effect in non-multiclass models. +#' +#' @return +#' For an \code{xgb.train} result, a matrix (either dense or sparse) with the columns +#' corresponding to iteration's coefficients (in the order as \code{xgb.dump()} would +#' return) and the rows corresponding to boosting iterations. +#' +#' For an \code{xgb.cv} result, a list of such matrices is returned with the elements +#' corresponding to CV folds. +#' +#' @examples +#' See \code{\link{cv.gblinear.history}} +#' +#' @export +xgb.gblinear.history <- function(model, class_index = NULL) { + + if (!(inherits(model, "xgb.Booster") || + inherits(model, "xgb.cv.synchronous"))) + stop("model must be an object of either xgb.Booster or xgb.cv.synchronous class") + is_cv <- inherits(model, "xgb.cv.synchronous") + + if (is.null(model[["callbacks"]]) || is.null(model$callbacks[["cb.gblinear.history"]])) + stop("model must be trained while using the cb.gblinear.history() callback") + + if (!is_cv) { + # extract num_class & num_feat from the internal model + dmp <- xgb.dump(model) + if(length(dmp) < 2 || dmp[2] != "bias:") + stop("It does not appear to be a gblinear model") + dmp <- dmp[-c(1,2)] + n <- which(dmp == 'weight:') + if(length(n) != 1) + stop("It does not appear to be a gblinear model") + num_class <- n - 1 + num_feat <- (length(dmp) - 4) / num_class + } else { + # in case of CV, the object is expected to have this info + if (model$params$booster != "gblinear") + stop("It does not appear to be a gblinear model") + num_class <- NVL(model$params$num_class, 1) + num_feat <- model$nfeatures + if (is.null(num_feat)) + stop("This xgb.cv result does not have nfeatures info") + } + + if (!is.null(class_index) && + num_class > 1 && + (class_index[1] < 0 || class_index[1] >= num_class)) + stop("class_index has to be within [0,", num_class - 1, "]") + + coef_path <- environment(model$callbacks$cb.gblinear.history)[["coefs"]] + if (!is.null(class_index) && num_class > 1) { + coef_path <- if (is.list(coef_path)) { + lapply(coef_path, + function(x) x[, seq(1 + class_index, by=num_class, length.out=num_feat)]) + } else { + coef_path <- coef_path[, seq(1 + class_index, by=num_class, length.out=num_feat)] + } + } + coef_path +} + + # # Internal utility functions for callbacks ------------------------------------ # diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 4b6a699d6..54c9f2d0b 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -88,6 +88,7 @@ #' CV-based evaluation means and standard deviations for the training and test CV-sets. #' It is created by the \code{\link{cb.evaluation.log}} callback. #' \item \code{niter} number of boosting iterations. +#' \item \code{nfeatures} number of features in training data. #' \item \code{folds} the list of CV folds' indices - either those passed through the \code{folds} #' parameter or randomly generated. #' \item \code{best_iteration} iteration number with the best evaluation metric value @@ -184,6 +185,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = handle <- xgb.Booster.handle(params, list(dtrain, dtest)) list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test=dtest), index = folds[[k]]) }) + rm(dall) # a "basket" to collect some results from callbacks basket <- list() @@ -221,6 +223,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = callbacks = callbacks, evaluation_log = evaluation_log, niter = end_iteration, + nfeatures = ncol(data), folds = folds ) ret <- c(ret, basket) diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index ab39fa0e9..26e6bc737 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -162,6 +162,7 @@ #' (only available with early stopping). #' \item \code{feature_names} names of the training dataset features #' (only when comun names were defined in training data). +#' \item \code{nfeatures} number of features in training data. #' } #' #' @seealso @@ -363,6 +364,7 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(), bst$callbacks <- callbacks if (!is.null(colnames(dtrain))) bst$feature_names <- colnames(dtrain) - + bst$nfeatures <- ncol(dtrain) + return(bst) } diff --git a/R-package/R/xgboost.R b/R-package/R/xgboost.R index 6991a0f83..c76e8b14e 100644 --- a/R-package/R/xgboost.R +++ b/R-package/R/xgboost.R @@ -81,6 +81,8 @@ NULL #' @importFrom Matrix colSums #' @importFrom Matrix sparse.model.matrix #' @importFrom Matrix sparseVector +#' @importFrom Matrix sparseMatrix +#' @importFrom Matrix t #' @importFrom data.table data.table #' @importFrom data.table is.data.table #' @importFrom data.table as.data.table diff --git a/R-package/man/xgb.cv.Rd b/R-package/man/xgb.cv.Rd index 19bba4fdc..31d41324a 100644 --- a/R-package/man/xgb.cv.Rd +++ b/R-package/man/xgb.cv.Rd @@ -104,6 +104,7 @@ An object of class \code{xgb.cv.synchronous} with the following elements: CV-based evaluation means and standard deviations for the training and test CV-sets. It is created by the \code{\link{cb.evaluation.log}} callback. \item \code{niter} number of boosting iterations. + \item \code{nfeatures} number of features in training data. \item \code{folds} the list of CV folds' indices - either those passed through the \code{folds} parameter or randomly generated. \item \code{best_iteration} iteration number with the best evaluation metric value diff --git a/R-package/man/xgb.train.Rd b/R-package/man/xgb.train.Rd index a4776f4fd..b93298911 100644 --- a/R-package/man/xgb.train.Rd +++ b/R-package/man/xgb.train.Rd @@ -155,6 +155,7 @@ An object of class \code{xgb.Booster} with the following elements: (only available with early stopping). \item \code{feature_names} names of the training dataset features (only when comun names were defined in training data). + \item \code{nfeatures} number of features in training data. } } \description{ diff --git a/R-package/tests/testthat/test_glm.R b/R-package/tests/testthat/test_glm.R index dc7b6efab..7293e1ada 100644 --- a/R-package/tests/testthat/test_glm.R +++ b/R-package/tests/testthat/test_glm.R @@ -2,18 +2,47 @@ context('Test generalized linear models') require(xgboost) -test_that("glm works", { +test_that("gblinear works", { data(agaricus.train, package='xgboost') data(agaricus.test, package='xgboost') dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) - expect_equal(class(dtrain), "xgb.DMatrix") - expect_equal(class(dtest), "xgb.DMatrix") + param <- list(objective = "binary:logistic", booster = "gblinear", - nthread = 2, alpha = 0.0001, lambda = 1) + nthread = 2, eta = 0.8, alpha = 0.0001, lambda = 0.0001) watchlist <- list(eval = dtest, train = dtrain) - num_round <- 2 - bst <- xgb.train(param, dtrain, num_round, watchlist) + + n <- 5 # iterations + ERR_UL <- 0.005 # upper limit for the test set error + VERB <- 0 # chatterbox switch + + param$updater = 'shotgun' + bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle') ypred <- predict(bst, dtest) expect_equal(length(getinfo(dtest, 'label')), 1611) + expect_lt(bst$evaluation_log$eval_error[n], ERR_UL) + + bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic', + callbacks = list(cb.gblinear.history())) + expect_lt(bst$evaluation_log$eval_error[n], ERR_UL) + h <- xgb.gblinear.history(bst) + expect_equal(dim(h), c(n, ncol(dtrain) + 1)) + expect_is(h, "matrix") + + param$updater = 'coord_descent' + bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic') + expect_lt(bst$evaluation_log$eval_error[n], ERR_UL) + + bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle') + expect_lt(bst$evaluation_log$eval_error[n], ERR_UL) + + bst <- xgb.train(param, dtrain, 2, watchlist, verbose = VERB, feature_selector = 'greedy') + expect_lt(bst$evaluation_log$eval_error[2], ERR_UL) + + bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty', + top_n = 50, callbacks = list(cb.gblinear.history(sparse = TRUE))) + expect_lt(bst$evaluation_log$eval_error[n], ERR_UL) + h <- xgb.gblinear.history(bst) + expect_equal(dim(h), c(n, ncol(dtrain) + 1)) + expect_s4_class(h, "dgCMatrix") }) diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 8ad2edc6a..f1cc48e57 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -119,7 +119,7 @@ ColIterator(const std::vector& fset) { } -bool SparsePageDMatrix::TryInitColData() { +bool SparsePageDMatrix::TryInitColData(bool sorted) { // load meta data. std::vector cache_shards = common::Split(cache_info_, ':'); { @@ -140,6 +140,8 @@ bool SparsePageDMatrix::TryInitColData() { files.push_back(std::move(fdata)); } col_iter_.reset(new ColPageIter(std::move(files))); + // warning: no attempt to check here whether the cached data was sorted + col_iter_->sorted = sorted; return true; } @@ -147,7 +149,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector& enabled, float pkeep, size_t max_row_perbatch, bool sorted) { if (HaveColAccess(sorted)) return; - if (TryInitColData()) return; + if (TryInitColData(sorted)) return; const MetaInfo& info = this->info(); if (max_row_perbatch == std::numeric_limits::max()) { max_row_perbatch = kMaxRowPerBatch; @@ -291,8 +293,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector& enabled, fo.reset(nullptr); } // initialize column data - CHECK(TryInitColData()); - col_iter_->sorted = sorted; + CHECK(TryInitColData(sorted)); } } // namespace data diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 4c99e72cc..597a223b9 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -116,7 +116,7 @@ class SparsePageDMatrix : public DMatrix { * \brief Try to initialize column data. * \return true if data already exists, false if they do not. */ - bool TryInitColData(); + bool TryInitColData(bool sorted); // source data pointer. std::unique_ptr source_; // the cache prefix diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index dde5231c5..6e14c2e6b 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -21,14 +21,12 @@ namespace gbm { DMLC_REGISTRY_FILE_TAG(gblinear); -// training parameter +// training parameters struct GBLinearTrainParam : public dmlc::Parameter { - /*! \brief learning_rate */ std::string updater; - // flag to print out detailed breakdown of runtime - int debug_verbose; float tolerance; - // declare parameters + size_t max_row_perbatch; + int debug_verbose; DMLC_DECLARE_PARAMETER(GBLinearTrainParam) { DMLC_DECLARE_FIELD(updater) .set_default("shotgun") @@ -37,6 +35,9 @@ struct GBLinearTrainParam : public dmlc::Parameter { .set_lower_bound(0.0f) .set_default(0.0f) .describe("Stop if largest weight update is smaller than this number."); + DMLC_DECLARE_FIELD(max_row_perbatch) + .set_default(std::numeric_limits::max()) + .describe("Maximum rows per batch."); DMLC_DECLARE_FIELD(debug_verbose) .set_lower_bound(0) .set_default(0) @@ -84,12 +85,10 @@ class GBLinear : public GradientBooster { if (!p_fmat->HaveColAccess(false)) { std::vector enabled(p_fmat->info().num_col, true); - p_fmat->InitColAccess(enabled, 1.0f, std::numeric_limits::max(), - false); + p_fmat->InitColAccess(enabled, 1.0f, param.max_row_perbatch, false); } model.LazyInitModel(); - this->LazySumWeights(p_fmat); if (!this->CheckConvergence()) { @@ -191,40 +190,7 @@ class GBLinear : public GradientBooster { std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const override { - const int ngroup = model.param.num_output_group; - const unsigned nfeature = model.param.num_feature; - - std::stringstream fo(""); - if (format == "json") { - fo << " { \"bias\": [" << std::endl; - for (int gid = 0; gid < ngroup; ++gid) { - if (gid != 0) fo << "," << std::endl; - fo << " " << model.bias()[gid]; - } - fo << std::endl << " ]," << std::endl - << " \"weight\": [" << std::endl; - for (unsigned i = 0; i < nfeature; ++i) { - for (int gid = 0; gid < ngroup; ++gid) { - if (i != 0 || gid != 0) fo << "," << std::endl; - fo << " " << model[i][gid]; - } - } - fo << std::endl << " ]" << std::endl << " }"; - } else { - fo << "bias:\n"; - for (int gid = 0; gid < ngroup; ++gid) { - fo << model.bias()[gid] << std::endl; - } - fo << "weight:\n"; - for (unsigned i = 0; i < nfeature; ++i) { - for (int gid = 0; gid < ngroup; ++gid) { - fo << model[i][gid] << std::endl; - } - } - } - std::vector v; - v.push_back(fo.str()); - return v; + return model.DumpModel(fmap, with_stats, format); } protected: @@ -272,9 +238,12 @@ class GBLinear : public GradientBooster { bool CheckConvergence() { if (param.tolerance == 0.0f) return false; if (is_converged) return true; - if (previous_model.weight.size() != model.weight.size()) return false; + if (previous_model.weight.size() != model.weight.size()) { + previous_model = model; + return false; + } float largest_dw = 0.0; - for (auto i = 0; i < model.weight.size(); i++) { + for (size_t i = 0; i < model.weight.size(); i++) { largest_dw = std::max( largest_dw, std::abs(model.weight[i] - previous_model.weight[i])); } @@ -287,7 +256,7 @@ class GBLinear : public GradientBooster { void LazySumWeights(DMatrix *p_fmat) { if (!sum_weight_complete) { auto &info = p_fmat->info(); - for (int i = 0; i < info.num_row; i++) { + for (size_t i = 0; i < info.num_row; i++) { sum_instance_weight += info.GetWeight(i); } sum_weight_complete = true; diff --git a/src/gbm/gblinear_model.h b/src/gbm/gblinear_model.h index 72fcedb80..10e4ffe0c 100644 --- a/src/gbm/gblinear_model.h +++ b/src/gbm/gblinear_model.h @@ -4,7 +4,9 @@ #pragma once #include #include +#include #include +#include #include namespace xgboost { @@ -68,6 +70,44 @@ class GBLinearModel { inline const bst_float* operator[](size_t i) const { return &weight[i * param.num_output_group]; } + + std::vector DumpModel(const FeatureMap& fmap, bool with_stats, + std::string format) const { + const int ngroup = param.num_output_group; + const unsigned nfeature = param.num_feature; + + std::stringstream fo(""); + if (format == "json") { + fo << " { \"bias\": [" << std::endl; + for (int gid = 0; gid < ngroup; ++gid) { + if (gid != 0) fo << "," << std::endl; + fo << " " << this->bias()[gid]; + } + fo << std::endl << " ]," << std::endl + << " \"weight\": [" << std::endl; + for (unsigned i = 0; i < nfeature; ++i) { + for (int gid = 0; gid < ngroup; ++gid) { + if (i != 0 || gid != 0) fo << "," << std::endl; + fo << " " << (*this)[i][gid]; + } + } + fo << std::endl << " ]" << std::endl << " }"; + } else { + fo << "bias:\n"; + for (int gid = 0; gid < ngroup; ++gid) { + fo << this->bias()[gid] << std::endl; + } + fo << "weight:\n"; + for (unsigned i = 0; i < nfeature; ++i) { + for (int gid = 0; gid < ngroup; ++gid) { + fo << (*this)[i][gid] << std::endl; + } + } + } + std::vector v; + v.push_back(fo.str()); + return v; + } }; } // namespace gbm } // namespace xgboost diff --git a/src/linear/coordinate_common.h b/src/linear/coordinate_common.h index 41955e4c7..141bb68a1 100644 --- a/src/linear/coordinate_common.h +++ b/src/linear/coordinate_common.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "../common/random.h" namespace xgboost { @@ -19,26 +20,21 @@ namespace linear { * \param sum_grad The sum gradient. * \param sum_hess The sum hess. * \param w The weight. - * \param reg_lambda Unnormalised L2 penalty. * \param reg_alpha Unnormalised L1 penalty. - * \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty. + * \param reg_lambda Unnormalised L2 penalty. * * \return The weight update. */ - inline double CoordinateDelta(double sum_grad, double sum_hess, double w, - double reg_lambda, double reg_alpha, - double sum_instance_weight) { - reg_alpha *= sum_instance_weight; - reg_lambda *= sum_instance_weight; + double reg_alpha, double reg_lambda) { if (sum_hess < 1e-5f) return 0.0f; - double tmp = w - (sum_grad + reg_lambda * w) / (sum_hess + reg_lambda); + const double sum_grad_l2 = sum_grad + reg_lambda * w; + const double sum_hess_l2 = sum_hess + reg_lambda; + const double tmp = w - sum_grad_l2 / sum_hess_l2; if (tmp >= 0) { - return std::max( - -(sum_grad + reg_lambda * w + reg_alpha) / (sum_hess + reg_lambda), -w); + return std::max(-(sum_grad_l2 + reg_alpha) / sum_hess_l2, -w); } else { - return std::min( - -(sum_grad + reg_lambda * w - reg_alpha) / (sum_hess + reg_lambda), -w); + return std::min(-(sum_grad_l2 - reg_alpha) / sum_hess_l2, -w); } } @@ -50,7 +46,6 @@ inline double CoordinateDelta(double sum_grad, double sum_hess, double w, * * \return The weight update. */ - inline double CoordinateDeltaBias(double sum_grad, double sum_hess) { return -sum_grad / sum_hess; } @@ -66,15 +61,14 @@ inline double CoordinateDeltaBias(double sum_grad, double sum_hess) { * * \return The gradient and diagonal Hessian entry for a given feature. */ - -inline std::pair GetGradient( - int group_idx, int num_group, int fidx, const std::vector &gpair, - DMatrix *p_fmat) { +inline std::pair GetGradient(int group_idx, int num_group, int fidx, + const std::vector &gpair, + DMatrix *p_fmat) { double sum_grad = 0.0, sum_hess = 0.0; - dmlc::DataIter *iter = p_fmat->ColIterator(); + dmlc::DataIter *iter = p_fmat->ColIterator({static_cast(fidx)}); while (iter->Next()) { const ColBatch &batch = iter->Value(); - ColBatch::Inst col = batch[fidx]; + ColBatch::Inst col = batch[0]; const bst_omp_uint ndata = static_cast(col.length); for (bst_omp_uint j = 0; j < ndata; ++j) { const bst_float v = col[j].fvalue; @@ -88,7 +82,7 @@ inline std::pair GetGradient( } /** - * \brief Get the gradient with respect to a single feature. Multithreaded. + * \brief Get the gradient with respect to a single feature. Row-wise multithreaded. * * \param group_idx Zero-based index of the group. * \param num_group Number of groups. @@ -98,16 +92,14 @@ inline std::pair GetGradient( * * \return The gradient and diagonal Hessian entry for a given feature. */ - -inline std::pair GetGradientParallel( - int group_idx, int num_group, int fidx, - - const std::vector &gpair, DMatrix *p_fmat) { +inline std::pair GetGradientParallel(int group_idx, int num_group, int fidx, + const std::vector &gpair, + DMatrix *p_fmat) { double sum_grad = 0.0, sum_hess = 0.0; - dmlc::DataIter *iter = p_fmat->ColIterator(); + dmlc::DataIter *iter = p_fmat->ColIterator({static_cast(fidx)}); while (iter->Next()) { const ColBatch &batch = iter->Value(); - ColBatch::Inst col = batch[fidx]; + ColBatch::Inst col = batch[0]; const bst_omp_uint ndata = static_cast(col.length); #pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess) for (bst_omp_uint j = 0; j < ndata; ++j) { @@ -122,7 +114,7 @@ inline std::pair GetGradientParallel( } /** - * \brief Get the gradient with respect to the bias. Multithreaded. + * \brief Get the gradient with respect to the bias. Row-wise multithreaded. * * \param group_idx Zero-based index of the group. * \param num_group Number of groups. @@ -131,10 +123,9 @@ inline std::pair GetGradientParallel( * * \return The gradient and diagonal Hessian entry for the bias. */ - -inline std::pair GetBiasGradientParallel( - int group_idx, int num_group, const std::vector &gpair, - DMatrix *p_fmat) { +inline std::pair GetBiasGradientParallel(int group_idx, int num_group, + const std::vector &gpair, + DMatrix *p_fmat) { const RowSet &rowset = p_fmat->buffered_rowset(); double sum_grad = 0.0, sum_hess = 0.0; const bst_omp_uint ndata = static_cast(rowset.size()); @@ -159,15 +150,14 @@ inline std::pair GetBiasGradientParallel( * \param in_gpair The gradient vector to be updated. * \param p_fmat The input feature matrix. */ - inline void UpdateResidualParallel(int fidx, int group_idx, int num_group, float dw, std::vector *in_gpair, DMatrix *p_fmat) { if (dw == 0.0f) return; - dmlc::DataIter *iter = p_fmat->ColIterator(); + dmlc::DataIter *iter = p_fmat->ColIterator({static_cast(fidx)}); while (iter->Next()) { const ColBatch &batch = iter->Value(); - ColBatch::Inst col = batch[fidx]; + ColBatch::Inst col = batch[0]; // update grad value const bst_omp_uint num_row = static_cast(col.length); #pragma omp parallel for schedule(static) @@ -188,9 +178,7 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group, * \param in_gpair The gradient vector to be updated. * \param p_fmat The input feature matrix. */ - -inline void UpdateBiasResidualParallel(int group_idx, int num_group, - float dbias, +inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias, std::vector *in_gpair, DMatrix *p_fmat) { if (dbias == 0.0f) return; @@ -205,114 +193,292 @@ inline void UpdateBiasResidualParallel(int group_idx, int num_group, } /** - * \class FeatureSelector - * - * \brief Abstract class for stateful feature selection in coordinate descent - * algorithms. + * \brief Abstract class for stateful feature selection or ordering + * in coordinate descent algorithms. */ - class FeatureSelector { public: - static FeatureSelector *Create(std::string name); + /*! \brief factory method */ + static FeatureSelector *Create(int choice); /*! \brief virtual destructor */ virtual ~FeatureSelector() {} - + /** + * \brief Setting up the selector state prior to looping through features. + * + * \param model The model. + * \param gpair The gpair. + * \param p_fmat The feature matrix. + * \param alpha Regularisation alpha. + * \param lambda Regularisation lambda. + * \param param A parameter with algorithm-dependent use. + */ + virtual void Setup(const gbm::GBLinearModel &model, + const std::vector &gpair, + DMatrix *p_fmat, + float alpha, float lambda, int param) {} /** * \brief Select next coordinate to update. * - * \param iteration The iteration. - * \param model The model. - * \param group_idx Zero-based index of the group. - * \param gpair The gpair. - * \param p_fmat The feature matrix. - * \param alpha Regularisation alpha. - * \param lambda Regularisation lambda. - * \param sum_instance_weight The sum instance weight. + * \param iteration The iteration in a loop through features + * \param model The model. + * \param group_idx Zero-based index of the group. + * \param gpair The gpair. + * \param p_fmat The feature matrix. + * \param alpha Regularisation alpha. + * \param lambda Regularisation lambda. * - * \return The index of the selected feature. -1 indicates the bias term. + * \return The index of the selected feature. -1 indicates none selected. */ - - virtual int SelectNextFeature(int iteration, - const gbm::GBLinearModel &model, - int group_idx, - const std::vector &gpair, - DMatrix *p_fmat, float alpha, float lambda, - double sum_instance_weight) = 0; + virtual int NextFeature(int iteration, + const gbm::GBLinearModel &model, + int group_idx, + const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda) = 0; }; /** - * \class CyclicFeatureSelector - * - * \brief Deterministic selection by cycling through coordinates one at a time. + * \brief Deterministic selection by cycling through features one at a time. */ - class CyclicFeatureSelector : public FeatureSelector { public: - int SelectNextFeature(int iteration, const gbm::GBLinearModel &model, - int group_idx, const std::vector &gpair, - DMatrix *p_fmat, float alpha, float lambda, - double sum_instance_weight) override { + int NextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda) override { return iteration % model.param.num_feature; } }; /** - * \class RandomFeatureSelector - * - * \brief A random coordinate selector. + * \brief Similar to Cyclyc but with random feature shuffling prior to each update. + * \note Its randomness is controllable by setting a random seed. */ +class ShuffleFeatureSelector : public FeatureSelector { + public: + void Setup(const gbm::GBLinearModel &model, + const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda, int param) override { + if (feat_index.size() == 0) { + feat_index.resize(model.param.num_feature); + std::iota(feat_index.begin(), feat_index.end(), 0); + } + std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom()); + } + int NextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda) override { + return feat_index[iteration % model.param.num_feature]; + } + + protected: + std::vector feat_index; +}; + +/** + * \brief A random (with replacement) coordinate selector. + * \note Its randomness is controllable by setting a random seed. + */ class RandomFeatureSelector : public FeatureSelector { public: - int SelectNextFeature(int iteration, const gbm::GBLinearModel &model, - int group_idx, const std::vector &gpair, - DMatrix *p_fmat, float alpha, float lambda, - double sum_instance_weight) override { + int NextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda) override { return common::GlobalRandom()() % model.param.num_feature; } }; /** - * \class GreedyFeatureSelector - * * \brief Select coordinate with the greatest gradient magnitude. + * \note It has O(num_feature^2) complexity. It is fully deterministic. + * + * \note It allows restricting the selection to top_k features per group with + * the largest magnitude of univariate weight change, by passing the top_k value + * through the `param` argument of Setup(). That would reduce the complexity to + * O(num_feature*top_k). */ - class GreedyFeatureSelector : public FeatureSelector { public: - int SelectNextFeature(int iteration, const gbm::GBLinearModel &model, - int group_idx, const std::vector &gpair, - DMatrix *p_fmat, float alpha, float lambda, - double sum_instance_weight) override { - // Find best + void Setup(const gbm::GBLinearModel &model, + const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda, int param) override { + top_k = static_cast(param); + const bst_uint ngroup = model.param.num_output_group; + if (param <= 0) top_k = std::numeric_limits::max(); + if (counter.size() == 0) { + counter.resize(ngroup); + gpair_sums.resize(model.param.num_feature * ngroup); + } + for (bst_uint gid = 0u; gid < ngroup; ++gid) { + counter[gid] = 0u; + } + } + + int NextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda) override { + // k-th selected feature for a group + auto k = counter[group_idx]++; + // stop after either reaching top-K or going through all the features in a group + if (k >= top_k || counter[group_idx] == model.param.num_feature) return -1; + + const int ngroup = model.param.num_output_group; + const bst_omp_uint nfeat = model.param.num_feature; + // Calculate univariate gradient sums + std::fill(gpair_sums.begin(), gpair_sums.end(), std::make_pair(0., 0.)); + dmlc::DataIter *iter = p_fmat->ColIterator(); + while (iter->Next()) { + const ColBatch &batch = iter->Value(); + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nfeat; ++i) { + const ColBatch::Inst col = batch[i]; + const bst_uint ndata = col.length; + auto &sums = gpair_sums[group_idx * nfeat + i]; + for (bst_uint j = 0u; j < ndata; ++j) { + const bst_float v = col[j].fvalue; + auto &p = gpair[col[j].index * ngroup + group_idx]; + if (p.GetHess() < 0.f) continue; + sums.first += p.GetGrad() * v; + sums.second += p.GetHess() * v * v; + } + } + } + // Find a feature with the largest magnitude of weight change int best_fidx = 0; double best_weight_update = 0.0f; - - for (auto fidx = 0U; fidx < model.param.num_feature; fidx++) { - const float w = model[fidx][group_idx]; - auto gradient = GetGradientParallel( - group_idx, model.param.num_output_group, fidx, gpair, p_fmat); - float dw = static_cast( - CoordinateDelta(gradient.first, gradient.second, w, lambda, alpha, - sum_instance_weight)); - if (std::abs(dw) > std::abs(best_weight_update)) { + for (bst_omp_uint fidx = 0; fidx < nfeat; ++fidx) { + auto &s = gpair_sums[group_idx * nfeat + fidx]; + float dw = std::abs(static_cast( + CoordinateDelta(s.first, s.second, model[fidx][group_idx], alpha, lambda))); + if (dw > best_weight_update) { best_weight_update = dw; best_fidx = fidx; } } return best_fidx; } + + protected: + bst_uint top_k; + std::vector counter; + std::vector> gpair_sums; }; -inline FeatureSelector *FeatureSelector::Create(std::string name) { - if (name == "cyclic") { - return new CyclicFeatureSelector(); - } else if (name == "random") { - return new RandomFeatureSelector(); - } else if (name == "greedy") { - return new GreedyFeatureSelector(); - } else { - LOG(FATAL) << name << ": unknown coordinate selector"; +/** + * \brief Thrifty, approximately-greedy feature selector. + * + * \note Prior to cyclic updates, reorders features in descending magnitude of + * their univariate weight changes. This operation is multithreaded and is a + * linear complexity approximation of the quadratic greedy selection. + * + * \note It allows restricting the selection to top_k features per group with + * the largest magnitude of univariate weight change, by passing the top_k value + * through the `param` argument of Setup(). + */ +class ThriftyFeatureSelector : public FeatureSelector { + public: + void Setup(const gbm::GBLinearModel &model, + const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda, int param) override { + top_k = static_cast(param); + if (param <= 0) top_k = std::numeric_limits::max(); + const bst_uint ngroup = model.param.num_output_group; + const bst_omp_uint nfeat = model.param.num_feature; + + if (deltaw.size() == 0) { + deltaw.resize(nfeat * ngroup); + sorted_idx.resize(nfeat * ngroup); + counter.resize(ngroup); + gpair_sums.resize(nfeat * ngroup); + } + // Calculate univariate gradient sums + std::fill(gpair_sums.begin(), gpair_sums.end(), std::make_pair(0., 0.)); + dmlc::DataIter *iter = p_fmat->ColIterator(); + while (iter->Next()) { + const ColBatch &batch = iter->Value(); + // column-parallel is usually faster than row-parallel + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nfeat; ++i) { + const ColBatch::Inst col = batch[i]; + const bst_uint ndata = col.length; + for (bst_uint gid = 0u; gid < ngroup; ++gid) { + auto &sums = gpair_sums[gid * nfeat + i]; + for (bst_uint j = 0u; j < ndata; ++j) { + const bst_float v = col[j].fvalue; + auto &p = gpair[col[j].index * ngroup + gid]; + if (p.GetHess() < 0.f) continue; + sums.first += p.GetGrad() * v; + sums.second += p.GetHess() * v * v; + } + } + } + } + // rank by descending weight magnitude within the groups + std::fill(deltaw.begin(), deltaw.end(), 0.f); + std::iota(sorted_idx.begin(), sorted_idx.end(), 0); + bst_float *pdeltaw = &deltaw[0]; + for (bst_uint gid = 0u; gid < ngroup; ++gid) { + // Calculate univariate weight changes + for (bst_omp_uint i = 0; i < nfeat; ++i) { + auto ii = gid * nfeat + i; + auto &s = gpair_sums[ii]; + deltaw[ii] = static_cast(CoordinateDelta( + s.first, s.second, model[i][gid], alpha, lambda)); + } + // sort in descending order of deltaw abs values + auto start = sorted_idx.begin() + gid * nfeat; + std::sort(start, start + nfeat, + [pdeltaw](size_t i, size_t j) { + return std::abs(*(pdeltaw + i)) > std::abs(*(pdeltaw + j)); + }); + counter[gid] = 0u; + } + } + + int NextFeature(int iteration, const gbm::GBLinearModel &model, + int group_idx, const std::vector &gpair, + DMatrix *p_fmat, float alpha, float lambda) override { + // k-th selected feature for a group + auto k = counter[group_idx]++; + // stop after either reaching top-N or going through all the features in a group + if (k >= top_k || counter[group_idx] == model.param.num_feature) return -1; + // note that sorted_idx stores the "long" indices + const size_t grp_offset = group_idx * model.param.num_feature; + return static_cast(sorted_idx[grp_offset + k] - grp_offset); + } + + protected: + bst_uint top_k; + std::vector deltaw; + std::vector sorted_idx; + std::vector counter; + std::vector> gpair_sums; +}; + +/** + * \brief A set of available FeatureSelector's + */ +enum FeatureSelectorEnum { + kCyclic = 0, + kShuffle, + kThrifty, + kGreedy, + kRandom +}; + +inline FeatureSelector *FeatureSelector::Create(int choice) { + switch (choice) { + case kCyclic: + return new CyclicFeatureSelector(); + case kShuffle: + return new ShuffleFeatureSelector(); + case kThrifty: + return new ThriftyFeatureSelector(); + case kGreedy: + return new GreedyFeatureSelector(); + case kRandom: + return new RandomFeatureSelector(); + default: + LOG(FATAL) << "unknown coordinate selector: " << choice; } return nullptr; } diff --git a/src/linear/updater_coordinate.cc b/src/linear/updater_coordinate.cc index 4f8a58b55..4caf37ca0 100644 --- a/src/linear/updater_coordinate.cc +++ b/src/linear/updater_coordinate.cc @@ -20,8 +20,8 @@ struct CoordinateTrainParam : public dmlc::Parameter { float reg_lambda; /*! \brief regularization weight for L1 norm */ float reg_alpha; - std::string feature_selector; - float maximum_weight; + int feature_selector; + int top_k; int debug_verbose; // declare parameters DMLC_DECLARE_PARAMETER(CoordinateTrainParam) { @@ -38,17 +38,35 @@ struct CoordinateTrainParam : public dmlc::Parameter { .set_default(0.0f) .describe("L1 regularization on weights."); DMLC_DECLARE_FIELD(feature_selector) - .set_default("cyclic") - .describe( - "Feature selection algorithm, one of cyclic/random/greedy"); + .set_default(kCyclic) + .add_enum("cyclic", kCyclic) + .add_enum("shuffle", kShuffle) + .add_enum("thrifty", kThrifty) + .add_enum("greedy", kGreedy) + .add_enum("random", kRandom) + .describe("Feature selection or ordering method."); + DMLC_DECLARE_FIELD(top_k) + .set_lower_bound(0) + .set_default(0) + .describe("The number of top features to select in 'thrifty' feature_selector. " + "The value of zero means using all the features."); DMLC_DECLARE_FIELD(debug_verbose) .set_lower_bound(0) .set_default(0) .describe("flag to print out detailed breakdown of runtime"); // alias of parameters + DMLC_DECLARE_ALIAS(learning_rate, eta); DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_alpha, alpha); } + /*! \brief Denormalizes the regularization penalties - to be called at each update */ + void DenormalizePenalties(double sum_instance_weight) { + reg_lambda_denorm = reg_lambda * sum_instance_weight; + reg_alpha_denorm = reg_alpha * sum_instance_weight; + } + // denormalizated regularization penalties + float reg_lambda_denorm; + float reg_alpha_denorm; }; /** @@ -66,47 +84,47 @@ class CoordinateUpdater : public LinearUpdater { selector.reset(FeatureSelector::Create(param.feature_selector)); monitor.Init("CoordinateUpdater", param.debug_verbose); } + void Update(std::vector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, double sum_instance_weight) override { - // Calculate bias - for (int group_idx = 0; group_idx < model->param.num_output_group; - ++group_idx) { - auto grad = GetBiasGradientParallel( - group_idx, model->param.num_output_group, *in_gpair, p_fmat); - auto dbias = static_cast( - param.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); + param.DenormalizePenalties(sum_instance_weight); + const int ngroup = model->param.num_output_group; + // update bias + for (int group_idx = 0; group_idx < ngroup; ++group_idx) { + auto grad = GetBiasGradientParallel(group_idx, ngroup, *in_gpair, p_fmat); + auto dbias = static_cast(param.learning_rate * + CoordinateDeltaBias(grad.first, grad.second)); model->bias()[group_idx] += dbias; - UpdateBiasResidualParallel(group_idx, model->param.num_output_group, - dbias, in_gpair, p_fmat); + UpdateBiasResidualParallel(group_idx, ngroup, dbias, in_gpair, p_fmat); } - for (int group_idx = 0; group_idx < model->param.num_output_group; - ++group_idx) { - for (auto i = 0U; i < model->param.num_feature; i++) { - int fidx = selector->SelectNextFeature( - i, *model, group_idx, *in_gpair, p_fmat, param.reg_alpha, - param.reg_lambda, sum_instance_weight); - this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model, - sum_instance_weight); + // prepare for updating the weights + selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm, + param.reg_lambda_denorm, param.top_k); + // update weights + for (int group_idx = 0; group_idx < ngroup; ++group_idx) { + for (unsigned i = 0U; i < model->param.num_feature; i++) { + int fidx = selector->NextFeature(i, *model, group_idx, *in_gpair, p_fmat, + param.reg_alpha_denorm, param.reg_lambda_denorm); + if (fidx < 0) break; + this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model); } } } - void UpdateFeature(int fidx, int group_idx, std::vector *in_gpair, - DMatrix *p_fmat, gbm::GBLinearModel *model, - double sum_instance_weight) { + inline void UpdateFeature(int fidx, int group_idx, std::vector *in_gpair, + DMatrix *p_fmat, gbm::GBLinearModel *model) { + const int ngroup = model->param.num_output_group; bst_float &w = (*model)[fidx][group_idx]; monitor.Start("GetGradientParallel"); - auto gradient = GetGradientParallel( - group_idx, model->param.num_output_group, fidx, *in_gpair, p_fmat); + auto gradient = GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat); monitor.Stop("GetGradientParallel"); auto dw = static_cast( param.learning_rate * - CoordinateDelta(gradient.first, gradient.second, w, param.reg_lambda, - param.reg_alpha, sum_instance_weight)); + CoordinateDelta(gradient.first, gradient.second, w, param.reg_alpha_denorm, + param.reg_lambda_denorm)); w += dw; monitor.Start("UpdateResidualParallel"); - UpdateResidualParallel(fidx, group_idx, model->param.num_output_group, dw, - in_gpair, p_fmat); + UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat); monitor.Stop("UpdateResidualParallel"); } diff --git a/src/linear/updater_shotgun.cc b/src/linear/updater_shotgun.cc index 02d740031..a15f22bba 100644 --- a/src/linear/updater_shotgun.cc +++ b/src/linear/updater_shotgun.cc @@ -19,11 +19,12 @@ struct ShotgunTrainParam : public dmlc::Parameter { float reg_lambda; /*! \brief regularization weight for L1 norm */ float reg_alpha; + int feature_selector; // declare parameters DMLC_DECLARE_PARAMETER(ShotgunTrainParam) { DMLC_DECLARE_FIELD(learning_rate) .set_lower_bound(0.0f) - .set_default(1.0f) + .set_default(0.5f) .describe("Learning rate of each update."); DMLC_DECLARE_FIELD(reg_lambda) .set_lower_bound(0.0f) @@ -33,75 +34,79 @@ struct ShotgunTrainParam : public dmlc::Parameter { .set_lower_bound(0.0f) .set_default(0.0f) .describe("L1 regularization on weights."); + DMLC_DECLARE_FIELD(feature_selector) + .set_default(kCyclic) + .add_enum("cyclic", kCyclic) + .add_enum("shuffle", kShuffle) + .describe("Feature selection or ordering method."); // alias of parameters DMLC_DECLARE_ALIAS(learning_rate, eta); DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_alpha, alpha); } + /*! \brief Denormalizes the regularization penalties - to be called at each update */ + void DenormalizePenalties(double sum_instance_weight) { + reg_lambda_denorm = reg_lambda * sum_instance_weight; + reg_alpha_denorm = reg_alpha * sum_instance_weight; + } + // denormalizated regularization penalties + float reg_lambda_denorm; + float reg_alpha_denorm; }; class ShotgunUpdater : public LinearUpdater { public: // set training parameter - void Init( - const std::vector > &args) override { + void Init(const std::vector > &args) override { param.InitAllowUnknown(args); + selector.reset(FeatureSelector::Create(param.feature_selector)); } + void Update(std::vector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, double sum_instance_weight) override { + param.DenormalizePenalties(sum_instance_weight); std::vector &gpair = *in_gpair; const int ngroup = model->param.num_output_group; - const RowSet &rowset = p_fmat->buffered_rowset(); - // for all the output group + + // update bias for (int gid = 0; gid < ngroup; ++gid) { - double sum_grad = 0.0, sum_hess = 0.0; - const bst_omp_uint ndata = static_cast(rowset.size()); -#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess) - for (bst_omp_uint i = 0; i < ndata; ++i) { - bst_gpair &p = gpair[rowset[i] * ngroup + gid]; - if (p.GetHess() >= 0.0f) { - sum_grad += p.GetGrad(); - sum_hess += p.GetHess(); - } - } - // remove bias effect - bst_float dw = static_cast( - param.learning_rate * CoordinateDeltaBias(sum_grad, sum_hess)); - model->bias()[gid] += dw; -// update grad value -#pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < ndata; ++i) { - bst_gpair &p = gpair[rowset[i] * ngroup + gid]; - if (p.GetHess() >= 0.0f) { - p += bst_gpair(p.GetHess() * dw, 0); - } - } + auto grad = GetBiasGradientParallel(gid, ngroup, *in_gpair, p_fmat); + auto dbias = static_cast(param.learning_rate * + CoordinateDeltaBias(grad.first, grad.second)); + model->bias()[gid] += dbias; + UpdateBiasResidualParallel(gid, ngroup, dbias, in_gpair, p_fmat); } + + // lock-free parallel updates of weights + selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm, 0); dmlc::DataIter *iter = p_fmat->ColIterator(); while (iter->Next()) { - // number of features const ColBatch &batch = iter->Value(); const bst_omp_uint nfeat = static_cast(batch.size); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nfeat; ++i) { - const bst_uint fid = batch.col_index[i]; - ColBatch::Inst col = batch[i]; + int ii = selector->NextFeature(i, *model, 0, *in_gpair, p_fmat, + param.reg_alpha_denorm, param.reg_lambda_denorm); + if (ii < 0) continue; + const bst_uint fid = batch.col_index[ii]; + ColBatch::Inst col = batch[ii]; for (int gid = 0; gid < ngroup; ++gid) { double sum_grad = 0.0, sum_hess = 0.0; for (bst_uint j = 0; j < col.length; ++j) { - const bst_float v = col[j].fvalue; bst_gpair &p = gpair[col[j].index * ngroup + gid]; if (p.GetHess() < 0.0f) continue; + const bst_float v = col[j].fvalue; sum_grad += p.GetGrad() * v; sum_hess += p.GetHess() * v * v; } bst_float &w = (*model)[fid][gid]; bst_float dw = static_cast( param.learning_rate * - CoordinateDelta(sum_grad, sum_hess, w, param.reg_lambda, - param.reg_alpha, sum_instance_weight)); + CoordinateDelta(sum_grad, sum_hess, w, param.reg_alpha_denorm, + param.reg_lambda_denorm)); + if (dw == 0.f) continue; w += dw; - // update grad value + // update grad values for (bst_uint j = 0; j < col.length; ++j) { bst_gpair &p = gpair[col[j].index * ngroup + gid]; if (p.GetHess() < 0.0f) continue; @@ -112,8 +117,11 @@ class ShotgunUpdater : public LinearUpdater { } } - // training parameter + protected: + // training parameters ShotgunTrainParam param; + + std::unique_ptr selector; }; DMLC_REGISTER_PARAMETER(ShotgunTrainParam); diff --git a/tests/cpp/linear/test_linear.cc b/tests/cpp/linear/test_linear.cc index a5f7756c0..92ad8095c 100644 --- a/tests/cpp/linear/test_linear.cc +++ b/tests/cpp/linear/test_linear.cc @@ -12,7 +12,7 @@ TEST(Linear, shotgun) { mat->InitColAccess(enabled, 1.0f, 1 << 16, false); auto updater = std::unique_ptr( xgboost::LinearUpdater::Create("shotgun")); - updater->Init({}); + updater->Init({{"eta", "1."}}); std::vector gpair(mat->info().num_row, xgboost::bst_gpair(-5, 1.0)); xgboost::gbm::GBLinearModel model; diff --git a/tests/python/test_linear.py b/tests/python/test_linear.py index fd85441e4..26e91ec93 100644 --- a/tests/python/test_linear.py +++ b/tests/python/test_linear.py @@ -3,9 +3,17 @@ from __future__ import print_function import itertools as it import numpy as np import sys +import os +import glob import testing as tm import unittest import xgboost as xgb +try: + from sklearn import metrics, datasets + from sklearn.linear_model import ElasticNet + from sklearn.preprocessing import scale +except ImportError: + None rng = np.random.RandomState(199) @@ -21,39 +29,35 @@ def is_float(s): def xgb_get_weights(bst): - return [float(s) for s in bst.get_dump()[0].split() if is_float(s)] + return np.array([float(s) for s in bst.get_dump()[0].split() if is_float(s)]) -# Check gradient/subgradient = 0 -def check_least_squares_solution(X, y, pred, tol, reg_alpha, reg_lambda, weights): - reg_alpha = reg_alpha * len(y) - reg_lambda = reg_lambda * len(y) - r = np.subtract(y, pred) - g = X.T.dot(r) - g = np.subtract(g, np.multiply(reg_lambda, weights)) - for i in range(0, len(weights)): - if weights[i] == 0.0: - assert abs(g[i]) <= reg_alpha - else: - assert np.isclose(g[i], np.sign(weights[i]) * reg_alpha, rtol=tol, atol=tol) +def check_ElasticNet(X, y, pred, tol, reg_alpha, reg_lambda, weights): + enet = ElasticNet(alpha=reg_alpha + reg_lambda, + l1_ratio=reg_alpha / (reg_alpha + reg_lambda)) + enet.fit(X, y) + enet_pred = enet.predict(X) + assert np.isclose(weights, enet.coef_, rtol=tol, atol=tol).all() + assert np.isclose(enet_pred, pred, rtol=tol, atol=tol).all() def train_diabetes(param_in): - from sklearn import datasets data = datasets.load_diabetes() - dtrain = xgb.DMatrix(data.data, label=data.target) + X = scale(data.data) + dtrain = xgb.DMatrix(X, label=data.target) param = {} param.update(param_in) bst = xgb.train(param, dtrain, num_rounds) xgb_pred = bst.predict(dtrain) - check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'], - xgb_get_weights(bst)[1:]) + check_ElasticNet(X, data.target, xgb_pred, 1e-2, + param['alpha'], param['lambda'], + xgb_get_weights(bst)[1:]) def train_breast_cancer(param_in): - from sklearn import metrics, datasets data = datasets.load_breast_cancer() - dtrain = xgb.DMatrix(data.data, label=data.target) + X = scale(data.data) + dtrain = xgb.DMatrix(X, label=data.target) param = {'objective': 'binary:logistic'} param.update(param_in) bst = xgb.train(param, dtrain, num_rounds) @@ -63,9 +67,8 @@ def train_breast_cancer(param_in): def train_classification(param_in): - from sklearn import metrics, datasets - X, y = datasets.make_classification(random_state=rng, - scale=100) # Scale is necessary otherwise regularisation parameters will force all coefficients to 0 + X, y = datasets.make_classification(random_state=rng) + X = scale(X) dtrain = xgb.DMatrix(X, label=y) param = {'objective': 'binary:logistic'} param.update(param_in) @@ -76,10 +79,11 @@ def train_classification(param_in): def train_classification_multi(param_in): - from sklearn import metrics, datasets num_class = 3 - X, y = datasets.make_classification(n_samples=10, random_state=rng, scale=100, n_classes=num_class, n_informative=4, + X, y = datasets.make_classification(n_samples=100, random_state=rng, + n_classes=num_class, n_informative=4, n_features=4, n_redundant=0) + X = scale(X) dtrain = xgb.DMatrix(X, label=y) param = {'objective': 'multi:softmax', 'num_class': num_class} param.update(param_in) @@ -90,20 +94,42 @@ def train_classification_multi(param_in): def train_boston(param_in): - from sklearn import datasets data = datasets.load_boston() - dtrain = xgb.DMatrix(data.data, label=data.target) + X = scale(data.data) + dtrain = xgb.DMatrix(X, label=data.target) param = {} param.update(param_in) bst = xgb.train(param, dtrain, num_rounds) xgb_pred = bst.predict(dtrain) - check_least_squares_solution(data.data, data.target, xgb_pred, 1e-2, param['alpha'], param['lambda'], - xgb_get_weights(bst)[1:]) + check_ElasticNet(X, data.target, xgb_pred, 1e-2, + param['alpha'], param['lambda'], + xgb_get_weights(bst)[1:]) + + +def train_external_mem(param_in): + data = datasets.load_boston() + X = scale(data.data) + y = data.target + param = {} + param.update(param_in) + dtrain = xgb.DMatrix(X, label=y) + bst = xgb.train(param, dtrain, num_rounds) + xgb_pred = bst.predict(dtrain) + np.savetxt('tmptmp_1234.csv', np.hstack((y.reshape(len(y), 1), X)), + delimiter=',', fmt='%10.9f') + dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_') + bst = xgb.train(param, dtrain, num_rounds) + xgb_pred_ext = bst.predict(dtrain) + assert np.abs(xgb_pred_ext - xgb_pred).max() < 1e-3 + del dtrain, bst + for f in glob.glob("tmptmp_*"): + os.remove(f) # Enumerates all permutations of variable parameters def assert_updater_accuracy(linear_updater, variable_param): - param = {'booster': 'gblinear', 'updater': linear_updater, 'tolerance': 1e-8} + param = {'booster': 'gblinear', 'updater': linear_updater, 'eta': 1., + 'top_k': 10, 'tolerance': 1e-5, 'nthread': 2} names = sorted(variable_param) combinations = it.product(*(variable_param[Name] for Name in names)) @@ -118,16 +144,17 @@ def assert_updater_accuracy(linear_updater, variable_param): train_classification(param_tmp) train_classification_multi(param_tmp) train_breast_cancer(param_tmp) + train_external_mem(param_tmp) class TestLinear(unittest.TestCase): def test_coordinate(self): tm._skip_if_no_sklearn() - variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0], - 'coordinate_selection': ['cyclic', 'random', 'greedy']} + variable_param = {'alpha': [.005, .1], 'lambda': [.005], + 'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']} assert_updater_accuracy('coord_descent', variable_param) def test_shotgun(self): tm._skip_if_no_sklearn() - variable_param = {'alpha': [1.0, 5.0], 'lambda': [1.0, 5.0]} + variable_param = {'alpha': [.005, .1], 'lambda': [.005, .1]} assert_updater_accuracy('shotgun', variable_param)