Fix GPU RF (#6755)

* Fix sampling.
This commit is contained in:
Jiaming Yuan
2021-03-17 06:23:35 +08:00
committed by GitHub
parent 1a73a28511
commit 4f75f514ce
4 changed files with 26 additions and 9 deletions

View File

@@ -169,6 +169,7 @@ struct GPUHistMakerDevice {
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist{};
dh::caching_device_vector<GradientPair> d_gpair; // storage for gpair;
common::Span<GradientPair> gpair;
dh::caching_device_vector<int> monotone_constraints;
@@ -269,7 +270,13 @@ struct GPUHistMakerDevice {
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
GradientPair());
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
if (d_gpair.size() != dh_gpair->Size()) {
d_gpair.resize(dh_gpair->Size());
}
thrust::copy(thrust::device, dh_gpair->ConstDevicePointer(),
dh_gpair->ConstDevicePointer() + dh_gpair->Size(),
d_gpair.begin());
auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat);
page = sample.page;
gpair = sample.gpair;