Add XGBoosterGetNumFeature (#5856)
- add GetNumFeature to Learner - add XGBoosterGetNumFeature to C API - update c-api-demo accordingly
This commit is contained in:
parent
e0c179c7cc
commit
970b4b3fa2
@ -60,6 +60,10 @@ int main(int argc, char** argv) {
|
||||
printf("%s\n", eval_result);
|
||||
}
|
||||
|
||||
bst_ulong num_feature = 0;
|
||||
safe_xgboost(XGBoosterGetNumFeature(booster, &num_feature));
|
||||
printf("num_feature: %llu\n", num_feature);
|
||||
|
||||
// predict
|
||||
bst_ulong out_len = 0;
|
||||
const float* out_result = NULL;
|
||||
|
||||
@ -563,6 +563,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
||||
const char *name,
|
||||
const char *value);
|
||||
|
||||
/*!
|
||||
* \brief get number of features
|
||||
* \param out number of features
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
|
||||
bst_ulong *out);
|
||||
|
||||
/*!
|
||||
* \brief update the model in one round using dtrain
|
||||
* \param handle handle
|
||||
|
||||
@ -158,6 +158,12 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
|
||||
*/
|
||||
virtual void SetParam(const std::string& key, const std::string& value) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Get the number of features of the booster.
|
||||
* \return number of features
|
||||
*/
|
||||
virtual uint32_t GetNumFeature() = 0;
|
||||
|
||||
/*!
|
||||
* \brief Set additional attribute to the Booster.
|
||||
*
|
||||
|
||||
@ -395,6 +395,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
*out = static_cast<Learner*>(handle)->GetNumFeature();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
@ -384,6 +384,10 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t GetNumFeature() override {
|
||||
return learner_model_param_.num_feature;
|
||||
}
|
||||
|
||||
void SetAttr(const std::string& key, const std::string& value) override {
|
||||
attributes_[key] = value;
|
||||
mparam_.contain_extra_attrs = 1;
|
||||
|
||||
@ -112,9 +112,10 @@ TEST(CAPI, ConfigIO) {
|
||||
|
||||
TEST(CAPI, JsonModelIO) {
|
||||
size_t constexpr kRows = 10;
|
||||
size_t constexpr kCols = 10;
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
|
||||
std::vector<bst_float> labels(kRows);
|
||||
for (size_t i = 0; i < labels.size(); ++i) {
|
||||
@ -131,6 +132,10 @@ TEST(CAPI, JsonModelIO) {
|
||||
XGBoosterSaveModel(handle, modelfile_0.c_str());
|
||||
XGBoosterLoadModel(handle, modelfile_0.c_str());
|
||||
|
||||
bst_ulong num_feature {0};
|
||||
ASSERT_EQ(XGBoosterGetNumFeature(handle, &num_feature), 0);
|
||||
ASSERT_EQ(num_feature, kCols);
|
||||
|
||||
std::string modelfile_1 = tempdir.path + "/model_1.json";
|
||||
XGBoosterSaveModel(handle, modelfile_1.c_str());
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user