enable returning prediction in cv
This commit is contained in:
parent
89d5e67b78
commit
947f0a926d
@ -39,7 +39,7 @@ setMethod("getinfo", signature = "xgb.DMatrix",
|
|||||||
if (name != "nrow"){
|
if (name != "nrow"){
|
||||||
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
|
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
|
||||||
} else {
|
} else {
|
||||||
ret <- .Call("XGDMatrixNumRow_R", object, PACKAGE = "xgboost")
|
ret <- xgb.numrow(object)
|
||||||
}
|
}
|
||||||
return(ret)
|
return(ret)
|
||||||
})
|
})
|
||||||
|
|||||||
@ -131,7 +131,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# iteratively evaluate one iteration
|
# iteratively evaluate one iteration
|
||||||
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
|
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) {
|
||||||
if (class(booster) != "xgb.Booster") {
|
if (class(booster) != "xgb.Booster") {
|
||||||
stop("xgb.eval: first argument must be type xgb.Booster")
|
stop("xgb.eval: first argument must be type xgb.Booster")
|
||||||
}
|
}
|
||||||
@ -169,18 +169,27 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
|
|||||||
} else {
|
} else {
|
||||||
msg <- ""
|
msg <- ""
|
||||||
}
|
}
|
||||||
|
if (prediction){
|
||||||
|
preds <- predict(booster,watchlist[[2]])
|
||||||
|
return(list(msg,preds))
|
||||||
|
}
|
||||||
return(msg)
|
return(msg)
|
||||||
}
|
}
|
||||||
#------------------------------------------
|
#------------------------------------------
|
||||||
# helper functions for cross validation
|
# helper functions for cross validation
|
||||||
#
|
#
|
||||||
xgb.cv.mknfold <- function(dall, nfold, param) {
|
xgb.cv.mknfold <- function(dall, nfold, param) {
|
||||||
randidx <- sample(1 : xgb.numrow(dall))
|
if (nfold <= 1) {
|
||||||
kstep <- length(randidx) / nfold
|
stop("nfold must be bigger than 1")
|
||||||
idset <- list()
|
|
||||||
for (i in 1:nfold) {
|
|
||||||
idset[[i]] <- randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ]
|
|
||||||
}
|
}
|
||||||
|
randidx <- sample(1 : xgb.numrow(dall))
|
||||||
|
kstep <- length(randidx) %/% nfold
|
||||||
|
idset <- list()
|
||||||
|
for (i in 1:(nfold-1)) {
|
||||||
|
idset[[i]] = randidx[1:kstep]
|
||||||
|
randidx = setdiff(randidx,idset[[i]])
|
||||||
|
}
|
||||||
|
idset[[nfold]] = randidx
|
||||||
ret <- list()
|
ret <- list()
|
||||||
for (k in 1:nfold) {
|
for (k in 1:nfold) {
|
||||||
dtest <- slice(dall, idset[[k]])
|
dtest <- slice(dall, idset[[k]])
|
||||||
@ -193,7 +202,7 @@ xgb.cv.mknfold <- function(dall, nfold, param) {
|
|||||||
dtrain <- slice(dall, didx)
|
dtrain <- slice(dall, didx)
|
||||||
bst <- xgb.Booster(param, list(dtrain, dtest))
|
bst <- xgb.Booster(param, list(dtrain, dtest))
|
||||||
watchlist = list(train=dtrain, test=dtest)
|
watchlist = list(train=dtrain, test=dtest)
|
||||||
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist)
|
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=idset[[k]])
|
||||||
}
|
}
|
||||||
return (ret)
|
return (ret)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,6 +31,9 @@
|
|||||||
#' @param nrounds the max number of iterations
|
#' @param nrounds the max number of iterations
|
||||||
#' @param nfold number of folds used
|
#' @param nfold number of folds used
|
||||||
#' @param label option field, when data is Matrix
|
#' @param label option field, when data is Matrix
|
||||||
|
#' @param missing Missing is only used when input is dense matrix, pick a float
|
||||||
|
# value that represents missing value. Sometime a data use 0 or other extreme value to represents missing values.
|
||||||
|
#' @param prediction A logical value indicating whether to return the prediction vector.
|
||||||
#' @param showsd \code{boolean}, whether show standard deviation of cross validation
|
#' @param showsd \code{boolean}, whether show standard deviation of cross validation
|
||||||
#' @param metrics, list of evaluation metrics to be used in corss validation,
|
#' @param metrics, list of evaluation metrics to be used in corss validation,
|
||||||
#' when it is not specified, the evaluation metric is chosen according to objective function.
|
#' when it is not specified, the evaluation metric is chosen according to objective function.
|
||||||
@ -71,7 +74,8 @@
|
|||||||
#' @export
|
#' @export
|
||||||
#'
|
#'
|
||||||
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
|
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
|
||||||
showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, verbose = T,...) {
|
prediction = FALSE, showsd = TRUE, metrics=list(),
|
||||||
|
obj = NULL, feval = NULL, verbose = T,...) {
|
||||||
if (typeof(params) != "list") {
|
if (typeof(params) != "list") {
|
||||||
stop("xgb.cv: first argument params must be list")
|
stop("xgb.cv: first argument params must be list")
|
||||||
}
|
}
|
||||||
@ -90,13 +94,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
}
|
}
|
||||||
|
|
||||||
folds <- xgb.cv.mknfold(dtrain, nfold, params)
|
folds <- xgb.cv.mknfold(dtrain, nfold, params)
|
||||||
|
predictValues <- rep(0,xgb.numrow(dtrain))
|
||||||
history <- c()
|
history <- c()
|
||||||
for (i in 1:nrounds) {
|
for (i in 1:nrounds) {
|
||||||
msg <- list()
|
msg <- list()
|
||||||
for (k in 1:nfold) {
|
for (k in 1:nfold) {
|
||||||
fd <- folds[[k]]
|
fd <- folds[[k]]
|
||||||
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
|
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
|
||||||
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
|
if (!prediction){
|
||||||
|
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
|
||||||
|
} else {
|
||||||
|
res <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction)
|
||||||
|
predictValues[fd$index] <- res[[2]]
|
||||||
|
msg[[k]] <- res[[1]] %>% str_split("\t") %>% .[[1]]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ret <- xgb.cv.aggcv(msg, showsd)
|
ret <- xgb.cv.aggcv(msg, showsd)
|
||||||
history <- c(history, ret)
|
history <- c(history, ret)
|
||||||
@ -115,5 +126,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
split <- str_split(string = history, pattern = "\t")
|
split <- str_split(string = history, pattern = "\t")
|
||||||
|
|
||||||
for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.list %>% {vec <- .; rbindlist(list(dt, vec), use.names = F, fill = F)}
|
for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.list %>% {vec <- .; rbindlist(list(dt, vec), use.names = F, fill = F)}
|
||||||
dt
|
|
||||||
|
if (prediction) {
|
||||||
|
return(list(dt,predictValues))
|
||||||
|
}
|
||||||
|
return(dt)
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user