Implementation of hinge loss for binary classification (#3477)
This commit is contained in:
parent
44811f2330
commit
69454d9487
@ -20,6 +20,7 @@
|
|||||||
#include "../src/objective/regression_obj.cc"
|
#include "../src/objective/regression_obj.cc"
|
||||||
#include "../src/objective/multiclass_obj.cc"
|
#include "../src/objective/multiclass_obj.cc"
|
||||||
#include "../src/objective/rank_obj.cc"
|
#include "../src/objective/rank_obj.cc"
|
||||||
|
#include "../src/objective/hinge.cc"
|
||||||
|
|
||||||
// gbms
|
// gbms
|
||||||
#include "../src/gbm/gbm.cc"
|
#include "../src/gbm/gbm.cc"
|
||||||
|
|||||||
@ -248,6 +248,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
|||||||
- ``reg:logistic``: logistic regression
|
- ``reg:logistic``: logistic regression
|
||||||
- ``binary:logistic``: logistic regression for binary classification, output probability
|
- ``binary:logistic``: logistic regression for binary classification, output probability
|
||||||
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
|
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
|
||||||
|
- ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.
|
||||||
- ``gpu:reg:linear``, ``gpu:reg:logistic``, ``gpu:binary:logistic``, ``gpu:binary:logitraw``: versions
|
- ``gpu:reg:linear``, ``gpu:reg:logistic``, ``gpu:binary:logistic``, ``gpu:binary:logitraw``: versions
|
||||||
of the corresponding objective functions evaluated on the GPU; note that like the GPU histogram algorithm,
|
of the corresponding objective functions evaluated on the GPU; note that like the GPU histogram algorithm,
|
||||||
they can only be used when the entire training session uses the same dataset
|
they can only be used when the entire training session uses the same dataset
|
||||||
|
|||||||
71
src/objective/hinge.cc
Normal file
71
src/objective/hinge.cc
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2018 by Contributors
|
||||||
|
* \file hinge.cc
|
||||||
|
* \brief Provides an implementation of the hinge loss function
|
||||||
|
* \author Henry Gouk
|
||||||
|
*/
|
||||||
|
#include <xgboost/objective.h>
|
||||||
|
#include "../common/math.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace obj {
|
||||||
|
|
||||||
|
DMLC_REGISTRY_FILE_TAG(hinge);
|
||||||
|
|
||||||
|
class HingeObj : public ObjFunction {
|
||||||
|
public:
|
||||||
|
HingeObj() = default;
|
||||||
|
|
||||||
|
void Configure(
|
||||||
|
const std::vector<std::pair<std::string, std::string> > &args) override {
|
||||||
|
// This objective does not take any parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetGradient(HostDeviceVector<bst_float> *preds,
|
||||||
|
const MetaInfo &info,
|
||||||
|
int iter,
|
||||||
|
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||||
|
CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty";
|
||||||
|
CHECK_EQ(preds->Size(), info.labels_.size())
|
||||||
|
<< "labels are not correctly provided"
|
||||||
|
<< "preds.size=" << preds->Size()
|
||||||
|
<< ", label.size=" << info.labels_.size();
|
||||||
|
auto& preds_h = preds->HostVector();
|
||||||
|
|
||||||
|
out_gpair->Resize(preds_h.size());
|
||||||
|
auto& gpair = out_gpair->HostVector();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < preds_h.size(); ++i) {
|
||||||
|
auto y = info.labels_[i] * 2.0 - 1.0;
|
||||||
|
bst_float p = preds_h[i];
|
||||||
|
bst_float w = info.GetWeight(i);
|
||||||
|
bst_float g, h;
|
||||||
|
if (p * y < 1.0) {
|
||||||
|
g = -y * w;
|
||||||
|
h = w;
|
||||||
|
} else {
|
||||||
|
g = 0.0;
|
||||||
|
h = std::numeric_limits<bst_float>::min();
|
||||||
|
}
|
||||||
|
gpair[i] = GradientPair(g, h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||||
|
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||||
|
for (auto& p : preds) {
|
||||||
|
p = p > 0.0 ? 1.0 : 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* DefaultEvalMetric() const override {
|
||||||
|
return "error";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
|
||||||
|
.describe("Hinge loss. Expects labels to be in [0,1f]")
|
||||||
|
.set_body([]() { return new HingeObj(); });
|
||||||
|
|
||||||
|
} // namespace obj
|
||||||
|
} // namespace xgboost
|
||||||
@ -36,5 +36,6 @@ DMLC_REGISTRY_LINK_TAG(regression_obj);
|
|||||||
#endif
|
#endif
|
||||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
||||||
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
||||||
|
DMLC_REGISTRY_LINK_TAG(hinge);
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
20
tests/cpp/objective/test_hinge.cc
Normal file
20
tests/cpp/objective/test_hinge.cc
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright by Contributors
|
||||||
|
#include <xgboost/objective.h>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
TEST(Objective, HingeObj) {
|
||||||
|
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:hinge");
|
||||||
|
std::vector<std::pair<std::string, std::string> > args;
|
||||||
|
obj->Configure(args);
|
||||||
|
xgboost::bst_float eps = std::numeric_limits<xgboost::bst_float>::min();
|
||||||
|
CheckObjFunction(obj,
|
||||||
|
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f},
|
||||||
|
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{ 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
|
||||||
|
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
|
||||||
|
|
||||||
|
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user