Merge pull request #1153 from khotilov/seed_in_configure
Fixes for repeated Configure calls
This commit is contained in:
commit
2f2ad21de4
@ -32,7 +32,15 @@ class Booster {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void SetParam(const std::string& name, const std::string& val) {
|
inline void SetParam(const std::string& name, const std::string& val) {
|
||||||
cfg_.push_back(std::make_pair(name, val));
|
auto it = std::find_if(cfg_.begin(), cfg_.end(),
|
||||||
|
[&name](decltype(*cfg_.begin()) &x) {
|
||||||
|
return x.first == name;
|
||||||
|
});
|
||||||
|
if (it == cfg_.end()) {
|
||||||
|
cfg_.push_back(std::make_pair(name, val));
|
||||||
|
} else {
|
||||||
|
(*it).second = val;
|
||||||
|
}
|
||||||
if (configured_) {
|
if (configured_) {
|
||||||
learner_->Configure(cfg_);
|
learner_->Configure(cfg_);
|
||||||
}
|
}
|
||||||
@ -277,7 +285,7 @@ int XGDMatrixCreateFromCSC(const bst_ulong* col_ptr,
|
|||||||
RowBatch::Entry(static_cast<bst_uint>(i), data[j]),
|
RowBatch::Entry(static_cast<bst_uint>(i), data[j]),
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mat.info.num_row = mat.row_ptr_.size() - 1;
|
mat.info.num_row = mat.row_ptr_.size() - 1;
|
||||||
mat.info.num_col = static_cast<uint64_t>(ncol);
|
mat.info.num_col = static_cast<uint64_t>(ncol);
|
||||||
mat.info.num_nonzero = nelem;
|
mat.info.num_nonzero = nelem;
|
||||||
|
|||||||
@ -189,10 +189,10 @@ class LearnerImpl : public Learner {
|
|||||||
mparam.InitAllowUnknown(args);
|
mparam.InitAllowUnknown(args);
|
||||||
name_obj_ = cfg_["objective"];
|
name_obj_ = cfg_["objective"];
|
||||||
name_gbm_ = cfg_["booster"];
|
name_gbm_ = cfg_["booster"];
|
||||||
|
// set seed only before the model is initialized
|
||||||
|
common::GlobalRandom().seed(tparam.seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
common::GlobalRandom().seed(tparam.seed);
|
|
||||||
|
|
||||||
// set number of features correctly.
|
// set number of features correctly.
|
||||||
cfg_["num_feature"] = common::ToString(mparam.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam.num_feature);
|
||||||
cfg_["num_class"] = common::ToString(mparam.num_class);
|
cfg_["num_class"] = common::ToString(mparam.num_class);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user