diff --git a/demo/c-api/c-api-demo.c b/demo/c-api/c-api-demo.c index c476357bd..9b1837f0e 100644 --- a/demo/c-api/c-api-demo.c +++ b/demo/c-api/c-api-demo.c @@ -60,6 +60,10 @@ int main(int argc, char** argv) { printf("%s\n", eval_result); } + bst_ulong num_feature = 0; + safe_xgboost(XGBoosterGetNumFeature(booster, &num_feature)); + printf("num_feature: %llu\n", num_feature); + // predict bst_ulong out_len = 0; const float* out_result = NULL; diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index c117f62f3..794cbdf19 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -563,6 +563,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle, const char *name, const char *value); +/*! + * \brief get number of features + * \param out number of features + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle, + bst_ulong *out); + /*! * \brief update the model in one round using dtrain * \param handle handle diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index a608bc1b8..bfe9fd4d5 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -158,6 +158,12 @@ class Learner : public Model, public Configurable, public rabit::Serializable { */ virtual void SetParam(const std::string& key, const std::string& value) = 0; + /*! + * \brief Get the number of features of the booster. + * \return number of features + */ + virtual uint32_t GetNumFeature() = 0; + /*! * \brief Set additional attribute to the Booster. * diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 824e46f71..c3d814d44 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -395,6 +395,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle, API_END(); } +XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle, + xgboost::bst_ulong *out) { + API_BEGIN(); + CHECK_HANDLE(); + *out = static_cast(handle)->GetNumFeature(); + API_END(); +} + XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) { API_BEGIN(); CHECK_HANDLE(); diff --git a/src/learner.cc b/src/learner.cc index ebfdeccc3..dcb6c5783 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -384,6 +384,10 @@ class LearnerConfiguration : public Learner { } } + uint32_t GetNumFeature() override { + return learner_model_param_.num_feature; + } + void SetAttr(const std::string& key, const std::string& value) override { attributes_[key] = value; mparam_.contain_extra_attrs = 1; diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 2ba36a16d..f4c2722fe 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -112,9 +112,10 @@ TEST(CAPI, ConfigIO) { TEST(CAPI, JsonModelIO) { size_t constexpr kRows = 10; + size_t constexpr kCols = 10; dmlc::TemporaryDirectory tempdir; - auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix(); + auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); std::vector> mat {p_dmat}; std::vector labels(kRows); for (size_t i = 0; i < labels.size(); ++i) { @@ -131,6 +132,10 @@ TEST(CAPI, JsonModelIO) { XGBoosterSaveModel(handle, modelfile_0.c_str()); XGBoosterLoadModel(handle, modelfile_0.c_str()); + bst_ulong num_feature {0}; + ASSERT_EQ(XGBoosterGetNumFeature(handle, &num_feature), 0); + ASSERT_EQ(num_feature, kCols); + std::string modelfile_1 = tempdir.path + "/model_1.json"; XGBoosterSaveModel(handle, modelfile_1.c_str());