R accessors for model attributes

This commit is contained in:
Vadim Khotilovich
2016-05-02 00:20:44 -05:00
parent 0839aed380
commit 79c7c9e5bb
4 changed files with 140 additions and 0 deletions

View File

@@ -142,3 +142,77 @@ predict.xgb.Booster.handle <- function(object, ...) {
ret <- predict(bst, ...)
return(ret)
}
#' Accessors for serializable attributes of a model.
#'
#' These methods allow to manipulate key-value attribute strings of an xgboost model.
#'
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}.
#' @param which a non-empty character string specifying which attribute is to be accessed.
#' @param value a value of an attribute. Non-character values are converted to character.
#' When length of a \code{value} vector is more than one, only the first element is used.
#'
#' @details
#' Note that the xgboost model attributes are a separate concept from the attributes in R.
#' Specifically, they refer to key-value strings that can be attached to an xgboost model
#' and stored within the model's binary representation.
#' In contrast, any R-attribute assigned to an R-object of \code{xgb.Booster} class
#' would not be saved by \code{xgb.save}, since xgboost model is an external memory object
#' and its serialization is handled extrnally.
#'
#' Also note that the attribute setter would usually work more efficiently for \code{xgb.Booster.handle}
#' than for \code{xgb.Booster}, since only just a handle would need to be copied.
#'
#' @return
#' \code{xgb.attr} returns either a string value of an attribute
#' or \code{NULL} if an attribute wasn't stored in a model.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' train <- agaricus.train
#'
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
#' eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
#'
#' xgb.attr(bst, "my_attribute") <- "my attribute value"
#' print(xgb.attr(bst, "my_attribute"))
#'
#' xgb.save(bst, 'xgb.model')
#' bst1 <- xgb.load('xgb.model')
#' print(xgb.attr(bst1, "my_attribute"))
#'
#' @rdname xgb.attr
#' @export
xgb.attr <- function(object, which) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
.Call("XGBoosterGetAttr_R", handle, as.character(which)[1], PACKAGE="xgboost")
}
#' @rdname xgb.attr
#' @export
`xgb.attr<-` <- function(object, which, value) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
# TODO: setting NULL value to remove an attribute
.Call("XGBoosterSetAttr_R", handle, as.character(which)[1], as.character(value)[1], PACKAGE="xgboost")
if (is(object, 'xgb.Booster') & !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
}
object
}
# Return a valid handle out of either xgb.Booster.handle or xgb.Booster
# internal utility function
xgb.get.handle <- function(object, caller="") {
handle = switch(class(object),
xgb.Booster = object$handle,
xgb.Booster.handle = object,
stop(caller, ": argument must be either xgb.Booster or xgb.Booster.handle")
)
if (is.null(handle) | .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
stop(caller, ": invalid xgb.Booster.handle")
}
handle
}