Add XGBoosterGetNumFeature (#5856)
- add GetNumFeature to Learner - add XGBoosterGetNumFeature to C API - update c-api-demo accordingly
This commit is contained in:
parent
e0c179c7cc
commit
970b4b3fa2
@ -60,6 +60,10 @@ int main(int argc, char** argv) {
|
|||||||
printf("%s\n", eval_result);
|
printf("%s\n", eval_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bst_ulong num_feature = 0;
|
||||||
|
safe_xgboost(XGBoosterGetNumFeature(booster, &num_feature));
|
||||||
|
printf("num_feature: %llu\n", num_feature);
|
||||||
|
|
||||||
// predict
|
// predict
|
||||||
bst_ulong out_len = 0;
|
bst_ulong out_len = 0;
|
||||||
const float* out_result = NULL;
|
const float* out_result = NULL;
|
||||||
|
|||||||
@ -563,6 +563,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
|||||||
const char *name,
|
const char *name,
|
||||||
const char *value);
|
const char *value);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief get number of features
|
||||||
|
* \param out number of features
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
|
||||||
|
bst_ulong *out);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief update the model in one round using dtrain
|
* \brief update the model in one round using dtrain
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
@ -158,6 +158,12 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
|
|||||||
*/
|
*/
|
||||||
virtual void SetParam(const std::string& key, const std::string& value) = 0;
|
virtual void SetParam(const std::string& key, const std::string& value) = 0;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get the number of features of the booster.
|
||||||
|
* \return number of features
|
||||||
|
*/
|
||||||
|
virtual uint32_t GetNumFeature() = 0;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Set additional attribute to the Booster.
|
* \brief Set additional attribute to the Booster.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -395,6 +395,14 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
|
||||||
|
xgboost::bst_ulong *out) {
|
||||||
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
|
*out = static_cast<Learner*>(handle)->GetNumFeature();
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|||||||
@ -384,6 +384,10 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t GetNumFeature() override {
|
||||||
|
return learner_model_param_.num_feature;
|
||||||
|
}
|
||||||
|
|
||||||
void SetAttr(const std::string& key, const std::string& value) override {
|
void SetAttr(const std::string& key, const std::string& value) override {
|
||||||
attributes_[key] = value;
|
attributes_[key] = value;
|
||||||
mparam_.contain_extra_attrs = 1;
|
mparam_.contain_extra_attrs = 1;
|
||||||
|
|||||||
@ -112,9 +112,10 @@ TEST(CAPI, ConfigIO) {
|
|||||||
|
|
||||||
TEST(CAPI, JsonModelIO) {
|
TEST(CAPI, JsonModelIO) {
|
||||||
size_t constexpr kRows = 10;
|
size_t constexpr kRows = 10;
|
||||||
|
size_t constexpr kCols = 10;
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
|
|
||||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix();
|
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
|
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
|
||||||
std::vector<bst_float> labels(kRows);
|
std::vector<bst_float> labels(kRows);
|
||||||
for (size_t i = 0; i < labels.size(); ++i) {
|
for (size_t i = 0; i < labels.size(); ++i) {
|
||||||
@ -131,6 +132,10 @@ TEST(CAPI, JsonModelIO) {
|
|||||||
XGBoosterSaveModel(handle, modelfile_0.c_str());
|
XGBoosterSaveModel(handle, modelfile_0.c_str());
|
||||||
XGBoosterLoadModel(handle, modelfile_0.c_str());
|
XGBoosterLoadModel(handle, modelfile_0.c_str());
|
||||||
|
|
||||||
|
bst_ulong num_feature {0};
|
||||||
|
ASSERT_EQ(XGBoosterGetNumFeature(handle, &num_feature), 0);
|
||||||
|
ASSERT_EQ(num_feature, kCols);
|
||||||
|
|
||||||
std::string modelfile_1 = tempdir.path + "/model_1.json";
|
std::string modelfile_1 = tempdir.path + "/model_1.json";
|
||||||
XGBoosterSaveModel(handle, modelfile_1.c_str());
|
XGBoosterSaveModel(handle, modelfile_1.c_str());
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user