Check support status for categorical features. (#9946)
This commit is contained in:
parent
db396ee340
commit
c03a4d5088
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2015-2023 by XGBoost Contributors
|
* Copyright 2015-2024, XGBoost Contributors
|
||||||
* \file data.h
|
* \file data.h
|
||||||
* \brief The input data structure of xgboost.
|
* \brief The input data structure of xgboost.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -158,15 +158,15 @@ class MetaInfo {
|
|||||||
void SetFeatureInfo(const char *key, const char **info, const bst_ulong size);
|
void SetFeatureInfo(const char *key, const char **info, const bst_ulong size);
|
||||||
void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const;
|
void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const;
|
||||||
|
|
||||||
/*
|
/**
|
||||||
* \brief Extend with other MetaInfo.
|
* @brief Extend with other MetaInfo.
|
||||||
*
|
*
|
||||||
* \param that The other MetaInfo object.
|
* @param that The other MetaInfo object.
|
||||||
*
|
*
|
||||||
* \param accumulate_rows Whether rows need to be accumulated in this function. If
|
* @param accumulate_rows Whether rows need to be accumulated in this function. If
|
||||||
* client code knows number of rows in advance, set this
|
* client code knows number of rows in advance, set this
|
||||||
* parameter to false.
|
* parameter to false.
|
||||||
* \param check_column Whether the extend method should check the consistency of
|
* @param check_column Whether the extend method should check the consistency of
|
||||||
* columns.
|
* columns.
|
||||||
*/
|
*/
|
||||||
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
|
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
|
||||||
@ -203,6 +203,10 @@ class MetaInfo {
|
|||||||
* learning where labels are only available on worker 0.
|
* learning where labels are only available on worker 0.
|
||||||
*/
|
*/
|
||||||
bool ShouldHaveLabels() const;
|
bool ShouldHaveLabels() const;
|
||||||
|
/**
|
||||||
|
* @brief Flag for whether the DMatrix has categorical features.
|
||||||
|
*/
|
||||||
|
bool HasCategorical() const { return has_categorical_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||||
@ -210,6 +214,7 @@ class MetaInfo {
|
|||||||
|
|
||||||
/*! \brief argsort of labels */
|
/*! \brief argsort of labels */
|
||||||
mutable std::vector<size_t> label_order_cache_;
|
mutable std::vector<size_t> label_order_cache_;
|
||||||
|
bool has_categorical_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief Element from a sparse vector */
|
/*! \brief Element from a sparse vector */
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023 by XGBoost contributors
|
* Copyright 2023-2024, XGBoost contributors
|
||||||
*
|
*
|
||||||
* \brief Common error message for various checks.
|
* \brief Common error message for various checks.
|
||||||
*/
|
*/
|
||||||
@ -99,5 +99,9 @@ constexpr StringView InvalidCUDAOrdinal() {
|
|||||||
void MismatchedDevices(Context const* booster, Context const* data);
|
void MismatchedDevices(Context const* booster, Context const* data);
|
||||||
|
|
||||||
inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }
|
inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }
|
||||||
|
|
||||||
|
inline auto NoCategorical(std::string name) {
|
||||||
|
return name + " doesn't support categorical features.";
|
||||||
|
}
|
||||||
} // namespace xgboost::error
|
} // namespace xgboost::error
|
||||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2015-2023 by XGBoost Contributors
|
* Copyright 2015-2024, XGBoost Contributors
|
||||||
* \file data.cc
|
* \file data.cc
|
||||||
*/
|
*/
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -260,9 +260,14 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
|||||||
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
|
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<FeatureType>* types) {
|
/**
|
||||||
|
* @brief Load feature type info from names, returns whether there's categorical features.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bool LoadFeatureType(std::vector<std::string> const& type_names,
|
||||||
|
std::vector<FeatureType>* types) {
|
||||||
types->clear();
|
types->clear();
|
||||||
for (auto const &elem : type_names) {
|
bool has_cat{false};
|
||||||
|
for (auto const& elem : type_names) {
|
||||||
if (elem == "int") {
|
if (elem == "int") {
|
||||||
types->emplace_back(FeatureType::kNumerical);
|
types->emplace_back(FeatureType::kNumerical);
|
||||||
} else if (elem == "float") {
|
} else if (elem == "float") {
|
||||||
@ -273,10 +278,12 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
|
|||||||
types->emplace_back(FeatureType::kNumerical);
|
types->emplace_back(FeatureType::kNumerical);
|
||||||
} else if (elem == "c") {
|
} else if (elem == "c") {
|
||||||
types->emplace_back(FeatureType::kCategorical);
|
types->emplace_back(FeatureType::kCategorical);
|
||||||
|
has_cat = true;
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}.";
|
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return has_cat;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
|
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
|
||||||
@ -340,7 +347,8 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
|||||||
LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names);
|
LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names);
|
||||||
LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names);
|
LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names);
|
||||||
LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights);
|
LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights);
|
||||||
LoadFeatureType(feature_type_names, &feature_types.HostVector());
|
|
||||||
|
this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -639,6 +647,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
|
|||||||
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||||
CHECK(info);
|
CHECK(info);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!std::strcmp(key, "feature_type")) {
|
if (!std::strcmp(key, "feature_type")) {
|
||||||
feature_type_names.clear();
|
feature_type_names.clear();
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
@ -651,7 +660,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
|
|||||||
<< "Length of " << key << " must be equal to number of columns.";
|
<< "Length of " << key << " must be equal to number of columns.";
|
||||||
}
|
}
|
||||||
auto& h_feature_types = feature_types.HostVector();
|
auto& h_feature_types = feature_types.HostVector();
|
||||||
LoadFeatureType(feature_type_names, &h_feature_types);
|
this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types);
|
||||||
} else if (!std::strcmp(key, "feature_name")) {
|
} else if (!std::strcmp(key, "feature_name")) {
|
||||||
if (IsColumnSplit()) {
|
if (IsColumnSplit()) {
|
||||||
std::vector<std::string> local_feature_names{};
|
std::vector<std::string> local_feature_names{};
|
||||||
@ -674,9 +683,8 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetaInfo::GetFeatureInfo(const char *field,
|
void MetaInfo::GetFeatureInfo(const char* field, std::vector<std::string>* out_str_vecs) const {
|
||||||
std::vector<std::string> *out_str_vecs) const {
|
auto& str_vecs = *out_str_vecs;
|
||||||
auto &str_vecs = *out_str_vecs;
|
|
||||||
if (!std::strcmp(field, "feature_type")) {
|
if (!std::strcmp(field, "feature_type")) {
|
||||||
str_vecs.resize(feature_type_names.size());
|
str_vecs.resize(feature_type_names.size());
|
||||||
std::copy(feature_type_names.cbegin(), feature_type_names.cend(), str_vecs.begin());
|
std::copy(feature_type_names.cbegin(), feature_type_names.cend(), str_vecs.begin());
|
||||||
@ -689,6 +697,9 @@ void MetaInfo::GetFeatureInfo(const char *field,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_column) {
|
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_column) {
|
||||||
|
/**
|
||||||
|
* shape
|
||||||
|
*/
|
||||||
if (accumulate_rows) {
|
if (accumulate_rows) {
|
||||||
this->num_row_ += that.num_row_;
|
this->num_row_ += that.num_row_;
|
||||||
}
|
}
|
||||||
@ -702,6 +713,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|||||||
}
|
}
|
||||||
this->num_col_ = that.num_col_;
|
this->num_col_ = that.num_col_;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* info with n_samples
|
||||||
|
*/
|
||||||
linalg::Stack(&this->labels, that.labels);
|
linalg::Stack(&this->labels, that.labels);
|
||||||
|
|
||||||
this->weights_.SetDevice(that.weights_.Device());
|
this->weights_.SetDevice(that.weights_.Device());
|
||||||
@ -715,6 +729,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|||||||
|
|
||||||
linalg::Stack(&this->base_margin_, that.base_margin_);
|
linalg::Stack(&this->base_margin_, that.base_margin_);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* group
|
||||||
|
*/
|
||||||
if (this->group_ptr_.size() == 0) {
|
if (this->group_ptr_.size() == 0) {
|
||||||
this->group_ptr_ = that.group_ptr_;
|
this->group_ptr_ = that.group_ptr_;
|
||||||
} else {
|
} else {
|
||||||
@ -727,17 +744,25 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|||||||
group_ptr.end());
|
group_ptr.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* info with n_features
|
||||||
|
*/
|
||||||
if (!that.feature_names.empty()) {
|
if (!that.feature_names.empty()) {
|
||||||
this->feature_names = that.feature_names;
|
this->feature_names = that.feature_names;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!that.feature_type_names.empty()) {
|
if (!that.feature_type_names.empty()) {
|
||||||
this->feature_type_names = that.feature_type_names;
|
this->feature_type_names = that.feature_type_names;
|
||||||
auto &h_feature_types = feature_types.HostVector();
|
auto& h_feature_types = feature_types.HostVector();
|
||||||
LoadFeatureType(this->feature_type_names, &h_feature_types);
|
this->has_categorical_ = LoadFeatureType(this->feature_type_names, &h_feature_types);
|
||||||
} else if (!that.feature_types.Empty()) {
|
} else if (!that.feature_types.Empty()) {
|
||||||
|
// FIXME(jiamingy): https://github.com/dmlc/xgboost/pull/9171/files#r1440188612
|
||||||
this->feature_types.Resize(that.feature_types.Size());
|
this->feature_types.Resize(that.feature_types.Size());
|
||||||
this->feature_types.Copy(that.feature_types);
|
this->feature_types.Copy(that.feature_types);
|
||||||
|
auto const& ft = this->feature_types.ConstHostVector();
|
||||||
|
this->has_categorical_ = std::any_of(ft.cbegin(), ft.cend(), common::IsCatOp{});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!that.feature_weights.Empty()) {
|
if (!that.feature_weights.Empty()) {
|
||||||
this->feature_weights.Resize(that.feature_weights.Size());
|
this->feature_weights.Resize(that.feature_weights.Size());
|
||||||
this->feature_weights.SetDevice(that.feature_weights.Device());
|
this->feature_weights.SetDevice(that.feature_weights.Device());
|
||||||
|
|||||||
@ -93,7 +93,7 @@ class IterativeDMatrix : public DMatrix {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
BatchSet<SparsePage> GetRowBatches() override {
|
BatchSet<SparsePage> GetRowBatches() override {
|
||||||
LOG(FATAL) << "Not implemented.";
|
LOG(FATAL) << "Not implemented for `QuantileDMatrix`.";
|
||||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||||
}
|
}
|
||||||
BatchSet<CSCPage> GetColumnBatches(Context const *) override {
|
BatchSet<CSCPage> GetColumnBatches(Context const *) override {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2014-2023, XGBoost Contributors
|
* Copyright 2014-2024, XGBoost Contributors
|
||||||
* \file gblinear.cc
|
* \file gblinear.cc
|
||||||
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
|
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
|
||||||
* the update rule is parallel coordinate descent (shotgun)
|
* the update rule is parallel coordinate descent (shotgun)
|
||||||
@ -8,25 +8,24 @@
|
|||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
#include <dmlc/parameter.h>
|
#include <dmlc/parameter.h>
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <sstream>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/common.h"
|
||||||
|
#include "../common/error_msg.h" // NoCategorical, DeprecatedFunc
|
||||||
|
#include "../common/threading_utils.h"
|
||||||
|
#include "../common/timer.h"
|
||||||
|
#include "gblinear_model.h"
|
||||||
#include "xgboost/gbm.h"
|
#include "xgboost/gbm.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/predictor.h"
|
|
||||||
#include "xgboost/linear_updater.h"
|
|
||||||
#include "xgboost/logging.h"
|
|
||||||
#include "xgboost/learner.h"
|
#include "xgboost/learner.h"
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h"
|
||||||
|
#include "xgboost/linear_updater.h"
|
||||||
#include "gblinear_model.h"
|
#include "xgboost/logging.h"
|
||||||
#include "../common/timer.h"
|
#include "xgboost/predictor.h"
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/threading_utils.h"
|
|
||||||
#include "../common/error_msg.h"
|
|
||||||
|
|
||||||
namespace xgboost::gbm {
|
namespace xgboost::gbm {
|
||||||
DMLC_REGISTRY_FILE_TAG(gblinear);
|
DMLC_REGISTRY_FILE_TAG(gblinear);
|
||||||
@ -145,6 +144,7 @@ class GBLinear : public GradientBooster {
|
|||||||
ObjFunction const*) override {
|
ObjFunction const*) override {
|
||||||
monitor_.Start("DoBoost");
|
monitor_.Start("DoBoost");
|
||||||
|
|
||||||
|
CHECK(!p_fmat->Info().HasCategorical()) << error::NoCategorical("`gblinear`");
|
||||||
model_.LazyInitModel();
|
model_.LazyInitModel();
|
||||||
this->LazySumWeights(p_fmat);
|
this->LazySumWeights(p_fmat);
|
||||||
|
|
||||||
|
|||||||
@ -1,21 +1,22 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2014-2023 by XGBoost Contributors
|
* Copyright 2014-2024, XGBoost Contributors
|
||||||
* \file updater_colmaker.cc
|
* \file updater_colmaker.cc
|
||||||
* \brief use columnwise update to construct a tree
|
* \brief use columnwise update to construct a tree
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <vector>
|
|
||||||
#include <cmath>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/error_msg.h" // for NoCategorical
|
||||||
|
#include "../common/random.h"
|
||||||
|
#include "constraints.h"
|
||||||
|
#include "param.h"
|
||||||
|
#include "split_evaluator.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/parameter.h"
|
#include "xgboost/parameter.h"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
#include "xgboost/logging.h"
|
|
||||||
#include "xgboost/json.h"
|
|
||||||
#include "param.h"
|
|
||||||
#include "constraints.h"
|
|
||||||
#include "../common/random.h"
|
|
||||||
#include "split_evaluator.h"
|
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
|
|
||||||
@ -102,6 +103,9 @@ class ColMaker: public TreeUpdater {
|
|||||||
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
||||||
"support external memory training.";
|
"support external memory training.";
|
||||||
}
|
}
|
||||||
|
if (dmat->Info().HasCategorical()) {
|
||||||
|
LOG(FATAL) << error::NoCategorical("Updater `grow_colmaker` or `exact` tree method");
|
||||||
|
}
|
||||||
this->LazyGetColumnDensity(dmat);
|
this->LazyGetColumnDensity(dmat);
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
|
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
@ -195,3 +197,39 @@ def test_data_cache() -> None:
|
|||||||
assert called == 1
|
assert called == 1
|
||||||
|
|
||||||
xgb.data._proxy_transform = transform
|
xgb.data._proxy_transform = transform
|
||||||
|
|
||||||
|
|
||||||
|
def test_cat_check() -> None:
|
||||||
|
n_batches = 3
|
||||||
|
n_features = 2
|
||||||
|
n_samples_per_batch = 16
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
|
||||||
|
for i in range(n_batches):
|
||||||
|
X, y = tm.make_categorical(
|
||||||
|
n_samples=n_samples_per_batch,
|
||||||
|
n_features=n_features,
|
||||||
|
n_categories=3,
|
||||||
|
onehot=False,
|
||||||
|
)
|
||||||
|
batches.append((X, y))
|
||||||
|
|
||||||
|
X, y = list(zip(*batches))
|
||||||
|
it = tm.IteratorForTest(X, y, None, cache=None)
|
||||||
|
Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="categorical features"):
|
||||||
|
xgb.train({"tree_method": "exact"}, Xy)
|
||||||
|
|
||||||
|
Xy = xgb.DMatrix(X[0], y[0], enable_categorical=True)
|
||||||
|
with pytest.raises(ValueError, match="categorical features"):
|
||||||
|
xgb.train({"tree_method": "exact"}, Xy)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cache_path = os.path.join(tmpdir, "cache")
|
||||||
|
|
||||||
|
it = tm.IteratorForTest(X, y, None, cache=cache_path)
|
||||||
|
Xy = xgb.DMatrix(it, enable_categorical=True)
|
||||||
|
with pytest.raises(ValueError, match="categorical features"):
|
||||||
|
xgb.train({"booster": "gblinear"}, Xy)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user