:Merge branch 'unity'
Conflicts: src/gbm/gbtree-inl.hpp src/learner/evaluation-inl.hpp src/tree/param.h
This commit is contained in:
commit
e4b9ee22fa
8
Makefile
8
Makefile
@ -5,9 +5,9 @@ export LDFLAGS= -pthread -lm
|
||||
# add include path to Rinternals.h here
|
||||
|
||||
ifeq ($(no_omp),1)
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -DDISABLE_OPENMP
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -DDISABLE_OPENMP -funroll-loops
|
||||
else
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp -funroll-loops
|
||||
endif
|
||||
|
||||
# expose these flags to R CMD SHLIB
|
||||
@ -18,11 +18,11 @@ BIN = xgboost
|
||||
OBJ =
|
||||
SLIB = wrapper/libxgboostwrapper.so
|
||||
RLIB = wrapper/libxgboostR.so
|
||||
.PHONY: clean all R
|
||||
.PHONY: clean all R python
|
||||
|
||||
all: $(BIN) wrapper/libxgboostwrapper.so
|
||||
R: wrapper/libxgboostR.so
|
||||
|
||||
python: wrapper/libxgboostwrapper.so
|
||||
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
|
||||
# now the wrapper takes in two files. io and wrapper part
|
||||
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
||||
|
||||
@ -117,17 +117,13 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
}
|
||||
|
||||
std::vector<float> &preds = *out_preds;
|
||||
preds.resize(0);
|
||||
const size_t stride = info.num_row * mparam.num_output_group;
|
||||
preds.resize(stride * (mparam.size_leaf_vector+1));
|
||||
// start collecting the prediction
|
||||
utils::IIterator<SparseBatch> *iter = fmat.RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const SparseBatch &batch = iter->Value();
|
||||
utils::Assert(batch.base_rowid * mparam.num_output_group == preds.size(),
|
||||
"base_rowid is not set correctly");
|
||||
// output convention: nrow * k, where nrow is number of rows
|
||||
// k is number of group
|
||||
preds.resize(preds.size() + batch.size * mparam.num_output_group);
|
||||
// parallel over local batch
|
||||
const unsigned nsize = static_cast<unsigned>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
@ -135,13 +131,13 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
const int tid = omp_get_thread_num();
|
||||
tree::RegTree::FVec &feats = thread_temp[tid];
|
||||
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
const unsigned root_idx = info.GetRoot(ridx);
|
||||
utils::Assert(static_cast<size_t>(ridx) < info.num_row, "data row index exceed bound");
|
||||
// loop over output groups
|
||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||
preds[ridx * mparam.num_output_group + gid] =
|
||||
this->Pred(batch[i],
|
||||
buffer_offset < 0 ? -1 : buffer_offset+ridx,
|
||||
gid, root_idx, &feats);
|
||||
this->Pred(batch[i],
|
||||
buffer_offset < 0 ? -1 : buffer_offset + ridx,
|
||||
gid, info.GetRoot(ridx), &feats,
|
||||
&preds[ridx * mparam.num_output_group + gid], stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -211,24 +207,34 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
mparam.num_trees += tparam.num_parallel_tree;
|
||||
}
|
||||
// make a prediction for a single instance
|
||||
inline float Pred(const SparseBatch::Inst &inst,
|
||||
int64_t buffer_index,
|
||||
int bst_group,
|
||||
unsigned root_index,
|
||||
tree::RegTree::FVec *p_feats) {
|
||||
inline void Pred(const SparseBatch::Inst &inst,
|
||||
int64_t buffer_index,
|
||||
int bst_group,
|
||||
unsigned root_index,
|
||||
tree::RegTree::FVec *p_feats,
|
||||
float *out_pred, size_t stride) {
|
||||
size_t itop = 0;
|
||||
float psum = 0.0f;
|
||||
// sum of leaf vector
|
||||
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
|
||||
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
|
||||
// load buffered results if any
|
||||
if (bid >= 0) {
|
||||
itop = pred_counter[bid];
|
||||
psum = pred_buffer[bid];
|
||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||
vec_psum[i] = pred_buffer[bid + i + 1];
|
||||
}
|
||||
}
|
||||
if (itop != trees.size()) {
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = itop; i < trees.size(); ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
psum += trees[i]->Predict(*p_feats, root_index);
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats, root_index);
|
||||
psum += (*trees[i])[tid].leaf_value();
|
||||
for (int j = 0; j < mparam.size_leaf_vector; ++j) {
|
||||
vec_psum[j] += trees[i]->leafvec(tid)[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
@ -237,8 +243,14 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
if (bid >= 0) {
|
||||
pred_counter[bid] = static_cast<unsigned>(trees.size());
|
||||
pred_buffer[bid] = psum;
|
||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||
pred_buffer[bid + i + 1] = vec_psum[i];
|
||||
}
|
||||
}
|
||||
out_pred[0] = psum;
|
||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||
out_pred[stride * (i + 1)] = vec_psum[i];
|
||||
}
|
||||
return psum;
|
||||
}
|
||||
// --- data structure ---
|
||||
/*! \brief training parameters */
|
||||
@ -291,14 +303,17 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
* suppose we have n instance and k group, output will be k*n
|
||||
*/
|
||||
int num_output_group;
|
||||
/*! \brief size of leaf vector needed in tree */
|
||||
int size_leaf_vector;
|
||||
/*! \brief reserved parameters */
|
||||
int reserved[32];
|
||||
int reserved[31];
|
||||
/*! \brief constructor */
|
||||
ModelParam(void) {
|
||||
num_trees = 0;
|
||||
num_roots = num_feature = 0;
|
||||
num_pbuffer = 0;
|
||||
num_output_group = 1;
|
||||
size_leaf_vector = 0;
|
||||
memset(reserved, 0, sizeof(reserved));
|
||||
}
|
||||
/*!
|
||||
@ -311,10 +326,11 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
if (!strcmp("num_output_group", name)) num_output_group = atol(val);
|
||||
if (!strcmp("bst:num_roots", name)) num_roots = atoi(val);
|
||||
if (!strcmp("bst:num_feature", name)) num_feature = atoi(val);
|
||||
if (!strcmp("bst:size_leaf_vector", name)) size_leaf_vector = atoi(val);
|
||||
}
|
||||
/*! \return size of prediction buffer actually needed */
|
||||
inline size_t PredBufferSize(void) const {
|
||||
return num_output_group * num_pbuffer;
|
||||
return num_output_group * num_pbuffer * (size_leaf_vector + 1);
|
||||
}
|
||||
/*!
|
||||
* \brief get the buffer offset given a buffer index and group id
|
||||
@ -323,7 +339,7 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
|
||||
if (buffer_index < 0) return -1;
|
||||
utils::Check(buffer_index < num_pbuffer, "buffer_index exceed num_pbuffer");
|
||||
return buffer_index + num_pbuffer * bst_group;
|
||||
return (buffer_index + num_pbuffer * bst_group) * (size_leaf_vector + 1);
|
||||
}
|
||||
};
|
||||
// training parameter
|
||||
|
||||
@ -24,9 +24,10 @@ template<typename Derived>
|
||||
struct EvalEWiseBase : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
@ -99,6 +100,45 @@ struct EvalMatchError : public EvalEWiseBase<EvalMatchError> {
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief ctest */
|
||||
struct EvalCTest: public IEvaluator {
|
||||
EvalCTest(IEvaluator *base, const char *name)
|
||||
: base_(base), name_(name) {}
|
||||
virtual ~EvalCTest(void) {
|
||||
delete base_;
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return name_.c_str();
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
size_t ngroup = preds.size() / info.labels.size() - 1;
|
||||
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||
utils::Check(ngroup > 1, "pred size does not meet requirement");
|
||||
utils::Check(ndata == info.info.fold_index.size(), "need fold index");
|
||||
double wsum = 0.0;
|
||||
for (size_t k = 0; k < ngroup; ++k) {
|
||||
std::vector<float> tpred;
|
||||
MetaInfo tinfo;
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
if (info.info.fold_index[i] == k) {
|
||||
tpred.push_back(preds[i + (k + 1) * ndata]);
|
||||
tinfo.labels.push_back(info.labels[i]);
|
||||
tinfo.weights.push_back(info.GetWeight(i));
|
||||
}
|
||||
}
|
||||
wsum += base_->Eval(tpred, tinfo);
|
||||
}
|
||||
return wsum / ngroup;
|
||||
}
|
||||
|
||||
private:
|
||||
IEvaluator *base_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
/*! \brief AMS: also records best threshold */
|
||||
struct EvalAMS : public IEvaluator {
|
||||
public:
|
||||
@ -109,7 +149,7 @@ struct EvalAMS : public IEvaluator {
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||
utils::Check(info.weights.size() == ndata, "we need weight to evaluate ams");
|
||||
std::vector< std::pair<float, unsigned> > rec(ndata);
|
||||
|
||||
@ -206,10 +246,14 @@ struct EvalPrecisionRatio : public IEvaluator{
|
||||
struct EvalAuc : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
utils::Check(preds.size() == info.labels.size(), "label size predict size not match");
|
||||
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(preds.size());
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label size predict size not match");
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(info.labels.size());
|
||||
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
utils::Check(gptr.back() == preds.size(),
|
||||
utils::Check(gptr.back() == info.labels.size(),
|
||||
"EvalAuc: group structure must match number of prediction");
|
||||
const unsigned ngroup = static_cast<unsigned>(gptr.size() - 1);
|
||||
// sum statictis
|
||||
|
||||
@ -45,7 +45,9 @@ inline IEvaluator* CreateEvaluator(const char *name) {
|
||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
|
||||
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
|
||||
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
|
||||
if (!strncmp(name, "ndcg", 4)) return new EvalNDCG(name);
|
||||
if (!strncmp(name, "ct-", 3)) return new EvalCTest(CreateEvaluator(name+3), name);
|
||||
|
||||
utils::Error("unknown evaluation metric type: %s", name);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@ -123,7 +123,7 @@ class RegLossObj : public IObjFunction{
|
||||
float p = loss.PredTransform(preds[i]);
|
||||
float w = info.GetWeight(j);
|
||||
if (info.labels[j] == 1.0f) w *= scale_pos_weight;
|
||||
gpair[j] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w,
|
||||
gpair[i] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w,
|
||||
loss.SecondOrderGradient(p, info.labels[j]) * w);
|
||||
}
|
||||
}
|
||||
|
||||
@ -270,6 +270,7 @@ class TreeModel {
|
||||
param.num_nodes = param.num_roots;
|
||||
nodes.resize(param.num_nodes);
|
||||
stats.resize(param.num_nodes);
|
||||
leaf_vector.resize(param.num_nodes * param.size_leaf_vector, 0.0f);
|
||||
for (int i = 0; i < param.num_nodes; i ++) {
|
||||
nodes[i].set_leaf(0.0f);
|
||||
nodes[i].set_parent(-1);
|
||||
|
||||
134
src/tree/param.h
134
src/tree/param.h
@ -22,10 +22,10 @@ struct TrainParam{
|
||||
//----- the rest parameters are less important ----
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
float min_child_weight;
|
||||
// weight decay parameter used to control leaf fitting
|
||||
// L2 regularization factor
|
||||
float reg_lambda;
|
||||
// reg method
|
||||
int reg_method;
|
||||
// L1 regularization factor
|
||||
float reg_alpha;
|
||||
// default direction choice
|
||||
int default_direction;
|
||||
// whether we want to do subsample
|
||||
@ -36,6 +36,8 @@ struct TrainParam{
|
||||
float colsample_bytree;
|
||||
// speed optimization for dense column
|
||||
float opt_dense_col;
|
||||
// leaf vector size
|
||||
int size_leaf_vector;
|
||||
// number of threads to be used for tree construction,
|
||||
// if OpenMP is enabled, if equals 0, use system default
|
||||
int nthread;
|
||||
@ -45,13 +47,14 @@ struct TrainParam{
|
||||
min_child_weight = 1.0f;
|
||||
max_depth = 6;
|
||||
reg_lambda = 1.0f;
|
||||
reg_method = 2;
|
||||
reg_alpha = 0.0f;
|
||||
default_direction = 0;
|
||||
subsample = 1.0f;
|
||||
colsample_bytree = 1.0f;
|
||||
colsample_bylevel = 1.0f;
|
||||
opt_dense_col = 1.0f;
|
||||
nthread = 0;
|
||||
size_leaf_vector = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
@ -63,15 +66,17 @@ struct TrainParam{
|
||||
if (!strcmp(name, "gamma")) min_split_loss = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "eta")) learning_rate = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "lambda")) reg_lambda = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "alpha")) reg_alpha = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "reg_method")) reg_method = atoi(val);
|
||||
if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val);
|
||||
if (!strcmp(name, "max_depth")) max_depth = atoi(val);
|
||||
if (!strcmp(name, "nthread")) nthread = atoi(val);
|
||||
if (!strcmp(name, "default_direction")) {
|
||||
@ -82,31 +87,31 @@ struct TrainParam{
|
||||
}
|
||||
// calculate the cost of loss function
|
||||
inline double CalcGain(double sum_grad, double sum_hess) const {
|
||||
if (sum_hess < min_child_weight) {
|
||||
return 0.0;
|
||||
if (sum_hess < min_child_weight) return 0.0;
|
||||
if (reg_alpha == 0.0f) {
|
||||
return Sqr(sum_grad) / (sum_hess + reg_lambda);
|
||||
} else {
|
||||
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
|
||||
}
|
||||
switch (reg_method) {
|
||||
case 1 : return Sqr(ThresholdL1(sum_grad, reg_lambda)) / sum_hess;
|
||||
case 2 : return Sqr(sum_grad) / (sum_hess + reg_lambda);
|
||||
case 3 : return
|
||||
Sqr(ThresholdL1(sum_grad, 0.5 * reg_lambda)) /
|
||||
(sum_hess + 0.5 * reg_lambda);
|
||||
default: return Sqr(sum_grad) / sum_hess;
|
||||
}
|
||||
// calculate cost of loss function with four stati
|
||||
inline double CalcGain(double sum_grad, double sum_hess,
|
||||
double test_grad, double test_hess) const {
|
||||
double w = CalcWeight(sum_grad, sum_hess);
|
||||
double ret = test_grad * w + 0.5 * (test_hess + reg_lambda) * Sqr(w);
|
||||
if (reg_alpha == 0.0f) {
|
||||
return - 2.0 * ret;
|
||||
} else {
|
||||
return - 2.0 * (ret + reg_alpha * std::abs(w));
|
||||
}
|
||||
}
|
||||
// calculate weight given the statistics
|
||||
inline double CalcWeight(double sum_grad, double sum_hess) const {
|
||||
if (sum_hess < min_child_weight) {
|
||||
return 0.0;
|
||||
if (sum_hess < min_child_weight) return 0.0;
|
||||
if (reg_alpha == 0.0f) {
|
||||
return -sum_grad / (sum_hess + reg_lambda);
|
||||
} else {
|
||||
switch (reg_method) {
|
||||
case 1: return - ThresholdL1(sum_grad, reg_lambda) / sum_hess;
|
||||
case 2: return - sum_grad / (sum_hess + reg_lambda);
|
||||
case 3: return
|
||||
- ThresholdL1(sum_grad, 0.5 * reg_lambda) /
|
||||
(sum_hess + 0.5 * reg_lambda);
|
||||
default: return - sum_grad / sum_hess;
|
||||
}
|
||||
return -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
|
||||
}
|
||||
}
|
||||
/*! \brief whether need forward small to big search: default right */
|
||||
@ -153,6 +158,9 @@ struct GradStats {
|
||||
inline void Clear(void) {
|
||||
sum_grad = sum_hess = 0.0f;
|
||||
}
|
||||
/*! \brief check if necessary information is ready */
|
||||
inline static void CheckInfo(const BoosterInfo &info) {
|
||||
}
|
||||
/*!
|
||||
* \brief accumulate statistics,
|
||||
* \param gpair the vector storing the gradient statistics
|
||||
@ -188,14 +196,88 @@ struct GradStats {
|
||||
}
|
||||
/*! \brief set leaf vector value based on statistics */
|
||||
inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const{
|
||||
}
|
||||
protected:
|
||||
}
|
||||
// constructor to allow inheritance
|
||||
GradStats(void) {}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(double grad, double hess) {
|
||||
sum_grad += grad; sum_hess += hess;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief vectorized cv statistics */
|
||||
template<unsigned vsize>
|
||||
struct CVGradStats : public GradStats {
|
||||
// additional statistics
|
||||
GradStats train[vsize], valid[vsize];
|
||||
// constructor
|
||||
explicit CVGradStats(const TrainParam ¶m) {
|
||||
utils::Check(param.size_leaf_vector == vsize,
|
||||
"CVGradStats: vsize must match size_leaf_vector");
|
||||
this->Clear();
|
||||
}
|
||||
/*! \brief check if necessary information is ready */
|
||||
inline static void CheckInfo(const BoosterInfo &info) {
|
||||
utils::Check(info.fold_index.size() != 0,
|
||||
"CVGradStats: require fold_index");
|
||||
}
|
||||
/*! \brief clear the statistics */
|
||||
inline void Clear(void) {
|
||||
GradStats::Clear();
|
||||
for (unsigned i = 0; i < vsize; ++i) {
|
||||
train[i].Clear(); valid[i].Clear();
|
||||
}
|
||||
}
|
||||
inline void Add(const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
bst_uint ridx) {
|
||||
GradStats::Add(gpair[ridx].grad, gpair[ridx].hess);
|
||||
const size_t step = info.fold_index.size();
|
||||
for (unsigned i = 0; i < vsize; ++i) {
|
||||
const bst_gpair &b = gpair[(i + 1) * step + ridx];
|
||||
if (info.fold_index[ridx] == i) {
|
||||
valid[i].Add(b.grad, b.hess);
|
||||
} else {
|
||||
train[i].Add(b.grad, b.hess);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief calculate gain of the solution */
|
||||
inline double CalcGain(const TrainParam ¶m) const {
|
||||
double ret = 0.0;
|
||||
for (unsigned i = 0; i < vsize; ++i) {
|
||||
ret += param.CalcGain(train[i].sum_grad,
|
||||
train[i].sum_hess,
|
||||
vsize * valid[i].sum_grad,
|
||||
vsize * valid[i].sum_hess);
|
||||
}
|
||||
return ret / vsize;
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(const CVGradStats &b) {
|
||||
GradStats::Add(b);
|
||||
for (unsigned i = 0; i < vsize; ++i) {
|
||||
train[i].Add(b.train[i]);
|
||||
valid[i].Add(b.valid[i]);
|
||||
}
|
||||
}
|
||||
/*! \brief set current value to a - b */
|
||||
inline void SetSubstract(const CVGradStats &a, const CVGradStats &b) {
|
||||
GradStats::SetSubstract(a, b);
|
||||
for (int i = 0; i < vsize; ++i) {
|
||||
train[i].SetSubstract(a.train[i], b.train[i]);
|
||||
valid[i].SetSubstract(a.valid[i], b.valid[i]);
|
||||
}
|
||||
}
|
||||
/*! \brief set leaf vector value based on statistics */
|
||||
inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const{
|
||||
for (int i = 0; i < vsize; ++i) {
|
||||
vec[i] = param.learning_rate *
|
||||
param.CalcWeight(train[i].sum_grad, train[i].sum_hess);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief statistics that is helpful to store
|
||||
* and represent a split solution for the tree
|
||||
|
||||
@ -62,6 +62,8 @@ inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
|
||||
if (!strcmp(name, "prune")) return new TreePruner<FMatrix>();
|
||||
if (!strcmp(name, "refresh")) return new TreeRefresher<FMatrix, GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker2")) return new ColMaker<FMatrix, CVGradStats<2> >();
|
||||
if (!strcmp(name, "grow_colmaker5")) return new ColMaker<FMatrix, CVGradStats<5> >();
|
||||
utils::Error("unknown updater:%s", name);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@ -27,6 +27,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
const FMatrix &fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
TStats::CheckInfo(info);
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
@ -81,7 +82,6 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
RegTree *p_tree) {
|
||||
this->InitData(gpair, fmat, info.root_index, *p_tree);
|
||||
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
|
||||
|
||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||
this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree);
|
||||
this->ResetPosition(this->qexpand, fmat, *p_tree);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user