Fix numpy array check logic
This commit is contained in:
parent
cf2ec238a4
commit
f7d434aec2
@ -220,7 +220,7 @@ class DMatrix(object):
|
||||
self._init_from_csr(data)
|
||||
elif isinstance(data, scipy.sparse.csc_matrix):
|
||||
self._init_from_csc(data)
|
||||
elif isinstance(data, np.ndarray) and len(data.shape) == 2:
|
||||
elif isinstance(data, np.ndarray):
|
||||
self._init_from_npy2d(data, missing)
|
||||
else:
|
||||
try:
|
||||
@ -278,6 +278,8 @@ class DMatrix(object):
|
||||
"""
|
||||
Initialize data from a 2-D numpy matrix.
|
||||
"""
|
||||
if len(mat.shape) != 2:
|
||||
raise ValueError('Input numpy.ndarray must be 2 dimensional')
|
||||
data = np.array(mat.reshape(mat.size), dtype=np.float32)
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
@ -792,11 +794,11 @@ class Booster(object):
|
||||
fname : string or a memory buffer
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
"""
|
||||
if isinstance(fname, str): # assume file name
|
||||
if isinstance(fname, STRING_TYPES): # assume file name
|
||||
if os.path.exists(fname):
|
||||
_LIB.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||
else:
|
||||
raise ValueError("No such file: {0}")
|
||||
raise ValueError("No such file: {0}".format(fname))
|
||||
else:
|
||||
buf = fname
|
||||
length = ctypes.c_ulong(len(buf))
|
||||
@ -851,6 +853,8 @@ class Booster(object):
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
else:
|
||||
if fmap != '' and not os.path.exists(fmap):
|
||||
raise ValueError("No such file: {0}".format(fmap))
|
||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||
c_str(fmap),
|
||||
int(with_stats),
|
||||
|
||||
@ -83,6 +83,31 @@ class TestBasic(unittest.TestCase):
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
model_file='incorrect_path')
|
||||
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
model_file=u'不正なパス')
|
||||
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(data)
|
||||
assert dm.num_row() == 5
|
||||
assert dm.num_col() == 5
|
||||
|
||||
data = np.matrix([[1, 2], [3, 4]])
|
||||
dm = xgb.DMatrix(data)
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 2
|
||||
|
||||
# 0d array
|
||||
self.assertRaises(ValueError, xgb.DMatrix, np.array(1))
|
||||
# 1d array
|
||||
self.assertRaises(ValueError, xgb.DMatrix, np.array([1, 2, 3]))
|
||||
# 3d array
|
||||
data = np.random.randn(5, 5, 5)
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data)
|
||||
# object dtype
|
||||
data = np.array([['a', 'b'], ['c', 'd']])
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data)
|
||||
|
||||
def test_plotting(self):
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
# plotting
|
||||
@ -127,4 +152,3 @@ class TestBasic(unittest.TestCase):
|
||||
assert isinstance(g, Digraph)
|
||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user