From 9df8bb13973b6f8cd37902dd4c69ff3129e7f647 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 17 Aug 2014 19:16:17 -0700 Subject: [PATCH] check in softmax multiclass --- demo/multiclass_classification/train.py | 7 +- python/xgboost_wrapper.cpp | 38 +++++------ src/learner/learner-inl.hpp | 1 + src/learner/objective-inl.hpp | 90 +++++++++++++++++++++++++ src/learner/objective.h | 2 + 5 files changed, 116 insertions(+), 22 deletions(-) diff --git a/demo/multiclass_classification/train.py b/demo/multiclass_classification/train.py index fabc43c45..69214a6c8 100755 --- a/demo/multiclass_classification/train.py +++ b/demo/multiclass_classification/train.py @@ -42,8 +42,9 @@ print ('predicting, classification error=%f' % (sum( int(pred[i]) != test_Y[i] f # do the same thing again, but output probabilities param['objective'] = 'multi:softprob' bst = xgb.train(param, xg_train, num_round, watchlist ); -# get prediction, this is in 1D array, need reshape to (nclass, ndata) -yprob = bst.predict( xg_test ).reshape( 6, test_Y.shape[0] ) -ylabel = np.argmax( yprob, axis=0) +# Note: this convention has been changed since xgboost-unity +# get prediction, this is in 1D array, need reshape to (ndata, nclass) +yprob = bst.predict( xg_test ).reshape( test_Y.shape[0], 6 ) +ylabel = np.argmax(yprob, axis=1) print ('predicting, classification error=%f' % (sum( int(ylabel[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) )) diff --git a/python/xgboost_wrapper.cpp b/python/xgboost_wrapper.cpp index 478d74936..8b89d1d25 100644 --- a/python/xgboost_wrapper.cpp +++ b/python/xgboost_wrapper.cpp @@ -32,7 +32,7 @@ class Booster: public learner::BoostLearner { inline void BoostOneIter(const DataMatrix &train, float *grad, float *hess, size_t len) { this->gpair_.resize(len); - const unsigned ndata = static_cast(len); + const unsigned ndata = static_cast(len); #pragma omp parallel for schedule(static) for (unsigned j = 0; j < ndata; ++j) { gpair_[j] = bst_gpair(grad[j], hess[j]); @@ -42,7 +42,7 @@ class Booster: public learner::BoostLearner { inline void CheckInitModel(void) { if (!init_model) { this->InitModel(); init_model = true; - } + } } inline void LoadModel(const char *fname) { learner::BoostLearner::LoadModel(fname); @@ -50,7 +50,7 @@ class Booster: public learner::BoostLearner { } inline const char** GetModelDump(const utils::FeatMap& fmap, bool with_stats, size_t *len) { model_dump = this->DumpModel(fmap, with_stats); - model_dump_cptr.resize(model_dump.size()); + model_dump_cptr.resize(model_dump.size()); for (size_t i = 0; i < model_dump.size(); ++i) { model_dump_cptr[i] = model_dump[i].c_str(); } @@ -82,11 +82,11 @@ extern "C"{ size_t nindptr, size_t nelem) { DMatrixSimple *p_mat = new DMatrixSimple(); - DMatrixSimple &mat = *p_mat; + DMatrixSimple &mat = *p_mat; mat.row_ptr_.resize(nindptr); memcpy(&mat.row_ptr_[0], indptr, sizeof(size_t)*nindptr); mat.row_data_.resize(nelem); - for (size_t i = 0; i < nelem; ++ i) { + for (size_t i = 0; i < nelem; ++i) { mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]); mat.info.num_col = std::max(mat.info.num_col, static_cast(indices[i]+1)); @@ -133,15 +133,15 @@ extern "C"{ ret.info.num_row = len; ret.info.num_col = src.info.num_col; - utils::IIterator *iter = src.fmat.RowIterator(); + utils::IIterator *iter = src.fmat.RowIterator(); iter->BeforeFirst(); utils::Assert(iter->Next(), "slice"); const SparseBatch &batch = iter->Value(); - for(size_t i = 0; i < len; ++i) { + for (size_t i = 0; i < len; ++i) { const int ridx = idxset[i]; SparseBatch::Inst inst = batch[ridx]; utils::Check(ridx < batch.size, "slice index exceed number of rows"); - ret.row_data_.resize(ret.row_data_.size() + inst.length); + ret.row_data_.resize(ret.row_data_.size() + inst.length); memcpy(&ret.row_data_[ret.row_ptr_.back()], inst.data, sizeof(SparseBatch::Entry) * inst.length); ret.row_ptr_.push_back(ret.row_ptr_.back() + inst.length); @@ -160,9 +160,9 @@ extern "C"{ void XGDMatrixFree(void *handle) { delete static_cast(handle); } - void XGDMatrixSaveBinary(void *handle, const char *fname, int silent) { + void XGDMatrixSaveBinary(void *handle, const char *fname, int silent) { SaveDataMatrix(*static_cast(handle), fname, silent); - } + } void XGDMatrixSetLabel(void *handle, const float *label, size_t len) { DataMatrix *pmat = static_cast(handle); pmat->info.labels.resize(len); @@ -173,11 +173,11 @@ extern "C"{ pmat->info.weights.resize(len); memcpy(&(pmat->info).weights[0], weight, sizeof(float) * len); } - void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len){ + void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) { DataMatrix *pmat = static_cast(handle); pmat->info.group_ptr.resize(len + 1); pmat->info.group_ptr[0] = 0; - for (size_t i = 0; i < len; ++ i) { + for (size_t i = 0; i < len; ++i) { pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i]; } } @@ -217,7 +217,7 @@ extern "C"{ bst->CheckInit(dtr); bst->UpdateOneIter(iter, *dtr); } - void XGBoosterBoostOneIter(void *handle, void *dtrain, + void XGBoosterBoostOneIter(void *handle, void *dtrain, float *grad, float *hess, size_t len) { Booster *bst = static_cast(handle); DataMatrix *dtr = static_cast(dtrain); @@ -225,8 +225,9 @@ extern "C"{ bst->CheckInit(dtr); bst->BoostOneIter(*dtr, grad, hess, len); } - const char* XGBoosterEvalOneIter(void *handle, int iter, void *dmats[], const char *evnames[], size_t len) { - Booster *bst = static_cast(handle); + const char* XGBoosterEvalOneIter(void *handle, int iter, void *dmats[], + const char *evnames[], size_t len) { + Booster *bst = static_cast(handle); std::vector names; std::vector mats; for (size_t i = 0; i < len; ++i) { @@ -243,13 +244,12 @@ extern "C"{ void XGBoosterLoadModel(void *handle, const char *fname) { static_cast(handle)->LoadModel(fname); } - void XGBoosterSaveModel( const void *handle, const char *fname) { + void XGBoosterSaveModel(const void *handle, const char *fname) { static_cast(handle)->SaveModel(fname); } const char** XGBoosterDumpModel(void *handle, const char *fmap, size_t *len){ - using namespace xgboost::utils; - FeatMap featmap; - if(strlen(fmap) != 0) { + utils::FeatMap featmap; + if (strlen(fmap) != 0) { featmap.LoadText(fmap); } return static_cast(handle)->GetModelDump(featmap, false, len); diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index d7ad3f71d..a183e904a 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -79,6 +79,7 @@ class BoostLearner { if (!strcmp(name, "silent")) silent = atoi(val); if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); if (!strcmp("seed", name)) random::Seed(atoi(val)); + if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val); if (gbm_ == NULL) { if (!strcmp(name, "objective")) name_obj_ = val; if (!strcmp(name, "booster")) name_gbm_ = val; diff --git a/src/learner/objective-inl.hpp b/src/learner/objective-inl.hpp index d5cc97fcf..5f23e3b00 100644 --- a/src/learner/objective-inl.hpp +++ b/src/learner/objective-inl.hpp @@ -7,7 +7,9 @@ */ #include #include +#include "../data.h" #include "./objective.h" +#include "./helper_utils.h" namespace xgboost { namespace learner { @@ -133,6 +135,94 @@ class RegLossObj : public IObjFunction{ float scale_pos_weight; LossType loss; }; + +// softmax multi-class classification +class SoftmaxMultiClassObj : public IObjFunction { + public: + explicit SoftmaxMultiClassObj(int output_prob) + : output_prob(output_prob) { + nclass = 0; + } + virtual ~SoftmaxMultiClassObj(void) {} + virtual void SetParam(const char *name, const char *val) { + if (!strcmp( "num_class", name )) nclass = atoi(val); + } + virtual void GetGradient(const std::vector& preds, + const MetaInfo &info, + int iter, + std::vector *out_gpair) { + utils::Check(nclass != 0, "must set num_class to use softmax"); + utils::Check(preds.size() == static_cast(nclass) * info.labels.size(), + "SoftmaxMultiClassObj: label size and pred size does not match"); + std::vector &gpair = *out_gpair; + gpair.resize(preds.size()); + const unsigned ndata = static_cast(info.labels.size()); + #pragma omp parallel + { + std::vector rec(nclass); + #pragma omp for schedule(static) + for (unsigned j = 0; j < ndata; ++j) { + for (int k = 0; k < nclass; ++k) { + rec[k] = preds[j * nclass + k]; + } + Softmax(&rec); + unsigned label = static_cast(info.labels[j]); + utils::Check(label < nclass, "SoftmaxMultiClassObj: label exceed num_class"); + const float wt = info.GetWeight(j); + for (int k = 0; k < nclass; ++k) { + float p = rec[k]; + const float h = 2.0f * p * (1.0f - p) * wt; + if (label == k) { + gpair[j * nclass + k] = bst_gpair((p - 1.0f) * wt, h); + } else { + gpair[j * nclass + k] = bst_gpair(p* wt, h); + } + } + } + } + } + virtual void PredTransform(std::vector *io_preds) { + this->Transform(io_preds, output_prob); + } + virtual void EvalTransform(std::vector *io_preds) { + this->Transform(io_preds, 0); + } + virtual const char* DefaultEvalMetric(void) { + return "merror"; + } + + private: + inline void Transform(std::vector *io_preds, int prob) { + utils::Check(nclass != 0, "must set num_class to use softmax"); + std::vector &preds = *io_preds; + const unsigned ndata = static_cast(preds.size()/nclass); + #pragma omp parallel + { + std::vector rec(nclass); + #pragma omp for schedule(static) + for (unsigned j = 0; j < ndata; ++j) { + for (int k = 0; k < nclass; ++k) { + rec[k] = preds[j * nclass + k]; + } + if (prob == 0) { + preds[j] = FindMaxIndex(rec); + } else { + Softmax(&rec); + for (int k = 0; k < nclass; ++k) { + preds[j * nclass + k] = rec[k]; + } + } + } + } + if (prob == 0) { + preds.resize(ndata); + } + } + // data field + int nclass; + int output_prob; +}; + } // namespace learner } // namespace xgboost #endif // XGBOOST_LEARNER_OBJECTIVE_INL_HPP_ diff --git a/src/learner/objective.h b/src/learner/objective.h index e38f7cfe4..bca035854 100644 --- a/src/learner/objective.h +++ b/src/learner/objective.h @@ -71,6 +71,8 @@ inline IObjFunction* CreateObjFunction(const char *name) { if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik); if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify); if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw); + if (!strcmp("multi:softmax", name)) return new SoftmaxMultiClassObj(0); + if (!strcmp("multi:softprob", name)) return new SoftmaxMultiClassObj(1); utils::Error("unknown objective function type: %s", name); return NULL; }