fix dart bug (#1882)

This commit is contained in:
wxchan 2016-12-20 01:01:28 +08:00 committed by Tianqi Chen
parent fa97259d66
commit cee4aafb93

View File

@ -680,7 +680,7 @@ class Dart : public GBTree {
// normalize_type 1
float factor = 1.0 / (1.0 + lr);
for (size_t i = 0; i < idx_drop.size(); ++i) {
weight_drop[i] *= factor;
weight_drop[idx_drop[i]] *= factor;
}
for (size_t i = 0; i < size_new_trees; ++i) {
weight_drop.push_back(factor);
@ -689,7 +689,7 @@ class Dart : public GBTree {
// normalize_type 0
float factor = 1.0 * num_drop / (num_drop + lr);
for (size_t i = 0; i < idx_drop.size(); ++i) {
weight_drop[i] *= factor;
weight_drop[idx_drop[i]] *= factor;
}
for (size_t i = 0; i < size_new_trees; ++i) {
weight_drop.push_back(1.0 / (num_drop + lr));