add ntree limit

This commit is contained in:
tqchen
2014-09-01 15:10:19 -07:00
parent 4c451de90b
commit 4592e500cb
10 changed files with 53 additions and 23 deletions

View File

@@ -11,7 +11,8 @@ setClass("xgb.Booster")
#' value of sum of functions, when outputmargin=TRUE, the prediction is
#' untransformed margin value. In logistic regression, outputmargin=T will
#' output value before logistic transformation.
#'
#' @param ntreelimit limit number of trees used in prediction, this parameter is only valid for gbtree, but not for gblinear.
#' set it to be value bigger than 0
#' @examples
#' data(iris)
#' bst <- xgboost(as.matrix(iris[,1:4]),as.numeric(iris[,5]), nrounds = 2)
@@ -19,11 +20,18 @@ setClass("xgb.Booster")
#' @export
#'
setMethod("predict", signature = "xgb.Booster",
definition = function(object, newdata, outputmargin = FALSE) {
definition = function(object, newdata, outputmargin = FALSE, ntreelimit = NULL) {
if (class(newdata) != "xgb.DMatrix") {
newdata <- xgb.DMatrix(newdata)
}
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(outputmargin), PACKAGE = "xgboost")
if (is.null(ntreelimit)) {
ntreelimit <- 0
} else {
if (ntreelimit < 1){
stop("predict: ntreelimit must be greater equal than 1")
}
}
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(outputmargin), as.integer(ntreelimit), PACKAGE = "xgboost")
return(ret)
})

View File

@@ -247,12 +247,13 @@ extern "C" {
&vec_dmats[0], &vec_sptr[0], len));
_WrapperEnd();
}
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) {
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit) {
_WrapperBegin();
bst_ulong olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat),
asInteger(output_margin),
asInteger(ntree_limit),
&olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {

View File

@@ -107,8 +107,9 @@ extern "C" {
* \param handle handle
* \param dmat data matrix
* \param output_margin whether only output raw margin value
* \param ntree_limit limit number of trees used in prediction
*/
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin);
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit);
/*!
* \brief load model from existing file
* \param handle handle