add base_margin
This commit is contained in:
parent
46fed899ab
commit
9da2ced8a2
7
Makefile
7
Makefile
@ -10,14 +10,13 @@ endif
|
||||
|
||||
# specify tensor path
|
||||
BIN = xgboost
|
||||
OBJ = io.o
|
||||
OBJ =
|
||||
SLIB = python/libxgboostwrapper.so
|
||||
.PHONY: clean all
|
||||
|
||||
all: $(BIN) $(OBJ) $(SLIB)
|
||||
|
||||
xgboost: src/xgboost_main.cpp io.o src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
|
||||
io.o: src/io/io.cpp src/data.h src/utils/*.h
|
||||
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
|
||||
# now the wrapper takes in two files. io and wrapper part
|
||||
python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
||||
|
||||
@ -34,4 +33,4 @@ install:
|
||||
cp -f -r $(BIN) $(INSTALL_PATH)
|
||||
|
||||
clean:
|
||||
$(RM) $(OBJ) $(BIN) *~ */*~ */*/*~
|
||||
$(RM) $(OBJ) $(BIN) $(SLIB) *~ */*~ */*/*~
|
||||
|
||||
@ -21,8 +21,8 @@ Features
|
||||
|
||||
xgboost-unity
|
||||
=======
|
||||
* experimental branch(not usable yet): refactor xgboost, cleaner code, more flexibility
|
||||
* This version of xgboost is not backward compatible with 0.2*, due to huge change in code structure
|
||||
* Experimental branch(not usable yet): refactor xgboost, cleaner code, more flexibility
|
||||
* This version of xgboost is not compatible with 0.2x, due to huge amount of changes in code structure
|
||||
- This means the model and buffer file of previous version can not be loaded in xgboost-unity
|
||||
|
||||
Build
|
||||
|
||||
@ -7,7 +7,7 @@ This script will achieve about 3.600 AMS score in public leadboard. To get start
|
||||
|
||||
1. Compile the XGBoost python lib
|
||||
```bash
|
||||
cd ../../python
|
||||
cd ../..
|
||||
make
|
||||
```
|
||||
2. Put training.csv test.csv on folder './data' (you can create a symbolic link)
|
||||
|
||||
@ -90,3 +90,22 @@ def evalerror(preds, dtrain):
|
||||
# training with customized objective, we can also do step by step training
|
||||
# simply look at xgboost.py's implementation of train
|
||||
bst = xgb.train(param, dtrain, num_round, evallist, logregobj, evalerror)
|
||||
|
||||
|
||||
###
|
||||
# advanced: start from a initial base prediction
|
||||
#
|
||||
print ('start running example to start from a initial prediction')
|
||||
# specify parameters via map, definition are same as c++ version
|
||||
param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||
# train xgboost for 1 round
|
||||
bst = xgb.train( param, dtrain, 1, evallist )
|
||||
# Note: we need the margin value instead of transformed prediction in set_base_margin
|
||||
# do predict with output_margin=True, will always give you margin values before logistic transformation
|
||||
ptrain = bst.predict(dtrain, output_margin=True)
|
||||
ptest = bst.predict(dtest, output_margin=True)
|
||||
dtrain.set_base_margin(ptrain)
|
||||
dtest.set_base_margin(ptest)
|
||||
|
||||
print ('this is result of running from initial prediction')
|
||||
bst = xgb.train( param, dtrain, 1, evallist )
|
||||
|
||||
@ -18,8 +18,7 @@ xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
|
||||
xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
|
||||
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
|
||||
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
|
||||
xglib.XGDMatrixGetLabel.restype = ctypes.POINTER(ctypes.c_float)
|
||||
xglib.XGDMatrixGetWeight.restype = ctypes.POINTER(ctypes.c_float)
|
||||
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
|
||||
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
|
||||
|
||||
xglib.XGBoosterCreate.restype = ctypes.c_void_p
|
||||
@ -77,28 +76,46 @@ class DMatrix:
|
||||
# destructor
|
||||
def __del__(self):
|
||||
xglib.XGDMatrixFree(self.handle)
|
||||
def __get_float_info(self, field):
|
||||
length = ctypes.c_ulong()
|
||||
ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
||||
ctypes.byref(length))
|
||||
return ctypes2numpy(ret, length.value)
|
||||
def __set_float_info(self, field, data):
|
||||
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
|
||||
(ctypes.c_float*len(data))(*data), len(data))
|
||||
# load data from file
|
||||
def save_binary(self, fname, silent=True):
|
||||
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
|
||||
# set label of dmatrix
|
||||
def set_label(self, label):
|
||||
xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label))
|
||||
self.__set_float_info('label', label)
|
||||
# set weight of each instances
|
||||
def set_weight(self, weight):
|
||||
self.__set_float_info('weight', label)
|
||||
# set initialized margin prediction
|
||||
def set_base_margin(self, margin):
|
||||
"""
|
||||
set base margin of booster to start from
|
||||
this can be used to specify a prediction value of
|
||||
existing model to be base_margin
|
||||
However, remember margin is needed, instead of transformed prediction
|
||||
e.g. for logistic regression: need to put in value before logistic transformation
|
||||
see also example/demo.py
|
||||
"""
|
||||
self.__set_float_info('base_margin', margin)
|
||||
# set group size of dmatrix, used for rank
|
||||
def set_group(self, group):
|
||||
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
|
||||
# set weight of each instances
|
||||
def set_weight(self, weight):
|
||||
xglib.XGDMatrixSetWeight(self.handle, (ctypes.c_float*len(weight))(*weight), len(weight))
|
||||
# get label from dmatrix
|
||||
def get_label(self):
|
||||
length = ctypes.c_ulong()
|
||||
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length))
|
||||
return ctypes2numpy(labels, length.value)
|
||||
return self.__get_float_info('label')
|
||||
# get weight from dmatrix
|
||||
def get_weight(self):
|
||||
length = ctypes.c_ulong()
|
||||
weights = xglib.XGDMatrixGetWeight(self.handle, ctypes.byref(length))
|
||||
return ctypes2numpy(weights, length.value)
|
||||
return self.__get_float_info('weight')
|
||||
# get base_margin from dmatrix
|
||||
def get_base_margin(self):
|
||||
return self.__get_float_info('base_margin')
|
||||
def num_row(self):
|
||||
return xglib.XGDMatrixNumRow(self.handle)
|
||||
# slice the DMatrix to return a new DMatrix that only contains rindex
|
||||
@ -161,9 +178,15 @@ class Booster:
|
||||
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
||||
def eval(self, mat, name = 'eval', it = 0):
|
||||
return self.eval_set( [(mat,name)], it)
|
||||
def predict(self, data):
|
||||
def predict(self, data, output_margin=False):
|
||||
"""
|
||||
predict with data
|
||||
data: the dmatrix storing the input
|
||||
output_margin: whether output raw margin value that is untransformed
|
||||
"""
|
||||
length = ctypes.c_ulong()
|
||||
preds = xglib.XGBoosterPredict(self.handle, data.handle, ctypes.byref(length))
|
||||
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
||||
int(output_margin), ctypes.byref(length))
|
||||
return ctypes2numpy(preds, length.value)
|
||||
def save_model(self, fname):
|
||||
""" save model to file """
|
||||
|
||||
@ -23,9 +23,9 @@ class Booster: public learner::BoostLearner<FMatrixS> {
|
||||
this->init_model = false;
|
||||
this->SetCacheData(mats);
|
||||
}
|
||||
const float *Pred(const DataMatrix &dmat, size_t *len) {
|
||||
const float *Pred(const DataMatrix &dmat, int output_margin, size_t *len) {
|
||||
this->CheckInitModel();
|
||||
this->Predict(dmat, &this->preds_);
|
||||
this->Predict(dmat, output_margin, &this->preds_);
|
||||
*len = this->preds_.size();
|
||||
return &this->preds_[0];
|
||||
}
|
||||
@ -163,15 +163,11 @@ extern "C"{
|
||||
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent) {
|
||||
SaveDataMatrix(*static_cast<DataMatrix*>(handle), fname, silent);
|
||||
}
|
||||
void XGDMatrixSetLabel(void *handle, const float *label, size_t len) {
|
||||
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
|
||||
pmat->info.labels.resize(len);
|
||||
memcpy(&(pmat->info).labels[0], label, sizeof(float) * len);
|
||||
}
|
||||
void XGDMatrixSetWeight(void *handle, const float *weight, size_t len) {
|
||||
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
|
||||
pmat->info.weights.resize(len);
|
||||
memcpy(&(pmat->info).weights[0], weight, sizeof(float) * len);
|
||||
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, size_t len) {
|
||||
std::vector<float> &vec =
|
||||
static_cast<DataMatrix*>(handle)->info.GetInfo(field);
|
||||
vec.resize(len);
|
||||
memcpy(&vec[0], info, sizeof(float) * len);
|
||||
}
|
||||
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) {
|
||||
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
|
||||
@ -181,15 +177,11 @@ extern "C"{
|
||||
pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i];
|
||||
}
|
||||
}
|
||||
const float* XGDMatrixGetLabel(const void *handle, size_t* len) {
|
||||
const DataMatrix *pmat = static_cast<const DataMatrix*>(handle);
|
||||
*len = pmat->info.labels.size();
|
||||
return &(pmat->info.labels[0]);
|
||||
}
|
||||
const float* XGDMatrixGetWeight(const void *handle, size_t* len) {
|
||||
const DataMatrix *pmat = static_cast<const DataMatrix*>(handle);
|
||||
*len = pmat->info.weights.size();
|
||||
return &(pmat->info.weights[0]);
|
||||
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* len) {
|
||||
const std::vector<float> &vec =
|
||||
static_cast<const DataMatrix*>(handle)->info.GetInfo(field);
|
||||
*len = vec.size();
|
||||
return &vec[0];
|
||||
}
|
||||
size_t XGDMatrixNumRow(const void *handle) {
|
||||
return static_cast<const DataMatrix*>(handle)->info.num_row;
|
||||
@ -238,8 +230,8 @@ extern "C"{
|
||||
bst->eval_str = bst->EvalOneIter(iter, mats, names);
|
||||
return bst->eval_str.c_str();
|
||||
}
|
||||
const float *XGBoosterPredict(void *handle, void *dmat, size_t *len) {
|
||||
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), len);
|
||||
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, size_t *len) {
|
||||
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, len);
|
||||
}
|
||||
void XGBoosterLoadModel(void *handle, const char *fname) {
|
||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||
|
||||
@ -64,19 +64,13 @@ extern "C" {
|
||||
*/
|
||||
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent);
|
||||
/*!
|
||||
* \brief set label of the training matrix
|
||||
* \brief set float vector to a content in info
|
||||
* \param handle a instance of data matrix
|
||||
* \param label pointer to label
|
||||
* \param field field name, can be label, weight
|
||||
* \param array pointer to float vector
|
||||
* \param len length of array
|
||||
*/
|
||||
void XGDMatrixSetLabel(void *handle, const float *label, size_t len);
|
||||
/*!
|
||||
* \brief set weight of each instance
|
||||
* \param handle a instance of data matrix
|
||||
* \param weight data pointer to weights
|
||||
* \param len length of array
|
||||
*/
|
||||
void XGDMatrixSetWeight(void *handle, const float *weight, size_t len);
|
||||
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, size_t len);
|
||||
/*!
|
||||
* \brief set label of the training matrix
|
||||
* \param handle a instance of data matrix
|
||||
@ -85,19 +79,13 @@ extern "C" {
|
||||
*/
|
||||
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len);
|
||||
/*!
|
||||
* \brief get label set from matrix
|
||||
* \brief get float info vector from matrix
|
||||
* \param handle a instance of data matrix
|
||||
* \param len used to set result length
|
||||
* \param field field name
|
||||
* \return pointer to the label
|
||||
*/
|
||||
const float* XGDMatrixGetLabel(const void *handle, size_t* out_len);
|
||||
/*!
|
||||
* \brief get weight set from matrix
|
||||
* \param handle a instance of data matrix
|
||||
* \param len used to set result length
|
||||
* \return pointer to the weight
|
||||
*/
|
||||
const float* XGDMatrixGetWeight(const void *handle, size_t* out_len);
|
||||
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len);
|
||||
/*!
|
||||
* \brief return number of rows
|
||||
*/
|
||||
@ -154,9 +142,10 @@ extern "C" {
|
||||
* \brief make prediction based on dmat
|
||||
* \param handle handle
|
||||
* \param dmat data matrix
|
||||
* \param output_margin whether only output raw margin value
|
||||
* \param len used to store length of returning result
|
||||
*/
|
||||
const float *XGBoosterPredict(void *handle, void *dmat, size_t *len);
|
||||
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, size_t *len);
|
||||
/*!
|
||||
* \brief load model from existing file
|
||||
* \param handle handle
|
||||
|
||||
@ -233,7 +233,7 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
pred_counter[bid] = static_cast<unsigned>(trees.size());
|
||||
pred_buffer[bid] = psum;
|
||||
}
|
||||
return psum + mparam.base_score;
|
||||
return psum;
|
||||
}
|
||||
// initialize thread local space for prediction
|
||||
inline void InitThreadTemp(int nthread) {
|
||||
@ -296,8 +296,6 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
};
|
||||
/*! \brief model parameters */
|
||||
struct ModelParam {
|
||||
/*! \brief base prediction score of everything */
|
||||
float base_score;
|
||||
/*! \brief number of trees */
|
||||
int num_trees;
|
||||
/*! \brief number of root: default 0, means single tree */
|
||||
@ -316,7 +314,6 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
int reserved[32];
|
||||
/*! \brief constructor */
|
||||
ModelParam(void) {
|
||||
base_score = 0.0f;
|
||||
num_trees = 0;
|
||||
num_roots = num_feature = 0;
|
||||
num_pbuffer = 0;
|
||||
@ -329,7 +326,6 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
if (!strcmp("base_score", name)) base_score = static_cast<float>(atof(val));
|
||||
if (!strcmp("num_pbuffer", name)) num_pbuffer = atol(val);
|
||||
if (!strcmp("num_output_group", name)) num_output_group = atol(val);
|
||||
if (!strcmp("bst:num_roots", name)) num_roots = atoi(val);
|
||||
|
||||
@ -110,10 +110,13 @@ class DMatrixSimple : public DataMatrix {
|
||||
"DMatrix: group data does not match the number of rows in features");
|
||||
}
|
||||
std::string wname = name + ".weight";
|
||||
if (info.TryLoadWeight(wname.c_str(), silent)) {
|
||||
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
|
||||
utils::Check(info.weights.size() == info.num_row,
|
||||
"DMatrix: weight data does not match the number of rows in features");
|
||||
}
|
||||
std::string mname = name + ".base_margin";
|
||||
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief load from binary file
|
||||
|
||||
@ -33,6 +33,15 @@ struct MetaInfo {
|
||||
* can be used for multi task setting
|
||||
*/
|
||||
std::vector<unsigned> root_index;
|
||||
/*!
|
||||
* \brief initialized margins,
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*/
|
||||
std::vector<float> base_margin;
|
||||
/*! \brief version flag, used to check version of this info */
|
||||
static const int kVersion = 0;
|
||||
// constructor
|
||||
MetaInfo(void) : num_row(0), num_col(0) {}
|
||||
/*! \brief clear all the information */
|
||||
inline void Clear(void) {
|
||||
@ -40,6 +49,7 @@ struct MetaInfo {
|
||||
group_ptr.clear();
|
||||
weights.clear();
|
||||
root_index.clear();
|
||||
base_margin.clear();
|
||||
num_row = num_col = 0;
|
||||
}
|
||||
/*! \brief get weight of each instances */
|
||||
@ -59,20 +69,26 @@ struct MetaInfo {
|
||||
}
|
||||
}
|
||||
inline void SaveBinary(utils::IStream &fo) const {
|
||||
int version = kVersion;
|
||||
fo.Write(&version, sizeof(version));
|
||||
fo.Write(&num_row, sizeof(num_row));
|
||||
fo.Write(&num_col, sizeof(num_col));
|
||||
fo.Write(labels);
|
||||
fo.Write(group_ptr);
|
||||
fo.Write(weights);
|
||||
fo.Write(root_index);
|
||||
fo.Write(base_margin);
|
||||
}
|
||||
inline void LoadBinary(utils::IStream &fi) {
|
||||
int version;
|
||||
utils::Check(fi.Read(&version, sizeof(version)), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format");
|
||||
}
|
||||
// try to load group information from file, if exists
|
||||
inline bool TryLoadGroup(const char* fname, bool silent = false) {
|
||||
@ -89,8 +105,19 @@ struct MetaInfo {
|
||||
fclose(fi);
|
||||
return true;
|
||||
}
|
||||
inline std::vector<float>& GetInfo(const char *field) {
|
||||
if (!strcmp(field, "label")) return labels;
|
||||
if (!strcmp(field, "weight")) return weights;
|
||||
if (!strcmp(field, "base_margin")) return base_margin;
|
||||
utils::Error("unknown field %s", field);
|
||||
return labels;
|
||||
}
|
||||
inline const std::vector<float>& GetInfo(const char *field) const {
|
||||
return ((MetaInfo*)this)->GetInfo(field);
|
||||
}
|
||||
// try to load weight information from file, if exists
|
||||
inline bool TryLoadWeight(const char* fname, bool silent = false) {
|
||||
inline bool TryLoadFloatInfo(const char *field, const char* fname, bool silent = false) {
|
||||
std::vector<float> &weights = this->GetInfo(field);
|
||||
FILE *fi = fopen64(fname, "r");
|
||||
if (fi == NULL) return false;
|
||||
float wt;
|
||||
@ -98,7 +125,7 @@ struct MetaInfo {
|
||||
weights.push_back(wt);
|
||||
}
|
||||
if (!silent) {
|
||||
printf("loading weight from %s\n", fname);
|
||||
printf("loading %s from %s\n", field, fname);
|
||||
}
|
||||
fclose(fi);
|
||||
return true;
|
||||
|
||||
@ -97,9 +97,6 @@ class BoostLearner {
|
||||
this->InitObjGBM();
|
||||
// reset the base score
|
||||
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
|
||||
char tmp[32];
|
||||
snprintf(tmp, sizeof(tmp), "%g", mparam.base_score);
|
||||
this->SetParam("base_score", tmp);
|
||||
// initialize GBM model
|
||||
gbm_->InitModel();
|
||||
}
|
||||
@ -199,12 +196,16 @@ class BoostLearner {
|
||||
/*!
|
||||
* \brief get prediction
|
||||
* \param data input data
|
||||
* \param output_margin whether to only predict margin value instead of transformed prediction
|
||||
* \param out_preds output vector that stores the prediction
|
||||
*/
|
||||
inline void Predict(const DMatrix<FMatrix> &data,
|
||||
bool output_margin,
|
||||
std::vector<float> *out_preds) const {
|
||||
this->PredictRaw(data, out_preds);
|
||||
obj_->PredTransform(out_preds);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
}
|
||||
/*! \brief dump model out */
|
||||
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||
@ -236,6 +237,22 @@ class BoostLearner {
|
||||
std::vector<float> *out_preds) const {
|
||||
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
|
||||
data.info.root_index, out_preds);
|
||||
// add base margin
|
||||
std::vector<float> &preds = *out_preds;
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
if (data.info.base_margin.size() != 0) {
|
||||
utils::Check(preds.size() == data.info.base_margin.size(),
|
||||
"base_margin.size does not match with prediction size");
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
preds[j] += data.info.base_margin[j];
|
||||
}
|
||||
} else {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
preds[j] += mparam.base_score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief training parameter for regression */
|
||||
|
||||
@ -49,6 +49,7 @@ class BoostLearnTask{
|
||||
if (!strcmp("silent", name)) silent = atoi(val);
|
||||
if (!strcmp("use_buffer", name)) use_buffer = atoi(val);
|
||||
if (!strcmp("num_round", name)) num_round = atoi(val);
|
||||
if (!strcmp("pred_margin", name)) pred_margin = atoi(val);
|
||||
if (!strcmp("save_period", name)) save_period = atoi(val);
|
||||
if (!strcmp("eval_train", name)) eval_train = atoi(val);
|
||||
if (!strcmp("task", name)) task = val;
|
||||
@ -77,6 +78,7 @@ class BoostLearnTask{
|
||||
num_round = 10;
|
||||
save_period = 0;
|
||||
eval_train = 0;
|
||||
pred_margin = 0;
|
||||
dump_model_stats = 0;
|
||||
task = "train";
|
||||
model_in = "NULL";
|
||||
@ -184,7 +186,7 @@ class BoostLearnTask{
|
||||
inline void TaskPred(void) {
|
||||
std::vector<float> preds;
|
||||
if (!silent) printf("start prediction...\n");
|
||||
learner.Predict(*data, &preds);
|
||||
learner.Predict(*data, pred_margin != 0, &preds);
|
||||
if (!silent) printf("writing prediction to %s\n", name_pred.c_str());
|
||||
FILE *fo = utils::FopenCheck(name_pred.c_str(), "w");
|
||||
for (size_t i = 0; i < preds.size(); i++) {
|
||||
@ -193,37 +195,39 @@ class BoostLearnTask{
|
||||
fclose(fo);
|
||||
}
|
||||
private:
|
||||
/* \brief whether silent */
|
||||
/*! \brief whether silent */
|
||||
int silent;
|
||||
/* \brief whether use auto binary buffer */
|
||||
/*! \brief whether use auto binary buffer */
|
||||
int use_buffer;
|
||||
/* \brief whether evaluate training statistics */
|
||||
/*! \brief whether evaluate training statistics */
|
||||
int eval_train;
|
||||
/* \brief number of boosting iterations */
|
||||
/*! \brief number of boosting iterations */
|
||||
int num_round;
|
||||
/* \brief the period to save the model, 0 means only save the final round model */
|
||||
/*! \brief the period to save the model, 0 means only save the final round model */
|
||||
int save_period;
|
||||
/* \brief the path of training/test data set */
|
||||
/*! \brief the path of training/test data set */
|
||||
std::string train_path, test_path;
|
||||
/* \brief the path of test model file, or file to restart training */
|
||||
/*! \brief the path of test model file, or file to restart training */
|
||||
std::string model_in;
|
||||
/* \brief the path of final model file, to be saved */
|
||||
/*! \brief the path of final model file, to be saved */
|
||||
std::string model_out;
|
||||
/* \brief the path of directory containing the saved models */
|
||||
/*! \brief the path of directory containing the saved models */
|
||||
std::string model_dir_path;
|
||||
/* \brief task to perform */
|
||||
/*! \brief task to perform */
|
||||
std::string task;
|
||||
/* \brief name of predict file */
|
||||
/*! \brief name of predict file */
|
||||
std::string name_pred;
|
||||
/* \brief whether dump statistics along with model */
|
||||
/*!\brief whether to directly output margin value */
|
||||
int pred_margin;
|
||||
/*! \brief whether dump statistics along with model */
|
||||
int dump_model_stats;
|
||||
/* \brief name of feature map */
|
||||
/*! \brief name of feature map */
|
||||
std::string name_fmap;
|
||||
/* \brief name of dump file */
|
||||
/*! \brief name of dump file */
|
||||
std::string name_dump;
|
||||
/* \brief the paths of validation data sets */
|
||||
/*! \brief the paths of validation data sets */
|
||||
std::vector<std::string> eval_data_paths;
|
||||
/* \brief the names of the evaluation data used in output log */
|
||||
/*! \brief the names of the evaluation data used in output log */
|
||||
std::vector<std::string> eval_data_names;
|
||||
private:
|
||||
io::DataMatrix* data;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user