add auc evaluation metric
This commit is contained in:
parent
88787b8573
commit
7487c2f668
@ -11,6 +11,7 @@
|
||||
#include <algorithm>
|
||||
#include "../utils/xgboost_utils.h"
|
||||
#include "../utils/xgboost_omp.h"
|
||||
#include "../utils/xgboost_random.h"
|
||||
|
||||
namespace xgboost{
|
||||
namespace regression{
|
||||
@ -67,6 +68,38 @@ namespace xgboost{
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Area under curve */
|
||||
struct EvalAuc : public IEvaluator{
|
||||
inline static bool CmpFirst( const std::pair<float,float> &a, const std::pair<float,float> &b ){
|
||||
return a.first > b.first;
|
||||
}
|
||||
virtual float Eval( const std::vector<float> &preds,
|
||||
const std::vector<float> &labels ) const{
|
||||
const unsigned ndata = static_cast<unsigned>( preds.size() );
|
||||
std::vector< std::pair<float, float> > rec;
|
||||
for( unsigned i = 0; i < ndata; ++ i ){
|
||||
rec.push_back( std::make_pair( preds[i], labels[i]) );
|
||||
}
|
||||
random::Shuffle( rec );
|
||||
std::sort( rec.begin(), rec.end(), CmpFirst );
|
||||
|
||||
long npos = 0, nhit = 0;
|
||||
for( unsigned i = 0; i < ndata; ++ i ){
|
||||
if( rec[i].second > 0.5f ) {
|
||||
++ npos;
|
||||
}else{
|
||||
// this is the number of correct pairs
|
||||
nhit += npos;
|
||||
}
|
||||
}
|
||||
long nneg = ndata - npos;
|
||||
utils::Assert( nneg > 0, "the dataset only contains pos samples" );
|
||||
return static_cast<float>(nhit) / nneg / npos;
|
||||
}
|
||||
virtual const char *Name( void ) const{
|
||||
return "auc";
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Error */
|
||||
struct EvalLogLoss : public IEvaluator{
|
||||
@ -96,6 +129,7 @@ namespace xgboost{
|
||||
if (!strcmp(name, "rmse")) evals_.push_back(&rmse_);
|
||||
if (!strcmp(name, "error")) evals_.push_back(&error_);
|
||||
if (!strcmp(name, "logloss")) evals_.push_back(&logloss_);
|
||||
if (!strcmp( name, "auc")) evals_.push_back( &auc_ );
|
||||
}
|
||||
inline void Init(void){
|
||||
std::sort(evals_.begin(), evals_.end());
|
||||
@ -112,6 +146,7 @@ namespace xgboost{
|
||||
private:
|
||||
EvalRMSE rmse_;
|
||||
EvalError error_;
|
||||
EvalAuc auc_;
|
||||
EvalLogLoss logloss_;
|
||||
std::vector<const IEvaluator*> evals_;
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user