diff --git a/R-package/R/predict.xgb.Booster.R b/R-package/R/predict.xgb.Booster.R index 0c50b2504..902260258 100644 --- a/R-package/R/predict.xgb.Booster.R +++ b/R-package/R/predict.xgb.Booster.R @@ -31,7 +31,7 @@ setClass("xgb.Booster", #' @export #' setMethod("predict", signature = "xgb.Booster", - definition = function(object, newdata, missing = NULL, + definition = function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) { if (class(object) != "xgb.Booster"){ stop("predict: model in prediction must be of class xgb.Booster") @@ -39,11 +39,7 @@ setMethod("predict", signature = "xgb.Booster", object <- xgb.Booster.check(object, saveraw = FALSE) } if (class(newdata) != "xgb.DMatrix") { - if (is.null(missing)) { - newdata <- xgb.DMatrix(newdata) - } else { - newdata <- xgb.DMatrix(newdata, missing = missing) - } + newdata <- xgb.DMatrix(newdata, missing = missing) } if (is.null(ntreelimit)) { ntreelimit <- 0 diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 732ef0d11..eecc5e260 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -103,18 +103,13 @@ xgb.Booster.check <- function(bst, saveraw = TRUE) ## ----the following are low level iteratively function, not needed if ## you do not want to use them --------------------------------------- # get dmatrix from data, label -xgb.get.DMatrix <- function(data, label = NULL, missing = NULL, weight = NULL) { +xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL) { inClass <- class(data) if (inClass == "dgCMatrix" || inClass == "matrix") { if (is.null(label)) { stop("xgboost: need label when data is a matrix") } - dtrain <- xgb.DMatrix(data, label = label) - if (is.null(missing)){ - dtrain <- xgb.DMatrix(data, label = label) - } else { - dtrain <- xgb.DMatrix(data, label = label, missing = missing) - } + dtrain <- xgb.DMatrix(data, label = label, missing = missing) if (!is.null(weight)){ xgb.setinfo(dtrain, "weight", weight) } diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 8c3ea80bc..970fab394 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -18,7 +18,7 @@ #' dtrain <- xgb.DMatrix('xgb.DMatrix.data') #' @export #' -xgb.DMatrix <- function(data, info = list(), missing = 0, ...) { +xgb.DMatrix <- function(data, info = list(), missing = NA, ...) { if (typeof(data) == "character") { handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE), PACKAGE = "xgboost") diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index a5364db52..9811bba38 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -91,7 +91,7 @@ #' print(history) #' @export #' -xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL, +xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NA, prediction = FALSE, showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T, print.every.n=1L, early.stop.round = NULL, maximize = NULL, ...) { @@ -107,11 +107,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = if (nfold <= 1) { stop("nfold must be bigger than 1") } - if (is.null(missing)) { - dtrain <- xgb.get.DMatrix(data, label) - } else { - dtrain <- xgb.get.DMatrix(data, label, missing) - } + dtrain <- xgb.get.DMatrix(data, label, missing) dot.params = list(...) nms.params = names(params) nms.dot.params = names(dot.params) diff --git a/R-package/R/xgboost.R b/R-package/R/xgboost.R index 164dc1838..e11052add 100644 --- a/R-package/R/xgboost.R +++ b/R-package/R/xgboost.R @@ -59,7 +59,7 @@ #' #' @export #' -xgboost <- function(data = NULL, label = NULL, missing = NULL, weight = NULL, +xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL, params = list(), nrounds, verbose = 1, print.every.n = 1L, early.stop.round = NULL, maximize = NULL, save_period = 0, save_name = "xgboost.model", ...) {