/** * Copyright 2023 by XGBoost contributors */ #include "quantile_loss_utils.h" #include // std::isspace #include // std::istream #include // std::ostream #include // std::string #include // std::vector #include "xgboost/json.h" // F32Array,TypeCheck,get,Number #include "xgboost/json_io.h" // JsonWriter namespace xgboost { namespace common { std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) { auto const& t = array.Get(); xgboost::F32Array arr{t.size()}; for (std::size_t i = 0; i < t.size(); ++i) { arr.Set(i, t[i]); } std::vector stream; xgboost::JsonWriter writer{&stream}; arr.Save(&writer); for (auto c : stream) { os << c; } return os; } std::istream& operator>>(std::istream& is, ParamFloatArray& array) { auto& t = array.Get(); t.clear(); std::string str; while (!is.eof()) { std::string tmp; is >> tmp; str += tmp; } std::size_t head{0}; // unify notation for parsing. while (std::isspace(str[head])) { ++head; } if (str[head] == '(') { str[head] = '['; } auto tail = str.size() - 1; while (std::isspace(str[tail])) { --tail; } if (str[tail] == ')') { str[tail] = ']'; } auto jarr = xgboost::Json::Load(xgboost::StringView{str}); // return if there's only one element if (xgboost::IsA(jarr)) { t.emplace_back(xgboost::get(jarr)); return is; } auto jvec = xgboost::get(jarr); for (auto v : jvec) { xgboost::TypeCheck(v, "alpha"); t.emplace_back(get(v)); } return is; } DMLC_REGISTER_PARAMETER(QuantileLossParam); } // namespace common } // namespace xgboost