Initial support for multi-target tree. (#8616)

* Implement multi-target for hist.

- Add new hist tree builder.
- Move data fetchers for tests.
- Dispatch function calls in gbm base on the tree type.
This commit is contained in:
Jiaming Yuan
2023-03-22 23:49:56 +08:00
committed by GitHub
parent ea04d4c46c
commit 151882dd26
34 changed files with 856 additions and 389 deletions

View File

@@ -286,8 +286,8 @@ struct LearnerModelParamLegacy;
* \brief Strategy for building multi-target models.
*/
enum class MultiStrategy : std::int32_t {
kComposite = 0,
kMonolithic = 1,
kOneOutputPerTree = 0,
kMultiOutputTree = 1,
};
/**
@@ -317,7 +317,7 @@ struct LearnerModelParam {
/**
* \brief Strategy for building multi-target models.
*/
MultiStrategy multi_strategy{MultiStrategy::kComposite};
MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree};
LearnerModelParam() = default;
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
@@ -338,7 +338,7 @@ struct LearnerModelParam {
void Copy(LearnerModelParam const& that);
[[nodiscard]] bool IsVectorLeaf() const noexcept {
return multi_strategy == MultiStrategy::kMonolithic;
return multi_strategy == MultiStrategy::kMultiOutputTree;
}
[[nodiscard]] bst_target_t OutputLength() const noexcept { return this->num_output_group; }
[[nodiscard]] bst_target_t LeafLength() const noexcept {

View File

@@ -530,17 +530,17 @@ class TensorView {
/**
* \brief Number of items in the tensor.
*/
LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; }
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
/**
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
*/
LINALG_HD [[nodiscard]] bool Contiguous() const {
[[nodiscard]] LINALG_HD bool Contiguous() const {
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
}
/**
* \brief Whether it's a c-contiguous array.
*/
LINALG_HD [[nodiscard]] bool CContiguous() const {
[[nodiscard]] LINALG_HD bool CContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape.
@@ -550,7 +550,7 @@ class TensorView {
/**
* \brief Whether it's a f-contiguous array.
*/
LINALG_HD [[nodiscard]] bool FContiguous() const {
[[nodiscard]] LINALG_HD bool FContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape.