add base_margin

This commit is contained in:
tqchen@graphlab.com 2014-08-18 12:20:13 -07:00
parent 46fed899ab
commit 9da2ced8a2
12 changed files with 162 additions and 93 deletions

View File

@ -10,14 +10,13 @@ endif
# specify tensor path
BIN = xgboost
OBJ = io.o
OBJ =
SLIB = python/libxgboostwrapper.so
.PHONY: clean all
all: $(BIN) $(OBJ) $(SLIB)
xgboost: src/xgboost_main.cpp io.o src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
io.o: src/io/io.cpp src/data.h src/utils/*.h
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
# now the wrapper takes in two files. io and wrapper part
python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
@ -34,4 +33,4 @@ install:
cp -f -r $(BIN) $(INSTALL_PATH)
clean:
$(RM) $(OBJ) $(BIN) *~ */*~ */*/*~
$(RM) $(OBJ) $(BIN) $(SLIB) *~ */*~ */*/*~

View File

@ -21,8 +21,8 @@ Features
xgboost-unity
=======
* experimental branch(not usable yet): refactor xgboost, cleaner code, more flexibility
* This version of xgboost is not backward compatible with 0.2*, due to huge change in code structure
* Experimental branch(not usable yet): refactor xgboost, cleaner code, more flexibility
* This version of xgboost is not compatible with 0.2x, due to huge amount of changes in code structure
- This means the model and buffer file of previous version can not be loaded in xgboost-unity
Build

View File

@ -7,7 +7,7 @@ This script will achieve about 3.600 AMS score in public leadboard. To get start
1. Compile the XGBoost python lib
```bash
cd ../../python
cd ../..
make
```
2. Put training.csv test.csv on folder './data' (you can create a symbolic link)

View File

@ -90,3 +90,22 @@ def evalerror(preds, dtrain):
# training with customized objective, we can also do step by step training
# simply look at xgboost.py's implementation of train
bst = xgb.train(param, dtrain, num_round, evallist, logregobj, evalerror)
###
# advanced: start from a initial base prediction
#
print ('start running example to start from a initial prediction')
# specify parameters via map, definition are same as c++ version
param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'objective':'binary:logistic' }
# train xgboost for 1 round
bst = xgb.train( param, dtrain, 1, evallist )
# Note: we need the margin value instead of transformed prediction in set_base_margin
# do predict with output_margin=True, will always give you margin values before logistic transformation
ptrain = bst.predict(dtrain, output_margin=True)
ptest = bst.predict(dtest, output_margin=True)
dtrain.set_base_margin(ptrain)
dtest.set_base_margin(ptest)
print ('this is result of running from initial prediction')
bst = xgb.train( param, dtrain, 1, evallist )

View File

@ -18,8 +18,7 @@ xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
xglib.XGDMatrixGetLabel.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGDMatrixGetWeight.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
xglib.XGBoosterCreate.restype = ctypes.c_void_p
@ -77,28 +76,46 @@ class DMatrix:
# destructor
def __del__(self):
xglib.XGDMatrixFree(self.handle)
def __get_float_info(self, field):
length = ctypes.c_ulong()
ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
ctypes.byref(length))
return ctypes2numpy(ret, length.value)
def __set_float_info(self, field, data):
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
(ctypes.c_float*len(data))(*data), len(data))
# load data from file
def save_binary(self, fname, silent=True):
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
# set label of dmatrix
def set_label(self, label):
xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label))
self.__set_float_info('label', label)
# set weight of each instances
def set_weight(self, weight):
self.__set_float_info('weight', label)
# set initialized margin prediction
def set_base_margin(self, margin):
"""
set base margin of booster to start from
this can be used to specify a prediction value of
existing model to be base_margin
However, remember margin is needed, instead of transformed prediction
e.g. for logistic regression: need to put in value before logistic transformation
see also example/demo.py
"""
self.__set_float_info('base_margin', margin)
# set group size of dmatrix, used for rank
def set_group(self, group):
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
# set weight of each instances
def set_weight(self, weight):
xglib.XGDMatrixSetWeight(self.handle, (ctypes.c_float*len(weight))(*weight), len(weight))
# get label from dmatrix
def get_label(self):
length = ctypes.c_ulong()
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length))
return ctypes2numpy(labels, length.value)
return self.__get_float_info('label')
# get weight from dmatrix
def get_weight(self):
length = ctypes.c_ulong()
weights = xglib.XGDMatrixGetWeight(self.handle, ctypes.byref(length))
return ctypes2numpy(weights, length.value)
return self.__get_float_info('weight')
# get base_margin from dmatrix
def get_base_margin(self):
return self.__get_float_info('base_margin')
def num_row(self):
return xglib.XGDMatrixNumRow(self.handle)
# slice the DMatrix to return a new DMatrix that only contains rindex
@ -161,9 +178,15 @@ class Booster:
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it)
def predict(self, data):
def predict(self, data, output_margin=False):
"""
predict with data
data: the dmatrix storing the input
output_margin: whether output raw margin value that is untransformed
"""
length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle, ctypes.byref(length))
preds = xglib.XGBoosterPredict(self.handle, data.handle,
int(output_margin), ctypes.byref(length))
return ctypes2numpy(preds, length.value)
def save_model(self, fname):
""" save model to file """

View File

@ -23,9 +23,9 @@ class Booster: public learner::BoostLearner<FMatrixS> {
this->init_model = false;
this->SetCacheData(mats);
}
const float *Pred(const DataMatrix &dmat, size_t *len) {
const float *Pred(const DataMatrix &dmat, int output_margin, size_t *len) {
this->CheckInitModel();
this->Predict(dmat, &this->preds_);
this->Predict(dmat, output_margin, &this->preds_);
*len = this->preds_.size();
return &this->preds_[0];
}
@ -163,15 +163,11 @@ extern "C"{
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent) {
SaveDataMatrix(*static_cast<DataMatrix*>(handle), fname, silent);
}
void XGDMatrixSetLabel(void *handle, const float *label, size_t len) {
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
pmat->info.labels.resize(len);
memcpy(&(pmat->info).labels[0], label, sizeof(float) * len);
}
void XGDMatrixSetWeight(void *handle, const float *weight, size_t len) {
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
pmat->info.weights.resize(len);
memcpy(&(pmat->info).weights[0], weight, sizeof(float) * len);
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, size_t len) {
std::vector<float> &vec =
static_cast<DataMatrix*>(handle)->info.GetInfo(field);
vec.resize(len);
memcpy(&vec[0], info, sizeof(float) * len);
}
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) {
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
@ -181,15 +177,11 @@ extern "C"{
pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i];
}
}
const float* XGDMatrixGetLabel(const void *handle, size_t* len) {
const DataMatrix *pmat = static_cast<const DataMatrix*>(handle);
*len = pmat->info.labels.size();
return &(pmat->info.labels[0]);
}
const float* XGDMatrixGetWeight(const void *handle, size_t* len) {
const DataMatrix *pmat = static_cast<const DataMatrix*>(handle);
*len = pmat->info.weights.size();
return &(pmat->info.weights[0]);
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* len) {
const std::vector<float> &vec =
static_cast<const DataMatrix*>(handle)->info.GetInfo(field);
*len = vec.size();
return &vec[0];
}
size_t XGDMatrixNumRow(const void *handle) {
return static_cast<const DataMatrix*>(handle)->info.num_row;
@ -238,8 +230,8 @@ extern "C"{
bst->eval_str = bst->EvalOneIter(iter, mats, names);
return bst->eval_str.c_str();
}
const float *XGBoosterPredict(void *handle, void *dmat, size_t *len) {
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), len);
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, size_t *len) {
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, len);
}
void XGBoosterLoadModel(void *handle, const char *fname) {
static_cast<Booster*>(handle)->LoadModel(fname);

View File

@ -64,19 +64,13 @@ extern "C" {
*/
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent);
/*!
* \brief set label of the training matrix
* \brief set float vector to a content in info
* \param handle a instance of data matrix
* \param label pointer to label
* \param field field name, can be label, weight
* \param array pointer to float vector
* \param len length of array
*/
void XGDMatrixSetLabel(void *handle, const float *label, size_t len);
/*!
* \brief set weight of each instance
* \param handle a instance of data matrix
* \param weight data pointer to weights
* \param len length of array
*/
void XGDMatrixSetWeight(void *handle, const float *weight, size_t len);
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, size_t len);
/*!
* \brief set label of the training matrix
* \param handle a instance of data matrix
@ -85,19 +79,13 @@ extern "C" {
*/
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len);
/*!
* \brief get label set from matrix
* \brief get float info vector from matrix
* \param handle a instance of data matrix
* \param len used to set result length
* \param field field name
* \return pointer to the label
*/
const float* XGDMatrixGetLabel(const void *handle, size_t* out_len);
/*!
* \brief get weight set from matrix
* \param handle a instance of data matrix
* \param len used to set result length
* \return pointer to the weight
*/
const float* XGDMatrixGetWeight(const void *handle, size_t* out_len);
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len);
/*!
* \brief return number of rows
*/
@ -154,9 +142,10 @@ extern "C" {
* \brief make prediction based on dmat
* \param handle handle
* \param dmat data matrix
* \param output_margin whether only output raw margin value
* \param len used to store length of returning result
*/
const float *XGBoosterPredict(void *handle, void *dmat, size_t *len);
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, size_t *len);
/*!
* \brief load model from existing file
* \param handle handle

View File

@ -233,7 +233,7 @@ class GBTree : public IGradBooster<FMatrix> {
pred_counter[bid] = static_cast<unsigned>(trees.size());
pred_buffer[bid] = psum;
}
return psum + mparam.base_score;
return psum;
}
// initialize thread local space for prediction
inline void InitThreadTemp(int nthread) {
@ -296,8 +296,6 @@ class GBTree : public IGradBooster<FMatrix> {
};
/*! \brief model parameters */
struct ModelParam {
/*! \brief base prediction score of everything */
float base_score;
/*! \brief number of trees */
int num_trees;
/*! \brief number of root: default 0, means single tree */
@ -316,7 +314,6 @@ class GBTree : public IGradBooster<FMatrix> {
int reserved[32];
/*! \brief constructor */
ModelParam(void) {
base_score = 0.0f;
num_trees = 0;
num_roots = num_feature = 0;
num_pbuffer = 0;
@ -329,7 +326,6 @@ class GBTree : public IGradBooster<FMatrix> {
* \param val value of the parameter
*/
inline void SetParam(const char *name, const char *val) {
if (!strcmp("base_score", name)) base_score = static_cast<float>(atof(val));
if (!strcmp("num_pbuffer", name)) num_pbuffer = atol(val);
if (!strcmp("num_output_group", name)) num_output_group = atol(val);
if (!strcmp("bst:num_roots", name)) num_roots = atoi(val);

View File

@ -110,10 +110,13 @@ class DMatrixSimple : public DataMatrix {
"DMatrix: group data does not match the number of rows in features");
}
std::string wname = name + ".weight";
if (info.TryLoadWeight(wname.c_str(), silent)) {
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
utils::Check(info.weights.size() == info.num_row,
"DMatrix: weight data does not match the number of rows in features");
}
std::string mname = name + ".base_margin";
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
}
}
/*!
* \brief load from binary file

View File

@ -33,6 +33,15 @@ struct MetaInfo {
* can be used for multi task setting
*/
std::vector<unsigned> root_index;
/*!
* \brief initialized margins,
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*/
std::vector<float> base_margin;
/*! \brief version flag, used to check version of this info */
static const int kVersion = 0;
// constructor
MetaInfo(void) : num_row(0), num_col(0) {}
/*! \brief clear all the information */
inline void Clear(void) {
@ -40,6 +49,7 @@ struct MetaInfo {
group_ptr.clear();
weights.clear();
root_index.clear();
base_margin.clear();
num_row = num_col = 0;
}
/*! \brief get weight of each instances */
@ -59,20 +69,26 @@ struct MetaInfo {
}
}
inline void SaveBinary(utils::IStream &fo) const {
int version = kVersion;
fo.Write(&version, sizeof(version));
fo.Write(&num_row, sizeof(num_row));
fo.Write(&num_col, sizeof(num_col));
fo.Write(labels);
fo.Write(group_ptr);
fo.Write(weights);
fo.Write(root_index);
fo.Write(base_margin);
}
inline void LoadBinary(utils::IStream &fi) {
int version;
utils::Check(fi.Read(&version, sizeof(version)), "MetaInfo: invalid format");
utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format");
utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format");
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format");
}
// try to load group information from file, if exists
inline bool TryLoadGroup(const char* fname, bool silent = false) {
@ -89,8 +105,19 @@ struct MetaInfo {
fclose(fi);
return true;
}
inline std::vector<float>& GetInfo(const char *field) {
if (!strcmp(field, "label")) return labels;
if (!strcmp(field, "weight")) return weights;
if (!strcmp(field, "base_margin")) return base_margin;
utils::Error("unknown field %s", field);
return labels;
}
inline const std::vector<float>& GetInfo(const char *field) const {
return ((MetaInfo*)this)->GetInfo(field);
}
// try to load weight information from file, if exists
inline bool TryLoadWeight(const char* fname, bool silent = false) {
inline bool TryLoadFloatInfo(const char *field, const char* fname, bool silent = false) {
std::vector<float> &weights = this->GetInfo(field);
FILE *fi = fopen64(fname, "r");
if (fi == NULL) return false;
float wt;
@ -98,7 +125,7 @@ struct MetaInfo {
weights.push_back(wt);
}
if (!silent) {
printf("loading weight from %s\n", fname);
printf("loading %s from %s\n", field, fname);
}
fclose(fi);
return true;

View File

@ -97,9 +97,6 @@ class BoostLearner {
this->InitObjGBM();
// reset the base score
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
char tmp[32];
snprintf(tmp, sizeof(tmp), "%g", mparam.base_score);
this->SetParam("base_score", tmp);
// initialize GBM model
gbm_->InitModel();
}
@ -199,13 +196,17 @@ class BoostLearner {
/*!
* \brief get prediction
* \param data input data
* \param output_margin whether to only predict margin value instead of transformed prediction
* \param out_preds output vector that stores the prediction
*/
inline void Predict(const DMatrix<FMatrix> &data,
bool output_margin,
std::vector<float> *out_preds) const {
this->PredictRaw(data, out_preds);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
}
/*! \brief dump model out */
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
return gbm_->DumpModel(fmap, option);
@ -236,6 +237,22 @@ class BoostLearner {
std::vector<float> *out_preds) const {
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
data.info.root_index, out_preds);
// add base margin
std::vector<float> &preds = *out_preds;
const unsigned ndata = static_cast<unsigned>(preds.size());
if (data.info.base_margin.size() != 0) {
utils::Check(preds.size() == data.info.base_margin.size(),
"base_margin.size does not match with prediction size");
#pragma omp parallel for schedule(static)
for (unsigned j = 0; j < ndata; ++j) {
preds[j] += data.info.base_margin[j];
}
} else {
#pragma omp parallel for schedule(static)
for (unsigned j = 0; j < ndata; ++j) {
preds[j] += mparam.base_score;
}
}
}
/*! \brief training parameter for regression */

View File

@ -49,6 +49,7 @@ class BoostLearnTask{
if (!strcmp("silent", name)) silent = atoi(val);
if (!strcmp("use_buffer", name)) use_buffer = atoi(val);
if (!strcmp("num_round", name)) num_round = atoi(val);
if (!strcmp("pred_margin", name)) pred_margin = atoi(val);
if (!strcmp("save_period", name)) save_period = atoi(val);
if (!strcmp("eval_train", name)) eval_train = atoi(val);
if (!strcmp("task", name)) task = val;
@ -77,6 +78,7 @@ class BoostLearnTask{
num_round = 10;
save_period = 0;
eval_train = 0;
pred_margin = 0;
dump_model_stats = 0;
task = "train";
model_in = "NULL";
@ -184,7 +186,7 @@ class BoostLearnTask{
inline void TaskPred(void) {
std::vector<float> preds;
if (!silent) printf("start prediction...\n");
learner.Predict(*data, &preds);
learner.Predict(*data, pred_margin != 0, &preds);
if (!silent) printf("writing prediction to %s\n", name_pred.c_str());
FILE *fo = utils::FopenCheck(name_pred.c_str(), "w");
for (size_t i = 0; i < preds.size(); i++) {
@ -193,37 +195,39 @@ class BoostLearnTask{
fclose(fo);
}
private:
/* \brief whether silent */
/*! \brief whether silent */
int silent;
/* \brief whether use auto binary buffer */
/*! \brief whether use auto binary buffer */
int use_buffer;
/* \brief whether evaluate training statistics */
/*! \brief whether evaluate training statistics */
int eval_train;
/* \brief number of boosting iterations */
/*! \brief number of boosting iterations */
int num_round;
/* \brief the period to save the model, 0 means only save the final round model */
/*! \brief the period to save the model, 0 means only save the final round model */
int save_period;
/* \brief the path of training/test data set */
/*! \brief the path of training/test data set */
std::string train_path, test_path;
/* \brief the path of test model file, or file to restart training */
/*! \brief the path of test model file, or file to restart training */
std::string model_in;
/* \brief the path of final model file, to be saved */
/*! \brief the path of final model file, to be saved */
std::string model_out;
/* \brief the path of directory containing the saved models */
/*! \brief the path of directory containing the saved models */
std::string model_dir_path;
/* \brief task to perform */
/*! \brief task to perform */
std::string task;
/* \brief name of predict file */
/*! \brief name of predict file */
std::string name_pred;
/* \brief whether dump statistics along with model */
/*!\brief whether to directly output margin value */
int pred_margin;
/*! \brief whether dump statistics along with model */
int dump_model_stats;
/* \brief name of feature map */
/*! \brief name of feature map */
std::string name_fmap;
/* \brief name of dump file */
/*! \brief name of dump file */
std::string name_dump;
/* \brief the paths of validation data sets */
/*! \brief the paths of validation data sets */
std::vector<std::string> eval_data_paths;
/* \brief the names of the evaluation data used in output log */
/*! \brief the names of the evaluation data used in output log */
std::vector<std::string> eval_data_names;
private:
io::DataMatrix* data;