diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 71269419b..258c7fb16 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -95,7 +95,18 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = } folds <- xgb.cv.mknfold(dtrain, nfold, params) - predictValues <- rep(0,xgb.numrow(dtrain)) + obj_type = params[['objective']] + mat_pred = FALSE + if (!is.null(obj_type) && obj_type=='multi:softprob') + { + num_class = params[['num_class']] + if (is.null(num_class)) + stop('must set num_class to use softmax') + predictValues <- matrix(0,xgb.numrow(dtrain),num_class) + mat_pred = TRUE + } + else + predictValues <- rep(0,xgb.numrow(dtrain)) history <- c() for (i in 1:nrounds) { msg <- list() @@ -106,7 +117,10 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]] } else { res <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction) - predictValues[fd$index] <- res[[2]] + if (mat_pred) + predictValues[fd$index,] <- res[[2]] + else + predictValues[fd$index] <- res[[2]] msg[[k]] <- res[[1]] %>% str_split("\t") %>% .[[1]] } }