add a indicator opt
This commit is contained in:
parent
bc7f6b37b0
commit
e5dd894960
@ -155,12 +155,12 @@ struct TrainParam{
|
|||||||
return dw;
|
return dw;
|
||||||
}
|
}
|
||||||
/*! \brief whether need forward small to big search: default right */
|
/*! \brief whether need forward small to big search: default right */
|
||||||
inline bool need_forward_search(float col_density = 0.0f) const {
|
inline bool need_forward_search(float col_density, bool indicator) const {
|
||||||
return this->default_direction == 2 ||
|
return this->default_direction == 2 ||
|
||||||
(default_direction == 0 && (col_density < opt_dense_col));
|
(default_direction == 0 && (col_density < opt_dense_col) && !indicator);
|
||||||
}
|
}
|
||||||
/*! \brief whether need backward big to small search: default left */
|
/*! \brief whether need backward big to small search: default left */
|
||||||
inline bool need_backward_search(float col_density = 0.0f) const {
|
inline bool need_backward_search(float col_density, bool indicator) const {
|
||||||
return this->default_direction != 2;
|
return this->default_direction != 2;
|
||||||
}
|
}
|
||||||
/*! \brief given the loss change, whether we need to invode prunning */
|
/*! \brief given the loss change, whether we need to invode prunning */
|
||||||
|
|||||||
@ -234,8 +234,9 @@ class ColMaker: public IUpdater {
|
|||||||
const IFMatrix &fmat,
|
const IFMatrix &fmat,
|
||||||
const std::vector<bst_gpair> &gpair,
|
const std::vector<bst_gpair> &gpair,
|
||||||
const BoosterInfo &info) {
|
const BoosterInfo &info) {
|
||||||
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid));
|
const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue;
|
||||||
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid));
|
bool need_forward = param.need_forward_search(fmat.GetColDensity(fid), ind);
|
||||||
|
bool need_backward = param.need_backward_search(fmat.GetColDensity(fid), ind);
|
||||||
const std::vector<int> &qexpand = qexpand_;
|
const std::vector<int> &qexpand = qexpand_;
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
@ -530,11 +531,12 @@ class ColMaker: public IUpdater {
|
|||||||
const bst_uint fid = batch.col_index[i];
|
const bst_uint fid = batch.col_index[i];
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
const ColBatch::Inst c = batch[i];
|
const ColBatch::Inst c = batch[i];
|
||||||
if (param.need_forward_search(fmat.GetColDensity(fid))) {
|
const bool ind = c.length != 0 && c.data[0].fvalue == c.data[c.length - 1].fvalue;
|
||||||
|
if (param.need_forward_search(fmat.GetColDensity(fid), ind)) {
|
||||||
this->EnumerateSplit(c.data, c.data + c.length, +1,
|
this->EnumerateSplit(c.data, c.data + c.length, +1,
|
||||||
fid, gpair, info, stemp[tid]);
|
fid, gpair, info, stemp[tid]);
|
||||||
}
|
}
|
||||||
if (param.need_backward_search(fmat.GetColDensity(fid))) {
|
if (param.need_backward_search(fmat.GetColDensity(fid), ind)) {
|
||||||
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
|
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
|
||||||
fid, gpair, info, stemp[tid]);
|
fid, gpair, info, stemp[tid]);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user