[R] Fix softprob reshape. (#7126)
This commit is contained in:
@@ -150,6 +150,21 @@ test_that("train and predict softprob", {
|
||||
|
||||
mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 2))
|
||||
expect_equal(mpred, mpred1)
|
||||
|
||||
d <- cbind(
|
||||
x1 = rnorm(100),
|
||||
x2 = rnorm(100),
|
||||
x3 = rnorm(100)
|
||||
)
|
||||
y <- sample.int(10, 100, replace = TRUE) - 1
|
||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y))
|
||||
booster <- xgb.train(
|
||||
params = list(tree_method = "hist"), data = dtrain, nrounds = 4, num_class = 10,
|
||||
objective = "multi:softprob"
|
||||
)
|
||||
predt <- predict(booster, as.matrix(d), reshape = TRUE, strict_shape = FALSE)
|
||||
expect_equal(ncol(predt), 10)
|
||||
expect_equal(rowSums(predt), rep(1, 100), tolerance = 1e-7)
|
||||
})
|
||||
|
||||
test_that("train and predict softmax", {
|
||||
|
||||
Reference in New Issue
Block a user