Extensible binary serialization format for DMatrix::MetaInfo (#5187)
* Turn xgboost::DataType into C++11 enum class * New binary serialization format for DMatrix::MetaInfo * Fix clang-tidy * Fix c++ test * Implement new format proposal * Move helper functions to anonymous namespace; remove unneeded field * Fix lint * Add shape. * Keep only roundtrip test. * Fix test. * various fixes * Update data.cc Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
b4f952bd22
commit
44469a0ca9
@ -26,7 +26,7 @@ namespace xgboost {
|
|||||||
class DMatrix;
|
class DMatrix;
|
||||||
|
|
||||||
/*! \brief data type accepted by xgboost interface */
|
/*! \brief data type accepted by xgboost interface */
|
||||||
enum DataType {
|
enum class DataType : uint8_t {
|
||||||
kFloat32 = 1,
|
kFloat32 = 1,
|
||||||
kDouble = 2,
|
kDouble = 2,
|
||||||
kUInt32 = 3,
|
kUInt32 = 3,
|
||||||
@ -38,6 +38,9 @@ enum DataType {
|
|||||||
*/
|
*/
|
||||||
class MetaInfo {
|
class MetaInfo {
|
||||||
public:
|
public:
|
||||||
|
/*! \brief number of data fields in MetaInfo */
|
||||||
|
static constexpr uint64_t kNumField = 7;
|
||||||
|
|
||||||
/*! \brief number of rows in the data */
|
/*! \brief number of rows in the data */
|
||||||
uint64_t num_row_{0};
|
uint64_t num_row_{0};
|
||||||
/*! \brief number of columns in the data */
|
/*! \brief number of columns in the data */
|
||||||
|
|||||||
@ -127,6 +127,8 @@ class HostDeviceVector {
|
|||||||
|
|
||||||
void Resize(size_t new_size, T v = T());
|
void Resize(size_t new_size, T v = T());
|
||||||
|
|
||||||
|
using value_type = T;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
HostDeviceVectorImpl<T>* impl_;
|
HostDeviceVectorImpl<T>* impl_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -340,7 +340,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||||
->get()->Info().SetInfo(field, info, kFloat32, len);
|
->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,7 +361,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||||
->get()->Info().SetInfo(field, info, kUInt32, len);
|
->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -372,7 +372,7 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
|||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||||
->get()->Info().SetInfo("group", group, kUInt32, len);
|
->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
169
src/data/data.cc
169
src/data/data.cc
@ -1,11 +1,13 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2015-2019 by Contributors
|
* Copyright 2015-2020 by Contributors
|
||||||
* \file data.cc
|
* \file data.cc
|
||||||
*/
|
*/
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
|
#include "dmlc/io.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/version_config.h"
|
#include "xgboost/version_config.h"
|
||||||
#include "sparse_page_writer.h"
|
#include "sparse_page_writer.h"
|
||||||
@ -29,7 +31,96 @@ DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPa
|
|||||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>);
|
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>);
|
||||||
} // namespace dmlc
|
} // namespace dmlc
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SaveScalarField(dmlc::Stream *strm, const std::string &name,
|
||||||
|
xgboost::DataType type, const T &field) {
|
||||||
|
strm->Write(name);
|
||||||
|
strm->Write(type);
|
||||||
|
strm->Write(true); // is_scalar=True
|
||||||
|
strm->Write(field);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SaveVectorField(dmlc::Stream *strm, const std::string &name,
|
||||||
|
xgboost::DataType type, std::pair<uint64_t, uint64_t> shape,
|
||||||
|
const std::vector<T>& field) {
|
||||||
|
strm->Write(name);
|
||||||
|
strm->Write(type);
|
||||||
|
strm->Write(false); // is_scalar=False
|
||||||
|
strm->Write(shape.first);
|
||||||
|
strm->Write(shape.second);
|
||||||
|
strm->Write(field);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SaveVectorField(dmlc::Stream* strm, const std::string& name,
|
||||||
|
xgboost::DataType type, std::pair<uint64_t, uint64_t> shape,
|
||||||
|
const xgboost::HostDeviceVector<T>& field) {
|
||||||
|
SaveVectorField(strm, name, type, shape, field.ConstHostVector());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name,
|
||||||
|
xgboost::DataType expected_type, T* field) {
|
||||||
|
const std::string invalid {"MetaInfo: Invalid format. "};
|
||||||
|
std::string name;
|
||||||
|
xgboost::DataType type;
|
||||||
|
bool is_scalar;
|
||||||
|
CHECK(strm->Read(&name)) << invalid;
|
||||||
|
CHECK_EQ(name, expected_name)
|
||||||
|
<< invalid << " Expected field: " << expected_name << ", got: " << name;
|
||||||
|
CHECK(strm->Read(&type)) << invalid;
|
||||||
|
CHECK(type == expected_type)
|
||||||
|
<< invalid << "Expected field of type: " << static_cast<int>(expected_type) << ", "
|
||||||
|
<< "got field type: " << static_cast<int>(type);
|
||||||
|
CHECK(strm->Read(&is_scalar)) << invalid;
|
||||||
|
CHECK(is_scalar)
|
||||||
|
<< invalid << "Expected field " << expected_name << " to be a scalar; got a vector";
|
||||||
|
CHECK(strm->Read(field, sizeof(T))) << invalid;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name,
|
||||||
|
xgboost::DataType expected_type, std::vector<T>* field) {
|
||||||
|
const std::string invalid {"MetaInfo: Invalid format. "};
|
||||||
|
std::string name;
|
||||||
|
xgboost::DataType type;
|
||||||
|
bool is_scalar;
|
||||||
|
CHECK(strm->Read(&name)) << invalid;
|
||||||
|
CHECK_EQ(name, expected_name)
|
||||||
|
<< invalid << " Expected field: " << expected_name << ", got: " << name;
|
||||||
|
CHECK(strm->Read(&type)) << invalid;
|
||||||
|
CHECK(type == expected_type)
|
||||||
|
<< invalid << "Expected field of type: " << static_cast<int>(expected_type) << ", "
|
||||||
|
<< "got field type: " << static_cast<int>(type);
|
||||||
|
CHECK(strm->Read(&is_scalar)) << invalid;
|
||||||
|
CHECK(!is_scalar)
|
||||||
|
<< invalid << "Expected field " << expected_name << " to be a vector; got a scalar";
|
||||||
|
std::pair<uint64_t, uint64_t> shape;
|
||||||
|
|
||||||
|
CHECK(strm->Read(&shape.first));
|
||||||
|
CHECK(strm->Read(&shape.second));
|
||||||
|
// TODO(hcho3): this restriction may be lifted, once we add a field with more than 1 column.
|
||||||
|
CHECK_EQ(shape.second, 1) << invalid << "Number of columns is expected to be 1.";
|
||||||
|
|
||||||
|
CHECK(strm->Read(field)) << invalid;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name,
|
||||||
|
xgboost::DataType expected_type,
|
||||||
|
xgboost::HostDeviceVector<T>* field) {
|
||||||
|
LoadVectorField(strm, expected_name, expected_type, &field->HostVector());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
|
uint64_t constexpr MetaInfo::kNumField;
|
||||||
|
|
||||||
// implementation of inline functions
|
// implementation of inline functions
|
||||||
void MetaInfo::Clear() {
|
void MetaInfo::Clear() {
|
||||||
num_row_ = num_col_ = num_nonzero_ = 0;
|
num_row_ = num_col_ = num_nonzero_ = 0;
|
||||||
@ -39,15 +130,42 @@ void MetaInfo::Clear() {
|
|||||||
base_margin_.HostVector().clear();
|
base_margin_.HostVector().clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Binary serialization format for MetaInfo:
|
||||||
|
*
|
||||||
|
* | name | type | is_scalar | num_row | num_col | value |
|
||||||
|
* |-------------+----------+-----------+---------+---------+-----------------|
|
||||||
|
* | num_row | kUInt64 | True | NA | NA | ${num_row_} |
|
||||||
|
* | num_col | kUInt64 | True | NA | NA | ${num_col_} |
|
||||||
|
* | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} |
|
||||||
|
* | labels | kFloat32 | False | ${size} | 1 | ${labels_} |
|
||||||
|
* | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} |
|
||||||
|
* | weights | kFloat32 | False | ${size} | 1 | ${weights_} |
|
||||||
|
* | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} |
|
||||||
|
*
|
||||||
|
* Note that the scalar fields (is_scalar=True) will have num_row and num_col missing.
|
||||||
|
* Also notice the difference between the saved name and the name used in `SetInfo':
|
||||||
|
* the former uses the plural form.
|
||||||
|
*/
|
||||||
|
|
||||||
void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
||||||
Version::Save(fo);
|
Version::Save(fo);
|
||||||
fo->Write(&num_row_, sizeof(num_row_));
|
fo->Write(kNumField);
|
||||||
fo->Write(&num_col_, sizeof(num_col_));
|
int field_cnt = 0; // make sure we are actually writing kNumField fields
|
||||||
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
|
|
||||||
fo->Write(labels_.HostVector());
|
SaveScalarField(fo, u8"num_row", DataType::kUInt64, num_row_); ++field_cnt;
|
||||||
fo->Write(group_ptr_);
|
SaveScalarField(fo, u8"num_col", DataType::kUInt64, num_col_); ++field_cnt;
|
||||||
fo->Write(weights_.HostVector());
|
SaveScalarField(fo, u8"num_nonzero", DataType::kUInt64, num_nonzero_); ++field_cnt;
|
||||||
fo->Write(base_margin_.HostVector());
|
SaveVectorField(fo, u8"labels", DataType::kFloat32,
|
||||||
|
{labels_.Size(), 1}, labels_); ++field_cnt;
|
||||||
|
SaveVectorField(fo, u8"group_ptr", DataType::kUInt32,
|
||||||
|
{group_ptr_.size(), 1}, group_ptr_); ++field_cnt;
|
||||||
|
SaveVectorField(fo, u8"weights", DataType::kFloat32,
|
||||||
|
{weights_.Size(), 1}, weights_); ++field_cnt;
|
||||||
|
SaveVectorField(fo, u8"base_margin", DataType::kFloat32,
|
||||||
|
{base_margin_.Size(), 1}, base_margin_); ++field_cnt;
|
||||||
|
|
||||||
|
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||||
@ -59,15 +177,24 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
|||||||
<< Version::String(version) << " is no longer supported. "
|
<< Version::String(version) << " is no longer supported. "
|
||||||
<< "Please process and save your data in current version: "
|
<< "Please process and save your data in current version: "
|
||||||
<< Version::String(Version::Self()) << " again.";
|
<< Version::String(Version::Self()) << " again.";
|
||||||
CHECK(fi->Read(&num_row_, sizeof(num_row_)) == sizeof(num_row_)) << "MetaInfo: invalid format";
|
|
||||||
CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format";
|
|
||||||
CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_))
|
|
||||||
<< "MetaInfo: invalid format";
|
|
||||||
CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format";
|
|
||||||
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
|
|
||||||
|
|
||||||
CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format";
|
const uint64_t expected_num_field = kNumField;
|
||||||
CHECK(fi->Read(&base_margin_.HostVector())) << "MetaInfo: invalid format";
|
uint64_t num_field { 0 };
|
||||||
|
CHECK(fi->Read(&num_field)) << "MetaInfo: invalid format";
|
||||||
|
CHECK_GE(num_field, expected_num_field)
|
||||||
|
<< "MetaInfo: insufficient number of fields (expected at least " << expected_num_field
|
||||||
|
<< " fields, but the binary file only contains " << num_field << "fields.)";
|
||||||
|
if (num_field > expected_num_field) {
|
||||||
|
LOG(WARNING) << "MetaInfo: the given binary file contains extra fields which will be ignored.";
|
||||||
|
}
|
||||||
|
|
||||||
|
LoadScalarField(fi, u8"num_row", DataType::kUInt64, &num_row_);
|
||||||
|
LoadScalarField(fi, u8"num_col", DataType::kUInt64, &num_col_);
|
||||||
|
LoadScalarField(fi, u8"num_nonzero", DataType::kUInt64, &num_nonzero_);
|
||||||
|
LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_);
|
||||||
|
LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_);
|
||||||
|
LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_);
|
||||||
|
LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to load group information from file, if exists
|
// try to load group information from file, if exists
|
||||||
@ -102,19 +229,19 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
|
|||||||
// macro to dispatch according to specified pointer types
|
// macro to dispatch according to specified pointer types
|
||||||
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
|
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
|
||||||
switch (dtype) { \
|
switch (dtype) { \
|
||||||
case kFloat32: { \
|
case xgboost::DataType::kFloat32: { \
|
||||||
auto cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; break; \
|
auto cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; break; \
|
||||||
} \
|
} \
|
||||||
case kDouble: { \
|
case xgboost::DataType::kDouble: { \
|
||||||
auto cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; break; \
|
auto cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; break; \
|
||||||
} \
|
} \
|
||||||
case kUInt32: { \
|
case xgboost::DataType::kUInt32: { \
|
||||||
auto cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; break; \
|
auto cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; break; \
|
||||||
} \
|
} \
|
||||||
case kUInt64: { \
|
case xgboost::DataType::kUInt64: { \
|
||||||
auto cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; break; \
|
auto cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; break; \
|
||||||
} \
|
} \
|
||||||
default: LOG(FATAL) << "Unknown data type" << dtype; \
|
default: LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
// Copyright by Contributors
|
// Copyright 2016-2020 by Contributors
|
||||||
#include <dmlc/io.h>
|
#include <dmlc/io.h>
|
||||||
#include <dmlc/filesystem.h>
|
#include <dmlc/filesystem.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "../../../src/data/simple_csr_source.h"
|
#include "../../../src/data/simple_csr_source.h"
|
||||||
|
#include "../../../src/common/version.h"
|
||||||
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
#include "xgboost/base.h"
|
||||||
|
|
||||||
TEST(MetaInfo, GetSet) {
|
TEST(MetaInfo, GetSet) {
|
||||||
xgboost::MetaInfo info;
|
xgboost::MetaInfo info;
|
||||||
@ -14,23 +16,23 @@ TEST(MetaInfo, GetSet) {
|
|||||||
double double2[2] = {1.0, 2.0};
|
double double2[2] = {1.0, 2.0};
|
||||||
|
|
||||||
EXPECT_EQ(info.labels_.Size(), 0);
|
EXPECT_EQ(info.labels_.Size(), 0);
|
||||||
info.SetInfo("label", double2, xgboost::kFloat32, 2);
|
info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2);
|
||||||
EXPECT_EQ(info.labels_.Size(), 2);
|
EXPECT_EQ(info.labels_.Size(), 2);
|
||||||
|
|
||||||
float float2[2] = {1.0f, 2.0f};
|
float float2[2] = {1.0f, 2.0f};
|
||||||
EXPECT_EQ(info.GetWeight(1), 1.0f)
|
EXPECT_EQ(info.GetWeight(1), 1.0f)
|
||||||
<< "When no weights are given, was expecting default value 1";
|
<< "When no weights are given, was expecting default value 1";
|
||||||
info.SetInfo("weight", float2, xgboost::kFloat32, 2);
|
info.SetInfo("weight", float2, xgboost::DataType::kFloat32, 2);
|
||||||
EXPECT_EQ(info.GetWeight(1), 2.0f);
|
EXPECT_EQ(info.GetWeight(1), 2.0f);
|
||||||
|
|
||||||
uint32_t uint32_t2[2] = {1U, 2U};
|
uint32_t uint32_t2[2] = {1U, 2U};
|
||||||
EXPECT_EQ(info.base_margin_.Size(), 0);
|
EXPECT_EQ(info.base_margin_.Size(), 0);
|
||||||
info.SetInfo("base_margin", uint32_t2, xgboost::kUInt32, 2);
|
info.SetInfo("base_margin", uint32_t2, xgboost::DataType::kUInt32, 2);
|
||||||
EXPECT_EQ(info.base_margin_.Size(), 2);
|
EXPECT_EQ(info.base_margin_.Size(), 2);
|
||||||
|
|
||||||
uint64_t uint64_t2[2] = {1U, 2U};
|
uint64_t uint64_t2[2] = {1U, 2U};
|
||||||
EXPECT_EQ(info.group_ptr_.size(), 0);
|
EXPECT_EQ(info.group_ptr_.size(), 0);
|
||||||
info.SetInfo("group", uint64_t2, xgboost::kUInt64, 2);
|
info.SetInfo("group", uint64_t2, xgboost::DataType::kUInt64, 2);
|
||||||
ASSERT_EQ(info.group_ptr_.size(), 3);
|
ASSERT_EQ(info.group_ptr_.size(), 3);
|
||||||
EXPECT_EQ(info.group_ptr_[2], 3);
|
EXPECT_EQ(info.group_ptr_[2], 3);
|
||||||
|
|
||||||
@ -40,10 +42,18 @@ TEST(MetaInfo, GetSet) {
|
|||||||
|
|
||||||
TEST(MetaInfo, SaveLoadBinary) {
|
TEST(MetaInfo, SaveLoadBinary) {
|
||||||
xgboost::MetaInfo info;
|
xgboost::MetaInfo info;
|
||||||
double vals[2] = {1.0, 2.0};
|
uint64_t constexpr kRows { 64 }, kCols { 32 };
|
||||||
info.SetInfo("label", vals, xgboost::kDouble, 2);
|
auto generator = []() {
|
||||||
info.num_row_ = 2;
|
static float f = 0;
|
||||||
info.num_col_ = 1;
|
return f++;
|
||||||
|
};
|
||||||
|
std::vector<float> values (kRows);
|
||||||
|
std::generate(values.begin(), values.end(), generator);
|
||||||
|
info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||||
|
info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||||
|
info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||||
|
info.num_row_ = kRows;
|
||||||
|
info.num_col_ = kCols;
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/metainfo.binary";
|
const std::string tmp_file = tempdir.path + "/metainfo.binary";
|
||||||
@ -54,17 +64,24 @@ TEST(MetaInfo, SaveLoadBinary) {
|
|||||||
info.SaveBinary(fs.get());
|
info.SaveBinary(fs.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSERT_EQ(GetFileSize(tmp_file), 84)
|
{
|
||||||
<< "Expected saved binary file size to be same as object size";
|
// Round-trip test
|
||||||
|
std::unique_ptr<dmlc::Stream> fs {
|
||||||
|
dmlc::Stream::Create(tmp_file.c_str(), "r")
|
||||||
|
};
|
||||||
|
xgboost::MetaInfo inforead;
|
||||||
|
inforead.LoadBinary(fs.get());
|
||||||
|
ASSERT_EQ(inforead.num_row_, kRows);
|
||||||
|
EXPECT_EQ(inforead.num_row_, info.num_row_);
|
||||||
|
EXPECT_EQ(inforead.num_col_, info.num_col_);
|
||||||
|
EXPECT_EQ(inforead.num_nonzero_, info.num_nonzero_);
|
||||||
|
|
||||||
std::unique_ptr<dmlc::Stream> fs {
|
ASSERT_EQ(inforead.labels_.HostVector(), values);
|
||||||
dmlc::Stream::Create(tmp_file.c_str(), "r")
|
EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector());
|
||||||
};
|
EXPECT_EQ(inforead.group_ptr_, info.group_ptr_);
|
||||||
xgboost::MetaInfo inforead;
|
EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector());
|
||||||
inforead.LoadBinary(fs.get());
|
EXPECT_EQ(inforead.base_margin_.HostVector(), info.base_margin_.HostVector());
|
||||||
EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector());
|
}
|
||||||
EXPECT_EQ(inforead.num_col_, info.num_col_);
|
|
||||||
EXPECT_EQ(inforead.num_row_, info.num_row_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MetaInfo, LoadQid) {
|
TEST(MetaInfo, LoadQid) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user