Support column split in GPU predictor (#9343)

This commit is contained in:
Rong Ou
2023-07-02 13:05:34 -07:00
committed by GitHub
parent f90771eec6
commit 3a0f787703
5 changed files with 288 additions and 25 deletions

View File

@@ -467,7 +467,6 @@ class ColumnSplitHelper {
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {
auto const &tree = *model_.trees[tree_id];
auto const &cats = tree.GetCategoriesMatrix();
auto const has_categorical = tree.HasCategoricalSplit();
bst_node_t n_nodes = tree.GetNodes().size();
for (bst_node_t nid = 0; nid < n_nodes; nid++) {
@@ -484,16 +483,10 @@ class ColumnSplitHelper {
}
auto const fvalue = feat.GetFvalue(split_index);
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto const node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
if (!common::Decision(node_categories, fvalue)) {
decision_bits_.Set(bit_index);
}
continue;
}
if (fvalue >= node.SplitCond()) {
auto const decision = tree.HasCategoricalSplit()
? GetDecision<true>(node, nid, fvalue, cats)
: GetDecision<false>(node, nid, fvalue, cats);
if (decision) {
decision_bits_.Set(bit_index);
}
}
@@ -511,7 +504,7 @@ class ColumnSplitHelper {
if (missing_bits_.Check(bit_index)) {
return node.DefaultChild();
} else {
return node.LeftChild() + decision_bits_.Check(bit_index);
return node.LeftChild() + !decision_bits_.Check(bit_index);
}
}