add Shap Magic to check if use cat

This commit is contained in:
amdsc21
2023-05-01 21:55:14 +02:00
parent 65d83e288f
commit 313a74b582
4 changed files with 13 additions and 4 deletions

View File

@@ -428,6 +428,8 @@ class DeviceModel {
}
};
#define ShapSplitMagic 99999
struct ShapSplitCondition {
ShapSplitCondition() = default;
XGBOOST_DEVICE
@@ -437,6 +439,7 @@ struct ShapSplitCondition {
feature_upper_bound(feature_upper_bound),
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
assert(feature_lower_bound <= feature_upper_bound);
cat_flag = ShapSplitMagic;
}
/*! Feature values >= lower and < upper flow down this path. */
@@ -444,6 +447,7 @@ struct ShapSplitCondition {
float feature_upper_bound;
/*! Feature value set to true flow down this path. */
common::CatBitField categories;
int cat_flag;
/*! Do missing values flow down this path? */
bool is_missing_branch;
@@ -453,7 +457,7 @@ struct ShapSplitCondition {
if (isnan(x)) {
return is_missing_branch;
}
if (categories.Size() != 0) {
if (cat_flag == ShapSplitMagic && categories.Size() != 0) {
auto cat = static_cast<uint32_t>(x);
return categories.Check(cat);
} else {
@@ -480,7 +484,7 @@ struct ShapSplitCondition {
// Combine two split conditions on the same feature
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
// Combine duplicate features
if (categories.Size() != 0 || other.categories.Size() != 0) {
if (cat_flag == ShapSplitMagic && (categories.Size() != 0 || other.categories.Size() != 0)) {
categories = Intersect(categories, other.categories);
} else {
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);

View File

@@ -14,6 +14,11 @@
#if defined(XGBOOST_USE_HIP)
#include <hip/hip_cooperative_groups.h>
#ifdef __AMDGCN_WAVEFRONT_SIZE
#undef WAVEFRONT_SIZE
#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE
#endif
#define WARP_SIZE WAVEFRONT_SIZE
#elif defined(XGBOOST_USE_CUDA)
#define WARP_SIZE 32