Convert labels into tensor. (#7456)
* Add a new ctor to tensor for `initilizer_list`. * Change labels from host device vector to tensor. * Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
@@ -16,30 +16,27 @@ namespace xgboost {
|
||||
inline void TestMetaInfoStridedData(int32_t device) {
|
||||
MetaInfo info;
|
||||
{
|
||||
// label
|
||||
HostDeviceVector<float> labels;
|
||||
labels.Resize(64);
|
||||
auto& h_labels = labels.HostVector();
|
||||
std::iota(h_labels.begin(), h_labels.end(), 0.0f);
|
||||
bool is_gpu = device >= 0;
|
||||
if (is_gpu) {
|
||||
labels.SetDevice(0);
|
||||
}
|
||||
// labels
|
||||
linalg::Tensor<float, 3> labels;
|
||||
labels.Reshape(4, 2, 3);
|
||||
auto& h_label = labels.Data()->HostVector();
|
||||
std::iota(h_label.begin(), h_label.end(), 0.0);
|
||||
auto t_labels = labels.View(device).Slice(linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_labels.Shape().size(), 2);
|
||||
|
||||
auto t = linalg::TensorView<float const, 2>{
|
||||
is_gpu ? labels.ConstDeviceSpan() : labels.ConstHostSpan(), {32, 2}, device};
|
||||
auto s = t.Slice(linalg::All(), 0);
|
||||
|
||||
auto str = ArrayInterfaceStr(s);
|
||||
ASSERT_EQ(s.Size(), 32);
|
||||
|
||||
info.SetInfo("label", StringView{str});
|
||||
auto const& h_result = info.labels_.HostVector();
|
||||
ASSERT_EQ(h_result.size(), 32);
|
||||
|
||||
for (auto v : h_result) {
|
||||
ASSERT_EQ(static_cast<int32_t>(v) % 2, 0);
|
||||
}
|
||||
info.SetInfo("label", StringView{ArrayInterfaceStr(t_labels)});
|
||||
auto const& h_result = info.labels.View(-1);
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_labels = labels.View(-1);
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
|
||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||
auto i0 = std::get<0>(tup);
|
||||
auto i1 = std::get<1>(tup);
|
||||
// Sliced at second dimension.
|
||||
auto v_1 = in_labels(i0, 0, i1);
|
||||
CHECK_EQ(v_0, v_1);
|
||||
return v_0;
|
||||
});
|
||||
}
|
||||
{
|
||||
// qid
|
||||
|
||||
Reference in New Issue
Block a user