add a indicator opt
This commit is contained in:
parent
bc7f6b37b0
commit
e5dd894960
@ -155,12 +155,12 @@ struct TrainParam{
|
||||
return dw;
|
||||
}
|
||||
/*! \brief whether need forward small to big search: default right */
|
||||
inline bool need_forward_search(float col_density = 0.0f) const {
|
||||
inline bool need_forward_search(float col_density, bool indicator) const {
|
||||
return this->default_direction == 2 ||
|
||||
(default_direction == 0 && (col_density < opt_dense_col));
|
||||
(default_direction == 0 && (col_density < opt_dense_col) && !indicator);
|
||||
}
|
||||
/*! \brief whether need backward big to small search: default left */
|
||||
inline bool need_backward_search(float col_density = 0.0f) const {
|
||||
inline bool need_backward_search(float col_density, bool indicator) const {
|
||||
return this->default_direction != 2;
|
||||
}
|
||||
/*! \brief given the loss change, whether we need to invode prunning */
|
||||
|
||||
@ -234,8 +234,9 @@ class ColMaker: public IUpdater {
|
||||
const IFMatrix &fmat,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info) {
|
||||
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid));
|
||||
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid));
|
||||
const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue;
|
||||
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind);
|
||||
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid), ind);
|
||||
const std::vector<int> &qexpand = qexpand_;
|
||||
#pragma omp parallel
|
||||
{
|
||||
@ -530,11 +531,12 @@ class ColMaker: public IUpdater {
|
||||
const bst_uint fid = batch.col_index[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
const ColBatch::Inst c = batch[i];
|
||||
if (param.need_forward_search(fmat.GetColDensity(fid))) {
|
||||
const bool ind = c.length != 0 && c.data[0].fvalue == c.data[c.length - 1].fvalue;
|
||||
if (param.need_forward_search(fmat.GetColDensity(fid), ind)) {
|
||||
this->EnumerateSplit(c.data, c.data + c.length, +1,
|
||||
fid, gpair, info, stemp[tid]);
|
||||
}
|
||||
if (param.need_backward_search(fmat.GetColDensity(fid))) {
|
||||
if (param.need_backward_search(fmat.GetColDensity(fid), ind)) {
|
||||
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
|
||||
fid, gpair, info, stemp[tid]);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user