Don't set seed on CLI interface. (#5563)
This commit is contained in:
parent
ccd30e4491
commit
b809f5d8b8
@ -17,6 +17,7 @@ def _train_internal(params, dtrain,
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
evals = list(evals)
|
||||
params = params.copy()
|
||||
if isinstance(params, dict) \
|
||||
and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
|
||||
@ -340,7 +340,6 @@ int CLIRunTask(int argc, char *argv[]) {
|
||||
|
||||
common::ConfigParser cp(argv[1]);
|
||||
auto cfg = cp.Parse();
|
||||
cfg.emplace_back("seed", "0");
|
||||
|
||||
for (int i = 2; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
|
||||
@ -13,7 +13,7 @@ booster = gbtree
|
||||
objective = reg:squarederror
|
||||
eta = 1.0
|
||||
gamma = 1.0
|
||||
seed = 0
|
||||
seed = {seed}
|
||||
min_child_weight = 0
|
||||
max_depth = 3
|
||||
task = {task}
|
||||
@ -41,14 +41,18 @@ eval[test] = {data_path}
|
||||
exe = os.path.join(project_root, exe)
|
||||
assert os.path.exists(exe)
|
||||
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_out = os.path.join(tmpdir, 'test_load_cli_model')
|
||||
model_out_cli = os.path.join(tmpdir, 'test_load_cli_model-cli.bin')
|
||||
model_out_py = os.path.join(tmpdir, 'test_cli_model-py.bin')
|
||||
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
|
||||
|
||||
train_conf = self.template.format(data_path=data_path,
|
||||
seed=seed,
|
||||
task='train',
|
||||
model_in='NULL',
|
||||
model_out=model_out,
|
||||
model_out=model_out_cli,
|
||||
test_path='NULL',
|
||||
name_pred='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
@ -59,8 +63,9 @@ eval[test] = {data_path}
|
||||
predict_out = os.path.join(tmpdir,
|
||||
'test_load_cli_model-prediction')
|
||||
predict_conf = self.template.format(task='pred',
|
||||
seed=seed,
|
||||
data_path=data_path,
|
||||
model_in=model_out,
|
||||
model_in=model_out_cli,
|
||||
model_out='NULL',
|
||||
test_path=data_path,
|
||||
name_pred=predict_out)
|
||||
@ -76,16 +81,24 @@ eval[test] = {data_path}
|
||||
'objective': 'reg:squarederror',
|
||||
'eta': 1.0,
|
||||
'gamma': 1.0,
|
||||
'seed': 0,
|
||||
'seed': seed,
|
||||
'min_child_weight': 0,
|
||||
'max_depth': 3
|
||||
}
|
||||
data = xgboost.DMatrix(data_path)
|
||||
booster = xgboost.train(parameters, data, num_boost_round=10)
|
||||
booster.save_model(model_out_py)
|
||||
py_predt = booster.predict(data)
|
||||
|
||||
numpy.testing.assert_allclose(cli_predt, py_predt)
|
||||
|
||||
cli_model = xgboost.Booster(model_file=model_out)
|
||||
cli_model = xgboost.Booster(model_file=model_out_cli)
|
||||
cli_predt = cli_model.predict(data)
|
||||
numpy.testing.assert_allclose(cli_predt, py_predt)
|
||||
|
||||
with open(model_out_cli, 'rb') as fd:
|
||||
cli_model_bin = fd.read()
|
||||
with open(model_out_py, 'rb') as fd:
|
||||
py_model_bin = fd.read()
|
||||
|
||||
assert hash(cli_model_bin) == hash(py_model_bin)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user