Support column split in GPU predictor (#9343)
This commit is contained in:
@@ -467,7 +467,6 @@ class ColumnSplitHelper {
|
||||
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {
|
||||
auto const &tree = *model_.trees[tree_id];
|
||||
auto const &cats = tree.GetCategoriesMatrix();
|
||||
auto const has_categorical = tree.HasCategoricalSplit();
|
||||
bst_node_t n_nodes = tree.GetNodes().size();
|
||||
|
||||
for (bst_node_t nid = 0; nid < n_nodes; nid++) {
|
||||
@@ -484,16 +483,10 @@ class ColumnSplitHelper {
|
||||
}
|
||||
|
||||
auto const fvalue = feat.GetFvalue(split_index);
|
||||
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||
auto const node_categories =
|
||||
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
||||
if (!common::Decision(node_categories, fvalue)) {
|
||||
decision_bits_.Set(bit_index);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (fvalue >= node.SplitCond()) {
|
||||
auto const decision = tree.HasCategoricalSplit()
|
||||
? GetDecision<true>(node, nid, fvalue, cats)
|
||||
: GetDecision<false>(node, nid, fvalue, cats);
|
||||
if (decision) {
|
||||
decision_bits_.Set(bit_index);
|
||||
}
|
||||
}
|
||||
@@ -511,7 +504,7 @@ class ColumnSplitHelper {
|
||||
if (missing_bits_.Check(bit_index)) {
|
||||
return node.DefaultChild();
|
||||
} else {
|
||||
return node.LeftChild() + decision_bits_.Check(bit_index);
|
||||
return node.LeftChild() + !decision_bits_.Check(bit_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user