[R] more attribute handling functionality

This commit is contained in:
Vadim Khotilovich
2016-05-14 18:11:29 -05:00
parent ea9285dd4f
commit 8664217a5a
7 changed files with 273 additions and 64 deletions

View File

@@ -2,31 +2,28 @@
# internal utility function
xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
if (typeof(cachelist) != "list") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
stop("xgb.Booster only accepts list of DMatrix as cachelist")
}
for (dm in cachelist) {
if (class(dm) != "xgb.DMatrix") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
stop("xgb.Booster only accepts list of DMatrix as cachelist")
}
}
handle <- .Call("XGBoosterCreate_R", cachelist, PACKAGE = "xgboost")
if (length(params) != 0) {
for (i in 1:length(params)) {
p <- params[i]
.Call("XGBoosterSetParam_R", handle, gsub("\\.", "_", names(p)), as.character(p),
PACKAGE = "xgboost")
}
}
if (!is.null(modelfile)) {
if (typeof(modelfile) == "character") {
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
} else if (typeof(modelfile) == "raw") {
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost")
} else {
stop("xgb.Booster: modelfile must be character or raw vector")
stop("modelfile must be character or raw vector")
}
}
return(structure(handle, class = "xgb.Booster.handle"))
class(handle) <- "xgb.Booster.handle"
if (length(params) > 0) {
xgb.parameters(handle) <- params
}
return(handle)
}
# Convert xgb.Booster.handle to xgb.Booster
@@ -38,6 +35,20 @@ xgb.handleToBooster <- function(handle, raw = NULL)
return(bst)
}
# Return a verified to be valid handle out of either xgb.Booster.handle or xgb.Booster
# internal utility function
xgb.get.handle <- function(object) {
handle <- switch(class(object)[1],
xgb.Booster = object$handle,
xgb.Booster.handle = object,
stop("argument must be of either xgb.Booster or xgb.Booster.handle class")
)
if (is.null(handle) | .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
stop("invalid xgb.Booster.handle")
}
handle
}
# Check whether an xgb.Booster object is complete
# internal utility function
xgb.Booster.check <- function(bst, saveraw = TRUE)
@@ -146,28 +157,43 @@ predict.xgb.Booster.handle <- function(object, ...) {
#' Accessors for serializable attributes of a model.
#'
#' These methods allow to manipulate key-value attribute strings of an xgboost model.
#' These methods allow to manipulate the key-value attribute strings of an xgboost model.
#'
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}.
#' @param which a non-empty character string specifying which attribute is to be accessed.
#' @param value a value of an attribute. Non-character values are converted to character.
#' When length of a \code{value} vector is more than one, only the first element is used.
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}.
#' @param name a non-empty character string specifying which attribute is to be accessed.
#' @param value a value of an attribute for \code{xgb.attr<-}; for \code{xgb.attributes<-}
#' it's a list (or an object coercible to a list) with the names of attributes to set
#' and the elements corresponding to attribute values.
#' Non-character values are converted to character.
#' When attribute value is not a scalar, only the first index is used.
#' Use \code{NULL} to remove an attribute.
#'
#' @details
#' Note that the xgboost model attributes are a separate concept from the attributes in R.
#' Specifically, they refer to key-value strings that can be attached to an xgboost model
#' and stored within the model's binary representation.
#' The primary purpose of xgboost model attributes is to store some meta-data about the model.
#' Note that they are a separate concept from the object attributes in R.
#' Specifically, they refer to key-value strings that can be attached to an xgboost model,
#' stored together with the model's binary representation, and accessed later
#' (from R or any other interface).
#' In contrast, any R-attribute assigned to an R-object of \code{xgb.Booster} class
#' would not be saved by \code{xgb.save}, since xgboost model is an external memory object
#' would not be saved by \code{xgb.save} because an xgboost model is an external memory object
#' and its serialization is handled extrnally.
#' Also, setting an attribute that has the same name as one of xgboost's parameters wouldn't
#' change the value of that parameter for a model.
#' Use \code{\link{`xgb.parameters<-`}} to set or change model parameters.
#'
#' Also note that the attribute setter would usually work more efficiently for \code{xgb.Booster.handle}
#' than for \code{xgb.Booster}, since only just a handle would need to be copied.
#' The attribute setters would usually work more efficiently for \code{xgb.Booster.handle}
#' than for \code{xgb.Booster}, since only just a handle (pointer) would need to be copied.
#'
#' The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
#' but doesn't delete the existing attributes which don't have their names in \code{names(attributes)}.
#'
#' @return
#' \code{xgb.attr} returns either a string value of an attribute
#' or \code{NULL} if an attribute wasn't stored in a model.
#'
#' \code{xgb.attributes} returns a list of all attribute stored in a model
#' or \code{NULL} if a model has no stored attributes.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' train <- agaricus.train
@@ -177,42 +203,117 @@ predict.xgb.Booster.handle <- function(object, ...) {
#'
#' xgb.attr(bst, "my_attribute") <- "my attribute value"
#' print(xgb.attr(bst, "my_attribute"))
#' xgb.attributes(bst) <- list(a = 123, b = "abc")
#'
#' xgb.save(bst, 'xgb.model')
#' bst1 <- xgb.load('xgb.model')
#' print(xgb.attr(bst1, "my_attribute"))
#'
#' print(xgb.attributes(bst1))
#'
#' # deletion:
#' xgb.attr(bst1, "my_attribute") <- NULL
#' print(xgb.attributes(bst1))
#' xgb.attributes(bst1) <- list(a = NULL, b = NULL)
#' print(xgb.attributes(bst1))
#'
#' @rdname xgb.attr
#' @export
xgb.attr <- function(object, which) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
.Call("XGBoosterGetAttr_R", handle, as.character(which)[1], PACKAGE="xgboost")
xgb.attr <- function(object, name) {
if (is.null(name) || nchar(as.character(name[1])) == 0) stop("invalid attribute name")
handle <- xgb.get.handle(object)
.Call("XGBoosterGetAttr_R", handle, as.character(name[1]), PACKAGE="xgboost")
}
#' @rdname xgb.attr
#' @export
`xgb.attr<-` <- function(object, which, value) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
# TODO: setting NULL value to remove an attribute
.Call("XGBoosterSetAttr_R", handle, as.character(which)[1], as.character(value)[1], PACKAGE="xgboost")
`xgb.attr<-` <- function(object, name, value) {
if (is.null(name) || nchar(as.character(name[1])) == 0) stop("invalid attribute name")
handle <- xgb.get.handle(object)
if (!is.null(value)) {
# Coerce the elements to be scalar strings.
# Q: should we warn user about non-scalar elements?
value <- as.character(value[1])
}
.Call("XGBoosterSetAttr_R", handle, as.character(name[1]), value, PACKAGE="xgboost")
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
object$raw <- xgb.save.raw(object$handle)
}
object
}
# Return a valid handle out of either xgb.Booster.handle or xgb.Booster
# internal utility function
xgb.get.handle <- function(object, caller="") {
handle = switch(class(object),
xgb.Booster = object$handle,
xgb.Booster.handle = object,
stop(caller, ": argument must be either xgb.Booster or xgb.Booster.handle")
)
if (is.null(handle) | .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
stop(caller, ": invalid xgb.Booster.handle")
}
handle
#' @rdname xgb.attr
#' @export
xgb.attributes <- function(object) {
handle <- xgb.get.handle(object)
attr_names <- .Call("XGBoosterGetAttrNames_R", handle, PACKAGE="xgboost")
if (is.null(attr_names)) return(NULL)
res <- lapply(attr_names, function(x) {
.Call("XGBoosterGetAttr_R", handle, x, PACKAGE="xgboost")
})
names(res) <- attr_names
res
}
#' @rdname xgb.attr
#' @export
`xgb.attributes<-` <- function(object, value) {
a <- as.list(value)
if (is.null(names(a)) || any(nchar(names(a)) == 0)) {
stop("attribute names cannot be empty strings")
}
# Coerce the elements to be scalar strings.
# Q: should we warn a user about non-scalar elements?
a <- lapply(a, function(x) {
if (is.null(x)) return(NULL)
as.character(x[1])
})
handle <- xgb.get.handle(object)
for (i in seq_along(a)) {
.Call("XGBoosterSetAttr_R", handle, names(a[i]), a[[i]], PACKAGE="xgboost")
}
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
}
object
}
#' Accessors for model parameters.
#'
#' Only the setter for xgboost parameters is currently implemented.
#'
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}.
#' @param value a list (or an object coercible to a list) with the names of parameters to set
#' and the elements corresponding to parameter values.
#'
#' @details
#' Note that the setter would usually work more efficiently for \code{xgb.Booster.handle}
#' than for \code{xgb.Booster}, since only just a handle would need to be copied.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' train <- agaricus.train
#'
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
#' eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
#'
#' xgb.parameters(bst) <- list(eta = 0.1)
#'
#' @rdname xgb.parameters
#' @export
`xgb.parameters<-` <- function(object, value) {
if (length(value) == 0) return(object)
p <- as.list(value)
if (is.null(names(p)) || any(nchar(names(p)) == 0)) {
stop("parameter names cannot be empty strings")
}
names(p) <- gsub("\\.", "_", names(p))
p <- lapply(p, function(x) as.character(x)[1])
handle <- xgb.get.handle(object)
for (i in seq_along(p)) {
.Call("XGBoosterSetParam_R", handle, names(p[i]), p[[i]], PACKAGE = "xgboost")
}
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
}
object
}