fix early stopping and prediction
This commit is contained in:
parent
6b254ec495
commit
704d9e0a13
@ -186,22 +186,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
for (k in 1:nfold) {
|
for (k in 1:nfold) {
|
||||||
fd <- xgb_folds[[k]]
|
fd <- xgb_folds[[k]]
|
||||||
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
|
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
|
||||||
if (i<nrounds) {
|
|
||||||
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
|
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
|
||||||
} else {
|
|
||||||
if (!prediction) {
|
|
||||||
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
|
|
||||||
} else {
|
|
||||||
res <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction)
|
|
||||||
if (mat_pred) {
|
|
||||||
pred_mat = matrix(res[[2]],num_class,length(fd$index))
|
|
||||||
predictValues[fd$index,] <- t(pred_mat)
|
|
||||||
} else {
|
|
||||||
predictValues[fd$index] <- res[[2]]
|
|
||||||
}
|
|
||||||
msg[[k]] <- res[[1]] %>% str_split("\t") %>% .[[1]]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ret <- xgb.cv.aggcv(msg, showsd)
|
ret <- xgb.cv.aggcv(msg, showsd)
|
||||||
history <- c(history, ret)
|
history <- c(history, ret)
|
||||||
@ -228,6 +213,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (prediction) {
|
||||||
|
for (k in 1:nfold) {
|
||||||
|
fd = xgb_folds[[k]]
|
||||||
|
res = xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction)
|
||||||
|
if (mat_pred) {
|
||||||
|
pred_mat = matrix(res[[2]],num_class,length(fd$index))
|
||||||
|
predictValues[fd$index,] = t(pred_mat)
|
||||||
|
} else {
|
||||||
|
predictValues[fd$index] = res[[2]]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")
|
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")
|
||||||
colnamesMean <- paste(colnames, "mean")
|
colnamesMean <- paste(colnames, "mean")
|
||||||
if(showsd) colnamesStd <- paste(colnames, "std")
|
if(showsd) colnamesStd <- paste(colnames, "std")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user