Extensible binary serialization format for DMatrix::MetaInfo (#5187)

* Turn xgboost::DataType into C++11 enum class

* New binary serialization format for DMatrix::MetaInfo

* Fix clang-tidy

* Fix c++ test

* Implement new format proposal

* Move helper functions to anonymous namespace; remove unneeded field

* Fix lint

* Add shape.

* Keep only roundtrip test.

* Fix test.

* various fixes

* Update data.cc

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho
2020-01-23 11:33:17 -08:00
committed by GitHub
parent b4f952bd22
commit 44469a0ca9
5 changed files with 193 additions and 44 deletions

View File

@@ -1,11 +1,13 @@
/*!
* Copyright 2015-2019 by Contributors
* Copyright 2015-2020 by Contributors
* \file data.cc
*/
#include <dmlc/registry.h>
#include <cstring>
#include "dmlc/io.h"
#include "xgboost/data.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/logging.h"
#include "xgboost/version_config.h"
#include "sparse_page_writer.h"
@@ -29,7 +31,96 @@ DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPa
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>);
} // namespace dmlc
namespace {
template <typename T>
void SaveScalarField(dmlc::Stream *strm, const std::string &name,
xgboost::DataType type, const T &field) {
strm->Write(name);
strm->Write(type);
strm->Write(true); // is_scalar=True
strm->Write(field);
}
template <typename T>
void SaveVectorField(dmlc::Stream *strm, const std::string &name,
xgboost::DataType type, std::pair<uint64_t, uint64_t> shape,
const std::vector<T>& field) {
strm->Write(name);
strm->Write(type);
strm->Write(false); // is_scalar=False
strm->Write(shape.first);
strm->Write(shape.second);
strm->Write(field);
}
template <typename T>
void SaveVectorField(dmlc::Stream* strm, const std::string& name,
xgboost::DataType type, std::pair<uint64_t, uint64_t> shape,
const xgboost::HostDeviceVector<T>& field) {
SaveVectorField(strm, name, type, shape, field.ConstHostVector());
}
template <typename T>
void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name,
xgboost::DataType expected_type, T* field) {
const std::string invalid {"MetaInfo: Invalid format. "};
std::string name;
xgboost::DataType type;
bool is_scalar;
CHECK(strm->Read(&name)) << invalid;
CHECK_EQ(name, expected_name)
<< invalid << " Expected field: " << expected_name << ", got: " << name;
CHECK(strm->Read(&type)) << invalid;
CHECK(type == expected_type)
<< invalid << "Expected field of type: " << static_cast<int>(expected_type) << ", "
<< "got field type: " << static_cast<int>(type);
CHECK(strm->Read(&is_scalar)) << invalid;
CHECK(is_scalar)
<< invalid << "Expected field " << expected_name << " to be a scalar; got a vector";
CHECK(strm->Read(field, sizeof(T))) << invalid;
}
template <typename T>
void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name,
xgboost::DataType expected_type, std::vector<T>* field) {
const std::string invalid {"MetaInfo: Invalid format. "};
std::string name;
xgboost::DataType type;
bool is_scalar;
CHECK(strm->Read(&name)) << invalid;
CHECK_EQ(name, expected_name)
<< invalid << " Expected field: " << expected_name << ", got: " << name;
CHECK(strm->Read(&type)) << invalid;
CHECK(type == expected_type)
<< invalid << "Expected field of type: " << static_cast<int>(expected_type) << ", "
<< "got field type: " << static_cast<int>(type);
CHECK(strm->Read(&is_scalar)) << invalid;
CHECK(!is_scalar)
<< invalid << "Expected field " << expected_name << " to be a vector; got a scalar";
std::pair<uint64_t, uint64_t> shape;
CHECK(strm->Read(&shape.first));
CHECK(strm->Read(&shape.second));
// TODO(hcho3): this restriction may be lifted, once we add a field with more than 1 column.
CHECK_EQ(shape.second, 1) << invalid << "Number of columns is expected to be 1.";
CHECK(strm->Read(field)) << invalid;
}
template <typename T>
void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name,
xgboost::DataType expected_type,
xgboost::HostDeviceVector<T>* field) {
LoadVectorField(strm, expected_name, expected_type, &field->HostVector());
}
} // anonymous namespace
namespace xgboost {
uint64_t constexpr MetaInfo::kNumField;
// implementation of inline functions
void MetaInfo::Clear() {
num_row_ = num_col_ = num_nonzero_ = 0;
@@ -39,15 +130,42 @@ void MetaInfo::Clear() {
base_margin_.HostVector().clear();
}
/*
* Binary serialization format for MetaInfo:
*
* | name | type | is_scalar | num_row | num_col | value |
* |-------------+----------+-----------+---------+---------+-----------------|
* | num_row | kUInt64 | True | NA | NA | ${num_row_} |
* | num_col | kUInt64 | True | NA | NA | ${num_col_} |
* | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} |
* | labels | kFloat32 | False | ${size} | 1 | ${labels_} |
* | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} |
* | weights | kFloat32 | False | ${size} | 1 | ${weights_} |
* | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} |
*
* Note that the scalar fields (is_scalar=True) will have num_row and num_col missing.
* Also notice the difference between the saved name and the name used in `SetInfo':
* the former uses the plural form.
*/
void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
Version::Save(fo);
fo->Write(&num_row_, sizeof(num_row_));
fo->Write(&num_col_, sizeof(num_col_));
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
fo->Write(labels_.HostVector());
fo->Write(group_ptr_);
fo->Write(weights_.HostVector());
fo->Write(base_margin_.HostVector());
fo->Write(kNumField);
int field_cnt = 0; // make sure we are actually writing kNumField fields
SaveScalarField(fo, u8"num_row", DataType::kUInt64, num_row_); ++field_cnt;
SaveScalarField(fo, u8"num_col", DataType::kUInt64, num_col_); ++field_cnt;
SaveScalarField(fo, u8"num_nonzero", DataType::kUInt64, num_nonzero_); ++field_cnt;
SaveVectorField(fo, u8"labels", DataType::kFloat32,
{labels_.Size(), 1}, labels_); ++field_cnt;
SaveVectorField(fo, u8"group_ptr", DataType::kUInt32,
{group_ptr_.size(), 1}, group_ptr_); ++field_cnt;
SaveVectorField(fo, u8"weights", DataType::kFloat32,
{weights_.Size(), 1}, weights_); ++field_cnt;
SaveVectorField(fo, u8"base_margin", DataType::kFloat32,
{base_margin_.Size(), 1}, base_margin_); ++field_cnt;
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
}
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
@@ -59,15 +177,24 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
<< Version::String(version) << " is no longer supported. "
<< "Please process and save your data in current version: "
<< Version::String(Version::Self()) << " again.";
CHECK(fi->Read(&num_row_, sizeof(num_row_)) == sizeof(num_row_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_))
<< "MetaInfo: invalid format";
CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format";
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format";
CHECK(fi->Read(&base_margin_.HostVector())) << "MetaInfo: invalid format";
const uint64_t expected_num_field = kNumField;
uint64_t num_field { 0 };
CHECK(fi->Read(&num_field)) << "MetaInfo: invalid format";
CHECK_GE(num_field, expected_num_field)
<< "MetaInfo: insufficient number of fields (expected at least " << expected_num_field
<< " fields, but the binary file only contains " << num_field << "fields.)";
if (num_field > expected_num_field) {
LOG(WARNING) << "MetaInfo: the given binary file contains extra fields which will be ignored.";
}
LoadScalarField(fi, u8"num_row", DataType::kUInt64, &num_row_);
LoadScalarField(fi, u8"num_col", DataType::kUInt64, &num_col_);
LoadScalarField(fi, u8"num_nonzero", DataType::kUInt64, &num_nonzero_);
LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_);
LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_);
LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_);
LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_);
}
// try to load group information from file, if exists
@@ -102,19 +229,19 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
// macro to dispatch according to specified pointer types
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
switch (dtype) { \
case kFloat32: { \
case xgboost::DataType::kFloat32: { \
auto cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; break; \
} \
case kDouble: { \
case xgboost::DataType::kDouble: { \
auto cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; break; \
} \
case kUInt32: { \
case xgboost::DataType::kUInt32: { \
auto cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; break; \
} \
case kUInt64: { \
case xgboost::DataType::kUInt64: { \
auto cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; break; \
} \
default: LOG(FATAL) << "Unknown data type" << dtype; \
default: LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype); \
} \
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {