From a1e188aa7503d1a4ccdc09ef1e7e5663f02fb4b7 Mon Sep 17 00:00:00 2001 From: hetong Date: Mon, 19 Jan 2015 13:35:11 -0800 Subject: [PATCH] add nrow to getinfo --- R-package/R/getinfo.xgb.DMatrix.R | 9 +++++++-- R-package/R/predict.xgb.Booster.R | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/R-package/R/getinfo.xgb.DMatrix.R b/R-package/R/getinfo.xgb.DMatrix.R index ed61ba654..830bbbffa 100644 --- a/R-package/R/getinfo.xgb.DMatrix.R +++ b/R-package/R/getinfo.xgb.DMatrix.R @@ -32,10 +32,15 @@ setMethod("getinfo", signature = "xgb.DMatrix", if (class(object) != "xgb.DMatrix") { stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix") } - if (name != "label" && name != "weight" && name != "base_margin") { + if (name != "label" && name != "weight" && + name != "base_margin" && name != "nrow") { stop(paste("xgb.getinfo: unknown info name", name)) } - ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost") + if (name != "nrow"){ + ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost") + } else { + ret <- .Call("XGDMatrixNumRow_R", object) + } return(ret) }) diff --git a/R-package/R/predict.xgb.Booster.R b/R-package/R/predict.xgb.Booster.R index 8e1982049..cdb3f3dc1 100644 --- a/R-package/R/predict.xgb.Booster.R +++ b/R-package/R/predict.xgb.Booster.R @@ -52,7 +52,8 @@ setMethod("predict", signature = "xgb.Booster", ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option), as.integer(ntreelimit), PACKAGE = "xgboost") if (predleaf){ - if (length(ret) == nrow(newdata)){ + len <- getinfo(newdata, "nrow") + if (length(ret) == len){ ret <- matrix(ret,ncol = 1) } else { ret <- matrix(ret, ncol = nrow(newdata))