fix segfault and add two function for handle and booster
This commit is contained in:
parent
0aef62dabc
commit
4c25600d2a
@ -36,12 +36,7 @@ setMethod("predict", signature = "xgb.Booster",
|
|||||||
if (class(object) != "xgb.Booster"){
|
if (class(object) != "xgb.Booster"){
|
||||||
stop("predict: model in prediction must be of class xgb.Booster")
|
stop("predict: model in prediction must be of class xgb.Booster")
|
||||||
} else {
|
} else {
|
||||||
if (is.null(object$handle)) {
|
object <- xgb.Booster.check(object, saveraw = FALSE)
|
||||||
object$handle <- xgb.load(object$raw)
|
|
||||||
} else {
|
|
||||||
if (is.null(object$raw))
|
|
||||||
object$raw <- xgb.save.raw(object$handle)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (class(newdata) != "xgb.DMatrix") {
|
if (class(newdata) != "xgb.DMatrix") {
|
||||||
if (is.null(missing)) {
|
if (is.null(missing)) {
|
||||||
|
|||||||
@ -6,9 +6,9 @@ setMethod("predict", signature = "xgb.Booster.handle",
|
|||||||
stop("predict: model in prediction must be of class xgb.Booster.handle")
|
stop("predict: model in prediction must be of class xgb.Booster.handle")
|
||||||
}
|
}
|
||||||
|
|
||||||
bst <- list(handle = object,raw = NULL)
|
bst <- xgb.handleToBooster(object)
|
||||||
class(bst) <- 'xgb.Booster'
|
# Avoid save a handle without update
|
||||||
bst$raw <- xgb.save.raw(bst$handle)
|
# bst$raw <- xgb.save.raw(object)
|
||||||
|
|
||||||
ret = predict(bst, ...)
|
ret = predict(bst, ...)
|
||||||
return(ret)
|
return(ret)
|
||||||
|
|||||||
@ -68,6 +68,26 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
|
|||||||
return(structure(handle, class = "xgb.Booster.handle"))
|
return(structure(handle, class = "xgb.Booster.handle"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# convert xgb.Booster.handle to xgb.Booster
|
||||||
|
xgb.handleToBooster <- function(handle)
|
||||||
|
{
|
||||||
|
bst <- list(handle = handle, raw = NULL)
|
||||||
|
class(bst) <- "xgb.Booster"
|
||||||
|
return(bst)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check whether an xgb.Booster object is complete
|
||||||
|
xgb.Booster.check <- function(bst, saveraw = TRUE)
|
||||||
|
{
|
||||||
|
if (is.null(bst$handle)) {
|
||||||
|
bst$handle <- xgb.load(bst$raw)
|
||||||
|
} else {
|
||||||
|
if (is.null(bst$raw) && saveraw)
|
||||||
|
bst$raw <- xgb.save.raw(bst$handle)
|
||||||
|
}
|
||||||
|
return(bst)
|
||||||
|
}
|
||||||
|
|
||||||
## ----the following are low level iteratively function, not needed if
|
## ----the following are low level iteratively function, not needed if
|
||||||
## you do not want to use them ---------------------------------------
|
## you do not want to use them ---------------------------------------
|
||||||
# get dmatrix from data, label
|
# get dmatrix from data, label
|
||||||
|
|||||||
@ -41,12 +41,7 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) {
|
|||||||
if (class(model) != "xgb.Booster") {
|
if (class(model) != "xgb.Booster") {
|
||||||
stop("model: argument must be type xgb.Booster")
|
stop("model: argument must be type xgb.Booster")
|
||||||
} else {
|
} else {
|
||||||
if (is.null(model$handle)) {
|
model <- xgb.Booster.check(model)
|
||||||
model$handle <- xgb.load(model$raw)
|
|
||||||
} else {
|
|
||||||
if (is.null(model$raw))
|
|
||||||
model$raw <- xgb.save.raw(model$handle)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) {
|
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) {
|
||||||
stop("fname: argument must be type character (when provided)")
|
stop("fname: argument must be type character (when provided)")
|
||||||
|
|||||||
@ -20,9 +20,8 @@ xgb.load <- function(modelfile) {
|
|||||||
if (is.null(modelfile))
|
if (is.null(modelfile))
|
||||||
stop("xgb.load: modelfile cannot be NULL")
|
stop("xgb.load: modelfile cannot be NULL")
|
||||||
|
|
||||||
bst <- list(handle = NULL,raw = NULL)
|
handle <- xgb.Booster(modelfile = modelfile)
|
||||||
class(bst) <- 'xgb.Booster'
|
bst <- xgb.handleToBooster(handle)
|
||||||
bst$handle <- xgb.Booster(modelfile = modelfile)
|
bst <- xgb.Booster.check(bst)
|
||||||
bst$raw <- xgb.save.raw(bst$handle)
|
|
||||||
return(bst)
|
return(bst)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,9 +22,7 @@ xgb.save <- function(model, fname) {
|
|||||||
stop("xgb.save: fname must be character")
|
stop("xgb.save: fname must be character")
|
||||||
}
|
}
|
||||||
if (class(model) == "xgb.Booster") {
|
if (class(model) == "xgb.Booster") {
|
||||||
if (is.null(model$handle)) {
|
model <- xgb.Booster.check(model)
|
||||||
model$handle <- xgb.load(model$raw)
|
|
||||||
}
|
|
||||||
.Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost")
|
.Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost")
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -86,9 +86,8 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
|||||||
}
|
}
|
||||||
params = append(params, list(...))
|
params = append(params, list(...))
|
||||||
|
|
||||||
bst <- list(handle = NULL,raw = NULL)
|
handle <- xgb.Booster(params, append(watchlist, dtrain))
|
||||||
class(bst) <- 'xgb.Booster'
|
bst <- xgb.handleToBooster(handle)
|
||||||
bst$handle <- xgb.Booster(params, append(watchlist, dtrain))
|
|
||||||
for (i in 1:nrounds) {
|
for (i in 1:nrounds) {
|
||||||
succ <- xgb.iter.update(bst$handle, dtrain, i - 1, obj)
|
succ <- xgb.iter.update(bst$handle, dtrain, i - 1, obj)
|
||||||
if (length(watchlist) != 0) {
|
if (length(watchlist) != 0) {
|
||||||
@ -96,6 +95,6 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
|||||||
cat(paste(msg, "\n", sep=""))
|
cat(paste(msg, "\n", sep=""))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bst$raw <- xgb.save.raw(bst$handle)
|
bst <- xgb.Booster.check(bst)
|
||||||
return(bst)
|
return(bst)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user