Fix numpy array check logic
This commit is contained in:
@@ -83,6 +83,31 @@ class TestBasic(unittest.TestCase):
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
model_file='incorrect_path')
|
||||
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
model_file=u'不正なパス')
|
||||
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(data)
|
||||
assert dm.num_row() == 5
|
||||
assert dm.num_col() == 5
|
||||
|
||||
data = np.matrix([[1, 2], [3, 4]])
|
||||
dm = xgb.DMatrix(data)
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 2
|
||||
|
||||
# 0d array
|
||||
self.assertRaises(ValueError, xgb.DMatrix, np.array(1))
|
||||
# 1d array
|
||||
self.assertRaises(ValueError, xgb.DMatrix, np.array([1, 2, 3]))
|
||||
# 3d array
|
||||
data = np.random.randn(5, 5, 5)
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data)
|
||||
# object dtype
|
||||
data = np.array([['a', 'b'], ['c', 'd']])
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data)
|
||||
|
||||
def test_plotting(self):
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
# plotting
|
||||
@@ -127,4 +152,3 @@ class TestBasic(unittest.TestCase):
|
||||
assert isinstance(g, Digraph)
|
||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user