add Shap Magic to check if use cat
This commit is contained in:
parent
65d83e288f
commit
313a74b582
@ -1 +1 @@
|
||||
Subproject commit 3704f6142138766bb6e3585f496c8b7de61d2d32
|
||||
Subproject commit 4ede6a0efef5c82776cfdc9e627dfab901898be4
|
||||
@ -428,6 +428,8 @@ class DeviceModel {
|
||||
}
|
||||
};
|
||||
|
||||
#define ShapSplitMagic 99999
|
||||
|
||||
struct ShapSplitCondition {
|
||||
ShapSplitCondition() = default;
|
||||
XGBOOST_DEVICE
|
||||
@ -437,6 +439,7 @@ struct ShapSplitCondition {
|
||||
feature_upper_bound(feature_upper_bound),
|
||||
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
|
||||
assert(feature_lower_bound <= feature_upper_bound);
|
||||
cat_flag = ShapSplitMagic;
|
||||
}
|
||||
|
||||
/*! Feature values >= lower and < upper flow down this path. */
|
||||
@ -444,6 +447,7 @@ struct ShapSplitCondition {
|
||||
float feature_upper_bound;
|
||||
/*! Feature value set to true flow down this path. */
|
||||
common::CatBitField categories;
|
||||
int cat_flag;
|
||||
/*! Do missing values flow down this path? */
|
||||
bool is_missing_branch;
|
||||
|
||||
@ -453,7 +457,7 @@ struct ShapSplitCondition {
|
||||
if (isnan(x)) {
|
||||
return is_missing_branch;
|
||||
}
|
||||
if (categories.Size() != 0) {
|
||||
if (cat_flag == ShapSplitMagic && categories.Size() != 0) {
|
||||
auto cat = static_cast<uint32_t>(x);
|
||||
return categories.Check(cat);
|
||||
} else {
|
||||
@ -480,7 +484,7 @@ struct ShapSplitCondition {
|
||||
// Combine two split conditions on the same feature
|
||||
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
|
||||
// Combine duplicate features
|
||||
if (categories.Size() != 0 || other.categories.Size() != 0) {
|
||||
if (cat_flag == ShapSplitMagic && (categories.Size() != 0 || other.categories.Size() != 0)) {
|
||||
categories = Intersect(categories, other.categories);
|
||||
} else {
|
||||
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
|
||||
|
||||
@ -14,6 +14,11 @@
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
|
||||
#ifdef __AMDGCN_WAVEFRONT_SIZE
|
||||
#undef WAVEFRONT_SIZE
|
||||
#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE
|
||||
#endif
|
||||
|
||||
#define WARP_SIZE WAVEFRONT_SIZE
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
#define WARP_SIZE 32
|
||||
|
||||
@ -1 +1 @@
|
||||
Subproject commit af1eccf8313f0579ff190d4b76627b4559f19d1a
|
||||
Subproject commit c55a03e81ef0049efbd5575ade1664b5f29232de
|
||||
Loading…
x
Reference in New Issue
Block a user