Refactor configuration [Part II]. (#4577)
* Refactor configuration [Part II].
* General changes:
** Remove `Init` methods to avoid ambiguity.
** Remove `Configure(std::map<>)` to avoid redundant copying and prepare for
parameter validation. (`std::vector` is returned from `InitAllowUnknown`).
** Add name to tree updaters for easier debugging.
* Learner changes:
** Make `LearnerImpl` the only source of configuration.
All configurations are stored and carried out by `LearnerImpl::Configure()`.
** Remove booster in C API.
Originally kept for "compatibility reason", but did not state why. So here
we just remove it.
** Add a `metric_names_` field in `LearnerImpl`.
** Remove `LazyInit`. Configuration will always be lazy.
** Run `Configure` before every iteration.
* Predictor changes:
** Allocate both cpu and gpu predictor.
** Remove cpu_predictor from gpu_predictor.
`GBTree` is now used to dispatch the predictor.
** Remove some GPU Predictor tests.
* IO
No IO changes. The binary model format stability is tested by comparing
hashing value of save models between two commits
This commit is contained in:
@@ -24,82 +24,6 @@
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
// booster wrapper for backward compatible reason.
|
||||
class Booster {
|
||||
public:
|
||||
explicit Booster(const std::vector<std::shared_ptr<DMatrix> >& cache_mats)
|
||||
: configured_(false),
|
||||
initialized_(false),
|
||||
learner_(Learner::Create(cache_mats)) {}
|
||||
|
||||
inline Learner* learner() { // NOLINT
|
||||
return learner_.get();
|
||||
}
|
||||
|
||||
inline void SetParam(const std::string& name, const std::string& val) {
|
||||
auto it = std::find_if(cfg_.begin(), cfg_.end(),
|
||||
[&name, &val](decltype(*cfg_.begin()) &x) {
|
||||
if (name == "eval_metric") {
|
||||
return x.first == name && x.second == val;
|
||||
}
|
||||
return x.first == name;
|
||||
});
|
||||
if (it == cfg_.end()) {
|
||||
cfg_.emplace_back(name, val);
|
||||
} else {
|
||||
(*it).second = val;
|
||||
}
|
||||
if (configured_) {
|
||||
learner_->Configure(cfg_);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LazyInit() {
|
||||
if (!configured_) {
|
||||
LoadSavedParamFromAttr();
|
||||
learner_->Configure(cfg_);
|
||||
configured_ = true;
|
||||
}
|
||||
if (!initialized_) {
|
||||
learner_->InitModel();
|
||||
initialized_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadSavedParamFromAttr() {
|
||||
// Locate saved parameters from learner attributes
|
||||
const std::string prefix = "SAVED_PARAM_";
|
||||
for (const std::string& attr_name : learner_->GetAttrNames()) {
|
||||
if (attr_name.find(prefix) == 0) {
|
||||
const std::string saved_param = attr_name.substr(prefix.length());
|
||||
if (std::none_of(cfg_.begin(), cfg_.end(),
|
||||
[&](const std::pair<std::string, std::string>& x)
|
||||
{ return x.first == saved_param; })) {
|
||||
// If cfg_ contains the parameter already, skip it
|
||||
// (this is to allow the user to explicitly override its value)
|
||||
std::string saved_param_value;
|
||||
CHECK(learner_->GetAttr(attr_name, &saved_param_value));
|
||||
cfg_.emplace_back(saved_param, saved_param_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadModel(dmlc::Stream* fi) {
|
||||
learner_->Load(fi);
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool IsInitialized() const { return initialized_; }
|
||||
void Intialize() { initialized_ = true; }
|
||||
|
||||
private:
|
||||
bool configured_;
|
||||
bool initialized_;
|
||||
std::unique_ptr<Learner> learner_;
|
||||
std::vector<std::pair<std::string, std::string> > cfg_;
|
||||
};
|
||||
|
||||
// declare the data callback.
|
||||
XGB_EXTERN_C int XGBoostNativeDataIterSetData(
|
||||
void *handle, XGBoostBatchCSR batch);
|
||||
@@ -861,14 +785,14 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
mats.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
|
||||
}
|
||||
*out = new Booster(mats);
|
||||
*out = Learner::Create(mats);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterFree(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
delete static_cast<Booster*>(handle);
|
||||
delete static_cast<Learner*>(handle);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -877,7 +801,7 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
||||
const char *value) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<Booster*>(handle)->SetParam(name, value);
|
||||
static_cast<Learner*>(handle)->SetParam(name, value);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -886,12 +810,11 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
|
||||
DMatrixHandle dtrain) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
auto *dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
|
||||
bst->LazyInit();
|
||||
bst->learner()->UpdateOneIter(iter, dtr->get());
|
||||
bst->UpdateOneIter(iter, dtr->get());
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -903,7 +826,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
HostDeviceVector<GradientPair> tmp_gpair;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
auto* dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
tmp_gpair.Resize(len);
|
||||
@@ -912,8 +835,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
tmp_gpair_h[i] = GradientPair(grad[i], hess[i]);
|
||||
}
|
||||
|
||||
bst->LazyInit();
|
||||
bst->learner()->BoostOneIter(0, dtr->get(), &tmp_gpair);
|
||||
bst->BoostOneIter(0, dtr->get(), &tmp_gpair);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -926,7 +848,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
std::vector<DMatrix*> data_sets;
|
||||
std::vector<std::string> data_names;
|
||||
|
||||
@@ -935,8 +857,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
data_names.emplace_back(evnames[i]);
|
||||
}
|
||||
|
||||
bst->LazyInit();
|
||||
eval_str = bst->learner()->EvalOneIter(iter, data_sets, data_names);
|
||||
eval_str = bst->EvalOneIter(iter, data_sets, data_names);
|
||||
*out_str = eval_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
@@ -951,10 +872,9 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
HostDeviceVector<bst_float> tmp_preds;
|
||||
bst->learner()->Predict(
|
||||
bst->Predict(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
(option_mask & 1) != 0,
|
||||
&tmp_preds, ntree_limit,
|
||||
@@ -972,7 +892,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||
static_cast<Booster*>(handle)->LoadModel(fi.get());
|
||||
static_cast<Learner*>(handle)->Load(fi.get());
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -980,9 +900,8 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
bst->learner()->Save(fo.get());
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
bst->Save(fo.get());
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -992,7 +911,7 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||
static_cast<Booster*>(handle)->LoadModel(&fs);
|
||||
static_cast<Learner*>(handle)->Load(&fs);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -1005,9 +924,8 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
bst->learner()->Save(&fo);
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
bst->Save(&fo);
|
||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||
API_END();
|
||||
@@ -1022,9 +940,8 @@ inline void XGBoostDumpModelImpl(
|
||||
const char*** out_models) {
|
||||
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
|
||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
str_vecs = bst->learner()->DumpModel(fmap, with_stats != 0, format);
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
str_vecs = bst->DumpModel(fmap, with_stats != 0, format);
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
@@ -1093,11 +1010,11 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char** out,
|
||||
int* success) {
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
if (bst->learner()->GetAttr(key, &ret_str)) {
|
||||
if (bst->GetAttr(key, &ret_str)) {
|
||||
*out = ret_str.c_str();
|
||||
*success = 1;
|
||||
} else {
|
||||
@@ -1108,28 +1025,28 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char* value) {
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
const char* key,
|
||||
const char* value) {
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
if (value == nullptr) {
|
||||
bst->learner()->DelAttr(key);
|
||||
bst->DelAttr(key);
|
||||
} else {
|
||||
bst->learner()->SetAttr(key, value);
|
||||
bst->SetAttr(key, value);
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
xgboost::bst_ulong* out_len,
|
||||
const char*** out) {
|
||||
xgboost::bst_ulong* out_len,
|
||||
const char*** out) {
|
||||
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
|
||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
str_vecs = bst->learner()->GetAttrNames();
|
||||
str_vecs = bst->GetAttrNames();
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
@@ -1140,13 +1057,13 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst->learner());
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst);
|
||||
if (*version != 0) {
|
||||
bst->Intialize();
|
||||
bst->Configure();
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
@@ -1154,23 +1071,14 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(bst->learner());
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
if (bst->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(bst);
|
||||
} else {
|
||||
rabit::CheckPoint(bst->learner());
|
||||
rabit::CheckPoint(bst);
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
/* hidden method; only known to C++ test suite */
|
||||
const std::map<std::string, std::string>&
|
||||
QueryBoosterConfigurationArguments(BoosterHandle handle) {
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
return bst->learner()->GetConfigurationArguments();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
Reference in New Issue
Block a user