check in softmax multiclass
This commit is contained in:
parent
e77df13815
commit
9df8bb1397
@ -42,8 +42,9 @@ print ('predicting, classification error=%f' % (sum( int(pred[i]) != test_Y[i] f
|
|||||||
# do the same thing again, but output probabilities
|
# do the same thing again, but output probabilities
|
||||||
param['objective'] = 'multi:softprob'
|
param['objective'] = 'multi:softprob'
|
||||||
bst = xgb.train(param, xg_train, num_round, watchlist );
|
bst = xgb.train(param, xg_train, num_round, watchlist );
|
||||||
# get prediction, this is in 1D array, need reshape to (nclass, ndata)
|
# Note: this convention has been changed since xgboost-unity
|
||||||
yprob = bst.predict( xg_test ).reshape( 6, test_Y.shape[0] )
|
# get prediction, this is in 1D array, need reshape to (ndata, nclass)
|
||||||
ylabel = np.argmax( yprob, axis=0)
|
yprob = bst.predict( xg_test ).reshape( test_Y.shape[0], 6 )
|
||||||
|
ylabel = np.argmax(yprob, axis=1)
|
||||||
|
|
||||||
print ('predicting, classification error=%f' % (sum( int(ylabel[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) ))
|
print ('predicting, classification error=%f' % (sum( int(ylabel[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) ))
|
||||||
|
|||||||
@ -86,7 +86,7 @@ extern "C"{
|
|||||||
mat.row_ptr_.resize(nindptr);
|
mat.row_ptr_.resize(nindptr);
|
||||||
memcpy(&mat.row_ptr_[0], indptr, sizeof(size_t)*nindptr);
|
memcpy(&mat.row_ptr_[0], indptr, sizeof(size_t)*nindptr);
|
||||||
mat.row_data_.resize(nelem);
|
mat.row_data_.resize(nelem);
|
||||||
for (size_t i = 0; i < nelem; ++ i) {
|
for (size_t i = 0; i < nelem; ++i) {
|
||||||
mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]);
|
mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]);
|
||||||
mat.info.num_col = std::max(mat.info.num_col,
|
mat.info.num_col = std::max(mat.info.num_col,
|
||||||
static_cast<size_t>(indices[i]+1));
|
static_cast<size_t>(indices[i]+1));
|
||||||
@ -137,7 +137,7 @@ extern "C"{
|
|||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
utils::Assert(iter->Next(), "slice");
|
utils::Assert(iter->Next(), "slice");
|
||||||
const SparseBatch &batch = iter->Value();
|
const SparseBatch &batch = iter->Value();
|
||||||
for(size_t i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
const int ridx = idxset[i];
|
const int ridx = idxset[i];
|
||||||
SparseBatch::Inst inst = batch[ridx];
|
SparseBatch::Inst inst = batch[ridx];
|
||||||
utils::Check(ridx < batch.size, "slice index exceed number of rows");
|
utils::Check(ridx < batch.size, "slice index exceed number of rows");
|
||||||
@ -173,11 +173,11 @@ extern "C"{
|
|||||||
pmat->info.weights.resize(len);
|
pmat->info.weights.resize(len);
|
||||||
memcpy(&(pmat->info).weights[0], weight, sizeof(float) * len);
|
memcpy(&(pmat->info).weights[0], weight, sizeof(float) * len);
|
||||||
}
|
}
|
||||||
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len){
|
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) {
|
||||||
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
|
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
|
||||||
pmat->info.group_ptr.resize(len + 1);
|
pmat->info.group_ptr.resize(len + 1);
|
||||||
pmat->info.group_ptr[0] = 0;
|
pmat->info.group_ptr[0] = 0;
|
||||||
for (size_t i = 0; i < len; ++ i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i];
|
pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -225,7 +225,8 @@ extern "C"{
|
|||||||
bst->CheckInit(dtr);
|
bst->CheckInit(dtr);
|
||||||
bst->BoostOneIter(*dtr, grad, hess, len);
|
bst->BoostOneIter(*dtr, grad, hess, len);
|
||||||
}
|
}
|
||||||
const char* XGBoosterEvalOneIter(void *handle, int iter, void *dmats[], const char *evnames[], size_t len) {
|
const char* XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
|
||||||
|
const char *evnames[], size_t len) {
|
||||||
Booster *bst = static_cast<Booster*>(handle);
|
Booster *bst = static_cast<Booster*>(handle);
|
||||||
std::vector<std::string> names;
|
std::vector<std::string> names;
|
||||||
std::vector<const DataMatrix*> mats;
|
std::vector<const DataMatrix*> mats;
|
||||||
@ -243,13 +244,12 @@ extern "C"{
|
|||||||
void XGBoosterLoadModel(void *handle, const char *fname) {
|
void XGBoosterLoadModel(void *handle, const char *fname) {
|
||||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||||
}
|
}
|
||||||
void XGBoosterSaveModel( const void *handle, const char *fname) {
|
void XGBoosterSaveModel(const void *handle, const char *fname) {
|
||||||
static_cast<const Booster*>(handle)->SaveModel(fname);
|
static_cast<const Booster*>(handle)->SaveModel(fname);
|
||||||
}
|
}
|
||||||
const char** XGBoosterDumpModel(void *handle, const char *fmap, size_t *len){
|
const char** XGBoosterDumpModel(void *handle, const char *fmap, size_t *len){
|
||||||
using namespace xgboost::utils;
|
utils::FeatMap featmap;
|
||||||
FeatMap featmap;
|
if (strlen(fmap) != 0) {
|
||||||
if(strlen(fmap) != 0) {
|
|
||||||
featmap.LoadText(fmap);
|
featmap.LoadText(fmap);
|
||||||
}
|
}
|
||||||
return static_cast<Booster*>(handle)->GetModelDump(featmap, false, len);
|
return static_cast<Booster*>(handle)->GetModelDump(featmap, false, len);
|
||||||
|
|||||||
@ -79,6 +79,7 @@ class BoostLearner {
|
|||||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
if (!strcmp(name, "silent")) silent = atoi(val);
|
||||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
||||||
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
||||||
|
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
||||||
if (gbm_ == NULL) {
|
if (gbm_ == NULL) {
|
||||||
if (!strcmp(name, "objective")) name_obj_ = val;
|
if (!strcmp(name, "objective")) name_obj_ = val;
|
||||||
if (!strcmp(name, "booster")) name_gbm_ = val;
|
if (!strcmp(name, "booster")) name_gbm_ = val;
|
||||||
|
|||||||
@ -7,7 +7,9 @@
|
|||||||
*/
|
*/
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include "../data.h"
|
||||||
#include "./objective.h"
|
#include "./objective.h"
|
||||||
|
#include "./helper_utils.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace learner {
|
namespace learner {
|
||||||
@ -133,6 +135,94 @@ class RegLossObj : public IObjFunction{
|
|||||||
float scale_pos_weight;
|
float scale_pos_weight;
|
||||||
LossType loss;
|
LossType loss;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// softmax multi-class classification
|
||||||
|
class SoftmaxMultiClassObj : public IObjFunction {
|
||||||
|
public:
|
||||||
|
explicit SoftmaxMultiClassObj(int output_prob)
|
||||||
|
: output_prob(output_prob) {
|
||||||
|
nclass = 0;
|
||||||
|
}
|
||||||
|
virtual ~SoftmaxMultiClassObj(void) {}
|
||||||
|
virtual void SetParam(const char *name, const char *val) {
|
||||||
|
if (!strcmp( "num_class", name )) nclass = atoi(val);
|
||||||
|
}
|
||||||
|
virtual void GetGradient(const std::vector<float>& preds,
|
||||||
|
const MetaInfo &info,
|
||||||
|
int iter,
|
||||||
|
std::vector<bst_gpair> *out_gpair) {
|
||||||
|
utils::Check(nclass != 0, "must set num_class to use softmax");
|
||||||
|
utils::Check(preds.size() == static_cast<size_t>(nclass) * info.labels.size(),
|
||||||
|
"SoftmaxMultiClassObj: label size and pred size does not match");
|
||||||
|
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||||
|
gpair.resize(preds.size());
|
||||||
|
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
std::vector<float> rec(nclass);
|
||||||
|
#pragma omp for schedule(static)
|
||||||
|
for (unsigned j = 0; j < ndata; ++j) {
|
||||||
|
for (int k = 0; k < nclass; ++k) {
|
||||||
|
rec[k] = preds[j * nclass + k];
|
||||||
|
}
|
||||||
|
Softmax(&rec);
|
||||||
|
unsigned label = static_cast<unsigned>(info.labels[j]);
|
||||||
|
utils::Check(label < nclass, "SoftmaxMultiClassObj: label exceed num_class");
|
||||||
|
const float wt = info.GetWeight(j);
|
||||||
|
for (int k = 0; k < nclass; ++k) {
|
||||||
|
float p = rec[k];
|
||||||
|
const float h = 2.0f * p * (1.0f - p) * wt;
|
||||||
|
if (label == k) {
|
||||||
|
gpair[j * nclass + k] = bst_gpair((p - 1.0f) * wt, h);
|
||||||
|
} else {
|
||||||
|
gpair[j * nclass + k] = bst_gpair(p* wt, h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
virtual void PredTransform(std::vector<float> *io_preds) {
|
||||||
|
this->Transform(io_preds, output_prob);
|
||||||
|
}
|
||||||
|
virtual void EvalTransform(std::vector<float> *io_preds) {
|
||||||
|
this->Transform(io_preds, 0);
|
||||||
|
}
|
||||||
|
virtual const char* DefaultEvalMetric(void) {
|
||||||
|
return "merror";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
inline void Transform(std::vector<float> *io_preds, int prob) {
|
||||||
|
utils::Check(nclass != 0, "must set num_class to use softmax");
|
||||||
|
std::vector<float> &preds = *io_preds;
|
||||||
|
const unsigned ndata = static_cast<unsigned>(preds.size()/nclass);
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
std::vector<float> rec(nclass);
|
||||||
|
#pragma omp for schedule(static)
|
||||||
|
for (unsigned j = 0; j < ndata; ++j) {
|
||||||
|
for (int k = 0; k < nclass; ++k) {
|
||||||
|
rec[k] = preds[j * nclass + k];
|
||||||
|
}
|
||||||
|
if (prob == 0) {
|
||||||
|
preds[j] = FindMaxIndex(rec);
|
||||||
|
} else {
|
||||||
|
Softmax(&rec);
|
||||||
|
for (int k = 0; k < nclass; ++k) {
|
||||||
|
preds[j * nclass + k] = rec[k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (prob == 0) {
|
||||||
|
preds.resize(ndata);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// data field
|
||||||
|
int nclass;
|
||||||
|
int output_prob;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace learner
|
} // namespace learner
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_LEARNER_OBJECTIVE_INL_HPP_
|
#endif // XGBOOST_LEARNER_OBJECTIVE_INL_HPP_
|
||||||
|
|||||||
@ -71,6 +71,8 @@ inline IObjFunction* CreateObjFunction(const char *name) {
|
|||||||
if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik);
|
if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik);
|
||||||
if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify);
|
if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify);
|
||||||
if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw);
|
if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw);
|
||||||
|
if (!strcmp("multi:softmax", name)) return new SoftmaxMultiClassObj(0);
|
||||||
|
if (!strcmp("multi:softprob", name)) return new SoftmaxMultiClassObj(1);
|
||||||
utils::Error("unknown objective function type: %s", name);
|
utils::Error("unknown objective function type: %s", name);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user