Cleanup on device sketch. (#5874)

* Remove old functions.

* Merge weighted and un-weighted into a common interface.
This commit is contained in:
Jiaming Yuan
2020-07-14 10:15:54 +08:00
committed by GitHub
parent 9f85e92602
commit dd445af56e
10 changed files with 97 additions and 209 deletions

View File

@@ -228,31 +228,6 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx,
});
}
template <typename AdapterT>
EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
int max_bin, common::Span<size_t> row_counts_span,
size_t row_stride) {
common::HistogramCuts cuts =
common::AdapterDeviceSketch(adapter, max_bin, missing);
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
auto& batch = adapter->Value();
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
adapter->NumRows());
CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing);
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
}
#define ELLPACK_SPECIALIZATION(__ADAPTER_T) \
template EllpackPageImpl::EllpackPageImpl( \
__ADAPTER_T* adapter, float missing, bool is_dense, int nthread, int max_bin, \
common::Span<size_t> row_counts_span, \
size_t row_stride);
ELLPACK_SPECIALIZATION(data::CudfAdapter)
ELLPACK_SPECIALIZATION(data::CupyAdapter)
template <typename AdapterBatch>
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
bool is_dense, int nthread,

View File

@@ -159,12 +159,6 @@ class EllpackPageImpl {
*/
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
template <typename AdapterT>
explicit EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
int max_bin,
common::Span<size_t> row_counts_span,
size_t row_stride);
template <typename AdapterBatch>
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread,
common::Span<size_t> row_counts_span,

View File

@@ -75,8 +75,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
auto* p_sketch = &sketch_containers.back();
proxy->Info().weights_.SetDevice(device);
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch);
common::AdapterDeviceSketch(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch);
});
auto batch_rows = num_rows();