[OBJ] Add basic objective function and registry

This commit is contained in:
tqchen
2015-12-30 18:22:15 -08:00
parent 46bcba7173
commit dedd87662b
16 changed files with 965 additions and 836 deletions

View File

@@ -7,6 +7,7 @@
#define XGBOOST_BASE_H_
#include <dmlc/base.h>
#include <dmlc/omp.h>
namespace xgboost {
/*!
@@ -17,10 +18,23 @@ typedef uint32_t bst_uint;
/*! \brief float type, used for storing statistics */
typedef float bst_float;
/*! \brief gradient statistics pair usually needed in gradient boosting */
struct bst_gpair {
/*! \brief gradient statistics */
bst_float grad;
/*! \brief second order gradient statistics */
bst_float hess;
bst_gpair() {}
bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {}
};
const float rt_eps = 1e-5f;
// min gap between feature values to allow a split happen
const float rt_2eps = rt_eps * 2.0f;
typedef dmlc::omp_ulong omp_ulong;
typedef dmlc::omp_uint bst_omp_uint;
/*!
* \brief define compatible keywords in g++
* Used to support g++-4.6 and g++4.7

View File

@@ -10,6 +10,7 @@
#include <dmlc/base.h>
#include <dmlc/data.h>
#include <memory>
#include <vector>
#include "./base.h"
namespace xgboost {
@@ -261,7 +262,7 @@ class DMatrix {
* \return a Created DMatrix.
*/
static DMatrix* Create(std::unique_ptr<DataSource>&& source,
const char* cache_prefix=nullptr);
const char* cache_prefix = nullptr);
/*!
* \brief Create a DMatrix by loaidng data from parser.
* Parser can later be deleted after the DMatrix i created.
@@ -275,7 +276,7 @@ class DMatrix {
* \return A created DMatrix.
*/
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
const char* cache_prefix=nullptr);
const char* cache_prefix = nullptr);
};
} // namespace xgboost

View File

@@ -0,0 +1,96 @@
/*!
* Copyright 2014 by Contributors
* \file objective.h
* \brief interface of objective function used by xgboost.
* \author Tianqi Chen, Kailong Chen
*/
#ifndef XGBOOST_OBJECTIVE_H_
#define XGBOOST_OBJECTIVE_H_
#include <dmlc/registry.h>
#include <vector>
#include <utility>
#include <string>
#include <functional>
#include "./data.h"
#include "./base.h"
namespace xgboost {
/*! \brief interface of objective function */
class ObjFunction {
public:
/*! \brief virtual destructor */
virtual ~ObjFunction() {}
/*!
* \brief Initialize the objective with the specified parameters.
* \param args arguments to the objective function.
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& args) = 0;
/*!
* \brief Get gradient over each of predictions, given existing information.
* \param preds prediction of current round
* \param info information about labels, weights, groups in rank
* \param iteration current iteration number.
* \param out_gpair output of get gradient, saves gradient and second order gradient in
*/
virtual void GetGradient(const std::vector<float>& preds,
const MetaInfo& info,
int iteration,
std::vector<bst_gpair>* out_gpair) = 0;
/*! \return the default evaluation metric for the objective */
virtual const char* DefaultEvalMetric() const = 0;
// the following functions are optional, most of time default implementation is good enough
/*!
* \brief transform prediction values, this is only called when Prediction is called
* \param io_preds prediction values, saves to this vector as well
*/
virtual void PredTransform(std::vector<float> *io_preds) {}
/*!
* \brief transform prediction values, this is only called when Eval is called,
* usually it redirect to PredTransform
* \param io_preds prediction values, saves to this vector as well
*/
virtual void EvalTransform(std::vector<float> *io_preds) {
this->PredTransform(io_preds);
}
/*!
* \brief transform probability value back to margin
* this is used to transform user-set base_score back to margin
* used by gradient boosting
* \return transformed value
*/
virtual float ProbToMargin(float base_score) const {
return base_score;
}
/*!
* \brief Create an objective function according to name.
* \param name Name of the objective.
*/
static ObjFunction* Create(const char* name);
};
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct ObjFunctionReg
: public dmlc::FunctionRegEntryBase<ObjFunctionReg,
std::function<ObjFunction* ()> > {
};
/*!
* \brief Macro to register objective
*
* \code
* // example of registering a objective
* XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
* .describe("Linear regression objective")
* .set_body([]() {
* return new RegLossObj(LossType::kLinearSquare);
* });
* \endcode
*/
#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
static ::xgboost::ObjFunctionReg & __make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(#Name)
} // namespace xgboost
#endif // XGBOOST_OBJECTIVE_H_