Cleanup on device sketch. (#5874)

* Remove old functions.

* Merge weighted and un-weighted into a common interface.
This commit is contained in:
Jiaming Yuan
2020-07-14 10:15:54 +08:00
committed by GitHub
parent 9f85e92602
commit dd445af56e
10 changed files with 97 additions and 209 deletions

View File

@@ -232,11 +232,8 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
thrust::make_constant_iterator(0lu),
[=]__device__(size_t idx) -> float {
auto ridx = batch.GetElement(idx).row_idx;
auto it = thrust::upper_bound(thrust::seq,
d_group_ptr.cbegin(), d_group_ptr.cend(),
ridx) - 1;
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
return weights[group];
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
return weights[group_idx];
});
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
weight_iter + begin, weight_iter + end,
@@ -277,46 +274,12 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}
template <typename AdapterT>
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
float missing,
size_t sketch_batch_num_elements = 0) {
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, adapter->NumRows());
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
adapter->BeforeFirst();
adapter->Next();
auto& batch = adapter->Value();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
adapter->NumRows(), adapter->NumColumns(), std::numeric_limits<size_t>::max(),
adapter->DeviceIdx(),
num_cuts_per_feature, false);
// Enforce single batch
CHECK(!adapter->Next());
HistogramCuts cuts;
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
adapter->NumRows(), adapter->DeviceIdx());
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
auto const& batch = adapter->Value();
ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(),
begin, end, missing, &sketch_container, num_cuts_per_feature);
}
sketch_container.MakeCuts(&cuts);
return cuts;
}
/*
* \brief Perform sketching on GPU.
*
* \param batch A batch from adapter.
* \param num_bins Bins per column.
* \param info Metainfo used for sketching.
* \param missing Floating point value that represents invalid value.
* \param sketch_container Container for output sketch.
* \param sketch_batch_num_elements Number of element per-sliding window, use it only for
@@ -324,51 +287,37 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
*/
template <typename Batch>
void AdapterDeviceSketch(Batch batch, int num_bins,
MetaInfo const& info,
float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) {
size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols();
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessSlidingWindow(batch, device, num_cols,
begin, end, missing, sketch_container, num_cuts_per_feature);
}
}
bool weighted = info.weights_.Size() != 0;
/*
* \brief Perform weighted sketching on GPU.
*
* When weight in info is empty, this function is equivalent to unweighted version.
*/
template <typename Batch>
void AdapterDeviceSketchWeighted(Batch batch, int num_bins,
MetaInfo const& info,
float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) {
if (info.weights_.Size() == 0) {
return AdapterDeviceSketch(batch, num_bins, missing, sketch_container, sketch_batch_num_elements);
}
size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols();
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, true);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info,
num_cuts_per_feature,
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container);
if (weighted) {
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, true);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info,
num_cuts_per_feature,
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container);
}
} else {
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessSlidingWindow(batch, device, num_cols,
begin, end, missing, sketch_container, num_cuts_per_feature);
}
}
}
} // namespace common