add a indicator opt

This commit is contained in:
tqchen 2015-06-02 11:38:06 -07:00
parent bc7f6b37b0
commit e5dd894960
2 changed files with 9 additions and 7 deletions

View File

@ -155,12 +155,12 @@ struct TrainParam{
return dw; return dw;
} }
/*! \brief whether need forward small to big search: default right */ /*! \brief whether need forward small to big search: default right */
inline bool need_forward_search(float col_density = 0.0f) const { inline bool need_forward_search(float col_density, bool indicator) const {
return this->default_direction == 2 || return this->default_direction == 2 ||
(default_direction == 0 && (col_density < opt_dense_col)); (default_direction == 0 && (col_density < opt_dense_col) && !indicator);
} }
/*! \brief whether need backward big to small search: default left */ /*! \brief whether need backward big to small search: default left */
inline bool need_backward_search(float col_density = 0.0f) const { inline bool need_backward_search(float col_density, bool indicator) const {
return this->default_direction != 2; return this->default_direction != 2;
} }
/*! \brief given the loss change, whether we need to invode prunning */ /*! \brief given the loss change, whether we need to invode prunning */

View File

@ -234,8 +234,9 @@ class ColMaker: public IUpdater {
const IFMatrix &fmat, const IFMatrix &fmat,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
const BoosterInfo &info) { const BoosterInfo &info) {
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid)); const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue;
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid)); bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind);
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid), ind);
const std::vector<int> &qexpand = qexpand_; const std::vector<int> &qexpand = qexpand_;
#pragma omp parallel #pragma omp parallel
{ {
@ -530,11 +531,12 @@ class ColMaker: public IUpdater {
const bst_uint fid = batch.col_index[i]; const bst_uint fid = batch.col_index[i];
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
const ColBatch::Inst c = batch[i]; const ColBatch::Inst c = batch[i];
if (param.need_forward_search(fmat.GetColDensity(fid))) { const bool ind = c.length != 0 && c.data[0].fvalue == c.data[c.length - 1].fvalue;
if (param.need_forward_search(fmat.GetColDensity(fid), ind)) {
this->EnumerateSplit(c.data, c.data + c.length, +1, this->EnumerateSplit(c.data, c.data + c.length, +1,
fid, gpair, info, stemp[tid]); fid, gpair, info, stemp[tid]);
} }
if (param.need_backward_search(fmat.GetColDensity(fid))) { if (param.need_backward_search(fmat.GetColDensity(fid), ind)) {
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1, this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
fid, gpair, info, stemp[tid]); fid, gpair, info, stemp[tid]);
} }