Merge pull request #106 from tqchen/master

pull master into unity
This commit is contained in:
Tianqi Chen
2014-11-21 08:56:01 -08:00
11 changed files with 106 additions and 23 deletions

View File

@@ -25,9 +25,13 @@ setClass("xgb.Booster")
#' @export
#'
setMethod("predict", signature = "xgb.Booster",
definition = function(object, newdata, outputmargin = FALSE, ntreelimit = NULL) {
definition = function(object, newdata, missing = NULL, outputmargin = FALSE, ntreelimit = NULL) {
if (class(newdata) != "xgb.DMatrix") {
newdata <- xgb.DMatrix(newdata)
if (is.null(missing)) {
newdata <- xgb.DMatrix(newdata)
} else {
newdata <- xgb.DMatrix(newdata, missing = missing)
}
}
if (is.null(ntreelimit)) {
ntreelimit <- 0

View File

@@ -68,13 +68,17 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
## ----the following are low level iteratively function, not needed if
## you do not want to use them ---------------------------------------
# get dmatrix from data, label
xgb.get.DMatrix <- function(data, label = NULL) {
xgb.get.DMatrix <- function(data, label = NULL, missing = NULL) {
inClass <- class(data)
if (inClass == "dgCMatrix" || inClass == "matrix") {
if (is.null(label)) {
stop("xgboost: need label when data is a matrix")
}
dtrain <- xgb.DMatrix(data, label = label)
if (is.null(missing)){
dtrain <- xgb.DMatrix(data, label = label)
} else {
dtrain <- xgb.DMatrix(data, label = label, missing = missing)
}
} else {
if (!is.null(label)) {
warning("xgboost: label will be ignored.")

View File

@@ -53,7 +53,7 @@
#' "max.depth"=3, "eta"=1, "objective"="binary:logistic")
#' @export
#'
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL,
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, ...) {
if (typeof(params) != "list") {
stop("xgb.cv: first argument params must be list")
@@ -61,7 +61,11 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL,
if (nfold <= 1) {
stop("nfold must be bigger than 1")
}
dtrain <- xgb.get.DMatrix(data, label)
if (is.null(missing)) {
dtrain <- xgb.get.DMatrix(data, label)
} else {
dtrain <- xgb.get.DMatrix(data, label, missing)
}
params <- append(params, list(...))
params <- append(params, list(silent=1))
for (mc in metrics) {

View File

@@ -43,9 +43,14 @@
#'
#' @export
#'
xgboost <- function(data = NULL, label = NULL, params = list(), nrounds,
xgboost <- function(data = NULL, label = NULL, missing = NULL, params = list(), nrounds,
verbose = 1, ...) {
dtrain <- xgb.get.DMatrix(data, label)
if (is.null(missing)) {
dtrain <- xgb.get.DMatrix(data, label)
} else {
dtrain <- xgb.get.DMatrix(data, label, missing)
}
params <- append(params, list(...))
if (verbose > 0) {

View File

@@ -37,3 +37,26 @@ print ('start training with user customized objective')
# training with customized objective, we can also do step by step training
# simply look at xgboost.py's implementation of train
bst <- xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
#
# there can be cases where you want additional information
# being considered besides the property of DMatrix you can get by getinfo
# you can set additional information as attributes if DMatrix
# set label attribute of dtrain to be label, we use label as an example, it can be anything
attr(dtrain, 'label') <- getinfo(dtrain, 'label')
# this is new customized objective, where you can access things you set
# same thing applies to customized evaluation function
logregobjattr <- function(preds, dtrain) {
# now you can access the attribute in customized function
labels <- attr(dtrain, 'label')
preds <- 1/(1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
print ('start training with user customized objective, with additional attributes in DMatrix')
# training with customized objective, we can also do step by step training
# simply look at xgboost.py's implementation of train
bst <- xgb.train(param, dtrain, num_round, watchlist, logregobjattr, evalerror)