print method; construct from initial xgb.Booster
This commit is contained in:
parent
264c222fe0
commit
bdf14007b5
@ -15,8 +15,11 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
|
||||
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
|
||||
} else if (typeof(modelfile) == "raw") {
|
||||
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost")
|
||||
} else if (class(modelfile) == "xgb.Booster") {
|
||||
modelfile <- xgb.Booster.check(modelfile, saveraw=TRUE)
|
||||
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile$raw, PACKAGE = "xgboost")
|
||||
} else {
|
||||
stop("modelfile must be character or raw vector")
|
||||
stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object")
|
||||
}
|
||||
}
|
||||
class(handle) <- "xgb.Booster.handle"
|
||||
@ -28,8 +31,7 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
|
||||
|
||||
# Convert xgb.Booster.handle to xgb.Booster
|
||||
# internal utility function
|
||||
xgb.handleToBooster <- function(handle, raw = NULL)
|
||||
{
|
||||
xgb.handleToBooster <- function(handle, raw = NULL) {
|
||||
bst <- list(handle = handle, raw = raw)
|
||||
class(bst) <- "xgb.Booster"
|
||||
return(bst)
|
||||
@ -43,7 +45,7 @@ xgb.get.handle <- function(object) {
|
||||
xgb.Booster.handle = object,
|
||||
stop("argument must be of either xgb.Booster or xgb.Booster.handle class")
|
||||
)
|
||||
if (is.null(handle) | .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
|
||||
if (is.null(handle) || .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
|
||||
stop("invalid xgb.Booster.handle")
|
||||
}
|
||||
handle
|
||||
@ -51,8 +53,7 @@ xgb.get.handle <- function(object) {
|
||||
|
||||
# Check whether an xgb.Booster object is complete
|
||||
# internal utility function
|
||||
xgb.Booster.check <- function(bst, saveraw = TRUE)
|
||||
{
|
||||
xgb.Booster.check <- function(bst, saveraw = TRUE) {
|
||||
isnull <- is.null(bst$handle)
|
||||
if (!isnull) {
|
||||
isnull <- .Call("XGCheckNullPtr_R", bst$handle, PACKAGE="xgboost")
|
||||
@ -108,38 +109,23 @@ xgb.Booster.check <- function(bst, saveraw = TRUE)
|
||||
#' @export
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA,
|
||||
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
|
||||
if (class(object) != "xgb.Booster"){
|
||||
stop("predict: model in prediction must be of class xgb.Booster")
|
||||
} else {
|
||||
object <- xgb.Booster.check(object, saveraw = FALSE)
|
||||
}
|
||||
if (class(newdata) != "xgb.DMatrix") {
|
||||
|
||||
object <- xgb.Booster.check(object, saveraw = FALSE)
|
||||
if (class(newdata) != "xgb.DMatrix")
|
||||
newdata <- xgb.DMatrix(newdata, missing = missing)
|
||||
}
|
||||
if (is.null(ntreelimit)) {
|
||||
ntreelimit <- 0
|
||||
} else {
|
||||
if (ntreelimit < 1){
|
||||
stop("predict: ntreelimit must be equal to or greater than 1")
|
||||
}
|
||||
}
|
||||
option <- 0
|
||||
if (outputmargin) {
|
||||
option <- option + 1
|
||||
}
|
||||
if (predleaf) {
|
||||
option <- option + 2
|
||||
}
|
||||
ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option),
|
||||
if (is.null(ntreelimit))
|
||||
ntreelimit <- NVL(object$best_ntreelimit, 0)
|
||||
if (ntreelimit < 0)
|
||||
stop("ntreelimit must be positive")
|
||||
|
||||
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf)
|
||||
|
||||
ret <- .Call("XGBoosterPredict_R", object$handle, newdata, option[1],
|
||||
as.integer(ntreelimit), PACKAGE = "xgboost")
|
||||
if (predleaf){
|
||||
len <- getinfo(newdata, "nrow")
|
||||
if (length(ret) == len){
|
||||
ret <- matrix(ret,ncol = 1)
|
||||
} else {
|
||||
ret <- matrix(ret, ncol = len)
|
||||
ret <- t(ret)
|
||||
}
|
||||
len <- nrow(newdata)
|
||||
ret <- if (length(ret) == len) matrix(ret, ncol = 1)
|
||||
else t(matrix(ret, ncol = len))
|
||||
}
|
||||
return(ret)
|
||||
}
|
||||
@ -317,3 +303,87 @@ xgb.attributes <- function(object) {
|
||||
}
|
||||
object
|
||||
}
|
||||
|
||||
|
||||
|
||||
#' Print xgb.Booster
|
||||
#'
|
||||
#' Print information about xgb.Booster.
|
||||
#'
|
||||
#' @param x an xgb.Booster object
|
||||
#' @param verbose whether to print detailed data (e.g., attribute values)
|
||||
#' @param ... not currently used
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' train <- agaricus.train
|
||||
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||
#' eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
|
||||
#' attr(bst, 'myattr') <- 'memo'
|
||||
#'
|
||||
#' print(bst)
|
||||
#' print(bst, verbose=TRUE)
|
||||
#'
|
||||
#' @export
|
||||
print.xgb.Booster <- function(x, verbose=FALSE, ...) {
|
||||
cat('##### xgb.Booster\n')
|
||||
|
||||
if (is.null(x$handle) || .Call("XGCheckNullPtr_R", x$handle, PACKAGE="xgboost")) {
|
||||
cat("handle is invalid\n")
|
||||
return(x)
|
||||
}
|
||||
|
||||
cat('raw: ')
|
||||
if (!is.null(x$raw)) cat(format(object.size(x$raw), units="auto"), '\n')
|
||||
else cat('NULL\n')
|
||||
|
||||
if (!is.null(x$call)) {
|
||||
cat('call:\n ')
|
||||
print(x$call)
|
||||
}
|
||||
|
||||
if (!is.null(x$params)) {
|
||||
cat('params (as set within xgb.train):\n')
|
||||
cat( ' ',
|
||||
paste(names(x$params),
|
||||
paste0('"', unlist(x$params), '"'),
|
||||
sep=' = ', collapse=', '), '\n', sep='')
|
||||
}
|
||||
# TODO: need an interface to access all the xgboosts parameters
|
||||
|
||||
attrs <- xgb.attributes(x)
|
||||
if (length(attrs) > 0) {
|
||||
cat('xgb.attributes:\n')
|
||||
if (verbose) {
|
||||
cat( paste(paste0(' ',names(attrs)),
|
||||
paste0('"', unlist(attrs), '"'),
|
||||
sep=' = ', collapse='\n'), '\n', sep='')
|
||||
} else {
|
||||
cat(' ', paste(names(attrs), collapse=', '), '\n', sep='')
|
||||
}
|
||||
}
|
||||
|
||||
if (!is.null(x$callbacks) && length(x$callbacks) > 0) {
|
||||
cat('callbacks:\n')
|
||||
lapply(callback.calls(x$callbacks), function(x) {
|
||||
cat(' ')
|
||||
print(x)
|
||||
})
|
||||
}
|
||||
|
||||
for (n in setdiff(names(x), c('handle', 'raw', 'call', 'params', 'callbacks','evaluation_log'))) {
|
||||
if (is.atomic(x[[n]])) {
|
||||
cat(n, ': ', x[[n]], '\n', sep='')
|
||||
} else {
|
||||
cat(n, ':\n\t', sep='')
|
||||
print(x[[n]])
|
||||
}
|
||||
}
|
||||
|
||||
if (!is.null(x$evaluation_log)) {
|
||||
cat('evaluation_log:\n')
|
||||
print(x$evaluation_log, row.names = FALSE, topn = 2)
|
||||
}
|
||||
|
||||
invisible(x)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user