Extensible binary serialization format for DMatrix::MetaInfo (#5187)
* Turn xgboost::DataType into C++11 enum class * New binary serialization format for DMatrix::MetaInfo * Fix clang-tidy * Fix c++ test * Implement new format proposal * Move helper functions to anonymous namespace; remove unneeded field * Fix lint * Add shape. * Keep only roundtrip test. * Fix test. * various fixes * Update data.cc Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
committed by
GitHub
parent
b4f952bd22
commit
44469a0ca9
@@ -1,12 +1,14 @@
|
||||
// Copyright by Contributors
|
||||
// Copyright 2016-2020 by Contributors
|
||||
#include <dmlc/io.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "../../../src/data/simple_csr_source.h"
|
||||
#include "../../../src/common/version.h"
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/base.h"
|
||||
|
||||
TEST(MetaInfo, GetSet) {
|
||||
xgboost::MetaInfo info;
|
||||
@@ -14,23 +16,23 @@ TEST(MetaInfo, GetSet) {
|
||||
double double2[2] = {1.0, 2.0};
|
||||
|
||||
EXPECT_EQ(info.labels_.Size(), 0);
|
||||
info.SetInfo("label", double2, xgboost::kFloat32, 2);
|
||||
info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2);
|
||||
EXPECT_EQ(info.labels_.Size(), 2);
|
||||
|
||||
float float2[2] = {1.0f, 2.0f};
|
||||
EXPECT_EQ(info.GetWeight(1), 1.0f)
|
||||
<< "When no weights are given, was expecting default value 1";
|
||||
info.SetInfo("weight", float2, xgboost::kFloat32, 2);
|
||||
info.SetInfo("weight", float2, xgboost::DataType::kFloat32, 2);
|
||||
EXPECT_EQ(info.GetWeight(1), 2.0f);
|
||||
|
||||
uint32_t uint32_t2[2] = {1U, 2U};
|
||||
EXPECT_EQ(info.base_margin_.Size(), 0);
|
||||
info.SetInfo("base_margin", uint32_t2, xgboost::kUInt32, 2);
|
||||
info.SetInfo("base_margin", uint32_t2, xgboost::DataType::kUInt32, 2);
|
||||
EXPECT_EQ(info.base_margin_.Size(), 2);
|
||||
|
||||
uint64_t uint64_t2[2] = {1U, 2U};
|
||||
EXPECT_EQ(info.group_ptr_.size(), 0);
|
||||
info.SetInfo("group", uint64_t2, xgboost::kUInt64, 2);
|
||||
info.SetInfo("group", uint64_t2, xgboost::DataType::kUInt64, 2);
|
||||
ASSERT_EQ(info.group_ptr_.size(), 3);
|
||||
EXPECT_EQ(info.group_ptr_[2], 3);
|
||||
|
||||
@@ -40,10 +42,18 @@ TEST(MetaInfo, GetSet) {
|
||||
|
||||
TEST(MetaInfo, SaveLoadBinary) {
|
||||
xgboost::MetaInfo info;
|
||||
double vals[2] = {1.0, 2.0};
|
||||
info.SetInfo("label", vals, xgboost::kDouble, 2);
|
||||
info.num_row_ = 2;
|
||||
info.num_col_ = 1;
|
||||
uint64_t constexpr kRows { 64 }, kCols { 32 };
|
||||
auto generator = []() {
|
||||
static float f = 0;
|
||||
return f++;
|
||||
};
|
||||
std::vector<float> values (kRows);
|
||||
std::generate(values.begin(), values.end(), generator);
|
||||
info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.num_row_ = kRows;
|
||||
info.num_col_ = kCols;
|
||||
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/metainfo.binary";
|
||||
@@ -54,17 +64,24 @@ TEST(MetaInfo, SaveLoadBinary) {
|
||||
info.SaveBinary(fs.get());
|
||||
}
|
||||
|
||||
ASSERT_EQ(GetFileSize(tmp_file), 84)
|
||||
<< "Expected saved binary file size to be same as object size";
|
||||
{
|
||||
// Round-trip test
|
||||
std::unique_ptr<dmlc::Stream> fs {
|
||||
dmlc::Stream::Create(tmp_file.c_str(), "r")
|
||||
};
|
||||
xgboost::MetaInfo inforead;
|
||||
inforead.LoadBinary(fs.get());
|
||||
ASSERT_EQ(inforead.num_row_, kRows);
|
||||
EXPECT_EQ(inforead.num_row_, info.num_row_);
|
||||
EXPECT_EQ(inforead.num_col_, info.num_col_);
|
||||
EXPECT_EQ(inforead.num_nonzero_, info.num_nonzero_);
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fs {
|
||||
dmlc::Stream::Create(tmp_file.c_str(), "r")
|
||||
};
|
||||
xgboost::MetaInfo inforead;
|
||||
inforead.LoadBinary(fs.get());
|
||||
EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector());
|
||||
EXPECT_EQ(inforead.num_col_, info.num_col_);
|
||||
EXPECT_EQ(inforead.num_row_, info.num_row_);
|
||||
ASSERT_EQ(inforead.labels_.HostVector(), values);
|
||||
EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector());
|
||||
EXPECT_EQ(inforead.group_ptr_, info.group_ptr_);
|
||||
EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector());
|
||||
EXPECT_EQ(inforead.base_margin_.HostVector(), info.base_margin_.HostVector());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MetaInfo, LoadQid) {
|
||||
|
||||
Reference in New Issue
Block a user