Fix evaluate root split. (#5948)

This commit is contained in:
Jiaming Yuan 2020-07-29 19:33:29 +08:00 committed by GitHub
parent 071e10c1d1
commit e4a273e1da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -228,13 +228,14 @@ void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
return 0; return 0;
}); });
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
auto num_segments = out_splits.size();
cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes, cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes,
feature_best_splits.data(), out_splits.data(), feature_best_splits.data(), out_splits.data(),
2, reduce_offset, reduce_offset + 1); num_segments, reduce_offset, reduce_offset + 1);
dh::TemporaryArray<int8_t> temp(temp_storage_bytes); dh::TemporaryArray<int8_t> temp(temp_storage_bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes, cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes,
feature_best_splits.data(), out_splits.data(), feature_best_splits.data(), out_splits.data(),
2, reduce_offset, reduce_offset + 1); num_segments, reduce_offset, reduce_offset + 1);
} }
template <typename GradientSumT> template <typename GradientSumT>