fix with new predict

This commit is contained in:
hetong007 2015-02-09 16:25:00 -08:00
parent f7c838ffaa
commit 0aef62dabc
4 changed files with 24 additions and 21 deletions

View File

@ -0,0 +1,16 @@
setClass("xgb.Booster.handle")
setMethod("predict", signature = "xgb.Booster.handle",
definition = function(object, ...) {
if (class(object) != "xgb.Booster.handle"){
stop("predict: model in prediction must be of class xgb.Booster.handle")
}
bst <- list(handle = object,raw = NULL)
class(bst) <- 'xgb.Booster'
bst$raw <- xgb.save.raw(bst$handle)
ret = predict(bst, ...)
return(ret)
})

View File

@ -102,42 +102,28 @@ xgb.numrow <- function(dmat) {
} }
# iteratively update booster with customized statistics # iteratively update booster with customized statistics
xgb.iter.boost <- function(booster, dtrain, gpair) { xgb.iter.boost <- function(booster, dtrain, gpair) {
if (class(booster) != "xgb.Booster") { if (class(booster) != "xgb.Booster.handle") {
stop("xgb.iter.update: first argument must be type xgb.Booster") stop("xgb.iter.update: first argument must be type xgb.Booster.handle")
} else {
if (is.null(booster$handle)) {
booster$handle <- xgb.load(booster$raw)
} else {
if (is.null(booster$raw))
booster$raw <- xgb.save.raw(booster$handle)
}
} }
if (class(dtrain) != "xgb.DMatrix") { if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.iter.update: second argument must be type xgb.DMatrix") stop("xgb.iter.update: second argument must be type xgb.DMatrix")
} }
.Call("XGBoosterBoostOneIter_R", booster$handle, dtrain, gpair$grad, gpair$hess, .Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess,
PACKAGE = "xgboost") PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
# iteratively update booster with dtrain # iteratively update booster with dtrain
xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
if (class(booster) != "xgb.Booster") { if (class(booster) != "xgb.Booster.handle") {
stop("xgb.iter.update: first argument must be type xgb.Booster") stop("xgb.iter.update: first argument must be type xgb.Booster.handle")
} else {
if (is.null(booster$handle)) {
booster$handle <- xgb.load(booster$raw)
} else {
if (is.null(booster$raw))
booster$raw <- xgb.save.raw(booster$handle)
}
} }
if (class(dtrain) != "xgb.DMatrix") { if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.iter.update: second argument must be type xgb.DMatrix") stop("xgb.iter.update: second argument must be type xgb.DMatrix")
} }
if (is.null(obj)) { if (is.null(obj)) {
.Call("XGBoosterUpdateOneIter_R", booster$handle, as.integer(iter), dtrain, .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain,
PACKAGE = "xgboost") PACKAGE = "xgboost")
} else { } else {
pred <- predict(booster, dtrain) pred <- predict(booster, dtrain)

View File

@ -19,6 +19,7 @@
xgb.load <- function(modelfile) { xgb.load <- function(modelfile) {
if (is.null(modelfile)) if (is.null(modelfile))
stop("xgb.load: modelfile cannot be NULL") stop("xgb.load: modelfile cannot be NULL")
bst <- list(handle = NULL,raw = NULL) bst <- list(handle = NULL,raw = NULL)
class(bst) <- 'xgb.Booster' class(bst) <- 'xgb.Booster'
bst$handle <- xgb.Booster(modelfile = modelfile) bst$handle <- xgb.Booster(modelfile = modelfile)