Check support status for categorical features. (#9946)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* Copyright 2015-2024, XGBoost Contributors
|
||||
* \file data.cc
|
||||
*/
|
||||
#include "xgboost/data.h"
|
||||
@@ -260,9 +260,14 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
||||
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
|
||||
}
|
||||
|
||||
void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<FeatureType>* types) {
|
||||
/**
|
||||
* @brief Load feature type info from names, returns whether there's categorical features.
|
||||
*/
|
||||
[[nodiscard]] bool LoadFeatureType(std::vector<std::string> const& type_names,
|
||||
std::vector<FeatureType>* types) {
|
||||
types->clear();
|
||||
for (auto const &elem : type_names) {
|
||||
bool has_cat{false};
|
||||
for (auto const& elem : type_names) {
|
||||
if (elem == "int") {
|
||||
types->emplace_back(FeatureType::kNumerical);
|
||||
} else if (elem == "float") {
|
||||
@@ -273,10 +278,12 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
|
||||
types->emplace_back(FeatureType::kNumerical);
|
||||
} else if (elem == "c") {
|
||||
types->emplace_back(FeatureType::kCategorical);
|
||||
has_cat = true;
|
||||
} else {
|
||||
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}.";
|
||||
}
|
||||
}
|
||||
return has_cat;
|
||||
}
|
||||
|
||||
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
|
||||
@@ -340,7 +347,8 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names);
|
||||
LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names);
|
||||
LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights);
|
||||
LoadFeatureType(feature_type_names, &feature_types.HostVector());
|
||||
|
||||
this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -639,6 +647,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
|
||||
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||
CHECK(info);
|
||||
}
|
||||
|
||||
if (!std::strcmp(key, "feature_type")) {
|
||||
feature_type_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
@@ -651,7 +660,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
}
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(feature_type_names, &h_feature_types);
|
||||
this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types);
|
||||
} else if (!std::strcmp(key, "feature_name")) {
|
||||
if (IsColumnSplit()) {
|
||||
std::vector<std::string> local_feature_names{};
|
||||
@@ -674,9 +683,8 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::GetFeatureInfo(const char *field,
|
||||
std::vector<std::string> *out_str_vecs) const {
|
||||
auto &str_vecs = *out_str_vecs;
|
||||
void MetaInfo::GetFeatureInfo(const char* field, std::vector<std::string>* out_str_vecs) const {
|
||||
auto& str_vecs = *out_str_vecs;
|
||||
if (!std::strcmp(field, "feature_type")) {
|
||||
str_vecs.resize(feature_type_names.size());
|
||||
std::copy(feature_type_names.cbegin(), feature_type_names.cend(), str_vecs.begin());
|
||||
@@ -689,6 +697,9 @@ void MetaInfo::GetFeatureInfo(const char *field,
|
||||
}
|
||||
|
||||
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_column) {
|
||||
/**
|
||||
* shape
|
||||
*/
|
||||
if (accumulate_rows) {
|
||||
this->num_row_ += that.num_row_;
|
||||
}
|
||||
@@ -702,6 +713,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
this->num_col_ = that.num_col_;
|
||||
|
||||
/**
|
||||
* info with n_samples
|
||||
*/
|
||||
linalg::Stack(&this->labels, that.labels);
|
||||
|
||||
this->weights_.SetDevice(that.weights_.Device());
|
||||
@@ -715,6 +729,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
|
||||
linalg::Stack(&this->base_margin_, that.base_margin_);
|
||||
|
||||
/**
|
||||
* group
|
||||
*/
|
||||
if (this->group_ptr_.size() == 0) {
|
||||
this->group_ptr_ = that.group_ptr_;
|
||||
} else {
|
||||
@@ -727,17 +744,25 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
group_ptr.end());
|
||||
}
|
||||
|
||||
/**
|
||||
* info with n_features
|
||||
*/
|
||||
if (!that.feature_names.empty()) {
|
||||
this->feature_names = that.feature_names;
|
||||
}
|
||||
|
||||
if (!that.feature_type_names.empty()) {
|
||||
this->feature_type_names = that.feature_type_names;
|
||||
auto &h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(this->feature_type_names, &h_feature_types);
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
this->has_categorical_ = LoadFeatureType(this->feature_type_names, &h_feature_types);
|
||||
} else if (!that.feature_types.Empty()) {
|
||||
// FIXME(jiamingy): https://github.com/dmlc/xgboost/pull/9171/files#r1440188612
|
||||
this->feature_types.Resize(that.feature_types.Size());
|
||||
this->feature_types.Copy(that.feature_types);
|
||||
auto const& ft = this->feature_types.ConstHostVector();
|
||||
this->has_categorical_ = std::any_of(ft.cbegin(), ft.cend(), common::IsCatOp{});
|
||||
}
|
||||
|
||||
if (!that.feature_weights.Empty()) {
|
||||
this->feature_weights.Resize(that.feature_weights.Size());
|
||||
this->feature_weights.SetDevice(that.feature_weights.Device());
|
||||
|
||||
@@ -93,7 +93,7 @@ class IterativeDMatrix : public DMatrix {
|
||||
return nullptr;
|
||||
}
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
LOG(FATAL) << "Not implemented for `QuantileDMatrix`.";
|
||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||
}
|
||||
BatchSet<CSCPage> GetColumnBatches(Context const *) override {
|
||||
|
||||
Reference in New Issue
Block a user