fix multiclass
This commit is contained in:
parent
1fd6ff817f
commit
58d74861b9
@ -105,19 +105,22 @@ class RegLossObj : public IObjFunction{
|
||||
scale_pos_weight = static_cast<float>(atof(val));
|
||||
}
|
||||
}
|
||||
virtual void GetGradient(const std::vector<float>& preds,
|
||||
virtual void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) {
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"labels are not correctly provided");
|
||||
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||
gpair.resize(preds.size());
|
||||
// start calculating gradient
|
||||
const unsigned nstep = static_cast<unsigned>(info.labels.size());
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
float p = loss.PredTransform(preds[j]);
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
const unsigned j = i % nstep;
|
||||
float p = loss.PredTransform(preds[i]);
|
||||
float w = info.GetWeight(j);
|
||||
if (info.labels[j] == 1.0f) w *= scale_pos_weight;
|
||||
gpair[j] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w,
|
||||
@ -155,25 +158,28 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
if (!strcmp( "num_class", name )) nclass = atoi(val);
|
||||
}
|
||||
virtual void GetGradient(const std::vector<float>& preds,
|
||||
virtual void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) {
|
||||
utils::Check(nclass != 0, "must set num_class to use softmax");
|
||||
utils::Check(preds.size() == static_cast<size_t>(nclass) * info.labels.size(),
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % (static_cast<size_t>(nclass) * info.labels.size()) == 0,
|
||||
"SoftmaxMultiClassObj: label size and pred size does not match");
|
||||
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||
gpair.resize(preds.size());
|
||||
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||
const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass);
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size() / nclass);
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> rec(nclass);
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
rec[k] = preds[j * nclass + k];
|
||||
rec[k] = preds[i * nclass + k];
|
||||
}
|
||||
Softmax(&rec);
|
||||
const unsigned j = i % nstep;
|
||||
int label = static_cast<int>(info.labels[j]);
|
||||
utils::Check(label < nclass, "SoftmaxMultiClassObj: label exceed num_class");
|
||||
const float wt = info.GetWeight(j);
|
||||
@ -181,9 +187,9 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
float p = rec[k];
|
||||
const float h = 2.0f * p * (1.0f - p) * wt;
|
||||
if (label == k) {
|
||||
gpair[j * nclass + k] = bst_gpair((p - 1.0f) * wt, h);
|
||||
gpair[i * nclass + k] = bst_gpair((p - 1.0f) * wt, h);
|
||||
} else {
|
||||
gpair[j * nclass + k] = bst_gpair(p* wt, h);
|
||||
gpair[i * nclass + k] = bst_gpair(p* wt, h);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -203,7 +209,9 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
inline void Transform(std::vector<float> *io_preds, int prob) {
|
||||
utils::Check(nclass != 0, "must set num_class to use softmax");
|
||||
std::vector<float> &preds = *io_preds;
|
||||
std::vector<float> tmp;
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size()/nclass);
|
||||
if (prob == 0) tmp.resize(ndata);
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> rec(nclass);
|
||||
@ -213,7 +221,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
rec[k] = preds[j * nclass + k];
|
||||
}
|
||||
if (prob == 0) {
|
||||
preds[j] = FindMaxIndex(rec);
|
||||
tmp[j] = FindMaxIndex(rec);
|
||||
} else {
|
||||
Softmax(&rec);
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
@ -222,9 +230,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (prob == 0) {
|
||||
preds.resize(ndata);
|
||||
}
|
||||
if (prob == 0) preds = tmp;
|
||||
}
|
||||
// data field
|
||||
int nclass;
|
||||
@ -245,17 +251,17 @@ class LambdaRankObj : public IObjFunction {
|
||||
if (!strcmp( "fix_list_weight", name)) fix_list_weight = static_cast<float>(atof(val));
|
||||
if (!strcmp( "num_pairsample", name)) num_pairsample = atoi(val);
|
||||
}
|
||||
virtual void GetGradient(const std::vector<float>& preds,
|
||||
virtual void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) {
|
||||
utils::Assert(preds.size() == info.labels.size(), "label size predict size not match");
|
||||
utils::Check(preds.size() == info.labels.size(), "label size predict size not match");
|
||||
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||
gpair.resize(preds.size());
|
||||
// quick consistency when group is not available
|
||||
std::vector<unsigned> tgptr(2, 0); tgptr[1] = preds.size();
|
||||
std::vector<unsigned> tgptr(2, 0); tgptr[1] = info.labels.size();
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
utils::Check(gptr.size() != 0 && gptr.back() == preds.size(),
|
||||
utils::Check(gptr.size() != 0 && gptr.back() == info.labels.size(),
|
||||
"group structure not consistent with #rows");
|
||||
const unsigned ngroup = static_cast<unsigned>(gptr.size() - 1);
|
||||
#pragma omp parallel
|
||||
|
||||
@ -27,7 +27,7 @@ class IObjFunction{
|
||||
* \param iter current iteration number
|
||||
* \param out_gpair output of get gradient, saves gradient and second order gradient in
|
||||
*/
|
||||
virtual void GetGradient(const std::vector<float>& preds,
|
||||
virtual void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) = 0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user