add nrow to getinfo
This commit is contained in:
parent
43c13d82ba
commit
a1e188aa75
@ -32,10 +32,15 @@ setMethod("getinfo", signature = "xgb.DMatrix",
|
|||||||
if (class(object) != "xgb.DMatrix") {
|
if (class(object) != "xgb.DMatrix") {
|
||||||
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix")
|
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix")
|
||||||
}
|
}
|
||||||
if (name != "label" && name != "weight" && name != "base_margin") {
|
if (name != "label" && name != "weight" &&
|
||||||
|
name != "base_margin" && name != "nrow") {
|
||||||
stop(paste("xgb.getinfo: unknown info name", name))
|
stop(paste("xgb.getinfo: unknown info name", name))
|
||||||
}
|
}
|
||||||
|
if (name != "nrow"){
|
||||||
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
|
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
|
||||||
|
} else {
|
||||||
|
ret <- .Call("XGDMatrixNumRow_R", object)
|
||||||
|
}
|
||||||
return(ret)
|
return(ret)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -52,7 +52,8 @@ setMethod("predict", signature = "xgb.Booster",
|
|||||||
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option),
|
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option),
|
||||||
as.integer(ntreelimit), PACKAGE = "xgboost")
|
as.integer(ntreelimit), PACKAGE = "xgboost")
|
||||||
if (predleaf){
|
if (predleaf){
|
||||||
if (length(ret) == nrow(newdata)){
|
len <- getinfo(newdata, "nrow")
|
||||||
|
if (length(ret) == len){
|
||||||
ret <- matrix(ret,ncol = 1)
|
ret <- matrix(ret,ncol = 1)
|
||||||
} else {
|
} else {
|
||||||
ret <- matrix(ret, ncol = nrow(newdata))
|
ret <- matrix(ret, ncol = nrow(newdata))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user