* Clean up logic for converting tree_method to updater sequence * Use C++11 enum class for extra safety Compiler will give warnings if switch statements don't handle all possible values of C++11 enum class. Also allow enum class to be used as DMLC parameter. * Fix compiler error + lint * Address reviewer comment * Better docstring for DECLARE_FIELD_ENUM_CLASS * Fix lint * Add C++ test to see if tree_method is recognized * Fix clang-tidy error * Add test_learner.h to R package * Update comments * Fix lint error
This commit is contained in:
parent
583c88bce7
commit
ad68865d6b
2
Makefile
2
Makefile
@ -250,6 +250,8 @@ Rpack: clean_all
|
||||
cp -r src xgboost/src/src
|
||||
cp -r include xgboost/src/include
|
||||
cp -r amalgamation xgboost/src/amalgamation
|
||||
mkdir -p xgboost/src/tests/cpp
|
||||
cp tests/cpp/test_learner.h xgboost/src/tests/cpp
|
||||
mkdir -p xgboost/src/rabit
|
||||
cp -r rabit/include xgboost/src/rabit/include
|
||||
cp -r rabit/src xgboost/src/rabit/src
|
||||
|
||||
81
src/common/enum_class_param.h
Normal file
81
src/common/enum_class_param.h
Normal file
@ -0,0 +1,81 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \file enum_class_param.h
|
||||
* \brief macro for using C++11 enum class as DMLC parameter
|
||||
* \author Hyunsu Philip Cho
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_COMMON_ENUM_CLASS_PARAM_H_
|
||||
#define XGBOOST_COMMON_ENUM_CLASS_PARAM_H_
|
||||
|
||||
#include <dmlc/parameter.h>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
/*!
|
||||
* \brief Specialization of FieldEntry for enum class (backed by int)
|
||||
*
|
||||
* Use this macro to use C++11 enum class as DMLC parameters
|
||||
*
|
||||
* Usage:
|
||||
*
|
||||
* \code{.cpp}
|
||||
*
|
||||
* // enum class must inherit from int type
|
||||
* enum class Foo : int {
|
||||
* kBar = 0, kFrog = 1, kCat = 2, kDog = 3
|
||||
* };
|
||||
*
|
||||
* // This line is needed to prevent compilation error
|
||||
* DECLARE_FIELD_ENUM_CLASS(Foo);
|
||||
*
|
||||
* // Now define DMLC parameter as usual;
|
||||
* // enum classes can now be members.
|
||||
* struct MyParam : dmlc::Parameter<MyParam> {
|
||||
* Foo foo;
|
||||
* DMLC_DECLARE_PARAMETER(MyParam) {
|
||||
* DMLC_DECLARE_FIELD(foo)
|
||||
* .set_default(Foo::kBar)
|
||||
* .add_enum("bar", Foo::kBar)
|
||||
* .add_enum("frog", Foo::kFrog)
|
||||
* .add_enum("cat", Foo::kCat)
|
||||
* .add_enum("dog", Foo::kDog);
|
||||
* }
|
||||
* };
|
||||
*
|
||||
* DMLC_REGISTER_PARAMETER(MyParam);
|
||||
* \endcode
|
||||
*/
|
||||
#define DECLARE_FIELD_ENUM_CLASS(EnumClass) \
|
||||
namespace dmlc { \
|
||||
namespace parameter { \
|
||||
template <> \
|
||||
class FieldEntry<EnumClass> : public FieldEntry<int> { \
|
||||
public: \
|
||||
FieldEntry<EnumClass>() { \
|
||||
static_assert( \
|
||||
std::is_same<int, typename std::underlying_type<EnumClass>::type>::value, \
|
||||
"enum class must be backed by int"); \
|
||||
is_enum_ = true; \
|
||||
} \
|
||||
using Super = FieldEntry<int>; \
|
||||
void Set(void *head, const std::string &value) const override { \
|
||||
Super::Set(head, value); \
|
||||
} \
|
||||
inline FieldEntry<EnumClass>& add_enum(const std::string &key, EnumClass value) { \
|
||||
Super::add_enum(key, static_cast<int>(value)); \
|
||||
return *this; \
|
||||
} \
|
||||
inline FieldEntry<EnumClass>& set_default(const EnumClass& default_value) { \
|
||||
default_value_ = static_cast<int>(default_value); \
|
||||
has_default_ = true; \
|
||||
return *this; \
|
||||
} \
|
||||
inline void Init(const std::string &key, void *head, EnumClass& ref) { /* NOLINT */ \
|
||||
Super::Init(key, head, *reinterpret_cast<int*>(&ref)); \
|
||||
} \
|
||||
}; \
|
||||
} /* namespace parameter */ \
|
||||
} /* namespace dmlc */
|
||||
|
||||
#endif // XGBOOST_COMMON_ENUM_CLASS_PARAM_H_
|
||||
212
src/learner.cc
212
src/learner.cc
@ -19,14 +19,28 @@
|
||||
#include "./common/host_device_vector.h"
|
||||
#include "./common/io.h"
|
||||
#include "./common/random.h"
|
||||
#include "common/timer.h"
|
||||
#include "./common/enum_class_param.h"
|
||||
#include "./common/timer.h"
|
||||
#include "../tests/cpp/test_learner.h"
|
||||
|
||||
namespace {
|
||||
|
||||
const char* kMaxDeltaStepDefaultValue = "0.7";
|
||||
|
||||
enum class TreeMethod : int {
|
||||
kAuto = 0, kApprox = 1, kExact = 2, kHist = 3,
|
||||
kGPUExact = 4, kGPUHist = 5
|
||||
};
|
||||
|
||||
enum class DataSplitMode : int {
|
||||
kAuto = 0, kCol = 1, kRow = 2
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
DECLARE_FIELD_ENUM_CLASS(TreeMethod);
|
||||
DECLARE_FIELD_ENUM_CLASS(DataSplitMode);
|
||||
|
||||
namespace xgboost {
|
||||
// implementation of base learner.
|
||||
bool Learner::AllowLazyCheckPoint() const {
|
||||
@ -80,9 +94,9 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
|
||||
// whether seed the PRNG each iteration
|
||||
bool seed_per_iteration;
|
||||
// data split mode, can be row, col, or none.
|
||||
int dsplit;
|
||||
DataSplitMode dsplit;
|
||||
// tree construction method
|
||||
int tree_method;
|
||||
TreeMethod tree_method;
|
||||
// internal test flag
|
||||
std::string test_flag;
|
||||
// number of threads to use if OpenMP is enabled
|
||||
@ -103,19 +117,19 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
|
||||
"this option will be switched on automatically on distributed "
|
||||
"mode.");
|
||||
DMLC_DECLARE_FIELD(dsplit)
|
||||
.set_default(0)
|
||||
.add_enum("auto", 0)
|
||||
.add_enum("col", 1)
|
||||
.add_enum("row", 2)
|
||||
.set_default(DataSplitMode::kAuto)
|
||||
.add_enum("auto", DataSplitMode::kAuto)
|
||||
.add_enum("col", DataSplitMode::kCol)
|
||||
.add_enum("row", DataSplitMode::kRow)
|
||||
.describe("Data split mode for distributed training.");
|
||||
DMLC_DECLARE_FIELD(tree_method)
|
||||
.set_default(0)
|
||||
.add_enum("auto", 0)
|
||||
.add_enum("approx", 1)
|
||||
.add_enum("exact", 2)
|
||||
.add_enum("hist", 3)
|
||||
.add_enum("gpu_exact", 4)
|
||||
.add_enum("gpu_hist", 5)
|
||||
.set_default(TreeMethod::kAuto)
|
||||
.add_enum("auto", TreeMethod::kAuto)
|
||||
.add_enum("approx", TreeMethod::kApprox)
|
||||
.add_enum("exact", TreeMethod::kExact)
|
||||
.add_enum("hist", TreeMethod::kHist)
|
||||
.add_enum("gpu_exact", TreeMethod::kGPUExact)
|
||||
.add_enum("gpu_hist", TreeMethod::kGPUHist)
|
||||
.describe("Choice of tree construction method.");
|
||||
DMLC_DECLARE_FIELD(test_flag).set_default("").describe(
|
||||
"Internal test flag");
|
||||
@ -138,7 +152,7 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
|
||||
* \brief learner that performs gradient boosting for a specific objective
|
||||
* function. It does training and prediction.
|
||||
*/
|
||||
class LearnerImpl : public Learner {
|
||||
class LearnerImpl : public Learner, public LearnerTestHook {
|
||||
public:
|
||||
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
|
||||
: cache_(std::move(cache)) {
|
||||
@ -154,37 +168,49 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
void ConfigureUpdaters() {
|
||||
if (tparam_.tree_method == 0 || tparam_.tree_method == 1 ||
|
||||
tparam_.tree_method == 2) {
|
||||
if (cfg_.count("updater") == 0) {
|
||||
if (tparam_.dsplit == 1) {
|
||||
cfg_["updater"] = "distcol";
|
||||
} else if (tparam_.dsplit == 2) {
|
||||
cfg_["updater"] = "grow_histmaker,prune";
|
||||
}
|
||||
}
|
||||
} else if (tparam_.tree_method == 3) {
|
||||
/* histogram-based algorithm */
|
||||
LOG(CONSOLE) << "Tree method is selected to be \'hist\', which uses a "
|
||||
"single updater "
|
||||
<< "grow_fast_histmaker.";
|
||||
/* Choose updaters according to tree_method parameters */
|
||||
if (cfg_.count("updater") > 0) {
|
||||
LOG(CONSOLE) << "DANGER AHEAD: You have manually specified `updater` "
|
||||
"parameter. The `tree_method` parameter will be ignored. "
|
||||
"Incorrect sequence of updaters will produce undefined "
|
||||
"behavior. For common uses, we recommend using "
|
||||
"`tree_method` parameter instead.";
|
||||
return;
|
||||
}
|
||||
|
||||
switch (tparam_.tree_method) {
|
||||
case TreeMethod::kAuto:
|
||||
// Use heuristic to choose between 'exact' and 'approx'
|
||||
// This choice is deferred to PerformTreeMethodHeuristic().
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
cfg_["updater"] = "grow_histmaker,prune";
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
cfg_["updater"] = "grow_colmaker,prune";
|
||||
break;
|
||||
case TreeMethod::kHist:
|
||||
LOG(CONSOLE) << "Tree method is selected to be 'hist', which uses a "
|
||||
"single updater grow_fast_histmaker.";
|
||||
cfg_["updater"] = "grow_fast_histmaker";
|
||||
} else if (tparam_.tree_method == 4) {
|
||||
break;
|
||||
case TreeMethod::kGPUExact:
|
||||
this->AssertGPUSupport();
|
||||
if (cfg_.count("updater") == 0) {
|
||||
cfg_["updater"] = "grow_gpu,prune";
|
||||
}
|
||||
cfg_["updater"] = "grow_gpu,prune";
|
||||
if (cfg_.count("predictor") == 0) {
|
||||
cfg_["predictor"] = "gpu_predictor";
|
||||
}
|
||||
} else if (tparam_.tree_method == 5) {
|
||||
break;
|
||||
case TreeMethod::kGPUHist:
|
||||
this->AssertGPUSupport();
|
||||
if (cfg_.count("updater") == 0) {
|
||||
cfg_["updater"] = "grow_gpu_hist";
|
||||
}
|
||||
cfg_["updater"] = "grow_gpu_hist";
|
||||
if (cfg_.count("predictor") == 0) {
|
||||
cfg_["predictor"] = "gpu_predictor";
|
||||
}
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown tree_method ("
|
||||
<< static_cast<int>(tparam_.tree_method) << ") detected";
|
||||
}
|
||||
}
|
||||
|
||||
@ -214,8 +240,8 @@ class LearnerImpl : public Learner {
|
||||
|
||||
// add additional parameters
|
||||
// These are cosntraints that need to be satisfied.
|
||||
if (tparam_.dsplit == 0 && rabit::IsDistributed()) {
|
||||
tparam_.dsplit = 2;
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
if (cfg_.count("num_class") != 0) {
|
||||
@ -376,7 +402,7 @@ class LearnerImpl : public Learner {
|
||||
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->LazyInitDMatrix(train);
|
||||
this->PerformTreeMethodHeuristic(train);
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_);
|
||||
monitor_.Stop("PredictRaw");
|
||||
@ -393,7 +419,7 @@ class LearnerImpl : public Learner {
|
||||
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->LazyInitDMatrix(train);
|
||||
this->PerformTreeMethodHeuristic(train);
|
||||
gbm_->DoBoost(train, in_gpair);
|
||||
monitor_.Stop("BoostOneIter");
|
||||
}
|
||||
@ -412,7 +438,7 @@ class LearnerImpl : public Learner {
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
<< ev->Eval(preds_.ConstHostVector(), data_sets[i]->Info(),
|
||||
tparam_.dsplit == 2);
|
||||
tparam_.dsplit == DataSplitMode::kRow);
|
||||
}
|
||||
}
|
||||
|
||||
@ -456,7 +482,7 @@ class LearnerImpl : public Learner {
|
||||
obj_->EvalTransform(&preds_);
|
||||
return std::make_pair(metric,
|
||||
ev->Eval(preds_.ConstHostVector(), data->Info(),
|
||||
tparam_.dsplit == 2));
|
||||
tparam_.dsplit == DataSplitMode::kRow));
|
||||
}
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
@ -479,21 +505,94 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
protected:
|
||||
// check if p_train is ready to used by training.
|
||||
// if not, initialize the column access.
|
||||
inline void LazyInitDMatrix(DMatrix* p_train) {
|
||||
if (tparam_.tree_method == 3 || tparam_.tree_method == 4 ||
|
||||
tparam_.tree_method == 5 || name_gbm_ == "gblinear") {
|
||||
// Revise `tree_method` and `updater` parameters after seeing the training
|
||||
// data matrix
|
||||
inline void PerformTreeMethodHeuristic(DMatrix* p_train) {
|
||||
if (name_gbm_ != "gbtree" || cfg_.count("updater") > 0) {
|
||||
// 1. This method is not applicable for non-tree learners
|
||||
// 2. This method is disabled when `updater` parameter is explicitly
|
||||
// set, since only experts are expected to do so.
|
||||
return;
|
||||
}
|
||||
|
||||
if (!p_train->SingleColBlock() && cfg_.count("updater") == 0) {
|
||||
if (tparam_.tree_method == 2) {
|
||||
LOG(CONSOLE) << "tree method is set to be 'exact',"
|
||||
<< " but currently we are only able to proceed with "
|
||||
"approximate algorithm";
|
||||
const TreeMethod current_tree_method = tparam_.tree_method;
|
||||
if (rabit::IsDistributed()) {
|
||||
/* Choose tree_method='approx' when distributed training is activated */
|
||||
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
|
||||
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
|
||||
if (tparam_.dsplit == DataSplitMode::kCol) {
|
||||
// 'distcol' updater hidden until it becomes functional again
|
||||
// See discussion at https://github.com/dmlc/xgboost/issues/1832
|
||||
LOG(FATAL) << "Column-wise data split is currently not supported.";
|
||||
}
|
||||
cfg_["updater"] = "grow_histmaker,prune";
|
||||
switch (current_tree_method) {
|
||||
case TreeMethod::kAuto:
|
||||
LOG(CONSOLE) << "Tree method is automatically selected to be 'approx' "
|
||||
"for distributed training.";
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
// things are okay, do nothing
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
case TreeMethod::kHist:
|
||||
LOG(CONSOLE) << "Tree method was set to be '"
|
||||
<< (current_tree_method == TreeMethod::kExact ?
|
||||
"exact" : "hist")
|
||||
<< "', but only 'approx' is available for distributed "
|
||||
"training. The `tree_method` parameter is now being "
|
||||
"changed to 'approx'";
|
||||
break;
|
||||
case TreeMethod::kGPUExact:
|
||||
case TreeMethod::kGPUHist:
|
||||
LOG(FATAL) << "Distributed training is not available with GPU algoritms";
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown tree_method ("
|
||||
<< static_cast<int>(current_tree_method) << ") detected";
|
||||
}
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
} else if (!p_train->SingleColBlock()) {
|
||||
/* Some tree methods are not available for external-memory DMatrix */
|
||||
switch (current_tree_method) {
|
||||
case TreeMethod::kAuto:
|
||||
LOG(CONSOLE) << "Tree method is automatically set to 'approx' "
|
||||
"since external-memory data matrix is used.";
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
// things are okay, do nothing
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
LOG(CONSOLE) << "Tree method was set to be 'exact', "
|
||||
"but currently we are only able to proceed with "
|
||||
"approximate algorithm ('approx') because external-"
|
||||
"memory data matrix is used.";
|
||||
break;
|
||||
case TreeMethod::kHist:
|
||||
// things are okay, do nothing
|
||||
break;
|
||||
case TreeMethod::kGPUExact:
|
||||
case TreeMethod::kGPUHist:
|
||||
LOG(FATAL)
|
||||
<< "External-memory data matrix is not available with GPU algorithms";
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown tree_method ("
|
||||
<< static_cast<int>(current_tree_method) << ") detected";
|
||||
}
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
} else if (p_train->Info().num_row_ >= (4UL << 20UL)
|
||||
&& current_tree_method == TreeMethod::kAuto) {
|
||||
/* Choose tree_method='approx' automatically for large data matrix */
|
||||
LOG(CONSOLE) << "Tree method is automatically selected to be "
|
||||
"'approx' for faster speed. To use old behavior "
|
||||
"(exact greedy algorithm on single machine), "
|
||||
"set tree_method to 'exact'.";
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
}
|
||||
|
||||
/* If tree_method was changed, re-configure updaters and gradient boosters */
|
||||
if (tparam_.tree_method != current_tree_method) {
|
||||
ConfigureUpdaters();
|
||||
if (gbm_ != nullptr) {
|
||||
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
@ -565,6 +664,11 @@ class LearnerImpl : public Learner {
|
||||
std::vector<std::shared_ptr<DMatrix> > cache_;
|
||||
|
||||
common::Monitor monitor_;
|
||||
|
||||
// diagnostic method reserved for C++ test learner.SelectTreeMethod
|
||||
std::string GetUpdaterSequence() const override {
|
||||
return cfg_.at("updater");
|
||||
}
|
||||
};
|
||||
|
||||
Learner* Learner::Create(
|
||||
|
||||
55
tests/cpp/common/test_enum_class_param.cc
Normal file
55
tests/cpp/common/test_enum_class_param.cc
Normal file
@ -0,0 +1,55 @@
|
||||
#include "../../../src/common/enum_class_param.h"
|
||||
#include <dmlc/parameter.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
enum class Foo : int {
|
||||
kBar = 0, kFrog = 1, kCat = 2, kDog = 3
|
||||
};
|
||||
|
||||
DECLARE_FIELD_ENUM_CLASS(Foo);
|
||||
|
||||
struct MyParam : dmlc::Parameter<MyParam> {
|
||||
Foo foo;
|
||||
int bar;
|
||||
DMLC_DECLARE_PARAMETER(MyParam) {
|
||||
DMLC_DECLARE_FIELD(foo)
|
||||
.set_default(Foo::kBar)
|
||||
.add_enum("bar", Foo::kBar)
|
||||
.add_enum("frog", Foo::kFrog)
|
||||
.add_enum("cat", Foo::kCat)
|
||||
.add_enum("dog", Foo::kDog);
|
||||
DMLC_DECLARE_FIELD(bar)
|
||||
.set_default(-1);
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(MyParam);
|
||||
|
||||
TEST(EnumClassParam, Basic) {
|
||||
MyParam param;
|
||||
std::map<std::string, std::string> kwargs{
|
||||
{"foo", "frog"}, {"bar", "10"}
|
||||
};
|
||||
// try initializing
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kFrog);
|
||||
ASSERT_EQ(param.bar, 10);
|
||||
|
||||
// try all possible enum values
|
||||
kwargs["foo"] = "bar";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kBar);
|
||||
kwargs["foo"] = "frog";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kFrog);
|
||||
kwargs["foo"] = "cat";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kCat);
|
||||
kwargs["foo"] = "dog";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kDog);
|
||||
|
||||
// try setting non-existent enum value
|
||||
kwargs["foo"] = "human";
|
||||
ASSERT_THROW(param.Init(kwargs), dmlc::ParamError);
|
||||
}
|
||||
@ -2,9 +2,20 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include "helpers.h"
|
||||
#include "./test_learner.h"
|
||||
#include "xgboost/learner.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class LearnerTestHookAdapter {
|
||||
public:
|
||||
static inline std::string GetUpdaterSequence(const Learner* learner) {
|
||||
const LearnerTestHook* hook = dynamic_cast<const LearnerTestHook*>(learner);
|
||||
CHECK(hook) << "LearnerImpl did not inherit from LearnerTestHook";
|
||||
return hook->GetUpdaterSequence();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(learner, Test) {
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
auto args = {arg("tree_method", "exact")};
|
||||
@ -15,4 +26,33 @@ TEST(learner, Test) {
|
||||
|
||||
delete mat_ptr;
|
||||
}
|
||||
|
||||
TEST(learner, SelectTreeMethod) {
|
||||
using arg = std::pair<std::string, std::string>;
|
||||
auto mat_ptr = CreateDMatrix(10, 10, 0);
|
||||
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
|
||||
// Test if `tree_method` can be set
|
||||
learner->Configure({arg("tree_method", "approx")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_histmaker,prune");
|
||||
learner->Configure({arg("tree_method", "exact")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_colmaker,prune");
|
||||
learner->Configure({arg("tree_method", "hist")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_fast_histmaker");
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
learner->Configure({arg("tree_method", "gpu_exact")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_gpu,prune");
|
||||
learner->Configure({arg("tree_method", "gpu_hist")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_gpu_hist");
|
||||
#endif
|
||||
|
||||
delete mat_ptr;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
22
tests/cpp/test_learner.h
Normal file
22
tests/cpp/test_learner.h
Normal file
@ -0,0 +1,22 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \file test_learner.h
|
||||
* \brief Hook to access implementation class of Learner
|
||||
* \author Hyunsu Philip Cho
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
||||
#define XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
class LearnerTestHook {
|
||||
private:
|
||||
virtual std::string GetUpdaterSequence() const = 0;
|
||||
// allow friend access to C++ tests for Learner
|
||||
friend class LearnerTestHookAdapter;
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
||||
Loading…
x
Reference in New Issue
Block a user