add Shap Magic to check if use cat
This commit is contained in:
Submodule rocgputreeshap updated: 3704f61421...4ede6a0efe
@@ -428,6 +428,8 @@ class DeviceModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#define ShapSplitMagic 99999
|
||||||
|
|
||||||
struct ShapSplitCondition {
|
struct ShapSplitCondition {
|
||||||
ShapSplitCondition() = default;
|
ShapSplitCondition() = default;
|
||||||
XGBOOST_DEVICE
|
XGBOOST_DEVICE
|
||||||
@@ -437,6 +439,7 @@ struct ShapSplitCondition {
|
|||||||
feature_upper_bound(feature_upper_bound),
|
feature_upper_bound(feature_upper_bound),
|
||||||
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
|
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
|
||||||
assert(feature_lower_bound <= feature_upper_bound);
|
assert(feature_lower_bound <= feature_upper_bound);
|
||||||
|
cat_flag = ShapSplitMagic;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! Feature values >= lower and < upper flow down this path. */
|
/*! Feature values >= lower and < upper flow down this path. */
|
||||||
@@ -444,6 +447,7 @@ struct ShapSplitCondition {
|
|||||||
float feature_upper_bound;
|
float feature_upper_bound;
|
||||||
/*! Feature value set to true flow down this path. */
|
/*! Feature value set to true flow down this path. */
|
||||||
common::CatBitField categories;
|
common::CatBitField categories;
|
||||||
|
int cat_flag;
|
||||||
/*! Do missing values flow down this path? */
|
/*! Do missing values flow down this path? */
|
||||||
bool is_missing_branch;
|
bool is_missing_branch;
|
||||||
|
|
||||||
@@ -453,7 +457,7 @@ struct ShapSplitCondition {
|
|||||||
if (isnan(x)) {
|
if (isnan(x)) {
|
||||||
return is_missing_branch;
|
return is_missing_branch;
|
||||||
}
|
}
|
||||||
if (categories.Size() != 0) {
|
if (cat_flag == ShapSplitMagic && categories.Size() != 0) {
|
||||||
auto cat = static_cast<uint32_t>(x);
|
auto cat = static_cast<uint32_t>(x);
|
||||||
return categories.Check(cat);
|
return categories.Check(cat);
|
||||||
} else {
|
} else {
|
||||||
@@ -480,7 +484,7 @@ struct ShapSplitCondition {
|
|||||||
// Combine two split conditions on the same feature
|
// Combine two split conditions on the same feature
|
||||||
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
|
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
|
||||||
// Combine duplicate features
|
// Combine duplicate features
|
||||||
if (categories.Size() != 0 || other.categories.Size() != 0) {
|
if (cat_flag == ShapSplitMagic && (categories.Size() != 0 || other.categories.Size() != 0)) {
|
||||||
categories = Intersect(categories, other.categories);
|
categories = Intersect(categories, other.categories);
|
||||||
} else {
|
} else {
|
||||||
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
|
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
|
||||||
|
|||||||
@@ -14,6 +14,11 @@
|
|||||||
#if defined(XGBOOST_USE_HIP)
|
#if defined(XGBOOST_USE_HIP)
|
||||||
#include <hip/hip_cooperative_groups.h>
|
#include <hip/hip_cooperative_groups.h>
|
||||||
|
|
||||||
|
#ifdef __AMDGCN_WAVEFRONT_SIZE
|
||||||
|
#undef WAVEFRONT_SIZE
|
||||||
|
#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE
|
||||||
|
#endif
|
||||||
|
|
||||||
#define WARP_SIZE WAVEFRONT_SIZE
|
#define WARP_SIZE WAVEFRONT_SIZE
|
||||||
#elif defined(XGBOOST_USE_CUDA)
|
#elif defined(XGBOOST_USE_CUDA)
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
|
|||||||
Submodule warp-primitives updated: af1eccf831...c55a03e81e
Reference in New Issue
Block a user