Fix #3342 and h2oai/h2o4gpu#625: Save predictor parameters in model file (#3856)

* Fix #3342 and h2oai/h2o4gpu#625: Save predictor parameters in model file

This allows pickled models to retain predictor attributes, such as
'predictor' (whether to use CPU or GPU) and 'n_gpu' (number of GPUs
to use). Related: h2oai/h2o4gpu#625

Closes #3342.

TODO. Write a test.

* Fix lint

* Do not load GPU predictor into CPU-only XGBoost

* Add a test for pickling GPU predictors

* Make sample data big enough to pass multi GPU test

* Update test_gpu_predictor.cu
This commit is contained in:
Philip Hyunsu Cho
2018-11-03 21:45:38 -07:00
committed by GitHub
parent e04ab56b57
commit 91537e7353
7 changed files with 206 additions and 51 deletions

View File

@@ -7,9 +7,10 @@
#include <dmlc/thread_local.h>
#include <rabit/rabit.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <string>
#include <cstring>
#include <memory>
#include "./c_api_error.h"
@@ -52,6 +53,7 @@ class Booster {
inline void LazyInit() {
if (!configured_) {
LoadSavedParamFromAttr();
learner_->Configure(cfg_);
configured_ = true;
}
@@ -61,6 +63,25 @@ class Booster {
}
}
inline void LoadSavedParamFromAttr() {
// Locate saved parameters from learner attributes
const std::string prefix = "SAVED_PARAM_";
for (const std::string& attr_name : learner_->GetAttrNames()) {
if (attr_name.find(prefix) == 0) {
const std::string saved_param = attr_name.substr(prefix.length());
if (std::none_of(cfg_.begin(), cfg_.end(),
[&](const std::pair<std::string, std::string>& x)
{ return x.first == saved_param; })) {
// If cfg_ contains the parameter already, skip it
// (this is to allow the user to explicitly override its value)
std::string saved_param_value;
CHECK(learner_->GetAttr(attr_name, &saved_param_value));
cfg_.emplace_back(saved_param, saved_param_value);
}
}
}
}
inline void LoadModel(dmlc::Stream* fi) {
learner_->Load(fi);
initialized_ = true;
@@ -1149,5 +1170,14 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_END();
}
/* hidden method; only known to C++ test suite */
const std::map<std::string, std::string>&
QueryBoosterConfigurationArguments(BoosterHandle handle) {
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
bst->LazyInit();
return bst->learner()->GetConfigurationArguments();
}
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();