add Dart booster (#1220)
This commit is contained in:
committed by
Tianqi Chen
parent
e034fdf74c
commit
949d1e3027
@@ -17,6 +17,8 @@
|
||||
#include <limits>
|
||||
#include "../common/common.h"
|
||||
|
||||
#include "../common/random.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
@@ -47,6 +49,42 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief training parameters */
|
||||
struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
||||
/*! \brief whether to not print info during training */
|
||||
bool silent;
|
||||
/*! \brief type of sampling algorithm */
|
||||
int sample_type;
|
||||
/*! \brief type of normalization algorithm */
|
||||
int normalize_type;
|
||||
/*! \brief how many trees are dropped */
|
||||
float rate_drop;
|
||||
/*! \brief whether to drop trees */
|
||||
float skip_drop;
|
||||
/*! \brief learning step size for a time */
|
||||
float learning_rate;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(DartTrainParam) {
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false)
|
||||
.describe("Not print information during trainig.");
|
||||
DMLC_DECLARE_FIELD(sample_type).set_default(0)
|
||||
.add_enum("uniform", 0)
|
||||
.add_enum("weighted", 1)
|
||||
.describe("Different types of sampling algorithm.");
|
||||
DMLC_DECLARE_FIELD(normalize_type).set_default(0)
|
||||
.add_enum("tree", 0)
|
||||
.add_enum("forest", 1)
|
||||
.describe("Different types of normalization algorithm.");
|
||||
DMLC_DECLARE_FIELD(rate_drop).set_range(0.0f, 1.0f).set_default(0.0f)
|
||||
.describe("Parameter of how many trees are dropped.");
|
||||
DMLC_DECLARE_FIELD(skip_drop).set_range(0.0f, 1.0f).set_default(0.0f)
|
||||
.describe("Parameter of whether to drop trees.");
|
||||
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
|
||||
.describe("Learning rate(step size) of update.");
|
||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief model parameters */
|
||||
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
/*! \brief number of trees */
|
||||
@@ -313,8 +351,9 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
// commit new trees all at once
|
||||
inline void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
||||
int bst_group) {
|
||||
virtual void
|
||||
CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
||||
int bst_group) {
|
||||
for (size_t i = 0; i < new_trees.size(); ++i) {
|
||||
trees.push_back(std::move(new_trees[i]));
|
||||
tree_info.push_back(bst_group);
|
||||
@@ -475,14 +514,236 @@ class GBTree : public GradientBooster {
|
||||
std::vector<std::unique_ptr<TreeUpdater> > updaters;
|
||||
};
|
||||
|
||||
// dart
|
||||
class Dart : public GBTree {
|
||||
public:
|
||||
Dart() {}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
|
||||
GBTree::Configure(cfg);
|
||||
if (trees.size() == 0) {
|
||||
dparam.InitAllowUnknown(cfg);
|
||||
}
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
GBTree::Load(fi);
|
||||
weight_drop.resize(mparam.num_trees);
|
||||
if (mparam.num_trees != 0) {
|
||||
fi->Read(&weight_drop);
|
||||
}
|
||||
}
|
||||
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
GBTree::Save(fo);
|
||||
if (weight_drop.size() != 0) {
|
||||
fo->Write(weight_drop);
|
||||
}
|
||||
}
|
||||
|
||||
// predict the leaf scores with dropout if ntree_limit = 0
|
||||
void Predict(DMatrix* p_fmat,
|
||||
int64_t buffer_offset,
|
||||
std::vector<float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(ntree_limit);
|
||||
const MetaInfo& info = p_fmat->info();
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
InitThreadTemp(nthread);
|
||||
std::vector<float> &preds = *out_preds;
|
||||
const size_t stride = p_fmat->info().num_row * mparam.num_output_group;
|
||||
preds.resize(stride * (mparam.size_leaf_vector+1));
|
||||
// start collecting the prediction
|
||||
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
|
||||
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
// parallel over local batch
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec &feats = thread_temp[tid];
|
||||
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
CHECK_LT(static_cast<size_t>(ridx), info.num_row);
|
||||
// loop over output groups
|
||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||
this->Pred(batch[i],
|
||||
buffer_offset < 0 ? -1 : buffer_offset + ridx,
|
||||
gid, info.GetRoot(ridx), &feats,
|
||||
&preds[ridx * mparam.num_output_group + gid], stride,
|
||||
ntree_limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Predict(const SparseBatch::Inst& inst,
|
||||
std::vector<float>* out_preds,
|
||||
unsigned ntree_limit,
|
||||
unsigned root_index) override {
|
||||
DropTrees(1);
|
||||
if (thread_temp.size() == 0) {
|
||||
thread_temp.resize(1, RegTree::FVec());
|
||||
thread_temp[0].Init(mparam.num_feature);
|
||||
}
|
||||
out_preds->resize(mparam.num_output_group * (mparam.size_leaf_vector+1));
|
||||
// loop over output groups
|
||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||
this->Pred(inst, -1, gid, root_index, &thread_temp[0],
|
||||
&(*out_preds)[gid], mparam.num_output_group,
|
||||
ntree_limit);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// commit new trees all at once
|
||||
virtual void
|
||||
CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
||||
int bst_group) {
|
||||
for (size_t i = 0; i < new_trees.size(); ++i) {
|
||||
trees.push_back(std::move(new_trees[i]));
|
||||
tree_info.push_back(bst_group);
|
||||
}
|
||||
mparam.num_trees += static_cast<int>(new_trees.size());
|
||||
size_t num_drop = NormalizeTrees(new_trees.size());
|
||||
if (dparam.silent != 1) {
|
||||
LOG(INFO) << "drop " << num_drop << " trees, "
|
||||
<< "weight = " << weight_drop.back();
|
||||
}
|
||||
}
|
||||
// predict the leaf scores without dropped trees
|
||||
inline void Pred(const RowBatch::Inst &inst,
|
||||
int64_t buffer_index,
|
||||
int bst_group,
|
||||
unsigned root_index,
|
||||
RegTree::FVec *p_feats,
|
||||
float *out_pred,
|
||||
size_t stride,
|
||||
unsigned ntree_limit) {
|
||||
float psum = 0.0f;
|
||||
// sum of leaf vector
|
||||
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
|
||||
const int64_t bid = this->BufferOffset(buffer_index, bst_group);
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
bool drop = (std::find(idx_drop.begin(), idx_drop.end(), i) != idx_drop.end());
|
||||
if (!drop) {
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats, root_index);
|
||||
psum += weight_drop[i] * (*trees[i])[tid].leaf_value();
|
||||
for (int j = 0; j < mparam.size_leaf_vector; ++j) {
|
||||
vec_psum[j] += weight_drop[i] * trees[i]->leafvec(tid)[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
// updated the buffered results
|
||||
if (bid >= 0 && ntree_limit == 0) {
|
||||
pred_counter[bid] = static_cast<unsigned>(trees.size());
|
||||
pred_buffer[bid] = psum;
|
||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||
pred_buffer[bid + i + 1] = vec_psum[i];
|
||||
}
|
||||
}
|
||||
out_pred[0] = psum;
|
||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||
out_pred[stride * (i + 1)] = vec_psum[i];
|
||||
}
|
||||
}
|
||||
|
||||
// select dropped trees
|
||||
inline void DropTrees(unsigned ntree_limit_drop) {
|
||||
std::uniform_real_distribution<> runif(0.0, 1.0);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
// reset
|
||||
idx_drop.clear();
|
||||
// sample dropped trees
|
||||
bool skip = false;
|
||||
if (dparam.skip_drop > 0.0) skip = (runif(rnd) < dparam.skip_drop);
|
||||
if (ntree_limit_drop == 0 && !skip) {
|
||||
if (dparam.sample_type == 1) {
|
||||
float sum_weight = 0.0;
|
||||
for (size_t i = 0; i < weight_drop.size(); ++i) {
|
||||
sum_weight += weight_drop[i];
|
||||
}
|
||||
for (size_t i = 0; i < weight_drop.size(); ++i) {
|
||||
if (runif(rnd) < dparam.rate_drop * weight_drop.size() * weight_drop[i] / sum_weight) {
|
||||
idx_drop.push_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < weight_drop.size(); ++i) {
|
||||
if (runif(rnd) < dparam.rate_drop) {
|
||||
idx_drop.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// set normalization factors
|
||||
inline size_t NormalizeTrees(size_t size_new_trees) {
|
||||
float lr = 1.0 * dparam.learning_rate / size_new_trees;
|
||||
size_t num_drop = idx_drop.size();
|
||||
if (num_drop == 0) {
|
||||
for (size_t i = 0; i < size_new_trees; ++i) {
|
||||
weight_drop.push_back(1.0);
|
||||
}
|
||||
} else {
|
||||
if (dparam.normalize_type == 1) {
|
||||
// normalize_type 1
|
||||
float factor = 1.0 / (1.0 + lr);
|
||||
for (size_t i = 0; i < idx_drop.size(); ++i) {
|
||||
weight_drop[i] *= factor;
|
||||
}
|
||||
for (size_t i = 0; i < size_new_trees; ++i) {
|
||||
weight_drop.push_back(lr * factor);
|
||||
}
|
||||
} else {
|
||||
// normalize_type 0
|
||||
float factor = 1.0 * num_drop / (num_drop + lr);
|
||||
for (size_t i = 0; i < idx_drop.size(); ++i) {
|
||||
weight_drop[i] *= factor;
|
||||
}
|
||||
for (size_t i = 0; i < size_new_trees; ++i) {
|
||||
weight_drop.push_back(1.0 * lr / (num_drop + lr));
|
||||
}
|
||||
}
|
||||
}
|
||||
// reset
|
||||
idx_drop.clear();
|
||||
return num_drop;
|
||||
}
|
||||
|
||||
// --- data structure ---
|
||||
// training parameter
|
||||
DartTrainParam dparam;
|
||||
/*! \brief prediction buffer */
|
||||
std::vector<float> weight_drop;
|
||||
// indexes of dropped trees
|
||||
std::vector<size_t> idx_drop;
|
||||
};
|
||||
|
||||
// register the ojective functions
|
||||
DMLC_REGISTER_PARAMETER(GBTreeModelParam);
|
||||
DMLC_REGISTER_PARAMETER(GBTreeTrainParam);
|
||||
DMLC_REGISTER_PARAMETER(DartTrainParam);
|
||||
|
||||
XGBOOST_REGISTER_GBM(GBTree, "gbtree")
|
||||
.describe("Tree booster, gradient boosted trees.")
|
||||
.set_body([]() {
|
||||
return new GBTree();
|
||||
});
|
||||
XGBOOST_REGISTER_GBM(Dart, "dart")
|
||||
.describe("Tree booster, dart.")
|
||||
.set_body([]() {
|
||||
return new Dart();
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user