add nrow to getinfo
This commit is contained in:
parent
43c13d82ba
commit
a1e188aa75
@ -32,10 +32,15 @@ setMethod("getinfo", signature = "xgb.DMatrix",
|
||||
if (class(object) != "xgb.DMatrix") {
|
||||
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix")
|
||||
}
|
||||
if (name != "label" && name != "weight" && name != "base_margin") {
|
||||
if (name != "label" && name != "weight" &&
|
||||
name != "base_margin" && name != "nrow") {
|
||||
stop(paste("xgb.getinfo: unknown info name", name))
|
||||
}
|
||||
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
|
||||
if (name != "nrow"){
|
||||
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
|
||||
} else {
|
||||
ret <- .Call("XGDMatrixNumRow_R", object)
|
||||
}
|
||||
return(ret)
|
||||
})
|
||||
|
||||
|
||||
@ -52,7 +52,8 @@ setMethod("predict", signature = "xgb.Booster",
|
||||
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option),
|
||||
as.integer(ntreelimit), PACKAGE = "xgboost")
|
||||
if (predleaf){
|
||||
if (length(ret) == nrow(newdata)){
|
||||
len <- getinfo(newdata, "nrow")
|
||||
if (length(ret) == len){
|
||||
ret <- matrix(ret,ncol = 1)
|
||||
} else {
|
||||
ret <- matrix(ret, ncol = nrow(newdata))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user