BUG: incorrect model_file results in segfault

This commit is contained in:
sinhrks 2015-09-16 21:53:51 +09:00
parent ae43fd7c7a
commit db0c9e1c2d
2 changed files with 15 additions and 1 deletions

View File

@ -743,7 +743,10 @@ class Booster(object):
Input file name or memory buffer(see also save_raw) Input file name or memory buffer(see also save_raw)
""" """
if isinstance(fname, str): # assume file name if isinstance(fname, str): # assume file name
_LIB.XGBoosterLoadModel(self.handle, c_str(fname)) if os.path.exists(fname):
_LIB.XGBoosterLoadModel(self.handle, c_str(fname))
else:
raise ValueError("No such file: {0}")
else: else:
buf = fname buf = fname
length = ctypes.c_ulong(len(buf)) length = ctypes.c_ulong(len(buf))

View File

@ -1,9 +1,20 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import unittest
dpath = 'demo/data/' dpath = 'demo/data/'
class TestBasic(unittest.TestCase):
def test_load_file_invalid(self):
self.assertRaises(ValueError, xgb.Booster,
model_file='incorrect_path')
def test_basic(): def test_basic():
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')