diff --git a/src/common/bitfield.h b/src/common/bitfield.h index 621078764..6ecd7fcdf 100644 --- a/src/common/bitfield.h +++ b/src/common/bitfield.h @@ -108,9 +108,11 @@ struct BitFieldContainer { #if defined(__CUDA_ARCH__) __device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) { auto tid = blockIdx.x * blockDim.x + threadIdx.x; - size_t min_size = min(NumValues(), rhs.NumValues()); + std::size_t min_size = std::min(this->Capacity(), rhs.Capacity()); if (tid < min_size) { - Data()[tid] |= rhs.Data()[tid]; + if (this->Check(tid) || rhs.Check(tid)) { + this->Set(tid); + } } return *this; } @@ -126,16 +128,20 @@ struct BitFieldContainer { #if defined(__CUDA_ARCH__) __device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) { - size_t min_size = min(NumValues(), rhs.NumValues()); auto tid = blockIdx.x * blockDim.x + threadIdx.x; + std::size_t min_size = std::min(this->Capacity(), rhs.Capacity()); if (tid < min_size) { - Data()[tid] &= rhs.Data()[tid]; + if (this->Check(tid) && rhs.Check(tid)) { + this->Set(tid); + } else { + this->Clear(tid); + } } return *this; } #else BitFieldContainer& operator&=(BitFieldContainer const& rhs) { - size_t min_size = std::min(NumValues(), rhs.NumValues()); + std::size_t min_size = std::min(NumValues(), rhs.NumValues()); for (size_t i = 0; i < min_size; ++i) { Data()[i] &= rhs.Data()[i]; } diff --git a/src/tree/constraints.cu b/src/tree/constraints.cu index ae1d3073c..121d80094 100644 --- a/src/tree/constraints.cu +++ b/src/tree/constraints.cu @@ -6,7 +6,6 @@ #include #include -#include #include #include @@ -279,10 +278,6 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature, } // enable constraints from feature node |= feature; - // clear the buffer after use - if (tid < feature.Capacity()) { - feature.Clear(tid); - } // enable constraints from parent left |= node; @@ -304,7 +299,7 @@ void FeatureInteractionConstraintDevice::Split( << " Split node: " << node_id << " and its left child: " << left_id << " cannot be the same."; CHECK_NE(node_id, right_id) - << " Split node: " << node_id << " and its left child: " + << " Split node: " << node_id << " and its right child: " << right_id << " cannot be the same."; CHECK_LT(right_id, s_node_constraints_.size()); CHECK_NE(s_node_constraints_.size(), 0); @@ -330,6 +325,9 @@ void FeatureInteractionConstraintDevice::Split( feature_buffer_, feature_id, node, left, right); + + // clear the buffer after use + thrust::fill_n(dh::CachingThrustPolicy(), feature_buffer_.Data(), feature_buffer_.NumValues(), 0); } } // namespace xgboost diff --git a/tests/cpp/tree/test_constraints.cu b/tests/cpp/tree/test_constraints.cu index 09e72a1d2..2af54d892 100644 --- a/tests/cpp/tree/test_constraints.cu +++ b/tests/cpp/tree/test_constraints.cu @@ -1,16 +1,17 @@ /** - * Copyright 2019-2023, XGBoost contributors + * Copyright 2019-2024, XGBoost contributors */ #include #include #include -#include -#include -#include + +#include #include +#include + +#include "../../../src/common/device_helpers.cuh" #include "../../../src/tree/constraints.cuh" #include "../../../src/tree/param.h" -#include "../../../src/common/device_helpers.cuh" namespace xgboost { namespace { @@ -36,9 +37,7 @@ std::string GetConstraintsStr() { } tree::TrainParam GetParameter() { - std::vector> args{ - {"interaction_constraints", GetConstraintsStr()} - }; + Args args{{"interaction_constraints", GetConstraintsStr()}}; tree::TrainParam param; param.Init(args); return param;