diff --git a/Makefile b/Makefile index 0a5fd7047..8c9980ac1 100644 --- a/Makefile +++ b/Makefile @@ -5,17 +5,23 @@ export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas # specify tensor path BIN = xgboost OBJ = io.o +SLIB = python/libxgboostwrapper.so .PHONY: clean all -all: $(BIN) $(OBJ) +all: $(BIN) $(OBJ) $(SLIB) export LDFLAGS= -pthread -lm xgboost: src/xgboost_main.cpp io.o src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp io.o: src/io/io.cpp src/data.h src/utils/*.h +# now the wrapper takes in two files. io and wrapper part +python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h $(BIN) : $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) +$(SLIB) : + $(CXX) $(CFLAGS) -fPIC $(LDFLAGS) -shared -o $@ $(filter %.cpp %.o %.c, $^) + $(OBJ) : $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) ) diff --git a/README.md b/README.md index f5b64b78a..732e64d7f 100644 --- a/README.md +++ b/README.md @@ -17,19 +17,7 @@ Build * Simply type make * If your compiler does not come with OpenMP support, it will fire an warning telling you that the code will compile into single thread mode, and you will get single thread xgboost * You may get a error: -lgomp is not found, you can remove -fopenmp flag in Makefile to get single thread xgboost, or upgrade your compiler to compile multi-thread version +* Possible way to build using Visual Studio (not tested): + - In principle, you can put src/xgboost.cpp and src/io/io.cpp into the project, and build xgboost. + - For python module, you need python/xgboost_wrapper.cpp and src/io/io.cpp to build a dll. -Project Logical Layout -======= -* Dependency order: io->learner->gbm->tree - - All module depends on data.h -* tree are implementations of tree construction algorithms. -* gbm is gradient boosting interface, that takes trees and other base learner to do boosting. - - gbm only takes gradient as sufficient statistics, it does not compute the gradient. -* learner is learning module that computes gradient for specific object, and pass it to GBM - -File Naming Convention -======= -* The project is templatized, to make it easy to adjust input data structure. -* .h files are data structures and interface, which are needed to use functions in that layer. -* -inl.hpp files are implementations of interface, like cpp file in most project. - - You only need to understand the interface file to understand the usage of that layer diff --git a/python/Makefile b/python/Makefile deleted file mode 100644 index 76dfdcf01..000000000 --- a/python/Makefile +++ /dev/null @@ -1,26 +0,0 @@ -export CC = gcc -export CXX = g++ -export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp - -# specify tensor path -SLIB = libxgboostpy.so -.PHONY: clean all - -all: $(SLIB) -export LDFLAGS= -pthread -lm - -libxgboostpy.so: xgboost_python.cpp ../regrank/*.h ../booster/*.h ../booster/*/*.hpp ../booster/*.hpp - -$(SLIB) : - $(CXX) $(CFLAGS) -fPIC $(LDFLAGS) -shared -o $@ $(filter %.cpp %.o %.c, $^) -$(BIN) : - $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) - -$(OBJ) : - $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) ) - -install: - cp -f -r $(BIN) $(INSTALL_PATH) - -clean: - $(RM) $(OBJ) $(BIN) $(SLIB) *~ diff --git a/python/README.md b/python/README.md index 19d33aa08..cf59ba9ab 100644 --- a/python/README.md +++ b/python/README.md @@ -1,3 +1,5 @@ python wrapper for xgboost using ctypes see example for usage + +to make the python module, type make in the root directory of project diff --git a/python/xgboost.py b/python/xgboost.py index 070fe6593..2e8deefa8 100644 --- a/python/xgboost.py +++ b/python/xgboost.py @@ -8,11 +8,7 @@ import numpy.ctypeslib import scipy.sparse as scp # set this line correctly -XGBOOST_PATH = os.path.dirname(__file__)+'/libxgboostpy.so' - -# entry type of sparse matrix -class REntry(ctypes.Structure): - _fields_ = [("findex", ctypes.c_uint), ("fvalue", ctypes.c_float) ] +XGBOOST_PATH = os.path.dirname(__file__)+'/libxgboostwrapper.so' # load in xgboost library xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH) diff --git a/python/xgboost_python.cpp b/python/xgboost_python.cpp deleted file mode 100644 index a325a20d4..000000000 --- a/python/xgboost_python.cpp +++ /dev/null @@ -1,297 +0,0 @@ -// implementations in ctypes -#include "xgboost_python.h" -#include "../regrank/xgboost_regrank.h" -#include "../regrank/xgboost_regrank_data.h" - -namespace xgboost{ - namespace python{ - class DMatrix: public regrank::DMatrix{ - public: - // whether column is initialized - bool init_col_; - public: - DMatrix(void){ - init_col_ = false; - } - ~DMatrix(void){} - public: - inline void Load(const char *fname, bool silent){ - this->CacheLoad(fname, silent); - init_col_ = this->data.HaveColAccess(); - } - inline void Clear( void ){ - this->data.Clear(); - this->info.labels.clear(); - this->info.weights.clear(); - this->info.group_ptr.clear(); - } - inline size_t NumRow( void ) const{ - return this->data.NumRow(); - } - inline void AddRow( const XGEntry *data, size_t len ){ - xgboost::booster::FMatrixS &mat = this->data; - mat.row_data_.resize( mat.row_ptr_.back() + len ); - memcpy( &mat.row_data_[mat.row_ptr_.back()], data, sizeof(XGEntry)*len ); - mat.row_ptr_.push_back( mat.row_ptr_.back() + len ); - init_col_ = false; - } - inline const XGEntry* GetRow(unsigned ridx, size_t* len) const{ - const xgboost::booster::FMatrixS &mat = this->data; - - *len = mat.row_ptr_[ridx+1] - mat.row_ptr_[ridx]; - return &mat.row_data_[ mat.row_ptr_[ridx] ]; - } - inline void ParseCSR( const size_t *indptr, - const unsigned *indices, - const float *data, - size_t nindptr, - size_t nelem ){ - xgboost::booster::FMatrixS &mat = this->data; - mat.row_ptr_.resize( nindptr ); - memcpy( &mat.row_ptr_[0], indptr, sizeof(size_t)*nindptr ); - mat.row_data_.resize( nelem ); - for( size_t i = 0; i < nelem; ++ i ){ - mat.row_data_[i] = XGEntry(indices[i], data[i]); - } - this->data.InitData(); - this->init_col_ = true; - } - - inline void ParseMat( const float *data, - size_t nrow, - size_t ncol, - float missing ){ - xgboost::booster::FMatrixS &mat = this->data; - mat.Clear(); - for( size_t i = 0; i < nrow; ++i, data += ncol ){ - size_t nelem = 0; - for( size_t j = 0; j < ncol; ++j ){ - if( data[j] != missing ){ - mat.row_data_.push_back( XGEntry(j, data[j]) ); - ++ nelem; - } - } - mat.row_ptr_.push_back( mat.row_ptr_.back() + nelem ); - } - this->data.InitData(); - this->init_col_ = true; - } - inline void SetLabel( const float *label, size_t len ){ - this->info.labels.resize( len ); - memcpy( &(this->info).labels[0], label, sizeof(float)*len ); - } - inline void SetGroup( const unsigned *group, size_t len ){ - this->info.group_ptr.resize( len + 1 ); - this->info.group_ptr[0] = 0; - for( size_t i = 0; i < len; ++ i ){ - this->info.group_ptr[i+1] = this->info.group_ptr[i]+group[i]; - } - } - inline void SetWeight( const float *weight, size_t len ){ - this->info.weights.resize( len ); - memcpy( &(this->info).weights[0], weight, sizeof(float)*len ); - } - inline const float* GetLabel( size_t* len ) const{ - *len = this->info.labels.size(); - return &(this->info.labels[0]); - } - inline const float* GetWeight( size_t* len ) const{ - *len = this->info.weights.size(); - return &(this->info.weights[0]); - } - inline void CheckInit(void){ - if(!init_col_){ - this->data.InitData(); - init_col_ = true; - } - utils::Assert( this->data.NumRow() == this->info.labels.size(), "DMatrix: number of labels must match number of rows in matrix"); - } - }; - - class Booster: public xgboost::regrank::RegRankBoostLearner{ - private: - bool init_trainer, init_model; - public: - Booster(const std::vector mats){ - silent = 1; - init_trainer = false; - init_model = false; - this->SetCacheData(mats); - } - inline void CheckInit(void){ - if( !init_trainer ){ - this->InitTrainer(); init_trainer = true; - } - if( !init_model ){ - this->InitModel(); init_model = true; - } - } - inline void LoadModel( const char *fname ){ - xgboost::regrank::RegRankBoostLearner::LoadModel(fname); - this->init_model = true; - } - inline void SetParam( const char *name, const char *val ){ - if( !strcmp( name, "seed" ) ) random::Seed(atoi(val)); - xgboost::regrank::RegRankBoostLearner::SetParam( name, val ); - } - const float *Pred( const DMatrix &dmat, size_t *len, int bst_group ){ - this->CheckInit(); - - this->Predict( this->preds_, dmat, bst_group ); - *len = this->preds_.size(); - return &this->preds_[0]; - } - inline void BoostOneIter( const DMatrix &train, - float *grad, float *hess, size_t len, int bst_group ){ - this->grad_.resize( len ); this->hess_.resize( len ); - memcpy( &this->grad_[0], grad, sizeof(float)*len ); - memcpy( &this->hess_[0], hess, sizeof(float)*len ); - - if( grad_.size() == train.Size() ){ - if( bst_group < 0 ) bst_group = 0; - base_gbm.DoBoost(grad_, hess_, train.data, train.info.root_index, bst_group); - }else{ - utils::Assert( bst_group == -1, "must set bst_group to -1 to support all group boosting" ); - int ngroup = base_gbm.NumBoosterGroup(); - utils::Assert( grad_.size() == train.Size() * (size_t)ngroup, "BUG: UpdateOneIter: mclass" ); - std::vector tgrad( train.Size() ), thess( train.Size() ); - for( int g = 0; g < ngroup; ++ g ){ - memcpy( &tgrad[0], &grad_[g*tgrad.size()], sizeof(float)*tgrad.size() ); - memcpy( &thess[0], &hess_[g*tgrad.size()], sizeof(float)*tgrad.size() ); - base_gbm.DoBoost(tgrad, thess, train.data, train.info.root_index, g ); - } - } - } - }; - }; -}; - -using namespace xgboost::python; - - -extern "C"{ - void* XGDMatrixCreate( void ){ - return new DMatrix(); - } - void XGDMatrixFree( void *handle ){ - delete static_cast(handle); - } - void XGDMatrixLoad( void *handle, const char *fname, int silent ){ - static_cast(handle)->Load(fname, silent!=0); - } - void XGDMatrixSaveBinary( void *handle, const char *fname, int silent ){ - static_cast(handle)->SaveBinary(fname, silent!=0); - } - void XGDMatrixParseCSR( void *handle, - const size_t *indptr, - const unsigned *indices, - const float *data, - size_t nindptr, - size_t nelem ){ - static_cast(handle)->ParseCSR(indptr, indices, data, nindptr, nelem); - } - void XGDMatrixParseMat( void *handle, - const float *data, - size_t nrow, - size_t ncol, - float missing ){ - static_cast(handle)->ParseMat(data, nrow, ncol, missing); - } - void XGDMatrixSetLabel( void *handle, const float *label, size_t len ){ - static_cast(handle)->SetLabel(label,len); - } - void XGDMatrixSetWeight( void *handle, const float *weight, size_t len ){ - static_cast(handle)->SetWeight(weight,len); - } - void XGDMatrixSetGroup( void *handle, const unsigned *group, size_t len ){ - static_cast(handle)->SetGroup(group,len); - } - const float* XGDMatrixGetLabel( const void *handle, size_t* len ){ - return static_cast(handle)->GetLabel(len); - } - const float* XGDMatrixGetWeight( const void *handle, size_t* len ){ - return static_cast(handle)->GetWeight(len); - } - void XGDMatrixClear(void *handle){ - static_cast(handle)->Clear(); - } - void XGDMatrixAddRow( void *handle, const XGEntry *data, size_t len ){ - static_cast(handle)->AddRow(data, len); - } - size_t XGDMatrixNumRow(const void *handle){ - return static_cast(handle)->NumRow(); - } - const XGEntry* XGDMatrixGetRow(void *handle, unsigned ridx, size_t* len){ - return static_cast(handle)->GetRow(ridx, len); - } - - // xgboost implementation - void *XGBoosterCreate( void *dmats[], size_t len ){ - std::vector mats; - for( size_t i = 0; i < len; ++i ){ - DMatrix *dtr = static_cast(dmats[i]); - dtr->CheckInit(); - mats.push_back( dtr ); - } - return new Booster( mats ); - } - void XGBoosterFree( void *handle ){ - delete static_cast(handle); - } - void XGBoosterSetParam( void *handle, const char *name, const char *value ){ - static_cast(handle)->SetParam( name, value ); - } - void XGBoosterUpdateOneIter( void *handle, void *dtrain ){ - Booster *bst = static_cast(handle); - DMatrix *dtr = static_cast(dtrain); - bst->CheckInit(); dtr->CheckInit(); - bst->UpdateOneIter( *dtr ); - } - void XGBoosterBoostOneIter( void *handle, void *dtrain, - float *grad, float *hess, size_t len, int bst_group ){ - Booster *bst = static_cast(handle); - DMatrix *dtr = static_cast(dtrain); - bst->CheckInit(); dtr->CheckInit(); - bst->BoostOneIter( *dtr, grad, hess, len, bst_group ); - } - void XGBoosterEvalOneIter( void *handle, int iter, void *dmats[], const char *evnames[], size_t len ){ - Booster *bst = static_cast(handle); - bst->CheckInit(); - - std::vector names; - std::vector mats; - for( size_t i = 0; i < len; ++i ){ - mats.push_back( static_cast(dmats[i]) ); - names.push_back( std::string( evnames[i]) ); - } - bst->EvalOneIter( iter, mats, names, stderr ); - } - const float *XGBoosterPredict( void *handle, void *dmat, size_t *len, int bst_group ){ - return static_cast(handle)->Pred( *static_cast(dmat), len, bst_group ); - } - void XGBoosterLoadModel( void *handle, const char *fname ){ - static_cast(handle)->LoadModel( fname ); - } - void XGBoosterSaveModel( const void *handle, const char *fname ){ - static_cast(handle)->SaveModel( fname ); - } - void XGBoosterDumpModel( void *handle, const char *fname, const char *fmap ){ - using namespace xgboost::utils; - FILE *fo = FopenCheck( fname, "w" ); - FeatMap featmap; - if( strlen(fmap) != 0 ){ - featmap.LoadText( fmap ); - } - static_cast(handle)->DumpModel( fo, featmap, false ); - fclose( fo ); - } - - void XGBoosterUpdateInteract( void *handle, void *dtrain, const char *action ){ - Booster *bst = static_cast(handle); - DMatrix *dtr = static_cast(dtrain); - bst->CheckInit(); dtr->CheckInit(); - std::string act( action ); - bst->UpdateInteract( act, *dtr ); - } -}; - diff --git a/python/xgboost_python.h b/python/xgboost_python.h deleted file mode 100644 index 6c113a108..000000000 --- a/python/xgboost_python.h +++ /dev/null @@ -1,209 +0,0 @@ -#ifndef XGBOOST_PYTHON_H -#define XGBOOST_PYTHON_H -/*! - * \file xgboost_python.h - * \author Tianqi Chen - * \brief python wrapper for xgboost, using ctypes, - * hides everything behind functions - * use c style interface - */ -#include "../booster/xgboost_data.h" -extern "C"{ - /*! \brief type of row entry */ - typedef xgboost::booster::FMatrixS::REntry XGEntry; - - /*! - * \brief create a data matrix - * \return a new data matrix - */ - void* XGDMatrixCreate(void); - /*! - * \brief free space in data matrix - */ - void XGDMatrixFree(void *handle); - /*! - * \brief load a data matrix from text file or buffer(if exists) - * \param handle a instance of data matrix - * \param fname file name - * \param silent print statistics when loading - */ - void XGDMatrixLoad(void *handle, const char *fname, int silent); - /*! - * \brief load a data matrix into binary file - * \param handle a instance of data matrix - * \param fname file name - * \param silent print statistics when saving - */ - void XGDMatrixSaveBinary(void *handle, const char *fname, int silent); - /*! - * \brief set matrix content from csr format - * \param handle a instance of data matrix - * \param indptr pointer to row headers - * \param indices findex - * \param data fvalue - * \param nindptr number of rows in the matix + 1 - * \param nelem number of nonzero elements in the matrix - */ - void XGDMatrixParseCSR( void *handle, - const size_t *indptr, - const unsigned *indices, - const float *data, - size_t nindptr, - size_t nelem ); - /*! - * \brief set matrix content from data content - * \param handle a instance of data matrix - * \param data pointer to the data space - * \param nrow number of rows - * \param ncol number columns - * \param missing which value to represent missing value - */ - void XGDMatrixParseMat( void *handle, - const float *data, - size_t nrow, - size_t ncol, - float missing ); - /*! - * \brief set label of the training matrix - * \param handle a instance of data matrix - * \param label pointer to label - * \param len length of array - */ - void XGDMatrixSetLabel( void *handle, const float *label, size_t len ); - /*! - * \brief set label of the training matrix - * \param handle a instance of data matrix - * \param group pointer to group size - * \param len length of array - */ - void XGDMatrixSetGroup( void *handle, const unsigned *group, size_t len ); - /*! - * \brief set weight of each instacne - * \param handle a instance of data matrix - * \param weight data pointer to weights - * \param len length of array - */ - void XGDMatrixSetWeight( void *handle, const float *weight, size_t len ); - /*! - * \brief get label set from matrix - * \param handle a instance of data matrix - * \param len used to set result length - * \return pointer to the label - */ - const float* XGDMatrixGetLabel( const void *handle, size_t* len ); - /*! - * \brief get weight set from matrix - * \param handle a instance of data matrix - * \param len used to set result length - * \return pointer to the weight - */ - const float* XGDMatrixGetWeight( const void *handle, size_t* len ); - /*! - * \brief clear all the records, including feature matrix and label - * \param handle a instance of data matrix - */ - void XGDMatrixClear(void *handle); - /*! - * \brief return number of rows - */ - size_t XGDMatrixNumRow(const void *handle); - /*! - * \brief add row - * \param handle a instance of data matrix - * \param data array of row content - * \param len length of array - */ - void XGDMatrixAddRow(void *handle, const XGEntry *data, size_t len); - /*! - * \brief get ridx-th row of sparse matrix - * \param handle handle - * \param ridx row index - * \param len used to set result length - * \reurn pointer to the row - */ - const XGEntry* XGDMatrixGetRow(void *handle, unsigned ridx, size_t* len); - - // --- start XGBoost class - /*! - * \brief create xgboost learner - * \param dmats matrices that are set to be cached - * \param create a booster - */ - void *XGBoosterCreate( void* dmats[], size_t len ); - /*! - * \brief free obj in handle - * \param handle handle to be freed - */ - void XGBoosterFree( void* handle ); - /*! - * \brief set parameters - * \param handle handle - * \param name parameter name - * \param val value of parameter - */ - void XGBoosterSetParam( void *handle, const char *name, const char *value ); - /*! - * \brief update the model in one round using dtrain - * \param handle handle - * \param dtrain training data - */ - void XGBoosterUpdateOneIter( void *handle, void *dtrain ); - - /*! - * \brief update the model, by directly specify gradient and second order gradient, - * this can be used to replace UpdateOneIter, to support customized loss function - * \param handle handle - * \param dtrain training data - * \param grad gradient statistics - * \param hess second order gradient statistics - * \param len length of grad/hess array - * \param bst_group boost group we are working at, default = -1 - */ - void XGBoosterBoostOneIter( void *handle, void *dtrain, - float *grad, float *hess, size_t len, int bst_group ); - /*! - * \brief print evaluation statistics to stdout for xgboost - * \param handle handle - * \param iter current iteration rounds - * \param dmats pointers to data to be evaluated - * \param evnames pointers to names of each data - * \param len length of dmats - */ - void XGBoosterEvalOneIter( void *handle, int iter, void *dmats[], const char *evnames[], size_t len ); - /*! - * \brief make prediction based on dmat - * \param handle handle - * \param dmat data matrix - * \param len used to store length of returning result - * \param bst_group booster group, if model contains multiple booster group, default = -1 means predict for all groups - */ - const float *XGBoosterPredict( void *handle, void *dmat, size_t *len, int bst_group ); - /*! - * \brief load model from existing file - * \param handle handle - * \param fname file name - */ - void XGBoosterLoadModel( void *handle, const char *fname ); - /*! - * \brief save model into existing file - * \param handle handle - * \param fname file name - */ - void XGBoosterSaveModel( const void *handle, const char *fname ); - /*! - * \brief dump model into text file - * \param handle handle - * \param fname file name - * \param fmap name to fmap can be empty string - */ - void XGBoosterDumpModel( void *handle, const char *fname, const char *fmap ); - /*! - * \brief interactively update model: beta - * \param handle handle - * \param dtrain training data - * \param action action name - */ - void XGBoosterUpdateInteract( void *handle, void *dtrain, const char* action ); -}; -#endif - diff --git a/python/xgboost_wrapper.cpp b/python/xgboost_wrapper.cpp new file mode 100644 index 000000000..e43095920 --- /dev/null +++ b/python/xgboost_wrapper.cpp @@ -0,0 +1,240 @@ +// implementations in ctypes +#include +#include +#include +#include +#include +#include "./xgboost_wrapper.h" +#include "../src/data.h" +#include "../src/learner/learner-inl.hpp" +#include "../src/io/io.h" +#include "../src/io/simple_dmatrix-inl.hpp" + +using namespace xgboost; +using namespace xgboost::io; + +namespace xgboost { +namespace wrapper { +// booster wrapper class +class Booster: public learner::BoostLearner { + public: + explicit Booster(const std::vector& mats) { + this->silent = 1; + this->SetCacheData(mats); + } + const float *Pred(const DataMatrix &dmat, size_t *len) { + this->Predict(dmat, &this->preds_); + *len = this->preds_.size(); + return &this->preds_[0]; + } + inline void BoostOneIter(const DataMatrix &train, + float *grad, float *hess, size_t len) { + this->gpair_.resize(len); + const unsigned ndata = static_cast(len); + #pragma omp parallel for schedule(static) + for (unsigned j = 0; j < ndata; ++j) { + gpair_[j] = bst_gpair(grad[j], hess[j]); + } + gbm_->DoBoost(gpair_, train.fmat, train.info.root_index); + } + inline const char** GetModelDump(const utils::FeatMap& fmap, bool with_stats, size_t *len) { + model_dump = this->DumpModel(fmap, with_stats); + model_dump_cptr.resize(model_dump.size()); + for (size_t i = 0; i < model_dump.size(); ++i) { + model_dump_cptr[i] = model_dump[i].c_str(); + } + *len = model_dump.size(); + return &model_dump_cptr[0]; + } + // temporal fields + // temporal data to save evaluation dump + std::string eval_str; + // temporal space to save model dump + std::vector model_dump; + std::vector model_dump_cptr; +}; +} // namespace wrapper +} // namespace xgboost + +using namespace xgboost::wrapper; + +extern "C"{ + void* XGDMatrixCreateFromFile(const char *fname, int silent) { + return LoadDataMatrix(fname, silent, false); + } + void* XGDMatrixCreateFromCSR(const size_t *indptr, + const unsigned *indices, + const float *data, + size_t nindptr, + size_t nelem) { + DMatrixSimple *p_mat = new DMatrixSimple(); + DMatrixSimple &mat = *p_mat; + mat.row_ptr_.resize(nindptr); + memcpy(&mat.row_ptr_[0], indptr, sizeof(size_t)*nindptr); + mat.row_data_.resize(nelem); + for (size_t i = 0; i < nelem; ++ i) { + mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]); + mat.info.num_col = std::max(mat.info.num_col, + static_cast(indices[i]+1)); + } + mat.info.num_row = nindptr - 1; + return p_mat; + } + void* XGDMatrixCreateFromMat(const float *data, + size_t nrow, + size_t ncol, + float missing) { + DMatrixSimple *p_mat = new DMatrixSimple(); + DMatrixSimple &mat = *p_mat; + mat.info.num_row = nrow; + mat.info.num_col = ncol; + for (size_t i = 0; i < nrow; ++i, data += ncol) { + size_t nelem = 0; + for (size_t j = 0; j < ncol; ++j) { + if (data[j] != missing) { + mat.row_data_.push_back(SparseBatch::Entry(j, data[j])); + ++nelem; + } + } + mat.row_ptr_.push_back(mat.row_ptr_.back() + nelem); + } + return p_mat; + } + void* XGDMatrixSliceDMatrix(void *handle, + const int *idxset, + size_t len) { + DMatrixSimple tmp; + DataMatrix &dsrc = *static_cast(handle); + if (dsrc.magic != DMatrixSimple::kMagic) { + tmp.CopyFrom(dsrc); + } + DataMatrix &src = (dsrc.magic == DMatrixSimple::kMagic ? + *static_cast(handle): tmp); + DMatrixSimple *p_ret = new DMatrixSimple(); + DMatrixSimple &ret = *p_ret; + + utils::Check(src.info.group_ptr.size() == 0, + "slice does not support group structure"); + ret.Clear(); + ret.info.num_row = len; + ret.info.num_col = src.info.num_col; + + utils::IIterator *iter = src.fmat.RowIterator(); + iter->BeforeFirst(); + utils::Assert(iter->Next(), "slice"); + const SparseBatch &batch = iter->Value(); + for(size_t i = 0; i < len; ++i) { + const int ridx = idxset[i]; + SparseBatch::Inst inst = batch[ridx]; + utils::Check(ridx < batch.size, "slice index exceed number of rows"); + ret.row_data_.resize(ret.row_data_.size() + inst.length); + memcpy(&ret.row_data_[ret.row_ptr_.back()], inst.data, + sizeof(SparseBatch::Entry) * inst.length); + ret.row_ptr_.push_back(ret.row_ptr_.back() + inst.length); + if (src.info.labels.size() != 0) { + ret.info.labels.push_back(src.info.labels[ridx]); + } + if (src.info.weights.size() != 0) { + ret.info.weights.push_back(src.info.weights[ridx]); + } + if (src.info.root_index.size() != 0) { + ret.info.weights.push_back(src.info.root_index[ridx]); + } + } + return p_ret; + } + void XGDMatrixFree(void *handle) { + delete static_cast(handle); + } + void XGDMatrixSaveBinary(void *handle, const char *fname, int silent) { + SaveDataMatrix(*static_cast(handle), fname, silent); + } + void XGDMatrixSetLabel(void *handle, const float *label, size_t len) { + DataMatrix *pmat = static_cast(handle); + pmat->info.labels.resize(len); + memcpy(&(pmat->info).labels[0], label, sizeof(float) * len); + } + void XGDMatrixSetWeight(void *handle, const float *weight, size_t len) { + DataMatrix *pmat = static_cast(handle); + pmat->info.weights.resize(len); + memcpy(&(pmat->info).weights[0], weight, sizeof(float) * len); + } + void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len){ + DataMatrix *pmat = static_cast(handle); + pmat->info.group_ptr.resize(len + 1); + pmat->info.group_ptr[0] = 0; + for (size_t i = 0; i < len; ++ i) { + pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i]; + } + } + const float* XGDMatrixGetLabel(const void *handle, size_t* len) { + const DataMatrix *pmat = static_cast(handle); + *len = pmat->info.labels.size(); + return &(pmat->info.labels[0]); + } + const float* XGDMatrixGetWeight(const void *handle, size_t* len) { + const DataMatrix *pmat = static_cast(handle); + *len = pmat->info.weights.size(); + return &(pmat->info.weights[0]); + } + size_t XGDMatrixNumRow(const void *handle) { + return static_cast(handle)->info.num_row; + } + + // xgboost implementation + void *XGBoosterCreate(void *dmats[], size_t len) { + std::vector mats; + for (size_t i = 0; i < len; ++i) { + DataMatrix *dtr = static_cast(dmats[i]); + mats.push_back(dtr); + } + return new Booster(mats); + } + void XGBoosterFree(void *handle) { + delete static_cast(handle); + } + void XGBoosterSetParam(void *handle, const char *name, const char *value) { + static_cast(handle)->SetParam(name, value); + } + void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain) { + Booster *bst = static_cast(handle); + DataMatrix *dtr = static_cast(dtrain); + bst->CheckInit(dtr); + bst->UpdateOneIter(iter, *dtr); + } + void XGBoosterBoostOneIter(void *handle, void *dtrain, + float *grad, float *hess, size_t len) { + Booster *bst = static_cast(handle); + DataMatrix *dtr = static_cast(dtrain); + bst->CheckInit(dtr); + bst->BoostOneIter(*dtr, grad, hess, len); + } + const char* XGBoosterEvalOneIter(void *handle, int iter, void *dmats[], const char *evnames[], size_t len) { + Booster *bst = static_cast(handle); + std::vector names; + std::vector mats; + for (size_t i = 0; i < len; ++i) { + mats.push_back(static_cast(dmats[i])); + names.push_back(std::string(evnames[i])); + } + bst->eval_str = bst->EvalOneIter(iter, mats, names); + return bst->eval_str.c_str(); + } + const float *XGBoosterPredict(void *handle, void *dmat, size_t *len) { + return static_cast(handle)->Pred(*static_cast(dmat), len); + } + void XGBoosterLoadModel(void *handle, const char *fname) { + static_cast(handle)->LoadModel(fname); + } + void XGBoosterSaveModel( const void *handle, const char *fname) { + static_cast(handle)->SaveModel(fname); + } + const char** XGBoosterDumpModel(void *handle, const char *fmap, size_t *len){ + using namespace xgboost::utils; + FeatMap featmap; + if(strlen(fmap) != 0) { + featmap.LoadText(fmap); + } + return static_cast(handle)->GetModelDump(featmap, false, len); + } +}; diff --git a/python/xgboost_wrapper.h b/python/xgboost_wrapper.h new file mode 100644 index 000000000..16b8fecd7 --- /dev/null +++ b/python/xgboost_wrapper.h @@ -0,0 +1,182 @@ +#ifndef XGBOOST_WRAPPER_H_ +#define XGBOOST_WRAPPER_H_ +/*! + * \file xgboost_wrapperh + * \author Tianqi Chen + * \brief a C style wrapper of xgboost + * can be used to create wrapper of other languages + */ +#include + +extern "C" { + /*! + * \brief load a data matrix + * \return a loaded data matrix + */ + void* XGDMatrixCreateFromFile(const char *fname, int silent); + /*! + * \brief create a matrix content from csr format + * \param handle a instance of data matrix + * \param indptr pointer to row headers + * \param indices findex + * \param data fvalue + * \param nindptr number of rows in the matix + 1 + * \param nelem number of nonzero elements in the matrix + * \return created dmatrix + */ + void* XGDMatrixCreateFromCSR(const size_t *indptr, + const unsigned *indices, + const float *data, + size_t nindptr, + size_t nelem); + /*! + * \brief create matrix content from dense matrix + * \param handle a instance of data matrix + * \param data pointer to the data space + * \param nrow number of rows + * \param ncol number columns + * \param missing which value to represent missing value + * \return created dmatrix + */ + void* XGDMatrixCreateFromMat(const float *data, + size_t nrow, + size_t ncol, + float missing); + /*! + * \brief create a new dmatrix from sliced content of existing matrix + * \param handle instance of data matrix to be sliced + * \param idxset index set + * \param len length of index set + * \return a sliced new matrix + */ + void* XGDMatrixSliceDMatrix(void *handle, + const int *idxset, + size_t len); + /*! + * \brief free space in data matrix + */ + void XGDMatrixFree(void *handle); + /*! + * \brief load a data matrix into binary file + * \param handle a instance of data matrix + * \param fname file name + * \param silent print statistics when saving + */ + void XGDMatrixSaveBinary(void *handle, const char *fname, int silent); + /*! + * \brief set label of the training matrix + * \param handle a instance of data matrix + * \param label pointer to label + * \param len length of array + */ + void XGDMatrixSetLabel(void *handle, const float *label, size_t len); + /*! + * \brief set weight of each instance + * \param handle a instance of data matrix + * \param weight data pointer to weights + * \param len length of array + */ + void XGDMatrixSetWeight(void *handle, const float *weight, size_t len); + /*! + * \brief set label of the training matrix + * \param handle a instance of data matrix + * \param group pointer to group size + * \param len length of array + */ + void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len); + /*! + * \brief get label set from matrix + * \param handle a instance of data matrix + * \param len used to set result length + * \return pointer to the label + */ + const float* XGDMatrixGetLabel(const void *handle, size_t* out_len); + /*! + * \brief get weight set from matrix + * \param handle a instance of data matrix + * \param len used to set result length + * \return pointer to the weight + */ + const float* XGDMatrixGetWeight(const void *handle, size_t* out_len); + /*! + * \brief return number of rows + */ + size_t XGDMatrixNumRow(const void *handle); + // --- start XGBoost class + /*! + * \brief create xgboost learner + * \param dmats matrices that are set to be cached + * \param len length of dmats + */ + void *XGBoosterCreate(void* dmats[], size_t len); + /*! + * \brief free obj in handle + * \param handle handle to be freed + */ + void XGBoosterFree(void* handle); + /*! + * \brief set parameters + * \param handle handle + * \param name parameter name + * \param val value of parameter + */ + void XGBoosterSetParam(void *handle, const char *name, const char *value); + /*! + * \brief update the model in one round using dtrain + * \param handle handle + * \param iter current iteration rounds + * \param dtrain training data + */ + void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain); + /*! + * \brief update the model, by directly specify gradient and second order gradient, + * this can be used to replace UpdateOneIter, to support customized loss function + * \param handle handle + * \param dtrain training data + * \param grad gradient statistics + * \param hess second order gradient statistics + * \param len length of grad/hess array + */ + void XGBoosterBoostOneIter(void *handle, void *dtrain, + float *grad, float *hess, size_t len); + /*! + * \brief get evaluation statistics for xgboost + * \param handle handle + * \param iter current iteration rounds + * \param dmats pointers to data to be evaluated + * \param evnames pointers to names of each data + * \param len length of dmats + * \return the string containing evaluation stati + */ + const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[], + const char *evnames[], size_t len); + /*! + * \brief make prediction based on dmat + * \param handle handle + * \param dmat data matrix + * \param len used to store length of returning result + */ + const float *XGBoosterPredict(void *handle, void *dmat, size_t *len); + /*! + * \brief load model from existing file + * \param handle handle + * \param fname file name + */ + void XGBoosterLoadModel(void *handle, const char *fname); + /*! + * \brief save model into existing file + * \param handle handle + * \param fname file name + */ + void XGBoosterSaveModel(const void *handle, const char *fname); + /*! + * \brief dump model, return array of strings representing model dump + * \param handle handle + * \param fmap name to fmap can be empty string + * \param out_len length of output array + * \return char *data[], representing dump of each model + */ + const char** XGBoosterDumpModel(void *handle, const char *fmap, + size_t *out_len); +}; +#endif // XGBOOST_WRAPPER_H_ diff --git a/src/README.md b/src/README.md new file mode 100644 index 000000000..35d9b08e8 --- /dev/null +++ b/src/README.md @@ -0,0 +1,25 @@ +Coding Guide +====== + +Project Logical Layout +======= +* Dependency order: io->learner->gbm->tree + - All module depends on data.h +* tree are implementations of tree construction algorithms. +* gbm is gradient boosting interface, that takes trees and other base learner to do boosting. + - gbm only takes gradient as sufficient statistics, it does not compute the gradient. +* learner is learning module that computes gradient for specific object, and pass it to GBM + +File Naming Convention +======= +* The project is templatized, to make it easy to adjust input data structure. +* .h files are data structures and interface, which are needed to use functions in that layer. +* -inl.hpp files are implementations of interface, like cpp file in most project. + - You only need to understand the interface file to understand the usage of that layer + +How to Hack the Code +====== +* Add objective function: add to learner/objective-inl.hpp and register it in learner/objective.h ```CreateObjFunction``` + - You can also directly do it in python +* Add new evaluation metric: add to learner/evaluation-inl.hpp and register it in learner/evaluation.h ```CreateEvaluator``` +* Add wrapper for a new language, most likely you can do it by taking the functions in python/xgboost_wrapper.h, which is purely C based, and call these C functions to use xgboost diff --git a/src/data.h b/src/data.h index fe81b4dad..c60b58b8a 100644 --- a/src/data.h +++ b/src/data.h @@ -226,8 +226,12 @@ class FMatrixS : public FMatrixInterface{ if (this->HaveColAccess()) return; this->InitColData(max_nrow); } - /*! \brief get the row iterator associated with FMatrix */ + /*! + * \brief get the row iterator associated with FMatrix + * this function is not threadsafe, returns iterator stored in FMatrixS + */ inline utils::IIterator* RowIterator(void) const { + iter_->BeforeFirst(); return iter_; } /*! \brief set iterator */ diff --git a/src/io/io.cpp b/src/io/io.cpp index 2cf42aadf..4ddf61eb0 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -2,6 +2,7 @@ #define _CRT_SECURE_NO_DEPRECATE #include #include "./io.h" +#include "../utils/utils.h" #include "simple_dmatrix-inl.hpp" // implements data loads using dmatrix simple for now @@ -12,5 +13,10 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { dmat->CacheLoad(fname, silent, savebuffer); return dmat; } + +void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { + utils::Error("not implemented"); +} + } // namespace io } // namespace xgboost diff --git a/src/io/io.h b/src/io/io.h index d6d280d5e..211893509 100644 --- a/src/io/io.h +++ b/src/io/io.h @@ -28,8 +28,9 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent = false, bool savebuff * SaveDMatrix will choose the best way to materialize the dmatrix. * \param dmat the dmatrix to be saved * \param fname file name to be savd + * \param silent whether print message during saving */ -void SaveDMatrix(const DataMatrix &dmat, const char *fname); +void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent = false); } // namespace io } // namespace xgboost diff --git a/src/io/simple_dmatrix-inl.hpp b/src/io/simple_dmatrix-inl.hpp index 5da6d1c0b..f996b8d8c 100644 --- a/src/io/simple_dmatrix-inl.hpp +++ b/src/io/simple_dmatrix-inl.hpp @@ -23,7 +23,7 @@ namespace io { class DMatrixSimple : public DataMatrix { public: // constructor - DMatrixSimple(void) { + DMatrixSimple(void) : DataMatrix(kMagic) { this->fmat.set_iter(new OneBatchIter(this)); this->Clear(); } @@ -36,6 +36,24 @@ class DMatrixSimple : public DataMatrix { row_data_.clear(); info.Clear(); } + /*! \brief copy content data from source matrix */ + inline void CopyFrom(const DataMatrix &src) { + this->info = src.info; + this->Clear(); + // clone data content in thos matrix + utils::IIterator *iter = src.fmat.RowIterator(); + iter->BeforeFirst(); + while (iter->Next()) { + const SparseBatch &batch = iter->Value(); + for (size_t i = 0; i < batch.size; ++i) { + SparseBatch::Inst inst = batch[i]; + row_data_.resize(row_data_.size() + inst.length); + memcpy(&row_data_[row_ptr_.back()], inst.data, + sizeof(SparseBatch::Entry) * inst.length); + row_ptr_.push_back(row_ptr_.back() + inst.length); + } + } + } /*! * \brief add a row to the matrix * \param feats features @@ -183,7 +201,7 @@ class DMatrixSimple : public DataMatrix { protected: // one batch iterator that return content in the matrix struct OneBatchIter: utils::IIterator { - OneBatchIter(DMatrixSimple *parent) + explicit OneBatchIter(DMatrixSimple *parent) : at_first_(true), parent_(parent) {} virtual ~OneBatchIter(void) {} virtual void BeforeFirst(void) { diff --git a/src/learner/dmatrix.h b/src/learner/dmatrix.h index 88a865399..b558b070b 100644 --- a/src/learner/dmatrix.h +++ b/src/learner/dmatrix.h @@ -6,6 +6,7 @@ * used for regression/classification/ranking * \author Tianqi Chen */ +#include #include "../data.h" namespace xgboost { @@ -43,7 +44,7 @@ struct MetaInfo { } /*! \brief get weight of each instances */ inline float GetWeight(size_t i) const { - if(weights.size() != 0) { + if (weights.size() != 0) { return weights[i]; } else { return 1.0f; @@ -51,7 +52,7 @@ struct MetaInfo { } /*! \brief get root index of i-th instance */ inline float GetRoot(size_t i) const { - if(root_index.size() != 0) { + if (root_index.size() != 0) { return static_cast(root_index[i]); } else { return 0; @@ -76,7 +77,7 @@ struct MetaInfo { // try to load group information from file, if exists inline bool TryLoadGroup(const char* fname, bool silent = false) { FILE *fi = fopen64(fname, "r"); - if (fi == NULL) return false; + if (fi == NULL) return false; group_ptr.push_back(0); unsigned nline; while (fscanf(fi, "%u", &nline) == 1) { @@ -110,6 +111,11 @@ struct MetaInfo { */ template struct DMatrix { + /*! + * \brief magic number associated with this object + * used to check if it is specific instance + */ + const int magic; /*! \brief meta information about the dataset */ MetaInfo info; /*! \brief feature matrix about data content */ @@ -120,7 +126,7 @@ struct DMatrix { */ void *cache_learner_ptr_; /*! \brief default constructor */ - DMatrix(void) : cache_learner_ptr_(NULL) {} + explicit DMatrix(int magic) : magic(magic), cache_learner_ptr_(NULL) {} // virtual destructor virtual ~DMatrix(void){} }; diff --git a/src/learner/evaluation.h b/src/learner/evaluation.h index d51e5b767..fa25aa7d7 100644 --- a/src/learner/evaluation.h +++ b/src/learner/evaluation.h @@ -39,7 +39,7 @@ inline IEvaluator* CreateEvaluator(const char *name) { if (!strcmp(name, "merror")) return new EvalMatchError(); if (!strcmp(name, "logloss")) return new EvalLogLoss(); if (!strcmp(name, "auc")) return new EvalAuc(); - if (!strncmp(name, "ams@",4)) return new EvalAMS(name); + if (!strncmp(name, "ams@", 4)) return new EvalAMS(name); if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name); if (!strncmp(name, "map", 3)) return new EvalMAP(name); if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name); diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 3c04837c3..d7ad3f71d 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -78,6 +78,7 @@ class BoostLearner { inline void SetParam(const char *name, const char *val) { if (!strcmp(name, "silent")) silent = atoi(val); if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); + if (!strcmp("seed", name)) random::Seed(atoi(val)); if (gbm_ == NULL) { if (!strcmp(name, "objective")) name_obj_ = val; if (!strcmp(name, "booster")) name_gbm_ = val; @@ -132,16 +133,24 @@ class BoostLearner { utils::FileStream fo(utils::FopenCheck(fname, "wb")); this->SaveModel(fo); fo.Close(); - } + } + /*! + * \brief check if data matrix is ready to be used by training, + * if not intialize it + * \param p_train pointer to the matrix used by training + */ + inline void CheckInit(DMatrix *p_train) const { + p_train->fmat.InitColAccess(); + } /*! * \brief update the model for one iteration * \param iter current iteration number * \param p_train pointer to the data matrix */ - inline void UpdateOneIter(int iter, DMatrix *p_train) { - this->PredictRaw(*p_train, &preds_); - obj_->GetGradient(preds_, p_train->info, iter, &gpair_); - gbm_->DoBoost(gpair_, p_train->fmat, p_train->info.root_index); + inline void UpdateOneIter(int iter, const DMatrix &train) { + this->PredictRaw(train, &preds_); + obj_->GetGradient(preds_, train.info, iter, &gpair_); + gbm_->DoBoost(gpair_, train.fmat, train.info.root_index); } /*! * \brief evaluate the model for specific iteration diff --git a/src/xgboost_main.cpp b/src/xgboost_main.cpp index 16139f0d8..f3fc9201d 100644 --- a/src/xgboost_main.cpp +++ b/src/xgboost_main.cpp @@ -48,7 +48,6 @@ class BoostLearnTask{ inline void SetParam(const char *name, const char *val) { if (!strcmp("silent", name)) silent = atoi(val); if (!strcmp("use_buffer", name)) use_buffer = atoi(val); - if (!strcmp("seed", name)) random::Seed(atoi(val)); if (!strcmp("num_round", name)) num_round = atoi(val); if (!strcmp("save_period", name)) save_period = atoi(val); if (!strcmp("eval_train", name)) eval_train = atoi(val); @@ -103,9 +102,6 @@ class BoostLearnTask{ } else { // training data = io::LoadDataMatrix(train_path.c_str(), silent != 0, use_buffer != 0); - {// intialize column access - data->fmat.InitColAccess(); - } utils::Assert(eval_data_names.size() == eval_data_paths.size(), "BUG"); for (size_t i = 0; i < eval_data_names.size(); ++i) { deval.push_back(io::LoadDataMatrix(eval_data_paths[i].c_str(), silent != 0, use_buffer != 0)); @@ -139,10 +135,11 @@ class BoostLearnTask{ inline void TaskTrain(void) { const time_t start = time(NULL); unsigned long elapsed = 0; + learner.CheckInit(data); for (int i = 0; i < num_round; ++i) { elapsed = (unsigned long)(time(NULL) - start); if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); - learner.UpdateOneIter(i,data); + learner.UpdateOneIter(i, *data); std::string res = learner.EvalOneIter(i, devalall, eval_data_names); fprintf(stderr, "%s\n", res.c_str()); if (save_period != 0 && (i + 1) % save_period == 0) {