support for multiclass output prob

This commit is contained in:
tqchen 2014-08-01 11:21:17 -07:00
parent ca4b3b7541
commit 0d6b977395
4 changed files with 36 additions and 7 deletions

View File

@ -39,4 +39,11 @@ pred = bst.predict( xg_test );
print ('predicting, classification error=%f' % (sum( int(pred[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) )) print ('predicting, classification error=%f' % (sum( int(pred[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) ))
# do the same thing again, but output probabilities
param['objective'] = 'multi:softprob'
bst = xgb.train(param, xg_train, num_round, watchlist );
# get prediction, this is in 1D array, need reshape to (nclass, ndata)
yprob = bst.predict( xg_test ).reshape( 6, test_Y.shape[0] )
ylabel = np.argmax( yprob, axis=0)
print ('predicting, classification error=%f' % (sum( int(ylabel[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) ))

View File

@ -103,7 +103,7 @@ namespace xgboost{
*/ */
inline void InitTrainer(void){ inline void InitTrainer(void){
if( mparam.num_class != 0 ){ if( mparam.num_class != 0 ){
if( name_obj_ != "multi:softmax" ){ if( name_obj_ != "multi:softmax" && name_obj_ != "multi:softprob"){
name_obj_ = "multi:softmax"; name_obj_ = "multi:softmax";
printf("auto select objective=softmax to support multi-class classification\n" ); printf("auto select objective=softmax to support multi-class classification\n" );
} }
@ -206,7 +206,7 @@ namespace xgboost{
fprintf(fo, "[%d]", iter); fprintf(fo, "[%d]", iter);
for (size_t i = 0; i < evals.size(); ++i){ for (size_t i = 0; i < evals.size(); ++i){
this->PredictRaw(preds_, *evals[i]); this->PredictRaw(preds_, *evals[i]);
obj_->PredTransform(preds_); obj_->EvalTransform(preds_);
evaluator_.Eval(fo, evname[i].c_str(), preds_, evals[i]->info); evaluator_.Eval(fo, evname[i].c_str(), preds_, evals[i]->info);
} }
fprintf(fo, "\n"); fprintf(fo, "\n");

View File

@ -41,6 +41,11 @@ namespace xgboost{
* \param preds prediction values, saves to this vector as well * \param preds prediction values, saves to this vector as well
*/ */
virtual void PredTransform(std::vector<float> &preds){} virtual void PredTransform(std::vector<float> &preds){}
/*!
* \brief transform prediction values, this is only called when Eval is called, usually it redirect to PredTransform
* \param preds prediction values, saves to this vector as well
*/
virtual void EvalTransform(std::vector<float> &preds){ this->PredTransform(preds); }
}; };
}; };
@ -114,8 +119,8 @@ namespace xgboost{
if( !strcmp("reg:logistic", name ) ) return new RegressionObj( LossType::kLogisticNeglik ); if( !strcmp("reg:logistic", name ) ) return new RegressionObj( LossType::kLogisticNeglik );
if( !strcmp("binary:logistic", name ) ) return new RegressionObj( LossType::kLogisticClassify ); if( !strcmp("binary:logistic", name ) ) return new RegressionObj( LossType::kLogisticClassify );
if( !strcmp("binary:logitraw", name ) ) return new RegressionObj( LossType::kLogisticRaw ); if( !strcmp("binary:logitraw", name ) ) return new RegressionObj( LossType::kLogisticRaw );
if( !strcmp("multi:softmax", name ) ) return new SoftmaxMultiClassObj(); if( !strcmp("multi:softmax", name ) ) return new SoftmaxMultiClassObj(0);
if( !strcmp("rank:pairwise", name ) ) return new PairwiseRankObj(); if( !strcmp("multi:softprob", name ) ) return new SoftmaxMultiClassObj(1);
if( !strcmp("rank:pairwise", name ) ) return new PairwiseRankObj(); if( !strcmp("rank:pairwise", name ) ) return new PairwiseRankObj();
if( !strcmp("rank:softmax", name ) ) return new SoftmaxRankObj(); if( !strcmp("rank:softmax", name ) ) return new SoftmaxRankObj();
utils::Error("unknown objective function type"); utils::Error("unknown objective function type");

View File

@ -112,7 +112,7 @@ namespace xgboost{
// simple softmax multi-class classification // simple softmax multi-class classification
class SoftmaxMultiClassObj : public IObjFunction{ class SoftmaxMultiClassObj : public IObjFunction{
public: public:
SoftmaxMultiClassObj(void){ SoftmaxMultiClassObj(int output_prob):output_prob(output_prob){
nclass = 0; nclass = 0;
} }
virtual ~SoftmaxMultiClassObj(){} virtual ~SoftmaxMultiClassObj(){}
@ -156,6 +156,13 @@ namespace xgboost{
} }
} }
virtual void PredTransform(std::vector<float> &preds){ virtual void PredTransform(std::vector<float> &preds){
this->Transform(preds, output_prob);
}
virtual void EvalTransform(std::vector<float> &preds){
this->Transform(preds, 0);
}
private:
inline void Transform(std::vector<float> &preds, int prob){
utils::Assert( nclass != 0, "must set num_class to use softmax" ); utils::Assert( nclass != 0, "must set num_class to use softmax" );
utils::Assert( preds.size() % nclass == 0, "SoftmaxMultiClassObj: label size and pred size does not match" ); utils::Assert( preds.size() % nclass == 0, "SoftmaxMultiClassObj: label size and pred size does not match" );
const unsigned ndata = static_cast<unsigned>(preds.size()/nclass); const unsigned ndata = static_cast<unsigned>(preds.size()/nclass);
@ -168,16 +175,26 @@ namespace xgboost{
for( int k = 0; k < nclass; ++ k ){ for( int k = 0; k < nclass; ++ k ){
rec[k] = preds[j + k * ndata]; rec[k] = preds[j + k * ndata];
} }
preds[j] = FindMaxIndex( rec ); if( prob == 0 ){
preds[j] = FindMaxIndex( rec );
}else{
Softmax( rec );
for( int k = 0; k < nclass; ++ k ){
preds[j + k * ndata] = rec[k];
}
}
} }
} }
preds.resize( ndata ); if( prob == 0 ){
preds.resize( ndata );
}
} }
virtual const char* DefaultEvalMetric(void) { virtual const char* DefaultEvalMetric(void) {
return "merror"; return "merror";
} }
private: private:
int nclass; int nclass;
int output_prob;
}; };
}; };