/** * Copyright 2023 by XGBoost Contributors */ #ifndef XGBOOST_TREE_IO_UTILS_H_ #define XGBOOST_TREE_IO_UTILS_H_ #include // for string #include // for enable_if_t, is_same, conditional_t #include // for vector #include "xgboost/json.h" // for Json namespace xgboost { template using FloatArrayT = std::conditional_t; template using U8ArrayT = std::conditional_t; template using I32ArrayT = std::conditional_t; template using I64ArrayT = std::conditional_t; template using IndexArrayT = std::conditional_t, I32ArrayT>; // typed array, not boolean template std::enable_if_t::value && !std::is_same::value, T> GetElem( std::vector const& arr, size_t i) { return arr[i]; } // typed array boolean template std::enable_if_t::value && std::is_same::value && std::is_same::value, bool> GetElem(std::vector const& arr, size_t i) { return arr[i] == 1; } // json array template std::enable_if_t< std::is_same::value, std::conditional_t::value, int64_t, std::conditional_t::value, bool, float>>> GetElem(std::vector const& arr, size_t i) { if (std::is_same::value && !IsA(arr[i])) { return get(arr[i]) == 1; } return get(arr[i]); } namespace tree_field { inline std::string const kLossChg{"loss_changes"}; inline std::string const kSumHess{"sum_hessian"}; inline std::string const kBaseWeight{"base_weights"}; inline std::string const kSplitIdx{"split_indices"}; inline std::string const kSplitCond{"split_conditions"}; inline std::string const kDftLeft{"default_left"}; inline std::string const kParent{"parents"}; inline std::string const kLeft{"left_children"}; inline std::string const kRight{"right_children"}; } // namespace tree_field } // namespace xgboost #endif // XGBOOST_TREE_IO_UTILS_H_