add Shap Magic to check if use cat
This commit is contained in:
parent
65d83e288f
commit
313a74b582
@ -1 +1 @@
|
|||||||
Subproject commit 3704f6142138766bb6e3585f496c8b7de61d2d32
|
Subproject commit 4ede6a0efef5c82776cfdc9e627dfab901898be4
|
||||||
@ -428,6 +428,8 @@ class DeviceModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#define ShapSplitMagic 99999
|
||||||
|
|
||||||
struct ShapSplitCondition {
|
struct ShapSplitCondition {
|
||||||
ShapSplitCondition() = default;
|
ShapSplitCondition() = default;
|
||||||
XGBOOST_DEVICE
|
XGBOOST_DEVICE
|
||||||
@ -437,6 +439,7 @@ struct ShapSplitCondition {
|
|||||||
feature_upper_bound(feature_upper_bound),
|
feature_upper_bound(feature_upper_bound),
|
||||||
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
|
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
|
||||||
assert(feature_lower_bound <= feature_upper_bound);
|
assert(feature_lower_bound <= feature_upper_bound);
|
||||||
|
cat_flag = ShapSplitMagic;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! Feature values >= lower and < upper flow down this path. */
|
/*! Feature values >= lower and < upper flow down this path. */
|
||||||
@ -444,6 +447,7 @@ struct ShapSplitCondition {
|
|||||||
float feature_upper_bound;
|
float feature_upper_bound;
|
||||||
/*! Feature value set to true flow down this path. */
|
/*! Feature value set to true flow down this path. */
|
||||||
common::CatBitField categories;
|
common::CatBitField categories;
|
||||||
|
int cat_flag;
|
||||||
/*! Do missing values flow down this path? */
|
/*! Do missing values flow down this path? */
|
||||||
bool is_missing_branch;
|
bool is_missing_branch;
|
||||||
|
|
||||||
@ -453,7 +457,7 @@ struct ShapSplitCondition {
|
|||||||
if (isnan(x)) {
|
if (isnan(x)) {
|
||||||
return is_missing_branch;
|
return is_missing_branch;
|
||||||
}
|
}
|
||||||
if (categories.Size() != 0) {
|
if (cat_flag == ShapSplitMagic && categories.Size() != 0) {
|
||||||
auto cat = static_cast<uint32_t>(x);
|
auto cat = static_cast<uint32_t>(x);
|
||||||
return categories.Check(cat);
|
return categories.Check(cat);
|
||||||
} else {
|
} else {
|
||||||
@ -480,7 +484,7 @@ struct ShapSplitCondition {
|
|||||||
// Combine two split conditions on the same feature
|
// Combine two split conditions on the same feature
|
||||||
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
|
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
|
||||||
// Combine duplicate features
|
// Combine duplicate features
|
||||||
if (categories.Size() != 0 || other.categories.Size() != 0) {
|
if (cat_flag == ShapSplitMagic && (categories.Size() != 0 || other.categories.Size() != 0)) {
|
||||||
categories = Intersect(categories, other.categories);
|
categories = Intersect(categories, other.categories);
|
||||||
} else {
|
} else {
|
||||||
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
|
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
|
||||||
|
|||||||
@ -14,6 +14,11 @@
|
|||||||
#if defined(XGBOOST_USE_HIP)
|
#if defined(XGBOOST_USE_HIP)
|
||||||
#include <hip/hip_cooperative_groups.h>
|
#include <hip/hip_cooperative_groups.h>
|
||||||
|
|
||||||
|
#ifdef __AMDGCN_WAVEFRONT_SIZE
|
||||||
|
#undef WAVEFRONT_SIZE
|
||||||
|
#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE
|
||||||
|
#endif
|
||||||
|
|
||||||
#define WARP_SIZE WAVEFRONT_SIZE
|
#define WARP_SIZE WAVEFRONT_SIZE
|
||||||
#elif defined(XGBOOST_USE_CUDA)
|
#elif defined(XGBOOST_USE_CUDA)
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
Subproject commit af1eccf8313f0579ff190d4b76627b4559f19d1a
|
Subproject commit c55a03e81ef0049efbd5575ade1664b5f29232de
|
||||||
Loading…
x
Reference in New Issue
Block a user