Switch missing values from 0 to NA in R package
This commit is contained in:
parent
3109069019
commit
7a94bdb60c
@ -31,7 +31,7 @@ setClass("xgb.Booster",
|
||||
#' @export
|
||||
#'
|
||||
setMethod("predict", signature = "xgb.Booster",
|
||||
definition = function(object, newdata, missing = NULL,
|
||||
definition = function(object, newdata, missing = NA,
|
||||
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
|
||||
if (class(object) != "xgb.Booster"){
|
||||
stop("predict: model in prediction must be of class xgb.Booster")
|
||||
@ -39,11 +39,7 @@ setMethod("predict", signature = "xgb.Booster",
|
||||
object <- xgb.Booster.check(object, saveraw = FALSE)
|
||||
}
|
||||
if (class(newdata) != "xgb.DMatrix") {
|
||||
if (is.null(missing)) {
|
||||
newdata <- xgb.DMatrix(newdata)
|
||||
} else {
|
||||
newdata <- xgb.DMatrix(newdata, missing = missing)
|
||||
}
|
||||
newdata <- xgb.DMatrix(newdata, missing = missing)
|
||||
}
|
||||
if (is.null(ntreelimit)) {
|
||||
ntreelimit <- 0
|
||||
|
||||
@ -103,18 +103,13 @@ xgb.Booster.check <- function(bst, saveraw = TRUE)
|
||||
## ----the following are low level iteratively function, not needed if
|
||||
## you do not want to use them ---------------------------------------
|
||||
# get dmatrix from data, label
|
||||
xgb.get.DMatrix <- function(data, label = NULL, missing = NULL, weight = NULL) {
|
||||
xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL) {
|
||||
inClass <- class(data)
|
||||
if (inClass == "dgCMatrix" || inClass == "matrix") {
|
||||
if (is.null(label)) {
|
||||
stop("xgboost: need label when data is a matrix")
|
||||
}
|
||||
dtrain <- xgb.DMatrix(data, label = label)
|
||||
if (is.null(missing)){
|
||||
dtrain <- xgb.DMatrix(data, label = label)
|
||||
} else {
|
||||
dtrain <- xgb.DMatrix(data, label = label, missing = missing)
|
||||
}
|
||||
dtrain <- xgb.DMatrix(data, label = label, missing = missing)
|
||||
if (!is.null(weight)){
|
||||
xgb.setinfo(dtrain, "weight", weight)
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
||||
#' @export
|
||||
#'
|
||||
xgb.DMatrix <- function(data, info = list(), missing = 0, ...) {
|
||||
xgb.DMatrix <- function(data, info = list(), missing = NA, ...) {
|
||||
if (typeof(data) == "character") {
|
||||
handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE),
|
||||
PACKAGE = "xgboost")
|
||||
|
||||
@ -91,7 +91,7 @@
|
||||
#' print(history)
|
||||
#' @export
|
||||
#'
|
||||
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
|
||||
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NA,
|
||||
prediction = FALSE, showsd = TRUE, metrics=list(),
|
||||
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T, print.every.n=1L,
|
||||
early.stop.round = NULL, maximize = NULL, ...) {
|
||||
@ -107,11 +107,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
if (nfold <= 1) {
|
||||
stop("nfold must be bigger than 1")
|
||||
}
|
||||
if (is.null(missing)) {
|
||||
dtrain <- xgb.get.DMatrix(data, label)
|
||||
} else {
|
||||
dtrain <- xgb.get.DMatrix(data, label, missing)
|
||||
}
|
||||
dtrain <- xgb.get.DMatrix(data, label, missing)
|
||||
dot.params = list(...)
|
||||
nms.params = names(params)
|
||||
nms.dot.params = names(dot.params)
|
||||
|
||||
@ -59,7 +59,7 @@
|
||||
#'
|
||||
#' @export
|
||||
#'
|
||||
xgboost <- function(data = NULL, label = NULL, missing = NULL, weight = NULL,
|
||||
xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL,
|
||||
params = list(), nrounds,
|
||||
verbose = 1, print.every.n = 1L, early.stop.round = NULL,
|
||||
maximize = NULL, save_period = 0, save_name = "xgboost.model", ...) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user