Use bst_float consistently throughout (#1824)

* Fix various typos

* Add override to functions that are overridden

gcc gives warnings about functions that are being overridden by not
being marked as oveirridden. This fixes it.

* Use bst_float consistently

Use bst_float for all the variables that involve weight,
leaf value, gradient, hessian, gain, loss_chg, predictions,
base_margin, feature values.

In some cases, when due to additions and so on the value can
take a larger value, double is used.

This ensures that type conversions are minimal and reduces loss of
precision.
This commit is contained in:
AbdealiJK 2016-11-30 23:32:10 +05:30 committed by Tianqi Chen
parent da2556f58a
commit 6f16f0ef58
50 changed files with 392 additions and 389 deletions

View File

@ -2,8 +2,8 @@ Contributors of DMLC/XGBoost
============================
XGBoost has been developed and used by a group of active community. Everyone is more than welcomed to is a great way to make the project better and more accessible to more users.
Comitters
---------
Committers
----------
Committers are people who have made substantial contribution to the project and granted write access to the project.
* [Tianqi Chen](https://github.com/tqchen), University of Washington
- Tianqi is a PhD working on large-scale machine learning, he is the creator of the project.
@ -16,14 +16,14 @@ Committers are people who have made substantial contribution to the project and
* [Yuan Tang](https://github.com/terrytangyuan)
- Yuan is a data scientist in Chicago, US. He contributed mostly in R and Python packages.
Become a Comitter
-----------------
XGBoost is a opensource project and we are actively looking for new comitters who are willing to help maintaining and lead the project.
Become a Committer
------------------
XGBoost is a opensource project and we are actively looking for new committers who are willing to help maintaining and lead the project.
Committers comes from contributors who:
* Made substantial contribution to the project.
* Willing to spent time on maintaining and lead the project.
New committers will be proposed by current comitter memembers, with support from more than two of current comitters.
New committers will be proposed by current committer members, with support from more than two of current committers.
List of Contributors
--------------------
@ -44,7 +44,7 @@ List of Contributors
* [Giulio](https://github.com/giuliohome)
- Giulio is the creator of windows project of xgboost
* [Jamie Hall](https://github.com/nerdcha)
- Jamie is the initial creator of xgboost sklearn modue.
- Jamie is the initial creator of xgboost sklearn module.
* [Yen-Ying Lee](https://github.com/white1033)
* [Masaaki Horikoshi](https://github.com/sinhrks)
- Masaaki is the initial creator of xgboost python plotting module.

View File

@ -6,7 +6,7 @@
"source": [
"# XGBoost Model Analysis\n",
"\n",
"This notebook can be used to load and anlysis model learnt from all xgboost bindings, including distributed training. "
"This notebook can be used to load and analysis model learnt from all xgboost bindings, including distributed training. "
]
},
{

View File

@ -27,9 +27,9 @@ def logregobj(preds, dtrain):
# user defined evaluation function, return a pair metric_name, result
# NOTE: when you do customized loss function, the default prediction value is margin
# this may make buildin evalution metric not function properly
# this may make builtin evaluation metric not function properly
# for example, we are doing logistic loss, the prediction is score before logistic transformation
# the buildin evaluation error assumes input is after logistic transformation
# the builtin evaluation error assumes input is after logistic transformation
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
def evalerror(preds, dtrain):
labels = dtrain.get_label()

View File

@ -44,7 +44,7 @@ param['nthread'] = 16
plst = list(param.items())+[('eval_metric', 'ams@0.15')]
watchlist = [ (xgmat,'train') ]
# boost 120 tres
# boost 120 trees
num_round = 120
print ('loading data end, start to boost trees')
bst = xgb.train( plst, xgmat, num_round, watchlist );

View File

@ -42,7 +42,7 @@ param['nthread'] = 4
plst = param.items()+[('eval_metric', 'ams@0.15')]
watchlist = [ (xgmat,'train') ]
# boost 10 tres
# boost 10 trees
num_round = 10
print ('loading data end, start to boost trees')
print ("training GBM from sklearn")

View File

@ -22,7 +22,7 @@ param <- list("objective" = "multi:softprob",
"num_class" = 9,
"nthread" = 8)
# Run Cross Valication
# Run Cross Validation
cv.nround = 50
bst.cv = xgb.cv(param=param, data = x[trind,], label = y,
nfold = 3, nrounds=cv.nround)

View File

@ -16,7 +16,7 @@ Introduction
While XGBoost is known for its fast speed and accurate predictive power, it also comes with various functions to help you understand the model.
The purpose of this RMarkdown document is to demonstrate how easily we can leverage the functions already implemented in **XGBoost R** package. Of course, everything showed below can be applied to the dataset you may have to manipulate at work or wherever!
First we will prepare the **Otto** dataset and train a model, then we will generate two vizualisations to get a clue of what is important to the model, finally, we will see how we can leverage these information.
First we will prepare the **Otto** dataset and train a model, then we will generate two visualisations to get a clue of what is important to the model, finally, we will see how we can leverage these information.
Preparation of the data
=======================

View File

@ -42,7 +42,7 @@
/*! \brief namespace of xgboo st*/
namespace xgboost {
/*!
* \brief unsigned interger type used in boost,
* \brief unsigned integer type used in boost,
* used for feature index and row index.
*/
typedef uint32_t bst_uint;
@ -62,7 +62,7 @@ struct bst_gpair {
};
/*! \brief small eps gap for minimum split decision. */
const float rt_eps = 1e-6f;
const bst_float rt_eps = 1e-6f;
/*! \brief define unsigned long for openmp loop */
typedef dmlc::omp_ulong omp_ulong;

View File

@ -23,9 +23,10 @@ XGB_EXTERN_C {
#define XGB_DLL XGB_EXTERN_C
#endif
// manually define unsign long
// manually define unsigned long
typedef uint64_t bst_ulong; // NOLINT(*)
/*! \brief handle to DMatrix */
typedef void *DMatrixHandle;
/*! \brief handle to Booster */
@ -86,11 +87,11 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext(
* \brief get string message of the last error
*
* all function in this file will return 0 when success
* and -1 when an error occured,
* and -1 when an error occurred,
* XGBGetLastError can be called to retrieve the error
*
* this function is thread safe and can be called by different thread
* \return const char* error inforomation
* \return const char* error information
*/
XGB_DLL const char *XGBGetLastError();
@ -124,7 +125,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
* \param indptr pointer to row headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nindptr number of rows in the matrix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data
* \param out created dmatrix
@ -143,7 +144,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
* \param indptr pointer to row headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nindptr number of rows in the matrix + 1
* \param nelem number of nonzero elements in the matrix
* \param out created dmatrix
* \return 0 when success, -1 when failure happens
@ -159,7 +160,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(const bst_ulong *indptr,
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nindptr number of rows in the matrix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param out created dmatrix
@ -178,7 +179,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nindptr number of rows in the matrix + 1
* \param nelem number of nonzero elements in the matrix
* \param out created dmatrix
* \return 0 when success, -1 when failure happens

View File

@ -65,7 +65,7 @@ struct MetaInfo {
* \param i Instance index.
* \return The weight.
*/
inline float GetWeight(size_t i) const {
inline bst_float GetWeight(size_t i) const {
return weights.size() != 0 ? weights[i] : 1.0f;
}
/*!
@ -253,7 +253,7 @@ class DMatrix {
* \brief check if column access is supported, if not, initialize column access.
* \param enabled whether certain feature should be included in column access.
* \param subsample subsample ratio when generating column access.
* \param max_row_perbatch auxilary information, maximum row used in each column batch.
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \return Number of column blocks in the column access.
*/
@ -304,7 +304,7 @@ class DMatrix {
static DMatrix* Create(std::unique_ptr<DataSource>&& source,
const std::string& cache_prefix = "");
/*!
* \brief Create a DMatrix by loaidng data from parser.
* \brief Create a DMatrix by loading data from parser.
* Parser can later be deleted after the DMatrix i created.
* \param parser The input data parser
* \param cache_prefix The path to prefix of temporary cache file of the DMatrix when used in external memory mode.

View File

@ -78,7 +78,7 @@ class GradientBooster {
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void Predict(DMatrix* dmat,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;
/*!
* \brief online prediction function, predict score for one instance at a time
@ -93,7 +93,7 @@ class GradientBooster {
* \sa Predict
*/
virtual void Predict(const SparseBatch::Inst& inst,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0,
unsigned root_index = 0) = 0;
/*!
@ -105,7 +105,7 @@ class GradientBooster {
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void PredictLeaf(DMatrix* dmat,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;
/*!
* \brief dump the model in the requested format
@ -127,7 +127,7 @@ class GradientBooster {
static GradientBooster* Create(
const std::string& name,
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
float base_margin);
bst_float base_margin);
};
// implementing configure.
@ -144,7 +144,7 @@ struct GradientBoosterReg
: public dmlc::FunctionRegEntryBase<
GradientBoosterReg,
std::function<GradientBooster* (const std::vector<std::shared_ptr<DMatrix> > &cached_mats,
float base_margin)> > {
bst_float base_margin)> > {
};
/*!

View File

@ -106,7 +106,7 @@ class Learner : public rabit::Serializable {
*/
virtual void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
std::vector<bst_float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const = 0;
/*!
@ -162,7 +162,7 @@ class Learner : public rabit::Serializable {
*/
inline void Predict(const SparseBatch::Inst &inst,
bool output_margin,
std::vector<float> *out_preds,
std::vector<bst_float> *out_preds,
unsigned ntree_limit = 0) const;
/*!
* \brief Create a new instance of learner.
@ -185,7 +185,7 @@ class Learner : public rabit::Serializable {
// implementation of inline functions.
inline void Learner::Predict(const SparseBatch::Inst& inst,
bool output_margin,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit) const {
gbm_->Predict(inst, out_preds, ntree_limit);
if (out_preds->size() == 1) {

View File

@ -29,7 +29,7 @@ class Metric {
* the average statistics across all the node,
* this is only supported by some metrics
*/
virtual float Eval(const std::vector<float>& preds,
virtual bst_float Eval(const std::vector<bst_float>& preds,
const MetaInfo& info,
bool distributed) const = 0;
/*! \return name of metric */

View File

@ -41,7 +41,7 @@ class ObjFunction {
* \param iteration current iteration number.
* \param out_gpair output of get gradient, saves gradient and second order gradient in
*/
virtual void GetGradient(const std::vector<float>& preds,
virtual void GetGradient(const std::vector<bst_float>& preds,
const MetaInfo& info,
int iteration,
std::vector<bst_gpair>* out_gpair) = 0;
@ -52,13 +52,13 @@ class ObjFunction {
* \brief transform prediction values, this is only called when Prediction is called
* \param io_preds prediction values, saves to this vector as well
*/
virtual void PredTransform(std::vector<float> *io_preds) {}
virtual void PredTransform(std::vector<bst_float> *io_preds) {}
/*!
* \brief transform prediction values, this is only called when Eval is called,
* usually it redirect to PredTransform
* \param io_preds prediction values, saves to this vector as well
*/
virtual void EvalTransform(std::vector<float> *io_preds) {
virtual void EvalTransform(std::vector<bst_float> *io_preds) {
this->PredTransform(io_preds);
}
/*!
@ -67,7 +67,7 @@ class ObjFunction {
* used by gradient boosting
* \return transformed value
*/
virtual float ProbToMargin(float base_score) const {
virtual bst_float ProbToMargin(bst_float base_score) const {
return base_score;
}
/*!

View File

@ -106,7 +106,7 @@ class TreeModel {
return cleft_ == -1;
}
/*! \return get leaf value of leaf node */
inline float leaf_value() const {
inline bst_float leaf_value() const {
return (this->info_).leaf_value;
}
/*! \return get split condition of the node */
@ -154,7 +154,7 @@ class TreeModel {
* \param right right index, could be used to store
* additional information
*/
inline void set_leaf(float value, int right = -1) {
inline void set_leaf(bst_float value, int right = -1) {
(this->info_).leaf_value = value;
this->cleft_ = -1;
this->cright_ = right;
@ -171,7 +171,7 @@ class TreeModel {
* we have split condition
*/
union Info{
float leaf_value;
bst_float leaf_value;
TSplitCond split_cond;
};
// pointer to parent, highest bit is used to
@ -230,7 +230,7 @@ class TreeModel {
* \param rid node id of the node
* \param value new leaf value
*/
inline void ChangeToLeaf(int rid, float value) {
inline void ChangeToLeaf(int rid, bst_float value) {
CHECK(nodes[nodes[rid].cleft() ].is_leaf());
CHECK(nodes[nodes[rid].cright()].is_leaf());
this->DeleteNode(nodes[rid].cleft());
@ -242,7 +242,7 @@ class TreeModel {
* \param rid node id of the node
* \param value new leaf value
*/
inline void CollapseToLeaf(int rid, float value) {
inline void CollapseToLeaf(int rid, bst_float value) {
if (nodes[rid].is_leaf()) return;
if (!nodes[nodes[rid].cleft() ].is_leaf()) {
CollapseToLeaf(nodes[rid].cleft(), 0.0f);
@ -338,7 +338,7 @@ class TreeModel {
}
/*!
* \brief add child nodes to node
* \param nid node id to add childs
* \param nid node id to add children to
*/
inline void AddChilds(int nid) {
int pleft = this->AllocNode();
@ -398,11 +398,11 @@ class TreeModel {
/*! \brief node statistics used in regression tree */
struct RTreeNodeStat {
/*! \brief loss change caused by current split */
float loss_chg;
bst_float loss_chg;
/*! \brief sum of hessian values, used to measure coverage of data */
float sum_hess;
bst_float sum_hess;
/*! \brief weight of current node */
float base_weight;
bst_float base_weight;
/*! \brief number of child that is leaf node known up to now */
int leaf_child_cnt;
};
@ -426,12 +426,12 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
inline void Init(size_t size);
/*!
* \brief fill the vector with sparse vector
* \param inst The sparse instance to fil.
* \param inst The sparse instance to fill.
*/
inline void Fill(const RowBatch::Inst& inst);
/*!
* \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instanc to drop.
* \param inst The sparse instance to drop.
*/
inline void Drop(const RowBatch::Inst& inst);
/*!
@ -439,7 +439,7 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \param i feature index.
* \return the i-th feature value
*/
inline float fvalue(size_t i) const;
inline bst_float fvalue(size_t i) const;
/*!
* \brief check whether i-th entry is missing
* \param i feature index.
@ -453,7 +453,7 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* when flag == -1, this indicate the value is missing
*/
union Entry {
float fvalue;
bst_float fvalue;
int flag;
};
std::vector<Entry> data;
@ -471,14 +471,14 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \param root_id starting root index of the instance
* \return the leaf index of the given feature
*/
inline float Predict(const FVec& feat, unsigned root_id = 0) const;
inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const;
/*!
* \brief get next position of the tree given current pid
* \param pid Current node id.
* \param fvalue feature value if not missing.
* \param is_unknown Whether current required feature is missing.
*/
inline int GetNext(int pid, float fvalue, bool is_unknown) const;
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
/*!
* \brief dump the model in the requested format as a text string
* \param fmap feature map that may help give interpretations of feature
@ -513,7 +513,7 @@ inline void RegTree::FVec::Drop(const RowBatch::Inst& inst) {
}
}
inline float RegTree::FVec::fvalue(size_t i) const {
inline bst_float RegTree::FVec::fvalue(size_t i) const {
return data[i].fvalue;
}
@ -530,14 +530,14 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) co
return pid;
}
inline float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const {
inline bst_float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const {
int pid = this->GetLeafIndex(feat, root_id);
return (*this)[pid].leaf_value();
}
/*! \brief get next position of the tree given current pid */
inline int RegTree::GetNext(int pid, float fvalue, bool is_unknown) const {
float split_value = (*this)[pid].split_cond();
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
bst_float split_value = (*this)[pid].split_cond();
if (is_unknown) {
return (*this)[pid].cdefault();
} else {

View File

@ -28,7 +28,7 @@ class DensifyParser : public dmlc::Parser<IndexType> {
LOG(INFO) << batch.size;
dense_index_.resize(num_col_ * batch.size);
dense_value_.resize(num_col_ * batch.size);
std::fill(dense_value_.begin(), dense_value_.end(), 0.0f);
std::fill(dense_value_.begin(), dense_value_.end(), 0.0);
offset_.resize(batch.size + 1);
offset_[0] = 0;
@ -66,7 +66,7 @@ class DensifyParser : public dmlc::Parser<IndexType> {
uint32_t num_col_;
std::vector<size_t> offset_;
std::vector<IndexType> dense_index_;
std::vector<float> dense_value_;
std::vector<xgboost::bst_float> dense_value_;
};
template<typename IndexType>

View File

@ -33,35 +33,35 @@ class MyLogistic : public ObjFunction {
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float> &preds,
void GetGradient(const std::vector<bst_float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
out_gpair->resize(preds.size());
for (size_t i = 0; i < preds.size(); ++i) {
float w = info.GetWeight(i);
bst_float w = info.GetWeight(i);
// scale the negative examples!
if (info.labels[i] == 0.0f) w *= param_.scale_neg_weight;
// logistic transoformation
float p = 1.0f / (1.0f + expf(-preds[i]));
// logistic transformation
bst_float p = 1.0f / (1.0f + std::exp(-preds[i]));
// this is the gradient
float grad = (p - info.labels[i]) * w;
bst_float grad = (p - info.labels[i]) * w;
// this is the second order gradient
float hess = p * (1.0f - p) * w;
bst_float hess = p * (1.0f - p) * w;
out_gpair->at(i) = bst_gpair(grad, hess);
}
}
const char* DefaultEvalMetric() const override {
return "error";
}
void PredTransform(std::vector<float> *io_preds) override {
void PredTransform(std::vector<bst_float> *io_preds) override {
// transform margin value to probability.
std::vector<float> &preds = *io_preds;
std::vector<bst_float> &preds = *io_preds;
for (size_t i = 0; i < preds.size(); ++i) {
preds[i] = 1.0f / (1.0f + expf(-preds[i]));
preds[i] = 1.0f / (1.0f + std::exp(-preds[i]));
}
}
float ProbToMargin(float base_score) const override {
bst_float ProbToMargin(bst_float base_score) const override {
// transform probability to margin value
return -std::log(1.0f / base_score - 1.0f);
}

View File

@ -20,7 +20,7 @@ stable version, please install using pip:
- Note for windows users: this pip installation may not work on some
windows environment, and it may cause unexpected errors. pip
installation on windows is currently disabled for further
invesigation, please install from github.
investigation, please install from github.
For up-to-date version, please install from github.

View File

@ -18,7 +18,7 @@ Linux platform (also Mac OS X in general)
**Solution 0**: Please check if you have:
* installed the latest C++ compilers and `make`, for example `g++` and `gcc` (Linux) or `clang LLVM` (Mac OS X). Recommended compilers are `g++-5` or newer (Linux and Mac), or `clang` comes with Xcode in Mac OS X. For installting compilers, please refer to your system package management commands, e.g. `apt-get` `yum` or `brew`(Mac).
* installed the latest C++ compilers and `make`, for example `g++` and `gcc` (Linux) or `clang LLVM` (Mac OS X). Recommended compilers are `g++-5` or newer (Linux and Mac), or `clang` comes with Xcode in Mac OS X. For installing compilers, please refer to your system package management commands, e.g. `apt-get` `yum` or `brew`(Mac).
* compilers in your `$PATH`. Try typing `gcc` and see if your have it in your path.
* Do you use other shells than `bash` and install from `pip`? In some old version of pip installation, the shell script used `pushd` for changing directory and triggering the build process, which may failed some shells without `pushd` command. Please update to the latest version by removing the old installation and redo `pip install xgboost`
* Some outdated `make` may not recognize the recent changes in the `Makefile` and gives this error, please update to the latest `make`:

View File

@ -64,7 +64,7 @@ try:
except ImportError:
SKLEARN_INSTALLED = False
# used for compatiblity without sklearn
# used for compatibility without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object

View File

@ -19,7 +19,7 @@ from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED
class XGBoostError(Exception):
"""Error throwed by xgboost trainer."""
"""Error thrown by xgboost trainer."""
pass
@ -980,11 +980,11 @@ class Booster(object):
def save_raw(self):
"""
Save the model to a in memory buffer represetation
Save the model to a in memory buffer representation
Returns
-------
a in memory buffer represetation of the model
a in memory buffer representation of the model
"""
length = ctypes.c_ulong()
cptr = ctypes.POINTER(ctypes.c_char)()

View File

@ -7,7 +7,7 @@ import sys
class XGBoostLibraryNotFound(Exception):
"""Error throwed by when xgboost is not found"""
"""Error thrown by when xgboost is not found"""
pass

View File

@ -157,7 +157,7 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000', **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the
returned graphiz instance. Otherwise, you shoud call .render() method
returned graphiz instance. Otherwise, you should call .render() method
of the returned graphiz instance.
Parameters
@ -169,9 +169,9 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
rankdir : str, default "UT"
Passed to graphiz via graph_attr
yes_color : str, default '#0000FF'
Edge color when meets the node condigion.
Edge color when meets the node condition.
no_color : str, default '#FF0000'
Edge color when doesn't meet the node condigion.
Edge color when doesn't meet the node condition.
kwargs :
Other keywords passed to graphviz graph_attr

View File

@ -12,7 +12,7 @@ from .compat import pickle
def _init_rabit():
"""internal libary initializer."""
"""internal library initializer."""
if _LIB is not None:
_LIB.RabitGetRank.restype = ctypes.c_int
_LIB.RabitGetWorldSize.restype = ctypes.c_int
@ -21,7 +21,7 @@ def _init_rabit():
def init(args=None):
"""Initialize the rabit libary with arguments"""
"""Initialize the rabit library with arguments"""
if args is None:
args = []
arr = (ctypes.c_char_p * len(args))()
@ -156,7 +156,7 @@ def allreduce(data, op, prepare_fun=None):
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to intialize the data
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called

View File

@ -142,7 +142,7 @@ class XGBModel(XGBModelBase):
self._Booster = None
def __setstate__(self, state):
# backward compatiblity code
# backward compatibility code
# load booster from raw if it is raw
# the booster now support pickle
bst = state["_Booster"]

View File

@ -148,7 +148,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist.
Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and
a paramater containing ('eval_metric': 'logloss')
a parameter containing ('eval_metric': 'logloss')
Returns: {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}
verbose_eval : bool or int
@ -291,7 +291,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
seed=0, callbacks=None):
# pylint: disable = invalid-name
"""Cross-validation with given paramaters.
"""Cross-validation with given parameters.
Parameters
----------

View File

@ -191,7 +191,7 @@ struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<float> ret_vec_float;
std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<bst_gpair> tmp_gpair;
};
@ -229,7 +229,7 @@ int XGDMatrixCreateFromDataIter(
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
const unsigned* indices,
const float* data,
const bst_float* data,
size_t nindptr,
size_t nelem,
size_t num_col,
@ -260,7 +260,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
XGB_DLL int XGDMatrixCreateFromCSR(const xgboost::bst_ulong* indptr,
const unsigned *indices,
const float* data,
const bst_float* data,
xgboost::bst_ulong nindptr,
xgboost::bst_ulong nelem,
DMatrixHandle* out) {
@ -274,7 +274,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(const xgboost::bst_ulong* indptr,
XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
const unsigned* indices,
const float* data,
const bst_float* data,
size_t nindptr,
size_t nelem,
size_t num_row,
@ -321,7 +321,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
XGB_DLL int XGDMatrixCreateFromCSC(const xgboost::bst_ulong* col_ptr,
const unsigned* indices,
const float* data,
const bst_float* data,
xgboost::bst_ulong nindptr,
xgboost::bst_ulong nelem,
DMatrixHandle* out) {
@ -333,10 +333,10 @@ XGB_DLL int XGDMatrixCreateFromCSC(const xgboost::bst_ulong* col_ptr,
static_cast<size_t>(nindptr), static_cast<size_t>(nelem), 0, out);
}
XGB_DLL int XGDMatrixCreateFromMat(const float* data,
XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
xgboost::bst_ulong nrow,
xgboost::bst_ulong ncol,
float missing,
bst_float missing,
DMatrixHandle* out) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
@ -428,7 +428,7 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field,
const float* info,
const bst_float* info,
xgboost::bst_ulong len) {
API_BEGIN();
static_cast<std::shared_ptr<DMatrix>*>(handle)
@ -463,10 +463,10 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
const char* field,
xgboost::bst_ulong* out_len,
const float** out_dptr) {
const bst_float** out_dptr) {
API_BEGIN();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info();
const std::vector<float>* vec = nullptr;
const std::vector<bst_float>* vec = nullptr;
if (!std::strcmp(field, "label")) {
vec = &info.labels;
} else if (!std::strcmp(field, "weight")) {
@ -556,8 +556,8 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
DMatrixHandle dtrain,
float *grad,
float *hess,
bst_float *grad,
bst_float *hess,
xgboost::bst_ulong len) {
std::vector<bst_gpair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
API_BEGIN();
@ -602,8 +602,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
int option_mask,
unsigned ntree_limit,
xgboost::bst_ulong *len,
const float **out_result) {
std::vector<float>& preds = XGBAPIThreadLocalStore::Get()->ret_vec_float;
const bst_float **out_result) {
std::vector<bst_float>& preds = XGBAPIThreadLocalStore::Get()->ret_vec_float;
API_BEGIN();
Booster *bst = static_cast<Booster*>(handle);
bst->LazyInit();

View File

@ -28,7 +28,7 @@
*/
void XGBAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out
* \brief handle exception thrown out
* \param e the exception
* \return the return value of API after exception is handled
*/

View File

@ -182,7 +182,7 @@ void CLITrain(const CLIParam& param) {
std::unique_ptr<Learner> learner(Learner::Create(cache_mats));
int version = rabit::LoadCheckPoint(learner.get());
if (version == 0) {
// initializ the model if needed.
// initialize the model if needed.
if (param.model_in != "NULL") {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
@ -320,7 +320,7 @@ void CLIPredict(const CLIParam& param) {
if (param.silent == 0) {
LOG(CONSOLE) << "start prediction...";
}
std::vector<float> preds;
std::vector<bst_float> preds;
learner->Predict(dtest.get(), param.pred_margin, &preds, param.ntree_limit);
if (param.silent == 0) {
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
@ -328,7 +328,7 @@ void CLIPredict(const CLIParam& param) {
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param.name_pred.c_str(), "w"));
dmlc::ostream os(fo.get());
for (float p : preds) {
for (bst_float p : preds) {
os << p << '\n';
}
// force flush before fo destruct.

View File

@ -77,12 +77,12 @@ inline bool MetaTryLoadGroup(const std::string& fname,
// try to load weight information from file, if exists
inline bool MetaTryLoadFloatInfo(const std::string& fname,
std::vector<float>* data) {
std::vector<bst_float>* data) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi.get() == nullptr) return false;
dmlc::istream is(fi.get());
data->clear();
float value;
bst_float value;
while (is >> value) {
data->push_back(value);
}

View File

@ -44,7 +44,7 @@ class SimpleDMatrix : public DMatrix {
return buffered_rowset_;
}
size_t GetColSize(size_t cidx) const {
size_t GetColSize(size_t cidx) const override {
return col_size_[cidx];
}

View File

@ -205,7 +205,7 @@ class SparsePage::Writer {
* \brief Push a write job to the writer.
* This function won't block,
* writing is done by another thread inside writer.
* \param page The page to be wriiten
* \param page The page to be written
*/
void PushWrite(std::unique_ptr<SparsePage>&& page);
/*!

View File

@ -48,7 +48,7 @@ class SparsePageDMatrix : public DMatrix {
return buffered_rowset_;
}
size_t GetColSize(size_t cidx) const {
size_t GetColSize(size_t cidx) const override {
return col_size_[cidx];
}
@ -111,7 +111,7 @@ class SparsePageDMatrix : public DMatrix {
std::vector<SparseBatch::Inst> col_data_;
};
/*!
* \brief Try to intitialize column data.
* \brief Try to initialize column data.
* \return true if data already exists, false if they do not.
*/
bool TryInitColData();

View File

@ -87,7 +87,7 @@ struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
*/
class GBLinear : public GradientBooster {
public:
explicit GBLinear(float base_margin)
explicit GBLinear(bst_float base_margin)
: base_margin_(base_margin) {
}
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
@ -149,13 +149,13 @@ class GBLinear : public GradientBooster {
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (bst_uint j = 0; j < col.length; ++j) {
const float v = col[j].fvalue;
const bst_float v = col[j].fvalue;
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.hess < 0.0f) continue;
sum_grad += p.grad * v;
sum_hess += p.hess * v * v;
}
float &w = model[fid][gid];
bst_float &w = model[fid][gid];
bst_float dw = static_cast<bst_float>(param.learning_rate *
param.CalcDelta(sum_grad, sum_hess, w));
w += dw;
@ -171,14 +171,14 @@ class GBLinear : public GradientBooster {
}
void Predict(DMatrix *p_fmat,
std::vector<float> *out_preds,
std::vector<bst_float> *out_preds,
unsigned ntree_limit) override {
if (model.weight.size() == 0) {
model.InitModel();
}
CHECK_EQ(ntree_limit, 0)
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
std::vector<float> &preds = *out_preds;
std::vector<bst_float> &preds = *out_preds;
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
if (base_margin.size() != 0) {
CHECK_EQ(preds.size(), base_margin.size())
@ -201,7 +201,7 @@ class GBLinear : public GradientBooster {
const size_t ridx = batch.base_rowid + i;
// loop over output groups
for (int gid = 0; gid < ngroup; ++gid) {
float margin = (base_margin.size() != 0) ?
bst_float margin = (base_margin.size() != 0) ?
base_margin[ridx * ngroup + gid] : base_margin_;
this->Pred(batch[i], &preds[ridx * ngroup], gid, margin);
}
@ -210,7 +210,7 @@ class GBLinear : public GradientBooster {
}
// add base margin
void Predict(const SparseBatch::Inst &inst,
std::vector<float> *out_preds,
std::vector<bst_float> *out_preds,
unsigned ntree_limit,
unsigned root_index) override {
const int ngroup = model.param.num_output_group;
@ -219,7 +219,7 @@ class GBLinear : public GradientBooster {
}
}
void PredictLeaf(DMatrix *p_fmat,
std::vector<float> *out_preds,
std::vector<bst_float> *out_preds,
unsigned ntree_limit) override {
LOG(FATAL) << "gblinear does not support predict leaf index";
}
@ -261,8 +261,8 @@ class GBLinear : public GradientBooster {
}
protected:
inline void Pred(const RowBatch::Inst &inst, float *preds, int gid, float base) {
float psum = model.bias()[gid] + base;
inline void Pred(const RowBatch::Inst &inst, bst_float *preds, int gid, bst_float base) {
bst_float psum = model.bias()[gid] + base;
for (bst_uint i = 0; i < inst.length; ++i) {
if (inst[i].index >= model.param.num_feature) continue;
psum += inst[i].fvalue * model[inst[i].index][gid];
@ -275,7 +275,7 @@ class GBLinear : public GradientBooster {
// parameter
GBLinearModelParam param;
// weight for each of feature, bias is the last one
std::vector<float> weight;
std::vector<bst_float> weight;
// initialize the model parameter
inline void InitModel(void) {
// bias is the last weight
@ -293,22 +293,22 @@ class GBLinear : public GradientBooster {
fi->Read(&weight);
}
// model bias
inline float* bias() {
inline bst_float* bias() {
return &weight[param.num_feature * param.num_output_group];
}
inline const float* bias() const {
inline const bst_float* bias() const {
return &weight[param.num_feature * param.num_output_group];
}
// get i-th weight
inline float* operator[](size_t i) {
inline bst_float* operator[](size_t i) {
return &weight[i * param.num_output_group];
}
inline const float* operator[](size_t i) const {
inline const bst_float* operator[](size_t i) const {
return &weight[i * param.num_output_group];
}
};
// biase margin score
float base_margin_;
bst_float base_margin_;
// model field
Model model;
// training parameter
@ -317,13 +317,13 @@ class GBLinear : public GradientBooster {
std::vector<bst_uint> feat_index;
};
// register the ojective functions
// register the objective functions
DMLC_REGISTER_PARAMETER(GBLinearModelParam);
DMLC_REGISTER_PARAMETER(GBLinearTrainParam);
XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
.describe("Linear booster, implement generalized linear model.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> >&cache, float base_margin) {
.set_body([](const std::vector<std::shared_ptr<DMatrix> >&cache, bst_float base_margin) {
return new GBLinear(base_margin);
});
} // namespace gbm

View File

@ -14,7 +14,7 @@ namespace xgboost {
GradientBooster* GradientBooster::Create(
const std::string& name,
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
float base_margin) {
bst_float base_margin) {
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown gbm type " << name;

View File

@ -91,7 +91,7 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
int num_roots;
/*! \brief number of features to be used by trees */
int num_feature;
/*! \brief pad this space, for backward compatiblity reason.*/
/*! \brief pad this space, for backward compatibility reason.*/
int pad_32bit;
/*! \brief deprecated padding space. */
int64_t num_pbuffer_deprecated;
@ -128,13 +128,13 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
// cache entry
struct CacheEntry {
std::shared_ptr<DMatrix> data;
std::vector<float> predictions;
std::vector<bst_float> predictions;
};
// gradient boosted trees
class GBTree : public GradientBooster {
public:
explicit GBTree(float base_margin) : base_margin_(base_margin) {}
explicit GBTree(bst_float base_margin) : base_margin_(base_margin) {}
void InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache) {
for (const std::shared_ptr<DMatrix>& d : cache) {
@ -225,13 +225,13 @@ class GBTree : public GradientBooster {
}
void Predict(DMatrix* p_fmat,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit) override {
if (ntree_limit == 0 ||
ntree_limit * mparam.num_output_group >= trees.size()) {
auto it = cache_.find(p_fmat);
if (it != cache_.end()) {
std::vector<float>& y = it->second.predictions;
std::vector<bst_float>& y = it->second.predictions;
if (y.size() != 0) {
out_preds->resize(y.size());
std::copy(y.begin(), y.end(), out_preds->begin());
@ -243,7 +243,7 @@ class GBTree : public GradientBooster {
}
void Predict(const SparseBatch::Inst& inst,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit,
unsigned root_index) override {
if (thread_temp.size() == 0) {
@ -264,7 +264,7 @@ class GBTree : public GradientBooster {
}
void PredictLeaf(DMatrix* p_fmat,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit) override {
int nthread;
#pragma omp parallel
@ -291,7 +291,7 @@ class GBTree : public GradientBooster {
template<typename Derived>
inline void PredLoopInternal(
DMatrix* p_fmat,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned tree_begin,
unsigned ntree_limit,
bool init_out_preds) {
@ -303,7 +303,7 @@ class GBTree : public GradientBooster {
if (init_out_preds) {
size_t n = num_group * p_fmat->info().num_row;
const std::vector<float>& base_margin = p_fmat->info().base_margin;
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
out_preds->resize(n);
if (base_margin.size() != 0) {
CHECK_EQ(out_preds->size(), n);
@ -325,7 +325,7 @@ class GBTree : public GradientBooster {
template<typename Derived>
inline void PredLoopSpecalize(
DMatrix* p_fmat,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
int num_group,
unsigned tree_begin,
unsigned tree_end) {
@ -337,7 +337,7 @@ class GBTree : public GradientBooster {
}
CHECK_EQ(num_group, mparam.num_output_group);
InitThreadTemp(nthread);
std::vector<float> &preds = *out_preds;
std::vector<bst_float> &preds = *out_preds;
CHECK_EQ(mparam.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far";
CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group);
@ -424,13 +424,13 @@ class GBTree : public GradientBooster {
}
// make a prediction for a single instance
inline float PredValue(const RowBatch::Inst &inst,
inline bst_float PredValue(const RowBatch::Inst &inst,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,
unsigned tree_begin,
unsigned tree_end) {
float psum = 0.0f;
bst_float psum = 0.0f;
p_feats->Fill(inst);
for (size_t i = tree_begin; i < tree_end; ++i) {
if (tree_info[i] == bst_group) {
@ -443,7 +443,7 @@ class GBTree : public GradientBooster {
}
// predict independent leaf index
inline void PredPath(DMatrix *p_fmat,
std::vector<float> *out_preds,
std::vector<bst_float> *out_preds,
unsigned ntree_limit) {
const MetaInfo& info = p_fmat->info();
// number of valid trees
@ -451,7 +451,7 @@ class GBTree : public GradientBooster {
if (ntree_limit == 0 || ntree_limit > trees.size()) {
ntree_limit = static_cast<unsigned>(trees.size());
}
std::vector<float>& preds = *out_preds;
std::vector<bst_float>& preds = *out_preds;
preds.resize(info.num_row * ntree_limit);
// start collecting the prediction
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
@ -468,7 +468,7 @@ class GBTree : public GradientBooster {
feats.Fill(batch[i]);
for (unsigned j = 0; j < ntree_limit; ++j) {
int tid = trees[j]->GetLeafIndex(feats, info.GetRoot(ridx));
preds[ridx * ntree_limit + j] = static_cast<float>(tid);
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
}
feats.Drop(batch[i]);
}
@ -486,7 +486,7 @@ class GBTree : public GradientBooster {
}
// --- data structure ---
// base margin
float base_margin_;
bst_float base_margin_;
// training parameter
GBTreeTrainParam tparam;
// model parameter
@ -508,7 +508,7 @@ class GBTree : public GradientBooster {
// dart
class Dart : public GBTree {
public:
explicit Dart(float base_margin) : GBTree(base_margin) {}
explicit Dart(bst_float base_margin) : GBTree(base_margin) {}
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
GBTree::Configure(cfg);
@ -534,14 +534,14 @@ class Dart : public GBTree {
// predict the leaf scores with dropout if ntree_limit = 0
void Predict(DMatrix* p_fmat,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit) override {
DropTrees(ntree_limit);
PredLoopInternal<Dart>(p_fmat, out_preds, 0, ntree_limit, true);
}
void Predict(const SparseBatch::Inst& inst,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit,
unsigned root_index) override {
DropTrees(1);
@ -579,13 +579,13 @@ class Dart : public GBTree {
}
}
// predict the leaf scores without dropped trees
inline float PredValue(const RowBatch::Inst &inst,
inline bst_float PredValue(const RowBatch::Inst &inst,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,
unsigned tree_begin,
unsigned tree_end) {
float psum = 0.0f;
bst_float psum = 0.0f;
p_feats->Fill(inst);
for (size_t i = tree_begin; i < tree_end; ++i) {
if (tree_info[i] == bst_group) {
@ -611,7 +611,7 @@ class Dart : public GBTree {
if (dparam.skip_drop > 0.0) skip = (runif(rnd) < dparam.skip_drop);
if (ntree_limit_drop == 0 && !skip) {
if (dparam.sample_type == 1) {
float sum_weight = 0.0;
bst_float sum_weight = 0.0;
for (size_t i = 0; i < weight_drop.size(); ++i) {
sum_weight += weight_drop[i];
}
@ -667,26 +667,26 @@ class Dart : public GBTree {
// training parameter
DartTrainParam dparam;
/*! \brief prediction buffer */
std::vector<float> weight_drop;
std::vector<bst_float> weight_drop;
// indexes of dropped trees
std::vector<size_t> idx_drop;
};
// register the ojective functions
// register the objective functions
DMLC_REGISTER_PARAMETER(GBTreeModelParam);
DMLC_REGISTER_PARAMETER(GBTreeTrainParam);
DMLC_REGISTER_PARAMETER(DartTrainParam);
XGBOOST_REGISTER_GBM(GBTree, "gbtree")
.describe("Tree booster, gradient boosted trees.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, float base_margin) {
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, bst_float base_margin) {
GBTree* p = new GBTree(base_margin);
p->InitCache(cached_mats);
return p;
});
XGBOOST_REGISTER_GBM(Dart, "dart")
.describe("Tree booster, dart.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, float base_margin) {
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, bst_float base_margin) {
GBTree* p = new Dart(base_margin);
return p;
});

View File

@ -36,7 +36,7 @@ Learner::DumpModel(const FeatureMap& fmap,
struct LearnerModelParam
: public dmlc::Parameter<LearnerModelParam> {
/* \brief global bias */
float base_score;
bst_float base_score;
/* \brief number of features */
unsigned num_feature;
/* \brief number of classes, if it is multi-class classification */
@ -353,7 +353,7 @@ class LearnerImpl : public Learner {
return out;
}
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
std::pair<std::string, bst_float> Evaluate(DMatrix* data, std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric();
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
this->PredictRaw(data, &preds_);
@ -363,7 +363,7 @@ class LearnerImpl : public Learner {
void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
std::vector<bst_float> *out_preds,
unsigned ntree_limit,
bool pred_leaf) const override {
if (pred_leaf) {
@ -460,7 +460,7 @@ class LearnerImpl : public Learner {
* predictor, when it equals 0, this means we are using all the trees
*/
inline void PredictRaw(DMatrix* data,
std::vector<float>* out_preds,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) const {
CHECK(gbm_.get() != nullptr)
<< "Predict must happen after Load or InitModel";
@ -478,10 +478,10 @@ class LearnerImpl : public Learner {
std::map<std::string, std::string> attributes_;
// name of gbm
std::string name_gbm_;
// name of objective functon
// name of objective function
std::string name_obj_;
// temporal storages for prediction
std::vector<float> preds_;
std::vector<bst_float> preds_;
// gradient pairs
std::vector<bst_gpair> gpair_;

View File

@ -21,7 +21,7 @@ DMLC_REGISTRY_FILE_TAG(elementwise_metric);
*/
template<typename Derived>
struct EvalEWiseBase : public Metric {
float Eval(const std::vector<float>& preds,
bst_float Eval(const std::vector<bst_float>& preds,
const MetaInfo& info,
bool distributed) const override {
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
@ -32,7 +32,7 @@ struct EvalEWiseBase : public Metric {
double sum = 0.0, wsum = 0.0;
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) {
const float wt = info.GetWeight(i);
const bst_float wt = info.GetWeight(i);
sum += static_cast<const Derived*>(this)->EvalRow(info.labels[i], preds[i]) * wt;
wsum += wt;
}
@ -48,13 +48,13 @@ struct EvalEWiseBase : public Metric {
* \param label label of current instance
* \param pred prediction value of current instance
*/
inline float EvalRow(float label, float pred) const;
inline bst_float EvalRow(bst_float label, bst_float pred) const;
/*!
* \brief to be overridden by subclass, final transformation
* \param esum the sum statistics returned by EvalRow
* \param wsum sum of weight
*/
inline static float GetFinal(float esum, float wsum) {
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum;
}
};
@ -63,11 +63,11 @@ struct EvalRMSE : public EvalEWiseBase<EvalRMSE> {
const char *Name() const override {
return "rmse";
}
inline float EvalRow(float label, float pred) const {
float diff = label - pred;
inline bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float diff = label - pred;
return diff * diff;
}
inline static float GetFinal(float esum, float wsum) {
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
return std::sqrt(esum / wsum);
}
};
@ -76,7 +76,7 @@ struct EvalMAE : public EvalEWiseBase<EvalMAE> {
const char *Name() const override {
return "mae";
}
inline float EvalRow(float label, float pred) const {
inline bst_float EvalRow(bst_float label, bst_float pred) const {
return std::abs(label - pred);
}
};
@ -85,9 +85,9 @@ struct EvalLogLoss : public EvalEWiseBase<EvalLogLoss> {
const char *Name() const override {
return "logloss";
}
inline float EvalRow(float y, float py) const {
const float eps = 1e-16f;
const float pneg = 1.0f - py;
inline bst_float EvalRow(bst_float y, bst_float py) const {
const bst_float eps = 1e-16f;
const bst_float pneg = 1.0f - py;
if (py < eps) {
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
} else if (pneg < eps) {
@ -115,12 +115,12 @@ struct EvalError : public EvalEWiseBase<EvalError> {
const char *Name() const override {
return name_.c_str();
}
inline float EvalRow(float label, float pred) const {
inline bst_float EvalRow(bst_float label, bst_float pred) const {
// assume label is in [0,1]
return pred > threshold_ ? 1.0f - label : label;
}
protected:
float threshold_;
bst_float threshold_;
std::string name_;
};
@ -128,8 +128,8 @@ struct EvalPoissionNegLogLik : public EvalEWiseBase<EvalPoissionNegLogLik> {
const char *Name() const override {
return "poisson-nloglik";
}
inline float EvalRow(float y, float py) const {
const float eps = 1e-16f;
inline bst_float EvalRow(bst_float y, bst_float py) const {
const bst_float eps = 1e-16f;
if (py < eps) py = eps;
return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
}
@ -139,12 +139,12 @@ struct EvalGammaDeviance : public EvalEWiseBase<EvalGammaDeviance> {
const char *Name() const override {
return "gamma-deviance";
}
inline float EvalRow(float label, float pred) const {
float epsilon = 1.0e-9;
float tmp = label / (pred + epsilon);
inline bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float epsilon = 1.0e-9;
bst_float tmp = label / (pred + epsilon);
return tmp - std::log(tmp) - 1;
}
inline static float GetFinal(float esum, float wsum) {
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
return 2 * esum;
}
};
@ -153,12 +153,12 @@ struct EvalGammaNLogLik: public EvalEWiseBase<EvalGammaNLogLik> {
const char *Name() const override {
return "gamma-nloglik";
}
inline float EvalRow(float y, float py) const {
float psi = 1.0;
float theta = -1. / py;
float a = psi;
float b = -std::log(-theta);
float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
inline bst_float EvalRow(bst_float y, bst_float py) const {
bst_float psi = 1.0;
bst_float theta = -1. / py;
bst_float a = psi;
bst_float b = -std::log(-theta);
bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
return -((y * theta - b) / a + c);
}
};
@ -177,14 +177,14 @@ struct EvalTweedieNLogLik: public EvalEWiseBase<EvalTweedieNLogLik> {
const char *Name() const override {
return name_.c_str();
}
inline float EvalRow(float y, float p) const {
float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
inline bst_float EvalRow(bst_float y, bst_float p) const {
bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
return -a + b;
}
protected:
std::string name_;
float rho_;
bst_float rho_;
};
XGBOOST_REGISTER_METRIC(RMSE, "rmse")

View File

@ -20,7 +20,7 @@ DMLC_REGISTRY_FILE_TAG(multiclass_metric);
*/
template<typename Derived>
struct EvalMClassBase : public Metric {
float Eval(const std::vector<float> &preds,
bst_float Eval(const std::vector<bst_float> &preds,
const MetaInfo &info,
bool distributed) const override {
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
@ -35,7 +35,7 @@ struct EvalMClassBase : public Metric {
int label_error = 0;
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const float wt = info.GetWeight(i);
const bst_float wt = info.GetWeight(i);
int label = static_cast<int>(info.labels[i]);
if (label >= 0 && label < static_cast<int>(nclass)) {
sum += Derived::EvalRow(label,
@ -63,15 +63,15 @@ struct EvalMClassBase : public Metric {
* \param pred prediction value of current instance
* \param nclass number of class in the prediction
*/
inline static float EvalRow(int label,
const float *pred,
inline static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass);
/*!
* \brief to be overridden by subclass, final transformation
* \param esum the sum statistics returned by EvalRow
* \param wsum sum of weight
*/
inline static float GetFinal(float esum, float wsum) {
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum;
}
// used to store error message
@ -83,8 +83,8 @@ struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
const char* Name() const override {
return "merror";
}
inline static float EvalRow(int label,
const float *pred,
inline static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass) {
return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast<int>(label);
}
@ -95,10 +95,10 @@ struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
const char* Name() const override {
return "mlogloss";
}
inline static float EvalRow(int label,
const float *pred,
inline static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass) {
const float eps = 1e-16f;
const bst_float eps = 1e-16f;
size_t k = static_cast<size_t>(label);
if (pred[k] > eps) {
return -std::log(pred[k]);

View File

@ -26,7 +26,7 @@ struct EvalAMS : public Metric {
os << "ams@" << ratio_;
name_ = os.str();
}
float Eval(const std::vector<float> &preds,
bst_float Eval(const std::vector<bst_float> &preds,
const MetaInfo &info,
bool distributed) const override {
CHECK(!distributed) << "metric AMS do not support distributed evaluation";
@ -34,7 +34,7 @@ struct EvalAMS : public Metric {
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
CHECK_EQ(info.weights.size(), ndata) << "we need weight to evaluate ams";
std::vector<std::pair<float, unsigned> > rec(ndata);
std::vector<std::pair<bst_float, unsigned> > rec(ndata);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
@ -48,7 +48,7 @@ struct EvalAMS : public Metric {
double s_tp = 0.0, b_fp = 0.0, tams = 0.0;
for (unsigned i = 0; i < static_cast<unsigned>(ndata-1) && i < ntop; ++i) {
const unsigned ridx = rec[i].second;
const float wt = info.weights[ridx];
const bst_float wt = info.weights[ridx];
if (info.labels[ridx] > 0.5f) {
s_tp += wt;
} else {
@ -63,10 +63,10 @@ struct EvalAMS : public Metric {
}
}
if (ntop == ndata) {
LOG(INFO) << "best-ams-ratio=" << static_cast<float>(thresindex) / ndata;
return static_cast<float>(tams);
LOG(INFO) << "best-ams-ratio=" << static_cast<bst_float>(thresindex) / ndata;
return static_cast<bst_float>(tams);
} else {
return static_cast<float>(
return static_cast<bst_float>(
sqrt(2 * ((s_tp + b_fp + br) * log(1.0 + s_tp/(b_fp + br)) - s_tp)));
}
}
@ -82,7 +82,7 @@ struct EvalAMS : public Metric {
/*! \brief Area Under Curve, for both classification and rank */
struct EvalAuc : public Metric {
float Eval(const std::vector<float> &preds,
bst_float Eval(const std::vector<bst_float> &preds,
const MetaInfo &info,
bool distributed) const override {
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
@ -96,12 +96,12 @@ struct EvalAuc : public Metric {
<< "EvalAuc: group structure must match number of prediction";
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum statistics
double sum_auc = 0.0f;
bst_float sum_auc = 0.0f;
int auc_error = 0;
#pragma omp parallel reduction(+:sum_auc)
{
// each thread takes a local rec
std::vector< std::pair<float, unsigned> > rec;
std::vector< std::pair<bst_float, unsigned> > rec;
#pragma omp for schedule(static)
for (bst_omp_uint k = 0; k < ngroup; ++k) {
rec.clear();
@ -113,8 +113,8 @@ struct EvalAuc : public Metric {
double sum_pospair = 0.0;
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
for (size_t j = 0; j < rec.size(); ++j) {
const float wt = info.GetWeight(rec[j].second);
const float ctr = info.labels[rec[j].second];
const bst_float wt = info.GetWeight(rec[j].second);
const bst_float ctr = info.labels[rec[j].second];
// keep bucketing predictions in same bucket
if (j != 0 && rec[j].first != rec[j - 1].first) {
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
@ -140,14 +140,14 @@ struct EvalAuc : public Metric {
CHECK(!auc_error)
<< "AUC: the dataset only contains pos or neg samples";
if (distributed) {
float dat[2];
dat[0] = static_cast<float>(sum_auc);
dat[1] = static_cast<float>(ngroup);
bst_float dat[2];
dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup);
// approximately estimate auc using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1];
} else {
return static_cast<float>(sum_auc) / ngroup;
return static_cast<bst_float>(sum_auc) / ngroup;
}
}
const char* Name() const override {
@ -158,7 +158,7 @@ struct EvalAuc : public Metric {
/*! \brief Evaluate rank list */
struct EvalRankList : public Metric {
public:
float Eval(const std::vector<float> &preds,
bst_float Eval(const std::vector<bst_float> &preds,
const MetaInfo &info,
bool distributed) const override {
CHECK_EQ(preds.size(), info.labels.size())
@ -176,7 +176,7 @@ struct EvalRankList : public Metric {
#pragma omp parallel reduction(+:sum_metric)
{
// each thread takes a local rec
std::vector< std::pair<float, unsigned> > rec;
std::vector< std::pair<bst_float, unsigned> > rec;
#pragma omp for schedule(static)
for (bst_omp_uint k = 0; k < ngroup; ++k) {
rec.clear();
@ -187,14 +187,14 @@ struct EvalRankList : public Metric {
}
}
if (distributed) {
float dat[2];
dat[0] = static_cast<float>(sum_metric);
dat[1] = static_cast<float>(ngroup);
bst_float dat[2];
dat[0] = static_cast<bst_float>(sum_metric);
dat[1] = static_cast<bst_float>(ngroup);
// approximately estimate the metric using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1];
} else {
return static_cast<float>(sum_metric) / ngroup;
return static_cast<bst_float>(sum_metric) / ngroup;
}
}
const char* Name() const override {
@ -221,7 +221,7 @@ struct EvalRankList : public Metric {
}
}
/*! \return evaluation metric, given the pair_sort record, (pred,label) */
virtual float EvalMetric(std::vector<std::pair<float, unsigned> > &pair_sort) const = 0; // NOLINT(*)
virtual bst_float EvalMetric(std::vector<std::pair<bst_float, unsigned> > &pair_sort) const = 0; // NOLINT(*)
protected:
unsigned topn_;
@ -235,14 +235,14 @@ struct EvalPrecision : public EvalRankList{
explicit EvalPrecision(const char *name) : EvalRankList("pre", name) {}
protected:
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &rec) const {
virtual bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const {
// calculate Precision
std::sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhit = 0;
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
nhit += (rec[j].second != 0);
}
return static_cast<float>(nhit) / topn_;
return static_cast<bst_float>(nhit) / topn_;
}
};
@ -252,7 +252,7 @@ struct EvalNDCG : public EvalRankList{
explicit EvalNDCG(const char *name) : EvalRankList("ndcg", name) {}
protected:
inline float CalcDCG(const std::vector<std::pair<float, unsigned> > &rec) const {
inline bst_float CalcDCG(const std::vector<std::pair<bst_float, unsigned> > &rec) const {
double sumdcg = 0.0;
for (size_t i = 0; i < rec.size() && i < this->topn_; ++i) {
const unsigned rel = rec[i].second;
@ -260,13 +260,13 @@ struct EvalNDCG : public EvalRankList{
sumdcg += ((1 << rel) - 1) / std::log2(i + 2.0);
}
}
return static_cast<float>(sumdcg);
return sumdcg;
}
virtual float EvalMetric(std::vector<std::pair<float, unsigned> > &rec) const { // NOLINT(*)
virtual bst_float EvalMetric(std::vector<std::pair<bst_float, unsigned> > &rec) const { // NOLINT(*)
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
float dcg = this->CalcDCG(rec);
bst_float dcg = this->CalcDCG(rec);
std::stable_sort(rec.begin(), rec.end(), common::CmpSecond);
float idcg = this->CalcDCG(rec);
bst_float idcg = this->CalcDCG(rec);
if (idcg == 0.0f) {
if (minus_) {
return 0.0f;
@ -284,7 +284,7 @@ struct EvalMAP : public EvalRankList {
explicit EvalMAP(const char *name) : EvalRankList("map", name) {}
protected:
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &rec) const {
virtual bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const {
std::sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
@ -292,13 +292,13 @@ struct EvalMAP : public EvalRankList {
if (rec[i].second != 0) {
nhits += 1;
if (i < this->topn_) {
sumap += static_cast<float>(nhits) / (i + 1);
sumap += static_cast<bst_float>(nhits) / (i + 1);
}
}
}
if (nhits != 0) {
sumap /= nhits;
return static_cast<float>(sumap);
return static_cast<bst_float>(sumap);
} else {
if (minus_) {
return 0.0f;

View File

@ -35,7 +35,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float>& preds,
void GetGradient(const std::vector<bst_float>& preds,
const MetaInfo& info,
int iter,
std::vector<bst_gpair>* out_gpair) override {
@ -49,7 +49,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
int label_error = 0;
#pragma omp parallel
{
std::vector<float> rec(nclass);
std::vector<bst_float> rec(nclass);
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) {
for (int k = 0; k < nclass; ++k) {
@ -60,10 +60,10 @@ class SoftmaxMultiClassObj : public ObjFunction {
if (label < 0 || label >= nclass) {
label_error = label; label = 0;
}
const float wt = info.GetWeight(i);
const bst_float wt = info.GetWeight(i);
for (int k = 0; k < nclass; ++k) {
float p = rec[k];
const float h = 2.0f * p * (1.0f - p) * wt;
bst_float p = rec[k];
const bst_float h = 2.0f * p * (1.0f - p) * wt;
if (label == k) {
out_gpair->at(i * nclass + k) = bst_gpair((p - 1.0f) * wt, h);
} else {
@ -77,10 +77,10 @@ class SoftmaxMultiClassObj : public ObjFunction {
<< " num_class=" << nclass
<< " but found " << label_error << " in label.";
}
void PredTransform(std::vector<float>* io_preds) override {
void PredTransform(std::vector<bst_float>* io_preds) override {
this->Transform(io_preds, output_prob_);
}
void EvalTransform(std::vector<float>* io_preds) override {
void EvalTransform(std::vector<bst_float>* io_preds) override {
this->Transform(io_preds, true);
}
const char* DefaultEvalMetric() const override {
@ -88,23 +88,23 @@ class SoftmaxMultiClassObj : public ObjFunction {
}
private:
inline void Transform(std::vector<float> *io_preds, bool prob) {
std::vector<float> &preds = *io_preds;
std::vector<float> tmp;
inline void Transform(std::vector<bst_float> *io_preds, bool prob) {
std::vector<bst_float> &preds = *io_preds;
std::vector<bst_float> tmp;
const int nclass = param_.num_class;
const omp_ulong ndata = static_cast<omp_ulong>(preds.size() / nclass);
if (!prob) tmp.resize(ndata);
#pragma omp parallel
{
std::vector<float> rec(nclass);
std::vector<bst_float> rec(nclass);
#pragma omp for schedule(static)
for (omp_ulong j = 0; j < ndata; ++j) {
for (int k = 0; k < nclass; ++k) {
rec[k] = preds[j * nclass + k];
}
if (!prob) {
tmp[j] = static_cast<float>(
tmp[j] = static_cast<bst_float>(
common::FindMaxIndex(rec.begin(), rec.end()) - rec.begin());
} else {
common::Softmax(&rec);
@ -122,7 +122,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
SoftmaxMultiClassParam param_;
};
// register the ojective functions
// register the objective functions
DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam);
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax")

View File

@ -37,7 +37,7 @@ class LambdaRankObj : public ObjFunction {
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float>& preds,
void GetGradient(const std::vector<bst_float>& preds,
const MetaInfo& info,
int iter,
std::vector<bst_gpair>* out_gpair) override {
@ -58,7 +58,7 @@ class LambdaRankObj : public ObjFunction {
std::vector<LambdaPair> pairs;
std::vector<ListEntry> lst;
std::vector< std::pair<float, unsigned> > rec;
std::vector< std::pair<bst_float, unsigned> > rec;
#pragma omp for schedule(static)
for (bst_omp_uint k = 0; k < ngroup; ++k) {
lst.clear(); pairs.clear();
@ -103,11 +103,11 @@ class LambdaRankObj : public ObjFunction {
for (size_t i = 0; i < pairs.size(); ++i) {
const ListEntry &pos = lst[pairs[i].pos_index];
const ListEntry &neg = lst[pairs[i].neg_index];
const float w = pairs[i].weight * scale;
const bst_float w = pairs[i].weight * scale;
const float eps = 1e-16f;
float p = common::Sigmoid(pos.pred - neg.pred);
float g = p - 1.0f;
float h = std::max(p * (1.0f - p), eps);
bst_float p = common::Sigmoid(pos.pred - neg.pred);
bst_float g = p - 1.0f;
bst_float h = std::max(p * (1.0f - p), eps);
// accumulate gradient and hessian in both pid, and nid
gpair[pos.rindex].grad += g * w;
gpair[pos.rindex].hess += 2.0f * w * h;
@ -125,13 +125,13 @@ class LambdaRankObj : public ObjFunction {
/*! \brief helper information in a list */
struct ListEntry {
/*! \brief the predict score we in the data */
float pred;
bst_float pred;
/*! \brief the actual label of the entry */
float label;
bst_float label;
/*! \brief row index in the data matrix */
unsigned rindex;
// constructor
ListEntry(float pred, float label, unsigned rindex)
ListEntry(bst_float pred, bst_float label, unsigned rindex)
: pred(pred), label(label), rindex(rindex) {}
// comparator by prediction
inline static bool CmpPred(const ListEntry &a, const ListEntry &b) {
@ -149,7 +149,7 @@ class LambdaRankObj : public ObjFunction {
/*! \brief negative index: this is a position in the list */
unsigned neg_index;
/*! \brief weight to be filled in */
float weight;
bst_float weight;
// constructor
LambdaPair(unsigned pos_index, unsigned neg_index)
: pos_index(pos_index), neg_index(neg_index), weight(1.0f) {}
@ -180,11 +180,11 @@ class LambdaRankObjNDCG : public LambdaRankObj {
std::vector<LambdaPair> &pairs = *io_pairs;
float IDCG;
{
std::vector<float> labels(sorted_list.size());
std::vector<bst_float> labels(sorted_list.size());
for (size_t i = 0; i < sorted_list.size(); ++i) {
labels[i] = sorted_list[i].label;
}
std::sort(labels.begin(), labels.end(), std::greater<float>());
std::sort(labels.begin(), labels.end(), std::greater<bst_float>());
IDCG = CalcDCG(labels);
}
if (IDCG == 0.0) {
@ -200,25 +200,25 @@ class LambdaRankObjNDCG : public LambdaRankObj {
float neg_loginv = 1.0f / std::log(neg_idx + 2.0f);
int pos_label = static_cast<int>(sorted_list[pos_idx].label);
int neg_label = static_cast<int>(sorted_list[neg_idx].label);
float original =
bst_float original =
((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
float changed =
((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
float delta = (original - changed) * IDCG;
bst_float delta = (original - changed) * IDCG;
if (delta < 0.0f) delta = - delta;
pairs[i].weight = delta;
}
}
}
inline static float CalcDCG(const std::vector<float> &labels) {
inline static bst_float CalcDCG(const std::vector<bst_float> &labels) {
double sumdcg = 0.0;
for (size_t i = 0; i < labels.size(); ++i) {
const unsigned rel = static_cast<unsigned>(labels[i]);
if (rel != 0) {
sumdcg += ((1 << rel) - 1) / std::log2(static_cast<float>(i + 2));
sumdcg += ((1 << rel) - 1) / std::log2(static_cast<bst_float>(i + 2));
}
}
return static_cast<float>(sumdcg);
return static_cast<bst_float>(sumdcg);
}
};
@ -250,7 +250,7 @@ class LambdaRankObjMAP : public LambdaRankObj {
* \param index1,index2 the instances switched
* \param map_stats a vector containing the accumulated precisions for each position in a list
*/
inline float GetLambdaMAP(const std::vector<ListEntry> &sorted_list,
inline bst_float GetLambdaMAP(const std::vector<ListEntry> &sorted_list,
int index1, int index2,
std::vector<MAPStats> *p_map_stats) {
std::vector<MAPStats> &map_stats = *p_map_stats;
@ -258,11 +258,11 @@ class LambdaRankObjMAP : public LambdaRankObj {
return 0.0f;
}
if (index1 > index2) std::swap(index1, index2);
float original = map_stats[index2].ap_acc;
bst_float original = map_stats[index2].ap_acc;
if (index1 != 0) original -= map_stats[index1 - 1].ap_acc;
float changed = 0;
float label1 = sorted_list[index1].label > 0.0f ? 1.0f : 0.0f;
float label2 = sorted_list[index2].label > 0.0f ? 1.0f : 0.0f;
bst_float changed = 0;
bst_float label1 = sorted_list[index1].label > 0.0f ? 1.0f : 0.0f;
bst_float label2 = sorted_list[index2].label > 0.0f ? 1.0f : 0.0f;
if (label1 == label2) {
return 0.0;
} else if (label1 < label2) {
@ -272,7 +272,7 @@ class LambdaRankObjMAP : public LambdaRankObj {
changed += map_stats[index2 - 1].ap_acc_miss - map_stats[index1].ap_acc_miss;
changed += map_stats[index2].hits / (index2 + 1);
}
float ans = (changed - original) / (map_stats[map_stats.size() - 1].hits);
bst_float ans = (changed - original) / (map_stats[map_stats.size() - 1].hits);
if (ans < 0) ans = -ans;
return ans;
}
@ -285,7 +285,7 @@ class LambdaRankObjMAP : public LambdaRankObj {
std::vector<MAPStats> *p_map_acc) {
std::vector<MAPStats> &map_acc = *p_map_acc;
map_acc.resize(sorted_list.size());
float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0;
bst_float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0;
for (size_t i = 1; i <= sorted_list.size(); ++i) {
if (sorted_list[i - 1].label > 0.0f) {
hit++;
@ -309,7 +309,7 @@ class LambdaRankObjMAP : public LambdaRankObj {
}
};
// register the ojective functions
// register the objective functions
DMLC_REGISTER_PARAMETER(LambdaRankParam);
XGBOOST_REGISTER_OBJECTIVE(PairwieRankObj, "rank:pairwise")

View File

@ -20,24 +20,24 @@ DMLC_REGISTRY_FILE_TAG(regression_obj);
// common regressions
// linear regression
struct LinearSquareLoss {
static float PredTransform(float x) { return x; }
static bool CheckLabel(float x) { return true; }
static float FirstOrderGradient(float predt, float label) { return predt - label; }
static float SecondOrderGradient(float predt, float label) { return 1.0f; }
static float ProbToMargin(float base_score) { return base_score; }
static bst_float PredTransform(bst_float x) { return x; }
static bool CheckLabel(bst_float x) { return true; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; }
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { return 1.0f; }
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
static const char* LabelErrorMsg() { return ""; }
static const char* DefaultEvalMetric() { return "rmse"; }
};
// logistic loss for probability regression task
struct LogisticRegression {
static float PredTransform(float x) { return common::Sigmoid(x); }
static bool CheckLabel(float x) { return x >= 0.0f && x <= 1.0f; }
static float FirstOrderGradient(float predt, float label) { return predt - label; }
static float SecondOrderGradient(float predt, float label) {
static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); }
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; }
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const float eps = 1e-16f;
return std::max(predt * (1.0f - predt), eps);
}
static float ProbToMargin(float base_score) {
static bst_float ProbToMargin(bst_float base_score) {
CHECK(base_score > 0.0f && base_score < 1.0f)
<< "base_score must be in (0,1) for logistic loss";
return -std::log(1.0f / base_score - 1.0f);
@ -53,12 +53,12 @@ struct LogisticClassification : public LogisticRegression {
};
// logistic loss, but predict un-transformed margin
struct LogisticRaw : public LogisticRegression {
static float PredTransform(float x) { return x; }
static float FirstOrderGradient(float predt, float label) {
static bst_float PredTransform(bst_float x) { return x; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
predt = common::Sigmoid(predt);
return predt - label;
}
static float SecondOrderGradient(float predt, float label) {
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const float eps = 1e-16f;
predt = common::Sigmoid(predt);
return std::max(predt * (1.0f - predt), eps);
@ -75,14 +75,14 @@ struct RegLossParam : public dmlc::Parameter<RegLossParam> {
}
};
// regression los function
// regression loss function
template<typename Loss>
class RegLossObj : public ObjFunction {
public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float> &preds,
void GetGradient(const std::vector<bst_float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
@ -97,8 +97,8 @@ class RegLossObj : public ObjFunction {
const omp_ulong ndata = static_cast<omp_ulong>(preds.size());
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) {
float p = Loss::PredTransform(preds[i]);
float w = info.GetWeight(i);
bst_float p = Loss::PredTransform(preds[i]);
bst_float w = info.GetWeight(i);
if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight;
if (!Loss::CheckLabel(info.labels[i])) label_correct = false;
out_gpair->at(i) = bst_gpair(Loss::FirstOrderGradient(p, info.labels[i]) * w,
@ -111,15 +111,15 @@ class RegLossObj : public ObjFunction {
const char* DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric();
}
void PredTransform(std::vector<float> *io_preds) override {
std::vector<float> &preds = *io_preds;
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
preds[j] = Loss::PredTransform(preds[j]);
}
}
float ProbToMargin(float base_score) const override {
bst_float ProbToMargin(bst_float base_score) const override {
return Loss::ProbToMargin(base_score);
}
@ -127,7 +127,7 @@ class RegLossObj : public ObjFunction {
RegLossParam param_;
};
// register the ojective functions
// register the objective functions
DMLC_REGISTER_PARAMETER(RegLossParam);
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
@ -164,7 +164,7 @@ class PoissonRegression : public ObjFunction {
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float> &preds,
void GetGradient(const std::vector<bst_float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
@ -177,9 +177,9 @@ class PoissonRegression : public ObjFunction {
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
float p = preds[i];
float w = info.GetWeight(i);
float y = info.labels[i];
bst_float p = preds[i];
bst_float w = info.GetWeight(i);
bst_float y = info.labels[i];
if (y >= 0.0f) {
out_gpair->at(i) = bst_gpair((std::exp(p) - y) * w,
std::exp(p + param_.max_delta_step) * w);
@ -189,18 +189,18 @@ class PoissonRegression : public ObjFunction {
}
CHECK(label_correct) << "PoissonRegression: label must be nonnegative";
}
void PredTransform(std::vector<float> *io_preds) override {
std::vector<float> &preds = *io_preds;
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(std::vector<float> *io_preds) override {
void EvalTransform(std::vector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
float ProbToMargin(float base_score) const override {
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric(void) const override {
@ -211,7 +211,7 @@ class PoissonRegression : public ObjFunction {
PoissonRegressionParam param_;
};
// register the ojective functions
// register the objective functions
DMLC_REGISTER_PARAMETER(PoissonRegressionParam);
XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
@ -225,7 +225,7 @@ class GammaRegression : public ObjFunction {
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
}
void GetGradient(const std::vector<float> &preds,
void GetGradient(const std::vector<bst_float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
@ -238,9 +238,9 @@ class GammaRegression : public ObjFunction {
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
float p = preds[i];
float w = info.GetWeight(i);
float y = info.labels[i];
bst_float p = preds[i];
bst_float w = info.GetWeight(i);
bst_float y = info.labels[i];
if (y >= 0.0f) {
out_gpair->at(i) = bst_gpair((1 - y / std::exp(p)) * w, y / std::exp(p) * w);
} else {
@ -249,18 +249,18 @@ class GammaRegression : public ObjFunction {
}
CHECK(label_correct) << "GammaRegression: label must be positive";
}
void PredTransform(std::vector<float> *io_preds) override {
std::vector<float> &preds = *io_preds;
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(std::vector<float> *io_preds) override {
void EvalTransform(std::vector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
float ProbToMargin(float base_score) const override {
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric(void) const override {
@ -268,7 +268,7 @@ class GammaRegression : public ObjFunction {
}
};
// register the ojective functions
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
.describe("Gamma regression for severity data.")
.set_body([]() { return new GammaRegression(); });
@ -290,7 +290,7 @@ class TweedieRegression : public ObjFunction {
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float> &preds,
void GetGradient(const std::vector<bst_float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
@ -303,13 +303,14 @@ class TweedieRegression : public ObjFunction {
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
float p = preds[i];
float w = info.GetWeight(i);
float y = info.labels[i];
bst_float p = preds[i];
bst_float w = info.GetWeight(i);
bst_float y = info.labels[i];
float rho = param_.tweedie_variance_power;
if (y >= 0.0f) {
float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p);
float hess = -y * (1 - rho) * std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p);
bst_float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p);
bst_float hess = -y * (1 - rho) * \
std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p);
out_gpair->at(i) = bst_gpair(grad * w, hess * w);
} else {
label_correct = false;
@ -317,8 +318,8 @@ class TweedieRegression : public ObjFunction {
}
CHECK(label_correct) << "TweedieRegression: label must be nonnegative";
}
void PredTransform(std::vector<float> *io_preds) override {
std::vector<float> &preds = *io_preds;
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
@ -336,7 +337,7 @@ class TweedieRegression : public ObjFunction {
TweedieRegressionParam param_;
};
// register the ojective functions
// register the objective functions
DMLC_REGISTER_PARAMETER(TweedieRegressionParam);
XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")

View File

@ -408,7 +408,7 @@ struct SplitEntry {
/*! \brief split index */
unsigned sindex;
/*! \brief split value */
float split_value;
bst_float split_value;
/*! \brief constructor */
SplitEntry() : loss_chg(0.0f), sindex(0), split_value(0.0f) {}
/*!
@ -452,7 +452,7 @@ struct SplitEntry {
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(bst_float new_loss_chg, unsigned split_index,
float new_split_value, bool default_left) {
bst_float new_split_value, bool default_left) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left)

View File

@ -68,13 +68,13 @@ void DumpRegTree(std::stringstream& fo, // NOLINT(*)
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.name(split_index) << "\""
<< ", \"split_condition\": " << int(float(cond) + 1.0f)
<< ", \"split_condition\": " << int(cond + 1.0)
<< ", \"yes\": " << tree[nid].cleft()
<< ", \"no\": " << tree[nid].cright()
<< ", \"missing\": " << tree[nid].cdefault();
} else {
fo << nid << ":[" << fmap.name(split_index) << "<"
<< int(float(cond)+1.0f)
<< int(cond + 1.0)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
@ -87,12 +87,12 @@ void DumpRegTree(std::stringstream& fo, // NOLINT(*)
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.name(split_index) << "\""
<< ", \"split_condition\": " << float(cond)
<< ", \"split_condition\": " << cond
<< ", \"yes\": " << tree[nid].cleft()
<< ", \"no\": " << tree[nid].cright()
<< ", \"missing\": " << tree[nid].cdefault();
} else {
fo << nid << ":[" << fmap.name(split_index) << "<" << float(cond)
fo << nid << ":[" << fmap.name(split_index) << "<" << cond
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
@ -106,12 +106,12 @@ void DumpRegTree(std::stringstream& fo, // NOLINT(*)
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": " << split_index
<< ", \"split_condition\": " << float(cond)
<< ", \"split_condition\": " << cond
<< ", \"yes\": " << tree[nid].cleft()
<< ", \"no\": " << tree[nid].cright()
<< ", \"missing\": " << tree[nid].cdefault();
} else {
fo << nid << ":[f" << split_index << "<"<< float(cond)
fo << nid << ":[f" << split_index << "<"<< cond
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();

View File

@ -267,7 +267,7 @@ class BaseMaker: public TreeUpdater {
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue;
const bst_float fvalue = col[j].fvalue;
const int nid = this->DecodePosition(ridx);
CHECK(tree[nid].is_leaf());
int pid = tree[nid].parent();
@ -327,7 +327,7 @@ class BaseMaker: public TreeUpdater {
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue;
const bst_float fvalue = col[j].fvalue;
const int nid = this->DecodePosition(ridx);
// go back to parent, correct those who are not default
if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {

View File

@ -53,9 +53,9 @@ class ColMaker: public TreeUpdater {
/*! \brief extra statistics of data */
TStats stats_extra;
/*! \brief last feature value scanned */
float last_fvalue;
bst_float last_fvalue;
/*! \brief first feature value scanned */
float first_fvalue;
bst_float first_fvalue;
/*! \brief current best solution */
SplitEntry best;
// constructor
@ -69,7 +69,7 @@ class ColMaker: public TreeUpdater {
/*! \brief loss of this node, without split */
bst_float root_gain;
/*! \brief weight calculated related to current data */
float weight;
bst_float weight;
/*! \brief current best solution */
SplitEntry best;
// constructor
@ -284,7 +284,7 @@ class ColMaker: public TreeUpdater {
const bst_uint ridx = col[i].index;
const int nid = position[ridx];
if (nid < 0) continue;
const float fvalue = col[i].fvalue;
const bst_float fvalue = col[i].fvalue;
if (temp[nid].stats.Empty()) {
temp[nid].first_fvalue = fvalue;
}
@ -309,7 +309,7 @@ class ColMaker: public TreeUpdater {
for (int tid = 0; tid < nthread; ++tid) {
stemp[tid][nid].stats_extra = sum;
ThreadEntry &e = stemp[tid][nid];
float fsplit;
bst_float fsplit;
if (tid != 0) {
if (stemp[tid - 1][nid].last_fvalue != e.first_fvalue) {
fsplit = (stemp[tid - 1][nid].last_fvalue + e.first_fvalue) * 0.5f;
@ -364,7 +364,7 @@ class ColMaker: public TreeUpdater {
const bst_uint ridx = col[i].index;
const int nid = position[ridx];
if (nid < 0) continue;
const float fvalue = col[i].fvalue;
const bst_float fvalue = col[i].fvalue;
// get the statistics of nid
ThreadEntry &e = temp[nid];
if (e.stats.Empty()) {
@ -403,7 +403,7 @@ class ColMaker: public TreeUpdater {
}
// update enumeration solution
inline void UpdateEnumeration(int nid, bst_gpair gstats,
float fvalue, int d_step, bst_uint fid,
bst_float fvalue, int d_step, bst_uint fid,
TStats &c, std::vector<ThreadEntry> &temp) { // NOLINT(*)
// get the statistics of nid
ThreadEntry &e = temp[nid];
@ -503,8 +503,8 @@ class ColMaker: public TreeUpdater {
loss_chg = static_cast<bst_float>(
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain);
}
const float gap = std::abs(e.last_fvalue) + rt_eps;
const float delta = d_step == +1 ? gap: -gap;
const bst_float gap = std::abs(e.last_fvalue) + rt_eps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
}
}
@ -535,7 +535,7 @@ class ColMaker: public TreeUpdater {
const int nid = position[ridx];
if (nid < 0) continue;
// start working
const float fvalue = it->fvalue;
const bst_float fvalue = it->fvalue;
// get the statistics of nid
ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init
@ -580,8 +580,8 @@ class ColMaker: public TreeUpdater {
loss_chg = static_cast<bst_float>(
constraints_[nid].CalcSplitGain(param, fid, e.stats, c) - snode[nid].root_gain);
}
const float gap = std::abs(e.last_fvalue) + rt_eps;
const float delta = d_step == +1 ? gap: -gap;
const bst_float gap = std::abs(e.last_fvalue) + rt_eps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
}
}
@ -730,7 +730,7 @@ class ColMaker: public TreeUpdater {
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const int nid = this->DecodePosition(ridx);
const float fvalue = col[j].fvalue;
const bst_float fvalue = col[j].fvalue;
// go back to parent, correct those who are not default
if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
if (fvalue < tree[nid].split_cond()) {
@ -864,7 +864,7 @@ class DistColMaker : public ColMaker<TStats, TConstraint> {
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue;
const bst_float fvalue = col[j].fvalue;
const int nid = this->DecodePosition(ridx);
if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
if (fvalue < tree[nid].split_cond()) {
@ -898,7 +898,7 @@ class DistColMaker : public ColMaker<TStats, TConstraint> {
}
}
// synchronize the best solution of each node
virtual void SyncBestSolution(const std::vector<int> &qexpand) {
void SyncBestSolution(const std::vector<int> &qexpand) override {
std::vector<SplitEntry> vec;
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];

View File

@ -191,7 +191,7 @@ class HistMaker: public BaseMaker {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
if (best->Update(static_cast<float>(loss_chg), fid, hist.cut[i], false)) {
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false)) {
*left_sum = s;
}
}
@ -204,7 +204,7 @@ class HistMaker: public BaseMaker {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
if (best->Update(static_cast<float>(loss_chg), fid, hist.cut[i-1], true)) {
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true)) {
*left_sum = c;
}
}
@ -260,8 +260,8 @@ class HistMaker: public BaseMaker {
}
inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) {
p_tree->stat(nid).base_weight = static_cast<float>(node_sum.CalcWeight(param));
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
p_tree->stat(nid).base_weight = static_cast<bst_float>(node_sum.CalcWeight(param));
p_tree->stat(nid).sum_hess = static_cast<bst_float>(node_sum.sum_hess);
node_sum.SetLeafVec(param, p_tree->leafvec(nid));
}
};

View File

@ -27,7 +27,7 @@ class TreeRefresher: public TreeUpdater {
// update the tree, do pruning
void Update(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) {
const std::vector<RegTree*> &trees) override {
if (trees.size() == 0) return;
// number of threads
// thread temporal space
@ -130,13 +130,13 @@ class TreeRefresher: public TreeUpdater {
inline void Refresh(const TStats *gstats,
int nid, RegTree *p_tree) {
RegTree &tree = *p_tree;
tree.stat(nid).base_weight = static_cast<float>(gstats[nid].CalcWeight(param));
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
tree.stat(nid).base_weight = static_cast<bst_float>(gstats[nid].CalcWeight(param));
tree.stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
} else {
tree.stat(nid).loss_chg = static_cast<float>(
tree.stat(nid).loss_chg = static_cast<bst_float>(
gstats[tree[nid].cleft()].CalcGain(param) +
gstats[tree[nid].cright()].CalcGain(param) -
gstats[nid].CalcGain(param));

View File

@ -60,7 +60,7 @@ class SketchMaker: public BaseMaker {
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
this->SetStats(nid, node_stats[nid], p_tree);
if (!(*p_tree)[nid].is_leaf()) {
p_tree->stat(nid).loss_chg = static_cast<float>(
p_tree->stat(nid).loss_chg = static_cast<bst_float>(
node_stats[(*p_tree)[nid].cleft()].CalcGain(param) +
node_stats[(*p_tree)[nid].cright()].CalcGain(param) -
node_stats[nid].CalcGain(param));
@ -310,8 +310,8 @@ class SketchMaker: public BaseMaker {
}
// set statistics on ptree
inline void SetStats(int nid, const SKStats &node_sum, RegTree *p_tree) {
p_tree->stat(nid).base_weight = static_cast<float>(node_sum.CalcWeight(param));
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
p_tree->stat(nid).base_weight = static_cast<bst_float>(node_sum.CalcWeight(param));
p_tree->stat(nid).sum_hess = static_cast<bst_float>(node_sum.sum_hess);
node_sum.SetLeafVec(param, p_tree->leafvec(nid));
}
inline void EnumerateSplit(const WXQSketch::Summary &pos_grad,
@ -372,7 +372,8 @@ class SketchMaker: public BaseMaker {
c.sum_hess >= param.min_child_weight) {
bst_float cpt = fsplits.back();
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(static_cast<bst_float>(loss_chg), fid, cpt + fabsf(cpt) + 1.0f, false);
best->Update(static_cast<bst_float>(loss_chg),
fid, cpt + std::abs(cpt) + 1.0f, false);
}
}
}