* Clean up logic for converting tree_method to updater sequence * Use C++11 enum class for extra safety Compiler will give warnings if switch statements don't handle all possible values of C++11 enum class. Also allow enum class to be used as DMLC parameter. * Fix compiler error + lint * Address reviewer comment * Better docstring for DECLARE_FIELD_ENUM_CLASS * Fix lint * Add C++ test to see if tree_method is recognized * Fix clang-tidy error * Add test_learner.h to R package * Update comments * Fix lint error
This commit is contained in:
committed by
GitHub
parent
583c88bce7
commit
ad68865d6b
55
tests/cpp/common/test_enum_class_param.cc
Normal file
55
tests/cpp/common/test_enum_class_param.cc
Normal file
@@ -0,0 +1,55 @@
|
||||
#include "../../../src/common/enum_class_param.h"
|
||||
#include <dmlc/parameter.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
enum class Foo : int {
|
||||
kBar = 0, kFrog = 1, kCat = 2, kDog = 3
|
||||
};
|
||||
|
||||
DECLARE_FIELD_ENUM_CLASS(Foo);
|
||||
|
||||
struct MyParam : dmlc::Parameter<MyParam> {
|
||||
Foo foo;
|
||||
int bar;
|
||||
DMLC_DECLARE_PARAMETER(MyParam) {
|
||||
DMLC_DECLARE_FIELD(foo)
|
||||
.set_default(Foo::kBar)
|
||||
.add_enum("bar", Foo::kBar)
|
||||
.add_enum("frog", Foo::kFrog)
|
||||
.add_enum("cat", Foo::kCat)
|
||||
.add_enum("dog", Foo::kDog);
|
||||
DMLC_DECLARE_FIELD(bar)
|
||||
.set_default(-1);
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(MyParam);
|
||||
|
||||
TEST(EnumClassParam, Basic) {
|
||||
MyParam param;
|
||||
std::map<std::string, std::string> kwargs{
|
||||
{"foo", "frog"}, {"bar", "10"}
|
||||
};
|
||||
// try initializing
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kFrog);
|
||||
ASSERT_EQ(param.bar, 10);
|
||||
|
||||
// try all possible enum values
|
||||
kwargs["foo"] = "bar";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kBar);
|
||||
kwargs["foo"] = "frog";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kFrog);
|
||||
kwargs["foo"] = "cat";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kCat);
|
||||
kwargs["foo"] = "dog";
|
||||
param.Init(kwargs);
|
||||
ASSERT_EQ(param.foo, Foo::kDog);
|
||||
|
||||
// try setting non-existent enum value
|
||||
kwargs["foo"] = "human";
|
||||
ASSERT_THROW(param.Init(kwargs), dmlc::ParamError);
|
||||
}
|
||||
@@ -2,9 +2,20 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include "helpers.h"
|
||||
#include "./test_learner.h"
|
||||
#include "xgboost/learner.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class LearnerTestHookAdapter {
|
||||
public:
|
||||
static inline std::string GetUpdaterSequence(const Learner* learner) {
|
||||
const LearnerTestHook* hook = dynamic_cast<const LearnerTestHook*>(learner);
|
||||
CHECK(hook) << "LearnerImpl did not inherit from LearnerTestHook";
|
||||
return hook->GetUpdaterSequence();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(learner, Test) {
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
auto args = {arg("tree_method", "exact")};
|
||||
@@ -15,4 +26,33 @@ TEST(learner, Test) {
|
||||
|
||||
delete mat_ptr;
|
||||
}
|
||||
|
||||
TEST(learner, SelectTreeMethod) {
|
||||
using arg = std::pair<std::string, std::string>;
|
||||
auto mat_ptr = CreateDMatrix(10, 10, 0);
|
||||
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
|
||||
// Test if `tree_method` can be set
|
||||
learner->Configure({arg("tree_method", "approx")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_histmaker,prune");
|
||||
learner->Configure({arg("tree_method", "exact")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_colmaker,prune");
|
||||
learner->Configure({arg("tree_method", "hist")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_fast_histmaker");
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
learner->Configure({arg("tree_method", "gpu_exact")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_gpu,prune");
|
||||
learner->Configure({arg("tree_method", "gpu_hist")});
|
||||
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
|
||||
"grow_gpu_hist");
|
||||
#endif
|
||||
|
||||
delete mat_ptr;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
22
tests/cpp/test_learner.h
Normal file
22
tests/cpp/test_learner.h
Normal file
@@ -0,0 +1,22 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \file test_learner.h
|
||||
* \brief Hook to access implementation class of Learner
|
||||
* \author Hyunsu Philip Cho
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
||||
#define XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
class LearnerTestHook {
|
||||
private:
|
||||
virtual std::string GetUpdaterSequence() const = 0;
|
||||
// allow friend access to C++ tests for Learner
|
||||
friend class LearnerTestHookAdapter;
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TESTS_CPP_TEST_LEARNER_H_
|
||||
Reference in New Issue
Block a user