parent
028ec5f028
commit
8a0db293c5
@ -5,9 +5,9 @@ objective="rank:pairwise"
|
||||
|
||||
# Tree Booster Parameters
|
||||
# step size shrinkage
|
||||
eta = 0.1
|
||||
eta = 0.1
|
||||
# minimum loss reduction required to make a further partition
|
||||
gamma = 1.0
|
||||
gamma = 1.0
|
||||
# minimum sum of instance weight(hessian) needed in a child
|
||||
min_child_weight = 0.1
|
||||
# maximum depth of a tree
|
||||
@ -17,12 +17,10 @@ max_depth = 6
|
||||
# the number of round to do boosting
|
||||
num_round = 4
|
||||
# 0 means do not save any model except the final round model
|
||||
save_period = 0
|
||||
save_period = 0
|
||||
# The path of training data
|
||||
data = "mq2008.train"
|
||||
data = "mq2008.train"
|
||||
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
|
||||
eval[test] = "mq2008.vali"
|
||||
# The path of test data
|
||||
test:data = "mq2008.test"
|
||||
|
||||
|
||||
eval[test] = "mq2008.vali"
|
||||
# The path of test data
|
||||
test:data = "mq2008.test"
|
||||
|
||||
@ -268,7 +268,7 @@ class CLI {
|
||||
// always save final round
|
||||
if ((param_.save_period == 0 ||
|
||||
param_.num_round % param_.save_period != 0) &&
|
||||
param_.model_out != CLIParam::kNull && rabit::GetRank() == 0) {
|
||||
rabit::GetRank() == 0) {
|
||||
std::ostringstream os;
|
||||
if (param_.model_out == CLIParam::kNull) {
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
|
||||
@ -22,6 +22,7 @@ model_in = {model_in}
|
||||
model_out = {model_out}
|
||||
test_path = {test_path}
|
||||
name_pred = {name_pred}
|
||||
model_dir = {model_dir}
|
||||
|
||||
num_round = 10
|
||||
data = {data_path}
|
||||
@ -59,7 +60,8 @@ eval[test] = {data_path}
|
||||
model_in='NULL',
|
||||
model_out=model_out_cli,
|
||||
test_path='NULL',
|
||||
name_pred='NULL')
|
||||
name_pred='NULL',
|
||||
model_dir='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
@ -73,7 +75,8 @@ eval[test] = {data_path}
|
||||
model_in=model_out_cli,
|
||||
model_out='NULL',
|
||||
test_path=data_path,
|
||||
name_pred=predict_out)
|
||||
name_pred=predict_out,
|
||||
model_dir='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(predict_conf)
|
||||
|
||||
@ -145,7 +148,8 @@ eval[test] = {data_path}
|
||||
model_in='NULL',
|
||||
model_out=model_out_cli,
|
||||
test_path='NULL',
|
||||
name_pred='NULL')
|
||||
name_pred='NULL',
|
||||
model_dir='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
@ -154,3 +158,28 @@ eval[test] = {data_path}
|
||||
model = json.load(fd)
|
||||
|
||||
assert model['learner']['gradient_booster']['name'] == 'gbtree'
|
||||
|
||||
def test_cli_save_model(self):
|
||||
'''Test save on final round'''
|
||||
exe = self.get_exe()
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.PROJECT_ROOT)
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_out_cli = os.path.join(tmpdir, '0010.model')
|
||||
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
|
||||
|
||||
train_conf = self.template.format(data_path=data_path,
|
||||
seed=seed,
|
||||
task='train',
|
||||
model_in='NULL',
|
||||
model_out='NULL',
|
||||
test_path='NULL',
|
||||
name_pred='NULL',
|
||||
model_dir=tmpdir)
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
subprocess.run([exe, config_path])
|
||||
assert os.path.exists(model_out_cli)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user