[R] Fix softprob reshape. (#7126)

This commit is contained in:
Jiaming Yuan 2021-07-27 15:25:17 +08:00 committed by GitHub
parent 7ee7a95b84
commit 48d5de80a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -428,7 +428,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
} else if (!reshape && n_groups != 1) { } else if (!reshape && n_groups != 1) {
arr <- ret arr <- ret
} else if (reshape && n_groups != 1) { } else if (reshape && n_groups != 1) {
arr <- matrix(arr, ncol = 3, byrow = TRUE) arr <- matrix(arr, ncol = n_groups, byrow = TRUE)
} }
arr <- drop(arr) arr <- drop(arr)
if (length(dim(arr)) == 1) { if (length(dim(arr)) == 1) {

View File

@ -150,6 +150,21 @@ test_that("train and predict softprob", {
mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 2)) mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 2))
expect_equal(mpred, mpred1) expect_equal(mpred, mpred1)
d <- cbind(
x1 = rnorm(100),
x2 = rnorm(100),
x3 = rnorm(100)
)
y <- sample.int(10, 100, replace = TRUE) - 1
dtrain <- xgb.DMatrix(data = d, info = list(label = y))
booster <- xgb.train(
params = list(tree_method = "hist"), data = dtrain, nrounds = 4, num_class = 10,
objective = "multi:softprob"
)
predt <- predict(booster, as.matrix(d), reshape = TRUE, strict_shape = FALSE)
expect_equal(ncol(predt), 10)
expect_equal(rowSums(predt), rep(1, 100), tolerance = 1e-7)
}) })
test_that("train and predict softmax", { test_that("train and predict softmax", {