[Blocking] Fix #3840: Clean up logic for parsing tree_method parameter (#3849)

* Clean up logic for converting tree_method to updater sequence

* Use C++11 enum class for extra safety

Compiler will give warnings if switch statements don't handle all
possible values of C++11 enum class.

Also allow enum class to be used as DMLC parameter.

* Fix compiler error + lint

* Address reviewer comment

* Better docstring for DECLARE_FIELD_ENUM_CLASS

* Fix lint

* Add C++ test to see if tree_method is recognized

* Fix clang-tidy error

* Add test_learner.h to R package

* Update comments

* Fix lint error
This commit is contained in:
Philip Hyunsu Cho
2018-11-01 19:33:35 -07:00
committed by GitHub
parent 583c88bce7
commit ad68865d6b
6 changed files with 358 additions and 54 deletions

View File

@@ -0,0 +1,55 @@
#include "../../../src/common/enum_class_param.h"
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
enum class Foo : int {
kBar = 0, kFrog = 1, kCat = 2, kDog = 3
};
DECLARE_FIELD_ENUM_CLASS(Foo);
struct MyParam : dmlc::Parameter<MyParam> {
Foo foo;
int bar;
DMLC_DECLARE_PARAMETER(MyParam) {
DMLC_DECLARE_FIELD(foo)
.set_default(Foo::kBar)
.add_enum("bar", Foo::kBar)
.add_enum("frog", Foo::kFrog)
.add_enum("cat", Foo::kCat)
.add_enum("dog", Foo::kDog);
DMLC_DECLARE_FIELD(bar)
.set_default(-1);
}
};
DMLC_REGISTER_PARAMETER(MyParam);
TEST(EnumClassParam, Basic) {
MyParam param;
std::map<std::string, std::string> kwargs{
{"foo", "frog"}, {"bar", "10"}
};
// try initializing
param.Init(kwargs);
ASSERT_EQ(param.foo, Foo::kFrog);
ASSERT_EQ(param.bar, 10);
// try all possible enum values
kwargs["foo"] = "bar";
param.Init(kwargs);
ASSERT_EQ(param.foo, Foo::kBar);
kwargs["foo"] = "frog";
param.Init(kwargs);
ASSERT_EQ(param.foo, Foo::kFrog);
kwargs["foo"] = "cat";
param.Init(kwargs);
ASSERT_EQ(param.foo, Foo::kCat);
kwargs["foo"] = "dog";
param.Init(kwargs);
ASSERT_EQ(param.foo, Foo::kDog);
// try setting non-existent enum value
kwargs["foo"] = "human";
ASSERT_THROW(param.Init(kwargs), dmlc::ParamError);
}

View File

@@ -2,9 +2,20 @@
#include <gtest/gtest.h>
#include <vector>
#include "helpers.h"
#include "./test_learner.h"
#include "xgboost/learner.h"
namespace xgboost {
class LearnerTestHookAdapter {
public:
static inline std::string GetUpdaterSequence(const Learner* learner) {
const LearnerTestHook* hook = dynamic_cast<const LearnerTestHook*>(learner);
CHECK(hook) << "LearnerImpl did not inherit from LearnerTestHook";
return hook->GetUpdaterSequence();
}
};
TEST(learner, Test) {
typedef std::pair<std::string, std::string> arg;
auto args = {arg("tree_method", "exact")};
@@ -15,4 +26,33 @@ TEST(learner, Test) {
delete mat_ptr;
}
TEST(learner, SelectTreeMethod) {
using arg = std::pair<std::string, std::string>;
auto mat_ptr = CreateDMatrix(10, 10, 0);
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
// Test if `tree_method` can be set
learner->Configure({arg("tree_method", "approx")});
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
"grow_histmaker,prune");
learner->Configure({arg("tree_method", "exact")});
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
"grow_colmaker,prune");
learner->Configure({arg("tree_method", "hist")});
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
"grow_fast_histmaker");
#ifdef XGBOOST_USE_CUDA
learner->Configure({arg("tree_method", "gpu_exact")});
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
"grow_gpu,prune");
learner->Configure({arg("tree_method", "gpu_hist")});
ASSERT_EQ(LearnerTestHookAdapter::GetUpdaterSequence(learner.get()),
"grow_gpu_hist");
#endif
delete mat_ptr;
}
} // namespace xgboost

22
tests/cpp/test_learner.h Normal file
View File

@@ -0,0 +1,22 @@
/*!
* Copyright 2018 by Contributors
* \file test_learner.h
* \brief Hook to access implementation class of Learner
* \author Hyunsu Philip Cho
*/
#ifndef XGBOOST_TESTS_CPP_TEST_LEARNER_H_
#define XGBOOST_TESTS_CPP_TEST_LEARNER_H_
#include <string>
namespace xgboost {
class LearnerTestHook {
private:
virtual std::string GetUpdaterSequence() const = 0;
// allow friend access to C++ tests for Learner
friend class LearnerTestHookAdapter;
};
} // namespace xgboost
#endif // XGBOOST_TESTS_CPP_TEST_LEARNER_H_