diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 08983950c..8a2dfefb9 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -26,7 +26,7 @@ namespace xgboost { class DMatrix; /*! \brief data type accepted by xgboost interface */ -enum DataType { +enum class DataType : uint8_t { kFloat32 = 1, kDouble = 2, kUInt32 = 3, @@ -38,6 +38,9 @@ enum DataType { */ class MetaInfo { public: + /*! \brief number of data fields in MetaInfo */ + static constexpr uint64_t kNumField = 7; + /*! \brief number of rows in the data */ uint64_t num_row_{0}; /*! \brief number of columns in the data */ diff --git a/include/xgboost/host_device_vector.h b/include/xgboost/host_device_vector.h index 7a2ea3d6d..b358c7889 100644 --- a/include/xgboost/host_device_vector.h +++ b/include/xgboost/host_device_vector.h @@ -127,6 +127,8 @@ class HostDeviceVector { void Resize(size_t new_size, T v = T()); + using value_type = T; + private: HostDeviceVectorImpl* impl_; }; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index b3d25c363..756a90f73 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -340,7 +340,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, API_BEGIN(); CHECK_HANDLE(); static_cast*>(handle) - ->get()->Info().SetInfo(field, info, kFloat32, len); + ->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len); API_END(); } @@ -361,7 +361,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, API_BEGIN(); CHECK_HANDLE(); static_cast*>(handle) - ->get()->Info().SetInfo(field, info, kUInt32, len); + ->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len); API_END(); } @@ -372,7 +372,7 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, CHECK_HANDLE(); LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead."; static_cast*>(handle) - ->get()->Info().SetInfo("group", group, kUInt32, len); + ->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len); API_END(); } diff --git a/src/data/data.cc b/src/data/data.cc index 9f158851c..f030686c4 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1,11 +1,13 @@ /*! - * Copyright 2015-2019 by Contributors + * Copyright 2015-2020 by Contributors * \file data.cc */ #include #include +#include "dmlc/io.h" #include "xgboost/data.h" +#include "xgboost/host_device_vector.h" #include "xgboost/logging.h" #include "xgboost/version_config.h" #include "sparse_page_writer.h" @@ -29,7 +31,96 @@ DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPa DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>); } // namespace dmlc +namespace { + +template +void SaveScalarField(dmlc::Stream *strm, const std::string &name, + xgboost::DataType type, const T &field) { + strm->Write(name); + strm->Write(type); + strm->Write(true); // is_scalar=True + strm->Write(field); +} + +template +void SaveVectorField(dmlc::Stream *strm, const std::string &name, + xgboost::DataType type, std::pair shape, + const std::vector& field) { + strm->Write(name); + strm->Write(type); + strm->Write(false); // is_scalar=False + strm->Write(shape.first); + strm->Write(shape.second); + strm->Write(field); +} + +template +void SaveVectorField(dmlc::Stream* strm, const std::string& name, + xgboost::DataType type, std::pair shape, + const xgboost::HostDeviceVector& field) { + SaveVectorField(strm, name, type, shape, field.ConstHostVector()); +} + +template +void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name, + xgboost::DataType expected_type, T* field) { + const std::string invalid {"MetaInfo: Invalid format. "}; + std::string name; + xgboost::DataType type; + bool is_scalar; + CHECK(strm->Read(&name)) << invalid; + CHECK_EQ(name, expected_name) + << invalid << " Expected field: " << expected_name << ", got: " << name; + CHECK(strm->Read(&type)) << invalid; + CHECK(type == expected_type) + << invalid << "Expected field of type: " << static_cast(expected_type) << ", " + << "got field type: " << static_cast(type); + CHECK(strm->Read(&is_scalar)) << invalid; + CHECK(is_scalar) + << invalid << "Expected field " << expected_name << " to be a scalar; got a vector"; + CHECK(strm->Read(field, sizeof(T))) << invalid; +} + +template +void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, + xgboost::DataType expected_type, std::vector* field) { + const std::string invalid {"MetaInfo: Invalid format. "}; + std::string name; + xgboost::DataType type; + bool is_scalar; + CHECK(strm->Read(&name)) << invalid; + CHECK_EQ(name, expected_name) + << invalid << " Expected field: " << expected_name << ", got: " << name; + CHECK(strm->Read(&type)) << invalid; + CHECK(type == expected_type) + << invalid << "Expected field of type: " << static_cast(expected_type) << ", " + << "got field type: " << static_cast(type); + CHECK(strm->Read(&is_scalar)) << invalid; + CHECK(!is_scalar) + << invalid << "Expected field " << expected_name << " to be a vector; got a scalar"; + std::pair shape; + + CHECK(strm->Read(&shape.first)); + CHECK(strm->Read(&shape.second)); + // TODO(hcho3): this restriction may be lifted, once we add a field with more than 1 column. + CHECK_EQ(shape.second, 1) << invalid << "Number of columns is expected to be 1."; + + CHECK(strm->Read(field)) << invalid; +} + +template +void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, + xgboost::DataType expected_type, + xgboost::HostDeviceVector* field) { + LoadVectorField(strm, expected_name, expected_type, &field->HostVector()); +} + +} // anonymous namespace + namespace xgboost { + +uint64_t constexpr MetaInfo::kNumField; + // implementation of inline functions void MetaInfo::Clear() { num_row_ = num_col_ = num_nonzero_ = 0; @@ -39,15 +130,42 @@ void MetaInfo::Clear() { base_margin_.HostVector().clear(); } +/* + * Binary serialization format for MetaInfo: + * + * | name | type | is_scalar | num_row | num_col | value | + * |-------------+----------+-----------+---------+---------+-----------------| + * | num_row | kUInt64 | True | NA | NA | ${num_row_} | + * | num_col | kUInt64 | True | NA | NA | ${num_col_} | + * | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} | + * | labels | kFloat32 | False | ${size} | 1 | ${labels_} | + * | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} | + * | weights | kFloat32 | False | ${size} | 1 | ${weights_} | + * | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} | + * + * Note that the scalar fields (is_scalar=True) will have num_row and num_col missing. + * Also notice the difference between the saved name and the name used in `SetInfo': + * the former uses the plural form. + */ + void MetaInfo::SaveBinary(dmlc::Stream *fo) const { Version::Save(fo); - fo->Write(&num_row_, sizeof(num_row_)); - fo->Write(&num_col_, sizeof(num_col_)); - fo->Write(&num_nonzero_, sizeof(num_nonzero_)); - fo->Write(labels_.HostVector()); - fo->Write(group_ptr_); - fo->Write(weights_.HostVector()); - fo->Write(base_margin_.HostVector()); + fo->Write(kNumField); + int field_cnt = 0; // make sure we are actually writing kNumField fields + + SaveScalarField(fo, u8"num_row", DataType::kUInt64, num_row_); ++field_cnt; + SaveScalarField(fo, u8"num_col", DataType::kUInt64, num_col_); ++field_cnt; + SaveScalarField(fo, u8"num_nonzero", DataType::kUInt64, num_nonzero_); ++field_cnt; + SaveVectorField(fo, u8"labels", DataType::kFloat32, + {labels_.Size(), 1}, labels_); ++field_cnt; + SaveVectorField(fo, u8"group_ptr", DataType::kUInt32, + {group_ptr_.size(), 1}, group_ptr_); ++field_cnt; + SaveVectorField(fo, u8"weights", DataType::kFloat32, + {weights_.Size(), 1}, weights_); ++field_cnt; + SaveVectorField(fo, u8"base_margin", DataType::kFloat32, + {base_margin_.Size(), 1}, base_margin_); ++field_cnt; + + CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields"; } void MetaInfo::LoadBinary(dmlc::Stream *fi) { @@ -59,15 +177,24 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { << Version::String(version) << " is no longer supported. " << "Please process and save your data in current version: " << Version::String(Version::Self()) << " again."; - CHECK(fi->Read(&num_row_, sizeof(num_row_)) == sizeof(num_row_)) << "MetaInfo: invalid format"; - CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format"; - CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_)) - << "MetaInfo: invalid format"; - CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format"; - CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format"; - CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format"; - CHECK(fi->Read(&base_margin_.HostVector())) << "MetaInfo: invalid format"; + const uint64_t expected_num_field = kNumField; + uint64_t num_field { 0 }; + CHECK(fi->Read(&num_field)) << "MetaInfo: invalid format"; + CHECK_GE(num_field, expected_num_field) + << "MetaInfo: insufficient number of fields (expected at least " << expected_num_field + << " fields, but the binary file only contains " << num_field << "fields.)"; + if (num_field > expected_num_field) { + LOG(WARNING) << "MetaInfo: the given binary file contains extra fields which will be ignored."; + } + + LoadScalarField(fi, u8"num_row", DataType::kUInt64, &num_row_); + LoadScalarField(fi, u8"num_col", DataType::kUInt64, &num_col_); + LoadScalarField(fi, u8"num_nonzero", DataType::kUInt64, &num_nonzero_); + LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_); + LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_); + LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_); + LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); } // try to load group information from file, if exists @@ -102,19 +229,19 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname, // macro to dispatch according to specified pointer types #define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \ switch (dtype) { \ - case kFloat32: { \ + case xgboost::DataType::kFloat32: { \ auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ } \ - case kDouble: { \ + case xgboost::DataType::kDouble: { \ auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ } \ - case kUInt32: { \ + case xgboost::DataType::kUInt32: { \ auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ } \ - case kUInt64: { \ + case xgboost::DataType::kUInt64: { \ auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ } \ - default: LOG(FATAL) << "Unknown data type" << dtype; \ + default: LOG(FATAL) << "Unknown data type" << static_cast(dtype); \ } \ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 700dd87bc..380c8b142 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -1,12 +1,14 @@ -// Copyright by Contributors +// Copyright 2016-2020 by Contributors #include #include #include #include #include #include "../../../src/data/simple_csr_source.h" +#include "../../../src/common/version.h" #include "../helpers.h" +#include "xgboost/base.h" TEST(MetaInfo, GetSet) { xgboost::MetaInfo info; @@ -14,23 +16,23 @@ TEST(MetaInfo, GetSet) { double double2[2] = {1.0, 2.0}; EXPECT_EQ(info.labels_.Size(), 0); - info.SetInfo("label", double2, xgboost::kFloat32, 2); + info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2); EXPECT_EQ(info.labels_.Size(), 2); float float2[2] = {1.0f, 2.0f}; EXPECT_EQ(info.GetWeight(1), 1.0f) << "When no weights are given, was expecting default value 1"; - info.SetInfo("weight", float2, xgboost::kFloat32, 2); + info.SetInfo("weight", float2, xgboost::DataType::kFloat32, 2); EXPECT_EQ(info.GetWeight(1), 2.0f); uint32_t uint32_t2[2] = {1U, 2U}; EXPECT_EQ(info.base_margin_.Size(), 0); - info.SetInfo("base_margin", uint32_t2, xgboost::kUInt32, 2); + info.SetInfo("base_margin", uint32_t2, xgboost::DataType::kUInt32, 2); EXPECT_EQ(info.base_margin_.Size(), 2); uint64_t uint64_t2[2] = {1U, 2U}; EXPECT_EQ(info.group_ptr_.size(), 0); - info.SetInfo("group", uint64_t2, xgboost::kUInt64, 2); + info.SetInfo("group", uint64_t2, xgboost::DataType::kUInt64, 2); ASSERT_EQ(info.group_ptr_.size(), 3); EXPECT_EQ(info.group_ptr_[2], 3); @@ -40,10 +42,18 @@ TEST(MetaInfo, GetSet) { TEST(MetaInfo, SaveLoadBinary) { xgboost::MetaInfo info; - double vals[2] = {1.0, 2.0}; - info.SetInfo("label", vals, xgboost::kDouble, 2); - info.num_row_ = 2; - info.num_col_ = 1; + uint64_t constexpr kRows { 64 }, kCols { 32 }; + auto generator = []() { + static float f = 0; + return f++; + }; + std::vector values (kRows); + std::generate(values.begin(), values.end(), generator); + info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows); + info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows); + info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows); + info.num_row_ = kRows; + info.num_col_ = kCols; dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/metainfo.binary"; @@ -54,17 +64,24 @@ TEST(MetaInfo, SaveLoadBinary) { info.SaveBinary(fs.get()); } - ASSERT_EQ(GetFileSize(tmp_file), 84) - << "Expected saved binary file size to be same as object size"; + { + // Round-trip test + std::unique_ptr fs { + dmlc::Stream::Create(tmp_file.c_str(), "r") + }; + xgboost::MetaInfo inforead; + inforead.LoadBinary(fs.get()); + ASSERT_EQ(inforead.num_row_, kRows); + EXPECT_EQ(inforead.num_row_, info.num_row_); + EXPECT_EQ(inforead.num_col_, info.num_col_); + EXPECT_EQ(inforead.num_nonzero_, info.num_nonzero_); - std::unique_ptr fs { - dmlc::Stream::Create(tmp_file.c_str(), "r") - }; - xgboost::MetaInfo inforead; - inforead.LoadBinary(fs.get()); - EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector()); - EXPECT_EQ(inforead.num_col_, info.num_col_); - EXPECT_EQ(inforead.num_row_, info.num_row_); + ASSERT_EQ(inforead.labels_.HostVector(), values); + EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector()); + EXPECT_EQ(inforead.group_ptr_, info.group_ptr_); + EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector()); + EXPECT_EQ(inforead.base_margin_.HostVector(), info.base_margin_.HostVector()); + } } TEST(MetaInfo, LoadQid) {