Fix race condition in interaction constraint. (#4587)

* Split up the kernel to sync write.

* QueryNode is no-longer used in Query, but kept for testing.
This commit is contained in:
Jiaming Yuan 2019-06-21 02:47:48 +08:00 committed by GitHub
parent 221e163185
commit fdf27a5b82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -134,9 +134,9 @@ void FeatureInteractionConstraint::Configure(
feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_); feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_);
// --- Initialize result buffers. // --- Initialize result buffers.
output_buffer_bits_storage_.resize(n_features); output_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_)); output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_));
input_buffer_bits_storage_.resize(n_features); input_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_)); input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_));
result_buffer_.resize(n_features); result_buffer_.resize(n_features);
s_result_buffer_ = dh::ToSpan(result_buffer_); s_result_buffer_ = dh::ToSpan(result_buffer_);
@ -155,10 +155,10 @@ void FeatureInteractionConstraint::Reset() {
} }
__global__ void ClearBuffersKernel( __global__ void ClearBuffersKernel(
BitField result_buffer_self, BitField result_buffer_input, BitField feature_buffer) { BitField result_buffer_output, BitField result_buffer_input) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < result_buffer_self.Size()) { if (tid < result_buffer_output.Size()) {
result_buffer_self.Clear(tid); result_buffer_output.Clear(tid);
} }
if (tid < result_buffer_input.Size()) { if (tid < result_buffer_input.Size()) {
result_buffer_input.Clear(tid); result_buffer_input.Clear(tid);
@ -172,7 +172,7 @@ void FeatureInteractionConstraint::ClearBuffers() {
const int n_grids = static_cast<int>( const int n_grids = static_cast<int>(
dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads)); dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
ClearBuffersKernel<<<n_grids, kBlockThreads>>>( ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
output_buffer_bits_, input_buffer_bits_, feature_buffer_); output_buffer_bits_, input_buffer_bits_);
} }
common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) { common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
@ -199,18 +199,18 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
return {s_result_buffer_.data(), s_result_buffer_.data() + n_available}; return {s_result_buffer_.data(), s_result_buffer_.data() + n_available};
} }
__global__ void QueryFeatureListKernel(common::Span<int32_t> feature_list_input, __global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input,
common::Span<int32_t> node_feature_list, BitField result_buffer_input) {
BitField result_buffer_input,
BitField result_buffer_output) {
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < feature_list_input.size()) { if (tid < feature_list_input.size()) {
result_buffer_input.Set(feature_list_input[tid]); result_buffer_input.Set(feature_list_input[tid]);
} }
if (tid < node_feature_list.size()) {
result_buffer_output.Set(node_feature_list[tid]);
} }
__global__ void QueryFeatureListKernel(BitField node_constraints,
BitField result_buffer_input,
BitField result_buffer_output) {
result_buffer_output |= node_constraints;
result_buffer_output &= result_buffer_input; result_buffer_output &= result_buffer_input;
} }
@ -219,17 +219,19 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
if (!has_constraint_ || nid == 0) { if (!has_constraint_ || nid == 0) {
return feature_list; return feature_list;
} }
auto selected = this->QueryNode(nid);
ClearBuffers();
BitField node_constraints = s_node_constraints_[nid];
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size()); CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
int constexpr kBlockThreads = 256; int constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>( const int n_grids = static_cast<int>(
dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads)); dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);
QueryFeatureListKernel<<<n_grids, kBlockThreads>>> QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
(feature_list, node_constraints, input_buffer_bits_, output_buffer_bits_);
selected,
input_buffer_bits_,
output_buffer_bits_);
thrust::counting_iterator<int32_t> begin(0); thrust::counting_iterator<int32_t> begin(0);
thrust::counting_iterator<int32_t> end(result_buffer_.size()); thrust::counting_iterator<int32_t> end(result_buffer_.size());