[MT-TREE] Support prediction cache and model slicing. (#8968)

- Fix prediction range.
- Support prediction cache in mt-hist.
- Support model slicing.
- Make the booster a Python iterable by defining `__iter__`.
- Cleanup removed/deprecated parameters.
- A new field in the output model `iteration_indptr` for pointing to the ranges of trees for each iteration.
This commit is contained in:
Jiaming Yuan
2023-03-27 23:10:54 +08:00
committed by GitHub
parent c2b3a13e70
commit acc110c251
30 changed files with 502 additions and 343 deletions

View File

@@ -677,9 +677,6 @@ template <typename Partitioner>
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
linalg::VectorView<float> out_preds) {
CHECK_GT(out_preds.Size(), 0U);
CHECK(p_last_tree);
auto const &tree = *p_last_tree;
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
size_t n_nodes = p_last_tree->GetNodes().size();
@@ -687,7 +684,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
auto const &rowset = part[nidx];
auto leaf_value = tree[nidx].LeafValue();
@@ -698,5 +695,42 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
});
}
}
template <typename Partitioner>
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
linalg::MatrixView<float> out_preds) {
CHECK_GT(out_preds.Size(), 0U);
CHECK(p_last_tree);
auto const &tree = *p_last_tree;
if (!tree.IsMultiTarget()) {
UpdatePredictionCacheImpl(ctx, p_last_tree, partitioner, out_preds.Slice(linalg::All(), 0));
return;
}
auto const *mttree = tree.GetMultiTargetTree();
auto n_nodes = mttree->Size();
auto n_targets = tree.NumTargets();
CHECK_EQ(out_preds.Shape(1), n_targets);
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
if (tree.IsLeaf(nidx)) {
auto const &rowset = part[nidx];
auto leaf_value = mttree->LeafValue(nidx);
for (std::size_t const *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
for (std::size_t i = 0; i < n_targets; ++i) {
out_preds(*it, i) += leaf_value(i);
}
}
}
});
}
}
} // namespace xgboost::tree
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_

View File

@@ -116,7 +116,7 @@ class GloablApproxBuilder {
return nodes.front();
}
void UpdatePredictionCache(DMatrix const *data, linalg::VectorView<float> out_preds) const {
void UpdatePredictionCache(DMatrix const *data, linalg::MatrixView<float> out_preds) const {
monitor_->Start(__func__);
// Caching prediction seems redundant for approx tree method, as sketching takes up
// majority of training time.
@@ -303,7 +303,7 @@ class GlobalApproxUpdater : public TreeUpdater {
}
}
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
if (data != cached_ || !pimpl_) {
return false;
}

View File

@@ -517,7 +517,7 @@ struct GPUHistMakerDevice {
});
}
bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) {
if (positions.empty()) {
return false;
}
@@ -535,11 +535,12 @@ struct GPUHistMakerDevice {
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice,
ctx_->CUDACtx()->Stream()));
auto d_nodes = dh::ToSpan(nodes);
CHECK_EQ(out_preds_d.Shape(1), 1);
dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(),
[=] XGBOOST_DEVICE(std::size_t idx) mutable {
bst_node_t nidx = d_position[idx];
auto weight = d_nodes[nidx].LeafValue();
out_preds_d(idx) += weight;
out_preds_d(idx, 0) += weight;
});
return true;
}
@@ -858,7 +859,7 @@ class GPUHistMaker : public TreeUpdater {
}
bool UpdatePredictionCache(const DMatrix* data,
linalg::VectorView<bst_float> p_out_preds) override {
linalg::MatrixView<bst_float> p_out_preds) override {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}

View File

@@ -125,6 +125,7 @@ class MultiTargetHistBuilder {
std::vector<CommonRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
RegTree const *p_last_tree_{nullptr};
DMatrix const * p_last_fmat_{nullptr};
ObjInfo const *task_{nullptr};
@@ -147,6 +148,7 @@ class MultiTargetHistBuilder {
void InitData(DMatrix *p_fmat, RegTree const *p_tree) {
monitor_->Start(__func__);
p_last_fmat_ = p_fmat;
std::size_t page_id = 0;
bst_bin_t n_total_bins = 0;
partitioner_.clear();
@@ -312,6 +314,19 @@ class MultiTargetHistBuilder {
task_{task} {
monitor_->Init(__func__);
}
bool UpdatePredictionCache(DMatrix const *data, linalg::MatrixView<float> out_preds) const {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
return false;
}
monitor_->Start(__func__);
CHECK_EQ(out_preds.Size(), data->Info().num_row_ * p_last_tree_->NumTargets());
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, out_preds);
monitor_->Stop(__func__);
return true;
}
};
class HistBuilder {
@@ -347,7 +362,7 @@ class HistBuilder {
monitor_->Init(__func__);
}
bool UpdatePredictionCache(DMatrix const *data, linalg::VectorView<float> out_preds) const {
bool UpdatePredictionCache(DMatrix const *data, linalg::MatrixView<float> out_preds) const {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
@@ -582,12 +597,11 @@ class QuantileHistMaker : public TreeUpdater {
}
}
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
if (p_impl_) {
return p_impl_->UpdatePredictionCache(data, out_preds);
} else if (p_mtimpl_) {
// Not yet supported.
return false;
return p_mtimpl_->UpdatePredictionCache(data, out_preds);
} else {
return false;
}