[MT-TREE] Support prediction cache and model slicing. (#8968)
- Fix prediction range. - Support prediction cache in mt-hist. - Support model slicing. - Make the booster a Python iterable by defining `__iter__`. - Cleanup removed/deprecated parameters. - A new field in the output model `iteration_indptr` for pointing to the ranges of trees for each iteration.
This commit is contained in:
@@ -1,15 +1,55 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
*/
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "gbtree_model.h"
|
||||
#include "gbtree.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
#include <algorithm> // for transform, max_element
|
||||
#include <cstddef> // for size_t
|
||||
#include <numeric> // for partial_sum
|
||||
#include <ostream> // for operator<<, basic_ostream
|
||||
#include <utility> // for move, pair
|
||||
|
||||
#include "../common/threading_utils.h" // for ParallelFor
|
||||
#include "dmlc/base.h" // for BeginPtr
|
||||
#include "dmlc/io.h" // for Stream
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/json.h" // for Json, get, Integer, Array, FromJson, ToJson, Json...
|
||||
#include "xgboost/learner.h" // for LearnerModelParam
|
||||
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost::gbm {
|
||||
namespace {
|
||||
// For creating the tree indptr from old models.
|
||||
void MakeIndptr(GBTreeModel* out_model) {
|
||||
auto const& tree_info = out_model->tree_info;
|
||||
if (tree_info.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto n_groups = *std::max_element(tree_info.cbegin(), tree_info.cend()) + 1;
|
||||
|
||||
auto& indptr = out_model->iteration_indptr;
|
||||
auto layer_trees = out_model->param.num_parallel_tree * n_groups;
|
||||
CHECK_NE(layer_trees, 0);
|
||||
indptr.resize(out_model->param.num_trees / layer_trees + 1, 0);
|
||||
indptr[0] = 0;
|
||||
|
||||
for (std::size_t i = 1; i < indptr.size(); ++i) {
|
||||
indptr[i] = n_groups * out_model->param.num_parallel_tree;
|
||||
}
|
||||
std::partial_sum(indptr.cbegin(), indptr.cend(), indptr.begin());
|
||||
}
|
||||
|
||||
// Validate the consistency of the model.
|
||||
void Validate(GBTreeModel const& model) {
|
||||
CHECK_EQ(model.trees.size(), model.param.num_trees);
|
||||
CHECK_EQ(model.tree_info.size(), model.param.num_trees);
|
||||
// True even if the model is empty since we should always have 0 as the first element.
|
||||
CHECK_EQ(model.iteration_indptr.back(), model.param.num_trees);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GBTreeModel::Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));
|
||||
|
||||
@@ -61,6 +101,9 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MakeIndptr(this);
|
||||
Validate(*this);
|
||||
}
|
||||
|
||||
void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
@@ -72,10 +115,10 @@ void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
CHECK(ctx_);
|
||||
common::ParallelFor(trees.size(), ctx_->Threads(), [&](auto t) {
|
||||
auto const& tree = trees[t];
|
||||
Json tree_json{Object()};
|
||||
tree->SaveModel(&tree_json);
|
||||
tree_json["id"] = Integer{static_cast<Integer::Int>(t)};
|
||||
trees_json[t] = std::move(tree_json);
|
||||
Json jtree{Object{}};
|
||||
tree->SaveModel(&jtree);
|
||||
jtree["id"] = Integer{static_cast<Integer::Int>(t)};
|
||||
trees_json[t] = std::move(jtree);
|
||||
});
|
||||
|
||||
std::vector<Json> tree_info_json(tree_info.size());
|
||||
@@ -85,6 +128,11 @@ void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
|
||||
out["trees"] = Array(std::move(trees_json));
|
||||
out["tree_info"] = Array(std::move(tree_info_json));
|
||||
|
||||
std::vector<Json> jiteration_indptr(iteration_indptr.size());
|
||||
std::transform(iteration_indptr.cbegin(), iteration_indptr.cend(), jiteration_indptr.begin(),
|
||||
[](bst_tree_t i) { return Integer{i}; });
|
||||
out["iteration_indptr"] = Array{std::move(jiteration_indptr)};
|
||||
}
|
||||
|
||||
void GBTreeModel::LoadModel(Json const& in) {
|
||||
@@ -93,22 +141,59 @@ void GBTreeModel::LoadModel(Json const& in) {
|
||||
trees.clear();
|
||||
trees_to_update.clear();
|
||||
|
||||
auto const& jmodel = get<Object const>(in);
|
||||
|
||||
auto const& trees_json = get<Array const>(in["trees"]);
|
||||
trees.resize(trees_json.size());
|
||||
CHECK_EQ(trees_json.size(), param.num_trees);
|
||||
trees.resize(param.num_trees);
|
||||
|
||||
auto const& tree_info_json = get<Array const>(in["tree_info"]);
|
||||
CHECK_EQ(tree_info_json.size(), param.num_trees);
|
||||
tree_info.resize(param.num_trees);
|
||||
|
||||
CHECK(ctx_);
|
||||
common::ParallelFor(trees_json.size(), ctx_->Threads(), [&](auto t) {
|
||||
auto tree_id = get<Integer>(trees_json[t]["id"]);
|
||||
trees.at(tree_id).reset(new RegTree());
|
||||
trees.at(tree_id)->LoadModel(trees_json[t]);
|
||||
|
||||
common::ParallelFor(param.num_trees, ctx_->Threads(), [&](auto t) {
|
||||
auto tree_id = get<Integer const>(trees_json[t]["id"]);
|
||||
trees.at(tree_id).reset(new RegTree{});
|
||||
trees[tree_id]->LoadModel(trees_json[t]);
|
||||
});
|
||||
|
||||
tree_info.resize(param.num_trees);
|
||||
auto const& tree_info_json = get<Array const>(in["tree_info"]);
|
||||
for (int32_t i = 0; i < param.num_trees; ++i) {
|
||||
for (bst_tree_t i = 0; i < param.num_trees; ++i) {
|
||||
tree_info[i] = get<Integer const>(tree_info_json[i]);
|
||||
}
|
||||
|
||||
auto indptr_it = jmodel.find("iteration_indptr");
|
||||
iteration_indptr.clear();
|
||||
if (indptr_it != jmodel.cend()) {
|
||||
auto const& vec = get<Array const>(indptr_it->second);
|
||||
iteration_indptr.resize(vec.size());
|
||||
std::transform(vec.cbegin(), vec.cend(), iteration_indptr.begin(),
|
||||
[](Json const& v) { return get<Integer const>(v); });
|
||||
} else {
|
||||
MakeIndptr(this);
|
||||
}
|
||||
|
||||
Validate(*this);
|
||||
}
|
||||
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
bst_tree_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) {
|
||||
CHECK(!iteration_indptr.empty());
|
||||
CHECK_EQ(iteration_indptr.back(), param.num_trees);
|
||||
bst_tree_t n_new_trees{0};
|
||||
|
||||
if (learner_model_param->IsVectorLeaf()) {
|
||||
n_new_trees += new_trees.front().size();
|
||||
this->CommitModelGroup(std::move(new_trees.front()), 0);
|
||||
} else {
|
||||
for (bst_target_t gidx{0}; gidx < learner_model_param->OutputLength(); ++gidx) {
|
||||
n_new_trees += new_trees[gidx].size();
|
||||
this->CommitModelGroup(std::move(new_trees[gidx]), gidx);
|
||||
}
|
||||
}
|
||||
|
||||
iteration_indptr.push_back(n_new_trees + iteration_indptr.back());
|
||||
Validate(*this);
|
||||
return n_new_trees;
|
||||
}
|
||||
} // namespace xgboost::gbm
|
||||
|
||||
Reference in New Issue
Block a user