Add range-based slicing to tensor view. (#7453)
This commit is contained in:
@@ -42,9 +42,9 @@ TEST(Adapter, CSRArrayAdapter) {
|
||||
size_t n_features = 100, n_samples = 10;
|
||||
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateCSR(&values, &indptr, &indices);
|
||||
using linalg::MakeVec;
|
||||
auto indptr_arr = MakeVec(indptr.HostPointer(), indptr.Size()).ArrayInterfaceStr();
|
||||
auto values_arr = MakeVec(values.HostPointer(), values.Size()).ArrayInterfaceStr();
|
||||
auto indices_arr = MakeVec(indices.HostPointer(), indices.Size()).ArrayInterfaceStr();
|
||||
auto indptr_arr = ArrayInterfaceStr(MakeVec(indptr.HostPointer(), indptr.Size()));
|
||||
auto values_arr = ArrayInterfaceStr(MakeVec(values.HostPointer(), values.Size()));
|
||||
auto indices_arr = ArrayInterfaceStr(MakeVec(indices.HostPointer(), indices.Size()));
|
||||
auto adapter = data::CSRArrayAdapter(
|
||||
StringView{indptr_arr.c_str(), indptr_arr.size()},
|
||||
StringView{values_arr.c_str(), values_arr.size()},
|
||||
|
||||
@@ -19,9 +19,8 @@ TEST(ArrayInterface, Initialize) {
|
||||
ASSERT_EQ(arr_interface.type, ArrayInterfaceHandler::kF4);
|
||||
|
||||
HostDeviceVector<size_t> u64_storage(storage.Size());
|
||||
std::string u64_arr_str{linalg::TensorView<size_t const, 2>{
|
||||
u64_storage.ConstHostSpan(), {kRows, kCols}, GenericParameter::kCpuId}
|
||||
.ArrayInterfaceStr()};
|
||||
std::string u64_arr_str{ArrayInterfaceStr(linalg::TensorView<size_t const, 2>{
|
||||
u64_storage.ConstHostSpan(), {kRows, kCols}, GenericParameter::kCpuId})};
|
||||
std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(),
|
||||
u64_storage.HostSpan().begin());
|
||||
auto u64_arr = ArrayInterface<2>{u64_arr_str};
|
||||
|
||||
@@ -127,7 +127,8 @@ TEST(MetaInfo, SaveLoadBinary) {
|
||||
|
||||
auto orig_margin = info.base_margin_.View(xgboost::GenericParameter::kCpuId);
|
||||
auto read_margin = inforead.base_margin_.View(xgboost::GenericParameter::kCpuId);
|
||||
EXPECT_TRUE(std::equal(orig_margin.cbegin(), orig_margin.cend(), read_margin.cbegin()));
|
||||
EXPECT_TRUE(std::equal(orig_margin.Values().cbegin(), orig_margin.Values().cend(),
|
||||
read_margin.Values().cbegin()));
|
||||
|
||||
EXPECT_EQ(inforead.feature_type_names.size(), kCols);
|
||||
EXPECT_EQ(inforead.feature_types.Size(), kCols);
|
||||
@@ -259,9 +260,8 @@ TEST(MetaInfo, Validate) {
|
||||
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
|
||||
d_groups.SetDevice(0);
|
||||
d_groups.DevicePointer(); // pull to device
|
||||
std::string arr_interface_str{
|
||||
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0)
|
||||
.ArrayInterfaceStr()};
|
||||
std::string arr_interface_str{ArrayInterfaceStr(
|
||||
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))};
|
||||
EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
is_gpu ? labels.ConstDeviceSpan() : labels.ConstHostSpan(), {32, 2}, device};
|
||||
auto s = t.Slice(linalg::All(), 0);
|
||||
|
||||
auto str = s.ArrayInterfaceStr();
|
||||
auto str = ArrayInterfaceStr(s);
|
||||
ASSERT_EQ(s.Size(), 32);
|
||||
|
||||
info.SetInfo("label", StringView{str});
|
||||
@@ -48,7 +48,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
auto& h_qid = qid.Data()->HostVector();
|
||||
std::iota(h_qid.begin(), h_qid.end(), 0);
|
||||
auto s = qid.View(device).Slice(linalg::All(), 0);
|
||||
auto str = s.ArrayInterfaceStr();
|
||||
auto str = ArrayInterfaceStr(s);
|
||||
info.SetInfo("qid", StringView{str});
|
||||
auto const& h_result = info.group_ptr_;
|
||||
ASSERT_EQ(h_result.size(), s.Size() + 1);
|
||||
@@ -62,7 +62,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_margin.Shape().size(), 2);
|
||||
|
||||
info.SetInfo("base_margin", StringView{t_margin.ArrayInterfaceStr()});
|
||||
info.SetInfo("base_margin", StringView{ArrayInterfaceStr(t_margin)});
|
||||
auto const& h_result = info.base_margin_.View(-1);
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_margin = base_margin.View(-1);
|
||||
|
||||
Reference in New Issue
Block a user