Use `UpdateAllowUnknown' for non-model related parameter. (#4961)

* Use `UpdateAllowUnknown' for non-model related parameter.

Model parameter can not pack an additional boolean value due to binary IO
format.  This commit deals only with non-model related parameter configuration.

* Add tidy command line arg for use-dmlc-gtest.
This commit is contained in:
Jiaming Yuan
2019-10-23 05:50:12 -04:00
committed by GitHub
parent f24be2efb4
commit ac457c56a2
44 changed files with 189 additions and 112 deletions

View File

@@ -25,7 +25,7 @@ namespace gbm {
DMLC_REGISTRY_FILE_TAG(gblinear);
// training parameters
struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
struct GBLinearTrainParam : public XGBoostParameter<GBLinearTrainParam> {
std::string updater;
float tolerance;
size_t max_row_perbatch;
@@ -64,7 +64,7 @@ class GBLinear : public GradientBooster {
if (model_.weight.size() == 0) {
model_.param.InitAllowUnknown(cfg);
}
param_.InitAllowUnknown(cfg);
param_.UpdateAllowUnknown(cfg);
updater_.reset(LinearUpdater::Create(param_.updater, learner_param_));
updater_->Configure(cfg);
monitor_.Init("GBLinear");

View File

@@ -34,7 +34,7 @@ DMLC_REGISTRY_FILE_TAG(gbtree);
void GBTree::Configure(const Args& cfg) {
this->cfg_ = cfg;
tparam_.InitAllowUnknown(cfg);
tparam_.UpdateAllowUnknown(cfg);
model_.Configure(cfg);
@@ -295,7 +295,7 @@ class Dart : public GBTree {
void Configure(const Args& cfg) override {
GBTree::Configure(cfg);
if (model_.trees.size() == 0) {
dparam_.InitAllowUnknown(cfg);
dparam_.UpdateAllowUnknown(cfg);
}
}

View File

@@ -48,7 +48,7 @@ namespace xgboost {
namespace gbm {
/*! \brief training parameters */
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
/*!
* \brief number of parallel trees constructed each iteration
* use this option to support boosted random forest
@@ -95,7 +95,7 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
};
/*! \brief training parameters */
struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
/*! \brief type of sampling algorithm */
int sample_type;
/*! \brief type of normalization algorithm */