Fix potential race in feature constraint. (#10719)
This commit is contained in:
parent
e9f1abc1f0
commit
402e7837fb
@ -108,9 +108,11 @@ struct BitFieldContainer {
|
|||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
|
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
size_t min_size = min(NumValues(), rhs.NumValues());
|
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
|
||||||
if (tid < min_size) {
|
if (tid < min_size) {
|
||||||
Data()[tid] |= rhs.Data()[tid];
|
if (this->Check(tid) || rhs.Check(tid)) {
|
||||||
|
this->Set(tid);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -126,16 +128,20 @@ struct BitFieldContainer {
|
|||||||
|
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
|
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
|
||||||
size_t min_size = min(NumValues(), rhs.NumValues());
|
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
|
||||||
if (tid < min_size) {
|
if (tid < min_size) {
|
||||||
Data()[tid] &= rhs.Data()[tid];
|
if (this->Check(tid) && rhs.Check(tid)) {
|
||||||
|
this->Set(tid);
|
||||||
|
} else {
|
||||||
|
this->Clear(tid);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
|
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
|
||||||
size_t min_size = std::min(NumValues(), rhs.NumValues());
|
std::size_t min_size = std::min(NumValues(), rhs.NumValues());
|
||||||
for (size_t i = 0; i < min_size; ++i) {
|
for (size_t i = 0; i < min_size; ++i) {
|
||||||
Data()[i] &= rhs.Data()[i];
|
Data()[i] &= rhs.Data()[i];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
#include <thrust/iterator/counting_iterator.h>
|
#include <thrust/iterator/counting_iterator.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
@ -279,10 +278,6 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
|
|||||||
}
|
}
|
||||||
// enable constraints from feature
|
// enable constraints from feature
|
||||||
node |= feature;
|
node |= feature;
|
||||||
// clear the buffer after use
|
|
||||||
if (tid < feature.Capacity()) {
|
|
||||||
feature.Clear(tid);
|
|
||||||
}
|
|
||||||
|
|
||||||
// enable constraints from parent
|
// enable constraints from parent
|
||||||
left |= node;
|
left |= node;
|
||||||
@ -304,7 +299,7 @@ void FeatureInteractionConstraintDevice::Split(
|
|||||||
<< " Split node: " << node_id << " and its left child: "
|
<< " Split node: " << node_id << " and its left child: "
|
||||||
<< left_id << " cannot be the same.";
|
<< left_id << " cannot be the same.";
|
||||||
CHECK_NE(node_id, right_id)
|
CHECK_NE(node_id, right_id)
|
||||||
<< " Split node: " << node_id << " and its left child: "
|
<< " Split node: " << node_id << " and its right child: "
|
||||||
<< right_id << " cannot be the same.";
|
<< right_id << " cannot be the same.";
|
||||||
CHECK_LT(right_id, s_node_constraints_.size());
|
CHECK_LT(right_id, s_node_constraints_.size());
|
||||||
CHECK_NE(s_node_constraints_.size(), 0);
|
CHECK_NE(s_node_constraints_.size(), 0);
|
||||||
@ -330,6 +325,9 @@ void FeatureInteractionConstraintDevice::Split(
|
|||||||
feature_buffer_,
|
feature_buffer_,
|
||||||
feature_id,
|
feature_id,
|
||||||
node, left, right);
|
node, left, right);
|
||||||
|
|
||||||
|
// clear the buffer after use
|
||||||
|
thrust::fill_n(dh::CachingThrustPolicy(), feature_buffer_.Data(), feature_buffer_.NumValues(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,16 +1,17 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023, XGBoost contributors
|
* Copyright 2019-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
#include <cinttypes>
|
|
||||||
#include <string>
|
#include <cstdint>
|
||||||
#include <bitset>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "../../../src/common/device_helpers.cuh"
|
||||||
#include "../../../src/tree/constraints.cuh"
|
#include "../../../src/tree/constraints.cuh"
|
||||||
#include "../../../src/tree/param.h"
|
#include "../../../src/tree/param.h"
|
||||||
#include "../../../src/common/device_helpers.cuh"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace {
|
namespace {
|
||||||
@ -36,9 +37,7 @@ std::string GetConstraintsStr() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tree::TrainParam GetParameter() {
|
tree::TrainParam GetParameter() {
|
||||||
std::vector<std::pair<std::string, std::string>> args{
|
Args args{{"interaction_constraints", GetConstraintsStr()}};
|
||||||
{"interaction_constraints", GetConstraintsStr()}
|
|
||||||
};
|
|
||||||
tree::TrainParam param;
|
tree::TrainParam param;
|
||||||
param.Init(args);
|
param.Init(args);
|
||||||
return param;
|
return param;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user