From ea9285dd4fbb03593770330a96c7deae6a4de79e Mon Sep 17 00:00:00 2001 From: Vadim Khotilovich Date: Sat, 14 May 2016 18:09:21 -0500 Subject: [PATCH] methods to delete an attribute and get names of available attributes --- include/xgboost/c_api.h | 17 ++++++++++++++--- include/xgboost/learner.h | 13 ++++++++++++- src/c_api/c_api.cc | 23 ++++++++++++++++++++++- src/learner.cc | 16 ++++++++++++++++ 4 files changed, 64 insertions(+), 5 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 1598aac3b..e6c09cd0b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -433,16 +433,27 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, const char** out, int *success); /*! - * \brief Set string attribute. + * \brief Set or delete string attribute. * * \param handle handle - * \param key The key of the symbol. - * \param value The value to be saved. + * \param key The key of the attribute. + * \param value The value to be saved. + * If nullptr, the attribute would be deleted. * \return 0 when success, -1 when failure happens */ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, const char* key, const char* value); +/*! + * \brief Get the names of all attribute from Booster. + * \param handle handle + * \param len the argument to hold the output length + * \param out pointer to hold the output attribute stings + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, + bst_ulong* out_len, + const char*** out); // --- Distributed training API---- // NOTE: functions in rabit/c_api.h will be also available in libxgboost.so diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 18c782518..474437bf2 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -121,9 +121,20 @@ class Learner : public rabit::Serializable { * The property will be saved along the booster. * \param key The key of the attribute. * \param out The output value. - * \return Whether the key is contained in the attribute. + * \return Whether the key exists among booster's attributes. */ virtual bool GetAttr(const std::string& key, std::string* out) const = 0; + /*! + * \brief Delete an attribute from the booster. + * \param key The key of the attribute. + * \return Whether the key was found among booster's attributes. + */ + virtual bool DelAttr(const std::string& key) = 0; + /*! + * \brief Get a vector of attribute names from the booster. + * \return vector of attribute name strings. + */ + virtual std::vector GetAttrNames() const = 0; /*! * \return whether the model allow lazy checkpoint in rabit. */ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index cf6be2046..37fb92c24 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -685,7 +685,28 @@ int XGBoosterSetAttr(BoosterHandle handle, const char* value) { Booster* bst = static_cast(handle); API_BEGIN(); - bst->learner()->SetAttr(key, value); + if (value == nullptr) { + bst->learner()->DelAttr(key); + } else { + bst->learner()->SetAttr(key, value); + } + API_END(); +} + +int XGBoosterGetAttrNames(BoosterHandle handle, + bst_ulong* out_len, + const char*** out) { + std::vector& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str; + std::vector& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp; + Booster *bst = static_cast(handle); + API_BEGIN(); + str_vecs = bst->learner()->GetAttrNames(); + charp_vecs.resize(str_vecs.size()); + for (size_t i = 0; i < str_vecs.size(); ++i) { + charp_vecs[i] = str_vecs[i].c_str(); + } + *out = dmlc::BeginPtr(charp_vecs); + *out_len = static_cast(charp_vecs.size()); API_END(); } diff --git a/src/learner.cc b/src/learner.cc index a88f967c4..971fa45bd 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -330,6 +330,22 @@ class LearnerImpl : public Learner { return true; } + bool DelAttr(const std::string& key) override { + auto it = attributes_.find(key); + if (it == attributes_.end()) return false; + attributes_.erase(it); + return true; + } + + std::vector GetAttrNames() const override { + std::vector out; + out.reserve(attributes_.size()); + for(auto &p: attributes_) { + out.push_back(p.first); + } + return out; + } + std::pair Evaluate(DMatrix* data, std::string metric) { if (metric == "auto") metric = obj_->DefaultEvalMetric(); std::unique_ptr ev(Metric::Create(metric.c_str()));