Add XGBoosterGetNumFeature (#5856)

- add GetNumFeature to Learner
- add XGBoosterGetNumFeature to C API
- update c-api-demo accordingly
This commit is contained in:
Alexander Gugel 2020-07-14 07:25:17 +01:00 committed by GitHub
parent e0c179c7cc
commit 970b4b3fa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 1 deletions

View File

@ -60,6 +60,10 @@ int main(int argc, char** argv) {
printf("%s\n", eval_result); printf("%s\n", eval_result);
} }
bst_ulong num_feature = 0;
safe_xgboost(XGBoosterGetNumFeature(booster, &num_feature));
printf("num_feature: %llu\n", num_feature);
// predict // predict
bst_ulong out_len = 0; bst_ulong out_len = 0;
const float* out_result = NULL; const float* out_result = NULL;

View File

@ -563,6 +563,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
const char *name, const char *name,
const char *value); const char *value);
/*!
* \brief get number of features
* \param out number of features
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
bst_ulong *out);
/*! /*!
* \brief update the model in one round using dtrain * \brief update the model in one round using dtrain
* \param handle handle * \param handle handle

View File

@ -158,6 +158,12 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
*/ */
virtual void SetParam(const std::string& key, const std::string& value) = 0; virtual void SetParam(const std::string& key, const std::string& value) = 0;
/*!
* \brief Get the number of features of the booster.
* \return number of features
*/
virtual uint32_t GetNumFeature() = 0;
/*! /*!
* \brief Set additional attribute to the Booster. * \brief Set additional attribute to the Booster.
* *

View File

@ -395,6 +395,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
*out = static_cast<Learner*>(handle)->GetNumFeature();
API_END();
}
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) { XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();

View File

@ -384,6 +384,10 @@ class LearnerConfiguration : public Learner {
} }
} }
uint32_t GetNumFeature() override {
return learner_model_param_.num_feature;
}
void SetAttr(const std::string& key, const std::string& value) override { void SetAttr(const std::string& key, const std::string& value) override {
attributes_[key] = value; attributes_[key] = value;
mparam_.contain_extra_attrs = 1; mparam_.contain_extra_attrs = 1;

View File

@ -112,9 +112,10 @@ TEST(CAPI, ConfigIO) {
TEST(CAPI, JsonModelIO) { TEST(CAPI, JsonModelIO) {
size_t constexpr kRows = 10; size_t constexpr kRows = 10;
size_t constexpr kCols = 10;
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix(); auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat}; std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
std::vector<bst_float> labels(kRows); std::vector<bst_float> labels(kRows);
for (size_t i = 0; i < labels.size(); ++i) { for (size_t i = 0; i < labels.size(); ++i) {
@ -131,6 +132,10 @@ TEST(CAPI, JsonModelIO) {
XGBoosterSaveModel(handle, modelfile_0.c_str()); XGBoosterSaveModel(handle, modelfile_0.c_str());
XGBoosterLoadModel(handle, modelfile_0.c_str()); XGBoosterLoadModel(handle, modelfile_0.c_str());
bst_ulong num_feature {0};
ASSERT_EQ(XGBoosterGetNumFeature(handle, &num_feature), 0);
ASSERT_EQ(num_feature, kCols);
std::string modelfile_1 = tempdir.path + "/model_1.json"; std::string modelfile_1 = tempdir.path + "/model_1.json";
XGBoosterSaveModel(handle, modelfile_1.c_str()); XGBoosterSaveModel(handle, modelfile_1.c_str());