Reduce base margin to 2 dim for now. (#7455)
This commit is contained in:
@@ -55,24 +55,23 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
}
|
||||
{
|
||||
// base margin
|
||||
linalg::Tensor<float, 4> base_margin;
|
||||
base_margin.Reshape(4, 3, 2, 3);
|
||||
linalg::Tensor<float, 3> base_margin;
|
||||
base_margin.Reshape(4, 2, 3);
|
||||
auto& h_margin = base_margin.Data()->HostVector();
|
||||
std::iota(h_margin.begin(), h_margin.end(), 0.0);
|
||||
auto t_margin = base_margin.View(device).Slice(linalg::All(), linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_margin.Shape().size(), 3);
|
||||
auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_margin.Shape().size(), 2);
|
||||
|
||||
info.SetInfo("base_margin", StringView{t_margin.ArrayInterfaceStr()});
|
||||
auto const& h_result = info.base_margin_.View(-1);
|
||||
ASSERT_EQ(h_result.Shape().size(), 3);
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_margin = base_margin.View(-1);
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
|
||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||
auto i0 = std::get<0>(tup);
|
||||
auto i1 = std::get<1>(tup);
|
||||
auto i2 = std::get<2>(tup);
|
||||
// Sliced at 3^th dimension.
|
||||
auto v_1 = in_margin(i0, i1, 0, i2);
|
||||
// Sliced at second dimension.
|
||||
auto v_1 = in_margin(i0, 0, i1);
|
||||
CHECK_EQ(v_0, v_1);
|
||||
return v_0;
|
||||
});
|
||||
|
||||
@@ -254,7 +254,7 @@ TEST(SimpleDMatrix, Slice) {
|
||||
std::iota(upper.begin(), upper.end(), 1.0f);
|
||||
|
||||
auto& margin = p_m->Info().base_margin_;
|
||||
margin = linalg::Tensor<float, 3>{{kRows, kClasses}, GenericParameter::kCpuId};
|
||||
margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, GenericParameter::kCpuId};
|
||||
|
||||
std::array<int32_t, 3> ridxs {1, 3, 5};
|
||||
std::unique_ptr<DMatrix> out { p_m->Slice(ridxs) };
|
||||
|
||||
Reference in New Issue
Block a user