Cleanup on device sketch. (#5874)

* Remove old functions.

* Merge weighted and un-weighted into a common interface.
This commit is contained in:
Jiaming Yuan
2020-07-14 10:15:54 +08:00
committed by GitHub
parent 9f85e92602
commit dd445af56e
10 changed files with 97 additions and 209 deletions

View File

@@ -253,22 +253,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
<< "Weight size should equal to number of groups.";
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
auto it =
thrust::upper_bound(thrust::seq,
d_group_ptr.cbegin(), d_group_ptr.cend(),
ridx + base_rowid) - 1;
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
d_temp_weights[idx] = weights[group];
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid);
d_temp_weights[idx] = weights[group_idx];
});
} else {
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}