Fix #3342 and h2oai/h2o4gpu#625: Save predictor parameters in model file (#3856)
* Fix #3342 and h2oai/h2o4gpu#625: Save predictor parameters in model file This allows pickled models to retain predictor attributes, such as 'predictor' (whether to use CPU or GPU) and 'n_gpu' (number of GPUs to use). Related: h2oai/h2o4gpu#625 Closes #3342. TODO. Write a test. * Fix lint * Do not load GPU predictor into CPU-only XGBoost * Add a test for pickling GPU predictors * Make sample data big enough to pass multi GPU test * Update test_gpu_predictor.cu
This commit is contained in:
committed by
GitHub
parent
e04ab56b57
commit
91537e7353
@@ -7,9 +7,10 @@
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#include "./c_api_error.h"
|
||||
@@ -52,6 +53,7 @@ class Booster {
|
||||
|
||||
inline void LazyInit() {
|
||||
if (!configured_) {
|
||||
LoadSavedParamFromAttr();
|
||||
learner_->Configure(cfg_);
|
||||
configured_ = true;
|
||||
}
|
||||
@@ -61,6 +63,25 @@ class Booster {
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadSavedParamFromAttr() {
|
||||
// Locate saved parameters from learner attributes
|
||||
const std::string prefix = "SAVED_PARAM_";
|
||||
for (const std::string& attr_name : learner_->GetAttrNames()) {
|
||||
if (attr_name.find(prefix) == 0) {
|
||||
const std::string saved_param = attr_name.substr(prefix.length());
|
||||
if (std::none_of(cfg_.begin(), cfg_.end(),
|
||||
[&](const std::pair<std::string, std::string>& x)
|
||||
{ return x.first == saved_param; })) {
|
||||
// If cfg_ contains the parameter already, skip it
|
||||
// (this is to allow the user to explicitly override its value)
|
||||
std::string saved_param_value;
|
||||
CHECK(learner_->GetAttr(attr_name, &saved_param_value));
|
||||
cfg_.emplace_back(saved_param, saved_param_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadModel(dmlc::Stream* fi) {
|
||||
learner_->Load(fi);
|
||||
initialized_ = true;
|
||||
@@ -1149,5 +1170,14 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
/* hidden method; only known to C++ test suite */
|
||||
const std::map<std::string, std::string>&
|
||||
QueryBoosterConfigurationArguments(BoosterHandle handle) {
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
return bst->learner()->GetConfigurationArguments();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
Reference in New Issue
Block a user