From 43c13d82badc3cdd0ccca92a7c2f0d8b0ccdf8c7 Mon Sep 17 00:00:00 2001 From: hetong Date: Mon, 19 Jan 2015 10:34:14 -0800 Subject: [PATCH] add leaf example in R --- R-package/R/predict.xgb.Booster.R | 16 +++++++++++++--- R-package/demo/predict_leaf_indices.R | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 R-package/demo/predict_leaf_indices.R diff --git a/R-package/R/predict.xgb.Booster.R b/R-package/R/predict.xgb.Booster.R index 62a64f0b5..8e1982049 100644 --- a/R-package/R/predict.xgb.Booster.R +++ b/R-package/R/predict.xgb.Booster.R @@ -11,7 +11,7 @@ setClass("xgb.Booster") #' value of sum of functions, when outputmargin=TRUE, the prediction is #' untransformed margin value. In logistic regression, outputmargin=T will #' output value before logistic transformation. -#' @param predleaf whether predict leaf index instead +#' @param predleaf whether predict leaf index instead. If set to TRUE, the output will be a matrix object. #' @param ntreelimit limit number of trees used in prediction, this parameter is #' only valid for gbtree, but not for gblinear. set it to be value bigger #' than 0. It will use all trees by default. @@ -26,7 +26,8 @@ setClass("xgb.Booster") #' @export #' setMethod("predict", signature = "xgb.Booster", - definition = function(object, newdata, missing = NULL, outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) { + definition = function(object, newdata, missing = NULL, + outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) { if (class(newdata) != "xgb.DMatrix") { if (is.null(missing)) { newdata <- xgb.DMatrix(newdata) @@ -48,7 +49,16 @@ setMethod("predict", signature = "xgb.Booster", if (predleaf) { option <- option + 2 } - ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option), as.integer(ntreelimit), PACKAGE = "xgboost") + ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option), + as.integer(ntreelimit), PACKAGE = "xgboost") + if (predleaf){ + if (length(ret) == nrow(newdata)){ + ret <- matrix(ret,ncol = 1) + } else { + ret <- matrix(ret, ncol = nrow(newdata)) + ret <- t(ret) + } + } return(ret) }) diff --git a/R-package/demo/predict_leaf_indices.R b/R-package/demo/predict_leaf_indices.R new file mode 100644 index 000000000..1fc64ba4a --- /dev/null +++ b/R-package/demo/predict_leaf_indices.R @@ -0,0 +1,22 @@ +require(xgboost) +# load in the agaricus dataset +data(agaricus.train, package='xgboost') +data(agaricus.test, package='xgboost') +dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) +dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) + +param <- list(max.depth=2,eta=1,silent=1,objective='binary:logistic') +watchlist <- list(eval = dtest, train = dtrain) +nround = 5 + +# training the model for two rounds +bst = xgb.train(param, dtrain, nround, watchlist) +cat('start testing prediction from first n trees\n') +labels <- getinfo(dtest,'label') + +### predict using first 2 tree +pred_with_leaf = predict(bst, dtest, ntreelimit = 2, predleaf = TRUE) +head(pred_with_leaf) +# by default, we predict using all the trees +pred_with_leaf = predict(bst, dtest, predleaf = TRUE) +head(pred_with_leaf)