fix early stopping and prediction

This commit is contained in:
Tong He 2015-06-21 19:46:31 -07:00
parent 6b254ec495
commit 704d9e0a13

View File

@ -186,22 +186,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
for (k in 1:nfold) { for (k in 1:nfold) {
fd <- xgb_folds[[k]] fd <- xgb_folds[[k]]
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj) succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
if (i<nrounds) {
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]] msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
} else {
if (!prediction) {
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
} else {
res <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction)
if (mat_pred) {
pred_mat = matrix(res[[2]],num_class,length(fd$index))
predictValues[fd$index,] <- t(pred_mat)
} else {
predictValues[fd$index] <- res[[2]]
}
msg[[k]] <- res[[1]] %>% str_split("\t") %>% .[[1]]
}
}
} }
ret <- xgb.cv.aggcv(msg, showsd) ret <- xgb.cv.aggcv(msg, showsd)
history <- c(history, ret) history <- c(history, ret)
@ -228,6 +213,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
} }
if (prediction) {
for (k in 1:nfold) {
fd = xgb_folds[[k]]
res = xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction)
if (mat_pred) {
pred_mat = matrix(res[[2]],num_class,length(fd$index))
predictValues[fd$index,] = t(pred_mat)
} else {
predictValues[fd$index] = res[[2]]
}
}
}
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".") colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")
colnamesMean <- paste(colnames, "mean") colnamesMean <- paste(colnames, "mean")
if(showsd) colnamesStd <- paste(colnames, "std") if(showsd) colnamesStd <- paste(colnames, "std")