GPUTreeShap (#6038)
This commit is contained in:
parent
b3193052b3
commit
9a4e8b1d81
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -4,3 +4,6 @@
|
||||
[submodule "cub"]
|
||||
path = cub
|
||||
url = https://github.com/NVlabs/cub
|
||||
[submodule "gputreeshap"]
|
||||
path = gputreeshap
|
||||
url = https://github.com/rapidsai/gputreeshap.git
|
||||
|
||||
1
gputreeshap
Submodule
1
gputreeshap
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit a3d4c44cc6a0a6c3870e7cebcd1ef1d09d7bc0cb
|
||||
@ -9,6 +9,7 @@ if (USE_CUDA)
|
||||
file(GLOB_RECURSE CUDA_SOURCES *.cu *.cuh)
|
||||
target_sources(objxgboost PRIVATE ${CUDA_SOURCES})
|
||||
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_CUDA=1)
|
||||
target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap)
|
||||
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
|
||||
target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/cub/)
|
||||
endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
|
||||
|
||||
@ -474,8 +474,18 @@ class TemporaryArray {
|
||||
using AllocT = XGBCachingDeviceAllocator<T>;
|
||||
using value_type = T; // NOLINT
|
||||
explicit TemporaryArray(size_t n) : size_(n) { ptr_ = AllocT().allocate(n); }
|
||||
TemporaryArray(size_t n, T val) : size_(n) {
|
||||
ptr_ = AllocT().allocate(n);
|
||||
this->fill(val);
|
||||
}
|
||||
~TemporaryArray() { AllocT().deallocate(ptr_, this->size()); }
|
||||
|
||||
void fill(T val) // NOLINT
|
||||
{
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
auto d_data = ptr_.get();
|
||||
LaunchN(device, this->size(), [=] __device__(size_t idx) { d_data[idx] = val; });
|
||||
}
|
||||
thrust::device_ptr<T> data() { return ptr_; } // NOLINT
|
||||
size_t size() { return size_; } // NOLINT
|
||||
|
||||
|
||||
@ -238,11 +238,11 @@ class GBTree : public GradientBooster {
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit, bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
unsigned ntree_limit, bool approximate,
|
||||
int condition, unsigned condition_feature) override {
|
||||
CHECK(configured_);
|
||||
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
|
||||
ntree_limit, nullptr, approximate);
|
||||
this->GetPredictor()->PredictContribution(
|
||||
p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate);
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/fill.h>
|
||||
#include <GPUTreeShap/gpu_treeshap.h>
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/data.h"
|
||||
@ -27,53 +28,20 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
struct SparsePageView {
|
||||
common::Span<const Entry> d_data;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
bst_feature_t num_features;
|
||||
|
||||
XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
|
||||
common::Span<const bst_row_t> row_ptr) :
|
||||
d_data{data}, d_row_ptr{row_ptr} {}
|
||||
};
|
||||
|
||||
struct SparsePageLoader {
|
||||
bool use_shared;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
common::Span<const Entry> d_data;
|
||||
bst_feature_t num_features;
|
||||
float* smem;
|
||||
size_t entry_start;
|
||||
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start)
|
||||
: use_shared(use_shared),
|
||||
d_row_ptr(data.d_row_ptr),
|
||||
d_data(data.d_data),
|
||||
num_features(num_features),
|
||||
entry_start(entry_start) {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
// Copy instances
|
||||
if (use_shared) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int shared_elements = blockDim.x * num_features;
|
||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
||||
__syncthreads();
|
||||
if (global_idx < num_rows) {
|
||||
bst_uint elem_begin = d_row_ptr[global_idx];
|
||||
bst_uint elem_end = d_row_ptr[global_idx + 1];
|
||||
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
|
||||
Entry elem = d_data[elem_idx - entry_start];
|
||||
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
__device__ float GetFvalue(int ridx, int fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * num_features + fidx];
|
||||
} else {
|
||||
common::Span<const bst_row_t> row_ptr,
|
||||
bst_feature_t num_features)
|
||||
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
|
||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
// Binary search
|
||||
auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start);
|
||||
auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start);
|
||||
auto begin_ptr = d_data.begin() + d_row_ptr[ridx];
|
||||
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1];
|
||||
if (end_ptr - begin_ptr == this->NumCols()) {
|
||||
// Bypass span check for dense data
|
||||
return d_data.data()[d_row_ptr[ridx] + fidx].fvalue;
|
||||
}
|
||||
common::Span<const Entry>::iterator previous_middle;
|
||||
while (end_ptr != begin_ptr) {
|
||||
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
||||
@ -94,6 +62,46 @@ struct SparsePageLoader {
|
||||
// Value is missing
|
||||
return nanf("");
|
||||
}
|
||||
XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
|
||||
XGBOOST_DEVICE size_t NumCols() const { return num_features; }
|
||||
};
|
||||
|
||||
struct SparsePageLoader {
|
||||
bool use_shared;
|
||||
SparsePageView data;
|
||||
float* smem;
|
||||
size_t entry_start;
|
||||
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start)
|
||||
: use_shared(use_shared),
|
||||
data(data),
|
||||
entry_start(entry_start) {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
// Copy instances
|
||||
if (use_shared) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int shared_elements = blockDim.x * data.num_features;
|
||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
||||
__syncthreads();
|
||||
if (global_idx < num_rows) {
|
||||
bst_uint elem_begin = data.d_row_ptr[global_idx];
|
||||
bst_uint elem_end = data.d_row_ptr[global_idx + 1];
|
||||
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
|
||||
Entry elem = data.d_data[elem_idx - entry_start];
|
||||
smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * data.num_features + fidx];
|
||||
} else {
|
||||
return data.GetElement(ridx, fidx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -103,7 +111,7 @@ struct EllpackLoader {
|
||||
bst_feature_t num_features, bst_row_t num_rows,
|
||||
size_t entry_start)
|
||||
: matrix{m} {}
|
||||
__device__ __forceinline__ float GetFvalue(int ridx, int fidx) const {
|
||||
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
@ -150,7 +158,7 @@ struct DeviceAdapterLoader {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * columns + fidx];
|
||||
}
|
||||
@ -163,7 +171,7 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||
Loader* loader) {
|
||||
RegTree::Node n = tree[0];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader->GetFvalue(ridx, n.SplitIndex());
|
||||
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(fvalue)) {
|
||||
n = tree[n.DefaultChild()];
|
||||
@ -273,7 +281,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()};
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<SparsePageLoader, SparsePageView>,
|
||||
data,
|
||||
@ -447,6 +456,60 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
if (approximate) {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " approximate is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
uint32_t real_ntree_limit =
|
||||
ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
CHECK_NE(ngroup, 0);
|
||||
// allocate space for (number of features + bias) times the number of rows
|
||||
std::vector<bst_float>& contribs = *out_contribs;
|
||||
size_t contributions_columns =
|
||||
model.learner_model_param->num_feature + 1; // +1 for bias
|
||||
contribs.resize(p_fmat->Info().num_row_ * contributions_columns *
|
||||
model.learner_model_param->num_output_group);
|
||||
dh::TemporaryArray<float> phis(contribs.size(), 0.0);
|
||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
|
||||
float base_score = model.learner_model_param->base_score;
|
||||
auto d_phis = phis.data().get();
|
||||
// Add the base margin term to last column
|
||||
dh::LaunchN(
|
||||
generic_param_->gpu_id,
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
|
||||
[=] __device__(size_t idx) {
|
||||
d_phis[(idx + 1) * contributions_columns - 1] =
|
||||
margin.empty() ? base_score : margin[idx];
|
||||
});
|
||||
|
||||
const auto& paths = this->ExtractPaths(model, real_ntree_limit);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature);
|
||||
gpu_treeshap::GPUTreeShap(
|
||||
X, paths, ngroup,
|
||||
phis.data().get() + batch.base_rowid * contributions_columns);
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(),
|
||||
sizeof(float) * phis.size(),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitOutPredictions(const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
@ -478,16 +541,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
@ -510,6 +563,49 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<gpu_treeshap::PathElement> ExtractPaths(
|
||||
const gbm::GBTreeModel& model, size_t tree_limit) {
|
||||
std::vector<gpu_treeshap::PathElement> paths;
|
||||
size_t path_idx = 0;
|
||||
CHECK_LE(tree_limit, model.trees.size());
|
||||
for (auto i = 0ull; i < tree_limit; i++) {
|
||||
const auto& tree = *model.trees.at(i);
|
||||
size_t group = model.tree_info[i];
|
||||
const auto& nodes = tree.GetNodes();
|
||||
for (auto j = 0ull; j < nodes.size(); j++) {
|
||||
if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) {
|
||||
auto child = nodes[j];
|
||||
float v = child.LeafValue();
|
||||
size_t child_idx = j;
|
||||
const float inf = std::numeric_limits<float>::infinity();
|
||||
while (!child.IsRoot()) {
|
||||
float child_cover = tree.Stat(child_idx).sum_hess;
|
||||
float parent_cover = tree.Stat(child.Parent()).sum_hess;
|
||||
float zero_fraction = child_cover / parent_cover;
|
||||
CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0);
|
||||
auto parent = nodes[child.Parent()];
|
||||
CHECK(parent.LeftChild() == child_idx ||
|
||||
parent.RightChild() == child_idx);
|
||||
bool is_left_path = parent.LeftChild() == child_idx;
|
||||
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
|
||||
(parent.DefaultLeft() && is_left_path);
|
||||
float lower_bound = is_left_path ? -inf : parent.SplitCond();
|
||||
float upper_bound = is_left_path ? parent.SplitCond() : inf;
|
||||
paths.emplace_back(path_idx, parent.SplitIndex(), group,
|
||||
lower_bound, upper_bound, is_missing_path,
|
||||
zero_fraction, v);
|
||||
child_idx = child.Parent();
|
||||
child = parent;
|
||||
}
|
||||
// Root node has feature -1
|
||||
paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v);
|
||||
path_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return paths;
|
||||
}
|
||||
|
||||
std::mutex lock_;
|
||||
DeviceModel model_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
|
||||
@ -163,5 +163,61 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
|
||||
TEST(GpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("gpu_predictor");
|
||||
}
|
||||
// Very basic test of empty model
|
||||
TEST(GPUPredictor, ShapStump) {
|
||||
cudaSetDevice(0);
|
||||
LearnerModelParam param;
|
||||
param.num_feature = 1;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
gbm::GBTreeModel model(¶m);
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
|
||||
auto gpu_lparam = CreateEmptyGenericParam(0);
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
|
||||
gpu_predictor->Configure({});
|
||||
std::vector<float > phis;
|
||||
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
||||
gpu_predictor->PredictContribution(dmat.get(), &phis, model);
|
||||
EXPECT_EQ(phis[0], 0.0);
|
||||
EXPECT_EQ(phis[1], param.base_score);
|
||||
EXPECT_EQ(phis[2], 0.0);
|
||||
EXPECT_EQ(phis[3], param.base_score);
|
||||
EXPECT_EQ(phis[4], 0.0);
|
||||
EXPECT_EQ(phis[5], param.base_score);
|
||||
}
|
||||
TEST(GPUPredictor, Shap) {
|
||||
LearnerModelParam param;
|
||||
param.num_feature = 1;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
gbm::GBTreeModel model(¶m);
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0);
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
|
||||
auto gpu_lparam = CreateEmptyGenericParam(0);
|
||||
auto cpu_lparam = CreateEmptyGenericParam(-1);
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &cpu_lparam));
|
||||
gpu_predictor->Configure({});
|
||||
cpu_predictor->Configure({});
|
||||
std::vector<float > phis;
|
||||
std::vector<float > cpu_phis;
|
||||
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
||||
gpu_predictor->PredictContribution(dmat.get(), &phis, model);
|
||||
cpu_predictor->PredictContribution(dmat.get(), &cpu_phis, model);
|
||||
for(auto i = 0ull; i < phis.size(); i++)
|
||||
{
|
||||
EXPECT_NEAR(cpu_phis[i], phis[i], 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from hypothesis import given, strategies, assume, settings, note
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
@ -11,6 +12,12 @@ from test_predict import run_threaded_predict # noqa
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
shap_parameter_strategy = strategies.fixed_dictionaries({
|
||||
'max_depth': strategies.integers(0, 11),
|
||||
'max_leaves': strategies.integers(0, 256),
|
||||
'num_parallel_tree': strategies.sampled_from([1, 10]),
|
||||
})
|
||||
|
||||
|
||||
class TestGPUPredict(unittest.TestCase):
|
||||
def test_predict(self):
|
||||
@ -149,7 +156,8 @@ class TestGPUPredict(unittest.TestCase):
|
||||
|
||||
# Don't do this on Windows, see issue #5793
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip('Multi-threaded in-place prediction with cuPy is not working on Windows')
|
||||
pytest.skip(
|
||||
'Multi-threaded in-place prediction with cuPy is not working on Windows')
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_dense)
|
||||
|
||||
@ -185,3 +193,24 @@ class TestGPUPredict(unittest.TestCase):
|
||||
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_df)
|
||||
|
||||
@given(strategies.integers(1, 200),
|
||||
tm.dataset_strategy, shap_parameter_strategy, strategies.booleans())
|
||||
@settings(deadline=None)
|
||||
def test_shap(self, num_rounds, dataset, param, all_rows):
|
||||
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
|
||||
param = dataset.set_params(param)
|
||||
dmat = dataset.get_dmat()
|
||||
bst = xgb.train(param, dmat, num_rounds)
|
||||
if all_rows:
|
||||
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
|
||||
else:
|
||||
test_dmat = xgb.DMatrix(dataset.X[0:1, :])
|
||||
shap = bst.predict(test_dmat, pred_contribs=True)
|
||||
bst.set_param({"predictor": "cpu_predictor"})
|
||||
cpu_shap = bst.predict(test_dmat, pred_contribs=True)
|
||||
margin = bst.predict(test_dmat, output_margin=True)
|
||||
assert np.allclose(shap, cpu_shap, 1e-3, 1e-3)
|
||||
# feature contributions should add up to predictions
|
||||
assume(len(dataset.y) > 0)
|
||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3)
|
||||
|
||||
@ -131,6 +131,7 @@ class TestDataset:
|
||||
self.metric = metric
|
||||
self.X, self.y = get_dataset()
|
||||
self.w = None
|
||||
self.margin = None
|
||||
|
||||
def set_params(self, params_in):
|
||||
params_in['objective'] = self.objective
|
||||
@ -140,13 +141,13 @@ class TestDataset:
|
||||
return params_in
|
||||
|
||||
def get_dmat(self):
|
||||
return xgb.DMatrix(self.X, self.y, self.w)
|
||||
return xgb.DMatrix(self.X, self.y, self.w, base_margin=self.margin)
|
||||
|
||||
def get_device_dmat(self):
|
||||
w = None if self.w is None else cp.array(self.w)
|
||||
X = cp.array(self.X, dtype=np.float32)
|
||||
y = cp.array(self.y, dtype=np.float32)
|
||||
return xgb.DeviceQuantileDMatrix(X, y, w)
|
||||
return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin)
|
||||
|
||||
def get_external_dmat(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@ -157,7 +158,7 @@ class TestDataset:
|
||||
uri = path + '?format=csv&label_column=0#tmptmp_'
|
||||
# The uri looks like:
|
||||
# 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_'
|
||||
return xgb.DMatrix(uri, weight=self.w)
|
||||
return xgb.DMatrix(uri, weight=self.w, base_margin=self.margin)
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
@ -206,16 +207,23 @@ _unweighted_datasets_strategy = strategies.sampled_from(
|
||||
|
||||
|
||||
@strategies.composite
|
||||
def _dataset_and_weight(draw):
|
||||
def _dataset_weight_margin(draw):
|
||||
data = draw(_unweighted_datasets_strategy)
|
||||
if draw(strategies.booleans()):
|
||||
data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)))
|
||||
if draw(strategies.booleans()):
|
||||
num_class = 1
|
||||
if data.objective == "multi:softmax":
|
||||
num_class = int(np.max(data.y) + 1)
|
||||
data.margin = draw(
|
||||
arrays(np.float64, (len(data.y) * num_class), elements=strategies.floats(0.5, 1.0)))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# A strategy for drawing from a set of example datasets
|
||||
# May add random weights to the dataset
|
||||
dataset_strategy = _dataset_and_weight()
|
||||
dataset_strategy = _dataset_weight_margin()
|
||||
|
||||
|
||||
def non_increasing(L, tolerance=1e-4):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user