Add helper for generating batches of data. (#5756)

* Add helper for generating batches of data.

* VC keyword clash.

* Another clash.
This commit is contained in:
Jiaming Yuan
2020-06-05 09:53:56 +08:00
committed by GitHub
parent 359023c0fa
commit bd9d57f579
3 changed files with 113 additions and 13 deletions

View File

@@ -2,6 +2,7 @@
#include <algorithm>
#include "helpers.h"
#include "../../src/data/array_interface.h"
namespace xgboost {
TEST(RandomDataGenerator, DMatrix) {
@@ -41,4 +42,29 @@ TEST(RandomDataGenerator, DMatrix) {
}
}
TEST(RandomDataGenerator, GenerateArrayInterfaceBatch) {
size_t constexpr kRows { 937 }, kCols { 100 }, kBatches { 13 };
float constexpr kSparsity { 0.4f };
HostDeviceVector<float> storage;
std::string array;
std::vector<std::string> batches;
std::tie(batches, array) =
RandomDataGenerator{kRows, kCols, kSparsity}.GenerateArrayInterfaceBatch(
&storage, kBatches);
CHECK_EQ(batches.size(), kBatches);
size_t rows = 0;
for (auto const &interface_str : batches) {
Json j_interface =
Json::Load({interface_str.c_str(), interface_str.size()});
ArrayInterfaceHandler::Validate(get<Object const>(j_interface));
CHECK_EQ(get<Integer>(j_interface["shape"][1]), kCols);
rows += get<Integer>(j_interface["shape"][0]);
}
CHECK_EQ(rows, kRows);
auto j_array = Json::Load({array.c_str(), array.size()});
CHECK_EQ(get<Integer>(j_array["shape"][0]), kRows);
CHECK_EQ(get<Integer>(j_array["shape"][1]), kCols);
}
} // namespace xgboost