print method; construct from initial xgb.Booster
This commit is contained in:
parent
264c222fe0
commit
bdf14007b5
@ -15,8 +15,11 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
|
|||||||
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
|
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
|
||||||
} else if (typeof(modelfile) == "raw") {
|
} else if (typeof(modelfile) == "raw") {
|
||||||
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost")
|
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost")
|
||||||
|
} else if (class(modelfile) == "xgb.Booster") {
|
||||||
|
modelfile <- xgb.Booster.check(modelfile, saveraw=TRUE)
|
||||||
|
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile$raw, PACKAGE = "xgboost")
|
||||||
} else {
|
} else {
|
||||||
stop("modelfile must be character or raw vector")
|
stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
class(handle) <- "xgb.Booster.handle"
|
class(handle) <- "xgb.Booster.handle"
|
||||||
@ -28,8 +31,7 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
|
|||||||
|
|
||||||
# Convert xgb.Booster.handle to xgb.Booster
|
# Convert xgb.Booster.handle to xgb.Booster
|
||||||
# internal utility function
|
# internal utility function
|
||||||
xgb.handleToBooster <- function(handle, raw = NULL)
|
xgb.handleToBooster <- function(handle, raw = NULL) {
|
||||||
{
|
|
||||||
bst <- list(handle = handle, raw = raw)
|
bst <- list(handle = handle, raw = raw)
|
||||||
class(bst) <- "xgb.Booster"
|
class(bst) <- "xgb.Booster"
|
||||||
return(bst)
|
return(bst)
|
||||||
@ -43,7 +45,7 @@ xgb.get.handle <- function(object) {
|
|||||||
xgb.Booster.handle = object,
|
xgb.Booster.handle = object,
|
||||||
stop("argument must be of either xgb.Booster or xgb.Booster.handle class")
|
stop("argument must be of either xgb.Booster or xgb.Booster.handle class")
|
||||||
)
|
)
|
||||||
if (is.null(handle) | .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
|
if (is.null(handle) || .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
|
||||||
stop("invalid xgb.Booster.handle")
|
stop("invalid xgb.Booster.handle")
|
||||||
}
|
}
|
||||||
handle
|
handle
|
||||||
@ -51,8 +53,7 @@ xgb.get.handle <- function(object) {
|
|||||||
|
|
||||||
# Check whether an xgb.Booster object is complete
|
# Check whether an xgb.Booster object is complete
|
||||||
# internal utility function
|
# internal utility function
|
||||||
xgb.Booster.check <- function(bst, saveraw = TRUE)
|
xgb.Booster.check <- function(bst, saveraw = TRUE) {
|
||||||
{
|
|
||||||
isnull <- is.null(bst$handle)
|
isnull <- is.null(bst$handle)
|
||||||
if (!isnull) {
|
if (!isnull) {
|
||||||
isnull <- .Call("XGCheckNullPtr_R", bst$handle, PACKAGE="xgboost")
|
isnull <- .Call("XGCheckNullPtr_R", bst$handle, PACKAGE="xgboost")
|
||||||
@ -108,38 +109,23 @@ xgb.Booster.check <- function(bst, saveraw = TRUE)
|
|||||||
#' @export
|
#' @export
|
||||||
predict.xgb.Booster <- function(object, newdata, missing = NA,
|
predict.xgb.Booster <- function(object, newdata, missing = NA,
|
||||||
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
|
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
|
||||||
if (class(object) != "xgb.Booster"){
|
|
||||||
stop("predict: model in prediction must be of class xgb.Booster")
|
|
||||||
} else {
|
|
||||||
object <- xgb.Booster.check(object, saveraw = FALSE)
|
object <- xgb.Booster.check(object, saveraw = FALSE)
|
||||||
}
|
if (class(newdata) != "xgb.DMatrix")
|
||||||
if (class(newdata) != "xgb.DMatrix") {
|
|
||||||
newdata <- xgb.DMatrix(newdata, missing = missing)
|
newdata <- xgb.DMatrix(newdata, missing = missing)
|
||||||
}
|
if (is.null(ntreelimit))
|
||||||
if (is.null(ntreelimit)) {
|
ntreelimit <- NVL(object$best_ntreelimit, 0)
|
||||||
ntreelimit <- 0
|
if (ntreelimit < 0)
|
||||||
} else {
|
stop("ntreelimit must be positive")
|
||||||
if (ntreelimit < 1){
|
|
||||||
stop("predict: ntreelimit must be equal to or greater than 1")
|
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf)
|
||||||
}
|
|
||||||
}
|
ret <- .Call("XGBoosterPredict_R", object$handle, newdata, option[1],
|
||||||
option <- 0
|
|
||||||
if (outputmargin) {
|
|
||||||
option <- option + 1
|
|
||||||
}
|
|
||||||
if (predleaf) {
|
|
||||||
option <- option + 2
|
|
||||||
}
|
|
||||||
ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option),
|
|
||||||
as.integer(ntreelimit), PACKAGE = "xgboost")
|
as.integer(ntreelimit), PACKAGE = "xgboost")
|
||||||
if (predleaf){
|
if (predleaf){
|
||||||
len <- getinfo(newdata, "nrow")
|
len <- nrow(newdata)
|
||||||
if (length(ret) == len){
|
ret <- if (length(ret) == len) matrix(ret, ncol = 1)
|
||||||
ret <- matrix(ret,ncol = 1)
|
else t(matrix(ret, ncol = len))
|
||||||
} else {
|
|
||||||
ret <- matrix(ret, ncol = len)
|
|
||||||
ret <- t(ret)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return(ret)
|
return(ret)
|
||||||
}
|
}
|
||||||
@ -317,3 +303,87 @@ xgb.attributes <- function(object) {
|
|||||||
}
|
}
|
||||||
object
|
object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#' Print xgb.Booster
|
||||||
|
#'
|
||||||
|
#' Print information about xgb.Booster.
|
||||||
|
#'
|
||||||
|
#' @param x an xgb.Booster object
|
||||||
|
#' @param verbose whether to print detailed data (e.g., attribute values)
|
||||||
|
#' @param ... not currently used
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' data(agaricus.train, package='xgboost')
|
||||||
|
#' train <- agaricus.train
|
||||||
|
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||||
|
#' eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
|
||||||
|
#' attr(bst, 'myattr') <- 'memo'
|
||||||
|
#'
|
||||||
|
#' print(bst)
|
||||||
|
#' print(bst, verbose=TRUE)
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
print.xgb.Booster <- function(x, verbose=FALSE, ...) {
|
||||||
|
cat('##### xgb.Booster\n')
|
||||||
|
|
||||||
|
if (is.null(x$handle) || .Call("XGCheckNullPtr_R", x$handle, PACKAGE="xgboost")) {
|
||||||
|
cat("handle is invalid\n")
|
||||||
|
return(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
cat('raw: ')
|
||||||
|
if (!is.null(x$raw)) cat(format(object.size(x$raw), units="auto"), '\n')
|
||||||
|
else cat('NULL\n')
|
||||||
|
|
||||||
|
if (!is.null(x$call)) {
|
||||||
|
cat('call:\n ')
|
||||||
|
print(x$call)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is.null(x$params)) {
|
||||||
|
cat('params (as set within xgb.train):\n')
|
||||||
|
cat( ' ',
|
||||||
|
paste(names(x$params),
|
||||||
|
paste0('"', unlist(x$params), '"'),
|
||||||
|
sep=' = ', collapse=', '), '\n', sep='')
|
||||||
|
}
|
||||||
|
# TODO: need an interface to access all the xgboosts parameters
|
||||||
|
|
||||||
|
attrs <- xgb.attributes(x)
|
||||||
|
if (length(attrs) > 0) {
|
||||||
|
cat('xgb.attributes:\n')
|
||||||
|
if (verbose) {
|
||||||
|
cat( paste(paste0(' ',names(attrs)),
|
||||||
|
paste0('"', unlist(attrs), '"'),
|
||||||
|
sep=' = ', collapse='\n'), '\n', sep='')
|
||||||
|
} else {
|
||||||
|
cat(' ', paste(names(attrs), collapse=', '), '\n', sep='')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is.null(x$callbacks) && length(x$callbacks) > 0) {
|
||||||
|
cat('callbacks:\n')
|
||||||
|
lapply(callback.calls(x$callbacks), function(x) {
|
||||||
|
cat(' ')
|
||||||
|
print(x)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for (n in setdiff(names(x), c('handle', 'raw', 'call', 'params', 'callbacks','evaluation_log'))) {
|
||||||
|
if (is.atomic(x[[n]])) {
|
||||||
|
cat(n, ': ', x[[n]], '\n', sep='')
|
||||||
|
} else {
|
||||||
|
cat(n, ':\n\t', sep='')
|
||||||
|
print(x[[n]])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is.null(x$evaluation_log)) {
|
||||||
|
cat('evaluation_log:\n')
|
||||||
|
print(x$evaluation_log, row.names = FALSE, topn = 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
invisible(x)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user