Refactor of FastHistMaker to allow for custom regularisation methods (#3335)
* Refactor to allow for custom regularisation methods * Implement compositional SplitEvaluator framework * Fixed segfault when no monotone_constraints are supplied. * Change pid to parentID * test_monotone_constraints.py now passes * Refactor ColMaker and DistColMaker to use SplitEvaluator * Performance optimisation when no monotone_constraints specified * Fix linter messages * Fix a few more linter errors * Update the amalgamation * Add bounds check * Add check for leaf node * Fix linter error in param.h * Fix clang-tidy errors on CI * Fix incorrect function name * Fix clang-tidy error in updater_fast_hist.cc * Enable SSE2 for Win32 R MinGW Addresses https://github.com/dmlc/xgboost/pull/3335#issuecomment-400535752 * Add contributor
This commit is contained in:
parent
cafc621914
commit
64b8cffde3
@ -74,3 +74,4 @@ List of Contributors
|
|||||||
* [Yi-Lin Juang](https://github.com/frankyjuang)
|
* [Yi-Lin Juang](https://github.com/frankyjuang)
|
||||||
* [Andrew Hannigan](https://github.com/andrewhannigan)
|
* [Andrew Hannigan](https://github.com/andrewhannigan)
|
||||||
* [Andy Adinets](https://github.com/canonizer)
|
* [Andy Adinets](https://github.com/canonizer)
|
||||||
|
* [Henry Gouk](https://github.com/henrygouk)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\
|
|||||||
|
|
||||||
# disable the use of thread_local for 32 bit windows:
|
# disable the use of thread_local for 32 bit windows:
|
||||||
ifeq ($(R_OSTYPE)$(WIN),windows)
|
ifeq ($(R_OSTYPE)$(WIN),windows)
|
||||||
XGB_RFLAGS += -DDMLC_CXX11_THREAD_LOCAL=0
|
XGB_RFLAGS += -DDMLC_CXX11_THREAD_LOCAL=0 -msse2 -mfpmath=sse
|
||||||
endif
|
endif
|
||||||
$(foreach v, $(XGB_RFLAGS), $(warning $(v)))
|
$(foreach v, $(XGB_RFLAGS), $(warning $(v)))
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\
|
|||||||
|
|
||||||
# disable the use of thread_local for 32 bit windows:
|
# disable the use of thread_local for 32 bit windows:
|
||||||
ifeq ($(R_OSTYPE)$(WIN),windows)
|
ifeq ($(R_OSTYPE)$(WIN),windows)
|
||||||
XGB_RFLAGS += -DDMLC_CXX11_THREAD_LOCAL=0
|
XGB_RFLAGS += -DDMLC_CXX11_THREAD_LOCAL=0 -msse2 -mfpmath=sse
|
||||||
endif
|
endif
|
||||||
$(foreach v, $(XGB_RFLAGS), $(warning $(v)))
|
$(foreach v, $(XGB_RFLAGS), $(warning $(v)))
|
||||||
|
|
||||||
|
|||||||
@ -43,6 +43,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// tress
|
// tress
|
||||||
|
#include "../src/tree/split_evaluator.cc"
|
||||||
#include "../src/tree/tree_model.cc"
|
#include "../src/tree/tree_model.cc"
|
||||||
#include "../src/tree/tree_updater.cc"
|
#include "../src/tree/tree_updater.cc"
|
||||||
#include "../src/tree/updater_colmaker.cc"
|
#include "../src/tree/updater_colmaker.cc"
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
@ -76,6 +77,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
int gpu_id;
|
int gpu_id;
|
||||||
// number of GPUs to use
|
// number of GPUs to use
|
||||||
int n_gpus;
|
int n_gpus;
|
||||||
|
// the criteria to use for ranking splits
|
||||||
|
std::string split_evaluator;
|
||||||
// declare the parameters
|
// declare the parameters
|
||||||
DMLC_DECLARE_PARAMETER(TrainParam) {
|
DMLC_DECLARE_PARAMETER(TrainParam) {
|
||||||
DMLC_DECLARE_FIELD(learning_rate)
|
DMLC_DECLARE_FIELD(learning_rate)
|
||||||
@ -183,7 +186,9 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
.set_lower_bound(-1)
|
.set_lower_bound(-1)
|
||||||
.set_default(1)
|
.set_default(1)
|
||||||
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
|
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
|
||||||
|
DMLC_DECLARE_FIELD(split_evaluator)
|
||||||
|
.set_default("monotonic")
|
||||||
|
.describe("The criteria to use for ranking splits");
|
||||||
// add alias of parameters
|
// add alias of parameters
|
||||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||||
|
|||||||
273
src/tree/split_evaluator.cc
Normal file
273
src/tree/split_evaluator.cc
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2018 by Contributors
|
||||||
|
* \file split_evaluator.cc
|
||||||
|
* \brief Contains implementations of different split evaluators.
|
||||||
|
*/
|
||||||
|
#include "split_evaluator.h"
|
||||||
|
#include <dmlc/registry.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include "param.h"
|
||||||
|
#include "../common/common.h"
|
||||||
|
#include "../common/host_device_vector.h"
|
||||||
|
|
||||||
|
#define ROOT_PARENT_ID (-1 & ((1U << 31) - 1))
|
||||||
|
|
||||||
|
namespace dmlc {
|
||||||
|
DMLC_REGISTRY_ENABLE(::xgboost::tree::SplitEvaluatorReg);
|
||||||
|
} // namespace dmlc
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
SplitEvaluator* SplitEvaluator::Create(const std::string& name) {
|
||||||
|
auto* e = ::dmlc::Registry< ::xgboost::tree::SplitEvaluatorReg>
|
||||||
|
::Get()->Find(name);
|
||||||
|
if (e == nullptr) {
|
||||||
|
LOG(FATAL) << "Unknown SplitEvaluator " << name;
|
||||||
|
}
|
||||||
|
return (e->body)();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default implementations of some virtual methods that aren't always needed
|
||||||
|
void SplitEvaluator::Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string> >& args) {}
|
||||||
|
void SplitEvaluator::Reset() {}
|
||||||
|
void SplitEvaluator::AddSplit(bst_uint nodeid,
|
||||||
|
bst_uint leftid,
|
||||||
|
bst_uint rightid,
|
||||||
|
bst_uint featureid,
|
||||||
|
bst_float leftweight,
|
||||||
|
bst_float rightweight) {}
|
||||||
|
|
||||||
|
//! \brief Encapsulates the parameters for by the RidgePenalty
|
||||||
|
struct RidgePenaltyParams : public dmlc::Parameter<RidgePenaltyParams> {
|
||||||
|
float reg_lambda;
|
||||||
|
float reg_gamma;
|
||||||
|
|
||||||
|
DMLC_DECLARE_PARAMETER(RidgePenaltyParams) {
|
||||||
|
DMLC_DECLARE_FIELD(reg_lambda)
|
||||||
|
.set_lower_bound(0.0)
|
||||||
|
.set_default(1.0)
|
||||||
|
.describe("L2 regularization on leaf weight");
|
||||||
|
DMLC_DECLARE_FIELD(reg_gamma)
|
||||||
|
.set_lower_bound(0.0f)
|
||||||
|
.set_default(0.0f)
|
||||||
|
.describe("Cost incurred by adding a new leaf node to the tree");
|
||||||
|
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||||
|
DMLC_DECLARE_ALIAS(reg_gamma, gamma);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(RidgePenaltyParams);
|
||||||
|
|
||||||
|
/*! \brief Applies an L2 penalty and per-leaf penalty. */
|
||||||
|
class RidgePenalty final : public SplitEvaluator {
|
||||||
|
public:
|
||||||
|
void Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
|
params_.InitAllowUnknown(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
SplitEvaluator* GetHostClone() const override {
|
||||||
|
auto r = new RidgePenalty();
|
||||||
|
r->params_ = this->params_;
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float ComputeSplitScore(bst_uint nodeid,
|
||||||
|
bst_uint featureid,
|
||||||
|
const GradStats& left,
|
||||||
|
const GradStats& right) const override {
|
||||||
|
// parentID is not needed for this split evaluator. Just use 0.
|
||||||
|
return ComputeScore(0, left) + ComputeScore(0, right);
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float ComputeScore(bst_uint parentID, const GradStats& stats)
|
||||||
|
const override {
|
||||||
|
return (stats.sum_grad * stats.sum_grad)
|
||||||
|
/ (stats.sum_hess + params_.reg_lambda) - params_.reg_gamma;
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
|
||||||
|
const override {
|
||||||
|
return -stats.sum_grad / (stats.sum_hess + params_.reg_lambda);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
RidgePenaltyParams params_;
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_SPLIT_EVALUATOR(RidgePenalty, "ridge")
|
||||||
|
.describe("Use an L2 penalty term for the weights and a cost per leaf node")
|
||||||
|
.set_body([]() {
|
||||||
|
return new RidgePenalty();
|
||||||
|
});
|
||||||
|
|
||||||
|
/*! \brief Encapsulates the parameters required by the MonotonicConstraint
|
||||||
|
split evaluator
|
||||||
|
*/
|
||||||
|
struct MonotonicConstraintParams
|
||||||
|
: public dmlc::Parameter<MonotonicConstraintParams> {
|
||||||
|
std::vector<bst_int> monotone_constraints;
|
||||||
|
float reg_lambda;
|
||||||
|
float reg_gamma;
|
||||||
|
|
||||||
|
DMLC_DECLARE_PARAMETER(MonotonicConstraintParams) {
|
||||||
|
DMLC_DECLARE_FIELD(reg_lambda)
|
||||||
|
.set_lower_bound(0.0)
|
||||||
|
.set_default(1.0)
|
||||||
|
.describe("L2 regularization on leaf weight");
|
||||||
|
DMLC_DECLARE_FIELD(reg_gamma)
|
||||||
|
.set_lower_bound(0.0f)
|
||||||
|
.set_default(0.0f)
|
||||||
|
.describe("Cost incurred by adding a new leaf node to the tree");
|
||||||
|
DMLC_DECLARE_FIELD(monotone_constraints)
|
||||||
|
.set_default(std::vector<bst_int>())
|
||||||
|
.describe("Constraint of variable monotonicity");
|
||||||
|
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||||
|
DMLC_DECLARE_ALIAS(reg_gamma, gamma);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(MonotonicConstraintParams);
|
||||||
|
|
||||||
|
/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of
|
||||||
|
features.
|
||||||
|
*/
|
||||||
|
class MonotonicConstraint final : public SplitEvaluator {
|
||||||
|
public:
|
||||||
|
void Init(const std::vector<std::pair<std::string, std::string> >& args)
|
||||||
|
override {
|
||||||
|
params_.InitAllowUnknown(args);
|
||||||
|
Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() override {
|
||||||
|
lower_.resize(1, -std::numeric_limits<bst_float>::max());
|
||||||
|
upper_.resize(1, std::numeric_limits<bst_float>::max());
|
||||||
|
}
|
||||||
|
|
||||||
|
SplitEvaluator* GetHostClone() const override {
|
||||||
|
if (params_.monotone_constraints.size() == 0) {
|
||||||
|
// No monotone constraints specified, make a RidgePenalty evaluator
|
||||||
|
using std::pair;
|
||||||
|
using std::string;
|
||||||
|
using std::to_string;
|
||||||
|
using std::vector;
|
||||||
|
auto c = new RidgePenalty();
|
||||||
|
vector<pair<string, string> > args;
|
||||||
|
args.emplace_back(
|
||||||
|
pair<string, string>("reg_lambda", to_string(params_.reg_lambda)));
|
||||||
|
args.emplace_back(
|
||||||
|
pair<string, string>("reg_gamma", to_string(params_.reg_gamma)));
|
||||||
|
c->Init(args);
|
||||||
|
c->Reset();
|
||||||
|
return c;
|
||||||
|
} else {
|
||||||
|
auto c = new MonotonicConstraint();
|
||||||
|
c->params_ = this->params_;
|
||||||
|
c->Reset();
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float ComputeSplitScore(bst_uint nodeid,
|
||||||
|
bst_uint featureid,
|
||||||
|
const GradStats& left,
|
||||||
|
const GradStats& right) const override {
|
||||||
|
bst_float infinity = std::numeric_limits<bst_float>::infinity();
|
||||||
|
bst_int constraint = GetConstraint(featureid);
|
||||||
|
|
||||||
|
bst_float score = ComputeScore(nodeid, left) + ComputeScore(nodeid, right);
|
||||||
|
bst_float leftweight = ComputeWeight(nodeid, left);
|
||||||
|
bst_float rightweight = ComputeWeight(nodeid, right);
|
||||||
|
|
||||||
|
if (constraint == 0) {
|
||||||
|
return score;
|
||||||
|
} else if (constraint > 0) {
|
||||||
|
return leftweight <= rightweight ? score : -infinity;
|
||||||
|
} else {
|
||||||
|
return leftweight >= rightweight ? score : -infinity;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float ComputeScore(bst_uint parentID, const GradStats& stats)
|
||||||
|
const override {
|
||||||
|
bst_float w = ComputeWeight(parentID, stats);
|
||||||
|
|
||||||
|
return -(2.0 * stats.sum_grad * w + (stats.sum_hess + params_.reg_lambda)
|
||||||
|
* w * w);
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
|
||||||
|
const override {
|
||||||
|
bst_float weight = -stats.sum_grad / (stats.sum_hess + params_.reg_lambda);
|
||||||
|
|
||||||
|
if (parentID == ROOT_PARENT_ID) {
|
||||||
|
// This is the root node
|
||||||
|
return weight;
|
||||||
|
} else if (weight < lower_.at(parentID)) {
|
||||||
|
return lower_.at(parentID);
|
||||||
|
} else if (weight > upper_.at(parentID)) {
|
||||||
|
return upper_.at(parentID);
|
||||||
|
} else {
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddSplit(bst_uint nodeid,
|
||||||
|
bst_uint leftid,
|
||||||
|
bst_uint rightid,
|
||||||
|
bst_uint featureid,
|
||||||
|
bst_float leftweight,
|
||||||
|
bst_float rightweight) override {
|
||||||
|
bst_uint newsize = std::max(leftid, rightid) + 1;
|
||||||
|
lower_.resize(newsize);
|
||||||
|
upper_.resize(newsize);
|
||||||
|
bst_int constraint = GetConstraint(featureid);
|
||||||
|
|
||||||
|
bst_float mid = (leftweight + rightweight) / 2;
|
||||||
|
CHECK(!std::isnan(mid));
|
||||||
|
CHECK(nodeid < upper_.size());
|
||||||
|
|
||||||
|
upper_[leftid] = upper_.at(nodeid);
|
||||||
|
upper_[rightid] = upper_.at(nodeid);
|
||||||
|
lower_[leftid] = lower_.at(nodeid);
|
||||||
|
lower_[rightid] = lower_.at(nodeid);
|
||||||
|
|
||||||
|
if (constraint < 0) {
|
||||||
|
lower_[leftid] = mid;
|
||||||
|
upper_[rightid] = mid;
|
||||||
|
} else if (constraint > 0) {
|
||||||
|
upper_[leftid] = mid;
|
||||||
|
lower_[rightid] = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MonotonicConstraintParams params_;
|
||||||
|
std::vector<bst_float> lower_;
|
||||||
|
std::vector<bst_float> upper_;
|
||||||
|
|
||||||
|
inline bst_int GetConstraint(bst_uint featureid) const {
|
||||||
|
if (featureid < params_.monotone_constraints.size()) {
|
||||||
|
return params_.monotone_constraints[featureid];
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
|
||||||
|
.describe("Enforces that the tree is monotonically increasing/decreasing "
|
||||||
|
"w.r.t. specified features")
|
||||||
|
.set_body([]() {
|
||||||
|
return new MonotonicConstraint();
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
87
src/tree/split_evaluator.h
Normal file
87
src/tree/split_evaluator.h
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2018 by Contributors
|
||||||
|
* \file split_evaluator.h
|
||||||
|
* \brief Used for implementing a loss term specific to decision trees. Useful for custom regularisation.
|
||||||
|
* \author Henry Gouk
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef XGBOOST_TREE_SPLIT_EVALUATOR_H_
|
||||||
|
#define XGBOOST_TREE_SPLIT_EVALUATOR_H_
|
||||||
|
|
||||||
|
#include <dmlc/registry.h>
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
// Should GradStats be in this header, rather than param.h?
|
||||||
|
struct GradStats;
|
||||||
|
|
||||||
|
class SplitEvaluator {
|
||||||
|
public:
|
||||||
|
// Factory method for constructing new SplitEvaluators
|
||||||
|
static SplitEvaluator* Create(const std::string& name);
|
||||||
|
|
||||||
|
virtual ~SplitEvaluator() = default;
|
||||||
|
|
||||||
|
// Used to initialise any regularisation hyperparameters provided by the user
|
||||||
|
virtual void Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string> >& args);
|
||||||
|
|
||||||
|
// Resets the SplitEvaluator to the state it was in after the Init was called
|
||||||
|
virtual void Reset();
|
||||||
|
|
||||||
|
// This will create a clone of the SplitEvaluator in host memory
|
||||||
|
virtual SplitEvaluator* GetHostClone() const = 0;
|
||||||
|
|
||||||
|
// Computes the score (negative loss) resulting from performing this split
|
||||||
|
virtual bst_float ComputeSplitScore(bst_uint nodeid,
|
||||||
|
bst_uint featureid,
|
||||||
|
const GradStats& left,
|
||||||
|
const GradStats& right) const = 0;
|
||||||
|
|
||||||
|
// Compute the Score for a node with the given stats
|
||||||
|
virtual bst_float ComputeScore(bst_uint parentid, const GradStats& stats)
|
||||||
|
const = 0;
|
||||||
|
|
||||||
|
// Compute the weight for a node with the given stats
|
||||||
|
virtual bst_float ComputeWeight(bst_uint parentid, const GradStats& stats)
|
||||||
|
const = 0;
|
||||||
|
|
||||||
|
virtual void AddSplit(bst_uint nodeid,
|
||||||
|
bst_uint leftid,
|
||||||
|
bst_uint rightid,
|
||||||
|
bst_uint featureid,
|
||||||
|
bst_float leftweight,
|
||||||
|
bst_float rightweight);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SplitEvaluatorReg
|
||||||
|
: public dmlc::FunctionRegEntryBase<SplitEvaluatorReg,
|
||||||
|
std::function<SplitEvaluator* ()> > {};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Macro to register tree split evaluator.
|
||||||
|
*
|
||||||
|
* \code
|
||||||
|
* // example of registering a split evaluator
|
||||||
|
* XGBOOST_REGISTER_SPLIT_EVALUATOR(SplitEval, "splitEval")
|
||||||
|
* .describe("Some split evaluator")
|
||||||
|
* .set_body([]() {
|
||||||
|
* return new SplitEval();
|
||||||
|
* });
|
||||||
|
* \endcode
|
||||||
|
*/
|
||||||
|
#define XGBOOST_REGISTER_SPLIT_EVALUATOR(UniqueID, Name) \
|
||||||
|
static DMLC_ATTRIBUTE_UNUSED ::xgboost::tree::SplitEvaluatorReg& \
|
||||||
|
__make_ ## SplitEvaluatorReg ## _ ## UniqueID ## __ = \
|
||||||
|
::dmlc::Registry< ::xgboost::tree::SplitEvaluatorReg>::Get()->__REGISTER__(Name) //NOLINT
|
||||||
|
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // XGBOOST_TREE_SPLIT_EVALUATOR_H_
|
||||||
@ -13,6 +13,7 @@
|
|||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "../common/bitmap.h"
|
#include "../common/bitmap.h"
|
||||||
#include "../common/sync.h"
|
#include "../common/sync.h"
|
||||||
|
#include "split_evaluator.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -20,24 +21,26 @@ namespace tree {
|
|||||||
DMLC_REGISTRY_FILE_TAG(updater_colmaker);
|
DMLC_REGISTRY_FILE_TAG(updater_colmaker);
|
||||||
|
|
||||||
/*! \brief column-wise update to construct a tree */
|
/*! \brief column-wise update to construct a tree */
|
||||||
template<typename TStats, typename TConstraint>
|
|
||||||
class ColMaker: public TreeUpdater {
|
class ColMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
|
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||||
|
spliteval_->Init(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||||
DMatrix* dmat,
|
DMatrix* dmat,
|
||||||
const std::vector<RegTree*> &trees) override {
|
const std::vector<RegTree*> &trees) override {
|
||||||
TStats::CheckInfo(dmat->Info());
|
GradStats::CheckInfo(dmat->Info());
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
float lr = param_.learning_rate;
|
||||||
param_.learning_rate = lr / trees.size();
|
param_.learning_rate = lr / trees.size();
|
||||||
TConstraint::Init(¶m_, dmat->Info().num_col_);
|
|
||||||
// build tree
|
// build tree
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
Builder builder(param_);
|
Builder builder(
|
||||||
|
param_,
|
||||||
|
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()));
|
||||||
builder.Update(gpair->HostVector(), dmat, tree);
|
builder.Update(gpair->HostVector(), dmat, tree);
|
||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
param_.learning_rate = lr;
|
||||||
@ -46,13 +49,15 @@ class ColMaker: public TreeUpdater {
|
|||||||
protected:
|
protected:
|
||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param_;
|
TrainParam param_;
|
||||||
|
// SplitEvaluator that will be cloned for each Builder
|
||||||
|
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||||
// data structure
|
// data structure
|
||||||
/*! \brief per thread x per node entry to store tmp data */
|
/*! \brief per thread x per node entry to store tmp data */
|
||||||
struct ThreadEntry {
|
struct ThreadEntry {
|
||||||
/*! \brief statistics of data */
|
/*! \brief statistics of data */
|
||||||
TStats stats;
|
GradStats stats;
|
||||||
/*! \brief extra statistics of data */
|
/*! \brief extra statistics of data */
|
||||||
TStats stats_extra;
|
GradStats stats_extra;
|
||||||
/*! \brief last feature value scanned */
|
/*! \brief last feature value scanned */
|
||||||
bst_float last_fvalue;
|
bst_float last_fvalue;
|
||||||
/*! \brief first feature value scanned */
|
/*! \brief first feature value scanned */
|
||||||
@ -66,7 +71,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
};
|
};
|
||||||
struct NodeEntry {
|
struct NodeEntry {
|
||||||
/*! \brief statics for node entry */
|
/*! \brief statics for node entry */
|
||||||
TStats stats;
|
GradStats stats;
|
||||||
/*! \brief loss of this node, without split */
|
/*! \brief loss of this node, without split */
|
||||||
bst_float root_gain;
|
bst_float root_gain;
|
||||||
/*! \brief weight calculated related to current data */
|
/*! \brief weight calculated related to current data */
|
||||||
@ -82,24 +87,41 @@ class ColMaker: public TreeUpdater {
|
|||||||
class Builder {
|
class Builder {
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
explicit Builder(const TrainParam& param) : param_(param), nthread_(omp_get_max_threads()) {}
|
explicit Builder(const TrainParam& param,
|
||||||
|
std::unique_ptr<SplitEvaluator> spliteval)
|
||||||
|
: param_(param), nthread_(omp_get_max_threads()),
|
||||||
|
spliteval_(std::move(spliteval)) {}
|
||||||
// update one tree, growing
|
// update one tree, growing
|
||||||
virtual void Update(const std::vector<GradientPair>& gpair,
|
virtual void Update(const std::vector<GradientPair>& gpair,
|
||||||
DMatrix* p_fmat,
|
DMatrix* p_fmat,
|
||||||
RegTree* p_tree) {
|
RegTree* p_tree) {
|
||||||
|
std::vector<int> newnodes;
|
||||||
this->InitData(gpair, *p_fmat, *p_tree);
|
this->InitData(gpair, *p_fmat, *p_tree);
|
||||||
this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
|
this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
|
||||||
for (int depth = 0; depth < param_.max_depth; ++depth) {
|
for (int depth = 0; depth < param_.max_depth; ++depth) {
|
||||||
this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree);
|
this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree);
|
||||||
this->ResetPosition(qexpand_, p_fmat, *p_tree);
|
this->ResetPosition(qexpand_, p_fmat, *p_tree);
|
||||||
this->UpdateQueueExpand(*p_tree, &qexpand_);
|
this->UpdateQueueExpand(*p_tree, qexpand_, &newnodes);
|
||||||
this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
|
this->InitNewNode(newnodes, gpair, *p_fmat, *p_tree);
|
||||||
|
for (auto nid : qexpand_) {
|
||||||
|
if ((*p_tree)[nid].IsLeaf()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int cleft = (*p_tree)[nid].LeftChild();
|
||||||
|
int cright = (*p_tree)[nid].RightChild();
|
||||||
|
spliteval_->AddSplit(nid,
|
||||||
|
cleft,
|
||||||
|
cright,
|
||||||
|
snode_[nid].best.SplitIndex(),
|
||||||
|
snode_[cleft].weight,
|
||||||
|
snode_[cright].weight);
|
||||||
|
}
|
||||||
|
qexpand_ = newnodes;
|
||||||
// if nothing left to be expand, break
|
// if nothing left to be expand, break
|
||||||
if (qexpand_.size() == 0) break;
|
if (qexpand_.size() == 0) break;
|
||||||
}
|
}
|
||||||
// set all the rest expanding nodes to leaf
|
// set all the rest expanding nodes to leaf
|
||||||
for (size_t i = 0; i < qexpand_.size(); ++i) {
|
for (const int nid : qexpand_) {
|
||||||
const int nid = qexpand_[i];
|
|
||||||
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
||||||
}
|
}
|
||||||
// remember auxiliary statistics in the tree node
|
// remember auxiliary statistics in the tree node
|
||||||
@ -170,8 +192,8 @@ class ColMaker: public TreeUpdater {
|
|||||||
// reserve a small space
|
// reserve a small space
|
||||||
stemp_.clear();
|
stemp_.clear();
|
||||||
stemp_.resize(this->nthread_, std::vector<ThreadEntry>());
|
stemp_.resize(this->nthread_, std::vector<ThreadEntry>());
|
||||||
for (size_t i = 0; i < stemp_.size(); ++i) {
|
for (auto& i : stemp_) {
|
||||||
stemp_[i].clear(); stemp_[i].reserve(256);
|
i.clear(); i.reserve(256);
|
||||||
}
|
}
|
||||||
snode_.reserve(256);
|
snode_.reserve(256);
|
||||||
}
|
}
|
||||||
@ -193,11 +215,10 @@ class ColMaker: public TreeUpdater {
|
|||||||
const RegTree& tree) {
|
const RegTree& tree) {
|
||||||
{
|
{
|
||||||
// setup statistics space for each tree node
|
// setup statistics space for each tree node
|
||||||
for (size_t i = 0; i < stemp_.size(); ++i) {
|
for (auto& i : stemp_) {
|
||||||
stemp_[i].resize(tree.param.num_nodes, ThreadEntry(param_));
|
i.resize(tree.param.num_nodes, ThreadEntry(param_));
|
||||||
}
|
}
|
||||||
snode_.resize(tree.param.num_nodes, NodeEntry(param_));
|
snode_.resize(tree.param.num_nodes, NodeEntry(param_));
|
||||||
constraints_.resize(tree.param.num_nodes);
|
|
||||||
}
|
}
|
||||||
const RowSet &rowset = fmat.BufferedRowset();
|
const RowSet &rowset = fmat.BufferedRowset();
|
||||||
const MetaInfo& info = fmat.Info();
|
const MetaInfo& info = fmat.Info();
|
||||||
@ -212,43 +233,33 @@ class ColMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
// sum the per thread statistics together
|
// sum the per thread statistics together
|
||||||
for (int nid : qexpand) {
|
for (int nid : qexpand) {
|
||||||
TStats stats(param_);
|
GradStats stats(param_);
|
||||||
for (size_t tid = 0; tid < stemp_.size(); ++tid) {
|
for (auto& s : stemp_) {
|
||||||
stats.Add(stemp_[tid][nid].stats);
|
stats.Add(s[nid].stats);
|
||||||
}
|
}
|
||||||
// update node statistics
|
// update node statistics
|
||||||
snode_[nid].stats = stats;
|
snode_[nid].stats = stats;
|
||||||
}
|
}
|
||||||
// setup constraints before calculating the weight
|
|
||||||
for (int nid : qexpand) {
|
|
||||||
if (tree[nid].IsRoot()) continue;
|
|
||||||
const int pid = tree[nid].Parent();
|
|
||||||
constraints_[pid].SetChild(param_, tree[pid].SplitIndex(),
|
|
||||||
snode_[tree[pid].LeftChild()].stats,
|
|
||||||
snode_[tree[pid].RightChild()].stats,
|
|
||||||
&constraints_[tree[pid].LeftChild()],
|
|
||||||
&constraints_[tree[pid].RightChild()]);
|
|
||||||
}
|
|
||||||
// calculating the weights
|
// calculating the weights
|
||||||
for (int nid : qexpand) {
|
for (int nid : qexpand) {
|
||||||
|
bst_uint parentid = tree[nid].Parent();
|
||||||
snode_[nid].root_gain = static_cast<float>(
|
snode_[nid].root_gain = static_cast<float>(
|
||||||
constraints_[nid].CalcGain(param_, snode_[nid].stats));
|
spliteval_->ComputeScore(parentid, snode_[nid].stats));
|
||||||
snode_[nid].weight = static_cast<float>(
|
snode_[nid].weight = static_cast<float>(
|
||||||
constraints_[nid].CalcWeight(param_, snode_[nid].stats));
|
spliteval_->ComputeWeight(parentid, snode_[nid].stats));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*! \brief update queue expand add in new leaves */
|
/*! \brief update queue expand add in new leaves */
|
||||||
inline void UpdateQueueExpand(const RegTree& tree, std::vector<int>* p_qexpand) {
|
inline void UpdateQueueExpand(const RegTree& tree,
|
||||||
std::vector<int> &qexpand = *p_qexpand;
|
const std::vector<int> &qexpand,
|
||||||
std::vector<int> newnodes;
|
std::vector<int>* p_newnodes) {
|
||||||
|
p_newnodes->clear();
|
||||||
for (int nid : qexpand) {
|
for (int nid : qexpand) {
|
||||||
if (!tree[ nid ].IsLeaf()) {
|
if (!tree[ nid ].IsLeaf()) {
|
||||||
newnodes.push_back(tree[nid].LeftChild());
|
p_newnodes->push_back(tree[nid].LeftChild());
|
||||||
newnodes.push_back(tree[nid].RightChild());
|
p_newnodes->push_back(tree[nid].RightChild());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// use new nodes for qexpand
|
|
||||||
qexpand = newnodes;
|
|
||||||
}
|
}
|
||||||
// parallel find the best split of current fid
|
// parallel find the best split of current fid
|
||||||
// this function does not support nested functions
|
// this function does not support nested functions
|
||||||
@ -289,7 +300,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint j = 0; j < nnode; ++j) {
|
for (bst_omp_uint j = 0; j < nnode; ++j) {
|
||||||
const int nid = qexpand[j];
|
const int nid = qexpand[j];
|
||||||
TStats sum(param_), tmp(param_), c(param_);
|
GradStats sum(param_), tmp(param_), c(param_);
|
||||||
for (int tid = 0; tid < this->nthread_; ++tid) {
|
for (int tid = 0; tid < this->nthread_; ++tid) {
|
||||||
tmp = stemp_[tid][nid].stats;
|
tmp = stemp_[tid][nid].stats;
|
||||||
stemp_[tid][nid].stats = sum;
|
stemp_[tid][nid].stats = sum;
|
||||||
@ -316,8 +327,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (c.sum_hess >= param_.min_child_weight &&
|
if (c.sum_hess >= param_.min_child_weight &&
|
||||||
e.stats.sum_hess >= param_.min_child_weight) {
|
e.stats.sum_hess >= param_.min_child_weight) {
|
||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
param_, param_.monotone_constraints[fid], e.stats, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, fsplit, false);
|
e.best.Update(loss_chg, fid, fsplit, false);
|
||||||
}
|
}
|
||||||
@ -328,8 +338,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (c.sum_hess >= param_.min_child_weight &&
|
if (c.sum_hess >= param_.min_child_weight &&
|
||||||
tmp.sum_hess >= param_.min_child_weight) {
|
tmp.sum_hess >= param_.min_child_weight) {
|
||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
|
||||||
param_, param_.monotone_constraints[fid], tmp, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, fsplit, true);
|
e.best.Update(loss_chg, fid, fsplit, true);
|
||||||
}
|
}
|
||||||
@ -342,8 +351,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (c.sum_hess >= param_.min_child_weight &&
|
if (c.sum_hess >= param_.min_child_weight &&
|
||||||
tmp.sum_hess >= param_.min_child_weight) {
|
tmp.sum_hess >= param_.min_child_weight) {
|
||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
|
||||||
param_, param_.monotone_constraints[fid], tmp, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true);
|
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true);
|
||||||
}
|
}
|
||||||
@ -352,7 +360,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
// rescan, generate candidate split
|
// rescan, generate candidate split
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
TStats c(param_), cright(param_);
|
GradStats c(param_), cright(param_);
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
std::vector<ThreadEntry> &temp = stemp_[tid];
|
std::vector<ThreadEntry> &temp = stemp_[tid];
|
||||||
bst_uint step = (col.length + this->nthread_ - 1) / this->nthread_;
|
bst_uint step = (col.length + this->nthread_ - 1) / this->nthread_;
|
||||||
@ -375,8 +383,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (c.sum_hess >= param_.min_child_weight &&
|
if (c.sum_hess >= param_.min_child_weight &&
|
||||||
e.stats.sum_hess >= param_.min_child_weight) {
|
e.stats.sum_hess >= param_.min_child_weight) {
|
||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
param_, param_.monotone_constraints[fid], e.stats, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f,
|
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f,
|
||||||
false);
|
false);
|
||||||
@ -388,8 +395,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (c.sum_hess >= param_.min_child_weight &&
|
if (c.sum_hess >= param_.min_child_weight &&
|
||||||
cright.sum_hess >= param_.min_child_weight) {
|
cright.sum_hess >= param_.min_child_weight) {
|
||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, c, cright) -
|
||||||
param_, param_.monotone_constraints[fid], c, cright) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true);
|
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true);
|
||||||
}
|
}
|
||||||
@ -404,7 +410,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
// update enumeration solution
|
// update enumeration solution
|
||||||
inline void UpdateEnumeration(int nid, GradientPair gstats,
|
inline void UpdateEnumeration(int nid, GradientPair gstats,
|
||||||
bst_float fvalue, int d_step, bst_uint fid,
|
bst_float fvalue, int d_step, bst_uint fid,
|
||||||
TStats &c, std::vector<ThreadEntry> &temp) { // NOLINT(*)
|
GradStats &c, std::vector<ThreadEntry> &temp) { // NOLINT(*)
|
||||||
// get the statistics of nid
|
// get the statistics of nid
|
||||||
ThreadEntry &e = temp[nid];
|
ThreadEntry &e = temp[nid];
|
||||||
// test if first hit, this is fine, because we set 0 during init
|
// test if first hit, this is fine, because we set 0 during init
|
||||||
@ -420,13 +426,11 @@ class ColMaker: public TreeUpdater {
|
|||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
if (d_step == -1) {
|
if (d_step == -1) {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
||||||
param_, param_.monotone_constraints[fid], c, e.stats) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
param_, param_.monotone_constraints[fid], e.stats, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
}
|
}
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
|
||||||
@ -451,7 +455,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
temp[nid].stats.Clear();
|
temp[nid].stats.Clear();
|
||||||
}
|
}
|
||||||
// left statistics
|
// left statistics
|
||||||
TStats c(param_);
|
GradStats c(param_);
|
||||||
// local cache buffer for position and gradient pair
|
// local cache buffer for position and gradient pair
|
||||||
constexpr int kBuffer = 32;
|
constexpr int kBuffer = 32;
|
||||||
int buf_position[kBuffer] = {};
|
int buf_position[kBuffer] = {};
|
||||||
@ -502,13 +506,11 @@ class ColMaker: public TreeUpdater {
|
|||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
if (d_step == -1) {
|
if (d_step == -1) {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
||||||
param_, param_.monotone_constraints[fid], c, e.stats) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
param_, param_.monotone_constraints[fid], e.stats, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
}
|
}
|
||||||
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
||||||
@ -527,7 +529,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
std::vector<ThreadEntry> &temp) { // NOLINT(*)
|
std::vector<ThreadEntry> &temp) { // NOLINT(*)
|
||||||
// use cacheline aware optimization
|
// use cacheline aware optimization
|
||||||
if (TStats::kSimpleStats != 0 && param_.cache_opt != 0) {
|
if (GradStats::kSimpleStats != 0 && param_.cache_opt != 0) {
|
||||||
EnumerateSplitCacheOpt(begin, end, d_step, fid, gpair, temp);
|
EnumerateSplitCacheOpt(begin, end, d_step, fid, gpair, temp);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -537,7 +539,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
temp[nid].stats.Clear();
|
temp[nid].stats.Clear();
|
||||||
}
|
}
|
||||||
// left statistics
|
// left statistics
|
||||||
TStats c(param_);
|
GradStats c(param_);
|
||||||
for (const Entry *it = begin; it != end; it += d_step) {
|
for (const Entry *it = begin; it != end; it += d_step) {
|
||||||
const bst_uint ridx = it->index;
|
const bst_uint ridx = it->index;
|
||||||
const int nid = position_[ridx];
|
const int nid = position_[ridx];
|
||||||
@ -559,13 +561,11 @@ class ColMaker: public TreeUpdater {
|
|||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
if (d_step == -1) {
|
if (d_step == -1) {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
||||||
param_, param_.monotone_constraints[fid], c, e.stats) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
param_, param_.monotone_constraints[fid], e.stats, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
}
|
}
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
|
||||||
@ -585,13 +585,11 @@ class ColMaker: public TreeUpdater {
|
|||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
if (d_step == -1) {
|
if (d_step == -1) {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
||||||
param_, param_.monotone_constraints[fid], c, e.stats) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraints_[nid].CalcSplitGain(
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
param_, param_.monotone_constraints[fid], e.stats, c) -
|
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
}
|
}
|
||||||
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
||||||
@ -782,40 +780,43 @@ class ColMaker: public TreeUpdater {
|
|||||||
std::vector<NodeEntry> snode_;
|
std::vector<NodeEntry> snode_;
|
||||||
/*! \brief queue of nodes to be expanded */
|
/*! \brief queue of nodes to be expanded */
|
||||||
std::vector<int> qexpand_;
|
std::vector<int> qexpand_;
|
||||||
// constraint value
|
// Evaluates splits and computes optimal weights for a given split
|
||||||
std::vector<TConstraint> constraints_;
|
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// distributed column maker
|
// distributed column maker
|
||||||
template<typename TStats, typename TConstraint>
|
class DistColMaker : public ColMaker {
|
||||||
class DistColMaker : public ColMaker<TStats, TConstraint> {
|
|
||||||
public:
|
public:
|
||||||
DistColMaker() : builder_(param_) {
|
|
||||||
pruner_.reset(TreeUpdater::Create("prune"));
|
|
||||||
}
|
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
|
pruner_.reset(TreeUpdater::Create("prune"));
|
||||||
pruner_->Init(args);
|
pruner_->Init(args);
|
||||||
|
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||||
|
spliteval_->Init(args);
|
||||||
}
|
}
|
||||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||||
DMatrix* dmat,
|
DMatrix* dmat,
|
||||||
const std::vector<RegTree*> &trees) override {
|
const std::vector<RegTree*> &trees) override {
|
||||||
TStats::CheckInfo(dmat->Info());
|
GradStats::CheckInfo(dmat->Info());
|
||||||
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
|
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
|
||||||
|
Builder builder(
|
||||||
|
param_,
|
||||||
|
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()));
|
||||||
// build the tree
|
// build the tree
|
||||||
builder_.Update(gpair->HostVector(), dmat, trees[0]);
|
builder.Update(gpair->HostVector(), dmat, trees[0]);
|
||||||
//// prune the tree, note that pruner will sync the tree
|
//// prune the tree, note that pruner will sync the tree
|
||||||
pruner_->Update(gpair, dmat, trees);
|
pruner_->Update(gpair, dmat, trees);
|
||||||
// update position after the tree is pruned
|
// update position after the tree is pruned
|
||||||
builder_.UpdatePosition(dmat, *trees[0]);
|
builder.UpdatePosition(dmat, *trees[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Builder : public ColMaker<TStats, TConstraint>::Builder {
|
class Builder : public ColMaker::Builder {
|
||||||
public:
|
public:
|
||||||
explicit Builder(const TrainParam ¶m)
|
explicit Builder(const TrainParam ¶m,
|
||||||
: ColMaker<TStats, TConstraint>::Builder(param) {}
|
std::unique_ptr<SplitEvaluator> spliteval)
|
||||||
|
: ColMaker::Builder(param, std::move(spliteval)) {}
|
||||||
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
|
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
|
||||||
const RowSet &rowset = p_fmat->BufferedRowset();
|
const RowSet &rowset = p_fmat->BufferedRowset();
|
||||||
const auto ndata = static_cast<bst_omp_uint>(rowset.Size());
|
const auto ndata = static_cast<bst_omp_uint>(rowset.Size());
|
||||||
@ -929,55 +930,20 @@ class DistColMaker : public ColMaker<TStats, TConstraint> {
|
|||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param_;
|
TrainParam param_;
|
||||||
// pointer to the builder
|
// Cloned for each builder instantiation
|
||||||
Builder builder_;
|
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||||
};
|
|
||||||
|
|
||||||
// simple switch to defer implementation.
|
|
||||||
class TreeUpdaterSwitch : public TreeUpdater {
|
|
||||||
public:
|
|
||||||
TreeUpdaterSwitch() = default;
|
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
|
||||||
for (auto &kv : args) {
|
|
||||||
if (kv.first == "monotone_constraints" && kv.second.length() != 0) {
|
|
||||||
monotone_ = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (inner_ == nullptr) {
|
|
||||||
if (monotone_) {
|
|
||||||
inner_.reset(new ColMaker<GradStats, ValueConstraint>());
|
|
||||||
} else {
|
|
||||||
inner_.reset(new ColMaker<GradStats, NoConstraint>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inner_->Init(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair,
|
|
||||||
DMatrix* data,
|
|
||||||
const std::vector<RegTree*>& trees) override {
|
|
||||||
CHECK(inner_ != nullptr);
|
|
||||||
inner_->Update(gpair, data, trees);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// monotone constraints
|
|
||||||
bool monotone_{false};
|
|
||||||
// internal implementation
|
|
||||||
std::unique_ptr<TreeUpdater> inner_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||||
.describe("Grow tree with parallelization over columns.")
|
.describe("Grow tree with parallelization over columns.")
|
||||||
.set_body([]() {
|
.set_body([]() {
|
||||||
return new TreeUpdaterSwitch();
|
return new ColMaker();
|
||||||
});
|
});
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(DistColMaker, "distcol")
|
XGBOOST_REGISTER_TREE_UPDATER(DistColMaker, "distcol")
|
||||||
.describe("Distributed column split version of tree maker.")
|
.describe("Distributed column split version of tree maker.")
|
||||||
.set_body([]() {
|
.set_body([]() {
|
||||||
return new DistColMaker<GradStats, NoConstraint>();
|
return new DistColMaker();
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./fast_hist_param.h"
|
#include "./fast_hist_param.h"
|
||||||
|
#include "./split_evaluator.h"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "../common/bitmap.h"
|
#include "../common/bitmap.h"
|
||||||
#include "../common/sync.h"
|
#include "../common/sync.h"
|
||||||
@ -42,7 +43,6 @@ DMLC_REGISTRY_FILE_TAG(updater_fast_hist);
|
|||||||
DMLC_REGISTER_PARAMETER(FastHistParam);
|
DMLC_REGISTER_PARAMETER(FastHistParam);
|
||||||
|
|
||||||
/*! \brief construct a tree using quantized feature values */
|
/*! \brief construct a tree using quantized feature values */
|
||||||
template<typename TStats, typename TConstraint>
|
|
||||||
class FastHistMaker: public TreeUpdater {
|
class FastHistMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
@ -54,12 +54,19 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
fhparam_.InitAllowUnknown(args);
|
fhparam_.InitAllowUnknown(args);
|
||||||
is_gmat_initialized_ = false;
|
is_gmat_initialized_ = false;
|
||||||
|
|
||||||
|
// initialise the split evaluator
|
||||||
|
if (!spliteval_) {
|
||||||
|
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||||
|
}
|
||||||
|
|
||||||
|
spliteval_->Init(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair,
|
void Update(HostDeviceVector<GradientPair>* gpair,
|
||||||
DMatrix* dmat,
|
DMatrix* dmat,
|
||||||
const std::vector<RegTree*>& trees) override {
|
const std::vector<RegTree*>& trees) override {
|
||||||
TStats::CheckInfo(dmat->Info());
|
GradStats::CheckInfo(dmat->Info());
|
||||||
if (is_gmat_initialized_ == false) {
|
if (is_gmat_initialized_ == false) {
|
||||||
double tstart = dmlc::GetTime();
|
double tstart = dmlc::GetTime();
|
||||||
hmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
|
hmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
|
||||||
@ -77,10 +84,13 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
float lr = param_.learning_rate;
|
||||||
param_.learning_rate = lr / trees.size();
|
param_.learning_rate = lr / trees.size();
|
||||||
TConstraint::Init(¶m_, dmat->Info().num_col_);
|
|
||||||
// build tree
|
// build tree
|
||||||
if (!builder_) {
|
if (!builder_) {
|
||||||
builder_.reset(new Builder(param_, fhparam_, std::move(pruner_)));
|
builder_.reset(new Builder(
|
||||||
|
param_,
|
||||||
|
fhparam_,
|
||||||
|
std::move(pruner_),
|
||||||
|
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone())));
|
||||||
}
|
}
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
builder_->Update
|
builder_->Update
|
||||||
@ -115,7 +125,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
// data structure
|
// data structure
|
||||||
struct NodeEntry {
|
struct NodeEntry {
|
||||||
/*! \brief statics for node entry */
|
/*! \brief statics for node entry */
|
||||||
TStats stats;
|
GradStats stats;
|
||||||
/*! \brief loss of this node, without split */
|
/*! \brief loss of this node, without split */
|
||||||
bst_float root_gain;
|
bst_float root_gain;
|
||||||
/*! \brief weight calculated related to current data */
|
/*! \brief weight calculated related to current data */
|
||||||
@ -134,9 +144,11 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
// constructor
|
// constructor
|
||||||
explicit Builder(const TrainParam& param,
|
explicit Builder(const TrainParam& param,
|
||||||
const FastHistParam& fhparam,
|
const FastHistParam& fhparam,
|
||||||
std::unique_ptr<TreeUpdater> pruner)
|
std::unique_ptr<TreeUpdater> pruner,
|
||||||
|
std::unique_ptr<SplitEvaluator> spliteval)
|
||||||
: param_(param), fhparam_(fhparam), pruner_(std::move(pruner)),
|
: param_(param), fhparam_(fhparam), pruner_(std::move(pruner)),
|
||||||
p_last_tree_(nullptr), p_last_fmat_(nullptr) {}
|
spliteval_(std::move(spliteval)), p_last_tree_(nullptr),
|
||||||
|
p_last_fmat_(nullptr) {}
|
||||||
// update one tree, growing
|
// update one tree, growing
|
||||||
virtual void Update(const GHistIndexMatrix& gmat,
|
virtual void Update(const GHistIndexMatrix& gmat,
|
||||||
const GHistIndexBlockMatrix& gmatb,
|
const GHistIndexBlockMatrix& gmatb,
|
||||||
@ -158,6 +170,8 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
std::vector<GradientPair>& gpair_h = gpair->HostVector();
|
std::vector<GradientPair>& gpair_h = gpair->HostVector();
|
||||||
|
|
||||||
|
spliteval_->Reset();
|
||||||
|
|
||||||
tstart = dmlc::GetTime();
|
tstart = dmlc::GetTime();
|
||||||
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
|
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
|
||||||
std::vector<bst_uint> feat_set = feat_index_;
|
std::vector<bst_uint> feat_set = feat_index_;
|
||||||
@ -215,6 +229,9 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
tstart = dmlc::GetTime();
|
tstart = dmlc::GetTime();
|
||||||
this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree);
|
this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree);
|
||||||
this->InitNewNode(cright, gmat, gpair_h, *p_fmat, *p_tree);
|
this->InitNewNode(cright, gmat, gpair_h, *p_fmat, *p_tree);
|
||||||
|
bst_uint featureid = snode_[nid].best.SplitIndex();
|
||||||
|
spliteval_->AddSplit(nid, cleft, cright, featureid,
|
||||||
|
snode_[cleft].weight, snode_[cright].weight);
|
||||||
time_init_new_node += dmlc::GetTime() - tstart;
|
time_init_new_node += dmlc::GetTime() - tstart;
|
||||||
|
|
||||||
tstart = dmlc::GetTime();
|
tstart = dmlc::GetTime();
|
||||||
@ -483,10 +500,10 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
for (bst_omp_uint i = 0; i < nfeature; ++i) {
|
for (bst_omp_uint i = 0; i < nfeature; ++i) {
|
||||||
const bst_uint fid = feat_set[i];
|
const bst_uint fid = feat_set[i];
|
||||||
const unsigned tid = omp_get_thread_num();
|
const unsigned tid = omp_get_thread_num();
|
||||||
this->EnumerateSplit(-1, gmat, hist[nid], snode_[nid], constraints_[nid], info,
|
this->EnumerateSplit(-1, gmat, hist[nid], snode_[nid], info,
|
||||||
&best_split_tloc_[tid], fid);
|
&best_split_tloc_[tid], fid, nid);
|
||||||
this->EnumerateSplit(+1, gmat, hist[nid], snode_[nid], constraints_[nid], info,
|
this->EnumerateSplit(+1, gmat, hist[nid], snode_[nid], info,
|
||||||
&best_split_tloc_[tid], fid);
|
&best_split_tloc_[tid], fid, nid);
|
||||||
}
|
}
|
||||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
for (unsigned tid = 0; tid < nthread; ++tid) {
|
||||||
snode_[nid].best.Update(best_split_tloc_[tid]);
|
snode_[nid].best.Update(best_split_tloc_[tid]);
|
||||||
@ -629,75 +646,6 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ApplySplitSparseDataOld(const RowSetCollection::Elem rowset,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
|
||||||
bst_uint lower_bound,
|
|
||||||
bst_uint upper_bound,
|
|
||||||
bst_int split_cond,
|
|
||||||
bool default_left) {
|
|
||||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
|
||||||
constexpr int kUnroll = 8; // loop unrolling factor
|
|
||||||
const size_t nrows = rowset.end - rowset.begin;
|
|
||||||
const size_t rest = nrows % kUnroll;
|
|
||||||
#pragma omp parallel for num_threads(nthread_) schedule(static)
|
|
||||||
for (bst_omp_uint i = 0; i < nrows - rest; i += kUnroll) {
|
|
||||||
size_t rid[kUnroll];
|
|
||||||
GHistIndexRow row[kUnroll];
|
|
||||||
const uint32_t* p[kUnroll];
|
|
||||||
bst_uint tid = omp_get_thread_num();
|
|
||||||
auto& left = row_split_tloc[tid].left;
|
|
||||||
auto& right = row_split_tloc[tid].right;
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
rid[k] = rowset.begin[i + k];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
row[k] = gmat[rid[k]];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
p[k] = std::lower_bound(row[k].index, row[k].index + row[k].size, lower_bound);
|
|
||||||
}
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) {
|
|
||||||
CHECK_LT(*p[k],
|
|
||||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
|
||||||
if (static_cast<int32_t>(*p[k]) <= split_cond) {
|
|
||||||
left.push_back(rid[k]);
|
|
||||||
} else {
|
|
||||||
right.push_back(rid[k]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (default_left) {
|
|
||||||
left.push_back(rid[k]);
|
|
||||||
} else {
|
|
||||||
right.push_back(rid[k]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = nrows - rest; i < nrows; ++i) {
|
|
||||||
const size_t rid = rowset.begin[i];
|
|
||||||
const auto row = gmat[rid];
|
|
||||||
const auto p = std::lower_bound(row.index, row.index + row.size, lower_bound);
|
|
||||||
auto& left = row_split_tloc[0].left;
|
|
||||||
auto& right = row_split_tloc[0].right;
|
|
||||||
if (p != row.index + row.size && *p < upper_bound) {
|
|
||||||
CHECK_LT(*p, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
|
||||||
if (static_cast<int32_t>(*p) <= split_cond) {
|
|
||||||
left.push_back(rid);
|
|
||||||
} else {
|
|
||||||
right.push_back(rid);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (default_left) {
|
|
||||||
left.push_back(rid);
|
|
||||||
} else {
|
|
||||||
right.push_back(rid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
inline void ApplySplitSparseData(const RowSetCollection::Elem rowset,
|
inline void ApplySplitSparseData(const RowSetCollection::Elem rowset,
|
||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
@ -776,10 +724,8 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
const RegTree& tree) {
|
const RegTree& tree) {
|
||||||
{
|
{
|
||||||
snode_.resize(tree.param.num_nodes, NodeEntry(param_));
|
snode_.resize(tree.param.num_nodes, NodeEntry(param_));
|
||||||
constraints_.resize(tree.param.num_nodes);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup constraints before calculating the weight
|
|
||||||
{
|
{
|
||||||
auto& stats = snode_[nid].stats;
|
auto& stats = snode_[nid].stats;
|
||||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||||
@ -801,22 +747,15 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
stats.Add(gpair[*it]);
|
stats.Add(gpair[*it]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!tree[nid].IsRoot()) {
|
|
||||||
const int pid = tree[nid].Parent();
|
|
||||||
constraints_[pid].SetChild(param_, tree[pid].SplitIndex(),
|
|
||||||
snode_[tree[pid].LeftChild()].stats,
|
|
||||||
snode_[tree[pid].RightChild()].stats,
|
|
||||||
&constraints_[tree[pid].LeftChild()],
|
|
||||||
&constraints_[tree[pid].RightChild()]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculating the weights
|
// calculating the weights
|
||||||
{
|
{
|
||||||
|
bst_uint parentid = tree[nid].Parent();
|
||||||
snode_[nid].root_gain = static_cast<float>(
|
snode_[nid].root_gain = static_cast<float>(
|
||||||
constraints_[nid].CalcGain(param_, snode_[nid].stats));
|
spliteval_->ComputeScore(parentid, snode_[nid].stats));
|
||||||
snode_[nid].weight = static_cast<float>(
|
snode_[nid].weight = static_cast<float>(
|
||||||
constraints_[nid].CalcWeight(param_, snode_[nid].stats));
|
spliteval_->ComputeWeight(parentid, snode_[nid].stats));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -825,10 +764,10 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
const GHistRow& hist,
|
const GHistRow& hist,
|
||||||
const NodeEntry& snode,
|
const NodeEntry& snode,
|
||||||
const TConstraint& constraint,
|
|
||||||
const MetaInfo& info,
|
const MetaInfo& info,
|
||||||
SplitEntry* p_best,
|
SplitEntry* p_best,
|
||||||
bst_uint fid) {
|
bst_uint fid,
|
||||||
|
bst_uint nodeID) {
|
||||||
CHECK(d_step == +1 || d_step == -1);
|
CHECK(d_step == +1 || d_step == -1);
|
||||||
|
|
||||||
// aliases
|
// aliases
|
||||||
@ -836,8 +775,8 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
const std::vector<bst_float>& cut_val = gmat.cut->cut;
|
const std::vector<bst_float>& cut_val = gmat.cut->cut;
|
||||||
|
|
||||||
// statistics on both sides of split
|
// statistics on both sides of split
|
||||||
TStats c(param_);
|
GradStats c(param_);
|
||||||
TStats e(param_);
|
GradStats e(param_);
|
||||||
// best split so far
|
// best split so far
|
||||||
SplitEntry best;
|
SplitEntry best;
|
||||||
|
|
||||||
@ -872,13 +811,13 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
if (d_step > 0) {
|
if (d_step > 0) {
|
||||||
// forward enumeration: split at right bound of each bin
|
// forward enumeration: split at right bound of each bin
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraint.CalcSplitGain(param_, param_.monotone_constraints[fid], e, c) -
|
spliteval_->ComputeSplitScore(nodeID, fid, e, c) -
|
||||||
snode.root_gain);
|
snode.root_gain);
|
||||||
split_pt = cut_val[i];
|
split_pt = cut_val[i];
|
||||||
} else {
|
} else {
|
||||||
// backward enumeration: split at left bound of each bin
|
// backward enumeration: split at left bound of each bin
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
constraint.CalcSplitGain(param_, param_.monotone_constraints[fid], c, e) -
|
spliteval_->ComputeSplitScore(nodeID, fid, c, e) -
|
||||||
snode.root_gain);
|
snode.root_gain);
|
||||||
if (i == imin) {
|
if (i == imin) {
|
||||||
// for leftmost bin, left bound is the smallest feature value
|
// for leftmost bin, left bound is the smallest feature value
|
||||||
@ -942,14 +881,12 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
GHistBuilder hist_builder_;
|
GHistBuilder hist_builder_;
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
|
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||||
|
|
||||||
// back pointers to tree and data matrix
|
// back pointers to tree and data matrix
|
||||||
const RegTree* p_last_tree_;
|
const RegTree* p_last_tree_;
|
||||||
const DMatrix* p_last_fmat_;
|
const DMatrix* p_last_fmat_;
|
||||||
|
|
||||||
// constraint value
|
|
||||||
std::vector<TConstraint> constraints_;
|
|
||||||
|
|
||||||
using ExpandQueue =
|
using ExpandQueue =
|
||||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||||
@ -961,47 +898,13 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
std::unique_ptr<Builder> builder_;
|
std::unique_ptr<Builder> builder_;
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
};
|
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||||
|
|
||||||
// simple switch to defer implementation.
|
|
||||||
class FastHistTreeUpdaterSwitch : public TreeUpdater {
|
|
||||||
public:
|
|
||||||
FastHistTreeUpdaterSwitch() = default;
|
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
|
||||||
for (auto &kv : args) {
|
|
||||||
if (kv.first == "monotone_constraints" && kv.second.length() != 0) {
|
|
||||||
monotone_ = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (inner_ == nullptr) {
|
|
||||||
if (monotone_) {
|
|
||||||
inner_.reset(new FastHistMaker<GradStats, ValueConstraint>());
|
|
||||||
} else {
|
|
||||||
inner_.reset(new FastHistMaker<GradStats, NoConstraint>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inner_->Init(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair,
|
|
||||||
DMatrix* data,
|
|
||||||
const std::vector<RegTree*>& trees) override {
|
|
||||||
CHECK(inner_ != nullptr);
|
|
||||||
inner_->Update(gpair, data, trees);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// monotone constraints
|
|
||||||
bool monotone_{false};
|
|
||||||
// internal implementation
|
|
||||||
std::unique_ptr<TreeUpdater> inner_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
|
||||||
.describe("Grow tree using quantized histogram.")
|
.describe("Grow tree using quantized histogram.")
|
||||||
.set_body([]() {
|
.set_body([]() {
|
||||||
return new FastHistTreeUpdaterSwitch();
|
return new FastHistMaker();
|
||||||
});
|
});
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user