GPUTreeShap (#6038)
This commit is contained in:
@@ -474,8 +474,18 @@ class TemporaryArray {
|
||||
using AllocT = XGBCachingDeviceAllocator<T>;
|
||||
using value_type = T; // NOLINT
|
||||
explicit TemporaryArray(size_t n) : size_(n) { ptr_ = AllocT().allocate(n); }
|
||||
TemporaryArray(size_t n, T val) : size_(n) {
|
||||
ptr_ = AllocT().allocate(n);
|
||||
this->fill(val);
|
||||
}
|
||||
~TemporaryArray() { AllocT().deallocate(ptr_, this->size()); }
|
||||
|
||||
void fill(T val) // NOLINT
|
||||
{
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
auto d_data = ptr_.get();
|
||||
LaunchN(device, this->size(), [=] __device__(size_t idx) { d_data[idx] = val; });
|
||||
}
|
||||
thrust::device_ptr<T> data() { return ptr_; } // NOLINT
|
||||
size_t size() { return size_; } // NOLINT
|
||||
|
||||
|
||||
Reference in New Issue
Block a user