check in softmax multiclass
This commit is contained in:
@@ -79,6 +79,7 @@ class BoostLearner {
|
||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
||||
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
||||
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
||||
if (gbm_ == NULL) {
|
||||
if (!strcmp(name, "objective")) name_obj_ = val;
|
||||
if (!strcmp(name, "booster")) name_gbm_ = val;
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
*/
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "../data.h"
|
||||
#include "./objective.h"
|
||||
#include "./helper_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace learner {
|
||||
@@ -133,6 +135,94 @@ class RegLossObj : public IObjFunction{
|
||||
float scale_pos_weight;
|
||||
LossType loss;
|
||||
};
|
||||
|
||||
// softmax multi-class classification
|
||||
class SoftmaxMultiClassObj : public IObjFunction {
|
||||
public:
|
||||
explicit SoftmaxMultiClassObj(int output_prob)
|
||||
: output_prob(output_prob) {
|
||||
nclass = 0;
|
||||
}
|
||||
virtual ~SoftmaxMultiClassObj(void) {}
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
if (!strcmp( "num_class", name )) nclass = atoi(val);
|
||||
}
|
||||
virtual void GetGradient(const std::vector<float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) {
|
||||
utils::Check(nclass != 0, "must set num_class to use softmax");
|
||||
utils::Check(preds.size() == static_cast<size_t>(nclass) * info.labels.size(),
|
||||
"SoftmaxMultiClassObj: label size and pred size does not match");
|
||||
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||
gpair.resize(preds.size());
|
||||
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> rec(nclass);
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
rec[k] = preds[j * nclass + k];
|
||||
}
|
||||
Softmax(&rec);
|
||||
unsigned label = static_cast<unsigned>(info.labels[j]);
|
||||
utils::Check(label < nclass, "SoftmaxMultiClassObj: label exceed num_class");
|
||||
const float wt = info.GetWeight(j);
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
float p = rec[k];
|
||||
const float h = 2.0f * p * (1.0f - p) * wt;
|
||||
if (label == k) {
|
||||
gpair[j * nclass + k] = bst_gpair((p - 1.0f) * wt, h);
|
||||
} else {
|
||||
gpair[j * nclass + k] = bst_gpair(p* wt, h);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual void PredTransform(std::vector<float> *io_preds) {
|
||||
this->Transform(io_preds, output_prob);
|
||||
}
|
||||
virtual void EvalTransform(std::vector<float> *io_preds) {
|
||||
this->Transform(io_preds, 0);
|
||||
}
|
||||
virtual const char* DefaultEvalMetric(void) {
|
||||
return "merror";
|
||||
}
|
||||
|
||||
private:
|
||||
inline void Transform(std::vector<float> *io_preds, int prob) {
|
||||
utils::Check(nclass != 0, "must set num_class to use softmax");
|
||||
std::vector<float> &preds = *io_preds;
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size()/nclass);
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> rec(nclass);
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
rec[k] = preds[j * nclass + k];
|
||||
}
|
||||
if (prob == 0) {
|
||||
preds[j] = FindMaxIndex(rec);
|
||||
} else {
|
||||
Softmax(&rec);
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
preds[j * nclass + k] = rec[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (prob == 0) {
|
||||
preds.resize(ndata);
|
||||
}
|
||||
}
|
||||
// data field
|
||||
int nclass;
|
||||
int output_prob;
|
||||
};
|
||||
|
||||
} // namespace learner
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_LEARNER_OBJECTIVE_INL_HPP_
|
||||
|
||||
@@ -71,6 +71,8 @@ inline IObjFunction* CreateObjFunction(const char *name) {
|
||||
if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik);
|
||||
if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify);
|
||||
if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw);
|
||||
if (!strcmp("multi:softmax", name)) return new SoftmaxMultiClassObj(0);
|
||||
if (!strcmp("multi:softprob", name)) return new SoftmaxMultiClassObj(1);
|
||||
utils::Error("unknown objective function type: %s", name);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user