[R] Fix softprob reshape. (#7126)
This commit is contained in:
parent
7ee7a95b84
commit
48d5de80a2
@ -428,7 +428,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
} else if (!reshape && n_groups != 1) {
|
||||
arr <- ret
|
||||
} else if (reshape && n_groups != 1) {
|
||||
arr <- matrix(arr, ncol = 3, byrow = TRUE)
|
||||
arr <- matrix(arr, ncol = n_groups, byrow = TRUE)
|
||||
}
|
||||
arr <- drop(arr)
|
||||
if (length(dim(arr)) == 1) {
|
||||
|
||||
@ -150,6 +150,21 @@ test_that("train and predict softprob", {
|
||||
|
||||
mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 2))
|
||||
expect_equal(mpred, mpred1)
|
||||
|
||||
d <- cbind(
|
||||
x1 = rnorm(100),
|
||||
x2 = rnorm(100),
|
||||
x3 = rnorm(100)
|
||||
)
|
||||
y <- sample.int(10, 100, replace = TRUE) - 1
|
||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y))
|
||||
booster <- xgb.train(
|
||||
params = list(tree_method = "hist"), data = dtrain, nrounds = 4, num_class = 10,
|
||||
objective = "multi:softprob"
|
||||
)
|
||||
predt <- predict(booster, as.matrix(d), reshape = TRUE, strict_shape = FALSE)
|
||||
expect_equal(ncol(predt), 10)
|
||||
expect_equal(rowSums(predt), rep(1, 100), tolerance = 1e-7)
|
||||
})
|
||||
|
||||
test_that("train and predict softmax", {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user