Check support status for categorical features. (#9946)

This commit is contained in:
Jiaming Yuan 2024-01-04 16:51:33 +08:00 committed by GitHub
parent db396ee340
commit c03a4d5088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 116 additions and 40 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2015-2023 by XGBoost Contributors * Copyright 2015-2024, XGBoost Contributors
* \file data.h * \file data.h
* \brief The input data structure of xgboost. * \brief The input data structure of xgboost.
* \author Tianqi Chen * \author Tianqi Chen
@ -158,15 +158,15 @@ class MetaInfo {
void SetFeatureInfo(const char *key, const char **info, const bst_ulong size); void SetFeatureInfo(const char *key, const char **info, const bst_ulong size);
void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const; void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const;
/* /**
* \brief Extend with other MetaInfo. * @brief Extend with other MetaInfo.
* *
* \param that The other MetaInfo object. * @param that The other MetaInfo object.
* *
* \param accumulate_rows Whether rows need to be accumulated in this function. If * @param accumulate_rows Whether rows need to be accumulated in this function. If
* client code knows number of rows in advance, set this * client code knows number of rows in advance, set this
* parameter to false. * parameter to false.
* \param check_column Whether the extend method should check the consistency of * @param check_column Whether the extend method should check the consistency of
* columns. * columns.
*/ */
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column); void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
@ -203,6 +203,10 @@ class MetaInfo {
* learning where labels are only available on worker 0. * learning where labels are only available on worker 0.
*/ */
bool ShouldHaveLabels() const; bool ShouldHaveLabels() const;
/**
* @brief Flag for whether the DMatrix has categorical features.
*/
bool HasCategorical() const { return has_categorical_; }
private: private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr); void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
@ -210,6 +214,7 @@ class MetaInfo {
/*! \brief argsort of labels */ /*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_; mutable std::vector<size_t> label_order_cache_;
bool has_categorical_{false};
}; };
/*! \brief Element from a sparse vector */ /*! \brief Element from a sparse vector */

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023 by XGBoost contributors * Copyright 2023-2024, XGBoost contributors
* *
* \brief Common error message for various checks. * \brief Common error message for various checks.
*/ */
@ -99,5 +99,9 @@ constexpr StringView InvalidCUDAOrdinal() {
void MismatchedDevices(Context const* booster, Context const* data); void MismatchedDevices(Context const* booster, Context const* data);
inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; } inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }
inline auto NoCategorical(std::string name) {
return name + " doesn't support categorical features.";
}
} // namespace xgboost::error } // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_ #endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2015-2023 by XGBoost Contributors * Copyright 2015-2024, XGBoost Contributors
* \file data.cc * \file data.cc
*/ */
#include "xgboost/data.h" #include "xgboost/data.h"
@ -260,9 +260,14 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields"; CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
} }
void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<FeatureType>* types) { /**
* @brief Load feature type info from names, returns whether there's categorical features.
*/
[[nodiscard]] bool LoadFeatureType(std::vector<std::string> const& type_names,
std::vector<FeatureType>* types) {
types->clear(); types->clear();
for (auto const &elem : type_names) { bool has_cat{false};
for (auto const& elem : type_names) {
if (elem == "int") { if (elem == "int") {
types->emplace_back(FeatureType::kNumerical); types->emplace_back(FeatureType::kNumerical);
} else if (elem == "float") { } else if (elem == "float") {
@ -273,10 +278,12 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
types->emplace_back(FeatureType::kNumerical); types->emplace_back(FeatureType::kNumerical);
} else if (elem == "c") { } else if (elem == "c") {
types->emplace_back(FeatureType::kCategorical); types->emplace_back(FeatureType::kCategorical);
has_cat = true;
} else { } else {
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}."; LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}.";
} }
} }
return has_cat;
} }
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const { const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
@ -340,7 +347,8 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names); LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names);
LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names); LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names);
LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights); LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights);
LoadFeatureType(feature_type_names, &feature_types.HostVector());
this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
} }
template <typename T> template <typename T>
@ -639,6 +647,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns."; CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
CHECK(info); CHECK(info);
} }
if (!std::strcmp(key, "feature_type")) { if (!std::strcmp(key, "feature_type")) {
feature_type_names.clear(); feature_type_names.clear();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
@ -651,7 +660,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
<< "Length of " << key << " must be equal to number of columns."; << "Length of " << key << " must be equal to number of columns.";
} }
auto& h_feature_types = feature_types.HostVector(); auto& h_feature_types = feature_types.HostVector();
LoadFeatureType(feature_type_names, &h_feature_types); this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types);
} else if (!std::strcmp(key, "feature_name")) { } else if (!std::strcmp(key, "feature_name")) {
if (IsColumnSplit()) { if (IsColumnSplit()) {
std::vector<std::string> local_feature_names{}; std::vector<std::string> local_feature_names{};
@ -674,9 +683,8 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
} }
} }
void MetaInfo::GetFeatureInfo(const char *field, void MetaInfo::GetFeatureInfo(const char* field, std::vector<std::string>* out_str_vecs) const {
std::vector<std::string> *out_str_vecs) const { auto& str_vecs = *out_str_vecs;
auto &str_vecs = *out_str_vecs;
if (!std::strcmp(field, "feature_type")) { if (!std::strcmp(field, "feature_type")) {
str_vecs.resize(feature_type_names.size()); str_vecs.resize(feature_type_names.size());
std::copy(feature_type_names.cbegin(), feature_type_names.cend(), str_vecs.begin()); std::copy(feature_type_names.cbegin(), feature_type_names.cend(), str_vecs.begin());
@ -689,6 +697,9 @@ void MetaInfo::GetFeatureInfo(const char *field,
} }
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_column) { void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_column) {
/**
* shape
*/
if (accumulate_rows) { if (accumulate_rows) {
this->num_row_ += that.num_row_; this->num_row_ += that.num_row_;
} }
@ -702,6 +713,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
} }
this->num_col_ = that.num_col_; this->num_col_ = that.num_col_;
/**
* info with n_samples
*/
linalg::Stack(&this->labels, that.labels); linalg::Stack(&this->labels, that.labels);
this->weights_.SetDevice(that.weights_.Device()); this->weights_.SetDevice(that.weights_.Device());
@ -715,6 +729,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
linalg::Stack(&this->base_margin_, that.base_margin_); linalg::Stack(&this->base_margin_, that.base_margin_);
/**
* group
*/
if (this->group_ptr_.size() == 0) { if (this->group_ptr_.size() == 0) {
this->group_ptr_ = that.group_ptr_; this->group_ptr_ = that.group_ptr_;
} else { } else {
@ -727,17 +744,25 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
group_ptr.end()); group_ptr.end());
} }
/**
* info with n_features
*/
if (!that.feature_names.empty()) { if (!that.feature_names.empty()) {
this->feature_names = that.feature_names; this->feature_names = that.feature_names;
} }
if (!that.feature_type_names.empty()) { if (!that.feature_type_names.empty()) {
this->feature_type_names = that.feature_type_names; this->feature_type_names = that.feature_type_names;
auto &h_feature_types = feature_types.HostVector(); auto& h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types); this->has_categorical_ = LoadFeatureType(this->feature_type_names, &h_feature_types);
} else if (!that.feature_types.Empty()) { } else if (!that.feature_types.Empty()) {
// FIXME(jiamingy): https://github.com/dmlc/xgboost/pull/9171/files#r1440188612
this->feature_types.Resize(that.feature_types.Size()); this->feature_types.Resize(that.feature_types.Size());
this->feature_types.Copy(that.feature_types); this->feature_types.Copy(that.feature_types);
auto const& ft = this->feature_types.ConstHostVector();
this->has_categorical_ = std::any_of(ft.cbegin(), ft.cend(), common::IsCatOp{});
} }
if (!that.feature_weights.Empty()) { if (!that.feature_weights.Empty()) {
this->feature_weights.Resize(that.feature_weights.Size()); this->feature_weights.Resize(that.feature_weights.Size());
this->feature_weights.SetDevice(that.feature_weights.Device()); this->feature_weights.SetDevice(that.feature_weights.Device());

View File

@ -93,7 +93,7 @@ class IterativeDMatrix : public DMatrix {
return nullptr; return nullptr;
} }
BatchSet<SparsePage> GetRowBatches() override { BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented."; LOG(FATAL) << "Not implemented for `QuantileDMatrix`.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr)); return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
} }
BatchSet<CSCPage> GetColumnBatches(Context const *) override { BatchSet<CSCPage> GetColumnBatches(Context const *) override {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2014-2023, XGBoost Contributors * Copyright 2014-2024, XGBoost Contributors
* \file gblinear.cc * \file gblinear.cc
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net * \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
* the update rule is parallel coordinate descent (shotgun) * the update rule is parallel coordinate descent (shotgun)
@ -8,25 +8,24 @@
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <vector>
#include <string>
#include <sstream>
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <sstream>
#include <string>
#include <vector>
#include "../common/common.h"
#include "../common/error_msg.h" // NoCategorical, DeprecatedFunc
#include "../common/threading_utils.h"
#include "../common/timer.h"
#include "gblinear_model.h"
#include "xgboost/gbm.h" #include "xgboost/gbm.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/predictor.h"
#include "xgboost/linear_updater.h"
#include "xgboost/logging.h"
#include "xgboost/learner.h" #include "xgboost/learner.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
#include "xgboost/linear_updater.h"
#include "gblinear_model.h" #include "xgboost/logging.h"
#include "../common/timer.h" #include "xgboost/predictor.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#include "../common/error_msg.h"
namespace xgboost::gbm { namespace xgboost::gbm {
DMLC_REGISTRY_FILE_TAG(gblinear); DMLC_REGISTRY_FILE_TAG(gblinear);
@ -145,6 +144,7 @@ class GBLinear : public GradientBooster {
ObjFunction const*) override { ObjFunction const*) override {
monitor_.Start("DoBoost"); monitor_.Start("DoBoost");
CHECK(!p_fmat->Info().HasCategorical()) << error::NoCategorical("`gblinear`");
model_.LazyInitModel(); model_.LazyInitModel();
this->LazySumWeights(p_fmat); this->LazySumWeights(p_fmat);

View File

@ -1,21 +1,22 @@
/** /**
* Copyright 2014-2023 by XGBoost Contributors * Copyright 2014-2024, XGBoost Contributors
* \file updater_colmaker.cc * \file updater_colmaker.cc
* \brief use columnwise update to construct a tree * \brief use columnwise update to construct a tree
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include <vector>
#include <cmath>
#include <algorithm> #include <algorithm>
#include <cmath>
#include <vector>
#include "../common/error_msg.h" // for NoCategorical
#include "../common/random.h"
#include "constraints.h"
#include "param.h"
#include "split_evaluator.h"
#include "xgboost/json.h"
#include "xgboost/logging.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "param.h"
#include "constraints.h"
#include "../common/random.h"
#include "split_evaluator.h"
namespace xgboost::tree { namespace xgboost::tree {
@ -102,6 +103,9 @@ class ColMaker: public TreeUpdater {
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't " LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
"support external memory training."; "support external memory training.";
} }
if (dmat->Info().HasCategorical()) {
LOG(FATAL) << error::NoCategorical("Updater `grow_colmaker` or `exact` tree method");
}
this->LazyGetColumnDensity(dmat); this->LazyGetColumnDensity(dmat);
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
interaction_constraints_.Configure(*param, dmat->Info().num_row_); interaction_constraints_.Configure(*param, dmat->Info().num_row_);

View File

@ -1,3 +1,5 @@
import os
import tempfile
import weakref import weakref
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
@ -195,3 +197,39 @@ def test_data_cache() -> None:
assert called == 1 assert called == 1
xgb.data._proxy_transform = transform xgb.data._proxy_transform = transform
def test_cat_check() -> None:
n_batches = 3
n_features = 2
n_samples_per_batch = 16
batches = []
for i in range(n_batches):
X, y = tm.make_categorical(
n_samples=n_samples_per_batch,
n_features=n_features,
n_categories=3,
onehot=False,
)
batches.append((X, y))
X, y = list(zip(*batches))
it = tm.IteratorForTest(X, y, None, cache=None)
Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"tree_method": "exact"}, Xy)
Xy = xgb.DMatrix(X[0], y[0], enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"tree_method": "exact"}, Xy)
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = os.path.join(tmpdir, "cache")
it = tm.IteratorForTest(X, y, None, cache=cache_path)
Xy = xgb.DMatrix(it, enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"booster": "gblinear"}, Xy)