From db0c9e1c2d38d0aff7f8e5012dd265f1b9a91b96 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Wed, 16 Sep 2015 21:53:51 +0900 Subject: [PATCH] BUG: incorrect model_file results in segfault --- python-package/xgboost/core.py | 5 ++++- tests/python/test_basic.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index bcb68580e..2718ca704 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -743,7 +743,10 @@ class Booster(object): Input file name or memory buffer(see also save_raw) """ if isinstance(fname, str): # assume file name - _LIB.XGBoosterLoadModel(self.handle, c_str(fname)) + if os.path.exists(fname): + _LIB.XGBoosterLoadModel(self.handle, c_str(fname)) + else: + raise ValueError("No such file: {0}") else: buf = fname length = ctypes.c_ulong(len(buf)) diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index bb6654f51..111d389a0 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -1,9 +1,20 @@ # -*- coding: utf-8 -*- import numpy as np import xgboost as xgb +import unittest + dpath = 'demo/data/' + +class TestBasic(unittest.TestCase): + + def test_load_file_invalid(self): + + self.assertRaises(ValueError, xgb.Booster, + model_file='incorrect_path') + + def test_basic(): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')