Support multi-target, fit intercept for hinge. (#9850)
This commit is contained in:
@@ -31,12 +31,10 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
|
||||
auto const& h_result = info.labels.View(DeviceOrd::CPU());
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_labels = labels.View(DeviceOrd::CPU());
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
|
||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||
auto i0 = std::get<0>(tup);
|
||||
auto i1 = std::get<1>(tup);
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, std::size_t j) {
|
||||
// Sliced at second dimension.
|
||||
auto v_1 = in_labels(i0, 0, i1);
|
||||
auto v_0 = h_result(i, j);
|
||||
auto v_1 = in_labels(i, 0, j);
|
||||
CHECK_EQ(v_0, v_1);
|
||||
});
|
||||
}
|
||||
@@ -65,14 +63,13 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
|
||||
auto const& h_result = info.base_margin_.View(DeviceOrd::CPU());
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_margin = base_margin.View(DeviceOrd::CPU());
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
|
||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||
auto i0 = std::get<0>(tup);
|
||||
auto i1 = std::get<1>(tup);
|
||||
// Sliced at second dimension.
|
||||
auto v_1 = in_margin(i0, 0, i1);
|
||||
CHECK_EQ(v_0, v_1);
|
||||
});
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(),
|
||||
[&](std::size_t i, std::size_t j) {
|
||||
// Sliced at second dimension.
|
||||
auto v_0 = h_result(i, j);
|
||||
auto v_1 = in_margin(i, 0, j);
|
||||
CHECK_EQ(v_0, v_1);
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user