Fix matrix attributes not sliced (#4311)
This commit is contained in:
parent
5c2575535f
commit
956e73f183
@ -301,12 +301,17 @@ slice.xgb.DMatrix <- function(object, idxset, ...) {
|
|||||||
|
|
||||||
attr_list <- attributes(object)
|
attr_list <- attributes(object)
|
||||||
nr <- nrow(object)
|
nr <- nrow(object)
|
||||||
len <- sapply(attr_list, length)
|
len <- sapply(attr_list, NROW)
|
||||||
ind <- which(len == nr)
|
ind <- which(len == nr)
|
||||||
if (length(ind) > 0) {
|
if (length(ind) > 0) {
|
||||||
nms <- names(attr_list)[ind]
|
nms <- names(attr_list)[ind]
|
||||||
for (i in seq_along(ind)) {
|
for (i in seq_along(ind)) {
|
||||||
attr(ret, nms[i]) <- attr(object, nms[i])[idxset]
|
obj_attr <- attr(object, nms[i])
|
||||||
|
if (NCOL(obj_attr) > 1) {
|
||||||
|
attr(ret, nms[i]) <- obj_attr[idxset,]
|
||||||
|
} else {
|
||||||
|
attr(ret, nms[i]) <- obj_attr[idxset]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return(structure(ret, class = "xgb.DMatrix"))
|
return(structure(ret, class = "xgb.DMatrix"))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user