Fix potential race in feature constraint. (#10719)
This commit is contained in:
parent
e9f1abc1f0
commit
402e7837fb
@ -108,9 +108,11 @@ struct BitFieldContainer {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t min_size = min(NumValues(), rhs.NumValues());
|
||||
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
|
||||
if (tid < min_size) {
|
||||
Data()[tid] |= rhs.Data()[tid];
|
||||
if (this->Check(tid) || rhs.Check(tid)) {
|
||||
this->Set(tid);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -126,16 +128,20 @@ struct BitFieldContainer {
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
|
||||
size_t min_size = min(NumValues(), rhs.NumValues());
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
|
||||
if (tid < min_size) {
|
||||
Data()[tid] &= rhs.Data()[tid];
|
||||
if (this->Check(tid) && rhs.Check(tid)) {
|
||||
this->Set(tid);
|
||||
} else {
|
||||
this->Clear(tid);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#else
|
||||
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
|
||||
size_t min_size = std::min(NumValues(), rhs.NumValues());
|
||||
std::size_t min_size = std::min(NumValues(), rhs.NumValues());
|
||||
for (size_t i = 0; i < min_size; ++i) {
|
||||
Data()[i] &= rhs.Data()[i];
|
||||
}
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
@ -279,10 +278,6 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
|
||||
}
|
||||
// enable constraints from feature
|
||||
node |= feature;
|
||||
// clear the buffer after use
|
||||
if (tid < feature.Capacity()) {
|
||||
feature.Clear(tid);
|
||||
}
|
||||
|
||||
// enable constraints from parent
|
||||
left |= node;
|
||||
@ -304,7 +299,7 @@ void FeatureInteractionConstraintDevice::Split(
|
||||
<< " Split node: " << node_id << " and its left child: "
|
||||
<< left_id << " cannot be the same.";
|
||||
CHECK_NE(node_id, right_id)
|
||||
<< " Split node: " << node_id << " and its left child: "
|
||||
<< " Split node: " << node_id << " and its right child: "
|
||||
<< right_id << " cannot be the same.";
|
||||
CHECK_LT(right_id, s_node_constraints_.size());
|
||||
CHECK_NE(s_node_constraints_.size(), 0);
|
||||
@ -330,6 +325,9 @@ void FeatureInteractionConstraintDevice::Split(
|
||||
feature_buffer_,
|
||||
feature_id,
|
||||
node, left, right);
|
||||
|
||||
// clear the buffer after use
|
||||
thrust::fill_n(dh::CachingThrustPolicy(), feature_buffer_.Data(), feature_buffer_.NumValues(), 0);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost contributors
|
||||
* Copyright 2019-2024, XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <cinttypes>
|
||||
#include <string>
|
||||
#include <bitset>
|
||||
|
||||
#include <cstdint>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/tree/constraints.cuh"
|
||||
#include "../../../src/tree/param.h"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
@ -36,9 +37,7 @@ std::string GetConstraintsStr() {
|
||||
}
|
||||
|
||||
tree::TrainParam GetParameter() {
|
||||
std::vector<std::pair<std::string, std::string>> args{
|
||||
{"interaction_constraints", GetConstraintsStr()}
|
||||
};
|
||||
Args args{{"interaction_constraints", GetConstraintsStr()}};
|
||||
tree::TrainParam param;
|
||||
param.Init(args);
|
||||
return param;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user