Merge pull request #149 from tqchen/unity

add proptype of predleaf in R, fix bug in lambda rank
This commit is contained in:
Tianqi Chen 2015-01-19 09:08:19 -08:00
commit 1ea23d3390
4 changed files with 16 additions and 8 deletions

View File

@ -11,6 +11,7 @@ setClass("xgb.Booster")
#' value of sum of functions, when outputmargin=TRUE, the prediction is #' value of sum of functions, when outputmargin=TRUE, the prediction is
#' untransformed margin value. In logistic regression, outputmargin=T will #' untransformed margin value. In logistic regression, outputmargin=T will
#' output value before logistic transformation. #' output value before logistic transformation.
#' @param predleaf whether predict leaf index instead
#' @param ntreelimit limit number of trees used in prediction, this parameter is #' @param ntreelimit limit number of trees used in prediction, this parameter is
#' only valid for gbtree, but not for gblinear. set it to be value bigger #' only valid for gbtree, but not for gblinear. set it to be value bigger
#' than 0. It will use all trees by default. #' than 0. It will use all trees by default.
@ -25,7 +26,7 @@ setClass("xgb.Booster")
#' @export #' @export
#' #'
setMethod("predict", signature = "xgb.Booster", setMethod("predict", signature = "xgb.Booster",
definition = function(object, newdata, missing = NULL, outputmargin = FALSE, ntreelimit = NULL) { definition = function(object, newdata, missing = NULL, outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
if (class(newdata) != "xgb.DMatrix") { if (class(newdata) != "xgb.DMatrix") {
if (is.null(missing)) { if (is.null(missing)) {
newdata <- xgb.DMatrix(newdata) newdata <- xgb.DMatrix(newdata)
@ -40,7 +41,14 @@ setMethod("predict", signature = "xgb.Booster",
stop("predict: ntreelimit must be equal to or greater than 1") stop("predict: ntreelimit must be equal to or greater than 1")
} }
} }
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(outputmargin), as.integer(ntreelimit), PACKAGE = "xgboost") option = 0
if (outputmargin) {
option <- option + 1
}
if (predleaf) {
option <- option + 2
}
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(predleaf), as.integer(ntreelimit), PACKAGE = "xgboost")
return(ret) return(ret)
}) })

View File

@ -248,12 +248,12 @@ extern "C" {
asInteger(iter), asInteger(iter),
BeginPtr(vec_dmats), BeginPtr(vec_sptr), len)); BeginPtr(vec_dmats), BeginPtr(vec_sptr), len));
} }
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit) { SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit) {
_WrapperBegin(); _WrapperBegin();
bst_ulong olen; bst_ulong olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle), const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat), R_ExternalPtrAddr(dmat),
asInteger(output_margin), asInteger(option_mask),
asInteger(ntree_limit), asInteger(ntree_limit),
&olen); &olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen)); SEXP ret = PROTECT(allocVector(REALSXP, olen));

View File

@ -111,10 +111,10 @@ extern "C" {
* \brief make prediction based on dmat * \brief make prediction based on dmat
* \param handle handle * \param handle handle
* \param dmat data matrix * \param dmat data matrix
* \param output_margin whether only output raw margin value * \param option_mask output_margin:1 predict_leaf:2
* \param ntree_limit limit number of trees used in prediction * \param ntree_limit limit number of trees used in prediction
*/ */
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit); SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit);
/*! /*!
* \brief load model from existing file * \brief load model from existing file
* \param handle handle * \param handle handle

View File

@ -348,9 +348,9 @@ class LambdaRankObj : public IObjFunction {
float h = loss.SecondOrderGradient(p, 1.0f); float h = loss.SecondOrderGradient(p, 1.0f);
// accumulate gradient and hessian in both pid, and nid // accumulate gradient and hessian in both pid, and nid
gpair[pos.rindex].grad += g * w; gpair[pos.rindex].grad += g * w;
gpair[pos.rindex].hess += 2.0f * h; gpair[pos.rindex].hess += 2.0f * w * h;
gpair[neg.rindex].grad -= g * w; gpair[neg.rindex].grad -= g * w;
gpair[neg.rindex].hess += 2.0f * h; gpair[neg.rindex].hess += 2.0f * w * h;
} }
} }
} }