add nrow to getinfo

This commit is contained in:
hetong 2015-01-19 13:35:11 -08:00
parent 43c13d82ba
commit a1e188aa75
2 changed files with 9 additions and 3 deletions

View File

@ -32,10 +32,15 @@ setMethod("getinfo", signature = "xgb.DMatrix",
if (class(object) != "xgb.DMatrix") {
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix")
}
if (name != "label" && name != "weight" && name != "base_margin") {
if (name != "label" && name != "weight" &&
name != "base_margin" && name != "nrow") {
stop(paste("xgb.getinfo: unknown info name", name))
}
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
if (name != "nrow"){
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
} else {
ret <- .Call("XGDMatrixNumRow_R", object)
}
return(ret)
})

View File

@ -52,7 +52,8 @@ setMethod("predict", signature = "xgb.Booster",
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option),
as.integer(ntreelimit), PACKAGE = "xgboost")
if (predleaf){
if (length(ret) == nrow(newdata)){
len <- getinfo(newdata, "nrow")
if (length(ret) == len){
ret <- matrix(ret,ncol = 1)
} else {
ret <- matrix(ret, ncol = nrow(newdata))