Partitioner for multi-target tree. (#8922)

This commit is contained in:
Jiaming Yuan
2023-03-16 18:49:34 +08:00
committed by GitHub
parent 26209a42a5
commit a093770f36
8 changed files with 239 additions and 178 deletions

View File

@@ -1,17 +1,20 @@
/*!
* Copyright 2021-2022, XGBoost contributors.
/**
* Copyright 2021-2023 by XGBoost contributors.
*/
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#include <xgboost/tree_model.h>
#include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for Constant, Vector
#include <xgboost/logging.h> // for CHECK
#include <xgboost/tree_model.h> // for RegTree
#include <vector>
#include <vector> // for vector
#include "../../../src/tree/hist/expand_entry.h"
#include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry, MultiExpandEntry
namespace xgboost {
namespace tree {
namespace xgboost::tree {
inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
CHECK(!tree->IsMultiTarget());
tree->ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
@@ -21,6 +24,22 @@ inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntr
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // namespace tree
} // namespace xgboost
inline void GetMultiSplitForTest(RegTree *tree, float split_value,
std::vector<MultiExpandEntry> *candidates) {
CHECK(tree->IsMultiTarget());
auto n_targets = tree->NumTargets();
Context ctx;
linalg::Vector<float> base_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
linalg::Vector<float> left_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
linalg::Vector<float> right_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
tree->ExpandNode(/*nidx=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, base_weight.HostView(), left_weight.HostView(),
right_weight.HostView());
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // namespace xgboost::tree
#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_