xgboost/src/tree/io_utils.h
Jiaming Yuan 5feee8d4a9
Define core multi-target regression tree structure. (#8884)
- Define a new tree struct embedded in the `RegTree`.
- Provide dispatching functions in `RegTree`.
- Fix some c++-17 warnings about the use of nodiscard (currently we disable the warning on
  the CI).
- Use uint32_t instead of size_t for `bst_target_t` as it has a defined size and can be used
  as part of dmlc parameter.
- Hide the `Segment` struct inside the categorical split matrix.
2023-03-09 19:03:06 +08:00

66 lines
2.3 KiB
C++

/**
* Copyright 2023 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_IO_UTILS_H_
#define XGBOOST_TREE_IO_UTILS_H_
#include <string> // for string
#include <type_traits> // for enable_if_t, is_same, conditional_t
#include <vector> // for vector
#include "xgboost/json.h" // for Json
namespace xgboost {
template <bool typed>
using FloatArrayT = std::conditional_t<typed, F32Array const, Array const>;
template <bool typed>
using U8ArrayT = std::conditional_t<typed, U8Array const, Array const>;
template <bool typed>
using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
template <bool typed>
using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
template <bool typed, bool feature_is_64>
using IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT<typed>, I32ArrayT<typed>>;
// typed array, not boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
std::vector<T> const& arr, size_t i) {
return arr[i];
}
// typed array boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
std::is_same<JT, Boolean>::value,
bool>
GetElem(std::vector<T> const& arr, size_t i) {
return arr[i] == 1;
}
// json array
template <typename JT, typename T>
std::enable_if_t<
std::is_same<T, Json>::value,
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
GetElem(std::vector<T> const& arr, size_t i) {
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
return get<Integer const>(arr[i]) == 1;
}
return get<JT const>(arr[i]);
}
namespace tree_field {
inline std::string const kLossChg{"loss_changes"};
inline std::string const kSumHess{"sum_hessian"};
inline std::string const kBaseWeight{"base_weights"};
inline std::string const kSplitIdx{"split_indices"};
inline std::string const kSplitCond{"split_conditions"};
inline std::string const kDftLeft{"default_left"};
inline std::string const kParent{"parents"};
inline std::string const kLeft{"left_children"};
inline std::string const kRight{"right_children"};
} // namespace tree_field
} // namespace xgboost
#endif // XGBOOST_TREE_IO_UTILS_H_