Bug mixing DMatrix's with and without feature names

This commit is contained in:
sinhrks
2016-04-29 13:51:34 +09:00
parent ff4dda2102
commit 6bab164d80
4 changed files with 25 additions and 14 deletions

View File

@@ -91,7 +91,7 @@ class TestBasic(unittest.TestCase):
# reset
dm.feature_names = None
assert dm.feature_names is None
self.assertEqual(dm.feature_names, ['f0', 'f1', 'f2', 'f3', 'f4'])
assert dm.feature_types is None
def test_feature_names(self):

View File

@@ -99,3 +99,20 @@ class TestModels(unittest.TestCase):
num_round = 2
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0, show_stdv=False)
def test_feature_names_validation(self):
X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,))
dm1 = xgb.DMatrix(X, y)
dm2 = xgb.DMatrix(X, y, feature_names=("a", "b", "c"))
bst = xgb.train([], dm1)
bst.predict(dm1) # success
self.assertRaises(ValueError, bst.predict, dm2)
bst.predict(dm1) # success
bst = xgb.train([], dm2)
bst.predict(dm2) # success
self.assertRaises(ValueError, bst.predict, dm1)
bst.predict(dm2) # success