diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 1228a9a61..922af0eb0 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -428,7 +428,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA } else if (!reshape && n_groups != 1) { arr <- ret } else if (reshape && n_groups != 1) { - arr <- matrix(arr, ncol = 3, byrow = TRUE) + arr <- matrix(arr, ncol = n_groups, byrow = TRUE) } arr <- drop(arr) if (length(dim(arr)) == 1) { diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index a97eb01ba..e90232dae 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -150,6 +150,21 @@ test_that("train and predict softprob", { mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 2)) expect_equal(mpred, mpred1) + + d <- cbind( + x1 = rnorm(100), + x2 = rnorm(100), + x3 = rnorm(100) + ) + y <- sample.int(10, 100, replace = TRUE) - 1 + dtrain <- xgb.DMatrix(data = d, info = list(label = y)) + booster <- xgb.train( + params = list(tree_method = "hist"), data = dtrain, nrounds = 4, num_class = 10, + objective = "multi:softprob" + ) + predt <- predict(booster, as.matrix(d), reshape = TRUE, strict_shape = FALSE) + expect_equal(ncol(predt), 10) + expect_equal(rowSums(predt), rep(1, 100), tolerance = 1e-7) }) test_that("train and predict softmax", {