@@ -169,6 +169,7 @@ struct GPUHistMakerDevice {
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogram<GradientSumT> hist{};
|
||||
|
||||
dh::caching_device_vector<GradientPair> d_gpair; // storage for gpair;
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::caching_device_vector<int> monotone_constraints;
|
||||
@@ -269,7 +270,13 @@ struct GPUHistMakerDevice {
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
|
||||
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
|
||||
if (d_gpair.size() != dh_gpair->Size()) {
|
||||
d_gpair.resize(dh_gpair->Size());
|
||||
}
|
||||
thrust::copy(thrust::device, dh_gpair->ConstDevicePointer(),
|
||||
dh_gpair->ConstDevicePointer() + dh_gpair->Size(),
|
||||
d_gpair.begin());
|
||||
auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat);
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user