Fixed issue 3605. (#3628)

* Fixed issue 3605.

- https://github.com/dmlc/xgboost/issues/3605

* Fixed the bug in a better way.

* Added a test to catch the bug.

* Fixed linter errors.
This commit is contained in:
Andy Adinets
2018-08-28 19:50:52 +02:00
committed by Philip Hyunsu Cho
parent 78bea0d204
commit 58d783df16
3 changed files with 18 additions and 8 deletions

View File

@@ -257,13 +257,13 @@ struct GPUSketcher {
n_cuts_cur_[icol] = std::min(n_cuts_, n_unique);
// if less elements than cuts: copy all elements with their weights
if (n_cuts_ > n_unique) {
auto weights2_iter = weights2_.begin();
auto fvalues_iter = fvalues_cur_.begin();
auto cuts_iter = cuts_d_.begin() + icol * n_cuts_;
float* weights2_ptr = weights2_.data().get();
float* fvalues_ptr = fvalues_cur_.data().get();
WXQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_;
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
bst_float rmax = weights2_iter[i];
bst_float rmin = i > 0 ? weights2_iter[i - 1] : 0;
cuts_iter[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_iter[i]);
bst_float rmax = weights2_ptr[i];
bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0;
cuts_ptr[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]);
});
} else if (n_cuts_cur_[icol] > 0) {
// if more elements than cuts: use binary search on cumulative weights