add base_margin

This commit is contained in:
tqchen@graphlab.com
2014-08-18 12:20:13 -07:00
parent 46fed899ab
commit 9da2ced8a2
12 changed files with 162 additions and 93 deletions

View File

@@ -33,6 +33,15 @@ struct MetaInfo {
* can be used for multi task setting
*/
std::vector<unsigned> root_index;
/*!
* \brief initialized margins,
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*/
std::vector<float> base_margin;
/*! \brief version flag, used to check version of this info */
static const int kVersion = 0;
// constructor
MetaInfo(void) : num_row(0), num_col(0) {}
/*! \brief clear all the information */
inline void Clear(void) {
@@ -40,6 +49,7 @@ struct MetaInfo {
group_ptr.clear();
weights.clear();
root_index.clear();
base_margin.clear();
num_row = num_col = 0;
}
/*! \brief get weight of each instances */
@@ -59,20 +69,26 @@ struct MetaInfo {
}
}
inline void SaveBinary(utils::IStream &fo) const {
int version = kVersion;
fo.Write(&version, sizeof(version));
fo.Write(&num_row, sizeof(num_row));
fo.Write(&num_col, sizeof(num_col));
fo.Write(labels);
fo.Write(group_ptr);
fo.Write(weights);
fo.Write(root_index);
fo.Write(base_margin);
}
inline void LoadBinary(utils::IStream &fi) {
int version;
utils::Check(fi.Read(&version, sizeof(version)), "MetaInfo: invalid format");
utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format");
utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format");
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format");
}
// try to load group information from file, if exists
inline bool TryLoadGroup(const char* fname, bool silent = false) {
@@ -89,8 +105,19 @@ struct MetaInfo {
fclose(fi);
return true;
}
inline std::vector<float>& GetInfo(const char *field) {
if (!strcmp(field, "label")) return labels;
if (!strcmp(field, "weight")) return weights;
if (!strcmp(field, "base_margin")) return base_margin;
utils::Error("unknown field %s", field);
return labels;
}
inline const std::vector<float>& GetInfo(const char *field) const {
return ((MetaInfo*)this)->GetInfo(field);
}
// try to load weight information from file, if exists
inline bool TryLoadWeight(const char* fname, bool silent = false) {
inline bool TryLoadFloatInfo(const char *field, const char* fname, bool silent = false) {
std::vector<float> &weights = this->GetInfo(field);
FILE *fi = fopen64(fname, "r");
if (fi == NULL) return false;
float wt;
@@ -98,7 +125,7 @@ struct MetaInfo {
weights.push_back(wt);
}
if (!silent) {
printf("loading weight from %s\n", fname);
printf("loading %s from %s\n", field, fname);
}
fclose(fi);
return true;

View File

@@ -97,9 +97,6 @@ class BoostLearner {
this->InitObjGBM();
// reset the base score
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
char tmp[32];
snprintf(tmp, sizeof(tmp), "%g", mparam.base_score);
this->SetParam("base_score", tmp);
// initialize GBM model
gbm_->InitModel();
}
@@ -199,12 +196,16 @@ class BoostLearner {
/*!
* \brief get prediction
* \param data input data
* \param output_margin whether to only predict margin value instead of transformed prediction
* \param out_preds output vector that stores the prediction
*/
inline void Predict(const DMatrix<FMatrix> &data,
bool output_margin,
std::vector<float> *out_preds) const {
this->PredictRaw(data, out_preds);
obj_->PredTransform(out_preds);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
}
/*! \brief dump model out */
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
@@ -236,6 +237,22 @@ class BoostLearner {
std::vector<float> *out_preds) const {
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
data.info.root_index, out_preds);
// add base margin
std::vector<float> &preds = *out_preds;
const unsigned ndata = static_cast<unsigned>(preds.size());
if (data.info.base_margin.size() != 0) {
utils::Check(preds.size() == data.info.base_margin.size(),
"base_margin.size does not match with prediction size");
#pragma omp parallel for schedule(static)
for (unsigned j = 0; j < ndata; ++j) {
preds[j] += data.info.base_margin[j];
}
} else {
#pragma omp parallel for schedule(static)
for (unsigned j = 0; j < ndata; ++j) {
preds[j] += mparam.base_score;
}
}
}
/*! \brief training parameter for regression */