BUG: incorrect model_file results in segfault
This commit is contained in:
parent
ae43fd7c7a
commit
db0c9e1c2d
@ -743,7 +743,10 @@ class Booster(object):
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
"""
|
||||
if isinstance(fname, str): # assume file name
|
||||
_LIB.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||
if os.path.exists(fname):
|
||||
_LIB.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||
else:
|
||||
raise ValueError("No such file: {0}")
|
||||
else:
|
||||
buf = fname
|
||||
length = ctypes.c_ulong(len(buf))
|
||||
|
||||
@ -1,9 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
|
||||
|
||||
dpath = 'demo/data/'
|
||||
|
||||
|
||||
class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_load_file_invalid(self):
|
||||
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
model_file='incorrect_path')
|
||||
|
||||
|
||||
def test_basic():
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user