Initial support for multi-target tree. (#8616)

* Implement multi-target for hist.

- Add new hist tree builder.
- Move data fetchers for tests.
- Dispatch function calls in gbm base on the tree type.
This commit is contained in:
Jiaming Yuan
2023-03-22 23:49:56 +08:00
committed by GitHub
parent ea04d4c46c
commit 151882dd26
34 changed files with 856 additions and 389 deletions

View File

@@ -87,30 +87,6 @@ bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
: GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
return tree[leaf].LeafValue();
}
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, const size_t predict_offset,
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
const size_t gid = model.tree_info[tree_id];
auto const &tree = *model.trees[tree_id];
auto const &cats = tree.GetCategoriesMatrix();
auto has_categorical = tree.HasCategoricalSplit();
if (has_categorical) {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
} else {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
}
}
}
} // namespace scalar
namespace multi {
@@ -128,7 +104,7 @@ bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat,
}
template <bool has_categorical>
void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree,
void PredValueByOneTree(RegTree::FVec const &p_feats, MultiTargetTree const &tree,
RegTree::CategoricalSplitMatrix const &cats,
linalg::VectorView<float> out_predt) {
bst_node_t const leaf = p_feats.HasMissing()
@@ -140,36 +116,52 @@ void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tre
out_predt(i) += leaf_value(i);
}
}
} // namespace multi
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, const size_t predict_offset,
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
namespace {
void PredictByAllTrees(gbm::GBTreeModel const &model, std::uint32_t const tree_begin,
std::uint32_t const tree_end, std::size_t const predict_offset,
std::vector<RegTree::FVec> const &thread_temp, std::size_t const offset,
std::size_t const block_size, linalg::MatrixView<float> out_predt) {
for (std::uint32_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
auto const &tree = *model.trees.at(tree_id);
auto cats = tree.GetCategoriesMatrix();
auto const &cats = tree.GetCategoriesMatrix();
bool has_categorical = tree.HasCategoricalSplit();
if (has_categorical) {
for (std::size_t i = 0; i < block_size; ++i) {
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
t_predts);
if (tree.IsMultiTarget()) {
if (has_categorical) {
for (std::size_t i = 0; i < block_size; ++i) {
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
multi::PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
t_predts);
}
} else {
for (std::size_t i = 0; i < block_size; ++i) {
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
multi::PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(),
cats, t_predts);
}
}
} else {
for (std::size_t i = 0; i < block_size; ++i) {
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
t_predts);
auto const gid = model.tree_info[tree_id];
if (has_categorical) {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
} else {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
}
}
}
}
} // namespace multi
template <typename DataView>
void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature,
DataView* batch, const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
DataView *batch, const size_t fvec_offset, std::vector<RegTree::FVec> *p_feats) {
for (size_t i = 0; i < block_size; ++i) {
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
if (feats.Size() == 0) {
@@ -181,8 +173,8 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
}
template <typename DataView>
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView *batch,
const size_t fvec_offset, std::vector<RegTree::FVec> *p_feats) {
for (size_t i = 0; i < block_size; ++i) {
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
const SparsePage::Inst inst = (*batch)[batch_offset + i];
@@ -190,9 +182,7 @@ void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batc
}
}
namespace {
static std::size_t constexpr kUnroll = 8;
} // anonymous namespace
struct SparsePageView {
bst_row_t base_rowid;
@@ -292,7 +282,7 @@ class AdapterView {
template <typename DataView, size_t block_of_rows_size>
void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model,
int32_t tree_begin, int32_t tree_end,
std::uint32_t tree_begin, std::uint32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads,
linalg::TensorView<float, 2> out_predt) {
auto &thread_temp = *p_thread_temp;
@@ -310,14 +300,8 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
// process block of rows through all trees to keep cache locality
if (model.learner_model_param->IsVectorLeaf()) {
multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
thread_temp, fvec_offset, block_size, out_predt);
} else {
scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
thread_temp, fvec_offset, block_size, out_predt);
}
PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp,
fvec_offset, block_size, out_predt);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
});
}
@@ -348,7 +332,6 @@ void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
FillNodeMeanValues(tree, 0, mean_values);
}
namespace {
// init thread buffers
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
int prev_thread_temp_size = out->size();