[R] Fix softprob reshape. (#7126)

This commit is contained in:
Jiaming Yuan
2021-07-27 15:25:17 +08:00
committed by GitHub
parent 7ee7a95b84
commit 48d5de80a2
2 changed files with 16 additions and 1 deletions

View File

@@ -150,6 +150,21 @@ test_that("train and predict softprob", {
mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 2))
expect_equal(mpred, mpred1)
d <- cbind(
x1 = rnorm(100),
x2 = rnorm(100),
x3 = rnorm(100)
)
y <- sample.int(10, 100, replace = TRUE) - 1
dtrain <- xgb.DMatrix(data = d, info = list(label = y))
booster <- xgb.train(
params = list(tree_method = "hist"), data = dtrain, nrounds = 4, num_class = 10,
objective = "multi:softprob"
)
predt <- predict(booster, as.matrix(d), reshape = TRUE, strict_shape = FALSE)
expect_equal(ncol(predt), 10)
expect_equal(rowSums(predt), rep(1, 100), tolerance = 1e-7)
})
test_that("train and predict softmax", {