/** * Copyright 2021-2023 by XGBoost contributors. */ #ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ #define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ #include // for Context #include // for Constant, Vector #include // for CHECK #include // for RegTree #include // for vector #include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry, MultiExpandEntry namespace xgboost::tree { inline void GetSplit(RegTree *tree, float split_value, std::vector *candidates) { CHECK(!tree->IsMultiTarget()); tree->ExpandNode( /*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value, /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); candidates->front().split.split_value = split_value; candidates->front().split.sindex = 0; candidates->front().split.sindex |= (1U << 31); } inline void GetMultiSplitForTest(RegTree *tree, float split_value, std::vector *candidates) { CHECK(tree->IsMultiTarget()); auto n_targets = tree->NumTargets(); Context ctx; linalg::Vector base_weight{linalg::Constant(&ctx, 0.0f, n_targets)}; linalg::Vector left_weight{linalg::Constant(&ctx, 0.0f, n_targets)}; linalg::Vector right_weight{linalg::Constant(&ctx, 0.0f, n_targets)}; tree->ExpandNode(/*nidx=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value, /*default_left=*/true, base_weight.HostView(), left_weight.HostView(), right_weight.HostView()); candidates->front().split.split_value = split_value; candidates->front().split.sindex = 0; candidates->front().split.sindex |= (1U << 31); } } // namespace xgboost::tree #endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_