enable returning prediction in cv
This commit is contained in:
@@ -131,7 +131,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
||||
}
|
||||
|
||||
# iteratively evaluate one iteration
|
||||
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
|
||||
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) {
|
||||
if (class(booster) != "xgb.Booster") {
|
||||
stop("xgb.eval: first argument must be type xgb.Booster")
|
||||
}
|
||||
@@ -169,18 +169,27 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
|
||||
} else {
|
||||
msg <- ""
|
||||
}
|
||||
if (prediction){
|
||||
preds <- predict(booster,watchlist[[2]])
|
||||
return(list(msg,preds))
|
||||
}
|
||||
return(msg)
|
||||
}
|
||||
}
|
||||
#------------------------------------------
|
||||
# helper functions for cross validation
|
||||
#
|
||||
xgb.cv.mknfold <- function(dall, nfold, param) {
|
||||
randidx <- sample(1 : xgb.numrow(dall))
|
||||
kstep <- length(randidx) / nfold
|
||||
idset <- list()
|
||||
for (i in 1:nfold) {
|
||||
idset[[i]] <- randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ]
|
||||
if (nfold <= 1) {
|
||||
stop("nfold must be bigger than 1")
|
||||
}
|
||||
randidx <- sample(1 : xgb.numrow(dall))
|
||||
kstep <- length(randidx) %/% nfold
|
||||
idset <- list()
|
||||
for (i in 1:(nfold-1)) {
|
||||
idset[[i]] = randidx[1:kstep]
|
||||
randidx = setdiff(randidx,idset[[i]])
|
||||
}
|
||||
idset[[nfold]] = randidx
|
||||
ret <- list()
|
||||
for (k in 1:nfold) {
|
||||
dtest <- slice(dall, idset[[k]])
|
||||
@@ -193,7 +202,7 @@ xgb.cv.mknfold <- function(dall, nfold, param) {
|
||||
dtrain <- slice(dall, didx)
|
||||
bst <- xgb.Booster(param, list(dtrain, dtest))
|
||||
watchlist = list(train=dtrain, test=dtest)
|
||||
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist)
|
||||
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=idset[[k]])
|
||||
}
|
||||
return (ret)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user