Support multi-target, fit intercept for hinge. (#9850)

This commit is contained in:
Jiaming Yuan
2023-12-08 05:50:41 +08:00
committed by GitHub
parent 39c637ee19
commit 42de9206fc
8 changed files with 221 additions and 155 deletions

View File

@@ -23,7 +23,7 @@ void TestElementWiseKernel() {
ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
// CPU view
t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All());
size_t k = 0;
std::size_t k = 0;
for (size_t i = 0; i < l.Shape(0); ++i) {
for (size_t j = 0; j < l.Shape(2); ++j) {
ASSERT_EQ(k++, t(i, j));
@@ -31,7 +31,15 @@ void TestElementWiseKernel() {
}
t = l.View(device).Slice(linalg::All(), 1, linalg::All());
ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); });
cuda_impl::ElementWiseKernel(
t, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { t(i, j) = i + j; });
t = l.Slice(linalg::All(), 1, linalg::All());
for (size_t i = 0; i < l.Shape(0); ++i) {
for (size_t j = 0; j < l.Shape(2); ++j) {
ASSERT_EQ(i + j, t(i, j));
}
}
}
{