Switch missing values from 0 to NA in R package

This commit is contained in:
kferris 2015-10-07 18:51:47 -04:00
parent 3109069019
commit 7a94bdb60c
5 changed files with 8 additions and 21 deletions

View File

@ -31,7 +31,7 @@ setClass("xgb.Booster",
#' @export #' @export
#' #'
setMethod("predict", signature = "xgb.Booster", setMethod("predict", signature = "xgb.Booster",
definition = function(object, newdata, missing = NULL, definition = function(object, newdata, missing = NA,
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) { outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
if (class(object) != "xgb.Booster"){ if (class(object) != "xgb.Booster"){
stop("predict: model in prediction must be of class xgb.Booster") stop("predict: model in prediction must be of class xgb.Booster")
@ -39,12 +39,8 @@ setMethod("predict", signature = "xgb.Booster",
object <- xgb.Booster.check(object, saveraw = FALSE) object <- xgb.Booster.check(object, saveraw = FALSE)
} }
if (class(newdata) != "xgb.DMatrix") { if (class(newdata) != "xgb.DMatrix") {
if (is.null(missing)) {
newdata <- xgb.DMatrix(newdata)
} else {
newdata <- xgb.DMatrix(newdata, missing = missing) newdata <- xgb.DMatrix(newdata, missing = missing)
} }
}
if (is.null(ntreelimit)) { if (is.null(ntreelimit)) {
ntreelimit <- 0 ntreelimit <- 0
} else { } else {

View File

@ -103,18 +103,13 @@ xgb.Booster.check <- function(bst, saveraw = TRUE)
## ----the following are low level iteratively function, not needed if ## ----the following are low level iteratively function, not needed if
## you do not want to use them --------------------------------------- ## you do not want to use them ---------------------------------------
# get dmatrix from data, label # get dmatrix from data, label
xgb.get.DMatrix <- function(data, label = NULL, missing = NULL, weight = NULL) { xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL) {
inClass <- class(data) inClass <- class(data)
if (inClass == "dgCMatrix" || inClass == "matrix") { if (inClass == "dgCMatrix" || inClass == "matrix") {
if (is.null(label)) { if (is.null(label)) {
stop("xgboost: need label when data is a matrix") stop("xgboost: need label when data is a matrix")
} }
dtrain <- xgb.DMatrix(data, label = label)
if (is.null(missing)){
dtrain <- xgb.DMatrix(data, label = label)
} else {
dtrain <- xgb.DMatrix(data, label = label, missing = missing) dtrain <- xgb.DMatrix(data, label = label, missing = missing)
}
if (!is.null(weight)){ if (!is.null(weight)){
xgb.setinfo(dtrain, "weight", weight) xgb.setinfo(dtrain, "weight", weight)
} }

View File

@ -18,7 +18,7 @@
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data') #' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
#' @export #' @export
#' #'
xgb.DMatrix <- function(data, info = list(), missing = 0, ...) { xgb.DMatrix <- function(data, info = list(), missing = NA, ...) {
if (typeof(data) == "character") { if (typeof(data) == "character") {
handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE), handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE),
PACKAGE = "xgboost") PACKAGE = "xgboost")

View File

@ -91,7 +91,7 @@
#' print(history) #' print(history)
#' @export #' @export
#' #'
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL, xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NA,
prediction = FALSE, showsd = TRUE, metrics=list(), prediction = FALSE, showsd = TRUE, metrics=list(),
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T, print.every.n=1L, obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T, print.every.n=1L,
early.stop.round = NULL, maximize = NULL, ...) { early.stop.round = NULL, maximize = NULL, ...) {
@ -107,11 +107,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
if (nfold <= 1) { if (nfold <= 1) {
stop("nfold must be bigger than 1") stop("nfold must be bigger than 1")
} }
if (is.null(missing)) {
dtrain <- xgb.get.DMatrix(data, label)
} else {
dtrain <- xgb.get.DMatrix(data, label, missing) dtrain <- xgb.get.DMatrix(data, label, missing)
}
dot.params = list(...) dot.params = list(...)
nms.params = names(params) nms.params = names(params)
nms.dot.params = names(dot.params) nms.dot.params = names(dot.params)

View File

@ -59,7 +59,7 @@
#' #'
#' @export #' @export
#' #'
xgboost <- function(data = NULL, label = NULL, missing = NULL, weight = NULL, xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL,
params = list(), nrounds, params = list(), nrounds,
verbose = 1, print.every.n = 1L, early.stop.round = NULL, verbose = 1, print.every.n = 1L, early.stop.round = NULL,
maximize = NULL, save_period = 0, save_name = "xgboost.model", ...) { maximize = NULL, save_period = 0, save_name = "xgboost.model", ...) {