Fix #3342 and h2oai/h2o4gpu#625: Save predictor parameters in model file (#3856)
* Fix #3342 and h2oai/h2o4gpu#625: Save predictor parameters in model file This allows pickled models to retain predictor attributes, such as 'predictor' (whether to use CPU or GPU) and 'n_gpu' (number of GPUs to use). Related: h2oai/h2o4gpu#625 Closes #3342. TODO. Write a test. * Fix lint * Do not load GPU predictor into CPU-only XGBoost * Add a test for pickling GPU predictors * Make sample data big enough to pass multi GPU test * Update test_gpu_predictor.cu
This commit is contained in:
parent
e04ab56b57
commit
91537e7353
2
Makefile
2
Makefile
@ -250,8 +250,6 @@ Rpack: clean_all
|
|||||||
cp -r src xgboost/src/src
|
cp -r src xgboost/src/src
|
||||||
cp -r include xgboost/src/include
|
cp -r include xgboost/src/include
|
||||||
cp -r amalgamation xgboost/src/amalgamation
|
cp -r amalgamation xgboost/src/amalgamation
|
||||||
mkdir -p xgboost/src/tests/cpp
|
|
||||||
cp tests/cpp/test_learner.h xgboost/src/tests/cpp
|
|
||||||
mkdir -p xgboost/src/rabit
|
mkdir -p xgboost/src/rabit
|
||||||
cp -r rabit/include xgboost/src/rabit/include
|
cp -r rabit/include xgboost/src/rabit/include
|
||||||
cp -r rabit/src xgboost/src/rabit/src
|
cp -r rabit/src xgboost/src/rabit/src
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
|
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "./base.h"
|
#include "./base.h"
|
||||||
@ -178,6 +179,12 @@ class Learner : public rabit::Serializable {
|
|||||||
*/
|
*/
|
||||||
static Learner* Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data);
|
static Learner* Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get configuration arguments currently stored by the learner
|
||||||
|
* \return Key-value pairs representing configuration arguments
|
||||||
|
*/
|
||||||
|
virtual const std::map<std::string, std::string>& GetConfigurationArguments() const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/*! \brief internal base score of the model */
|
/*! \brief internal base score of the model */
|
||||||
bst_float base_score_;
|
bst_float base_score_;
|
||||||
|
|||||||
@ -7,9 +7,10 @@
|
|||||||
#include <dmlc/thread_local.h>
|
#include <dmlc/thread_local.h>
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "./c_api_error.h"
|
#include "./c_api_error.h"
|
||||||
@ -52,6 +53,7 @@ class Booster {
|
|||||||
|
|
||||||
inline void LazyInit() {
|
inline void LazyInit() {
|
||||||
if (!configured_) {
|
if (!configured_) {
|
||||||
|
LoadSavedParamFromAttr();
|
||||||
learner_->Configure(cfg_);
|
learner_->Configure(cfg_);
|
||||||
configured_ = true;
|
configured_ = true;
|
||||||
}
|
}
|
||||||
@ -61,6 +63,25 @@ class Booster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void LoadSavedParamFromAttr() {
|
||||||
|
// Locate saved parameters from learner attributes
|
||||||
|
const std::string prefix = "SAVED_PARAM_";
|
||||||
|
for (const std::string& attr_name : learner_->GetAttrNames()) {
|
||||||
|
if (attr_name.find(prefix) == 0) {
|
||||||
|
const std::string saved_param = attr_name.substr(prefix.length());
|
||||||
|
if (std::none_of(cfg_.begin(), cfg_.end(),
|
||||||
|
[&](const std::pair<std::string, std::string>& x)
|
||||||
|
{ return x.first == saved_param; })) {
|
||||||
|
// If cfg_ contains the parameter already, skip it
|
||||||
|
// (this is to allow the user to explicitly override its value)
|
||||||
|
std::string saved_param_value;
|
||||||
|
CHECK(learner_->GetAttr(attr_name, &saved_param_value));
|
||||||
|
cfg_.emplace_back(saved_param, saved_param_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void LoadModel(dmlc::Stream* fi) {
|
inline void LoadModel(dmlc::Stream* fi) {
|
||||||
learner_->Load(fi);
|
learner_->Load(fi);
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
@ -1149,5 +1170,14 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* hidden method; only known to C++ test suite */
|
||||||
|
const std::map<std::string, std::string>&
|
||||||
|
QueryBoosterConfigurationArguments(BoosterHandle handle) {
|
||||||
|
CHECK_HANDLE();
|
||||||
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
|
bst->LazyInit();
|
||||||
|
return bst->learner()->GetConfigurationArguments();
|
||||||
|
}
|
||||||
|
|
||||||
// force link rabit
|
// force link rabit
|
||||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <ios>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "./common/common.h"
|
#include "./common/common.h"
|
||||||
@ -21,7 +22,6 @@
|
|||||||
#include "./common/random.h"
|
#include "./common/random.h"
|
||||||
#include "./common/enum_class_param.h"
|
#include "./common/enum_class_param.h"
|
||||||
#include "./common/timer.h"
|
#include "./common/timer.h"
|
||||||
#include "../tests/cpp/test_learner.h"
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -36,6 +36,26 @@ enum class DataSplitMode : int {
|
|||||||
kAuto = 0, kCol = 1, kRow = 2
|
kAuto = 0, kCol = 1, kRow = 2
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool IsFloat(const std::string& str) {
|
||||||
|
std::stringstream ss(str);
|
||||||
|
float f;
|
||||||
|
return !((ss >> std::noskipws >> f).rdstate() ^ std::ios_base::eofbit);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool IsInt(const std::string& str) {
|
||||||
|
std::stringstream ss(str);
|
||||||
|
int i;
|
||||||
|
return !((ss >> std::noskipws >> i).rdstate() ^ std::ios_base::eofbit);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string RenderParamVal(const std::string& str) {
|
||||||
|
if (IsFloat(str) || IsInt(str)) {
|
||||||
|
return str;
|
||||||
|
} else {
|
||||||
|
return std::string("'") + str + "'";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
DECLARE_FIELD_ENUM_CLASS(TreeMethod);
|
DECLARE_FIELD_ENUM_CLASS(TreeMethod);
|
||||||
@ -152,7 +172,7 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
|
|||||||
* \brief learner that performs gradient boosting for a specific objective
|
* \brief learner that performs gradient boosting for a specific objective
|
||||||
* function. It does training and prediction.
|
* function. It does training and prediction.
|
||||||
*/
|
*/
|
||||||
class LearnerImpl : public Learner, public LearnerTestHook {
|
class LearnerImpl : public Learner {
|
||||||
public:
|
public:
|
||||||
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
|
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
|
||||||
: cache_(std::move(cache)) {
|
: cache_(std::move(cache)) {
|
||||||
@ -330,6 +350,38 @@ class LearnerImpl : public Learner, public LearnerTestHook {
|
|||||||
if (mparam_.contain_extra_attrs != 0) {
|
if (mparam_.contain_extra_attrs != 0) {
|
||||||
std::vector<std::pair<std::string, std::string> > attr;
|
std::vector<std::pair<std::string, std::string> > attr;
|
||||||
fi->Read(&attr);
|
fi->Read(&attr);
|
||||||
|
for (auto& kv : attr) {
|
||||||
|
// Load `predictor`, `n_gpus`, `gpu_id` parameters from extra attributes
|
||||||
|
const std::string prefix = "SAVED_PARAM_";
|
||||||
|
if (kv.first.find(prefix) == 0) {
|
||||||
|
const std::string saved_param = kv.first.substr(prefix.length());
|
||||||
|
#ifdef XGBOOST_USE_CUDA
|
||||||
|
if (saved_param == "predictor" || saved_param == "n_gpus"
|
||||||
|
|| saved_param == "gpu_id") {
|
||||||
|
cfg_[saved_param] = kv.second;
|
||||||
|
LOG(INFO)
|
||||||
|
<< "Parameter '" << saved_param << "' has been recovered from "
|
||||||
|
<< "the saved model. It will be set to "
|
||||||
|
<< RenderParamVal(kv.second) << " for prediction. To "
|
||||||
|
<< "override the predictor behavior, explicitly set '"
|
||||||
|
<< saved_param << "' parameter as follows:\n"
|
||||||
|
<< " * Python package: bst.set_param('"
|
||||||
|
<< saved_param << "', [new value])\n"
|
||||||
|
<< " * R package: xgb.parameters(bst) <- list("
|
||||||
|
<< saved_param << " = [new value])\n"
|
||||||
|
<< " * JVM packages: bst.setParam(\""
|
||||||
|
<< saved_param << "\", [new value])";
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (saved_param == "predictor" && kv.second == "gpu_predictor") {
|
||||||
|
LOG(INFO) << "Parameter 'predictor' will be set to 'cpu_predictor' "
|
||||||
|
<< "since XGBoots wasn't compiled with GPU support.";
|
||||||
|
cfg_["predictor"] = "cpu_predictor";
|
||||||
|
kv.second = "cpu_predictor";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
attributes_ =
|
attributes_ =
|
||||||
std::map<std::string, std::string>(attr.begin(), attr.end());
|
std::map<std::string, std::string>(attr.begin(), attr.end());
|
||||||
}
|
}
|
||||||
@ -364,15 +416,28 @@ class LearnerImpl : public Learner, public LearnerTestHook {
|
|||||||
extra_attr.emplace_back("count_poisson_max_delta_step", it->second);
|
extra_attr.emplace_back("count_poisson_max_delta_step", it->second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
// Write `predictor`, `n_gpus`, `gpu_id` parameters as extra attributes
|
||||||
|
for (const auto& key : std::vector<std::string>{
|
||||||
|
"predictor", "n_gpus", "gpu_id"}) {
|
||||||
|
auto it = cfg_.find(key);
|
||||||
|
if (it != cfg_.end()) {
|
||||||
|
mparam.contain_extra_attrs = 1;
|
||||||
|
extra_attr.emplace_back("SAVED_PARAM_" + key, it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
fo->Write(&mparam, sizeof(LearnerModelParam));
|
fo->Write(&mparam, sizeof(LearnerModelParam));
|
||||||
fo->Write(name_obj_);
|
fo->Write(name_obj_);
|
||||||
fo->Write(name_gbm_);
|
fo->Write(name_gbm_);
|
||||||
gbm_->Save(fo);
|
gbm_->Save(fo);
|
||||||
if (mparam.contain_extra_attrs != 0) {
|
if (mparam.contain_extra_attrs != 0) {
|
||||||
std::vector<std::pair<std::string, std::string> > attr(
|
std::map<std::string, std::string> attr(attributes_);
|
||||||
attributes_.begin(), attributes_.end());
|
for (const auto& kv : extra_attr) {
|
||||||
attr.insert(attr.end(), extra_attr.begin(), extra_attr.end());
|
attr[kv.first] = kv.second;
|
||||||
fo->Write(attr);
|
}
|
||||||
|
fo->Write(std::vector<std::pair<std::string, std::string>>(
|
||||||
|
attr.begin(), attr.end()));
|
||||||
}
|
}
|
||||||
if (name_obj_ == "count:poisson") {
|
if (name_obj_ == "count:poisson") {
|
||||||
auto it = cfg_.find("max_delta_step");
|
auto it = cfg_.find("max_delta_step");
|
||||||
@ -504,6 +569,10 @@ class LearnerImpl : public Learner, public LearnerTestHook {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::map<std::string, std::string>& GetConfigurationArguments() const override {
|
||||||
|
return cfg_;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Revise `tree_method` and `updater` parameters after seeing the training
|
// Revise `tree_method` and `updater` parameters after seeing the training
|
||||||
// data matrix
|
// data matrix
|
||||||
@ -664,11 +733,6 @@ class LearnerImpl : public Learner, public LearnerTestHook {
|
|||||||
std::vector<std::shared_ptr<DMatrix> > cache_;
|
std::vector<std::shared_ptr<DMatrix> > cache_;
|
||||||
|
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
|
|
||||||
// diagnostic method reserved for C++ test learner.SelectTreeMethod
|
|
||||||
std::string GetUpdaterSequence() const override {
|
|
||||||
return cfg_.at("updater");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Learner* Learner::Create(
|
Learner* Learner::Create(
|
||||||
|
|||||||
@ -2,11 +2,25 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017 XGBoost contributors
|
* Copyright 2017 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
|
#include <dmlc/logging.h>
|
||||||
|
#include <dmlc/filesystem.h>
|
||||||
#include <xgboost/c_api.h>
|
#include <xgboost/c_api.h>
|
||||||
#include <xgboost/predictor.h>
|
#include <xgboost/predictor.h>
|
||||||
|
#include <string>
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline void CheckCAPICall(int ret) {
|
||||||
|
ASSERT_EQ(ret, 0) << XGBGetLastError();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace anonymous
|
||||||
|
|
||||||
|
extern const std::map<std::string, std::string>&
|
||||||
|
QueryBoosterConfigurationArguments(BoosterHandle handle);
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace predictor {
|
namespace predictor {
|
||||||
|
|
||||||
@ -77,6 +91,80 @@ TEST(gpu_predictor, Test) {
|
|||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test whether pickling preserves predictor parameters
|
||||||
|
TEST(gpu_predictor, MGPU_PicklingTest) {
|
||||||
|
int ngpu;
|
||||||
|
dh::safe_cuda(cudaGetDeviceCount(&ngpu));
|
||||||
|
|
||||||
|
dmlc::TemporaryDirectory tempdir;
|
||||||
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
|
CreateBigTestData(tmp_file, 600);
|
||||||
|
|
||||||
|
DMatrixHandle dmat[1];
|
||||||
|
BoosterHandle bst, bst2;
|
||||||
|
std::vector<bst_float> label;
|
||||||
|
for (int i = 0; i < 200; ++i) {
|
||||||
|
label.push_back((i % 2 ? 1 : 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load data matrix
|
||||||
|
CheckCAPICall(XGDMatrixCreateFromFile(tmp_file.c_str(), 0, &dmat[0]));
|
||||||
|
CheckCAPICall(XGDMatrixSetFloatInfo(dmat[0], "label", label.data(), 200));
|
||||||
|
// Create booster
|
||||||
|
CheckCAPICall(XGBoosterCreate(dmat, 1, &bst));
|
||||||
|
// Set parameters
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "seed", "0"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "base_score", "0.5"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "booster", "gbtree"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "learning_rate", "0.01"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "max_depth", "8"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "objective", "binary:logistic"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "seed", "123"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "tree_method", "gpu_hist"));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "n_gpus", std::to_string(ngpu).c_str()));
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst, "predictor", "gpu_predictor"));
|
||||||
|
|
||||||
|
// Run boosting iterations
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
CheckCAPICall(XGBoosterUpdateOneIter(bst, i, dmat[0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete matrix
|
||||||
|
CheckCAPICall(XGDMatrixFree(dmat[0]));
|
||||||
|
|
||||||
|
// Pickle
|
||||||
|
const char* dptr;
|
||||||
|
bst_ulong len;
|
||||||
|
std::string buf;
|
||||||
|
CheckCAPICall(XGBoosterGetModelRaw(bst, &len, &dptr));
|
||||||
|
buf = std::string(dptr, len);
|
||||||
|
CheckCAPICall(XGBoosterFree(bst));
|
||||||
|
|
||||||
|
// Unpickle
|
||||||
|
CheckCAPICall(XGBoosterCreate(nullptr, 0, &bst2));
|
||||||
|
CheckCAPICall(XGBoosterLoadModelFromBuffer(bst2, buf.c_str(), len));
|
||||||
|
|
||||||
|
{ // Query predictor
|
||||||
|
const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
|
||||||
|
ASSERT_EQ(kwargs.at("predictor"), "gpu_predictor");
|
||||||
|
ASSERT_EQ(kwargs.at("n_gpus"), std::to_string(ngpu).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
{ // Change n_gpus and query again
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst2, "n_gpus", "1"));
|
||||||
|
const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
|
||||||
|
ASSERT_EQ(kwargs.at("n_gpus"), "1");
|
||||||
|
}
|
||||||
|
|
||||||
|
{ // Change predictor and query again
|
||||||
|
CheckCAPICall(XGBoosterSetParam(bst2, "predictor", "cpu_predictor"));
|
||||||
|
const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
|
||||||
|
ASSERT_EQ(kwargs.at("predictor"), "cpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
|
CheckCAPICall(XGBoosterFree(bst2));
|
||||||
|
}
|
||||||
|
|
||||||
// multi-GPU predictor test
|
// multi-GPU predictor test
|
||||||
TEST(gpu_predictor, MGPU_Test) {
|
TEST(gpu_predictor, MGPU_Test) {
|
||||||
std::unique_ptr<Predictor> gpu_predictor =
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
|
|||||||
@ -2,20 +2,10 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "helpers.h"
|
#include "helpers.h"
|
||||||
#include "./test_learner.h"
|
|
||||||
#include "xgboost/learner.h"
|
#include "xgboost/learner.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
class LearnerTestHookAdapter {
|
|
||||||
public:
|
|
||||||
static inline std::string GetUpdaterSequence(const Learner* learner) {
|
|
||||||
const LearnerTestHook* hook = dynamic_cast<const LearnerTestHook*>(learner);
|
|
||||||
CHECK(hook) << "LearnerImpl did not inherit from LearnerTestHook";
|
|
||||||
return hook->GetUpdaterSequence();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST(learner, Test) {
|
TEST(learner, Test) {
|
||||||
typedef std::pair<std::string, std::string> arg;
|
typedef std::pair<std::string, std::string> arg;
|
||||||
auto args = {arg("tree_method", "exact")};
|
auto args = {arg("tree_method", "exact")};
|
||||||
@ -35,20 +25,20 @@ TEST(learner, SelectTreeMethod) {
|
|||||||
|
|
||||||
// Test if `tree_method` can be set
|
// Test if `tree_method` can be set
|
||||||
learner->Configure({arg("tree_method", "approx")});
|
learner->Configure({arg("tree_method", "approx")});
|
||||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_histmaker,prune");
|
"grow_histmaker,prune");
|
||||||
learner->Configure({arg("tree_method", "exact")});
|
learner->Configure({arg("tree_method", "exact")});
|
||||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_colmaker,prune");
|
"grow_colmaker,prune");
|
||||||
learner->Configure({arg("tree_method", "hist")});
|
learner->Configure({arg("tree_method", "hist")});
|
||||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_fast_histmaker");
|
"grow_fast_histmaker");
|
||||||
#ifdef XGBOOST_USE_CUDA
|
#ifdef XGBOOST_USE_CUDA
|
||||||
learner->Configure({arg("tree_method", "gpu_exact")});
|
learner->Configure({arg("tree_method", "gpu_exact")});
|
||||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_gpu,prune");
|
"grow_gpu,prune");
|
||||||
learner->Configure({arg("tree_method", "gpu_hist")});
|
learner->Configure({arg("tree_method", "gpu_hist")});
|
||||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_gpu_hist");
|
"grow_gpu_hist");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2018 by Contributors
|
|
||||||
* \file test_learner.h
|
|
||||||
* \brief Hook to access implementation class of Learner
|
|
||||||
* \author Hyunsu Philip Cho
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
|
||||||
#define XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
class LearnerTestHook {
|
|
||||||
private:
|
|
||||||
virtual std::string GetUpdaterSequence() const = 0;
|
|
||||||
// allow friend access to C++ tests for Learner
|
|
||||||
friend class LearnerTestHookAdapter;
|
|
||||||
};
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
#endif // XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
|
||||||
Loading…
x
Reference in New Issue
Block a user