check in softmax multiclass

This commit is contained in:
tqchen 2014-08-17 19:16:17 -07:00
parent e77df13815
commit 9df8bb1397
5 changed files with 116 additions and 22 deletions

View File

@ -42,8 +42,9 @@ print ('predicting, classification error=%f' % (sum( int(pred[i]) != test_Y[i] f
# do the same thing again, but output probabilities # do the same thing again, but output probabilities
param['objective'] = 'multi:softprob' param['objective'] = 'multi:softprob'
bst = xgb.train(param, xg_train, num_round, watchlist ); bst = xgb.train(param, xg_train, num_round, watchlist );
# get prediction, this is in 1D array, need reshape to (nclass, ndata) # Note: this convention has been changed since xgboost-unity
yprob = bst.predict( xg_test ).reshape( 6, test_Y.shape[0] ) # get prediction, this is in 1D array, need reshape to (ndata, nclass)
ylabel = np.argmax( yprob, axis=0) yprob = bst.predict( xg_test ).reshape( test_Y.shape[0], 6 )
ylabel = np.argmax(yprob, axis=1)
print ('predicting, classification error=%f' % (sum( int(ylabel[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) )) print ('predicting, classification error=%f' % (sum( int(ylabel[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) ))

View File

@ -225,7 +225,8 @@ extern "C"{
bst->CheckInit(dtr); bst->CheckInit(dtr);
bst->BoostOneIter(*dtr, grad, hess, len); bst->BoostOneIter(*dtr, grad, hess, len);
} }
const char* XGBoosterEvalOneIter(void *handle, int iter, void *dmats[], const char *evnames[], size_t len) { const char* XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
const char *evnames[], size_t len) {
Booster *bst = static_cast<Booster*>(handle); Booster *bst = static_cast<Booster*>(handle);
std::vector<std::string> names; std::vector<std::string> names;
std::vector<const DataMatrix*> mats; std::vector<const DataMatrix*> mats;
@ -247,8 +248,7 @@ extern "C"{
static_cast<const Booster*>(handle)->SaveModel(fname); static_cast<const Booster*>(handle)->SaveModel(fname);
} }
const char** XGBoosterDumpModel(void *handle, const char *fmap, size_t *len){ const char** XGBoosterDumpModel(void *handle, const char *fmap, size_t *len){
using namespace xgboost::utils; utils::FeatMap featmap;
FeatMap featmap;
if (strlen(fmap) != 0) { if (strlen(fmap) != 0) {
featmap.LoadText(fmap); featmap.LoadText(fmap);
} }

View File

@ -79,6 +79,7 @@ class BoostLearner {
if (!strcmp(name, "silent")) silent = atoi(val); if (!strcmp(name, "silent")) silent = atoi(val);
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
if (!strcmp("seed", name)) random::Seed(atoi(val)); if (!strcmp("seed", name)) random::Seed(atoi(val));
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
if (gbm_ == NULL) { if (gbm_ == NULL) {
if (!strcmp(name, "objective")) name_obj_ = val; if (!strcmp(name, "objective")) name_obj_ = val;
if (!strcmp(name, "booster")) name_gbm_ = val; if (!strcmp(name, "booster")) name_gbm_ = val;

View File

@ -7,7 +7,9 @@
*/ */
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include "../data.h"
#include "./objective.h" #include "./objective.h"
#include "./helper_utils.h"
namespace xgboost { namespace xgboost {
namespace learner { namespace learner {
@ -133,6 +135,94 @@ class RegLossObj : public IObjFunction{
float scale_pos_weight; float scale_pos_weight;
LossType loss; LossType loss;
}; };
// softmax multi-class classification
class SoftmaxMultiClassObj : public IObjFunction {
public:
explicit SoftmaxMultiClassObj(int output_prob)
: output_prob(output_prob) {
nclass = 0;
}
virtual ~SoftmaxMultiClassObj(void) {}
virtual void SetParam(const char *name, const char *val) {
if (!strcmp( "num_class", name )) nclass = atoi(val);
}
virtual void GetGradient(const std::vector<float>& preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) {
utils::Check(nclass != 0, "must set num_class to use softmax");
utils::Check(preds.size() == static_cast<size_t>(nclass) * info.labels.size(),
"SoftmaxMultiClassObj: label size and pred size does not match");
std::vector<bst_gpair> &gpair = *out_gpair;
gpair.resize(preds.size());
const unsigned ndata = static_cast<unsigned>(info.labels.size());
#pragma omp parallel
{
std::vector<float> rec(nclass);
#pragma omp for schedule(static)
for (unsigned j = 0; j < ndata; ++j) {
for (int k = 0; k < nclass; ++k) {
rec[k] = preds[j * nclass + k];
}
Softmax(&rec);
unsigned label = static_cast<unsigned>(info.labels[j]);
utils::Check(label < nclass, "SoftmaxMultiClassObj: label exceed num_class");
const float wt = info.GetWeight(j);
for (int k = 0; k < nclass; ++k) {
float p = rec[k];
const float h = 2.0f * p * (1.0f - p) * wt;
if (label == k) {
gpair[j * nclass + k] = bst_gpair((p - 1.0f) * wt, h);
} else {
gpair[j * nclass + k] = bst_gpair(p* wt, h);
}
}
}
}
}
virtual void PredTransform(std::vector<float> *io_preds) {
this->Transform(io_preds, output_prob);
}
virtual void EvalTransform(std::vector<float> *io_preds) {
this->Transform(io_preds, 0);
}
virtual const char* DefaultEvalMetric(void) {
return "merror";
}
private:
inline void Transform(std::vector<float> *io_preds, int prob) {
utils::Check(nclass != 0, "must set num_class to use softmax");
std::vector<float> &preds = *io_preds;
const unsigned ndata = static_cast<unsigned>(preds.size()/nclass);
#pragma omp parallel
{
std::vector<float> rec(nclass);
#pragma omp for schedule(static)
for (unsigned j = 0; j < ndata; ++j) {
for (int k = 0; k < nclass; ++k) {
rec[k] = preds[j * nclass + k];
}
if (prob == 0) {
preds[j] = FindMaxIndex(rec);
} else {
Softmax(&rec);
for (int k = 0; k < nclass; ++k) {
preds[j * nclass + k] = rec[k];
}
}
}
}
if (prob == 0) {
preds.resize(ndata);
}
}
// data field
int nclass;
int output_prob;
};
} // namespace learner } // namespace learner
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_LEARNER_OBJECTIVE_INL_HPP_ #endif // XGBOOST_LEARNER_OBJECTIVE_INL_HPP_

View File

@ -71,6 +71,8 @@ inline IObjFunction* CreateObjFunction(const char *name) {
if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik); if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik);
if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify); if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify);
if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw); if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw);
if (!strcmp("multi:softmax", name)) return new SoftmaxMultiClassObj(0);
if (!strcmp("multi:softprob", name)) return new SoftmaxMultiClassObj(1);
utils::Error("unknown objective function type: %s", name); utils::Error("unknown objective function type: %s", name);
return NULL; return NULL;
} }