Fix CLI ranking demo. (#6439)

Save model at final round.
This commit is contained in:
Jiaming Yuan 2020-11-28 14:12:06 -05:00 committed by Hyunsu Cho
parent 028ec5f028
commit 8a0db293c5
3 changed files with 40 additions and 13 deletions

View File

@ -5,9 +5,9 @@ objective="rank:pairwise"
# Tree Booster Parameters # Tree Booster Parameters
# step size shrinkage # step size shrinkage
eta = 0.1 eta = 0.1
# minimum loss reduction required to make a further partition # minimum loss reduction required to make a further partition
gamma = 1.0 gamma = 1.0
# minimum sum of instance weight(hessian) needed in a child # minimum sum of instance weight(hessian) needed in a child
min_child_weight = 0.1 min_child_weight = 0.1
# maximum depth of a tree # maximum depth of a tree
@ -17,12 +17,10 @@ max_depth = 6
# the number of round to do boosting # the number of round to do boosting
num_round = 4 num_round = 4
# 0 means do not save any model except the final round model # 0 means do not save any model except the final round model
save_period = 0 save_period = 0
# The path of training data # The path of training data
data = "mq2008.train" data = "mq2008.train"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set # The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "mq2008.vali" eval[test] = "mq2008.vali"
# The path of test data # The path of test data
test:data = "mq2008.test" test:data = "mq2008.test"

View File

@ -268,7 +268,7 @@ class CLI {
// always save final round // always save final round
if ((param_.save_period == 0 || if ((param_.save_period == 0 ||
param_.num_round % param_.save_period != 0) && param_.num_round % param_.save_period != 0) &&
param_.model_out != CLIParam::kNull && rabit::GetRank() == 0) { rabit::GetRank() == 0) {
std::ostringstream os; std::ostringstream os;
if (param_.model_out == CLIParam::kNull) { if (param_.model_out == CLIParam::kNull) {
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4) os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)

View File

@ -22,6 +22,7 @@ model_in = {model_in}
model_out = {model_out} model_out = {model_out}
test_path = {test_path} test_path = {test_path}
name_pred = {name_pred} name_pred = {name_pred}
model_dir = {model_dir}
num_round = 10 num_round = 10
data = {data_path} data = {data_path}
@ -59,7 +60,8 @@ eval[test] = {data_path}
model_in='NULL', model_in='NULL',
model_out=model_out_cli, model_out=model_out_cli,
test_path='NULL', test_path='NULL',
name_pred='NULL') name_pred='NULL',
model_dir='NULL')
with open(config_path, 'w') as fd: with open(config_path, 'w') as fd:
fd.write(train_conf) fd.write(train_conf)
@ -73,7 +75,8 @@ eval[test] = {data_path}
model_in=model_out_cli, model_in=model_out_cli,
model_out='NULL', model_out='NULL',
test_path=data_path, test_path=data_path,
name_pred=predict_out) name_pred=predict_out,
model_dir='NULL')
with open(config_path, 'w') as fd: with open(config_path, 'w') as fd:
fd.write(predict_conf) fd.write(predict_conf)
@ -145,7 +148,8 @@ eval[test] = {data_path}
model_in='NULL', model_in='NULL',
model_out=model_out_cli, model_out=model_out_cli,
test_path='NULL', test_path='NULL',
name_pred='NULL') name_pred='NULL',
model_dir='NULL')
with open(config_path, 'w') as fd: with open(config_path, 'w') as fd:
fd.write(train_conf) fd.write(train_conf)
@ -154,3 +158,28 @@ eval[test] = {data_path}
model = json.load(fd) model = json.load(fd)
assert model['learner']['gradient_booster']['name'] == 'gbtree' assert model['learner']['gradient_booster']['name'] == 'gbtree'
def test_cli_save_model(self):
'''Test save on final round'''
exe = self.get_exe()
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.PROJECT_ROOT)
seed = 1994
with tempfile.TemporaryDirectory() as tmpdir:
model_out_cli = os.path.join(tmpdir, '0010.model')
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
train_conf = self.template.format(data_path=data_path,
seed=seed,
task='train',
model_in='NULL',
model_out='NULL',
test_path='NULL',
name_pred='NULL',
model_dir=tmpdir)
with open(config_path, 'w') as fd:
fd.write(train_conf)
subprocess.run([exe, config_path])
assert os.path.exists(model_out_cli)