add Shap Magic to check if use cat

This commit is contained in:
amdsc21 2023-05-01 21:55:14 +02:00
parent 65d83e288f
commit 313a74b582
4 changed files with 13 additions and 4 deletions

@ -1 +1 @@
Subproject commit 3704f6142138766bb6e3585f496c8b7de61d2d32
Subproject commit 4ede6a0efef5c82776cfdc9e627dfab901898be4

View File

@ -428,6 +428,8 @@ class DeviceModel {
}
};
#define ShapSplitMagic 99999
struct ShapSplitCondition {
ShapSplitCondition() = default;
XGBOOST_DEVICE
@ -437,6 +439,7 @@ struct ShapSplitCondition {
feature_upper_bound(feature_upper_bound),
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
assert(feature_lower_bound <= feature_upper_bound);
cat_flag = ShapSplitMagic;
}
/*! Feature values >= lower and < upper flow down this path. */
@ -444,6 +447,7 @@ struct ShapSplitCondition {
float feature_upper_bound;
/*! Feature value set to true flow down this path. */
common::CatBitField categories;
int cat_flag;
/*! Do missing values flow down this path? */
bool is_missing_branch;
@ -453,7 +457,7 @@ struct ShapSplitCondition {
if (isnan(x)) {
return is_missing_branch;
}
if (categories.Size() != 0) {
if (cat_flag == ShapSplitMagic && categories.Size() != 0) {
auto cat = static_cast<uint32_t>(x);
return categories.Check(cat);
} else {
@ -480,7 +484,7 @@ struct ShapSplitCondition {
// Combine two split conditions on the same feature
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
// Combine duplicate features
if (categories.Size() != 0 || other.categories.Size() != 0) {
if (cat_flag == ShapSplitMagic && (categories.Size() != 0 || other.categories.Size() != 0)) {
categories = Intersect(categories, other.categories);
} else {
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);

View File

@ -14,6 +14,11 @@
#if defined(XGBOOST_USE_HIP)
#include <hip/hip_cooperative_groups.h>
#ifdef __AMDGCN_WAVEFRONT_SIZE
#undef WAVEFRONT_SIZE
#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE
#endif
#define WARP_SIZE WAVEFRONT_SIZE
#elif defined(XGBOOST_USE_CUDA)
#define WARP_SIZE 32

@ -1 +1 @@
Subproject commit af1eccf8313f0579ff190d4b76627b4559f19d1a
Subproject commit c55a03e81ef0049efbd5575ade1664b5f29232de