diff --git a/Makefile b/Makefile index ca58f0eb3..25e61ca7b 100644 --- a/Makefile +++ b/Makefile @@ -10,14 +10,13 @@ endif # specify tensor path BIN = xgboost -OBJ = io.o +OBJ = SLIB = python/libxgboostwrapper.so .PHONY: clean all all: $(BIN) $(OBJ) $(SLIB) -xgboost: src/xgboost_main.cpp io.o src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp -io.o: src/io/io.cpp src/data.h src/utils/*.h +xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp # now the wrapper takes in two files. io and wrapper part python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h @@ -34,4 +33,4 @@ install: cp -f -r $(BIN) $(INSTALL_PATH) clean: - $(RM) $(OBJ) $(BIN) *~ */*~ */*/*~ + $(RM) $(OBJ) $(BIN) $(SLIB) *~ */*~ */*/*~ diff --git a/README.md b/README.md index 106757471..61472aa44 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,8 @@ Features xgboost-unity ======= -* experimental branch(not usable yet): refactor xgboost, cleaner code, more flexibility -* This version of xgboost is not backward compatible with 0.2*, due to huge change in code structure +* Experimental branch(not usable yet): refactor xgboost, cleaner code, more flexibility +* This version of xgboost is not compatible with 0.2x, due to huge amount of changes in code structure - This means the model and buffer file of previous version can not be loaded in xgboost-unity Build diff --git a/demo/kaggle-higgs/README.md b/demo/kaggle-higgs/README.md index 28472a848..9e535ef1e 100644 --- a/demo/kaggle-higgs/README.md +++ b/demo/kaggle-higgs/README.md @@ -7,7 +7,7 @@ This script will achieve about 3.600 AMS score in public leadboard. To get start 1. Compile the XGBoost python lib ```bash -cd ../../python +cd ../.. make ``` 2. Put training.csv test.csv on folder './data' (you can create a symbolic link) diff --git a/python/example/demo.py b/python/example/demo.py index a099f56bf..231640d91 100755 --- a/python/example/demo.py +++ b/python/example/demo.py @@ -90,3 +90,22 @@ def evalerror(preds, dtrain): # training with customized objective, we can also do step by step training # simply look at xgboost.py's implementation of train bst = xgb.train(param, dtrain, num_round, evallist, logregobj, evalerror) + + +### +# advanced: start from a initial base prediction +# +print ('start running example to start from a initial prediction') +# specify parameters via map, definition are same as c++ version +param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'objective':'binary:logistic' } +# train xgboost for 1 round +bst = xgb.train( param, dtrain, 1, evallist ) +# Note: we need the margin value instead of transformed prediction in set_base_margin +# do predict with output_margin=True, will always give you margin values before logistic transformation +ptrain = bst.predict(dtrain, output_margin=True) +ptest = bst.predict(dtest, output_margin=True) +dtrain.set_base_margin(ptrain) +dtest.set_base_margin(ptest) + +print ('this is result of running from initial prediction') +bst = xgb.train( param, dtrain, 1, evallist ) diff --git a/python/xgboost.py b/python/xgboost.py index 2e5aeceba..badeebed9 100644 --- a/python/xgboost.py +++ b/python/xgboost.py @@ -18,8 +18,7 @@ xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p -xglib.XGDMatrixGetLabel.restype = ctypes.POINTER(ctypes.c_float) -xglib.XGDMatrixGetWeight.restype = ctypes.POINTER(ctypes.c_float) +xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float) xglib.XGDMatrixNumRow.restype = ctypes.c_ulong xglib.XGBoosterCreate.restype = ctypes.c_void_p @@ -77,28 +76,46 @@ class DMatrix: # destructor def __del__(self): xglib.XGDMatrixFree(self.handle) - # load data from file + def __get_float_info(self, field): + length = ctypes.c_ulong() + ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), + ctypes.byref(length)) + return ctypes2numpy(ret, length.value) + def __set_float_info(self, field, data): + xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')), + (ctypes.c_float*len(data))(*data), len(data)) + # load data from file def save_binary(self, fname, silent=True): xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) # set label of dmatrix def set_label(self, label): - xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label)) + self.__set_float_info('label', label) + # set weight of each instances + def set_weight(self, weight): + self.__set_float_info('weight', label) + # set initialized margin prediction + def set_base_margin(self, margin): + """ + set base margin of booster to start from + this can be used to specify a prediction value of + existing model to be base_margin + However, remember margin is needed, instead of transformed prediction + e.g. for logistic regression: need to put in value before logistic transformation + see also example/demo.py + """ + self.__set_float_info('base_margin', margin) # set group size of dmatrix, used for rank def set_group(self, group): xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group)) - # set weight of each instances - def set_weight(self, weight): - xglib.XGDMatrixSetWeight(self.handle, (ctypes.c_float*len(weight))(*weight), len(weight)) # get label from dmatrix def get_label(self): - length = ctypes.c_ulong() - labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length)) - return ctypes2numpy(labels, length.value) + return self.__get_float_info('label') # get weight from dmatrix def get_weight(self): - length = ctypes.c_ulong() - weights = xglib.XGDMatrixGetWeight(self.handle, ctypes.byref(length)) - return ctypes2numpy(weights, length.value) + return self.__get_float_info('weight') + # get base_margin from dmatrix + def get_base_margin(self): + return self.__get_float_info('base_margin') def num_row(self): return xglib.XGDMatrixNumRow(self.handle) # slice the DMatrix to return a new DMatrix that only contains rindex @@ -161,9 +178,15 @@ class Booster: return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) def eval(self, mat, name = 'eval', it = 0): return self.eval_set( [(mat,name)], it) - def predict(self, data): + def predict(self, data, output_margin=False): + """ + predict with data + data: the dmatrix storing the input + output_margin: whether output raw margin value that is untransformed + """ length = ctypes.c_ulong() - preds = xglib.XGBoosterPredict(self.handle, data.handle, ctypes.byref(length)) + preds = xglib.XGBoosterPredict(self.handle, data.handle, + int(output_margin), ctypes.byref(length)) return ctypes2numpy(preds, length.value) def save_model(self, fname): """ save model to file """ diff --git a/python/xgboost_wrapper.cpp b/python/xgboost_wrapper.cpp index edda96c29..7f2365ba3 100644 --- a/python/xgboost_wrapper.cpp +++ b/python/xgboost_wrapper.cpp @@ -23,9 +23,9 @@ class Booster: public learner::BoostLearner { this->init_model = false; this->SetCacheData(mats); } - const float *Pred(const DataMatrix &dmat, size_t *len) { + const float *Pred(const DataMatrix &dmat, int output_margin, size_t *len) { this->CheckInitModel(); - this->Predict(dmat, &this->preds_); + this->Predict(dmat, output_margin, &this->preds_); *len = this->preds_.size(); return &this->preds_[0]; } @@ -163,15 +163,11 @@ extern "C"{ void XGDMatrixSaveBinary(void *handle, const char *fname, int silent) { SaveDataMatrix(*static_cast(handle), fname, silent); } - void XGDMatrixSetLabel(void *handle, const float *label, size_t len) { - DataMatrix *pmat = static_cast(handle); - pmat->info.labels.resize(len); - memcpy(&(pmat->info).labels[0], label, sizeof(float) * len); - } - void XGDMatrixSetWeight(void *handle, const float *weight, size_t len) { - DataMatrix *pmat = static_cast(handle); - pmat->info.weights.resize(len); - memcpy(&(pmat->info).weights[0], weight, sizeof(float) * len); + void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, size_t len) { + std::vector &vec = + static_cast(handle)->info.GetInfo(field); + vec.resize(len); + memcpy(&vec[0], info, sizeof(float) * len); } void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) { DataMatrix *pmat = static_cast(handle); @@ -181,15 +177,11 @@ extern "C"{ pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i]; } } - const float* XGDMatrixGetLabel(const void *handle, size_t* len) { - const DataMatrix *pmat = static_cast(handle); - *len = pmat->info.labels.size(); - return &(pmat->info.labels[0]); - } - const float* XGDMatrixGetWeight(const void *handle, size_t* len) { - const DataMatrix *pmat = static_cast(handle); - *len = pmat->info.weights.size(); - return &(pmat->info.weights[0]); + const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* len) { + const std::vector &vec = + static_cast(handle)->info.GetInfo(field); + *len = vec.size(); + return &vec[0]; } size_t XGDMatrixNumRow(const void *handle) { return static_cast(handle)->info.num_row; @@ -238,8 +230,8 @@ extern "C"{ bst->eval_str = bst->EvalOneIter(iter, mats, names); return bst->eval_str.c_str(); } - const float *XGBoosterPredict(void *handle, void *dmat, size_t *len) { - return static_cast(handle)->Pred(*static_cast(dmat), len); + const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, size_t *len) { + return static_cast(handle)->Pred(*static_cast(dmat), output_margin, len); } void XGBoosterLoadModel(void *handle, const char *fname) { static_cast(handle)->LoadModel(fname); diff --git a/python/xgboost_wrapper.h b/python/xgboost_wrapper.h index 16b8fecd7..1b6805c61 100644 --- a/python/xgboost_wrapper.h +++ b/python/xgboost_wrapper.h @@ -64,19 +64,13 @@ extern "C" { */ void XGDMatrixSaveBinary(void *handle, const char *fname, int silent); /*! - * \brief set label of the training matrix + * \brief set float vector to a content in info * \param handle a instance of data matrix - * \param label pointer to label + * \param field field name, can be label, weight + * \param array pointer to float vector * \param len length of array */ - void XGDMatrixSetLabel(void *handle, const float *label, size_t len); - /*! - * \brief set weight of each instance - * \param handle a instance of data matrix - * \param weight data pointer to weights - * \param len length of array - */ - void XGDMatrixSetWeight(void *handle, const float *weight, size_t len); + void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, size_t len); /*! * \brief set label of the training matrix * \param handle a instance of data matrix @@ -85,19 +79,13 @@ extern "C" { */ void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len); /*! - * \brief get label set from matrix + * \brief get float info vector from matrix * \param handle a instance of data matrix * \param len used to set result length + * \param field field name * \return pointer to the label */ - const float* XGDMatrixGetLabel(const void *handle, size_t* out_len); - /*! - * \brief get weight set from matrix - * \param handle a instance of data matrix - * \param len used to set result length - * \return pointer to the weight - */ - const float* XGDMatrixGetWeight(const void *handle, size_t* out_len); + const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len); /*! * \brief return number of rows */ @@ -154,9 +142,10 @@ extern "C" { * \brief make prediction based on dmat * \param handle handle * \param dmat data matrix + * \param output_margin whether only output raw margin value * \param len used to store length of returning result */ - const float *XGBoosterPredict(void *handle, void *dmat, size_t *len); + const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, size_t *len); /*! * \brief load model from existing file * \param handle handle diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index b0bd0f99a..216240b74 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -233,7 +233,7 @@ class GBTree : public IGradBooster { pred_counter[bid] = static_cast(trees.size()); pred_buffer[bid] = psum; } - return psum + mparam.base_score; + return psum; } // initialize thread local space for prediction inline void InitThreadTemp(int nthread) { @@ -296,8 +296,6 @@ class GBTree : public IGradBooster { }; /*! \brief model parameters */ struct ModelParam { - /*! \brief base prediction score of everything */ - float base_score; /*! \brief number of trees */ int num_trees; /*! \brief number of root: default 0, means single tree */ @@ -316,7 +314,6 @@ class GBTree : public IGradBooster { int reserved[32]; /*! \brief constructor */ ModelParam(void) { - base_score = 0.0f; num_trees = 0; num_roots = num_feature = 0; num_pbuffer = 0; @@ -329,7 +326,6 @@ class GBTree : public IGradBooster { * \param val value of the parameter */ inline void SetParam(const char *name, const char *val) { - if (!strcmp("base_score", name)) base_score = static_cast(atof(val)); if (!strcmp("num_pbuffer", name)) num_pbuffer = atol(val); if (!strcmp("num_output_group", name)) num_output_group = atol(val); if (!strcmp("bst:num_roots", name)) num_roots = atoi(val); diff --git a/src/io/simple_dmatrix-inl.hpp b/src/io/simple_dmatrix-inl.hpp index bc0e3c2bd..c0b98b789 100644 --- a/src/io/simple_dmatrix-inl.hpp +++ b/src/io/simple_dmatrix-inl.hpp @@ -110,10 +110,13 @@ class DMatrixSimple : public DataMatrix { "DMatrix: group data does not match the number of rows in features"); } std::string wname = name + ".weight"; - if (info.TryLoadWeight(wname.c_str(), silent)) { + if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) { utils::Check(info.weights.size() == info.num_row, "DMatrix: weight data does not match the number of rows in features"); } + std::string mname = name + ".base_margin"; + if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) { + } } /*! * \brief load from binary file diff --git a/src/learner/dmatrix.h b/src/learner/dmatrix.h index 144b1a44e..f7dbcb639 100644 --- a/src/learner/dmatrix.h +++ b/src/learner/dmatrix.h @@ -33,6 +33,15 @@ struct MetaInfo { * can be used for multi task setting */ std::vector root_index; + /*! + * \brief initialized margins, + * if specified, xgboost will start from this init margin + * can be used to specify initial prediction to boost from + */ + std::vector base_margin; + /*! \brief version flag, used to check version of this info */ + static const int kVersion = 0; + // constructor MetaInfo(void) : num_row(0), num_col(0) {} /*! \brief clear all the information */ inline void Clear(void) { @@ -40,6 +49,7 @@ struct MetaInfo { group_ptr.clear(); weights.clear(); root_index.clear(); + base_margin.clear(); num_row = num_col = 0; } /*! \brief get weight of each instances */ @@ -59,20 +69,26 @@ struct MetaInfo { } } inline void SaveBinary(utils::IStream &fo) const { + int version = kVersion; + fo.Write(&version, sizeof(version)); fo.Write(&num_row, sizeof(num_row)); fo.Write(&num_col, sizeof(num_col)); fo.Write(labels); fo.Write(group_ptr); fo.Write(weights); fo.Write(root_index); + fo.Write(base_margin); } inline void LoadBinary(utils::IStream &fi) { + int version; + utils::Check(fi.Read(&version, sizeof(version)), "MetaInfo: invalid format"); utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format"); utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format"); utils::Check(fi.Read(&labels), "MetaInfo: invalid format"); utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format"); utils::Check(fi.Read(&weights), "MetaInfo: invalid format"); utils::Check(fi.Read(&root_index), "MetaInfo: invalid format"); + utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format"); } // try to load group information from file, if exists inline bool TryLoadGroup(const char* fname, bool silent = false) { @@ -89,8 +105,19 @@ struct MetaInfo { fclose(fi); return true; } + inline std::vector& GetInfo(const char *field) { + if (!strcmp(field, "label")) return labels; + if (!strcmp(field, "weight")) return weights; + if (!strcmp(field, "base_margin")) return base_margin; + utils::Error("unknown field %s", field); + return labels; + } + inline const std::vector& GetInfo(const char *field) const { + return ((MetaInfo*)this)->GetInfo(field); + } // try to load weight information from file, if exists - inline bool TryLoadWeight(const char* fname, bool silent = false) { + inline bool TryLoadFloatInfo(const char *field, const char* fname, bool silent = false) { + std::vector &weights = this->GetInfo(field); FILE *fi = fopen64(fname, "r"); if (fi == NULL) return false; float wt; @@ -98,7 +125,7 @@ struct MetaInfo { weights.push_back(wt); } if (!silent) { - printf("loading weight from %s\n", fname); + printf("loading %s from %s\n", field, fname); } fclose(fi); return true; diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 09167d8bf..4d227f488 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -97,9 +97,6 @@ class BoostLearner { this->InitObjGBM(); // reset the base score mparam.base_score = obj_->ProbToMargin(mparam.base_score); - char tmp[32]; - snprintf(tmp, sizeof(tmp), "%g", mparam.base_score); - this->SetParam("base_score", tmp); // initialize GBM model gbm_->InitModel(); } @@ -199,12 +196,16 @@ class BoostLearner { /*! * \brief get prediction * \param data input data + * \param output_margin whether to only predict margin value instead of transformed prediction * \param out_preds output vector that stores the prediction */ inline void Predict(const DMatrix &data, + bool output_margin, std::vector *out_preds) const { this->PredictRaw(data, out_preds); - obj_->PredTransform(out_preds); + if (!output_margin) { + obj_->PredTransform(out_preds); + } } /*! \brief dump model out */ inline std::vector DumpModel(const utils::FeatMap& fmap, int option) { @@ -236,6 +237,22 @@ class BoostLearner { std::vector *out_preds) const { gbm_->Predict(data.fmat, this->FindBufferOffset(data), data.info.root_index, out_preds); + // add base margin + std::vector &preds = *out_preds; + const unsigned ndata = static_cast(preds.size()); + if (data.info.base_margin.size() != 0) { + utils::Check(preds.size() == data.info.base_margin.size(), + "base_margin.size does not match with prediction size"); + #pragma omp parallel for schedule(static) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] += data.info.base_margin[j]; + } + } else { + #pragma omp parallel for schedule(static) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] += mparam.base_score; + } + } } /*! \brief training parameter for regression */ diff --git a/src/xgboost_main.cpp b/src/xgboost_main.cpp index f3fc9201d..c807df15a 100644 --- a/src/xgboost_main.cpp +++ b/src/xgboost_main.cpp @@ -49,6 +49,7 @@ class BoostLearnTask{ if (!strcmp("silent", name)) silent = atoi(val); if (!strcmp("use_buffer", name)) use_buffer = atoi(val); if (!strcmp("num_round", name)) num_round = atoi(val); + if (!strcmp("pred_margin", name)) pred_margin = atoi(val); if (!strcmp("save_period", name)) save_period = atoi(val); if (!strcmp("eval_train", name)) eval_train = atoi(val); if (!strcmp("task", name)) task = val; @@ -77,6 +78,7 @@ class BoostLearnTask{ num_round = 10; save_period = 0; eval_train = 0; + pred_margin = 0; dump_model_stats = 0; task = "train"; model_in = "NULL"; @@ -184,7 +186,7 @@ class BoostLearnTask{ inline void TaskPred(void) { std::vector preds; if (!silent) printf("start prediction...\n"); - learner.Predict(*data, &preds); + learner.Predict(*data, pred_margin != 0, &preds); if (!silent) printf("writing prediction to %s\n", name_pred.c_str()); FILE *fo = utils::FopenCheck(name_pred.c_str(), "w"); for (size_t i = 0; i < preds.size(); i++) { @@ -193,37 +195,39 @@ class BoostLearnTask{ fclose(fo); } private: - /* \brief whether silent */ + /*! \brief whether silent */ int silent; - /* \brief whether use auto binary buffer */ + /*! \brief whether use auto binary buffer */ int use_buffer; - /* \brief whether evaluate training statistics */ + /*! \brief whether evaluate training statistics */ int eval_train; - /* \brief number of boosting iterations */ + /*! \brief number of boosting iterations */ int num_round; - /* \brief the period to save the model, 0 means only save the final round model */ + /*! \brief the period to save the model, 0 means only save the final round model */ int save_period; - /* \brief the path of training/test data set */ + /*! \brief the path of training/test data set */ std::string train_path, test_path; - /* \brief the path of test model file, or file to restart training */ + /*! \brief the path of test model file, or file to restart training */ std::string model_in; - /* \brief the path of final model file, to be saved */ + /*! \brief the path of final model file, to be saved */ std::string model_out; - /* \brief the path of directory containing the saved models */ + /*! \brief the path of directory containing the saved models */ std::string model_dir_path; - /* \brief task to perform */ + /*! \brief task to perform */ std::string task; - /* \brief name of predict file */ + /*! \brief name of predict file */ std::string name_pred; - /* \brief whether dump statistics along with model */ + /*!\brief whether to directly output margin value */ + int pred_margin; + /*! \brief whether dump statistics along with model */ int dump_model_stats; - /* \brief name of feature map */ + /*! \brief name of feature map */ std::string name_fmap; - /* \brief name of dump file */ + /*! \brief name of dump file */ std::string name_dump; - /* \brief the paths of validation data sets */ + /*! \brief the paths of validation data sets */ std::vector eval_data_paths; - /* \brief the names of the evaluation data used in output log */ + /*! \brief the names of the evaluation data used in output log */ std::vector eval_data_names; private: io::DataMatrix* data;