fix base score, and print message

This commit is contained in:
tqchen@graphlab.com
2014-08-18 10:53:15 -07:00
parent 04e04ec5a0
commit f6c763a2a7
7 changed files with 39 additions and 14 deletions

View File

@@ -233,7 +233,7 @@ class GBTree : public IGradBooster<FMatrix> {
pred_counter[bid] = static_cast<unsigned>(trees.size());
pred_buffer[bid] = psum;
}
return psum;
return psum + mparam.base_score;
}
// initialize thread local space for prediction
inline void InitThreadTemp(int nthread) {
@@ -296,6 +296,8 @@ class GBTree : public IGradBooster<FMatrix> {
};
/*! \brief model parameters */
struct ModelParam {
/*! \brief base prediction score of everything */
float base_score;
/*! \brief number of trees */
int num_trees;
/*! \brief number of root: default 0, means single tree */
@@ -314,6 +316,7 @@ class GBTree : public IGradBooster<FMatrix> {
int reserved[32];
/*! \brief constructor */
ModelParam(void) {
base_score = 0.0f;
num_trees = 0;
num_roots = num_feature = 0;
num_pbuffer = 0;
@@ -326,6 +329,7 @@ class GBTree : public IGradBooster<FMatrix> {
* \param val value of the parameter
*/
inline void SetParam(const char *name, const char *val) {
if (!strcmp("base_score", name)) base_score = static_cast<float>(atof(val));
if (!strcmp("num_pbuffer", name)) num_pbuffer = atol(val);
if (!strcmp("num_output_group", name)) num_output_group = atol(val);
if (!strcmp("bst:num_roots", name)) num_roots = atoi(val);