Fix potential race in feature constraint. (#10719)

This commit is contained in:
Jiaming Yuan 2024-08-21 16:50:31 +08:00 committed by GitHub
parent e9f1abc1f0
commit 402e7837fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 19 deletions

View File

@ -108,9 +108,11 @@ struct BitFieldContainer {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) { __device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t min_size = min(NumValues(), rhs.NumValues()); std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
if (tid < min_size) { if (tid < min_size) {
Data()[tid] |= rhs.Data()[tid]; if (this->Check(tid) || rhs.Check(tid)) {
this->Set(tid);
}
} }
return *this; return *this;
} }
@ -126,16 +128,20 @@ struct BitFieldContainer {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) { __device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = min(NumValues(), rhs.NumValues());
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
if (tid < min_size) { if (tid < min_size) {
Data()[tid] &= rhs.Data()[tid]; if (this->Check(tid) && rhs.Check(tid)) {
this->Set(tid);
} else {
this->Clear(tid);
}
} }
return *this; return *this;
} }
#else #else
BitFieldContainer& operator&=(BitFieldContainer const& rhs) { BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = std::min(NumValues(), rhs.NumValues()); std::size_t min_size = std::min(NumValues(), rhs.NumValues());
for (size_t i = 0; i < min_size; ++i) { for (size_t i = 0; i < min_size; ++i) {
Data()[i] &= rhs.Data()[i]; Data()[i] &= rhs.Data()[i];
} }

View File

@ -6,7 +6,6 @@
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <algorithm>
#include <string> #include <string>
#include <set> #include <set>
@ -279,10 +278,6 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
} }
// enable constraints from feature // enable constraints from feature
node |= feature; node |= feature;
// clear the buffer after use
if (tid < feature.Capacity()) {
feature.Clear(tid);
}
// enable constraints from parent // enable constraints from parent
left |= node; left |= node;
@ -304,7 +299,7 @@ void FeatureInteractionConstraintDevice::Split(
<< " Split node: " << node_id << " and its left child: " << " Split node: " << node_id << " and its left child: "
<< left_id << " cannot be the same."; << left_id << " cannot be the same.";
CHECK_NE(node_id, right_id) CHECK_NE(node_id, right_id)
<< " Split node: " << node_id << " and its left child: " << " Split node: " << node_id << " and its right child: "
<< right_id << " cannot be the same."; << right_id << " cannot be the same.";
CHECK_LT(right_id, s_node_constraints_.size()); CHECK_LT(right_id, s_node_constraints_.size());
CHECK_NE(s_node_constraints_.size(), 0); CHECK_NE(s_node_constraints_.size(), 0);
@ -330,6 +325,9 @@ void FeatureInteractionConstraintDevice::Split(
feature_buffer_, feature_buffer_,
feature_id, feature_id,
node, left, right); node, left, right);
// clear the buffer after use
thrust::fill_n(dh::CachingThrustPolicy(), feature_buffer_.Data(), feature_buffer_.NumValues(), 0);
} }
} // namespace xgboost } // namespace xgboost

View File

@ -1,16 +1,17 @@
/** /**
* Copyright 2019-2023, XGBoost contributors * Copyright 2019-2024, XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <cinttypes>
#include <string> #include <cstdint>
#include <bitset>
#include <set> #include <set>
#include <string>
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/tree/constraints.cuh" #include "../../../src/tree/constraints.cuh"
#include "../../../src/tree/param.h" #include "../../../src/tree/param.h"
#include "../../../src/common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
namespace { namespace {
@ -36,9 +37,7 @@ std::string GetConstraintsStr() {
} }
tree::TrainParam GetParameter() { tree::TrainParam GetParameter() {
std::vector<std::pair<std::string, std::string>> args{ Args args{{"interaction_constraints", GetConstraintsStr()}};
{"interaction_constraints", GetConstraintsStr()}
};
tree::TrainParam param; tree::TrainParam param;
param.Init(args); param.Init(args);
return param; return param;