From 313a74b58237042bca07cb6a850174727a75b0e8 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Mon, 1 May 2023 21:55:14 +0200 Subject: [PATCH] add Shap Magic to check if use cat --- rocgputreeshap | 2 +- src/predictor/gpu_predictor.cu | 8 ++++++-- src/tree/gpu_hist/evaluate_splits.cu | 5 +++++ warp-primitives | 2 +- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/rocgputreeshap b/rocgputreeshap index 3704f6142..4ede6a0ef 160000 --- a/rocgputreeshap +++ b/rocgputreeshap @@ -1 +1 @@ -Subproject commit 3704f6142138766bb6e3585f496c8b7de61d2d32 +Subproject commit 4ede6a0efef5c82776cfdc9e627dfab901898be4 diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 6676022b5..b50bcf399 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -428,6 +428,8 @@ class DeviceModel { } }; +#define ShapSplitMagic 99999 + struct ShapSplitCondition { ShapSplitCondition() = default; XGBOOST_DEVICE @@ -437,6 +439,7 @@ struct ShapSplitCondition { feature_upper_bound(feature_upper_bound), is_missing_branch(is_missing_branch), categories{std::move(cats)} { assert(feature_lower_bound <= feature_upper_bound); + cat_flag = ShapSplitMagic; } /*! Feature values >= lower and < upper flow down this path. */ @@ -444,6 +447,7 @@ struct ShapSplitCondition { float feature_upper_bound; /*! Feature value set to true flow down this path. */ common::CatBitField categories; + int cat_flag; /*! Do missing values flow down this path? */ bool is_missing_branch; @@ -453,7 +457,7 @@ struct ShapSplitCondition { if (isnan(x)) { return is_missing_branch; } - if (categories.Size() != 0) { + if (cat_flag == ShapSplitMagic && categories.Size() != 0) { auto cat = static_cast(x); return categories.Check(cat); } else { @@ -480,7 +484,7 @@ struct ShapSplitCondition { // Combine two split conditions on the same feature XGBOOST_DEVICE void Merge(ShapSplitCondition other) { // Combine duplicate features - if (categories.Size() != 0 || other.categories.Size() != 0) { + if (cat_flag == ShapSplitMagic && (categories.Size() != 0 || other.categories.Size() != 0)) { categories = Intersect(categories, other.categories); } else { feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index c6baa97b6..f3970c9ec 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -14,6 +14,11 @@ #if defined(XGBOOST_USE_HIP) #include +#ifdef __AMDGCN_WAVEFRONT_SIZE +#undef WAVEFRONT_SIZE +#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE +#endif + #define WARP_SIZE WAVEFRONT_SIZE #elif defined(XGBOOST_USE_CUDA) #define WARP_SIZE 32 diff --git a/warp-primitives b/warp-primitives index af1eccf83..c55a03e81 160000 --- a/warp-primitives +++ b/warp-primitives @@ -1 +1 @@ -Subproject commit af1eccf8313f0579ff190d4b76627b4559f19d1a +Subproject commit c55a03e81ef0049efbd5575ade1664b5f29232de