Support multi-target, fit intercept for hinge. (#9850)
This commit is contained in:
@@ -23,7 +23,7 @@ void TestElementWiseKernel() {
|
||||
ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
|
||||
// CPU view
|
||||
t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All());
|
||||
size_t k = 0;
|
||||
std::size_t k = 0;
|
||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||
for (size_t j = 0; j < l.Shape(2); ++j) {
|
||||
ASSERT_EQ(k++, t(i, j));
|
||||
@@ -31,7 +31,15 @@ void TestElementWiseKernel() {
|
||||
}
|
||||
|
||||
t = l.View(device).Slice(linalg::All(), 1, linalg::All());
|
||||
ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); });
|
||||
cuda_impl::ElementWiseKernel(
|
||||
t, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { t(i, j) = i + j; });
|
||||
|
||||
t = l.Slice(linalg::All(), 1, linalg::All());
|
||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||
for (size_t j = 0; j < l.Shape(2); ++j) {
|
||||
ASSERT_EQ(i + j, t(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user