@@ -22,6 +22,7 @@ model_in = {model_in}
|
||||
model_out = {model_out}
|
||||
test_path = {test_path}
|
||||
name_pred = {name_pred}
|
||||
model_dir = {model_dir}
|
||||
|
||||
num_round = 10
|
||||
data = {data_path}
|
||||
@@ -59,7 +60,8 @@ eval[test] = {data_path}
|
||||
model_in='NULL',
|
||||
model_out=model_out_cli,
|
||||
test_path='NULL',
|
||||
name_pred='NULL')
|
||||
name_pred='NULL',
|
||||
model_dir='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
@@ -73,7 +75,8 @@ eval[test] = {data_path}
|
||||
model_in=model_out_cli,
|
||||
model_out='NULL',
|
||||
test_path=data_path,
|
||||
name_pred=predict_out)
|
||||
name_pred=predict_out,
|
||||
model_dir='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(predict_conf)
|
||||
|
||||
@@ -145,7 +148,8 @@ eval[test] = {data_path}
|
||||
model_in='NULL',
|
||||
model_out=model_out_cli,
|
||||
test_path='NULL',
|
||||
name_pred='NULL')
|
||||
name_pred='NULL',
|
||||
model_dir='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
@@ -154,3 +158,28 @@ eval[test] = {data_path}
|
||||
model = json.load(fd)
|
||||
|
||||
assert model['learner']['gradient_booster']['name'] == 'gbtree'
|
||||
|
||||
def test_cli_save_model(self):
|
||||
'''Test save on final round'''
|
||||
exe = self.get_exe()
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.PROJECT_ROOT)
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_out_cli = os.path.join(tmpdir, '0010.model')
|
||||
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
|
||||
|
||||
train_conf = self.template.format(data_path=data_path,
|
||||
seed=seed,
|
||||
task='train',
|
||||
model_in='NULL',
|
||||
model_out='NULL',
|
||||
test_path='NULL',
|
||||
name_pred='NULL',
|
||||
model_dir=tmpdir)
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
subprocess.run([exe, config_path])
|
||||
assert os.path.exists(model_out_cli)
|
||||
|
||||
Reference in New Issue
Block a user