add a indicator opt

This commit is contained in:
tqchen 2015-06-02 11:38:06 -07:00
parent bc7f6b37b0
commit e5dd894960
2 changed files with 9 additions and 7 deletions

View File

@ -155,12 +155,12 @@ struct TrainParam{
return dw;
}
/*! \brief whether need forward small to big search: default right */
inline bool need_forward_search(float col_density = 0.0f) const {
inline bool need_forward_search(float col_density, bool indicator) const {
return this->default_direction == 2 ||
(default_direction == 0 && (col_density < opt_dense_col));
(default_direction == 0 && (col_density < opt_dense_col) && !indicator);
}
/*! \brief whether need backward big to small search: default left */
inline bool need_backward_search(float col_density = 0.0f) const {
inline bool need_backward_search(float col_density, bool indicator) const {
return this->default_direction != 2;
}
/*! \brief given the loss change, whether we need to invode prunning */

View File

@ -234,8 +234,9 @@ class ColMaker: public IUpdater {
const IFMatrix &fmat,
const std::vector<bst_gpair> &gpair,
const BoosterInfo &info) {
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid));
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid));
const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue;
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind);
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid), ind);
const std::vector<int> &qexpand = qexpand_;
#pragma omp parallel
{
@ -530,11 +531,12 @@ class ColMaker: public IUpdater {
const bst_uint fid = batch.col_index[i];
const int tid = omp_get_thread_num();
const ColBatch::Inst c = batch[i];
if (param.need_forward_search(fmat.GetColDensity(fid))) {
const bool ind = c.length != 0 && c.data[0].fvalue == c.data[c.length - 1].fvalue;
if (param.need_forward_search(fmat.GetColDensity(fid), ind)) {
this->EnumerateSplit(c.data, c.data + c.length, +1,
fid, gpair, info, stemp[tid]);
}
if (param.need_backward_search(fmat.GetColDensity(fid))) {
if (param.need_backward_search(fmat.GetColDensity(fid), ind)) {
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
fid, gpair, info, stemp[tid]);
}