fix with new predict
This commit is contained in:
parent
f7c838ffaa
commit
0aef62dabc
16
R-package/R/predict.xgb.Booster.handle.R
Normal file
16
R-package/R/predict.xgb.Booster.handle.R
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
setClass("xgb.Booster.handle")
|
||||||
|
|
||||||
|
setMethod("predict", signature = "xgb.Booster.handle",
|
||||||
|
definition = function(object, ...) {
|
||||||
|
if (class(object) != "xgb.Booster.handle"){
|
||||||
|
stop("predict: model in prediction must be of class xgb.Booster.handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
bst <- list(handle = object,raw = NULL)
|
||||||
|
class(bst) <- 'xgb.Booster'
|
||||||
|
bst$raw <- xgb.save.raw(bst$handle)
|
||||||
|
|
||||||
|
ret = predict(bst, ...)
|
||||||
|
return(ret)
|
||||||
|
})
|
||||||
|
|
||||||
@ -102,42 +102,28 @@ xgb.numrow <- function(dmat) {
|
|||||||
}
|
}
|
||||||
# iteratively update booster with customized statistics
|
# iteratively update booster with customized statistics
|
||||||
xgb.iter.boost <- function(booster, dtrain, gpair) {
|
xgb.iter.boost <- function(booster, dtrain, gpair) {
|
||||||
if (class(booster) != "xgb.Booster") {
|
if (class(booster) != "xgb.Booster.handle") {
|
||||||
stop("xgb.iter.update: first argument must be type xgb.Booster")
|
stop("xgb.iter.update: first argument must be type xgb.Booster.handle")
|
||||||
} else {
|
|
||||||
if (is.null(booster$handle)) {
|
|
||||||
booster$handle <- xgb.load(booster$raw)
|
|
||||||
} else {
|
|
||||||
if (is.null(booster$raw))
|
|
||||||
booster$raw <- xgb.save.raw(booster$handle)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (class(dtrain) != "xgb.DMatrix") {
|
if (class(dtrain) != "xgb.DMatrix") {
|
||||||
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
|
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
|
||||||
}
|
}
|
||||||
.Call("XGBoosterBoostOneIter_R", booster$handle, dtrain, gpair$grad, gpair$hess,
|
.Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess,
|
||||||
PACKAGE = "xgboost")
|
PACKAGE = "xgboost")
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
|
|
||||||
# iteratively update booster with dtrain
|
# iteratively update booster with dtrain
|
||||||
xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
||||||
if (class(booster) != "xgb.Booster") {
|
if (class(booster) != "xgb.Booster.handle") {
|
||||||
stop("xgb.iter.update: first argument must be type xgb.Booster")
|
stop("xgb.iter.update: first argument must be type xgb.Booster.handle")
|
||||||
} else {
|
|
||||||
if (is.null(booster$handle)) {
|
|
||||||
booster$handle <- xgb.load(booster$raw)
|
|
||||||
} else {
|
|
||||||
if (is.null(booster$raw))
|
|
||||||
booster$raw <- xgb.save.raw(booster$handle)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (class(dtrain) != "xgb.DMatrix") {
|
if (class(dtrain) != "xgb.DMatrix") {
|
||||||
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
|
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is.null(obj)) {
|
if (is.null(obj)) {
|
||||||
.Call("XGBoosterUpdateOneIter_R", booster$handle, as.integer(iter), dtrain,
|
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain,
|
||||||
PACKAGE = "xgboost")
|
PACKAGE = "xgboost")
|
||||||
} else {
|
} else {
|
||||||
pred <- predict(booster, dtrain)
|
pred <- predict(booster, dtrain)
|
||||||
|
|||||||
@ -19,6 +19,7 @@
|
|||||||
xgb.load <- function(modelfile) {
|
xgb.load <- function(modelfile) {
|
||||||
if (is.null(modelfile))
|
if (is.null(modelfile))
|
||||||
stop("xgb.load: modelfile cannot be NULL")
|
stop("xgb.load: modelfile cannot be NULL")
|
||||||
|
|
||||||
bst <- list(handle = NULL,raw = NULL)
|
bst <- list(handle = NULL,raw = NULL)
|
||||||
class(bst) <- 'xgb.Booster'
|
class(bst) <- 'xgb.Booster'
|
||||||
bst$handle <- xgb.Booster(modelfile = modelfile)
|
bst$handle <- xgb.Booster(modelfile = modelfile)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user