[LEARNER] refactor learner
This commit is contained in:
parent
4b4b36d047
commit
0d95e863c9
@ -14,6 +14,9 @@
|
|||||||
#include "./base.h"
|
#include "./base.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
// forward declare learner.
|
||||||
|
class LearnerImpl;
|
||||||
|
|
||||||
/*! \brief data type accepted by xgboost interface */
|
/*! \brief data type accepted by xgboost interface */
|
||||||
enum DataType {
|
enum DataType {
|
||||||
kFloat32 = 1,
|
kFloat32 = 1,
|
||||||
@ -199,6 +202,8 @@ class DataSource : public dmlc::DataIter<RowBatch> {
|
|||||||
*/
|
*/
|
||||||
class DMatrix {
|
class DMatrix {
|
||||||
public:
|
public:
|
||||||
|
/*! \brief default constructor */
|
||||||
|
DMatrix() : cache_learner_ptr_(nullptr) {}
|
||||||
/*! \brief meta information of the dataset */
|
/*! \brief meta information of the dataset */
|
||||||
virtual MetaInfo& info() = 0;
|
virtual MetaInfo& info() = 0;
|
||||||
/*! \brief meta information of the dataset */
|
/*! \brief meta information of the dataset */
|
||||||
@ -222,6 +227,7 @@ class DMatrix {
|
|||||||
* \param subsample subsample ratio when generating column access.
|
* \param subsample subsample ratio when generating column access.
|
||||||
* \param max_row_perbatch auxilary information, maximum row used in each column batch.
|
* \param max_row_perbatch auxilary information, maximum row used in each column batch.
|
||||||
* this is a hint information that can be ignored by the implementation.
|
* this is a hint information that can be ignored by the implementation.
|
||||||
|
* \return Number of column blocks in the column access.
|
||||||
*/
|
*/
|
||||||
virtual void InitColAccess(const std::vector<bool>& enabled,
|
virtual void InitColAccess(const std::vector<bool>& enabled,
|
||||||
float subsample,
|
float subsample,
|
||||||
@ -229,6 +235,8 @@ class DMatrix {
|
|||||||
// the following are column meta data, should be able to answer them fast.
|
// the following are column meta data, should be able to answer them fast.
|
||||||
/*! \return whether column access is enabled */
|
/*! \return whether column access is enabled */
|
||||||
virtual bool HaveColAccess() const = 0;
|
virtual bool HaveColAccess() const = 0;
|
||||||
|
/*! \return Whether the data columns single column block. */
|
||||||
|
virtual bool SingleColBlock() const = 0;
|
||||||
/*! \brief get number of non-missing entries in column */
|
/*! \brief get number of non-missing entries in column */
|
||||||
virtual size_t GetColSize(size_t cidx) const = 0;
|
virtual size_t GetColSize(size_t cidx) const = 0;
|
||||||
/*! \brief get column density */
|
/*! \brief get column density */
|
||||||
@ -279,6 +287,12 @@ class DMatrix {
|
|||||||
*/
|
*/
|
||||||
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
|
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
|
||||||
const char* cache_prefix = nullptr);
|
const char* cache_prefix = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// allow learner class to access this field.
|
||||||
|
friend class LearnerImpl;
|
||||||
|
/*! \brief public field to back ref cached matrix. */
|
||||||
|
LearnerImpl* cache_learner_ptr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -25,6 +25,14 @@ class GradientBooster {
|
|||||||
public:
|
public:
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
virtual ~GradientBooster() {}
|
virtual ~GradientBooster() {}
|
||||||
|
/*!
|
||||||
|
* \brief set configuration from pair iterators.
|
||||||
|
* \param begin The beginning iterator.
|
||||||
|
* \param end The end iterator.
|
||||||
|
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||||
|
*/
|
||||||
|
template<typename PairIter>
|
||||||
|
inline void Configure(PairIter begin, PairIter end);
|
||||||
/*!
|
/*!
|
||||||
* \brief Set the configuration of gradient boosting.
|
* \brief Set the configuration of gradient boosting.
|
||||||
* User must call configure once before InitModel and Training.
|
* User must call configure once before InitModel and Training.
|
||||||
@ -123,9 +131,16 @@ class GradientBooster {
|
|||||||
* \breif create a gradient booster from given name
|
* \breif create a gradient booster from given name
|
||||||
* \param name name of gradient booster
|
* \param name name of gradient booster
|
||||||
*/
|
*/
|
||||||
static GradientBooster* Create(const char *name);
|
static GradientBooster* Create(const std::string& name);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// implementing configure.
|
||||||
|
template<typename PairIter>
|
||||||
|
inline void GradientBooster::Configure(PairIter begin, PairIter end) {
|
||||||
|
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||||
|
this->Configure(vec);
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Registry entry for tree updater.
|
* \brief Registry entry for tree updater.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -14,7 +14,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "./base.h"
|
#include "./base.h"
|
||||||
#include "./gbm.h"
|
#include "./gbm.h"
|
||||||
#include "./meric.h"
|
#include "./metric.h"
|
||||||
#include "./objective.h"
|
#include "./objective.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -36,6 +36,14 @@ namespace xgboost {
|
|||||||
*/
|
*/
|
||||||
class Learner : public rabit::Serializable {
|
class Learner : public rabit::Serializable {
|
||||||
public:
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief set configuration from pair iterators.
|
||||||
|
* \param begin The beginning iterator.
|
||||||
|
* \param end The end iterator.
|
||||||
|
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||||
|
*/
|
||||||
|
template<typename PairIter>
|
||||||
|
inline void Configure(PairIter begin, PairIter end);
|
||||||
/*!
|
/*!
|
||||||
* \brief Set the configuration of gradient boosting.
|
* \brief Set the configuration of gradient boosting.
|
||||||
* User must call configure once before InitModel and Training.
|
* User must call configure once before InitModel and Training.
|
||||||
@ -59,7 +67,7 @@ class Learner : public rabit::Serializable {
|
|||||||
* \param iter current iteration number
|
* \param iter current iteration number
|
||||||
* \param train reference to the data matrix.
|
* \param train reference to the data matrix.
|
||||||
*/
|
*/
|
||||||
void UpdateOneIter(int iter, DMatrix* train);
|
virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief Do customized gradient boosting with in_gpair.
|
* \brief Do customized gradient boosting with in_gpair.
|
||||||
* in_gair can be mutated after this call.
|
* in_gair can be mutated after this call.
|
||||||
@ -67,9 +75,9 @@ class Learner : public rabit::Serializable {
|
|||||||
* \param train reference to the data matrix.
|
* \param train reference to the data matrix.
|
||||||
* \param in_gpair The input gradient statistics.
|
* \param in_gpair The input gradient statistics.
|
||||||
*/
|
*/
|
||||||
void BoostOneIter(int iter,
|
virtual void BoostOneIter(int iter,
|
||||||
DMatrix* train,
|
DMatrix* train,
|
||||||
std::vector<bst_gpair>* in_gpair);
|
std::vector<bst_gpair>* in_gpair) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief evaluate the model for specific iteration using the configured metrics.
|
* \brief evaluate the model for specific iteration using the configured metrics.
|
||||||
* \param iter iteration number
|
* \param iter iteration number
|
||||||
@ -77,9 +85,9 @@ class Learner : public rabit::Serializable {
|
|||||||
* \param data_names name of each dataset
|
* \param data_names name of each dataset
|
||||||
* \return a string corresponding to the evaluation result
|
* \return a string corresponding to the evaluation result
|
||||||
*/
|
*/
|
||||||
std::string EvalOneIter(int iter,
|
virtual std::string EvalOneIter(int iter,
|
||||||
const std::vector<DMatrix*>& data_sets,
|
const std::vector<DMatrix*>& data_sets,
|
||||||
const std::vector<std::string>& data_names);
|
const std::vector<std::string>& data_names) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief get prediction given the model.
|
* \brief get prediction given the model.
|
||||||
* \param data input data
|
* \param data input data
|
||||||
@ -89,11 +97,11 @@ class Learner : public rabit::Serializable {
|
|||||||
* predictor, when it equals 0, this means we are using all the trees
|
* predictor, when it equals 0, this means we are using all the trees
|
||||||
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
|
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
|
||||||
*/
|
*/
|
||||||
void Predict(DMatrix* data,
|
virtual void Predict(DMatrix* data,
|
||||||
bool output_margin,
|
bool output_margin,
|
||||||
std::vector<float> *out_preds,
|
std::vector<float> *out_preds,
|
||||||
unsigned ntree_limit = 0,
|
unsigned ntree_limit = 0,
|
||||||
bool pred_leaf = false) const;
|
bool pred_leaf = false) const = 0;
|
||||||
/*!
|
/*!
|
||||||
* \return whether the model allow lazy checkpoint in rabit.
|
* \return whether the model allow lazy checkpoint in rabit.
|
||||||
*/
|
*/
|
||||||
@ -151,5 +159,13 @@ inline void Learner::Predict(const SparseBatch::Inst& inst,
|
|||||||
obj_->PredTransform(out_preds);
|
obj_->PredTransform(out_preds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementing configure.
|
||||||
|
template<typename PairIter>
|
||||||
|
inline void Learner::Configure(PairIter begin, PairIter end) {
|
||||||
|
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||||
|
this->Configure(vec);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_LEARNER_H_
|
#endif // XGBOOST_LEARNER_H_
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "./data.h"
|
#include "./data.h"
|
||||||
#include "./base.h"
|
#include "./base.h"
|
||||||
@ -42,7 +43,7 @@ class Metric {
|
|||||||
* and the name will be matched in the registry.
|
* and the name will be matched in the registry.
|
||||||
* \return the created metric.
|
* \return the created metric.
|
||||||
*/
|
*/
|
||||||
static Metric* Create(const char *name);
|
static Metric* Create(const std::string& name);
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -22,10 +22,18 @@ class ObjFunction {
|
|||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
virtual ~ObjFunction() {}
|
virtual ~ObjFunction() {}
|
||||||
/*!
|
/*!
|
||||||
* \brief Initialize the objective with the specified parameters.
|
* \brief set configuration from pair iterators.
|
||||||
|
* \param begin The beginning iterator.
|
||||||
|
* \param end The end iterator.
|
||||||
|
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||||
|
*/
|
||||||
|
template<typename PairIter>
|
||||||
|
inline void Configure(PairIter begin, PairIter end);
|
||||||
|
/*!
|
||||||
|
* \brief Configure the objective with the specified parameters.
|
||||||
* \param args arguments to the objective function.
|
* \param args arguments to the objective function.
|
||||||
*/
|
*/
|
||||||
virtual void Init(const std::vector<std::pair<std::string, std::string> >& args) = 0;
|
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief Get gradient over each of predictions, given existing information.
|
* \brief Get gradient over each of predictions, given existing information.
|
||||||
* \param preds prediction of current round
|
* \param preds prediction of current round
|
||||||
@ -66,9 +74,16 @@ class ObjFunction {
|
|||||||
* \brief Create an objective function according to name.
|
* \brief Create an objective function according to name.
|
||||||
* \param name Name of the objective.
|
* \param name Name of the objective.
|
||||||
*/
|
*/
|
||||||
static ObjFunction* Create(const char* name);
|
static ObjFunction* Create(const std::string& name);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// implementing configure.
|
||||||
|
template<typename PairIter>
|
||||||
|
inline void ObjFunction::Configure(PairIter begin, PairIter end) {
|
||||||
|
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||||
|
this->Configure(vec);
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Registry entry for objective factory functions.
|
* \brief Registry entry for objective factory functions.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class TreeUpdater {
|
|||||||
* \brief Create a tree updater given name
|
* \brief Create a tree updater given name
|
||||||
* \param name Name of the tree updater.
|
* \param name Name of the tree updater.
|
||||||
*/
|
*/
|
||||||
static TreeUpdater* Create(const char* name);
|
static TreeUpdater* Create(const std::string& name);
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -9,12 +9,67 @@
|
|||||||
#define XGBOOST_COMMON_IO_H_
|
#define XGBOOST_COMMON_IO_H_
|
||||||
|
|
||||||
#include <dmlc/io.h>
|
#include <dmlc/io.h>
|
||||||
|
#include <string>
|
||||||
|
#include <cstring>
|
||||||
#include "./sync.h"
|
#include "./sync.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
|
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
|
||||||
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
|
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Input stream that support additional PeekRead
|
||||||
|
* operation, besides read.
|
||||||
|
*/
|
||||||
|
class PeekableInStream : public dmlc::Stream {
|
||||||
|
public:
|
||||||
|
explicit PeekableInStream(dmlc::Stream* strm)
|
||||||
|
: strm_(strm), buffer_ptr_(0) {}
|
||||||
|
|
||||||
|
size_t Read(void* dptr, size_t size) override {
|
||||||
|
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||||
|
if (nbuffer == 0) return strm_->Read(dptr, size);
|
||||||
|
if (nbuffer < size) {
|
||||||
|
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
|
||||||
|
buffer_ptr_ += nbuffer;
|
||||||
|
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
|
||||||
|
size - nbuffer);
|
||||||
|
} else {
|
||||||
|
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||||
|
buffer_ptr_ += size;
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t PeekRead(void* dptr, size_t size) {
|
||||||
|
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||||
|
if (nbuffer < size) {
|
||||||
|
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
|
||||||
|
buffer_ptr_ = 0;
|
||||||
|
buffer_.resize(size);
|
||||||
|
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
|
||||||
|
buffer_.resize(nbuffer + nadd);
|
||||||
|
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
|
||||||
|
return buffer_.length();
|
||||||
|
} else {
|
||||||
|
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Write(const void* dptr, size_t size) override {
|
||||||
|
LOG(FATAL) << "Not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief input stream */
|
||||||
|
dmlc::Stream *strm_;
|
||||||
|
/*! \brief current buffer pointer */
|
||||||
|
size_t buffer_ptr_;
|
||||||
|
/*! \brief internal buffer */
|
||||||
|
std::string buffer_;
|
||||||
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_COMMON_IO_H_
|
#endif // XGBOOST_COMMON_IO_H_
|
||||||
|
|||||||
@ -1,53 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2015 by Contributors
|
|
||||||
* \file metric_set.h
|
|
||||||
* \brief additional math utils
|
|
||||||
* \author Tianqi Chen
|
|
||||||
*/
|
|
||||||
#ifndef XGBOOST_COMMON_METRIC_SET_H_
|
|
||||||
#define XGBOOST_COMMON_METRIC_SET_H_
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace common {
|
|
||||||
|
|
||||||
/*! \brief helper util to create a set of metrics */
|
|
||||||
class MetricSet {
|
|
||||||
inline void AddEval(const char *name) {
|
|
||||||
using namespace std;
|
|
||||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
|
||||||
if (!strcmp(name, evals_[i]->Name())) return;
|
|
||||||
}
|
|
||||||
evals_.push_back(CreateEvaluator(name));
|
|
||||||
}
|
|
||||||
~EvalSet(void) {
|
|
||||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
|
||||||
delete evals_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inline std::string Eval(const char *evname,
|
|
||||||
const std::vector<float> &preds,
|
|
||||||
const MetaInfo &info,
|
|
||||||
bool distributed = false) {
|
|
||||||
std::string result = "";
|
|
||||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
|
||||||
float res = evals_[i]->Eval(preds, info, distributed);
|
|
||||||
char tmp[1024];
|
|
||||||
utils::SPrintf(tmp, sizeof(tmp), "\t%s-%s:%f", evname, evals_[i]->Name(), res);
|
|
||||||
result += tmp;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
inline size_t Size(void) const {
|
|
||||||
return evals_.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<const IEvaluator*> evals_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace common
|
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_COMMON_METRIC_SET_H_
|
|
||||||
@ -39,6 +39,8 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
|||||||
" This option is used to support boosted random forest");
|
" This option is used to support boosted random forest");
|
||||||
DMLC_DECLARE_FIELD(updater_seq).set_default("grow_colmaker,prune")
|
DMLC_DECLARE_FIELD(updater_seq).set_default("grow_colmaker,prune")
|
||||||
.describe("Tree updater sequence.");
|
.describe("Tree updater sequence.");
|
||||||
|
// add alias
|
||||||
|
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// implement factory functions
|
// implement factory functions
|
||||||
ObjFunction* ObjFunction::Create(const char* name) {
|
ObjFunction* ObjFunction::Create(const std::string& name) {
|
||||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
||||||
if (e == nullptr) {
|
if (e == nullptr) {
|
||||||
LOG(FATAL) << "Unknown objective function " << name;
|
LOG(FATAL) << "Unknown objective function " << name;
|
||||||
@ -27,7 +27,7 @@ ObjFunction* ObjFunction::Create(const char* name) {
|
|||||||
return (e->body)();
|
return (e->body)();
|
||||||
}
|
}
|
||||||
|
|
||||||
Metric* Metric::Create(const char* name) {
|
Metric* Metric::Create(const std::string& name) {
|
||||||
std::string buf = name;
|
std::string buf = name;
|
||||||
std::string prefix = name;
|
std::string prefix = name;
|
||||||
auto pos = buf.find('@');
|
auto pos = buf.find('@');
|
||||||
@ -47,7 +47,7 @@ Metric* Metric::Create(const char* name) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TreeUpdater* TreeUpdater::Create(const char* name) {
|
TreeUpdater* TreeUpdater::Create(const std::string& name) {
|
||||||
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
||||||
if (e == nullptr) {
|
if (e == nullptr) {
|
||||||
LOG(FATAL) << "Unknown tree updater " << name;
|
LOG(FATAL) << "Unknown tree updater " << name;
|
||||||
@ -55,7 +55,7 @@ TreeUpdater* TreeUpdater::Create(const char* name) {
|
|||||||
return (e->body)();
|
return (e->body)();
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBooster* GradientBooster::Create(const char* name) {
|
GradientBooster* GradientBooster::Create(const std::string& name) {
|
||||||
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
|
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
|
||||||
if (e == nullptr) {
|
if (e == nullptr) {
|
||||||
LOG(FATAL) << "Unknown gbm type " << name;
|
LOG(FATAL) << "Unknown gbm type " << name;
|
||||||
|
|||||||
762
src/learner.cc
762
src/learner.cc
@ -1,366 +1,301 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2014 by Contributors
|
* Copyright 2014 by Contributors
|
||||||
* \file learner-inl.hpp
|
* \file learner.cc
|
||||||
* \brief learning algorithm
|
* \brief Implementation of learning algorithm.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_LEARNER_LEARNER_INL_HPP_
|
#include <xgboost/learner.h>
|
||||||
#define XGBOOST_LEARNER_LEARNER_INL_HPP_
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include "../sync/sync.h"
|
#include "./common/io.h"
|
||||||
#include "../utils/io.h"
|
#include "./common/random.h"
|
||||||
#include "./objective.h"
|
|
||||||
#include "./evaluation.h"
|
|
||||||
#include "../gbm/gbm.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/*! \brief namespace for learning algorithm */
|
// implementation of base learner.
|
||||||
namespace learner {
|
bool Learner::AllowLazyCheckPoint() const {
|
||||||
|
|
||||||
inline bool Learner::AllowLazyCheckPoint() const {
|
|
||||||
return gbm_->AllowLazyCheckPoint();
|
return gbm_->AllowLazyCheckPoint();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::vector<std::string>
|
std::vector<std::string>
|
||||||
Learner::Dump2Text(const FeatureMap& fmap, int option) const {
|
Learner::Dump2Text(const FeatureMap& fmap, int option) const {
|
||||||
return gbm_->Dump2Text(fmap, option);
|
return gbm_->Dump2Text(fmap, option);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// simple routine to convert any data to string
|
||||||
|
template<typename T>
|
||||||
|
inline std::string ToString(const T& data) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << data;
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief training parameter for regression */
|
||||||
|
struct LearnerModelParam
|
||||||
|
: public dmlc::Parameter<LearnerModelParam> {
|
||||||
|
/* \brief global bias */
|
||||||
|
float base_score;
|
||||||
|
/* \brief number of features */
|
||||||
|
unsigned num_feature;
|
||||||
|
/* \brief number of classes, if it is multi-class classification */
|
||||||
|
int num_class;
|
||||||
|
/*! \brief reserved field */
|
||||||
|
int reserved[31];
|
||||||
|
/*! \brief constructor */
|
||||||
|
LearnerModelParam() {
|
||||||
|
std::memset(this, 0, sizeof(LearnerModelParam));
|
||||||
|
base_score = 0.5f;
|
||||||
|
}
|
||||||
|
// declare parameters
|
||||||
|
DMLC_DECLARE_PARAMETER(LearnerModelParam) {
|
||||||
|
DMLC_DECLARE_FIELD(base_score).set_default(0.5f)
|
||||||
|
.describe("Global bias of the model.");
|
||||||
|
DMLC_DECLARE_FIELD(num_feature).set_default(0)
|
||||||
|
.describe("Number of features in training data,"\
|
||||||
|
" this parameter will be automatically detected by learner.");
|
||||||
|
DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0)
|
||||||
|
.describe("Number of class option for multi-class classifier. "\
|
||||||
|
" By default equals 0 and corresponds to binary classifier.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct LearnerTrainParam
|
||||||
|
: public dmlc::Parameter<LearnerTrainParam> {
|
||||||
|
// stored random seed
|
||||||
|
int seed;
|
||||||
|
// whether seed the PRNG each iteration
|
||||||
|
bool seed_per_iteration;
|
||||||
|
// data split mode, can be row, col, or none.
|
||||||
|
int dsplit;
|
||||||
|
// internal test flag
|
||||||
|
std::string test_flag;
|
||||||
|
// maximum buffered row value
|
||||||
|
float prob_buffer_row;
|
||||||
|
// declare parameters
|
||||||
|
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
|
||||||
|
DMLC_DECLARE_FIELD(seed).set_default(0)
|
||||||
|
.describe("Random number seed during training.");
|
||||||
|
DMLC_DECLARE_FIELD(seed_per_iteration).set_default(false)
|
||||||
|
.describe("Seed PRNG determnisticly via iterator number, "\
|
||||||
|
"this option will be switched on automatically on distributed mode.");
|
||||||
|
DMLC_DECLARE_FIELD(dsplit).set_default(0)
|
||||||
|
.add_enum("auto", 0)
|
||||||
|
.add_enum("col", 1)
|
||||||
|
.add_enum("row", 2)
|
||||||
|
.describe("Data split mode for distributed trainig. ");
|
||||||
|
DMLC_DECLARE_FIELD(test_flag).set_default("")
|
||||||
|
.describe("Internal test flag");
|
||||||
|
DMLC_DECLARE_FIELD(prob_buffer_row).set_default(1.0f).set_range(0.0f, 1.0f)
|
||||||
|
.describe("Maximum buffered row portion");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief learner that performs gradient boosting for a specific objective function.
|
* \brief learner that performs gradient boosting for a specific objective function.
|
||||||
* It does training and prediction.
|
* It does training and prediction.
|
||||||
*/
|
*/
|
||||||
class BoostLearner : public rabit::Serializable {
|
class LearnerImpl : public Learner {
|
||||||
public:
|
public:
|
||||||
BoostLearner(void) {
|
explicit LearnerImpl(const std::vector<DMatrix*>& cache_mats)
|
||||||
obj_ = NULL;
|
noexcept(false) {
|
||||||
gbm_ = NULL;
|
// setup the cache setting in constructor.
|
||||||
|
CHECK_EQ(cache_.size(), 0);
|
||||||
|
size_t buffer_size = 0;
|
||||||
|
for (auto it = cache_mats.begin(); it != cache_mats.end(); ++it) {
|
||||||
|
// avoid duplication.
|
||||||
|
if (std::find(cache_mats.begin(), it, *it) != it) continue;
|
||||||
|
DMatrix* pmat = *it;
|
||||||
|
pmat->cache_learner_ptr_ = this;
|
||||||
|
cache_.push_back(CacheEntry(pmat, buffer_size, pmat->info().num_row));
|
||||||
|
buffer_size += pmat->info().num_row;
|
||||||
|
}
|
||||||
|
pred_buffer_size_ = buffer_size;
|
||||||
|
// boosted tree
|
||||||
name_obj_ = "reg:linear";
|
name_obj_ = "reg:linear";
|
||||||
name_gbm_ = "gbtree";
|
name_gbm_ = "gbtree";
|
||||||
silent = 0;
|
|
||||||
prob_buffer_row = 1.0f;
|
|
||||||
distributed_mode = 0;
|
|
||||||
updater_mode = 0;
|
|
||||||
pred_buffer_size = 0;
|
|
||||||
seed_per_iteration = 0;
|
|
||||||
seed = 0;
|
|
||||||
save_base64 = 0;
|
|
||||||
}
|
}
|
||||||
virtual ~BoostLearner(void) {
|
|
||||||
if (obj_ != NULL) delete obj_;
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
if (gbm_ != NULL) delete gbm_;
|
tparam.InitAllowUnknown(args);
|
||||||
}
|
// add to configurations
|
||||||
/*!
|
cfg_.clear();
|
||||||
* \brief add internal cache space for mat, this can speedup prediction for matrix,
|
for (const auto& kv : args) {
|
||||||
* please cache prediction for training and eval data
|
if (kv.first == "eval_metric") {
|
||||||
* warning: if the model is loaded from file from some previous training history
|
// check duplication
|
||||||
* set cache data must be called with exactly SAME
|
auto dup_check = [&kv](const std::unique_ptr<Metric>&m) {
|
||||||
* data matrices to continue training otherwise it will cause error
|
return m->Name() != kv.second;
|
||||||
* \param mats array of pointers to matrix whose prediction result need to be cached
|
};
|
||||||
*/
|
if (std::all_of(metrics_.begin(), metrics_.end(), dup_check)) {
|
||||||
inline void SetCacheData(const std::vector<DMatrix*>& mats) {
|
metrics_.emplace_back(Metric::Create(kv.second));
|
||||||
utils::Assert(cache_.size() == 0, "can only call cache data once");
|
}
|
||||||
// assign buffer index
|
|
||||||
size_t buffer_size = 0;
|
|
||||||
for (size_t i = 0; i < mats.size(); ++i) {
|
|
||||||
bool dupilicate = false;
|
|
||||||
for (size_t j = 0; j < i; ++j) {
|
|
||||||
if (mats[i] == mats[j]) dupilicate = true;
|
|
||||||
}
|
|
||||||
if (dupilicate) continue;
|
|
||||||
// set mats[i]'s cache learner pointer to this
|
|
||||||
mats[i]->cache_learner_ptr_ = this;
|
|
||||||
cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->info.num_row()));
|
|
||||||
buffer_size += mats[i]->info.num_row();
|
|
||||||
}
|
|
||||||
char str_temp[25];
|
|
||||||
utils::SPrintf(str_temp, sizeof(str_temp), "%lu",
|
|
||||||
static_cast<unsigned long>(buffer_size)); // NOLINT(*)
|
|
||||||
this->SetParam("num_pbuffer", str_temp);
|
|
||||||
this->pred_buffer_size = buffer_size;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set parameters from outside
|
|
||||||
* \param name name of the parameter
|
|
||||||
* \param val value of the parameter
|
|
||||||
*/
|
|
||||||
inline void SetParam(const char *name, const char *val) {
|
|
||||||
using namespace std;
|
|
||||||
// in this version, bst: prefix is no longer required
|
|
||||||
if (strncmp(name, "bst:", 4) != 0) {
|
|
||||||
std::string n = "bst:"; n += name;
|
|
||||||
this->SetParam(n.c_str(), val);
|
|
||||||
}
|
|
||||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
|
||||||
if (!strcmp(name, "dsplit")) {
|
|
||||||
if (!strcmp(val, "col")) {
|
|
||||||
this->SetParam("updater", "distcol");
|
|
||||||
distributed_mode = 1;
|
|
||||||
} else if (!strcmp(val, "row")) {
|
|
||||||
this->SetParam("updater", "grow_histmaker,prune");
|
|
||||||
distributed_mode = 2;
|
|
||||||
} else {
|
} else {
|
||||||
utils::Error("%s is invalid value for dsplit, should be row or col", val);
|
cfg_[kv.first] = kv.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "updater_mode")) updater_mode = atoi(val);
|
// add additional parameter
|
||||||
if (!strcmp(name, "prob_buffer_row")) {
|
// These are cosntraints that need to be satisfied.
|
||||||
prob_buffer_row = static_cast<float>(atof(val));
|
if (tparam.dsplit == 0 && rabit::IsDistributed()) {
|
||||||
utils::Check(distributed_mode == 0,
|
tparam.dsplit = 2;
|
||||||
"prob_buffer_row can only be used in single node mode so far");
|
|
||||||
this->SetParam("updater", "grow_colmaker,refresh,prune");
|
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
|
||||||
if (!strcmp("seed", name)) {
|
if (cfg_.count("num_class") != 0) {
|
||||||
seed = atoi(val); random::Seed(seed);
|
cfg_["num_output_group"] = cfg_["num_class"];
|
||||||
}
|
}
|
||||||
if (!strcmp("seed_per_iter", name)) seed_per_iteration = atoi(val);
|
|
||||||
if (!strcmp("save_base64", name)) save_base64 = atoi(val);
|
if (cfg_.count("max_delta_step") == 0 &&
|
||||||
if (!strcmp(name, "num_class")) {
|
cfg_.count("objective") != 0 &&
|
||||||
this->SetParam("num_output_group", val);
|
cfg_["objective"] == "count:poisson") {
|
||||||
|
cfg_["max_delta_step"] = "0.7";
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "nthread")) {
|
|
||||||
omp_set_num_threads(atoi(val));
|
if (cfg_.count("updater") == 0) {
|
||||||
}
|
if (tparam.dsplit == 1) {
|
||||||
if (gbm_ == NULL) {
|
cfg_["updater"] = "distcol";
|
||||||
if (!strcmp(name, "objective")) name_obj_ = val;
|
} else if (tparam.dsplit == 2) {
|
||||||
if (!strcmp(name, "booster")) name_gbm_ = val;
|
cfg_["updater"] = "grow_histmaker,prune";
|
||||||
mparam.SetParam(name, val);
|
}
|
||||||
}
|
if (tparam.prob_buffer_row != 1.0f) {
|
||||||
if (gbm_ != NULL) gbm_->SetParam(name, val);
|
cfg_["updater"] = "grow_histmaker,refresh,prune";
|
||||||
if (obj_ != NULL) obj_->SetParam(name, val);
|
|
||||||
if (gbm_ == NULL || obj_ == NULL) {
|
|
||||||
cfg_.push_back(std::make_pair(std::string(name), std::string(val)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// this is an internal function
|
|
||||||
// initialize the trainer, called at InitModel and LoadModel
|
|
||||||
inline void InitTrainer(bool calc_num_feature = true) {
|
|
||||||
if (calc_num_feature) {
|
|
||||||
// estimate feature bound
|
|
||||||
unsigned num_feature = 0;
|
|
||||||
for (size_t i = 0; i < cache_.size(); ++i) {
|
|
||||||
num_feature = std::max(num_feature,
|
|
||||||
static_cast<unsigned>(cache_[i].mat_->info.num_col()));
|
|
||||||
}
|
}
|
||||||
// run allreduce on num_feature to find the maximum value
|
|
||||||
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
|
|
||||||
if (num_feature > mparam.num_feature) mparam.num_feature = num_feature;
|
|
||||||
}
|
}
|
||||||
char str_temp[25];
|
if (cfg_.count("objective") == 0) {
|
||||||
utils::SPrintf(str_temp, sizeof(str_temp), "%d", mparam.num_feature);
|
cfg_["objective"] = "reg:linear";
|
||||||
this->SetParam("bst:num_feature", str_temp);
|
}
|
||||||
|
if (cfg_.count("booster") == 0) {
|
||||||
|
cfg_["booster"] = "gbtree";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!this->ModelInitialized()) {
|
||||||
|
mparam.InitAllowUnknown(args);
|
||||||
|
name_obj_ = cfg_["objective"];
|
||||||
|
name_gbm_ = cfg_["booster"];
|
||||||
|
}
|
||||||
|
|
||||||
|
common::GlobalRandom().seed(tparam.seed);
|
||||||
|
|
||||||
|
// set number of features correctly.
|
||||||
|
cfg_["num_feature"] = ToString(mparam.num_feature);
|
||||||
|
if (gbm_.get() != nullptr) {
|
||||||
|
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
}
|
||||||
|
if (obj_.get() != nullptr) {
|
||||||
|
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief initialize the model
|
void Load(dmlc::Stream* fi) override {
|
||||||
*/
|
// TODO(tqchen) mark deprecation of old format.
|
||||||
inline void InitModel(void) {
|
common::PeekableInStream fp(fi);
|
||||||
this->InitTrainer();
|
// backward compatible header check.
|
||||||
// initialize model
|
std::string header;
|
||||||
this->InitObjGBM();
|
header.resize(4);
|
||||||
// reset the base score
|
if (fp.PeekRead(&header[0], 4) == 4) {
|
||||||
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
|
CHECK_NE(header, "bs64")
|
||||||
// initialize GBM model
|
<< "Base64 format is no longer supported in brick.";
|
||||||
gbm_->InitModel();
|
if (header == "binf") {
|
||||||
}
|
CHECK_EQ(fp.Read(&header[0], 4), 4);
|
||||||
/*!
|
}
|
||||||
* \brief load model from stream
|
}
|
||||||
* \param fi input stream
|
// use the peekable reader.
|
||||||
* \param calc_num_feature whether call InitTrainer with calc_num_feature
|
fi = &fp;
|
||||||
*/
|
std::string name_gbm, name_obj;
|
||||||
inline void LoadModel(utils::IStream &fi, // NOLINT(*)
|
// read parameter
|
||||||
bool calc_num_feature = true) {
|
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
||||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
<< "BoostLearner: wrong model format";
|
||||||
"BoostLearner: wrong model format");
|
|
||||||
{
|
{
|
||||||
// backward compatibility code for compatible with old model type
|
// backward compatibility code for compatible with old model type
|
||||||
// for new model, Read(&name_obj_) is suffice
|
// for new model, Read(&name_obj_) is suffice
|
||||||
uint64_t len;
|
uint64_t len;
|
||||||
utils::Check(fi.Read(&len, sizeof(len)) != 0, "BoostLearner: wrong model format");
|
CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len));
|
||||||
if (len >= std::numeric_limits<unsigned>::max()) {
|
if (len >= std::numeric_limits<unsigned>::max()) {
|
||||||
int gap;
|
int gap;
|
||||||
utils::Check(fi.Read(&gap, sizeof(gap)) != 0, "BoostLearner: wrong model format");
|
CHECK_EQ(fi->Read(&gap, sizeof(gap)), sizeof(gap))
|
||||||
|
<< "BoostLearner: wrong model format";
|
||||||
len = len >> static_cast<uint64_t>(32UL);
|
len = len >> static_cast<uint64_t>(32UL);
|
||||||
}
|
}
|
||||||
if (len != 0) {
|
if (len != 0) {
|
||||||
name_obj_.resize(len);
|
name_obj.resize(len);
|
||||||
utils::Check(fi.Read(&name_obj_[0], len) != 0, "BoostLearner: wrong model format");
|
CHECK_EQ(fi->Read(&name_obj_[0], len), len)
|
||||||
|
<<"BoostLearner: wrong model format";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::Check(fi.Read(&name_gbm_), "BoostLearner: wrong model format");
|
CHECK(fi->Read(&name_gbm_))
|
||||||
// delete existing gbm if any
|
<< "BoostLearner: wrong model format";
|
||||||
if (obj_ != NULL) delete obj_;
|
// duplicated code with LazyInitModel
|
||||||
if (gbm_ != NULL) delete gbm_;
|
obj_.reset(ObjFunction::Create(cfg_.at(name_obj_)));
|
||||||
this->InitTrainer(calc_num_feature);
|
gbm_.reset(GradientBooster::Create(cfg_.at(name_gbm_)));
|
||||||
this->InitObjGBM();
|
if (metrics_.size() == 0) {
|
||||||
char tmp[32];
|
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||||
utils::SPrintf(tmp, sizeof(tmp), "%u", mparam.num_class);
|
|
||||||
obj_->SetParam("num_class", tmp);
|
|
||||||
gbm_->LoadModel(fi, mparam.saved_with_pbuffer != 0);
|
|
||||||
if (mparam.saved_with_pbuffer == 0) {
|
|
||||||
gbm_->ResetPredBuffer(pred_buffer_size);
|
|
||||||
}
|
}
|
||||||
|
this->base_score_ = mparam.base_score;
|
||||||
|
gbm_->ResetPredBuffer(pred_buffer_size_);
|
||||||
|
cfg_["num_class"] = ToString(mparam.num_class);
|
||||||
|
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||||
}
|
}
|
||||||
// rabit load model from rabit checkpoint
|
|
||||||
virtual void Load(rabit::Stream *fi) {
|
|
||||||
// for row split, we should not keep pbuffer
|
|
||||||
this->LoadModel(*fi, false);
|
|
||||||
}
|
|
||||||
// rabit save model to rabit checkpoint
|
// rabit save model to rabit checkpoint
|
||||||
virtual void Save(rabit::Stream *fo) const {
|
void Save(dmlc::Stream *fo) const override {
|
||||||
// for row split, we should not keep pbuffer
|
fo->Write(&mparam, sizeof(LearnerModelParam));
|
||||||
this->SaveModel(*fo, distributed_mode != 2);
|
fo->Write(name_obj_);
|
||||||
|
fo->Write(name_gbm_);
|
||||||
|
gbm_->Save(fo);
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief load model from file
|
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||||
* \param fname file name
|
if (tparam.seed_per_iteration || rabit::IsDistributed()) {
|
||||||
*/
|
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
|
||||||
inline void LoadModel(const char *fname) {
|
|
||||||
utils::IStream *fi = utils::IStream::Create(fname, "r");
|
|
||||||
std::string header; header.resize(4);
|
|
||||||
// check header for different binary encode
|
|
||||||
// can be base64 or binary
|
|
||||||
utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
|
|
||||||
// base64 format
|
|
||||||
if (header == "bs64") {
|
|
||||||
utils::Base64InStream bsin(fi);
|
|
||||||
bsin.InitPosition();
|
|
||||||
this->LoadModel(bsin, true);
|
|
||||||
} else if (header == "binf") {
|
|
||||||
this->LoadModel(*fi, true);
|
|
||||||
} else {
|
|
||||||
delete fi;
|
|
||||||
fi = utils::IStream::Create(fname, "r");
|
|
||||||
this->LoadModel(*fi, true);
|
|
||||||
}
|
|
||||||
delete fi;
|
|
||||||
}
|
|
||||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer) const { // NOLINT(*)
|
|
||||||
ModelParam p = mparam;
|
|
||||||
p.saved_with_pbuffer = static_cast<int>(with_pbuffer);
|
|
||||||
fo.Write(&p, sizeof(ModelParam));
|
|
||||||
fo.Write(name_obj_);
|
|
||||||
fo.Write(name_gbm_);
|
|
||||||
gbm_->SaveModel(fo, with_pbuffer);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief save model into file
|
|
||||||
* \param fname file name
|
|
||||||
* \param with_pbuffer whether save pbuffer together
|
|
||||||
*/
|
|
||||||
inline void SaveModel(const char *fname, bool with_pbuffer) const {
|
|
||||||
utils::IStream *fo = utils::IStream::Create(fname, "w");
|
|
||||||
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
|
||||||
fo->Write("bs64\t", 5);
|
|
||||||
utils::Base64OutStream bout(fo);
|
|
||||||
this->SaveModel(bout, with_pbuffer);
|
|
||||||
bout.Finish('\n');
|
|
||||||
} else {
|
|
||||||
fo->Write("binf", 4);
|
|
||||||
this->SaveModel(*fo, with_pbuffer);
|
|
||||||
}
|
|
||||||
delete fo;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief check if data matrix is ready to be used by training,
|
|
||||||
* if not initialize it
|
|
||||||
* \param p_train pointer to the matrix used by training
|
|
||||||
*/
|
|
||||||
inline void CheckInit(DMatrix *p_train) {
|
|
||||||
int ncol = static_cast<int>(p_train->info.info.num_col);
|
|
||||||
std::vector<bool> enabled(ncol, true);
|
|
||||||
// set max row per batch to limited value
|
|
||||||
// in distributed mode, use safe choice otherwise
|
|
||||||
size_t max_row_perbatch = std::numeric_limits<size_t>::max();
|
|
||||||
if (updater_mode != 0 || distributed_mode == 2) {
|
|
||||||
max_row_perbatch = 32UL << 10UL;
|
|
||||||
}
|
|
||||||
// initialize column access
|
|
||||||
p_train->fmat()->InitColAccess(enabled,
|
|
||||||
prob_buffer_row,
|
|
||||||
max_row_perbatch);
|
|
||||||
const int kMagicPage = 0xffffab02;
|
|
||||||
// check, if it is DMatrixPage, then use hist maker
|
|
||||||
if (p_train->magic == kMagicPage) {
|
|
||||||
this->SetParam("updater", "grow_histmaker,prune");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief update the model for one iteration
|
|
||||||
* \param iter current iteration number
|
|
||||||
* \param train reference to the data matrix
|
|
||||||
*/
|
|
||||||
inline void UpdateOneIter(int iter, const DMatrix &train) {
|
|
||||||
if (seed_per_iteration != 0 || rabit::IsDistributed()) {
|
|
||||||
random::Seed(this->seed * kRandSeedMagic + iter);
|
|
||||||
}
|
}
|
||||||
|
this->LazyInitDMatrix(train);
|
||||||
|
this->LazyInitModel();
|
||||||
this->PredictRaw(train, &preds_);
|
this->PredictRaw(train, &preds_);
|
||||||
obj_->GetGradient(preds_, train.info, iter, &gpair_);
|
obj_->GetGradient(preds_, train->info(), iter, &gpair_);
|
||||||
gbm_->DoBoost(train.fmat(), this->FindBufferOffset(train), train.info.info, &gpair_);
|
gbm_->DoBoost(train, this->FindBufferOffset(train), &gpair_);
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief whether model allow lazy checkpoint
|
void BoostOneIter(int iter,
|
||||||
*/
|
DMatrix* train,
|
||||||
inline bool AllowLazyCheckPoint(void) const {
|
std::vector<bst_gpair>* in_gpair) override {
|
||||||
return gbm_->AllowLazyCheckPoint();
|
if (tparam.seed_per_iteration || rabit::IsDistributed()) {
|
||||||
}
|
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
|
||||||
/*!
|
|
||||||
* \brief evaluate the model for specific iteration
|
|
||||||
* \param iter iteration number
|
|
||||||
* \param evals datas i want to evaluate
|
|
||||||
* \param evname name of each dataset
|
|
||||||
* \return a string corresponding to the evaluation result
|
|
||||||
*/
|
|
||||||
inline std::string EvalOneIter(int iter,
|
|
||||||
const std::vector<const DMatrix*> &evals,
|
|
||||||
const std::vector<std::string> &evname) {
|
|
||||||
std::string res;
|
|
||||||
char tmp[256];
|
|
||||||
utils::SPrintf(tmp, sizeof(tmp), "[%d]", iter);
|
|
||||||
res = tmp;
|
|
||||||
for (size_t i = 0; i < evals.size(); ++i) {
|
|
||||||
this->PredictRaw(*evals[i], &preds_);
|
|
||||||
obj_->EvalTransform(&preds_);
|
|
||||||
res += evaluator_.Eval(evname[i].c_str(), preds_, evals[i]->info, distributed_mode == 2);
|
|
||||||
}
|
}
|
||||||
return res;
|
gbm_->DoBoost(train, this->FindBufferOffset(train), in_gpair);
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief simple evaluation function, with a specified metric
|
std::string EvalOneIter(int iter,
|
||||||
* \param data input data
|
const std::vector<DMatrix*>& data_sets,
|
||||||
* \param metric name of metric
|
const std::vector<std::string>& data_names) override {
|
||||||
* \return a pair of <evaluation name, result>
|
std::ostringstream os;
|
||||||
*/
|
os << '[' << iter << ']';
|
||||||
std::pair<std::string, float> Evaluate(const DMatrix &data, std::string metric) {
|
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||||
|
this->PredictRaw(data_sets[i], &preds_);
|
||||||
|
obj_->EvalTransform(&preds_);
|
||||||
|
for (auto& ev : metrics_) {
|
||||||
|
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||||
|
<< ev->Eval(preds_, data_sets[i]->info(), tparam.dsplit == 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
|
||||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||||
IEvaluator *ev = CreateEvaluator(metric.c_str());
|
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||||
this->PredictRaw(data, &preds_);
|
this->PredictRaw(data, &preds_);
|
||||||
obj_->EvalTransform(&preds_);
|
obj_->EvalTransform(&preds_);
|
||||||
float res = ev->Eval(preds_, data.info);
|
return std::make_pair(metric, ev->Eval(preds_, data->info(), tparam.dsplit == 2));
|
||||||
delete ev;
|
|
||||||
return std::make_pair(metric, res);
|
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief get prediction
|
void Predict(DMatrix* data,
|
||||||
* \param data input data
|
bool output_margin,
|
||||||
* \param output_margin whether to only predict margin value instead of transformed prediction
|
std::vector<float> *out_preds,
|
||||||
* \param out_preds output vector that stores the prediction
|
unsigned ntree_limit,
|
||||||
* \param ntree_limit limit number of trees used for boosted tree
|
bool pred_leaf) const override {
|
||||||
* predictor, when it equals 0, this means we are using all the trees
|
|
||||||
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
|
|
||||||
*/
|
|
||||||
inline void Predict(const DMatrix &data,
|
|
||||||
bool output_margin,
|
|
||||||
std::vector<float> *out_preds,
|
|
||||||
unsigned ntree_limit = 0,
|
|
||||||
bool pred_leaf = false) const {
|
|
||||||
if (pred_leaf) {
|
if (pred_leaf) {
|
||||||
gbm_->PredictLeaf(data.fmat(), data.info.info, out_preds, ntree_limit);
|
gbm_->PredictLeaf(data, out_preds, ntree_limit);
|
||||||
} else {
|
} else {
|
||||||
this->PredictRaw(data, out_preds, ntree_limit);
|
this->PredictRaw(data, out_preds, ntree_limit);
|
||||||
if (!output_margin) {
|
if (!output_margin) {
|
||||||
@ -368,63 +303,65 @@ class BoostLearner : public rabit::Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief online prediction function, predict score for one instance at a time
|
|
||||||
* NOTE: use the batch prediction interface if possible, batch prediction is usually
|
|
||||||
* more efficient than online prediction
|
|
||||||
* This function is NOT threadsafe, make sure you only call from one thread
|
|
||||||
*
|
|
||||||
* \param inst the instance you want to predict
|
|
||||||
* \param output_margin whether to only predict margin value instead of transformed prediction
|
|
||||||
* \param out_preds output vector to hold the predictions
|
|
||||||
* \param ntree_limit limit the number of trees used in prediction
|
|
||||||
* \sa Predict
|
|
||||||
*/
|
|
||||||
inline void Predict(const SparseBatch::Inst &inst,
|
|
||||||
bool output_margin,
|
|
||||||
std::vector<float> *out_preds,
|
|
||||||
unsigned ntree_limit = 0) const {
|
|
||||||
gbm_->Predict(inst, out_preds, ntree_limit);
|
|
||||||
if (out_preds->size() == 1) {
|
|
||||||
(*out_preds)[0] += mparam.base_score;
|
|
||||||
}
|
|
||||||
if (!output_margin) {
|
|
||||||
obj_->PredTransform(out_preds);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*! \brief dump model out */
|
|
||||||
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
|
||||||
return gbm_->DumpModel(fmap, option);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/*!
|
// check if p_train is ready to used by training.
|
||||||
* \brief initialize the objective function and GBM,
|
// if not, initialize the column access.
|
||||||
* if not yet done
|
inline void LazyInitDMatrix(DMatrix *p_train) {
|
||||||
*/
|
if (p_train->HaveColAccess()) return;
|
||||||
inline void InitObjGBM(void) {
|
int ncol = static_cast<int>(p_train->info().num_col);
|
||||||
if (obj_ != NULL) return;
|
std::vector<bool> enabled(ncol, true);
|
||||||
utils::Assert(gbm_ == NULL, "GBM and obj should be NULL");
|
// set max row per batch to limited value
|
||||||
obj_ = CreateObjFunction(name_obj_.c_str());
|
// in distributed mode, use safe choice otherwise
|
||||||
gbm_ = gbm::CreateGradBooster(name_gbm_.c_str());
|
size_t max_row_perbatch = std::numeric_limits<size_t>::max();
|
||||||
this->InitAdditionDefaultParam();
|
if (tparam.test_flag == "block" || tparam.dsplit == 2) {
|
||||||
// set parameters
|
max_row_perbatch = 32UL << 10UL;
|
||||||
for (size_t i = 0; i < cfg_.size(); ++i) {
|
|
||||||
obj_->SetParam(cfg_[i].first.c_str(), cfg_[i].second.c_str());
|
|
||||||
gbm_->SetParam(cfg_[i].first.c_str(), cfg_[i].second.c_str());
|
|
||||||
}
|
}
|
||||||
if (evaluator_.Size() == 0) {
|
// initialize column access
|
||||||
evaluator_.AddEval(obj_->DefaultEvalMetric());
|
p_train->InitColAccess(enabled,
|
||||||
|
tparam.prob_buffer_row,
|
||||||
|
max_row_perbatch);
|
||||||
|
if (!p_train->SingleColBlock() && cfg_.count("updater") == 0) {
|
||||||
|
cfg_["updater"] = "grow_histmaker,prune";
|
||||||
|
if (gbm_.get() != nullptr) {
|
||||||
|
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief additional default value for specific objs
|
// return whether model is already initialized.
|
||||||
*/
|
inline bool ModelInitialized() const {
|
||||||
inline void InitAdditionDefaultParam(void) {
|
return gbm_.get() != nullptr;
|
||||||
if (name_obj_ == "count:poisson") {
|
}
|
||||||
obj_->SetParam("max_delta_step", "0.7");
|
// lazily initialize the model if it haven't yet been initialized.
|
||||||
gbm_->SetParam("max_delta_step", "0.7");
|
inline void LazyInitModel() {
|
||||||
|
if (this->ModelInitialized()) return;
|
||||||
|
// estimate feature bound
|
||||||
|
unsigned num_feature = 0;
|
||||||
|
for (size_t i = 0; i < cache_.size(); ++i) {
|
||||||
|
num_feature = std::max(num_feature,
|
||||||
|
static_cast<unsigned>(cache_[i].mat_->info().num_col));
|
||||||
}
|
}
|
||||||
|
// run allreduce on num_feature to find the maximum value
|
||||||
|
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
|
||||||
|
if (num_feature > mparam.num_feature) {
|
||||||
|
mparam.num_feature = num_feature;
|
||||||
|
}
|
||||||
|
// reset the base score
|
||||||
|
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
|
||||||
|
|
||||||
|
// setup
|
||||||
|
cfg_["num_feature"] = ToString(mparam.num_feature);
|
||||||
|
CHECK(obj_.get() == nullptr && gbm_.get() == nullptr);
|
||||||
|
obj_.reset(ObjFunction::Create(name_obj_));
|
||||||
|
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||||
|
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
if (metrics_.size() == 0) {
|
||||||
|
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||||
|
}
|
||||||
|
this->base_score_ = mparam.base_score;
|
||||||
|
gbm_->ResetPredBuffer(pred_buffer_size_);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief get un-transformed prediction
|
* \brief get un-transformed prediction
|
||||||
@ -433,125 +370,76 @@ class BoostLearner : public rabit::Serializable {
|
|||||||
* \param ntree_limit limit number of trees used for boosted tree
|
* \param ntree_limit limit number of trees used for boosted tree
|
||||||
* predictor, when it equals 0, this means we are using all the trees
|
* predictor, when it equals 0, this means we are using all the trees
|
||||||
*/
|
*/
|
||||||
inline void PredictRaw(const DMatrix &data,
|
inline void PredictRaw(DMatrix* data,
|
||||||
std::vector<float> *out_preds,
|
std::vector<float>* out_preds,
|
||||||
unsigned ntree_limit = 0) const {
|
unsigned ntree_limit = 0) const {
|
||||||
gbm_->Predict(data.fmat(), this->FindBufferOffset(data),
|
gbm_->Predict(data,
|
||||||
data.info.info, out_preds, ntree_limit);
|
this->FindBufferOffset(data),
|
||||||
|
out_preds,
|
||||||
|
ntree_limit);
|
||||||
// add base margin
|
// add base margin
|
||||||
std::vector<float> &preds = *out_preds;
|
std::vector<float>& preds = *out_preds;
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||||
if (data.info.base_margin.size() != 0) {
|
const std::vector<bst_float>& base_margin = data->info().base_margin;
|
||||||
utils::Check(preds.size() == data.info.base_margin.size(),
|
if (base_margin.size() != 0) {
|
||||||
"base_margin.size does not match with prediction size");
|
CHECK_EQ(preds.size(), base_margin.size())
|
||||||
|
<< "base_margin.size does not match with prediction size";
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
preds[j] += data.info.base_margin[j];
|
preds[j] += base_margin[j];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
preds[j] += mparam.base_score;
|
preds[j] += this->base_score_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief training parameter for regression */
|
|
||||||
struct ModelParam{
|
|
||||||
/* \brief global bias */
|
|
||||||
float base_score;
|
|
||||||
/* \brief number of features */
|
|
||||||
unsigned num_feature;
|
|
||||||
/* \brief number of classes, if it is multi-class classification */
|
|
||||||
int num_class;
|
|
||||||
/*! \brief whether the model itself is saved with pbuffer */
|
|
||||||
int saved_with_pbuffer;
|
|
||||||
/*! \brief reserved field */
|
|
||||||
int reserved[30];
|
|
||||||
/*! \brief constructor */
|
|
||||||
ModelParam(void) {
|
|
||||||
std::memset(this, 0, sizeof(ModelParam));
|
|
||||||
base_score = 0.5f;
|
|
||||||
num_feature = 0;
|
|
||||||
num_class = 0;
|
|
||||||
saved_with_pbuffer = 0;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set parameters from outside
|
|
||||||
* \param name name of the parameter
|
|
||||||
* \param val value of the parameter
|
|
||||||
*/
|
|
||||||
inline void SetParam(const char *name, const char *val) {
|
|
||||||
using namespace std;
|
|
||||||
if (!strcmp("base_score", name)) base_score = static_cast<float>(atof(val));
|
|
||||||
if (!strcmp("num_class", name)) num_class = atoi(val);
|
|
||||||
if (!strcmp("bst:num_feature", name)) num_feature = atoi(val);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// data fields
|
|
||||||
// stored random seed
|
|
||||||
int seed;
|
|
||||||
// whether seed the PRNG each iteration
|
|
||||||
// this is important for restart from existing iterations
|
|
||||||
// default set to no, but will auto switch on in distributed mode
|
|
||||||
int seed_per_iteration;
|
|
||||||
// save model in base64 encoding
|
|
||||||
int save_base64;
|
|
||||||
// silent during training
|
|
||||||
int silent;
|
|
||||||
// distributed learning mode, if any, 0:none, 1:col, 2:row
|
|
||||||
int distributed_mode;
|
|
||||||
// updater mode, 0:normal, reserved for internal test
|
|
||||||
int updater_mode;
|
|
||||||
// cached size of predict buffer
|
// cached size of predict buffer
|
||||||
size_t pred_buffer_size;
|
size_t pred_buffer_size_;
|
||||||
// maximum buffered row value
|
|
||||||
float prob_buffer_row;
|
|
||||||
// evaluation set
|
|
||||||
EvalSet evaluator_;
|
|
||||||
// model parameter
|
// model parameter
|
||||||
ModelParam mparam;
|
LearnerModelParam mparam;
|
||||||
// gbm model that back everything
|
// training parameter
|
||||||
gbm::IGradBooster *gbm_;
|
LearnerTrainParam tparam;
|
||||||
// name of gbm model used for training
|
|
||||||
std::string name_gbm_;
|
|
||||||
// objective function
|
|
||||||
IObjFunction *obj_;
|
|
||||||
// name of objective function
|
|
||||||
std::string name_obj_;
|
|
||||||
// configurations
|
// configurations
|
||||||
std::vector< std::pair<std::string, std::string> > cfg_;
|
std::map<std::string, std::string> cfg_;
|
||||||
|
// name of gbm
|
||||||
|
std::string name_gbm_;
|
||||||
|
// name of objective functon
|
||||||
|
std::string name_obj_;
|
||||||
// temporal storages for prediction
|
// temporal storages for prediction
|
||||||
std::vector<float> preds_;
|
std::vector<float> preds_;
|
||||||
// gradient pairs
|
// gradient pairs
|
||||||
std::vector<bst_gpair> gpair_;
|
std::vector<bst_gpair> gpair_;
|
||||||
|
|
||||||
protected:
|
private:
|
||||||
// magic number to transform random seed
|
/*! \brief random number transformation seed. */
|
||||||
static const int kRandSeedMagic = 127;
|
static const int kRandSeedMagic = 127;
|
||||||
// cache entry object that helps handle feature caching
|
// cache entry object that helps handle feature caching
|
||||||
struct CacheEntry {
|
struct CacheEntry {
|
||||||
const DMatrix *mat_;
|
const DMatrix* mat_;
|
||||||
size_t buffer_offset_;
|
size_t buffer_offset_;
|
||||||
size_t num_row_;
|
size_t num_row_;
|
||||||
CacheEntry(const DMatrix *mat, size_t buffer_offset, size_t num_row)
|
CacheEntry(const DMatrix* mat, size_t buffer_offset, size_t num_row)
|
||||||
:mat_(mat), buffer_offset_(buffer_offset), num_row_(num_row) {}
|
:mat_(mat), buffer_offset_(buffer_offset), num_row_(num_row) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// find internal buffer offset for certain matrix, if not exist, return -1
|
// find internal buffer offset for certain matrix, if not exist, return -1
|
||||||
inline int64_t FindBufferOffset(const DMatrix &mat) const {
|
inline int64_t FindBufferOffset(const DMatrix* mat) const {
|
||||||
for (size_t i = 0; i < cache_.size(); ++i) {
|
for (size_t i = 0; i < cache_.size(); ++i) {
|
||||||
if (cache_[i].mat_ == &mat && mat.cache_learner_ptr_ == this) {
|
if (cache_[i].mat_ == mat && mat->cache_learner_ptr_ == this) {
|
||||||
if (cache_[i].num_row_ == mat.info.num_row()) {
|
if (cache_[i].num_row_ == mat->info().num_row) {
|
||||||
return static_cast<int64_t>(cache_[i].buffer_offset_);
|
return static_cast<int64_t>(cache_[i].buffer_offset_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
// data structure field
|
|
||||||
/*! \brief the entries indicates that we have internal prediction cache */
|
/*! \brief the entries indicates that we have internal prediction cache */
|
||||||
std::vector<CacheEntry> cache_;
|
std::vector<CacheEntry> cache_;
|
||||||
};
|
};
|
||||||
} // namespace learner
|
|
||||||
|
Learner* Learner::Create(const std::vector<DMatrix*>& cache_data) {
|
||||||
|
return new LearnerImpl(cache_data);
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_LEARNER_LEARNER_INL_HPP_
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
explicit SoftmaxMultiClassObj(bool output_prob)
|
explicit SoftmaxMultiClassObj(bool output_prob)
|
||||||
: output_prob_(output_prob) {
|
: output_prob_(output_prob) {
|
||||||
}
|
}
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
}
|
}
|
||||||
void GetGradient(const std::vector<float>& preds,
|
void GetGradient(const std::vector<float>& preds,
|
||||||
|
|||||||
@ -32,7 +32,7 @@ struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
|
|||||||
// objective for lambda rank
|
// objective for lambda rank
|
||||||
class LambdaRankObj : public ObjFunction {
|
class LambdaRankObj : public ObjFunction {
|
||||||
public:
|
public:
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
}
|
}
|
||||||
void GetGradient(const std::vector<float>& preds,
|
void GetGradient(const std::vector<float>& preds,
|
||||||
|
|||||||
@ -76,7 +76,7 @@ struct RegLossParam : public dmlc::Parameter<RegLossParam> {
|
|||||||
template<typename Loss>
|
template<typename Loss>
|
||||||
class RegLossObj : public ObjFunction {
|
class RegLossObj : public ObjFunction {
|
||||||
public:
|
public:
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
}
|
}
|
||||||
void GetGradient(const std::vector<float> &preds,
|
void GetGradient(const std::vector<float> &preds,
|
||||||
@ -155,7 +155,7 @@ struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
|
|||||||
class PoissonRegression : public ObjFunction {
|
class PoissonRegression : public ObjFunction {
|
||||||
public:
|
public:
|
||||||
// declare functions
|
// declare functions
|
||||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user