Cleanup on device sketch. (#5874)
* Remove old functions. * Merge weighted and un-weighted into a common interface.
This commit is contained in:
@@ -228,31 +228,6 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
|
||||
int max_bin, common::Span<size_t> row_counts_span,
|
||||
size_t row_stride) {
|
||||
common::HistogramCuts cuts =
|
||||
common::AdapterDeviceSketch(adapter, max_bin, missing);
|
||||
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
|
||||
auto& batch = adapter->Value();
|
||||
|
||||
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
|
||||
adapter->NumRows());
|
||||
CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing);
|
||||
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
|
||||
}
|
||||
|
||||
#define ELLPACK_SPECIALIZATION(__ADAPTER_T) \
|
||||
template EllpackPageImpl::EllpackPageImpl( \
|
||||
__ADAPTER_T* adapter, float missing, bool is_dense, int nthread, int max_bin, \
|
||||
common::Span<size_t> row_counts_span, \
|
||||
size_t row_stride);
|
||||
|
||||
ELLPACK_SPECIALIZATION(data::CudfAdapter)
|
||||
ELLPACK_SPECIALIZATION(data::CupyAdapter)
|
||||
|
||||
|
||||
template <typename AdapterBatch>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||
bool is_dense, int nthread,
|
||||
|
||||
@@ -159,12 +159,6 @@ class EllpackPageImpl {
|
||||
*/
|
||||
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
||||
|
||||
template <typename AdapterT>
|
||||
explicit EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
|
||||
int max_bin,
|
||||
common::Span<size_t> row_counts_span,
|
||||
size_t row_stride);
|
||||
|
||||
template <typename AdapterBatch>
|
||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread,
|
||||
common::Span<size_t> row_counts_span,
|
||||
|
||||
@@ -75,8 +75,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
proxy->Info().weights_.SetDevice(device);
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
|
||||
proxy->Info(), missing, p_sketch);
|
||||
common::AdapterDeviceSketch(value, batch_param_.max_bin,
|
||||
proxy->Info(), missing, p_sketch);
|
||||
});
|
||||
|
||||
auto batch_rows = num_rows();
|
||||
|
||||
Reference in New Issue
Block a user