Use `UpdateAllowUnknown' for non-model related parameter. (#4961)

* Use `UpdateAllowUnknown' for non-model related parameter.

Model parameter can not pack an additional boolean value due to binary IO
format.  This commit deals only with non-model related parameter configuration.

* Add tidy command line arg for use-dmlc-gtest.
This commit is contained in:
Jiaming Yuan 2019-10-23 05:50:12 -04:00 committed by GitHub
parent f24be2efb4
commit ac457c56a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 189 additions and 112 deletions

View File

@ -225,7 +225,7 @@ using GradientPairInteger = detail::GradientPairInternal<int64_t>;
using Args = std::vector<std::pair<std::string, std::string> >;
/*! \brief small eps gap for minimum split decision. */
const bst_float kRtEps = 1e-6f;
constexpr bst_float kRtEps = 1e-6f;
/*! \brief define unsigned long for openmp loop */
using omp_ulong = dmlc::omp_ulong; // NOLINT

View File

@ -5,14 +5,13 @@
#ifndef XGBOOST_GENERIC_PARAMETERS_H_
#define XGBOOST_GENERIC_PARAMETERS_H_
#include <dmlc/parameter.h>
#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <string>
namespace xgboost {
struct GenericParameter : public dmlc::Parameter<GenericParameter> {
struct GenericParameter : public XGBoostParameter<GenericParameter> {
// stored random seed
int seed;
// whether seed the PRNG each iteration

View File

@ -5,6 +5,7 @@
#define XGBOOST_JSON_H_
#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <string>
#include <map>
@ -533,7 +534,7 @@ using Null = JsonNull;
// Utils tailored for XGBoost.
template <typename Type>
Object toJson(dmlc::Parameter<Type> const& param) {
Object toJson(XGBoostParameter<Type> const& param) {
Object obj;
for (auto const& kv : param.__DICT__()) {
obj[kv.first] = kv.second;
@ -542,13 +543,13 @@ Object toJson(dmlc::Parameter<Type> const& param) {
}
template <typename Type>
void fromJson(Json const& obj, dmlc::Parameter<Type>* param) {
void fromJson(Json const& obj, XGBoostParameter<Type>* param) {
auto const& j_param = get<Object const>(obj);
std::map<std::string, std::string> m;
for (auto const& kv : j_param) {
m[kv.first] = get<String const>(kv.second);
}
param->InitAllowUnknown(m);
param->UpdateAllowUnknown(m);
}
} // namespace xgboost
#endif // XGBOOST_JSON_H_

View File

@ -9,9 +9,10 @@
#define XGBOOST_LOGGING_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <dmlc/thread_local.h>
#include <xgboost/base.h>
#include <xgboost/parameter.h>
#include <sstream>
#include <map>
@ -35,7 +36,7 @@ class BaseLogger {
};
// Parsing both silent and debug_verbose is to provide backward compatibility.
struct ConsoleLoggerParam : public dmlc::Parameter<ConsoleLoggerParam> {
struct ConsoleLoggerParam : public XGBoostParameter<ConsoleLoggerParam> {
bool silent; // deprecated.
int verbosity;

View File

@ -60,6 +60,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
// other arguments are set by the algorithm.
DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1)
.describe("Number of start root of trees.");
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
DMLC_DECLARE_FIELD(num_feature)
.describe("Number of features used in tree construction.");
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
@ -83,7 +84,7 @@ struct RTreeNodeStat {
/*! \brief weight of current node */
bst_float base_weight;
/*! \brief number of child that is leaf node known up to now */
int leaf_child_cnt;
int leaf_child_cnt {0};
bool operator==(const RTreeNodeStat& b) const {
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
@ -98,6 +99,7 @@ class RegTree : public Model {
public:
/*! \brief auxiliary statistics of node to help tree building */
using SplitCondT = bst_float;
static constexpr int32_t kInvalidNodeId {-1};
/*! \brief tree node */
class Node {
public:
@ -106,6 +108,12 @@ class RegTree : public Model {
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
"Node: 64 bit align");
}
Node(int32_t cleft, int32_t cright, int32_t parent,
uint32_t split_ind, float split_cond, bool default_left) :
parent_{parent}, cleft_{cleft}, cright_{cright} {
this->SetSplit(split_ind, split_cond, default_left);
}
/*! \brief index of left child */
XGBOOST_DEVICE int LeftChild() const {
return this->cleft_;
@ -219,11 +227,11 @@ class RegTree : public Model {
};
// pointer to parent, highest bit is used to
// indicate whether it's a left child or not
int parent_;
int32_t parent_{kInvalidNodeId};
// pointer to left, right
int cleft_, cright_;
int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
// split feature index, left split or right split depends on the highest bit
unsigned sindex_{0};
uint32_t sindex_{0};
// extra info
Info info_;
};

View File

@ -5,7 +5,7 @@
* This plugin defines the additional metric function.
*/
#include <xgboost/base.h>
#include <dmlc/parameter.h>
#include <xgboost/parameter.h>
#include <xgboost/objective.h>
#include <xgboost/json.h>
@ -16,7 +16,7 @@ namespace obj {
// You do not have to use it.
// see http://dmlc-core.readthedocs.org/en/latest/parameter.html
// for introduction of this module.
struct MyLogisticParam : public dmlc::Parameter<MyLogisticParam> {
struct MyLogisticParam : public XGBoostParameter<MyLogisticParam> {
float scale_neg_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(MyLogisticParam) {
@ -32,7 +32,7 @@ DMLC_REGISTER_PARAMETER(MyLogisticParam);
class MyLogistic : public ObjFunction {
public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,

View File

@ -3,10 +3,11 @@
* \file sparse_page_lz4_format.cc
* XGBoost Plugin to enable LZ4 compressed format on the external memory pages.
*/
#include <dmlc/registry.h>
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <dmlc/registry.h>
#include <dmlc/parameter.h>
#include <xgboost/parameter.h>
#include <lz4.h>
#include <lz4hc.h>
#include "../../src/data/sparse_page_writer.h"

View File

@ -1,6 +1,6 @@
# coding: utf-8
# pylint: disable= invalid-name, unused-import
"""For compatibility"""
"""For compatibility and optional dependencies."""
from __future__ import absolute_import
@ -16,22 +16,22 @@ if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = (str,)
def py_str(x):
"""convert c string back to python string"""
return x.decode('utf-8')
else:
STRING_TYPES = (basestring,) # pylint: disable=undefined-variable
def py_str(x):
"""convert c string back to python string"""
return x
########################################################################################
###############################################################################
# START NUMPY PATHLIB ATTRIBUTION
########################################################################################
# os.PathLike compatibility used in Numpy: https://github.com/numpy/numpy/tree/v1.17.0
###############################################################################
# os.PathLike compatibility used in Numpy:
# https://github.com/numpy/numpy/tree/v1.17.0
# Attribution:
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/compat/py3k.py#L188-L247
# Backport os.fs_path, os.PathLike, and PurePath.__fspath__
@ -56,7 +56,6 @@ else:
return True
return hasattr(subclass, '__fspath__')
def os_fspath(path):
"""Return the path representation of a path-like object.
If str or bytes is passed in, it is returned unchanged. Otherwise the
@ -84,9 +83,9 @@ else:
raise TypeError("expected {}.__fspath__() to return str or bytes, "
"not {}".format(path_type.__name__,
type(path_repr).__name__))
########################################################################################
###############################################################################
# END NUMPY PATHLIB ATTRIBUTION
########################################################################################
###############################################################################
# pickle
try:

View File

@ -12,6 +12,8 @@
#include <xgboost/learner.h>
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <dmlc/timer.h>
#include <iomanip>
#include <ctime>
@ -30,7 +32,7 @@ enum CLITask {
kPredict = 2
};
struct CLIParam : public dmlc::Parameter<CLIParam> {
struct CLIParam : public XGBoostParameter<CLIParam> {
/*! \brief the task name */
int task;
/*! \brief whether evaluate training statistics */
@ -123,7 +125,7 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
// customized configure function of CLIParam
inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) {
this->cfg = _cfg;
this->InitAllowUnknown(_cfg);
this->UpdateAllowUnknown(_cfg);
for (const auto& kv : _cfg) {
if (!strncmp("eval[", kv.first.c_str(), 5)) {
char evname[256];

View File

@ -25,7 +25,7 @@ namespace gbm {
DMLC_REGISTRY_FILE_TAG(gblinear);
// training parameters
struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
struct GBLinearTrainParam : public XGBoostParameter<GBLinearTrainParam> {
std::string updater;
float tolerance;
size_t max_row_perbatch;
@ -64,7 +64,7 @@ class GBLinear : public GradientBooster {
if (model_.weight.size() == 0) {
model_.param.InitAllowUnknown(cfg);
}
param_.InitAllowUnknown(cfg);
param_.UpdateAllowUnknown(cfg);
updater_.reset(LinearUpdater::Create(param_.updater, learner_param_));
updater_->Configure(cfg);
monitor_.Init("GBLinear");

View File

@ -34,7 +34,7 @@ DMLC_REGISTRY_FILE_TAG(gbtree);
void GBTree::Configure(const Args& cfg) {
this->cfg_ = cfg;
tparam_.InitAllowUnknown(cfg);
tparam_.UpdateAllowUnknown(cfg);
model_.Configure(cfg);
@ -295,7 +295,7 @@ class Dart : public GBTree {
void Configure(const Args& cfg) override {
GBTree::Configure(cfg);
if (model_.trees.size() == 0) {
dparam_.InitAllowUnknown(cfg);
dparam_.UpdateAllowUnknown(cfg);
}
}

View File

@ -48,7 +48,7 @@ namespace xgboost {
namespace gbm {
/*! \brief training parameters */
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
/*!
* \brief number of parallel trees constructed each iteration
* use this option to support boosted random forest
@ -95,7 +95,7 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
};
/*! \brief training parameters */
struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
/*! \brief type of sampling algorithm */
int sample_type;
/*! \brief type of normalization algorithm */

View File

@ -5,13 +5,8 @@
* \author Tianqi Chen
*/
#include <dmlc/io.h>
#include <dmlc/timer.h>
#include <dmlc/any.h>
#include <xgboost/feature_map.h>
#include <xgboost/learner.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <xgboost/generic_parameters.h>
#include <dmlc/parameter.h>
#include <algorithm>
#include <iomanip>
#include <limits>
@ -21,6 +16,12 @@
#include <utility>
#include <vector>
#include "xgboost/feature_map.h"
#include "xgboost/learner.h"
#include "xgboost/base.h"
#include "xgboost/parameter.h"
#include "xgboost/logging.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h"
#include "common/common.h"
#include "common/io.h"
@ -103,7 +104,7 @@ struct LearnerModelParam : public dmlc::Parameter<LearnerModelParam> {
}
};
struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none.
DataSplitMode dsplit;
// flag to disable default metric
@ -155,9 +156,9 @@ class LearnerImpl : public Learner {
auto old_tparam = tparam_;
Args args = {cfg_.cbegin(), cfg_.cend()};
tparam_.InitAllowUnknown(args);
tparam_.UpdateAllowUnknown(args);
generic_param_.InitAllowUnknown(args);
generic_param_.UpdateAllowUnknown(args);
generic_param_.CheckDeprecated();
ConsoleLogger::Configure(args);
@ -208,7 +209,7 @@ class LearnerImpl : public Learner {
}
void Load(dmlc::Stream* fi) override {
generic_param_.InitAllowUnknown(Args{});
generic_param_.UpdateAllowUnknown(Args{});
tparam_.Init(std::vector<std::pair<std::string, std::string>>{});
// TODO(tqchen) mark deprecation of old format.
common::PeekableInStream fp(fi);
@ -314,7 +315,7 @@ class LearnerImpl : public Learner {
cfg_.insert(n.cbegin(), n.cend());
Args args = {cfg_.cbegin(), cfg_.cend()};
generic_param_.InitAllowUnknown(args);
generic_param_.UpdateAllowUnknown(args);
gbm_->Configure(args);
obj_->Configure({cfg_.begin(), cfg_.end()});

View File

@ -10,6 +10,7 @@
#include <limits>
#include "xgboost/data.h"
#include "xgboost/parameter.h"
#include "./param.h"
#include "../gbm/gblinear_model.h"
#include "../common/random.h"
@ -17,7 +18,7 @@
namespace xgboost {
namespace linear {
struct CoordinateParam : public dmlc::Parameter<CoordinateParam> {
struct CoordinateParam : public XGBoostParameter<CoordinateParam> {
int top_k;
DMLC_DECLARE_PARAMETER(CoordinateParam) {
DMLC_DECLARE_FIELD(top_k)

View File

@ -5,7 +5,7 @@
*/
#ifndef XGBOOST_LINEAR_PARAM_H_
#define XGBOOST_LINEAR_PARAM_H_
#include <dmlc/parameter.h>
#include "xgboost/parameter.h"
namespace xgboost {
namespace linear {
@ -20,7 +20,7 @@ enum FeatureSelectorEnum {
kRandom
};
struct LinearTrainParam : public dmlc::Parameter<LinearTrainParam> {
struct LinearTrainParam : public XGBoostParameter<LinearTrainParam> {
/*! \brief learning_rate */
float learning_rate;
/*! \brief regularization weight for L2 norm */

View File

@ -26,9 +26,9 @@ class CoordinateUpdater : public LinearUpdater {
// set training parameter
void Configure(Args const& args) override {
const std::vector<std::pair<std::string, std::string> > rest {
tparam_.InitAllowUnknown(args)
tparam_.UpdateAllowUnknown(args)
};
cparam_.InitAllowUnknown(rest);
cparam_.UpdateAllowUnknown(rest);
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
monitor_.Init("CoordinateUpdater");
}

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2018 by Contributors
* Copyright 2018-2019 by Contributors
* \author Rory Mitchell
*/
@ -36,7 +36,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
// set training parameter
void Configure(Args const& args) override {
tparam_.InitAllowUnknown(args);
tparam_.UpdateAllowUnknown(args);
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
monitor_.Init("GPUCoordinateUpdater");
}

View File

@ -15,7 +15,7 @@ class ShotgunUpdater : public LinearUpdater {
public:
// set training parameter
void Configure(Args const& args) override {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
if (param_.feature_selector != kCyclic &&
param_.feature_selector != kShuffle) {
LOG(FATAL) << "Unsupported feature selector for shotgun updater.\n"

View File

@ -5,12 +5,13 @@
* \author Tianqi Chen
*/
#include <rabit/rabit.h>
#include <dmlc/parameter.h>
#include <xgboost/logging.h>
#include <iostream>
#include <map>
#include "xgboost/parameter.h"
#include "xgboost/logging.h"
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// Override logging mechanism for non-R interfaces
void dmlc::CustomLogMessage::Log(const std::string& msg) {
@ -51,7 +52,7 @@ bool ConsoleLogger::ShouldLog(LogVerbosity verbosity) {
}
void ConsoleLogger::Configure(Args const& args) {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
// Deprecated, but when trying to display deprecation message some R
// tests trying to catch stdout will fail.
if (param_.silent) {

View File

@ -327,7 +327,6 @@ struct EvalEWiseBase : public Metric {
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label and prediction size not match, "
<< "hint: use merror or mlogloss for multi-class classification";
const auto ndata = static_cast<omp_ulong>(info.labels_.Size());
int device = tparam_->gpu_id;
auto result =

View File

@ -5,7 +5,6 @@
#ifndef XGBOOST_METRIC_METRIC_COMMON_H_
#define XGBOOST_METRIC_METRIC_COMMON_H_
#include <dmlc/parameter.h>
#include "../common/common.h"
namespace xgboost {

View File

@ -172,7 +172,6 @@ struct EvalMClassBase : public Metric {
CHECK_GE(nclass, 1U)
<< "mlogloss and merror are only used for multi-class classification,"
<< " use logloss for binary classification";
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
int device = tparam_->gpu_id;
auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds);

View File

@ -5,16 +5,18 @@
* \author Tianqi Chen
*/
#include <dmlc/omp.h>
#include <dmlc/parameter.h>
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
#include <algorithm>
#include <limits>
#include <utility>
#include "xgboost/parameter.h"
#include "xgboost/data.h"
#include "xgboost/logging.h"
#include "xgboost/objective.h"
#include "xgboost/json.h"
#include "../common/common.h"
#include "../common/math.h"
#include "../common/transform.h"

View File

@ -1,13 +1,13 @@
/*!
* Copyright 2017-2018 by Contributors
*/
#include <dmlc/parameter.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <memory>
#include "xgboost/parameter.h"
#include "xgboost/data.h"
#include "xgboost/predictor.h"
#include "xgboost/tree_model.h"

View File

@ -1,4 +1,3 @@
/*!
* Copyright 2017-2019 XGBoost contributors
*/

View File

@ -7,20 +7,20 @@
#ifndef XGBOOST_TREE_PARAM_H_
#define XGBOOST_TREE_PARAM_H_
#include <dmlc/parameter.h>
#include <xgboost/data.h>
#include <cmath>
#include <cstring>
#include <limits>
#include <string>
#include <vector>
#include "xgboost/parameter.h"
#include "xgboost/data.h"
namespace xgboost {
namespace tree {
/*! \brief training parameters for regression tree */
struct TrainParam : public dmlc::Parameter<TrainParam> {
struct TrainParam : public XGBoostParameter<TrainParam> {
// learning step size for a time
float learning_rate;
// minimum loss change required for a split

View File

@ -5,6 +5,7 @@
*/
#include <dmlc/json.h>
#include <dmlc/registry.h>
#include <algorithm>
#include <unordered_set>
#include <vector>
@ -15,7 +16,8 @@
#include <utility>
#include "xgboost/logging.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/parameter.h"
#include "param.h"
#include "split_evaluator.h"
#include "../common/common.h"
@ -67,7 +69,7 @@ bool SplitEvaluator::CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid)
}
//! \brief Encapsulates the parameters for ElasticNet
struct ElasticNetParams : public dmlc::Parameter<ElasticNetParams> {
struct ElasticNetParams : public XGBoostParameter<ElasticNetParams> {
bst_float reg_lambda;
bst_float reg_alpha;
// maximum delta update we can add in weight estimation
@ -105,7 +107,7 @@ class ElasticNet final : public SplitEvaluator {
}
}
void Init(const Args& args) override {
params_.InitAllowUnknown(args);
params_.UpdateAllowUnknown(args);
}
SplitEvaluator* GetHostClone() const override {
@ -185,7 +187,7 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(ElasticNet, "elastic_net")
split evaluator
*/
struct MonotonicConstraintParams
: public dmlc::Parameter<MonotonicConstraintParams> {
: public XGBoostParameter<MonotonicConstraintParams> {
std::vector<bst_int> monotone_constraints;
DMLC_DECLARE_PARAMETER(MonotonicConstraintParams) {
@ -212,7 +214,7 @@ class MonotonicConstraint final : public SplitEvaluator {
void Init(const Args& args)
override {
inner_->Init(args);
params_.InitAllowUnknown(args);
params_.UpdateAllowUnknown(args);
Reset();
}
@ -337,7 +339,7 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
split evaluator
*/
struct InteractionConstraintParams
: public dmlc::Parameter<InteractionConstraintParams> {
: public XGBoostParameter<InteractionConstraintParams> {
std::string interaction_constraints;
bst_uint num_feature;
@ -371,7 +373,7 @@ class InteractionConstraint final : public SplitEvaluator {
void Init(const Args& args)
override {
inner_->Init(args);
params_.InitAllowUnknown(args);
params_.UpdateAllowUnknown(args);
Reset();
}

View File

@ -423,7 +423,7 @@ XGBOOST_REGISTER_TREE_IO(JsonGenerator, "json")
return new JsonGenerator(fmap, attrs, with_stats);
});
struct GraphvizParam : public dmlc::Parameter<GraphvizParam> {
struct GraphvizParam : public XGBoostParameter<GraphvizParam> {
std::string yes_color;
std::string no_color;
std::string rankdir;
@ -462,7 +462,7 @@ class GraphvizGenerator : public TreeGenerator {
public:
GraphvizGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {
param_.InitAllowUnknown(std::map<std::string, std::string>{});
param_.UpdateAllowUnknown(std::map<std::string, std::string>{});
using KwArg = std::map<std::string, std::map<std::string, std::string>>;
KwArg kwargs;
if (attrs.length() != 0) {

View File

@ -31,7 +31,7 @@ namespace tree {
class BaseMaker: public TreeUpdater {
public:
void Configure(const Args& args) override {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
}
protected:

View File

@ -26,7 +26,7 @@ DMLC_REGISTRY_FILE_TAG(updater_colmaker);
class ColMaker: public TreeUpdater {
public:
void Configure(const Args& args) override {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
spliteval_->Init(args);
}
@ -773,7 +773,7 @@ class ColMaker: public TreeUpdater {
class DistColMaker : public ColMaker {
public:
void Configure(const Args& args) override {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
pruner_.reset(TreeUpdater::Create("prune", tparam_));
pruner_->Configure(args);
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));

View File

@ -16,6 +16,7 @@
#include <vector>
#include "xgboost/host_device_vector.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "../common/common.h"
@ -38,7 +39,7 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
// training parameters specific to this algorithm
struct GPUHistMakerTrainParam
: public dmlc::Parameter<GPUHistMakerTrainParam> {
: public XGBoostParameter<GPUHistMakerTrainParam> {
bool single_precision_histogram;
// number of rows in a single GPU batch
int gpu_batch_nrows;
@ -969,9 +970,9 @@ class GPUHistMakerSpecialised {
public:
GPUHistMakerSpecialised() : initialised_{false}, p_last_fmat_{nullptr} {}
void Configure(const Args& args, GenericParameter const* generic_param) {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
generic_param_ = generic_param;
hist_maker_param_.InitAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args);
device_ = generic_param_->gpu_id;
CHECK_GE(device_, 0) << "Must have at least one device";
@ -1107,7 +1108,7 @@ class GPUHistMakerSpecialised {
class GPUHistMaker : public TreeUpdater {
public:
void Configure(const Args& args) override {
hist_maker_param_.InitAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args);
float_maker_.reset();
double_maker_.reset();
if (hist_maker_param_.single_precision_histogram) {

View File

@ -30,7 +30,7 @@ class TreePruner: public TreeUpdater {
// set training parameter
void Configure(const Args& args) override {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
syncher_->Configure(args);
}
// update the tree, do pruning

View File

@ -38,7 +38,7 @@ void QuantileHistMaker::Configure(const Args& args) {
pruner_.reset(TreeUpdater::Create("prune", tparam_));
}
pruner_->Configure(args);
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
is_gmat_initialized_ = false;
// initialize the split evaluator

View File

@ -22,7 +22,7 @@ DMLC_REGISTRY_FILE_TAG(updater_refresh);
class TreeRefresher: public TreeUpdater {
public:
void Configure(const Args& args) override {
param_.InitAllowUnknown(args);
param_.UpdateAllowUnknown(args);
}
char const* Name() const override {
return "refresh";

View File

@ -26,17 +26,22 @@ def call(args):
class ClangTidy(object):
'''
clang tidy wrapper.
''' clang tidy wrapper.
Args:
cpp_lint: Run linter on C++ source code.
cuda_lint: Run linter on CUDA source code.
args: Command line arguments.
cpp_lint: Run linter on C++ source code.
cuda_lint: Run linter on CUDA source code.
use_dmlc_gtest: Whether to use gtest bundled in dmlc-core.
'''
def __init__(self, cpp_lint, cuda_lint):
self.cpp_lint = cpp_lint
self.cuda_lint = cuda_lint
def __init__(self, args):
self.cpp_lint = args.cpp
self.cuda_lint = args.cuda
self.use_dmlc_gtest = args.use_dmlc_gtest
print('Run linter on CUDA: ', self.cuda_lint)
print('Run linter on C++:', self.cpp_lint)
print('Use dmlc gtest:', self.use_dmlc_gtest)
if not self.cpp_lint and not self.cuda_lint:
raise ValueError('Both --cpp and --cuda are set to 0.')
self.root_path = os.path.abspath(os.path.curdir)
@ -58,7 +63,12 @@ class ClangTidy(object):
os.mkdir(self.cdb_path)
os.chdir(self.cdb_path)
cmake_args = ['cmake', '..', '-DCMAKE_EXPORT_COMPILE_COMMANDS=ON',
'-DGOOGLE_TEST=ON', '-DUSE_DMLC_GTEST=ON']
'-DGOOGLE_TEST=ON']
if self.use_dmlc_gtest:
cmake_args.append('-DUSE_DMLC_GTEST=ON')
else:
cmake_args.append('-DUSE_DMLC_GTEST=OFF')
if self.cuda_lint:
cmake_args.extend(['-DUSE_CUDA=ON', '-DUSE_NCCL=ON'])
subprocess.run(cmake_args)
@ -234,11 +244,13 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run clang-tidy.')
parser.add_argument('--cpp', type=int, default=1)
parser.add_argument('--cuda', type=int, default=1)
parser.add_argument('--use-dmlc-gtest', type=int, default=1,
help='Whether to use gtest bundled in dmlc-core.')
args = parser.parse_args()
test_tidy()
with ClangTidy(args.cpp, args.cuda) as linter:
with ClangTidy(args) as linter:
passed = linter.run()
if not passed:
sys.exit(1)

View File

@ -16,6 +16,8 @@ TEST(Monitor, Logging) {
Args args = {std::make_pair("verbosity", "3")};
ConsoleLogger::Configure(args);
ASSERT_EQ(ConsoleLogger::GlobalVerbosity(), ConsoleLogger::LogVerbosity::kDebug);
testing::internal::CaptureStderr();
run_monitor();
std::string output = testing::internal::GetCapturedStderr();
@ -28,6 +30,8 @@ TEST(Monitor, Logging) {
run_monitor();
output = testing::internal::GetCapturedStderr();
ASSERT_EQ(output.size(), 0);
ConsoleLogger::Configure(Args{{"verbosity", "1"}});
}
} // namespace common
} // namespace xgboost

View File

@ -1,5 +1,9 @@
#include <dmlc/parameter.h>
/*!
* Copyright (c) by Contributors 2019
*/
#include <gtest/gtest.h>
#include <xgboost/base.h>
#include <xgboost/parameter.h>
enum class Foo : int {
@ -8,10 +12,10 @@ enum class Foo : int {
DECLARE_FIELD_ENUM_CLASS(Foo);
struct MyParam : dmlc::Parameter<MyParam> {
struct MyEnumParam : xgboost::XGBoostParameter<MyEnumParam> {
Foo foo;
int bar;
DMLC_DECLARE_PARAMETER(MyParam) {
DMLC_DECLARE_PARAMETER(MyEnumParam) {
DMLC_DECLARE_FIELD(foo)
.set_default(Foo::kBar)
.add_enum("bar", Foo::kBar)
@ -23,10 +27,10 @@ struct MyParam : dmlc::Parameter<MyParam> {
}
};
DMLC_REGISTER_PARAMETER(MyParam);
DMLC_REGISTER_PARAMETER(MyEnumParam);
TEST(EnumClassParam, Basic) {
MyParam param;
MyEnumParam param;
std::map<std::string, std::string> kwargs{
{"foo", "frog"}, {"bar", "10"}
};
@ -53,3 +57,44 @@ TEST(EnumClassParam, Basic) {
kwargs["foo"] = "human";
ASSERT_THROW(param.Init(kwargs), dmlc::ParamError);
}
struct UpdatableParam : xgboost::XGBoostParameter<UpdatableParam> {
float f { 0.0f };
double d { 0.0 };
DMLC_DECLARE_PARAMETER(UpdatableParam) {
DMLC_DECLARE_FIELD(f)
.set_default(11.0f);
DMLC_DECLARE_FIELD(d)
.set_default(2.71828f);
}
};
DMLC_REGISTER_PARAMETER(UpdatableParam);
TEST(XGBoostParameter, Update) {
{
UpdatableParam p;
auto constexpr kRtEps = xgboost::kRtEps;
p.UpdateAllowUnknown(xgboost::Args{});
// When it's not initialized, perform set_default.
ASSERT_NEAR(p.f, 11.0f, kRtEps);
ASSERT_NEAR(p.d, 2.71828f, kRtEps);
p.d = 3.14149;
p.UpdateAllowUnknown(xgboost::Args{{"f", "2.71828"}});
ASSERT_NEAR(p.f, 2.71828f, kRtEps);
// p.d is un-effected by the update.
ASSERT_NEAR(p.d, 3.14149, kRtEps);
}
{
UpdatableParam p;
auto constexpr kRtEps = xgboost::kRtEps;
p.UpdateAllowUnknown(xgboost::Args{{"f", "2.71828"}});
ASSERT_NEAR(p.f, 2.71828f, kRtEps);
ASSERT_NEAR(p.d, 2.71828, kRtEps); // default
}
}

View File

@ -11,7 +11,7 @@ TEST(GBTree, SelectTreeMethod) {
size_t constexpr kCols = 10;
GenericParameter generic_param;
generic_param.InitAllowUnknown(Args{});
generic_param.UpdateAllowUnknown(Args{});
std::unique_ptr<GradientBooster> p_gbm{
GradientBooster::Create("gbtree", &generic_param, {}, 0)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
@ -36,7 +36,7 @@ TEST(GBTree, SelectTreeMethod) {
ASSERT_EQ(tparam.predictor, "cpu_predictor");
#ifdef XGBOOST_USE_CUDA
generic_param.InitAllowUnknown(Args{{"gpu_id", "0"}});
generic_param.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
gbtree.Configure({{"tree_method", "gpu_hist"}, {"num_feature", n_feat}});
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
ASSERT_EQ(tparam.predictor, "gpu_predictor");
@ -64,7 +64,7 @@ TEST(GBTree, ChoosePredictor) {
std::string n_feat = std::to_string(kCols);
Args args {{"tree_method", "approx"}, {"num_feature", n_feat}};
GenericParameter generic_param;
generic_param.InitAllowUnknown(Args{{"gpu_id", "0"}});
generic_param.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
auto& data = (*(p_mat->GetBatches<SparsePage>().begin())).data;

View File

@ -9,7 +9,7 @@ TEST(Objective, UnknownFunction) {
xgboost::ObjFunction* obj = nullptr;
xgboost::GenericParameter tparam;
std::vector<std::pair<std::string, std::string>> args;
tparam.InitAllowUnknown(args);
tparam.UpdateAllowUnknown(args);
EXPECT_ANY_THROW(obj = xgboost::ObjFunction::Create("unknown_name", &tparam));
EXPECT_NO_THROW(obj = xgboost::ObjFunction::Create("reg:squarederror", &tparam));

View File

@ -38,7 +38,7 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) {
TEST(Objective, DeclareUnifiedTest(NDCG_Json_IO)) {
xgboost::GenericParameter tparam;
tparam.InitAllowUnknown(Args{});
tparam.UpdateAllowUnknown(Args{});
std::unique_ptr<xgboost::ObjFunction> obj {
xgboost::ObjFunction::Create("rank:ndcg", &tparam)

View File

@ -53,6 +53,7 @@ TEST(Logging, Basic) {
output = testing::internal::GetCapturedStderr();
ASSERT_NE(output.find("Test Log Console"), std::string::npos);
args["silent"] = "False";
args["verbosity"] = "1"; // restore
ConsoleLogger::Configure({args.cbegin(), args.cend()});
}

View File

@ -33,7 +33,7 @@ TEST(Updater, Prune) {
// prepare tree
RegTree tree = RegTree();
tree.param.InitAllowUnknown(cfg);
tree.param.UpdateAllowUnknown(cfg);
std::vector<RegTree*> trees {&tree};
// prepare pruner
std::unique_ptr<TreeUpdater> pruner(TreeUpdater::Create("prune", &lparam));

View File

@ -262,7 +262,7 @@ class QuantileHistMock : public QuantileHistMaker {
gmat.Init((*dmat_).get(), kMaxBins);
RegTree tree = RegTree();
tree.param.InitAllowUnknown(cfg_);
tree.param.UpdateAllowUnknown(cfg_);
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
@ -273,7 +273,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestBuildHist() {
RegTree tree = RegTree();
tree.param.InitAllowUnknown(cfg_);
tree.param.UpdateAllowUnknown(cfg_);
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
@ -284,7 +284,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestEvaluateSplit() {
RegTree tree = RegTree();
tree.param.InitAllowUnknown(cfg_);
tree.param.UpdateAllowUnknown(cfg_);
builder_->TestEvaluateSplit(gmatb_, tree);
}

View File

@ -28,7 +28,7 @@ TEST(Updater, Refresh) {
RegTree tree = RegTree();
auto lparam = CreateEmptyGenericParam(GPUIDX);
tree.param.InitAllowUnknown(cfg);
tree.param.UpdateAllowUnknown(cfg);
std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));