Reduce base margin to 2 dim for now. (#7455)

This commit is contained in:
Jiaming Yuan
2021-11-27 00:46:13 +08:00
committed by GitHub
parent bf7bb575b4
commit 557ffc4bf5
7 changed files with 33 additions and 33 deletions

View File

@@ -55,24 +55,23 @@ inline void TestMetaInfoStridedData(int32_t device) {
}
{
// base margin
linalg::Tensor<float, 4> base_margin;
base_margin.Reshape(4, 3, 2, 3);
linalg::Tensor<float, 3> base_margin;
base_margin.Reshape(4, 2, 3);
auto& h_margin = base_margin.Data()->HostVector();
std::iota(h_margin.begin(), h_margin.end(), 0.0);
auto t_margin = base_margin.View(device).Slice(linalg::All(), linalg::All(), 0, linalg::All());
ASSERT_EQ(t_margin.Shape().size(), 3);
auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All());
ASSERT_EQ(t_margin.Shape().size(), 2);
info.SetInfo("base_margin", StringView{t_margin.ArrayInterfaceStr()});
auto const& h_result = info.base_margin_.View(-1);
ASSERT_EQ(h_result.Shape().size(), 3);
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_margin = base_margin.View(-1);
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);
auto i1 = std::get<1>(tup);
auto i2 = std::get<2>(tup);
// Sliced at 3^th dimension.
auto v_1 = in_margin(i0, i1, 0, i2);
// Sliced at second dimension.
auto v_1 = in_margin(i0, 0, i1);
CHECK_EQ(v_0, v_1);
return v_0;
});

View File

@@ -254,7 +254,7 @@ TEST(SimpleDMatrix, Slice) {
std::iota(upper.begin(), upper.end(), 1.0f);
auto& margin = p_m->Info().base_margin_;
margin = linalg::Tensor<float, 3>{{kRows, kClasses}, GenericParameter::kCpuId};
margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, GenericParameter::kCpuId};
std::array<int32_t, 3> ridxs {1, 3, 5};
std::unique_ptr<DMatrix> out { p_m->Slice(ridxs) };