Support multi-target, fit intercept for hinge. (#9850)

This commit is contained in:
Jiaming Yuan
2023-12-08 05:50:41 +08:00
committed by GitHub
parent 39c637ee19
commit 42de9206fc
8 changed files with 221 additions and 155 deletions

View File

@@ -31,12 +31,10 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
auto const& h_result = info.labels.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_labels = labels.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);
auto i1 = std::get<1>(tup);
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, std::size_t j) {
// Sliced at second dimension.
auto v_1 = in_labels(i0, 0, i1);
auto v_0 = h_result(i, j);
auto v_1 = in_labels(i, 0, j);
CHECK_EQ(v_0, v_1);
});
}
@@ -65,14 +63,13 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
auto const& h_result = info.base_margin_.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_margin = base_margin.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);
auto i1 = std::get<1>(tup);
// Sliced at second dimension.
auto v_1 = in_margin(i0, 0, i1);
CHECK_EQ(v_0, v_1);
});
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(),
[&](std::size_t i, std::size_t j) {
// Sliced at second dimension.
auto v_0 = h_result(i, j);
auto v_1 = in_margin(i, 0, j);
CHECK_EQ(v_0, v_1);
});
}
}
} // namespace xgboost