Fix race condition in interaction constraint. (#4587)

* Split up the kernel to sync write.

* QueryNode is no-longer used in Query, but kept for testing.
This commit is contained in:
Jiaming Yuan 2019-06-21 02:47:48 +08:00 committed by GitHub
parent 221e163185
commit fdf27a5b82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -134,9 +134,9 @@ void FeatureInteractionConstraint::Configure(
feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_);
// --- Initialize result buffers.
output_buffer_bits_storage_.resize(n_features);
output_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_));
input_buffer_bits_storage_.resize(n_features);
input_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_));
result_buffer_.resize(n_features);
s_result_buffer_ = dh::ToSpan(result_buffer_);
@ -155,10 +155,10 @@ void FeatureInteractionConstraint::Reset() {
}
__global__ void ClearBuffersKernel(
BitField result_buffer_self, BitField result_buffer_input, BitField feature_buffer) {
BitField result_buffer_output, BitField result_buffer_input) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < result_buffer_self.Size()) {
result_buffer_self.Clear(tid);
if (tid < result_buffer_output.Size()) {
result_buffer_output.Clear(tid);
}
if (tid < result_buffer_input.Size()) {
result_buffer_input.Clear(tid);
@ -172,7 +172,7 @@ void FeatureInteractionConstraint::ClearBuffers() {
const int n_grids = static_cast<int>(
dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
output_buffer_bits_, input_buffer_bits_, feature_buffer_);
output_buffer_bits_, input_buffer_bits_);
}
common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
@ -199,18 +199,18 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
return {s_result_buffer_.data(), s_result_buffer_.data() + n_available};
}
__global__ void QueryFeatureListKernel(common::Span<int32_t> feature_list_input,
common::Span<int32_t> node_feature_list,
BitField result_buffer_input,
BitField result_buffer_output) {
__global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input,
BitField result_buffer_input) {
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < feature_list_input.size()) {
result_buffer_input.Set(feature_list_input[tid]);
}
}
if (tid < node_feature_list.size()) {
result_buffer_output.Set(node_feature_list[tid]);
}
__global__ void QueryFeatureListKernel(BitField node_constraints,
BitField result_buffer_input,
BitField result_buffer_output) {
result_buffer_output |= node_constraints;
result_buffer_output &= result_buffer_input;
}
@ -219,17 +219,19 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
if (!has_constraint_ || nid == 0) {
return feature_list;
}
auto selected = this->QueryNode(nid);
ClearBuffers();
BitField node_constraints = s_node_constraints_[nid];
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
int constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>(
dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>
(feature_list,
selected,
input_buffer_bits_,
output_buffer_bits_);
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
node_constraints, input_buffer_bits_, output_buffer_bits_);
thrust::counting_iterator<int32_t> begin(0);
thrust::counting_iterator<int32_t> end(result_buffer_.size());