Refactor configuration [Part II]. (#4577)

* Refactor configuration [Part II].

* General changes:
** Remove `Init` methods to avoid ambiguity.
** Remove `Configure(std::map<>)` to avoid redundant copying and prepare for
   parameter validation. (`std::vector` is returned from `InitAllowUnknown`).
** Add name to tree updaters for easier debugging.

* Learner changes:
** Make `LearnerImpl` the only source of configuration.

    All configurations are stored and carried out by `LearnerImpl::Configure()`.

** Remove booster in C API.

    Originally kept for "compatibility reason", but did not state why.  So here
    we just remove it.

** Add a `metric_names_` field in `LearnerImpl`.
** Remove `LazyInit`.  Configuration will always be lazy.
** Run `Configure` before every iteration.

* Predictor changes:
** Allocate both cpu and gpu predictor.
** Remove cpu_predictor from gpu_predictor.

    `GBTree` is now used to dispatch the predictor.

** Remove some GPU Predictor tests.

* IO

No IO changes.  The binary model format stability is tested by comparing
hashing value of save models between two commits
This commit is contained in:
Jiaming Yuan
2019-07-20 08:34:56 -04:00
committed by GitHub
parent ad1192e8a3
commit f0064c07ab
69 changed files with 669 additions and 761 deletions

View File

@@ -24,82 +24,6 @@
namespace xgboost {
// booster wrapper for backward compatible reason.
class Booster {
public:
explicit Booster(const std::vector<std::shared_ptr<DMatrix> >& cache_mats)
: configured_(false),
initialized_(false),
learner_(Learner::Create(cache_mats)) {}
inline Learner* learner() { // NOLINT
return learner_.get();
}
inline void SetParam(const std::string& name, const std::string& val) {
auto it = std::find_if(cfg_.begin(), cfg_.end(),
[&name, &val](decltype(*cfg_.begin()) &x) {
if (name == "eval_metric") {
return x.first == name && x.second == val;
}
return x.first == name;
});
if (it == cfg_.end()) {
cfg_.emplace_back(name, val);
} else {
(*it).second = val;
}
if (configured_) {
learner_->Configure(cfg_);
}
}
inline void LazyInit() {
if (!configured_) {
LoadSavedParamFromAttr();
learner_->Configure(cfg_);
configured_ = true;
}
if (!initialized_) {
learner_->InitModel();
initialized_ = true;
}
}
inline void LoadSavedParamFromAttr() {
// Locate saved parameters from learner attributes
const std::string prefix = "SAVED_PARAM_";
for (const std::string& attr_name : learner_->GetAttrNames()) {
if (attr_name.find(prefix) == 0) {
const std::string saved_param = attr_name.substr(prefix.length());
if (std::none_of(cfg_.begin(), cfg_.end(),
[&](const std::pair<std::string, std::string>& x)
{ return x.first == saved_param; })) {
// If cfg_ contains the parameter already, skip it
// (this is to allow the user to explicitly override its value)
std::string saved_param_value;
CHECK(learner_->GetAttr(attr_name, &saved_param_value));
cfg_.emplace_back(saved_param, saved_param_value);
}
}
}
}
inline void LoadModel(dmlc::Stream* fi) {
learner_->Load(fi);
initialized_ = true;
}
bool IsInitialized() const { return initialized_; }
void Intialize() { initialized_ = true; }
private:
bool configured_;
bool initialized_;
std::unique_ptr<Learner> learner_;
std::vector<std::pair<std::string, std::string> > cfg_;
};
// declare the data callback.
XGB_EXTERN_C int XGBoostNativeDataIterSetData(
void *handle, XGBoostBatchCSR batch);
@@ -861,14 +785,14 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
for (xgboost::bst_ulong i = 0; i < len; ++i) {
mats.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
}
*out = new Booster(mats);
*out = Learner::Create(mats);
API_END();
}
XGB_DLL int XGBoosterFree(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
delete static_cast<Booster*>(handle);
delete static_cast<Learner*>(handle);
API_END();
}
@@ -877,7 +801,7 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
const char *value) {
API_BEGIN();
CHECK_HANDLE();
static_cast<Booster*>(handle)->SetParam(name, value);
static_cast<Learner*>(handle)->SetParam(name, value);
API_END();
}
@@ -886,12 +810,11 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
DMatrixHandle dtrain) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
auto* bst = static_cast<Learner*>(handle);
auto *dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
bst->LazyInit();
bst->learner()->UpdateOneIter(iter, dtr->get());
bst->UpdateOneIter(iter, dtr->get());
API_END();
}
@@ -903,7 +826,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
HostDeviceVector<GradientPair> tmp_gpair;
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
auto* bst = static_cast<Learner*>(handle);
auto* dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
tmp_gpair.Resize(len);
@@ -912,8 +835,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
tmp_gpair_h[i] = GradientPair(grad[i], hess[i]);
}
bst->LazyInit();
bst->learner()->BoostOneIter(0, dtr->get(), &tmp_gpair);
bst->BoostOneIter(0, dtr->get(), &tmp_gpair);
API_END();
}
@@ -926,7 +848,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
auto* bst = static_cast<Learner*>(handle);
std::vector<DMatrix*> data_sets;
std::vector<std::string> data_names;
@@ -935,8 +857,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
data_names.emplace_back(evnames[i]);
}
bst->LazyInit();
eval_str = bst->learner()->EvalOneIter(iter, data_sets, data_names);
eval_str = bst->EvalOneIter(iter, data_sets, data_names);
*out_str = eval_str.c_str();
API_END();
}
@@ -951,10 +872,9 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
XGBAPIThreadLocalStore::Get()->ret_vec_float;
API_BEGIN();
CHECK_HANDLE();
auto *bst = static_cast<Booster*>(handle);
bst->LazyInit();
auto *bst = static_cast<Learner*>(handle);
HostDeviceVector<bst_float> tmp_preds;
bst->learner()->Predict(
bst->Predict(
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
(option_mask & 1) != 0,
&tmp_preds, ntree_limit,
@@ -972,7 +892,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Booster*>(handle)->LoadModel(fi.get());
static_cast<Learner*>(handle)->Load(fi.get());
API_END();
}
@@ -980,9 +900,8 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
auto *bst = static_cast<Booster*>(handle);
bst->LazyInit();
bst->learner()->Save(fo.get());
auto *bst = static_cast<Learner*>(handle);
bst->Save(fo.get());
API_END();
}
@@ -992,7 +911,7 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
API_BEGIN();
CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
static_cast<Booster*>(handle)->LoadModel(&fs);
static_cast<Learner*>(handle)->Load(&fs);
API_END();
}
@@ -1005,9 +924,8 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
API_BEGIN();
CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str);
auto *bst = static_cast<Booster*>(handle);
bst->LazyInit();
bst->learner()->Save(&fo);
auto *bst = static_cast<Learner*>(handle);
bst->Save(&fo);
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
@@ -1022,9 +940,8 @@ inline void XGBoostDumpModelImpl(
const char*** out_models) {
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
auto *bst = static_cast<Booster*>(handle);
bst->LazyInit();
str_vecs = bst->learner()->DumpModel(fmap, with_stats != 0, format);
auto *bst = static_cast<Learner*>(handle);
str_vecs = bst->DumpModel(fmap, with_stats != 0, format);
charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) {
charp_vecs[i] = str_vecs[i].c_str();
@@ -1093,11 +1010,11 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
const char* key,
const char** out,
int* success) {
auto* bst = static_cast<Booster*>(handle);
auto* bst = static_cast<Learner*>(handle);
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
CHECK_HANDLE();
if (bst->learner()->GetAttr(key, &ret_str)) {
if (bst->GetAttr(key, &ret_str)) {
*out = ret_str.c_str();
*success = 1;
} else {
@@ -1108,28 +1025,28 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
}
XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
const char* key,
const char* value) {
auto* bst = static_cast<Booster*>(handle);
const char* key,
const char* value) {
auto* bst = static_cast<Learner*>(handle);
API_BEGIN();
CHECK_HANDLE();
if (value == nullptr) {
bst->learner()->DelAttr(key);
bst->DelAttr(key);
} else {
bst->learner()->SetAttr(key, value);
bst->SetAttr(key, value);
}
API_END();
}
XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
xgboost::bst_ulong* out_len,
const char*** out) {
xgboost::bst_ulong* out_len,
const char*** out) {
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
auto *bst = static_cast<Booster*>(handle);
auto *bst = static_cast<Learner*>(handle);
API_BEGIN();
CHECK_HANDLE();
str_vecs = bst->learner()->GetAttrNames();
str_vecs = bst->GetAttrNames();
charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) {
charp_vecs[i] = str_vecs[i].c_str();
@@ -1140,13 +1057,13 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
}
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
*version = rabit::LoadCheckPoint(bst->learner());
auto* bst = static_cast<Learner*>(handle);
*version = rabit::LoadCheckPoint(bst);
if (*version != 0) {
bst->Intialize();
bst->Configure();
}
API_END();
}
@@ -1154,23 +1071,14 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
if (bst->learner()->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(bst->learner());
auto* bst = static_cast<Learner*>(handle);
if (bst->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(bst);
} else {
rabit::CheckPoint(bst->learner());
rabit::CheckPoint(bst);
}
API_END();
}
/* hidden method; only known to C++ test suite */
const std::map<std::string, std::string>&
QueryBoosterConfigurationArguments(BoosterHandle handle) {
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
bst->LazyInit();
return bst->learner()->GetConfigurationArguments();
}
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();