[MT-TREE] Support prediction cache and model slicing. (#8968)
- Fix prediction range. - Support prediction cache in mt-hist. - Support model slicing. - Make the booster a Python iterable by defining `__iter__`. - Cleanup removed/deprecated parameters. - A new field in the output model `iteration_indptr` for pointing to the ranges of trees for each iteration.
This commit is contained in:
@@ -677,9 +677,6 @@ template <typename Partitioner>
|
||||
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
|
||||
std::vector<Partitioner> const &partitioner,
|
||||
linalg::VectorView<float> out_preds) {
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
CHECK(p_last_tree);
|
||||
auto const &tree = *p_last_tree;
|
||||
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
|
||||
size_t n_nodes = p_last_tree->GetNodes().size();
|
||||
@@ -687,7 +684,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
|
||||
CHECK_EQ(part.Size(), n_nodes);
|
||||
common::BlockedSpace2d space(
|
||||
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
|
||||
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
|
||||
auto const &rowset = part[nidx];
|
||||
auto leaf_value = tree[nidx].LeafValue();
|
||||
@@ -698,5 +695,42 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Partitioner>
|
||||
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
|
||||
std::vector<Partitioner> const &partitioner,
|
||||
linalg::MatrixView<float> out_preds) {
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
CHECK(p_last_tree);
|
||||
|
||||
auto const &tree = *p_last_tree;
|
||||
if (!tree.IsMultiTarget()) {
|
||||
UpdatePredictionCacheImpl(ctx, p_last_tree, partitioner, out_preds.Slice(linalg::All(), 0));
|
||||
return;
|
||||
}
|
||||
|
||||
auto const *mttree = tree.GetMultiTargetTree();
|
||||
auto n_nodes = mttree->Size();
|
||||
auto n_targets = tree.NumTargets();
|
||||
CHECK_EQ(out_preds.Shape(1), n_targets);
|
||||
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
|
||||
|
||||
for (auto &part : partitioner) {
|
||||
CHECK_EQ(part.Size(), n_nodes);
|
||||
common::BlockedSpace2d space(
|
||||
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
|
||||
if (tree.IsLeaf(nidx)) {
|
||||
auto const &rowset = part[nidx];
|
||||
auto leaf_value = mttree->LeafValue(nidx);
|
||||
for (std::size_t const *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
for (std::size_t i = 0; i < n_targets; ++i) {
|
||||
out_preds(*it, i) += leaf_value(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||
|
||||
@@ -116,7 +116,7 @@ class GloablApproxBuilder {
|
||||
return nodes.front();
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(DMatrix const *data, linalg::VectorView<float> out_preds) const {
|
||||
void UpdatePredictionCache(DMatrix const *data, linalg::MatrixView<float> out_preds) const {
|
||||
monitor_->Start(__func__);
|
||||
// Caching prediction seems redundant for approx tree method, as sketching takes up
|
||||
// majority of training time.
|
||||
@@ -303,7 +303,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
|
||||
if (data != cached_ || !pimpl_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -517,7 +517,7 @@ struct GPUHistMakerDevice {
|
||||
});
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
|
||||
bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) {
|
||||
if (positions.empty()) {
|
||||
return false;
|
||||
}
|
||||
@@ -535,11 +535,12 @@ struct GPUHistMakerDevice {
|
||||
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice,
|
||||
ctx_->CUDACtx()->Stream()));
|
||||
auto d_nodes = dh::ToSpan(nodes);
|
||||
CHECK_EQ(out_preds_d.Shape(1), 1);
|
||||
dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(),
|
||||
[=] XGBOOST_DEVICE(std::size_t idx) mutable {
|
||||
bst_node_t nidx = d_position[idx];
|
||||
auto weight = d_nodes[nidx].LeafValue();
|
||||
out_preds_d(idx) += weight;
|
||||
out_preds_d(idx, 0) += weight;
|
||||
});
|
||||
return true;
|
||||
}
|
||||
@@ -858,7 +859,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::VectorView<bst_float> p_out_preds) override {
|
||||
linalg::MatrixView<bst_float> p_out_preds) override {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -125,6 +125,7 @@ class MultiTargetHistBuilder {
|
||||
std::vector<CommonRowPartitioner> partitioner_;
|
||||
// Pointer to last updated tree, used for update prediction cache.
|
||||
RegTree const *p_last_tree_{nullptr};
|
||||
DMatrix const * p_last_fmat_{nullptr};
|
||||
|
||||
ObjInfo const *task_{nullptr};
|
||||
|
||||
@@ -147,6 +148,7 @@ class MultiTargetHistBuilder {
|
||||
void InitData(DMatrix *p_fmat, RegTree const *p_tree) {
|
||||
monitor_->Start(__func__);
|
||||
|
||||
p_last_fmat_ = p_fmat;
|
||||
std::size_t page_id = 0;
|
||||
bst_bin_t n_total_bins = 0;
|
||||
partitioner_.clear();
|
||||
@@ -312,6 +314,19 @@ class MultiTargetHistBuilder {
|
||||
task_{task} {
|
||||
monitor_->Init(__func__);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(DMatrix const *data, linalg::MatrixView<float> out_preds) const {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
return false;
|
||||
}
|
||||
monitor_->Start(__func__);
|
||||
CHECK_EQ(out_preds.Size(), data->Info().num_row_ * p_last_tree_->NumTargets());
|
||||
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, out_preds);
|
||||
monitor_->Stop(__func__);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class HistBuilder {
|
||||
@@ -347,7 +362,7 @@ class HistBuilder {
|
||||
monitor_->Init(__func__);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(DMatrix const *data, linalg::VectorView<float> out_preds) const {
|
||||
bool UpdatePredictionCache(DMatrix const *data, linalg::MatrixView<float> out_preds) const {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
@@ -582,12 +597,11 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
|
||||
if (p_impl_) {
|
||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (p_mtimpl_) {
|
||||
// Not yet supported.
|
||||
return false;
|
||||
return p_mtimpl_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user