Fix race condition in interaction constraint. (#4587)
* Split up the kernel to sync write. * QueryNode is no-longer used in Query, but kept for testing.
This commit is contained in:
parent
221e163185
commit
fdf27a5b82
@ -134,9 +134,9 @@ void FeatureInteractionConstraint::Configure(
|
||||
feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_);
|
||||
|
||||
// --- Initialize result buffers.
|
||||
output_buffer_bits_storage_.resize(n_features);
|
||||
output_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
|
||||
output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_));
|
||||
input_buffer_bits_storage_.resize(n_features);
|
||||
input_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
|
||||
input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_));
|
||||
result_buffer_.resize(n_features);
|
||||
s_result_buffer_ = dh::ToSpan(result_buffer_);
|
||||
@ -155,10 +155,10 @@ void FeatureInteractionConstraint::Reset() {
|
||||
}
|
||||
|
||||
__global__ void ClearBuffersKernel(
|
||||
BitField result_buffer_self, BitField result_buffer_input, BitField feature_buffer) {
|
||||
BitField result_buffer_output, BitField result_buffer_input) {
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < result_buffer_self.Size()) {
|
||||
result_buffer_self.Clear(tid);
|
||||
if (tid < result_buffer_output.Size()) {
|
||||
result_buffer_output.Clear(tid);
|
||||
}
|
||||
if (tid < result_buffer_input.Size()) {
|
||||
result_buffer_input.Clear(tid);
|
||||
@ -172,7 +172,7 @@ void FeatureInteractionConstraint::ClearBuffers() {
|
||||
const int n_grids = static_cast<int>(
|
||||
dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
|
||||
ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
|
||||
output_buffer_bits_, input_buffer_bits_, feature_buffer_);
|
||||
output_buffer_bits_, input_buffer_bits_);
|
||||
}
|
||||
|
||||
common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
|
||||
@ -199,18 +199,18 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
|
||||
return {s_result_buffer_.data(), s_result_buffer_.data() + n_available};
|
||||
}
|
||||
|
||||
__global__ void QueryFeatureListKernel(common::Span<int32_t> feature_list_input,
|
||||
common::Span<int32_t> node_feature_list,
|
||||
BitField result_buffer_input,
|
||||
BitField result_buffer_output) {
|
||||
__global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input,
|
||||
BitField result_buffer_input) {
|
||||
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (tid < feature_list_input.size()) {
|
||||
result_buffer_input.Set(feature_list_input[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
if (tid < node_feature_list.size()) {
|
||||
result_buffer_output.Set(node_feature_list[tid]);
|
||||
}
|
||||
__global__ void QueryFeatureListKernel(BitField node_constraints,
|
||||
BitField result_buffer_input,
|
||||
BitField result_buffer_output) {
|
||||
result_buffer_output |= node_constraints;
|
||||
result_buffer_output &= result_buffer_input;
|
||||
}
|
||||
|
||||
@ -219,17 +219,19 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
|
||||
if (!has_constraint_ || nid == 0) {
|
||||
return feature_list;
|
||||
}
|
||||
auto selected = this->QueryNode(nid);
|
||||
|
||||
ClearBuffers();
|
||||
|
||||
BitField node_constraints = s_node_constraints_[nid];
|
||||
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
|
||||
|
||||
int constexpr kBlockThreads = 256;
|
||||
const int n_grids = static_cast<int>(
|
||||
dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
|
||||
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);
|
||||
|
||||
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>
|
||||
(feature_list,
|
||||
selected,
|
||||
input_buffer_bits_,
|
||||
output_buffer_bits_);
|
||||
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
|
||||
node_constraints, input_buffer_bits_, output_buffer_bits_);
|
||||
|
||||
thrust::counting_iterator<int32_t> begin(0);
|
||||
thrust::counting_iterator<int32_t> end(result_buffer_.size());
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user