diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index fead30413..4b2bb0d2a 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -16,6 +16,8 @@ #' only care about the relative ordering of data points within each group, #' so it doesn't make sense to assign weights to individual data points. #' @param base_margin Base margin used for boosting from existing model. +#' +#' In the case of multi-output models, one can also pass multi-dimensional base_margin. #' @param missing a float value to represents missing values in data (used only when input is a dense matrix). #' It is useful when a 0 or some other extreme value represents missing values in data. #' @param silent whether to suppress printing an informational message after loading from a file. @@ -439,9 +441,7 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { return(TRUE) } if (name == "base_margin") { - # if (length(info)!=nrow(object)) - # stop("The length of base margin must equal to the number of rows in the input data") - .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "group") { diff --git a/R-package/man/xgb.DMatrix.Rd b/R-package/man/xgb.DMatrix.Rd index 95cc8d3cd..a1ef39f0b 100644 --- a/R-package/man/xgb.DMatrix.Rd +++ b/R-package/man/xgb.DMatrix.Rd @@ -36,7 +36,9 @@ is assigned to each group (not each data point). This is because we only care about the relative ordering of data points within each group, so it doesn't make sense to assign weights to individual data points.} -\item{base_margin}{Base margin used for boosting from existing model.} +\item{base_margin}{Base margin used for boosting from existing model. + + In the case of multi-output models, one can also pass multi-dimensional base_margin.} \item{missing}{a float value to represents missing values in data (used only when input is a dense matrix). It is useful when a 0 or some other extreme value represents missing values in data.} diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 87a73d84b..55a699687 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", { m <- xgb.DMatrix(df, enable_categorical = TRUE) expect_equal(getinfo(m, "feature_type"), c("c", "c")) }) + +test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", { + set.seed(123) + x <- matrix(rnorm(100 * 10), nrow = 100) + y <- matrix(rnorm(100 * 2), nrow = 100) + b <- matrix(rnorm(100 * 2), nrow = 100) + model <- xgb.train( + data = xgb.DMatrix(data = x, label = y, nthread = n_threads), + params = list( + objective = "reg:squarederror", + tree_method = "hist", + multi_strategy = "multi_output_tree", + base_score = 0, + nthread = n_threads + ), + nround = 1 + ) + pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE) + pred_w_base <- predict( + model, + xgb.DMatrix(data = x, base_margin = b, nthread = n_threads), + nthread = n_threads, + reshape = TRUE + ) + expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5) +})