Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) by Contributors 2020
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
@@ -12,9 +12,7 @@
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/survival_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
namespace xgboost::common {
|
||||
TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) {
|
||||
auto ctx = MakeCUDACtx(GPUIDX);
|
||||
std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", &ctx));
|
||||
@@ -65,14 +63,14 @@ static inline void CheckGPairOverGridPoints(
|
||||
preds[i] = std::log(std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low));
|
||||
}
|
||||
|
||||
HostDeviceVector<GradientPair> out_gpair;
|
||||
linalg::Matrix<GradientPair> out_gpair;
|
||||
obj->GetGradient(HostDeviceVector<bst_float>(preds), info, 1, &out_gpair);
|
||||
const auto& gpair = out_gpair.HostVector();
|
||||
const auto gpair = out_gpair.HostView();
|
||||
CHECK_EQ(num_point, expected_grad.size());
|
||||
CHECK_EQ(num_point, expected_hess.size());
|
||||
for (int i = 0; i < num_point; ++i) {
|
||||
EXPECT_NEAR(gpair[i].GetGrad(), expected_grad[i], ftol);
|
||||
EXPECT_NEAR(gpair[i].GetHess(), expected_hess[i], ftol);
|
||||
EXPECT_NEAR(gpair(i).GetGrad(), expected_grad[i], ftol);
|
||||
EXPECT_NEAR(gpair(i).GetHess(), expected_hess[i], ftol);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,5 +167,4 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairIntervalCensoredLabels)) {
|
||||
0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f });
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
Reference in New Issue
Block a user