Initial support for multi-target tree. (#8616)
* Implement multi-target for hist. - Add new hist tree builder. - Move data fetchers for tests. - Dispatch function calls in gbm base on the tree type.
This commit is contained in:
@@ -87,30 +87,6 @@ bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
|
||||
: GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||
return tree[leaf].LeafValue();
|
||||
}
|
||||
|
||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, const size_t predict_offset,
|
||||
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
|
||||
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
|
||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||
const size_t gid = model.tree_info[tree_id];
|
||||
auto const &tree = *model.trees[tree_id];
|
||||
auto const &cats = tree.GetCategoriesMatrix();
|
||||
auto has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
if (has_categorical) {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
out_predt(predict_offset + i, gid) +=
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
} else {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
out_predt(predict_offset + i, gid) +=
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace scalar
|
||||
|
||||
namespace multi {
|
||||
@@ -128,7 +104,7 @@ bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat,
|
||||
}
|
||||
|
||||
template <bool has_categorical>
|
||||
void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree,
|
||||
void PredValueByOneTree(RegTree::FVec const &p_feats, MultiTargetTree const &tree,
|
||||
RegTree::CategoricalSplitMatrix const &cats,
|
||||
linalg::VectorView<float> out_predt) {
|
||||
bst_node_t const leaf = p_feats.HasMissing()
|
||||
@@ -140,36 +116,52 @@ void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tre
|
||||
out_predt(i) += leaf_value(i);
|
||||
}
|
||||
}
|
||||
} // namespace multi
|
||||
|
||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, const size_t predict_offset,
|
||||
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
|
||||
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
|
||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||
namespace {
|
||||
void PredictByAllTrees(gbm::GBTreeModel const &model, std::uint32_t const tree_begin,
|
||||
std::uint32_t const tree_end, std::size_t const predict_offset,
|
||||
std::vector<RegTree::FVec> const &thread_temp, std::size_t const offset,
|
||||
std::size_t const block_size, linalg::MatrixView<float> out_predt) {
|
||||
for (std::uint32_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||
auto const &tree = *model.trees.at(tree_id);
|
||||
auto cats = tree.GetCategoriesMatrix();
|
||||
auto const &cats = tree.GetCategoriesMatrix();
|
||||
bool has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
if (has_categorical) {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
||||
t_predts);
|
||||
if (tree.IsMultiTarget()) {
|
||||
if (has_categorical) {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||
multi::PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
||||
t_predts);
|
||||
}
|
||||
} else {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||
multi::PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(),
|
||||
cats, t_predts);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||
PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
||||
t_predts);
|
||||
auto const gid = model.tree_info[tree_id];
|
||||
if (has_categorical) {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
out_predt(predict_offset + i, gid) +=
|
||||
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
} else {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
out_predt(predict_offset + i, gid) +=
|
||||
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace multi
|
||||
|
||||
template <typename DataView>
|
||||
void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature,
|
||||
DataView* batch, const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
||||
DataView *batch, const size_t fvec_offset, std::vector<RegTree::FVec> *p_feats) {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
||||
if (feats.Size() == 0) {
|
||||
@@ -181,8 +173,8 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
|
||||
}
|
||||
|
||||
template <typename DataView>
|
||||
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
|
||||
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
||||
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView *batch,
|
||||
const size_t fvec_offset, std::vector<RegTree::FVec> *p_feats) {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
||||
const SparsePage::Inst inst = (*batch)[batch_offset + i];
|
||||
@@ -190,9 +182,7 @@ void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batc
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
static std::size_t constexpr kUnroll = 8;
|
||||
} // anonymous namespace
|
||||
|
||||
struct SparsePageView {
|
||||
bst_row_t base_rowid;
|
||||
@@ -292,7 +282,7 @@ class AdapterView {
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size>
|
||||
void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model,
|
||||
int32_t tree_begin, int32_t tree_end,
|
||||
std::uint32_t tree_begin, std::uint32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads,
|
||||
linalg::TensorView<float, 2> out_predt) {
|
||||
auto &thread_temp = *p_thread_temp;
|
||||
@@ -310,14 +300,8 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod
|
||||
|
||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
|
||||
// process block of rows through all trees to keep cache locality
|
||||
if (model.learner_model_param->IsVectorLeaf()) {
|
||||
multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
|
||||
thread_temp, fvec_offset, block_size, out_predt);
|
||||
} else {
|
||||
scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
|
||||
thread_temp, fvec_offset, block_size, out_predt);
|
||||
}
|
||||
|
||||
PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp,
|
||||
fvec_offset, block_size, out_predt);
|
||||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
||||
});
|
||||
}
|
||||
@@ -348,7 +332,6 @@ void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
|
||||
FillNodeMeanValues(tree, 0, mean_values);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// init thread buffers
|
||||
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
|
||||
int prev_thread_temp_size = out->size();
|
||||
|
||||
Reference in New Issue
Block a user