add softmax

This commit is contained in:
tqchen 2014-04-30 22:11:26 -07:00
parent 38577d45b0
commit ec14d32756
2 changed files with 66 additions and 0 deletions

View File

@ -107,6 +107,7 @@ namespace xgboost{
IObjFunction* CreateObjFunction( const char *name ){
if( !strcmp("reg", name ) ) return new RegressionObj();
if( !strcmp("rank", name ) ) return new PairwiseRankObj();
if( !strcmp("softmax", name ) ) return new SoftmaxObj();
utils::Error("unknown objective function type");
return NULL;
}

View File

@ -47,6 +47,71 @@ namespace xgboost{
};
};
namespace regrank{
// simple softmax rak
class SoftmaxObj : public IObjFunction{
public:
SoftmaxObj(void){
}
virtual ~SoftmaxObj(){}
virtual void SetParam(const char *name, const char *val){
}
virtual void GetGradient(const std::vector<float>& preds,
const DMatrix::Info &info,
int iter,
std::vector<float> &grad,
std::vector<float> &hess ) {
grad.resize(preds.size()); hess.resize(preds.size());
const std::vector<unsigned> &gptr = info.group_ptr;
utils::Assert( gptr.size() != 0 && gptr.back() == preds.size(), "rank loss must have group file" );
const unsigned ngroup = static_cast<unsigned>( gptr.size() - 1 );
#pragma omp parallel
{
std::vector< float > rec;
#pragma for schedule(static)
for (unsigned k = 0; k < ngroup; ++k){
rec.clear();
int nhit = 0;
for(unsigned j = gptr[k]; j < gptr[k+1]; ++j ){
rec.push_back( preds[j] );
grad[j] = hess[j] = 0.0f;
nhit += info.labels[j];
}
Softmax( rec );
if( nhit == 1 ){
for(unsigned j = gptr[k]; j < gptr[k+1]; ++j ){
float p = rec[ j - gptr[k] ];
grad[j] = p - info.labels[j];
hess[j] = 2.0f * p * ( 1.0f - p );
}
}else{
utils::Assert( nhit == 0, "softmax does not allow multiple labels" );
}
}
}
}
virtual const char* DefaultEvalMetric(void) {
return "pre@1";
}
private:
inline static void Softmax( std::vector<float>& rec ){
float wmax = rec[0];
for( size_t i = 1; i < rec.size(); ++ i ){
wmax = std::max( rec[i], wmax );
}
double wsum = 0.0f;
for( size_t i = 0; i < rec.size(); ++ i ){
rec[i] = expf(rec[i]-wmax);
wsum += rec[i];
}
for( size_t i = 0; i < rec.size(); ++ i ){
rec[i] /= wsum;
}
}
};
};
namespace regrank{
// simple pairwise rank
class PairwiseRankObj : public IObjFunction{