Implement GPU predict leaf. (#6187)
This commit is contained in:
parent
7f101d1b33
commit
8a17610666
@ -144,7 +144,7 @@ class GradientBooster : public Model, public Configurable {
|
|||||||
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
||||||
*/
|
*/
|
||||||
virtual void PredictLeaf(DMatrix* dmat,
|
virtual void PredictLeaf(DMatrix* dmat,
|
||||||
std::vector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) = 0;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -164,10 +164,6 @@ class Predictor {
|
|||||||
unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn virtual void Predictor::PredictLeaf(DMatrix* dmat,
|
|
||||||
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model, unsigned
|
|
||||||
* ntree_limit = 0) = 0;
|
|
||||||
*
|
|
||||||
* \brief predict the leaf index of each tree, the output will be nsample *
|
* \brief predict the leaf index of each tree, the output will be nsample *
|
||||||
* ntree vector this is only valid in gbtree predictor.
|
* ntree vector this is only valid in gbtree predictor.
|
||||||
*
|
*
|
||||||
@ -177,7 +173,7 @@ class Predictor {
|
|||||||
* \param ntree_limit (Optional) The ntree limit.
|
* \param ntree_limit (Optional) The ntree limit.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
virtual void PredictLeaf(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) = 0;
|
||||||
|
|
||||||
|
|||||||
@ -147,9 +147,7 @@ class GBLinear : public GradientBooster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix*,
|
void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned) override {
|
||||||
std::vector<bst_float>*,
|
|
||||||
unsigned) override {
|
|
||||||
LOG(FATAL) << "gblinear does not support prediction of leaf index";
|
LOG(FATAL) << "gblinear does not support prediction of leaf index";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -278,10 +278,9 @@ class GBTree : public GradientBooster {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix* p_fmat,
|
void PredictLeaf(DMatrix* p_fmat,
|
||||||
std::vector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
unsigned ntree_limit) override {
|
unsigned ntree_limit) override {
|
||||||
CHECK(configured_);
|
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
|
||||||
cpu_predictor_->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictContribution(DMatrix* p_fmat,
|
void PredictContribution(DMatrix* p_fmat,
|
||||||
|
|||||||
@ -1106,7 +1106,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
|
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
|
||||||
approx_contribs);
|
approx_contribs);
|
||||||
} else if (pred_leaf) {
|
} else if (pred_leaf) {
|
||||||
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
|
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit);
|
||||||
} else {
|
} else {
|
||||||
auto local_cache = this->GetPredictionCache();
|
auto local_cache = this->GetPredictionCache();
|
||||||
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
|
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
|
||||||
|
|||||||
@ -345,7 +345,7 @@ class CPUPredictor : public Predictor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
|
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||||
const int nthread = omp_get_max_threads();
|
const int nthread = omp_get_max_threads();
|
||||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
|
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
|
||||||
@ -355,7 +355,7 @@ class CPUPredictor : public Predictor {
|
|||||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||||
}
|
}
|
||||||
std::vector<bst_float>& preds = *out_preds;
|
std::vector<bst_float>& preds = out_preds->HostVector();
|
||||||
preds.resize(info.num_row_ * ntree_limit);
|
preds.resize(info.num_row_ * ntree_limit);
|
||||||
// start collecting the prediction
|
// start collecting the prediction
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
|||||||
@ -169,7 +169,7 @@ struct DeviceAdapterLoader {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Loader>
|
template <typename Loader>
|
||||||
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
|
||||||
common::Span<FeatureType const> split_types,
|
common::Span<FeatureType const> split_types,
|
||||||
common::Span<RegTree::Segment const> d_cat_ptrs,
|
common::Span<RegTree::Segment const> d_cat_ptrs,
|
||||||
common::Span<uint32_t const> d_categories,
|
common::Span<uint32_t const> d_categories,
|
||||||
@ -201,6 +201,49 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
|||||||
return tree[nidx].LeafValue();
|
return tree[nidx].LeafValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Loader>
|
||||||
|
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
|
||||||
|
Loader const& loader) {
|
||||||
|
bst_node_t nidx = 0;
|
||||||
|
RegTree::Node n = tree[nidx];
|
||||||
|
while (!n.IsLeaf()) {
|
||||||
|
float fvalue = loader.GetElement(ridx, n.SplitIndex());
|
||||||
|
// Missing value
|
||||||
|
if (isnan(fvalue)) {
|
||||||
|
nidx = n.DefaultChild();
|
||||||
|
n = tree[nidx];
|
||||||
|
} else {
|
||||||
|
if (fvalue < n.SplitCond()) {
|
||||||
|
nidx = n.LeftChild();
|
||||||
|
n = tree[nidx];
|
||||||
|
} else {
|
||||||
|
nidx = n.RightChild();
|
||||||
|
n = tree[nidx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nidx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Loader, typename Data>
|
||||||
|
__global__ void PredictLeafKernel(Data data,
|
||||||
|
common::Span<const RegTree::Node> d_nodes,
|
||||||
|
common::Span<float> d_out_predictions,
|
||||||
|
common::Span<size_t const> d_tree_segments,
|
||||||
|
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||||
|
size_t num_rows, size_t entry_start, bool use_shared) {
|
||||||
|
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (ridx >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Loader loader(data, use_shared, num_features, num_rows, entry_start);
|
||||||
|
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||||
|
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||||
|
auto leaf = GetLeafIndex(ridx, d_tree, loader);
|
||||||
|
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Loader, typename Data>
|
template <typename Loader, typename Data>
|
||||||
__global__ void
|
__global__ void
|
||||||
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||||
@ -437,6 +480,19 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <size_t kBlockThreads>
|
||||||
|
size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) {
|
||||||
|
// No way max_shared_memory_bytes that is equal to 0.
|
||||||
|
CHECK_GT(max_shared_memory_bytes, 0);
|
||||||
|
size_t shared_memory_bytes =
|
||||||
|
static_cast<size_t>(sizeof(float) * cols * kBlockThreads);
|
||||||
|
if (shared_memory_bytes > max_shared_memory_bytes) {
|
||||||
|
shared_memory_bytes = 0;
|
||||||
|
}
|
||||||
|
return shared_memory_bytes;
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
class GPUPredictor : public xgboost::Predictor {
|
class GPUPredictor : public xgboost::Predictor {
|
||||||
private:
|
private:
|
||||||
@ -450,13 +506,10 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
size_t num_rows = batch.Size();
|
size_t num_rows = batch.Size();
|
||||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
|
|
||||||
auto shared_memory_bytes =
|
size_t shared_memory_bytes =
|
||||||
static_cast<size_t>(sizeof(float) * num_features * BLOCK_THREADS);
|
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes_);
|
||||||
bool use_shared = true;
|
bool use_shared = shared_memory_bytes != 0;
|
||||||
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
|
||||||
shared_memory_bytes = 0;
|
|
||||||
use_shared = false;
|
|
||||||
}
|
|
||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||||
num_features);
|
num_features);
|
||||||
@ -608,13 +661,9 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
const uint32_t BLOCK_THREADS = 128;
|
const uint32_t BLOCK_THREADS = 128;
|
||||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
||||||
|
|
||||||
auto shared_memory_bytes =
|
size_t shared_memory_bytes =
|
||||||
static_cast<size_t>(sizeof(float) * m->NumColumns() * BLOCK_THREADS);
|
SharedMemoryBytes<BLOCK_THREADS>(info.num_col_, max_shared_memory_bytes);
|
||||||
bool use_shared = true;
|
bool use_shared = shared_memory_bytes != 0;
|
||||||
if (shared_memory_bytes > max_shared_memory_bytes) {
|
|
||||||
shared_memory_bytes = 0;
|
|
||||||
use_shared = false;
|
|
||||||
}
|
|
||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
|
|
||||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||||
@ -780,11 +829,65 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
<< " is not implemented in GPU Predictor.";
|
<< " is not implemented in GPU Predictor.";
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix*, std::vector<bst_float>*,
|
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
|
||||||
const gbm::GBTreeModel&,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned) override {
|
unsigned ntree_limit) override {
|
||||||
LOG(FATAL) << "[Internal error]: " << __func__
|
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||||
<< " is not implemented in GPU Predictor.";
|
ConfigureDevice(generic_param_->gpu_id);
|
||||||
|
|
||||||
|
const MetaInfo& info = p_fmat->Info();
|
||||||
|
constexpr uint32_t kBlockThreads = 128;
|
||||||
|
size_t shared_memory_bytes =
|
||||||
|
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes_);
|
||||||
|
bool use_shared = shared_memory_bytes != 0;
|
||||||
|
bst_feature_t num_features = info.num_col_;
|
||||||
|
bst_row_t num_rows = info.num_row_;
|
||||||
|
size_t entry_start = 0;
|
||||||
|
|
||||||
|
uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group;
|
||||||
|
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||||
|
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||||
|
}
|
||||||
|
predictions->SetDevice(generic_param_->gpu_id);
|
||||||
|
predictions->Resize(num_rows * real_ntree_limit);
|
||||||
|
model_.Init(model, 0, real_ntree_limit, generic_param_->gpu_id);
|
||||||
|
|
||||||
|
if (p_fmat->PageExists<SparsePage>()) {
|
||||||
|
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
batch.data.SetDevice(generic_param_->gpu_id);
|
||||||
|
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||||
|
bst_row_t batch_offset = 0;
|
||||||
|
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||||
|
model.learner_model_param->num_feature};
|
||||||
|
size_t num_rows = batch.Size();
|
||||||
|
auto grid =
|
||||||
|
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||||
|
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
||||||
|
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
|
||||||
|
model_.nodes.ConstDeviceSpan(),
|
||||||
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
|
model_.tree_segments.ConstDeviceSpan(),
|
||||||
|
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
||||||
|
entry_start, use_shared);
|
||||||
|
batch_offset += batch.Size();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
|
||||||
|
bst_row_t batch_offset = 0;
|
||||||
|
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)};
|
||||||
|
size_t num_rows = batch.Size();
|
||||||
|
auto grid =
|
||||||
|
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||||
|
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
||||||
|
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
|
||||||
|
model_.nodes.ConstDeviceSpan(),
|
||||||
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
|
model_.tree_segments.ConstDeviceSpan(),
|
||||||
|
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
||||||
|
entry_start, use_shared);
|
||||||
|
batch_offset += batch.Size();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
|
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
|
||||||
@ -801,7 +904,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
|
|
||||||
std::mutex lock_;
|
std::mutex lock_;
|
||||||
DeviceModel model_;
|
DeviceModel model_;
|
||||||
size_t max_shared_memory_bytes_;
|
size_t max_shared_memory_bytes_ { 0 };
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||||
|
|||||||
@ -348,6 +348,13 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
|||||||
gen.GenerateDense(&out->Info().labels_);
|
gen.GenerateDense(&out->Info().labels_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (device_ >= 0) {
|
||||||
|
out->Info().labels_.SetDevice(device_);
|
||||||
|
for (auto const& page : out->GetBatches<SparsePage>()) {
|
||||||
|
page.data.SetDevice(device_);
|
||||||
|
page.offset.SetDevice(device_);
|
||||||
|
}
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -46,9 +46,10 @@ TEST(CpuPredictor, Basic) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test predict leaf
|
// Test predict leaf
|
||||||
std::vector<float> leaf_out_predictions;
|
HostDeviceVector<float> leaf_out_predictions;
|
||||||
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||||
for (auto v : leaf_out_predictions) {
|
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
|
||||||
|
for (auto v : h_leaf_out_predictions) {
|
||||||
ASSERT_EQ(v, 0);
|
ASSERT_EQ(v, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,10 +113,11 @@ TEST(CpuPredictor, ExternalMemory) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test predict leaf
|
// Test predict leaf
|
||||||
std::vector<float> leaf_out_predictions;
|
HostDeviceVector<float> leaf_out_predictions;
|
||||||
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||||
ASSERT_EQ(leaf_out_predictions.size(), dmat->Info().num_row_);
|
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
|
||||||
for (const auto& v : leaf_out_predictions) {
|
ASSERT_EQ(h_leaf_out_predictions.size(), dmat->Info().num_row_);
|
||||||
|
for (const auto& v : h_leaf_out_predictions) {
|
||||||
ASSERT_EQ(v, 0);
|
ASSERT_EQ(v, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -190,6 +190,7 @@ TEST(GPUPredictor, ShapStump) {
|
|||||||
EXPECT_EQ(phis[4], 0.0);
|
EXPECT_EQ(phis[4], 0.0);
|
||||||
EXPECT_EQ(phis[5], param.base_score);
|
EXPECT_EQ(phis[5], param.base_score);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GPUPredictor, Shap) {
|
TEST(GPUPredictor, Shap) {
|
||||||
LearnerModelParam param;
|
LearnerModelParam param;
|
||||||
param.num_feature = 1;
|
param.num_feature = 1;
|
||||||
@ -224,5 +225,28 @@ TEST(GPUPredictor, Shap) {
|
|||||||
TEST(GPUPredictor, CategoricalPrediction) {
|
TEST(GPUPredictor, CategoricalPrediction) {
|
||||||
TestCategoricalPrediction("gpu_predictor");
|
TestCategoricalPrediction("gpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GPUPredictor, PredictLeafBasic) {
|
||||||
|
size_t constexpr kRows = 5, kCols = 5;
|
||||||
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix();
|
||||||
|
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||||
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
|
||||||
|
gpu_predictor->Configure({});
|
||||||
|
|
||||||
|
LearnerModelParam param;
|
||||||
|
param.num_feature = kCols;
|
||||||
|
param.base_score = 0.0;
|
||||||
|
param.num_output_group = 1;
|
||||||
|
|
||||||
|
gbm::GBTreeModel model = CreateTestModel(¶m);
|
||||||
|
|
||||||
|
HostDeviceVector<float> leaf_out_predictions;
|
||||||
|
gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||||
|
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
|
||||||
|
for (auto v : h_leaf_out_predictions) {
|
||||||
|
ASSERT_EQ(v, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace predictor
|
} // namespace predictor
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ from hypothesis import given, strategies, assume, settings, note
|
|||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
import testing as tm
|
import testing as tm
|
||||||
from test_predict import run_threaded_predict # noqa
|
from test_predict import run_threaded_predict # noqa
|
||||||
|
from test_predict import run_predict_leaf # noqa
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
@ -18,6 +20,11 @@ shap_parameter_strategy = strategies.fixed_dictionaries({
|
|||||||
'num_parallel_tree': strategies.sampled_from([1, 10]),
|
'num_parallel_tree': strategies.sampled_from([1, 10]),
|
||||||
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
|
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
|
||||||
|
|
||||||
|
predict_parameter_strategy = strategies.fixed_dictionaries({
|
||||||
|
'max_depth': strategies.integers(1, 8),
|
||||||
|
'num_parallel_tree': strategies.sampled_from([1, 4]),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
class TestGPUPredict(unittest.TestCase):
|
class TestGPUPredict(unittest.TestCase):
|
||||||
def test_predict(self):
|
def test_predict(self):
|
||||||
@ -223,3 +230,34 @@ class TestGPUPredict(unittest.TestCase):
|
|||||||
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
||||||
margin,
|
margin,
|
||||||
1e-3, 1e-3)
|
1e-3, 1e-3)
|
||||||
|
|
||||||
|
def test_predict_leaf_basic(self):
|
||||||
|
gpu_leaf = run_predict_leaf('gpu_predictor')
|
||||||
|
cpu_leaf = run_predict_leaf('cpu_predictor')
|
||||||
|
np.testing.assert_equal(gpu_leaf, cpu_leaf)
|
||||||
|
|
||||||
|
def run_predict_leaf_booster(self, param, num_rounds, dataset):
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
m = dataset.get_dmat()
|
||||||
|
booster = xgb.train(param, dtrain=dataset.get_dmat(), num_boost_round=num_rounds)
|
||||||
|
booster.set_param({'predictor': 'cpu_predictor'})
|
||||||
|
cpu_leaf = booster.predict(m, pred_leaf=True)
|
||||||
|
|
||||||
|
booster.set_param({'predictor': 'gpu_predictor'})
|
||||||
|
gpu_leaf = booster.predict(m, pred_leaf=True)
|
||||||
|
|
||||||
|
np.testing.assert_equal(cpu_leaf, gpu_leaf)
|
||||||
|
|
||||||
|
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||||
|
@settings(deadline=None)
|
||||||
|
def test_predict_leaf_gbtree(self, param, dataset):
|
||||||
|
param['booster'] = 'gbtree'
|
||||||
|
param['tree_method'] = 'gpu_hist'
|
||||||
|
self.run_predict_leaf_booster(param, 10, dataset)
|
||||||
|
|
||||||
|
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||||
|
@settings(deadline=None)
|
||||||
|
def test_predict_leaf_dart(self, param, dataset):
|
||||||
|
param['booster'] = 'dart'
|
||||||
|
param['tree_method'] = 'gpu_hist'
|
||||||
|
self.run_predict_leaf_booster(param, 10, dataset)
|
||||||
|
|||||||
@ -23,6 +23,49 @@ def run_threaded_predict(X, rows, predict_func):
|
|||||||
assert f.result()
|
assert f.result()
|
||||||
|
|
||||||
|
|
||||||
|
def run_predict_leaf(predictor):
|
||||||
|
rows = 100
|
||||||
|
cols = 4
|
||||||
|
classes = 5
|
||||||
|
num_parallel_tree = 4
|
||||||
|
num_boost_round = 10
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
X = rng.randn(rows, cols)
|
||||||
|
y = rng.randint(low=0, high=classes, size=rows)
|
||||||
|
m = xgb.DMatrix(X, y)
|
||||||
|
booster = xgb.train(
|
||||||
|
{'num_parallel_tree': num_parallel_tree, 'num_class': classes,
|
||||||
|
'predictor': predictor, 'tree_method': 'hist'}, m,
|
||||||
|
num_boost_round=num_boost_round)
|
||||||
|
|
||||||
|
empty = xgb.DMatrix(np.ones(shape=(0, cols)))
|
||||||
|
empty_leaf = booster.predict(empty, pred_leaf=True)
|
||||||
|
assert empty_leaf.shape[0] == 0
|
||||||
|
|
||||||
|
leaf = booster.predict(m, pred_leaf=True)
|
||||||
|
assert leaf.shape[0] == rows
|
||||||
|
assert leaf.shape[1] == classes * num_parallel_tree * num_boost_round
|
||||||
|
|
||||||
|
for i in range(rows):
|
||||||
|
row = leaf[i, ...]
|
||||||
|
for j in range(num_boost_round):
|
||||||
|
start = classes * num_parallel_tree * j
|
||||||
|
end = classes * num_parallel_tree * (j + 1)
|
||||||
|
layer = row[start: end]
|
||||||
|
for c in range(classes):
|
||||||
|
tree_group = layer[c * num_parallel_tree:
|
||||||
|
(c+1) * num_parallel_tree]
|
||||||
|
assert tree_group.shape[0] == num_parallel_tree
|
||||||
|
# no subsampling so tree in same forest should output same
|
||||||
|
# leaf.
|
||||||
|
assert np.all(tree_group == tree_group[0])
|
||||||
|
return leaf
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_leaf():
|
||||||
|
run_predict_leaf('cpu_predictor')
|
||||||
|
|
||||||
|
|
||||||
class TestInplacePredict(unittest.TestCase):
|
class TestInplacePredict(unittest.TestCase):
|
||||||
'''Tests for running inplace prediction'''
|
'''Tests for running inplace prediction'''
|
||||||
def test_predict(self):
|
def test_predict(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user