[R] Support multi-class custom objective. (#9526)

This commit is contained in:
Jiaming Yuan
2023-08-29 08:27:13 +08:00
committed by GitHub
parent 90ef250ea1
commit be6a552956
6 changed files with 106 additions and 26 deletions

View File

@@ -354,12 +354,12 @@ void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data
* @brief Make array interface for other language bindings.
*/
template <typename G, typename H>
auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, std::size_t n_samples,
std::size_t n_targets) {
auto t_grad =
linalg::MakeTensorView(ctx, common::Span{grad, n_samples * n_targets}, n_samples, n_targets);
auto t_hess =
linalg::MakeTensorView(ctx, common::Span{hess, n_samples * n_targets}, n_samples, n_targets);
auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, linalg::Order order,
std::size_t n_samples, std::size_t n_targets) {
auto t_grad = linalg::MakeTensorView(ctx, order, common::Span{grad, n_samples * n_targets},
n_samples, n_targets);
auto t_hess = linalg::MakeTensorView(ctx, order, common::Span{hess, n_samples * n_targets},
n_samples, n_targets);
auto s_grad = linalg::ArrayInterfaceStr(t_grad);
auto s_hess = linalg::ArrayInterfaceStr(t_hess);
return std::make_tuple(s_grad, s_hess);