Move feature names and types of DMatrix from Python to C++. (#5858)
* Add thread local return entry for DMatrix. * Save feature name and feature type in binary file. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -10,7 +10,6 @@
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/io.h"
|
||||
|
||||
|
||||
TEST(CAPI, XGDMatrixCreateFromMatDT) {
|
||||
std::vector<int> col0 = {0, -1, 3};
|
||||
std::vector<float> col1 = {-4.0f, 2.0f, 0.0f};
|
||||
@@ -148,4 +147,48 @@ TEST(CAPI, CatchDMLCError) {
|
||||
EXPECT_THROW({ dmlc::Stream::Create("foo", "r"); }, dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(CAPI, DMatrixSetFeatureName) {
|
||||
size_t constexpr kRows = 10;
|
||||
bst_feature_t constexpr kCols = 2;
|
||||
|
||||
DMatrixHandle handle;
|
||||
std::vector<float> data(kCols * kRows, 1.5);
|
||||
|
||||
XGDMatrixCreateFromMat_omp(data.data(), kRows, kCols,
|
||||
std::numeric_limits<float>::quiet_NaN(), &handle,
|
||||
0);
|
||||
std::vector<std::string> feature_names;
|
||||
for (bst_feature_t i = 0; i < kCols; ++i) {
|
||||
feature_names.emplace_back(std::to_string(i));
|
||||
}
|
||||
std::vector<char const*> c_feature_names;
|
||||
c_feature_names.resize(feature_names.size());
|
||||
std::transform(feature_names.cbegin(), feature_names.cend(),
|
||||
c_feature_names.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
XGDMatrixSetStrFeatureInfo(handle, u8"feature_name", c_feature_names.data(),
|
||||
c_feature_names.size());
|
||||
bst_ulong out_len = 0;
|
||||
char const **c_out_features;
|
||||
XGDMatrixGetStrFeatureInfo(handle, u8"feature_name", &out_len,
|
||||
&c_out_features);
|
||||
|
||||
CHECK_EQ(out_len, kCols);
|
||||
std::vector<std::string> out_features;
|
||||
for (bst_ulong i = 0; i < out_len; ++i) {
|
||||
ASSERT_EQ(std::to_string(i), c_out_features[i]);
|
||||
}
|
||||
|
||||
char const* feat_types [] {"i", "q"};
|
||||
static_assert(sizeof(feat_types)/ sizeof(feat_types[0]) == kCols, "");
|
||||
XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, kCols);
|
||||
char const **c_out_types;
|
||||
XGDMatrixGetStrFeatureInfo(handle, u8"feature_type", &out_len,
|
||||
&c_out_types);
|
||||
for (bst_ulong i = 0; i < out_len; ++i) {
|
||||
ASSERT_STREQ(feat_types[i], c_out_types[i]);
|
||||
}
|
||||
|
||||
XGDMatrixFree(handle);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -39,6 +39,36 @@ TEST(MetaInfo, GetSet) {
|
||||
ASSERT_EQ(info.group_ptr_.size(), 0);
|
||||
}
|
||||
|
||||
TEST(MetaInfo, GetSetFeature) {
|
||||
xgboost::MetaInfo info;
|
||||
EXPECT_THROW(info.SetFeatureInfo("", nullptr, 0), dmlc::Error);
|
||||
EXPECT_THROW(info.SetFeatureInfo("foo", nullptr, 0), dmlc::Error);
|
||||
EXPECT_NO_THROW(info.SetFeatureInfo("feature_name", nullptr, 0));
|
||||
EXPECT_NO_THROW(info.SetFeatureInfo("feature_type", nullptr, 0));
|
||||
ASSERT_EQ(info.feature_type_names.size(), 0);
|
||||
ASSERT_EQ(info.feature_types.Size(), 0);
|
||||
ASSERT_EQ(info.feature_names.size(), 0);
|
||||
|
||||
size_t constexpr kCols = 19;
|
||||
std::vector<std::string> types(kCols, u8"float");
|
||||
std::vector<char const*> c_types(kCols);
|
||||
std::transform(types.cbegin(), types.cend(), c_types.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
// Info has 0 column
|
||||
EXPECT_THROW(
|
||||
info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()),
|
||||
dmlc::Error);
|
||||
info.num_col_ = kCols;
|
||||
EXPECT_NO_THROW(
|
||||
info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()));
|
||||
|
||||
// Test clear.
|
||||
info.SetFeatureInfo("feature_type", nullptr, 0);
|
||||
ASSERT_EQ(info.feature_type_names.size(), 0);
|
||||
ASSERT_EQ(info.feature_types.Size(), 0);
|
||||
// Other conditions are tested in `SaveLoadBinary`.
|
||||
}
|
||||
|
||||
TEST(MetaInfo, SaveLoadBinary) {
|
||||
xgboost::MetaInfo info;
|
||||
uint64_t constexpr kRows { 64 }, kCols { 32 };
|
||||
@@ -51,9 +81,22 @@ TEST(MetaInfo, SaveLoadBinary) {
|
||||
info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
|
||||
info.num_row_ = kRows;
|
||||
info.num_col_ = kCols;
|
||||
|
||||
auto featname = u8"特征名";
|
||||
std::vector<std::string> types(kCols, u8"float");
|
||||
std::vector<char const*> c_types(kCols);
|
||||
std::transform(types.cbegin(), types.cend(), c_types.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size());
|
||||
std::vector<std::string> names(kCols, featname);
|
||||
std::vector<char const*> c_names(kCols);
|
||||
std::transform(names.cbegin(), names.cend(), c_names.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size());;
|
||||
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/metainfo.binary";
|
||||
{
|
||||
@@ -80,6 +123,23 @@ TEST(MetaInfo, SaveLoadBinary) {
|
||||
EXPECT_EQ(inforead.group_ptr_, info.group_ptr_);
|
||||
EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector());
|
||||
EXPECT_EQ(inforead.base_margin_.HostVector(), info.base_margin_.HostVector());
|
||||
|
||||
EXPECT_EQ(inforead.feature_type_names.size(), kCols);
|
||||
EXPECT_EQ(inforead.feature_types.Size(), kCols);
|
||||
EXPECT_TRUE(std::all_of(inforead.feature_type_names.cbegin(),
|
||||
inforead.feature_type_names.cend(),
|
||||
[](auto const &str) { return str == u8"float"; }));
|
||||
auto h_ft = inforead.feature_types.HostSpan();
|
||||
EXPECT_TRUE(std::all_of(h_ft.cbegin(), h_ft.cend(), [](auto f) {
|
||||
return f == xgboost::FeatureType::kNumerical;
|
||||
}));
|
||||
|
||||
EXPECT_EQ(inforead.feature_names.size(), kCols);
|
||||
EXPECT_TRUE(std::all_of(inforead.feature_names.cbegin(),
|
||||
inforead.feature_names.cend(),
|
||||
[=](auto const& str) {
|
||||
return str == featname;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user