Fix slice and get info. (#5552)
This commit is contained in:
@@ -188,9 +188,10 @@ getinfo <- function(object, ...) UseMethod("getinfo")
|
||||
getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
if (typeof(name) != "character" ||
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow')) {
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||
'label_lower_bound', 'label_upper_bound')) {
|
||||
stop("getinfo: name must be one of the following\n",
|
||||
" 'label', 'weight', 'base_margin', 'nrow'")
|
||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
|
||||
}
|
||||
if (name != "nrow"){
|
||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||
|
||||
@@ -50,6 +50,12 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
|
||||
labels <- getinfo(dtest, 'label')
|
||||
expect_equal(test_label, getinfo(dtest, 'label'))
|
||||
|
||||
expect_true(setinfo(dtest, 'label_lower_bound', test_label))
|
||||
expect_equal(test_label, getinfo(dtest, 'label_lower_bound'))
|
||||
|
||||
expect_true(setinfo(dtest, 'label_upper_bound', test_label))
|
||||
expect_equal(test_label, getinfo(dtest, 'label_upper_bound'))
|
||||
|
||||
expect_true(length(getinfo(dtest, 'weight')) == 0)
|
||||
expect_true(length(getinfo(dtest, 'base_margin')) == 0)
|
||||
|
||||
@@ -59,7 +65,7 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
|
||||
expect_error(setinfo(dtest, 'group', test_label))
|
||||
|
||||
# providing character values will give a warning
|
||||
expect_warning( setinfo(dtest, 'weight', rep('a', nrow(test_data))) )
|
||||
expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data))))
|
||||
|
||||
# any other label should error
|
||||
expect_error(setinfo(dtest, 'asdf', test_label))
|
||||
|
||||
Reference in New Issue
Block a user