Fix numpy array check logic

This commit is contained in:
sinhrks
2015-09-16 20:47:37 +09:00
parent cf2ec238a4
commit f7d434aec2
2 changed files with 32 additions and 4 deletions

View File

@@ -220,7 +220,7 @@ class DMatrix(object):
self._init_from_csr(data)
elif isinstance(data, scipy.sparse.csc_matrix):
self._init_from_csc(data)
elif isinstance(data, np.ndarray) and len(data.shape) == 2:
elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, missing)
else:
try:
@@ -278,6 +278,8 @@ class DMatrix(object):
"""
Initialize data from a 2-D numpy matrix.
"""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
data = np.array(mat.reshape(mat.size), dtype=np.float32)
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
@@ -792,11 +794,11 @@ class Booster(object):
fname : string or a memory buffer
Input file name or memory buffer(see also save_raw)
"""
if isinstance(fname, str): # assume file name
if isinstance(fname, STRING_TYPES): # assume file name
if os.path.exists(fname):
_LIB.XGBoosterLoadModel(self.handle, c_str(fname))
else:
raise ValueError("No such file: {0}")
raise ValueError("No such file: {0}".format(fname))
else:
buf = fname
length = ctypes.c_ulong(len(buf))
@@ -851,6 +853,8 @@ class Booster(object):
ctypes.byref(length),
ctypes.byref(sarr)))
else:
if fmap != '' and not os.path.exists(fmap):
raise ValueError("No such file: {0}".format(fmap))
_check_call(_LIB.XGBoosterDumpModel(self.handle,
c_str(fmap),
int(with_stats),