methods to delete an attribute and get names of available attributes
This commit is contained in:
parent
9c26566eb0
commit
ea9285dd4f
@ -433,16 +433,27 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
||||
const char** out,
|
||||
int *success);
|
||||
/*!
|
||||
* \brief Set string attribute.
|
||||
* \brief Set or delete string attribute.
|
||||
*
|
||||
* \param handle handle
|
||||
* \param key The key of the symbol.
|
||||
* \param value The value to be saved.
|
||||
* \param key The key of the attribute.
|
||||
* \param value The value to be saved.
|
||||
* If nullptr, the attribute would be deleted.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char* value);
|
||||
/*!
|
||||
* \brief Get the names of all attribute from Booster.
|
||||
* \param handle handle
|
||||
* \param len the argument to hold the output length
|
||||
* \param out pointer to hold the output attribute stings
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
bst_ulong* out_len,
|
||||
const char*** out);
|
||||
|
||||
// --- Distributed training API----
|
||||
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
|
||||
|
||||
@ -121,9 +121,20 @@ class Learner : public rabit::Serializable {
|
||||
* The property will be saved along the booster.
|
||||
* \param key The key of the attribute.
|
||||
* \param out The output value.
|
||||
* \return Whether the key is contained in the attribute.
|
||||
* \return Whether the key exists among booster's attributes.
|
||||
*/
|
||||
virtual bool GetAttr(const std::string& key, std::string* out) const = 0;
|
||||
/*!
|
||||
* \brief Delete an attribute from the booster.
|
||||
* \param key The key of the attribute.
|
||||
* \return Whether the key was found among booster's attributes.
|
||||
*/
|
||||
virtual bool DelAttr(const std::string& key) = 0;
|
||||
/*!
|
||||
* \brief Get a vector of attribute names from the booster.
|
||||
* \return vector of attribute name strings.
|
||||
*/
|
||||
virtual std::vector<std::string> GetAttrNames() const = 0;
|
||||
/*!
|
||||
* \return whether the model allow lazy checkpoint in rabit.
|
||||
*/
|
||||
|
||||
@ -685,7 +685,28 @@ int XGBoosterSetAttr(BoosterHandle handle,
|
||||
const char* value) {
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
API_BEGIN();
|
||||
bst->learner()->SetAttr(key, value);
|
||||
if (value == nullptr) {
|
||||
bst->learner()->DelAttr(key);
|
||||
} else {
|
||||
bst->learner()->SetAttr(key, value);
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
bst_ulong* out_len,
|
||||
const char*** out) {
|
||||
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
|
||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
API_BEGIN();
|
||||
str_vecs = bst->learner()->GetAttrNames();
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
}
|
||||
*out = dmlc::BeginPtr(charp_vecs);
|
||||
*out_len = static_cast<bst_ulong>(charp_vecs.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -330,6 +330,22 @@ class LearnerImpl : public Learner {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DelAttr(const std::string& key) override {
|
||||
auto it = attributes_.find(key);
|
||||
if (it == attributes_.end()) return false;
|
||||
attributes_.erase(it);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::string> GetAttrNames() const override {
|
||||
std::vector<std::string> out;
|
||||
out.reserve(attributes_.size());
|
||||
for(auto &p: attributes_) {
|
||||
out.push_back(p.first);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
|
||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user