Use dynamic types for array interface columns instead of templates (#5108)
This commit is contained in:
@@ -23,21 +23,21 @@ TEST(ArrayInterfaceHandler, Error) {
|
||||
|
||||
auto const& column_obj = get<Object>(column);
|
||||
// missing version
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
// missing data
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||
column["data"] = j_data;
|
||||
// missing typestr
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||
column["typestr"] = String("<f4");
|
||||
// nullptr is not valid
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||
thrust::device_vector<float> d_data(kRows);
|
||||
j_data = {Json(Integer(reinterpret_cast<Integer::Int>(d_data.data().get()))),
|
||||
Json(Boolean(false))};
|
||||
column["data"] = j_data;
|
||||
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj));
|
||||
EXPECT_NO_THROW(Columnar c(column_obj));
|
||||
|
||||
std::vector<Json> j_mask_shape {Json(Integer(static_cast<Integer::Int>(kRows - 1)))};
|
||||
column["mask"] = Object();
|
||||
@@ -46,7 +46,7 @@ TEST(ArrayInterfaceHandler, Error) {
|
||||
column["mask"]["typestr"] = String("<i1");
|
||||
column["mask"]["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
// shape of mask and data doesn't match.
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -75,6 +75,23 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows,
|
||||
return column;
|
||||
}
|
||||
|
||||
void TestGetElement() {
|
||||
thrust::device_vector<float> data;
|
||||
auto j_column = GenerateDenseColumn("<f4", 3, &data);
|
||||
auto const& column_obj = get<Object const>(j_column);
|
||||
Columnar foreign_column(column_obj);
|
||||
|
||||
EXPECT_NO_THROW({
|
||||
dh::LaunchN(0, 1, [=] __device__(size_t idx) {
|
||||
KERNEL_CHECK(foreign_column.GetElement(0) == 0.0f);
|
||||
KERNEL_CHECK(foreign_column.GetElement(1) == 2.0f);
|
||||
KERNEL_CHECK(foreign_column.GetElement(2) == 4.0f);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Columnar, GetElement) { TestGetElement(); }
|
||||
|
||||
void TestDenseColumn(std::unique_ptr<data::SimpleCSRSource> const& source,
|
||||
size_t n_rows, size_t n_cols) {
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
@@ -384,4 +401,4 @@ TEST(SimpleCSRSource, Types) {
|
||||
TestDenseColumn(source, kRows, kCols);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user