From e5dd894960b6735336df90335d8f7da753b95553 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Jun 2015 11:38:06 -0700 Subject: [PATCH] add a indicator opt --- src/tree/param.h | 6 +++--- src/tree/updater_colmaker-inl.hpp | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/tree/param.h b/src/tree/param.h index 1bffcb32c..20ba1e6c0 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -155,12 +155,12 @@ struct TrainParam{ return dw; } /*! \brief whether need forward small to big search: default right */ - inline bool need_forward_search(float col_density = 0.0f) const { + inline bool need_forward_search(float col_density, bool indicator) const { return this->default_direction == 2 || - (default_direction == 0 && (col_density < opt_dense_col)); + (default_direction == 0 && (col_density < opt_dense_col) && !indicator); } /*! \brief whether need backward big to small search: default left */ - inline bool need_backward_search(float col_density = 0.0f) const { + inline bool need_backward_search(float col_density, bool indicator) const { return this->default_direction != 2; } /*! \brief given the loss change, whether we need to invode prunning */ diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index b52842a93..db3581aac 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -234,8 +234,9 @@ class ColMaker: public IUpdater { const IFMatrix &fmat, const std::vector &gpair, const BoosterInfo &info) { - bool need_forward = param.need_forward_search(fmat.GetColDensity(fid)); - bool need_backward = param.need_backward_search(fmat.GetColDensity(fid)); + const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue; + bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind); + bool need_backward = param.need_backward_search(fmat.GetColDensity(fid), ind); const std::vector &qexpand = qexpand_; #pragma omp parallel { @@ -530,11 +531,12 @@ class ColMaker: public IUpdater { const bst_uint fid = batch.col_index[i]; const int tid = omp_get_thread_num(); const ColBatch::Inst c = batch[i]; - if (param.need_forward_search(fmat.GetColDensity(fid))) { + const bool ind = c.length != 0 && c.data[0].fvalue == c.data[c.length - 1].fvalue; + if (param.need_forward_search(fmat.GetColDensity(fid), ind)) { this->EnumerateSplit(c.data, c.data + c.length, +1, fid, gpair, info, stemp[tid]); } - if (param.need_backward_search(fmat.GetColDensity(fid))) { + if (param.need_backward_search(fmat.GetColDensity(fid), ind)) { this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1, fid, gpair, info, stemp[tid]); }