now support distributed evaluation
This commit is contained in:
@@ -11,6 +11,8 @@
|
||||
#include <cmath>
|
||||
#include <climits>
|
||||
#include <algorithm>
|
||||
// rabit library for synchronization
|
||||
#include <rabit.h>
|
||||
#include "./evaluation.h"
|
||||
#include "./helper_utils.h"
|
||||
|
||||
@@ -23,7 +25,8 @@ namespace learner {
|
||||
template<typename Derived>
|
||||
struct EvalEWiseBase : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
@@ -37,7 +40,11 @@ struct EvalEWiseBase : public IEvaluator {
|
||||
sum += Derived::EvalRow(info.labels[i], preds[i]) * wt;
|
||||
wsum += wt;
|
||||
}
|
||||
return Derived::GetFinal(sum, wsum);
|
||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
@@ -113,7 +120,9 @@ struct EvalCTest: public IEvaluator {
|
||||
return name_.c_str();
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(!distributed, "metric %s do not support distributed evaluation", name_.c_str());
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
size_t ngroup = preds.size() / info.labels.size() - 1;
|
||||
@@ -150,7 +159,9 @@ struct EvalAMS : public IEvaluator {
|
||||
utils::Check(std::sscanf(name, "ams@%f", &ratio_) == 1, "invalid ams format");
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(!distributed, "metric AMS do not support distributed evaluation");
|
||||
using namespace std;
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
|
||||
@@ -212,7 +223,9 @@ struct EvalPrecisionRatio : public IEvaluator{
|
||||
}
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(!distributed, "metric %s do not support distributed evaluation", Name());
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Assert(preds.size() % info.labels.size() == 0,
|
||||
"label size predict size not match");
|
||||
@@ -252,7 +265,8 @@ struct EvalPrecisionRatio : public IEvaluator{
|
||||
/*! \brief Area under curve, for both classification and rank */
|
||||
struct EvalAuc : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label size predict size not match");
|
||||
@@ -299,8 +313,14 @@ struct EvalAuc : public IEvaluator {
|
||||
sum_auc += sum_pospair / (sum_npos*sum_nneg);
|
||||
}
|
||||
}
|
||||
// return average AUC over list
|
||||
return static_cast<float>(sum_auc) / ngroup;
|
||||
if (distributed) {
|
||||
float dat[2]; dat[0] = sum_auc; dat[1] = ngroup;
|
||||
// approximately estimate auc using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<float>(sum_auc) / ngroup;
|
||||
}
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return "auc";
|
||||
@@ -311,7 +331,8 @@ struct EvalAuc : public IEvaluator {
|
||||
struct EvalRankList : public IEvaluator {
|
||||
public:
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
"label size predict size not match");
|
||||
// quick consistency when group is not available
|
||||
@@ -336,7 +357,14 @@ struct EvalRankList : public IEvaluator {
|
||||
sum_metric += this->EvalMetric(rec);
|
||||
}
|
||||
}
|
||||
return static_cast<float>(sum_metric) / ngroup;
|
||||
if (distributed) {
|
||||
float dat[2]; dat[0] = sum_metric; dat[1] = ngroup;
|
||||
// approximately estimate auc using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<float>(sum_metric) / ngroup;
|
||||
}
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return name_.c_str();
|
||||
|
||||
@@ -19,9 +19,13 @@ struct IEvaluator{
|
||||
* \brief evaluate a specific metric
|
||||
* \param preds prediction
|
||||
* \param info information, including label etc.
|
||||
* \param distributed whether a call to Allreduce is needed to gather
|
||||
* the average statistics across all the node,
|
||||
* this is only supported by some metrics
|
||||
*/
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const = 0;
|
||||
const MetaInfo &info,
|
||||
bool distributed = false) const = 0;
|
||||
/*! \return name of metric */
|
||||
virtual const char *Name(void) const = 0;
|
||||
/*! \brief virtual destructor */
|
||||
@@ -70,10 +74,11 @@ class EvalSet{
|
||||
}
|
||||
inline std::string Eval(const char *evname,
|
||||
const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const MetaInfo &info,
|
||||
bool distributed = false) {
|
||||
std::string result = "";
|
||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
||||
float res = evals_[i]->Eval(preds, info);
|
||||
float res = evals_[i]->Eval(preds, info, distributed);
|
||||
char tmp[1024];
|
||||
utils::SPrintf(tmp, sizeof(tmp), "\t%s-%s:%f", evname, evals_[i]->Name(), res);
|
||||
result += tmp;
|
||||
|
||||
@@ -287,7 +287,7 @@ class BoostLearner : public rabit::ISerializable {
|
||||
for (size_t i = 0; i < evals.size(); ++i) {
|
||||
this->PredictRaw(*evals[i], &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
res += evaluator_.Eval(evname[i].c_str(), preds_, evals[i]->info);
|
||||
res += evaluator_.Eval(evname[i].c_str(), preds_, evals[i]->info, distributed_mode == 2);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user