* Clang-tidy static analysis * Modernise checks * Google coding standard checks * Identifier renaming according to Google style
118 lines
3.9 KiB
C++
118 lines
3.9 KiB
C++
/*!
|
|
* Copyright 2014 by Contributors
|
|
* \file objective.h
|
|
* \brief interface of objective function used by xgboost.
|
|
* \author Tianqi Chen, Kailong Chen
|
|
*/
|
|
#ifndef XGBOOST_OBJECTIVE_H_
|
|
#define XGBOOST_OBJECTIVE_H_
|
|
|
|
#include <dmlc/registry.h>
|
|
#include <vector>
|
|
#include <utility>
|
|
#include <string>
|
|
#include <functional>
|
|
#include "./data.h"
|
|
#include "./base.h"
|
|
#include "../../src/common/host_device_vector.h"
|
|
|
|
|
|
namespace xgboost {
|
|
|
|
/*! \brief interface of objective function */
|
|
class ObjFunction {
|
|
public:
|
|
/*! \brief virtual destructor */
|
|
virtual ~ObjFunction() = default;
|
|
/*!
|
|
* \brief set configuration from pair iterators.
|
|
* \param begin The beginning iterator.
|
|
* \param end The end iterator.
|
|
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
|
*/
|
|
template<typename PairIter>
|
|
inline void Configure(PairIter begin, PairIter end);
|
|
/*!
|
|
* \brief Configure the objective with the specified parameters.
|
|
* \param args arguments to the objective function.
|
|
*/
|
|
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0;
|
|
/*!
|
|
* \brief Get gradient over each of predictions, given existing information.
|
|
* \param preds prediction of current round
|
|
* \param info information about labels, weights, groups in rank
|
|
* \param iteration current iteration number.
|
|
* \param out_gpair output of get gradient, saves gradient and second order gradient in
|
|
*/
|
|
virtual void GetGradient(HostDeviceVector<bst_float>* preds,
|
|
const MetaInfo& info,
|
|
int iteration,
|
|
HostDeviceVector<GradientPair>* out_gpair) = 0;
|
|
|
|
/*! \return the default evaluation metric for the objective */
|
|
virtual const char* DefaultEvalMetric() const = 0;
|
|
// the following functions are optional, most of time default implementation is good enough
|
|
/*!
|
|
* \brief transform prediction values, this is only called when Prediction is called
|
|
* \param io_preds prediction values, saves to this vector as well
|
|
*/
|
|
virtual void PredTransform(HostDeviceVector<bst_float> *io_preds) {}
|
|
|
|
/*!
|
|
* \brief transform prediction values, this is only called when Eval is called,
|
|
* usually it redirect to PredTransform
|
|
* \param io_preds prediction values, saves to this vector as well
|
|
*/
|
|
virtual void EvalTransform(HostDeviceVector<bst_float> *io_preds) {
|
|
this->PredTransform(io_preds);
|
|
}
|
|
/*!
|
|
* \brief transform probability value back to margin
|
|
* this is used to transform user-set base_score back to margin
|
|
* used by gradient boosting
|
|
* \return transformed value
|
|
*/
|
|
virtual bst_float ProbToMargin(bst_float base_score) const {
|
|
return base_score;
|
|
}
|
|
/*!
|
|
* \brief Create an objective function according to name.
|
|
* \param name Name of the objective.
|
|
*/
|
|
static ObjFunction* Create(const std::string& name);
|
|
};
|
|
|
|
// implementing configure.
|
|
template<typename PairIter>
|
|
inline void ObjFunction::Configure(PairIter begin, PairIter end) {
|
|
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
|
this->Configure(vec);
|
|
}
|
|
|
|
/*!
|
|
* \brief Registry entry for objective factory functions.
|
|
*/
|
|
struct ObjFunctionReg
|
|
: public dmlc::FunctionRegEntryBase<ObjFunctionReg,
|
|
std::function<ObjFunction* ()> > {
|
|
};
|
|
|
|
/*!
|
|
* \brief Macro to register objective function.
|
|
*
|
|
* \code
|
|
* // example of registering a objective
|
|
* XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
|
* .describe("Linear regression objective")
|
|
* .set_body([]() {
|
|
* return new RegLossObj(LossType::kLinearSquare);
|
|
* });
|
|
* \endcode
|
|
*/
|
|
#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
|
|
static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg & \
|
|
__make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \
|
|
::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
|
|
} // namespace xgboost
|
|
#endif // XGBOOST_OBJECTIVE_H_
|