complete refactor data.h, now replies on iterator to access column

This commit is contained in:
tqchen@graphlab.com 2014-08-27 17:00:21 -07:00
parent a59f8945dc
commit 605269133e
15 changed files with 216 additions and 492 deletions

View File

@ -10,23 +10,22 @@ else
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp -funroll-loops
endif
# expose these flags to R CMD SHLIB
export PKG_CPPFLAGS = $(CFLAGS) -DXGBOOST_CUSTOMIZE_ERROR_
# specify tensor path
BIN = xgboost
OBJ =
SLIB = wrapper/libxgboostwrapper.so
RLIB = wrapper/libxgboostR.so
BIN =
OBJ = updater.o gbm.o xgboost_main.o
#SLIB = wrapper/libxgboostwrapper.so
#RLIB = wrapper/libxgboostR.so
.PHONY: clean all R python
all: $(BIN) wrapper/libxgboostwrapper.so
R: wrapper/libxgboostR.so
python: wrapper/libxgboostwrapper.so
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
all: $(BIN) $(OBJ)
#python: wrapper/libxgboostwrapper.so
#xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
# now the wrapper takes in two files. io and wrapper part
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
wrapper/libxgboostR.so: wrapper/xgboost_wrapper.cpp wrapper/xgboost_R.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
#wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
updater.o: src/tree/updater.cpp
gbm.o: src/gbm/gbm.cpp
xgboost_main.o: src/xgboost_main.cpp
$(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
@ -34,9 +33,6 @@ $(BIN) :
$(SLIB) :
$(CXX) $(CFLAGS) -fPIC $(LDFLAGS) -shared -o $@ $(filter %.cpp %.o %.c, $^)
$(RLIB) :
R CMD SHLIB -c -o $@ $(filter %.cpp %.o %.c, $^)
$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )

View File

@ -7,16 +7,8 @@
*/
#include <cstdio>
#include <vector>
#include <limits>
#include <climits>
#include <cstring>
#include <algorithm>
#include "utils/io.h"
#include "utils/omp.h"
#include "utils/utils.h"
#include "utils/iterator.h"
#include "utils/random.h"
#include "utils/matrix_csr.h"
namespace xgboost {
/*!
@ -96,308 +88,72 @@ struct SparseBatch {
};
/*! \brief batch size */
size_t size;
/*! \brief array[size+1], row pointer of each of the elements */
const size_t *row_ptr;
/*! \brief array[row_ptr.back()], content of the sparse element */
const Entry *data_ptr;
/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return Inst(data_ptr + row_ptr[i], static_cast<bst_uint>(row_ptr[i+1] - row_ptr[i]));
}
};
/*! \brief read-only row batch, used to access row continuously */
struct RowBatch : public SparseBatch {
/*! \brief the offset of rowid of this batch */
size_t base_rowid;
/*! \brief array[size+1], row pointer of each of the elements */
const size_t *ind_ptr;
/*! \brief array[ind_ptr.back()], content of the sparse element */
const Entry *data_ptr;
/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return Inst(data_ptr + ind_ptr[i], static_cast<bst_uint>(ind_ptr[i+1] - ind_ptr[i]));
}
};
/*!
* \brief read-only column batch, used to access columns,
* the columns are not required to be continuous
*/
struct ColBatch : public RowBatch {
struct ColBatch : public SparseBatch {
/*! \brief column index of each columns in the data */
bst_uint *col_index;
const bst_uint *col_index;
/*! \brief pointer to the column data */
const Inst *col_data;
/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return col_data[i];
}
};
/**
* \brief This is a interface convention via template, defining the way to access features,
* column access rule is defined by template, for efficiency purpose,
* row access is defined by iterator of sparse batches
* \tparam Derived type of actual implementation
* \brief interface of feature matrix, needed for tree construction
* this interface defines two way to access features,
* row access is defined by iterator of RowBatch
* col access is optional, checked by HaveColAccess, and defined by iterator of ColBatch
*/
template<typename Derived>
class FMatrixInterface {
class IFMatrix {
public:
/*! \brief example iterator over one column */
struct ColIter{
/*!
* \brief move to next position
* \return whether there is element in next position
*/
inline bool Next(void);
/*! \return row index of current position */
inline bst_uint rindex(void) const;
/*! \return feature value in current position */
inline bst_float fvalue(void) const;
};
/*! \brief backward iterator over column */
struct ColBackIter : public ColIter {};
public:
// column access is needed by some of tree construction algorithms
// the interface only need to ganrantee row iter
// column iter is active, when ColIterator is called, row_iter can be disabled
/*! \brief get the row iterator associated with FMatrix */
virtual utils::IIterator<RowBatch> *RowIterator(void) = 0;
/*!\brief get column iterator */
virtual utils::IIterator<ColBatch> *ColIterator(void) = 0;
/*!
* \brief get column iterator, the columns must be sorted by feature value
* \param cidx column index
* \return column iterator
* \brief get the column iterator associated with FMatrix with subset of column features
* \param fset is the list of column index set that must be contained in the returning Column iterator
* \return the column iterator, initialized so that it reads the elements in fset
*/
inline ColIter GetSortedCol(size_t cidx) const;
/*!
* \brief get column backward iterator, starts from biggest fvalue, and iterator back
* \param cidx column index
* \return reverse column iterator
*/
inline ColBackIter GetReverseSortedCol(size_t cidx) const;
/*!
* \brief get number of columns
* \return number of columns
*/
inline size_t NumCol(void) const;
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) = 0;
/*!
* \brief check if column access is supported, if not, initialize column access
* \param max_rows maximum number of rows allowed in constructor
* \param subsample subsample ratio when generating column access
*/
inline void InitColAccess(void);
virtual void InitColAccess(float subsample) = 0;
// the following are column meta data, should be able to answer them fast
/*! \return whether column access is enabled */
inline bool HaveColAccess(void) const;
/*! \breif return #entries-in-col */
inline size_t GetColSize(size_t cidx) const;
/*!
* \breif return #entries-in-col / #rows
* \param cidx column index
* this function is used to help speedup,
* doese not necessarily implement it if not sure, return 0.0;
* \return column density
*/
inline float GetColDensity(size_t cidx) const;
/*! \brief get the row iterator associated with FMatrix */
inline utils::IIterator<RowBatch>* RowIterator(void) const;
};
/*!
* \brief sparse matrix that support column access, CSC
*/
class FMatrixS : public FMatrixInterface<FMatrixS>{
public:
typedef RowBatch::Entry Entry;
/*! \brief row iterator */
struct ColIter{
const Entry *dptr_, *end_;
ColIter(const Entry* begin, const Entry* end)
:dptr_(begin), end_(end) {}
inline bool Next(void) {
if (dptr_ == end_) {
return false;
} else {
++dptr_; return true;
}
}
inline bst_uint rindex(void) const {
return dptr_->index;
}
inline bst_float fvalue(void) const {
return dptr_->fvalue;
}
};
/*! \brief reverse column iterator */
struct ColBackIter : public ColIter {
ColBackIter(const Entry* dptr, const Entry* end) : ColIter(dptr, end) {}
// shadows ColIter::Next
inline bool Next(void) {
if (dptr_ == end_) {
return false;
} else {
--dptr_; return true;
}
}
};
/*! \brief constructor */
FMatrixS(void) {
iter_ = NULL;
}
// destructor
~FMatrixS(void) {
if (iter_ != NULL) delete iter_;
}
/*! \return whether column access is enabled */
inline bool HaveColAccess(void) const {
return col_ptr_.size() != 0;
}
/*! \brief get number of colmuns */
inline size_t NumCol(void) const {
utils::Check(this->HaveColAccess(), "NumCol:need column access");
return col_ptr_.size() - 1;
}
/*! \brief get number of buffered rows */
inline const std::vector<bst_uint> buffered_rowset(void) const {
return buffered_rowset_;
}
/*! \brief get col sorted iterator */
inline ColIter GetSortedCol(size_t cidx) const {
utils::Assert(cidx < this->NumCol(), "col id exceed bound");
return ColIter(&col_data_[0] + col_ptr_[cidx] - 1,
&col_data_[0] + col_ptr_[cidx + 1] - 1);
}
/*!
* \brief get reversed col iterator,
* this function will be deprecated at some point
*/
inline ColBackIter GetReverseSortedCol(size_t cidx) const {
utils::Assert(cidx < this->NumCol(), "col id exceed bound");
return ColBackIter(&col_data_[0] + col_ptr_[cidx + 1],
&col_data_[0] + col_ptr_[cidx]);
}
/*! \brief get col size */
inline size_t GetColSize(size_t cidx) const {
return col_ptr_[cidx+1] - col_ptr_[cidx];
}
/*! \brief get column density */
inline float GetColDensity(size_t cidx) const {
size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
}
inline void InitColAccess(float pkeep = 1.0f) {
if (this->HaveColAccess()) return;
this->InitColData(pkeep);
}
/*!
* \brief get the row iterator associated with FMatrix
* this function is not threadsafe, returns iterator stored in FMatrixS
*/
inline utils::IIterator<RowBatch>* RowIterator(void) const {
iter_->BeforeFirst();
return iter_;
}
/*! \brief set iterator */
inline void set_iter(utils::IIterator<RowBatch> *iter) {
this->iter_ = iter;
}
/*!
* \brief save column access data into stream
* \param fo output stream to save to
*/
inline void SaveColAccess(utils::IStream &fo) const {
fo.Write(buffered_rowset_);
if (buffered_rowset_.size() != 0) {
SaveBinary(fo, col_ptr_, col_data_);
}
}
/*!
* \brief load column access data from stream
* \param fo output stream to load from
*/
inline void LoadColAccess(utils::IStream &fi) {
utils::Check(fi.Read(&buffered_rowset_), "invalid input file format");
if (buffered_rowset_.size() != 0) {
LoadBinary(fi, &col_ptr_, &col_data_);
}
}
/*!
* \brief save data to binary stream
* \param fo output stream
* \param ptr pointer data
* \param data data content
*/
inline static void SaveBinary(utils::IStream &fo,
const std::vector<size_t> &ptr,
const std::vector<RowBatch::Entry> &data) {
size_t nrow = ptr.size() - 1;
fo.Write(&nrow, sizeof(size_t));
fo.Write(&ptr[0], ptr.size() * sizeof(size_t));
if (data.size() != 0) {
fo.Write(&data[0], data.size() * sizeof(RowBatch::Entry));
}
}
/*!
* \brief load data from binary stream
* \param fi input stream
* \param out_ptr pointer data
* \param out_data data content
*/
inline static void LoadBinary(utils::IStream &fi,
std::vector<size_t> *out_ptr,
std::vector<RowBatch::Entry> *out_data) {
size_t nrow;
utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format");
out_ptr->resize(nrow + 1);
utils::Check(fi.Read(&(*out_ptr)[0], out_ptr->size() * sizeof(size_t)) != 0,
"invalid input file format");
out_data->resize(out_ptr->back());
if (out_data->size() != 0) {
utils::Assert(fi.Read(&(*out_data)[0], out_data->size() * sizeof(RowBatch::Entry)) != 0,
"invalid input file format");
}
}
protected:
/*!
* \brief intialize column data
* \param pkeep probability to keep a row
*/
inline void InitColData(float pkeep) {
buffered_rowset_.clear();
// note: this part of code is serial, todo, parallelize this transformer
utils::SparseCSRMBuilder<RowBatch::Entry> builder(col_ptr_, col_data_);
builder.InitBudget(0);
// start working
iter_->BeforeFirst();
while (iter_->Next()) {
const RowBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
if (pkeep == 1.0f || random::SampleBinary(pkeep)) {
buffered_rowset_.push_back(static_cast<bst_uint>(batch.base_rowid+i));
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].index);
}
}
}
}
builder.InitStorage();
iter_->BeforeFirst();
size_t ktop = 0;
while (iter_->Next()) {
const RowBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
if (ktop < buffered_rowset_.size() &&
buffered_rowset_[ktop] == batch.base_rowid+i) {
++ktop;
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.PushElem(inst[j].index,
Entry((bst_uint)(batch.base_rowid+i),
inst[j].fvalue));
}
}
}
}
// sort columns
bst_omp_uint ncol = static_cast<bst_omp_uint>(this->NumCol());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ncol; ++i) {
std::sort(&col_data_[0] + col_ptr_[i],
&col_data_[0] + col_ptr_[i + 1], Entry::CmpValue);
}
}
private:
// --- data structure used to support InitColAccess --
utils::IIterator<RowBatch> *iter_;
/*! \brief list of row index that are buffered */
std::vector<bst_uint> buffered_rowset_;
/*! \brief column pointer of CSC format */
std::vector<size_t> col_ptr_;
/*! \brief column datas in CSC format */
std::vector<RowBatch::Entry> col_data_;
virtual bool HaveColAccess(void) const = 0;
/*! \return number of columns in the FMatrix */
virtual size_t NumCol(void) const = 0;
/*! \brief get number of non-missing entries in column */
virtual float GetColSize(size_t cidx) const = 0;
/*! \brief get column density */
virtual float GetColDensity(size_t cidx) const = 0;
/*! \brief reference of buffered rowset */
virtual const std::vector<bst_uint> &buffered_rowset(void) const = 0;
// virtual destructor
virtual ~IFMatrix(void){}
};
} // namespace xgboost
#endif // XGBOOST_DATA_H

View File

@ -18,8 +18,7 @@ namespace gbm {
* \brief gradient boosted linear model
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class GBLinear : public IGradBooster<FMatrix> {
class GBLinear : public IGradBooster {
public:
virtual ~GBLinear(void) {
}
@ -41,13 +40,12 @@ class GBLinear : public IGradBooster<FMatrix> {
virtual void InitModel(void) {
model.InitModel();
}
virtual void DoBoost(const FMatrix &fmat,
virtual void DoBoost(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<bst_gpair> *in_gpair) {
this->InitFeatIndex(fmat);
std::vector<bst_gpair> &gpair = *in_gpair;
const int ngroup = model.param.num_output_group;
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
// for all the output group
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
@ -72,42 +70,46 @@ class GBLinear : public IGradBooster<FMatrix> {
}
}
}
// number of features
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(feat_index.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const bst_uint fid = feat_index[i];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
const float v = it.fvalue();
bst_gpair &p = gpair[it.rindex() * ngroup + gid];
if (p.hess < 0.0f) continue;
sum_grad += p.grad * v;
sum_hess += p.hess * v * v;
}
float &w = model[fid][gid];
bst_float dw = static_cast<bst_float>(param.learning_rate * param.CalcDelta(sum_grad, sum_hess, w));
w += dw;
// update grad value
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
bst_gpair &p = gpair[it.rindex() * ngroup + gid];
if (p.hess < 0.0f) continue;
p.grad += p.hess * it.fvalue() * dw;
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
while (iter->Next()) {
// number of features
const ColBatch &batch = iter->Value();
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const bst_uint fid = batch.col_index[i];
ColBatch::Inst col = batch[i];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (bst_uint j = 0; j < col.length; ++j) {
const float v = col[j].fvalue;
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.hess < 0.0f) continue;
sum_grad += p.grad * v;
sum_hess += p.hess * v * v;
}
float &w = model[fid][gid];
bst_float dw = static_cast<bst_float>(param.learning_rate * param.CalcDelta(sum_grad, sum_hess, w));
w += dw;
// update grad value
for (bst_uint j = 0; j < col.length; ++j) {
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.hess < 0.0f) continue;
p.grad += p.hess * col[j].fvalue * dw;
}
}
}
}
}
virtual void Predict(const FMatrix &fmat,
virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<float> *out_preds) {
std::vector<float> &preds = *out_preds;
preds.resize(0);
// start collecting the prediction
utils::IIterator<RowBatch> *iter = fmat.RowIterator();
iter->BeforeFirst();
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
const int ngroup = model.param.num_output_group;
while (iter->Next()) {
const RowBatch &batch = iter->Value();
@ -134,18 +136,6 @@ class GBLinear : public IGradBooster<FMatrix> {
}
protected:
inline void InitFeatIndex(const FMatrix &fmat) {
if (feat_index.size() != 0) return;
// initialize feature index
unsigned ncol = static_cast<unsigned>(fmat.NumCol());
feat_index.reserve(ncol);
for (unsigned i = 0; i < ncol; ++i) {
if (fmat.GetColSize(i) != 0) {
feat_index.push_back(i);
}
}
random::Shuffle(feat_index);
}
inline void Pred(const RowBatch::Inst &inst, float *preds) {
for (int gid = 0; gid < model.param.num_output_group; ++gid) {
float psum = model.bias()[gid];

View File

@ -7,6 +7,7 @@
*/
#include <vector>
#include "../data.h"
#include "../utils/io.h"
#include "../utils/fmap.h"
namespace xgboost {
@ -14,9 +15,7 @@ namespace xgboost {
namespace gbm {
/*!
* \brief interface of gradient boosting model
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class IGradBooster {
public:
/*!
@ -41,17 +40,17 @@ class IGradBooster {
virtual void InitModel(void) = 0;
/*!
* \brief peform update to the model(boosting)
* \param fmat feature matrix that provide access to features
* \param p_fmat feature matrix that provide access to features
* \param info meta information about training
* \param in_gpair address of the gradient pair statistics of the data
* the booster may change content of gpair
*/
virtual void DoBoost(const FMatrix &fmat,
virtual void DoBoost(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<bst_gpair> *in_gpair) = 0;
/*!
* \brief generate predictions for given feature matrix
* \param fmat feature matrix
* \param p_fmat feature matrix
* \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction
@ -59,7 +58,7 @@ class IGradBooster {
* \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions
*/
virtual void Predict(const FMatrix &fmat,
virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<float> *out_preds) = 0;
@ -73,21 +72,11 @@ class IGradBooster {
// destrcutor
virtual ~IGradBooster(void){}
};
} // namespace gbm
} // namespace xgboost
#include "gbtree-inl.hpp"
#include "gblinear-inl.hpp"
namespace xgboost {
namespace gbm {
template<typename FMatrix>
inline IGradBooster<FMatrix>* CreateGradBooster(const char *name) {
if (!strcmp("gbtree", name)) return new GBTree<FMatrix>();
if (!strcmp("gblinear", name)) return new GBLinear<FMatrix>();
utils::Error("unknown booster type: %s", name);
return NULL;
}
/*!
* \breif create a gradient booster from given name
* \param name name of gradient booster
*/
IGradBooster* CreateGradBooster(const char *name);
} // namespace gbm
} // namespace xgboost
#endif // XGBOOST_GBM_GBM_H_

View File

@ -9,16 +9,15 @@
#include <utility>
#include <string>
#include "./gbm.h"
#include "../utils/omp.h"
#include "../tree/updater.h"
namespace xgboost {
namespace gbm {
/*!
* \brief gradient boosted tree
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class GBTree : public IGradBooster<FMatrix> {
class GBTree : public IGradBooster {
public:
virtual ~GBTree(void) {
this->Clear();
@ -82,12 +81,12 @@ class GBTree : public IGradBooster<FMatrix> {
utils::Assert(mparam.num_trees == 0, "GBTree: model already initialized");
utils::Assert(trees.size() == 0, "GBTree: model already initialized");
}
virtual void DoBoost(const FMatrix &fmat,
virtual void DoBoost(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<bst_gpair> *in_gpair) {
const std::vector<bst_gpair> &gpair = *in_gpair;
if (mparam.num_output_group == 1) {
this->BoostNewTrees(gpair, fmat, info, 0);
this->BoostNewTrees(gpair, p_fmat, info, 0);
} else {
const int ngroup = mparam.num_output_group;
utils::Check(gpair.size() % ngroup == 0,
@ -99,11 +98,11 @@ class GBTree : public IGradBooster<FMatrix> {
for (bst_omp_uint i = 0; i < nsize; ++i) {
tmp[i] = gpair[i * ngroup + gid];
}
this->BoostNewTrees(tmp, fmat, info, gid);
this->BoostNewTrees(tmp, p_fmat, info, gid);
}
}
}
virtual void Predict(const FMatrix &fmat,
virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<float> *out_preds) {
@ -121,7 +120,7 @@ class GBTree : public IGradBooster<FMatrix> {
const size_t stride = info.num_row * mparam.num_output_group;
preds.resize(stride * (mparam.size_leaf_vector+1));
// start collecting the prediction
utils::IIterator<RowBatch> *iter = fmat.RowIterator();
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
@ -172,7 +171,7 @@ class GBTree : public IGradBooster<FMatrix> {
char *pstr;
pstr = strtok(&tval[0], ",");
while (pstr != NULL) {
updaters.push_back(tree::CreateUpdater<FMatrix>(pstr));
updaters.push_back(tree::CreateUpdater(pstr));
for (size_t j = 0; j < cfg.size(); ++j) {
// set parameters
updaters.back()->SetParam(cfg[j].first.c_str(), cfg[j].second.c_str());
@ -183,7 +182,7 @@ class GBTree : public IGradBooster<FMatrix> {
}
// do group specific group
inline void BoostNewTrees(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
int bst_group) {
this->InitUpdater();
@ -198,7 +197,7 @@ class GBTree : public IGradBooster<FMatrix> {
}
// update the trees
for (size_t i = 0; i < updaters.size(); ++i) {
updaters[i]->Update(gpair, fmat, info, new_trees);
updaters[i]->Update(gpair, p_fmat, info, new_trees);
}
// push back to model
for (size_t i = 0; i < new_trees.size(); ++i) {
@ -361,7 +360,7 @@ class GBTree : public IGradBooster<FMatrix> {
// temporal storage for per thread
std::vector<tree::RegTree::FVec> thread_temp;
// the updaters that can be applied to each of tree
std::vector< tree::IUpdater<FMatrix>* > updaters;
std::vector<tree::IUpdater*> updaters;
};
} // namespace gbm

View File

@ -13,7 +13,7 @@ namespace xgboost {
/*! \brief namespace related to data format */
namespace io {
/*! \brief DMatrix object that I/O module support save/load */
typedef learner::DMatrix<FMatrixS> DataMatrix;
typedef learner::DMatrix DataMatrix;
/*!
* \brief load DataMatrix from stream
* \param fname file name to be loaded

View File

@ -229,7 +229,7 @@ class DMatrixSimple : public DataMatrix {
at_first_ = false;
batch_.size = parent_->row_ptr_.size() - 1;
batch_.base_rowid = 0;
batch_.row_ptr = &parent_->row_ptr_[0];
batch_.ind_ptr = &parent_->row_ptr_[0];
batch_.data_ptr = &parent_->row_data_[0];
return true;
}

View File

@ -8,7 +8,7 @@
*/
#include <vector>
#include "../data.h"
#include "../utils/io.h"
namespace xgboost {
namespace learner {
/*!
@ -142,7 +142,6 @@ struct MetaInfo {
* \brief data object used for learning,
* \tparam FMatrix type of feature data source
*/
template<typename FMatrix>
struct DMatrix {
/*!
* \brief magic number associated with this object
@ -152,7 +151,7 @@ struct DMatrix {
/*! \brief meta information about the dataset */
MetaInfo info;
/*! \brief feature matrix about data content */
FMatrix fmat;
IFMatrix *fmat;
/*!
* \brief cache pointer to verify if the data structure is cached in some learner
* used to verify if DMatrix is cached
@ -161,7 +160,9 @@ struct DMatrix {
/*! \brief default constructor */
explicit DMatrix(int magic) : magic(magic), cache_learner_ptr_(NULL) {}
// virtual destructor
virtual ~DMatrix(void){}
virtual ~DMatrix(void){
delete fmat;
}
};
} // namespace learner

View File

@ -21,7 +21,6 @@ namespace learner {
* \brief learner that takes do gradient boosting on specific objective functions
* and do training and prediction
*/
template<typename FMatrix>
class BoostLearner {
public:
BoostLearner(void) {
@ -44,7 +43,7 @@ class BoostLearner {
* data matrices to continue training otherwise it will cause error
* \param mats array of pointers to matrix whose prediction result need to be cached
*/
inline void SetCacheData(const std::vector<DMatrix<FMatrix>*>& mats) {
inline void SetCacheData(const std::vector<DMatrix*>& mats) {
// estimate feature bound
unsigned num_feature = 0;
// assign buffer index
@ -158,15 +157,15 @@ class BoostLearner {
* if not intialize it
* \param p_train pointer to the matrix used by training
*/
inline void CheckInit(DMatrix<FMatrix> *p_train) {
p_train->fmat.InitColAccess(prob_buffer_row);
inline void CheckInit(DMatrix *p_train) {
p_train->fmat->InitColAccess(prob_buffer_row);
}
/*!
* \brief update the model for one iteration
* \param iter current iteration number
* \param p_train pointer to the data matrix
*/
inline void UpdateOneIter(int iter, const DMatrix<FMatrix> &train) {
inline void UpdateOneIter(int iter, const DMatrix &train) {
this->PredictRaw(train, &preds_);
obj_->GetGradient(preds_, train.info, iter, &gpair_);
gbm_->DoBoost(train.fmat, train.info.info, &gpair_);
@ -179,7 +178,7 @@ class BoostLearner {
* \return a string corresponding to the evaluation result
*/
inline std::string EvalOneIter(int iter,
const std::vector<const DMatrix<FMatrix>*> &evals,
const std::vector<const DMatrix*> &evals,
const std::vector<std::string> &evname) {
std::string res;
char tmp[256];
@ -198,7 +197,7 @@ class BoostLearner {
* \param metric name of metric
* \return a pair of <evaluation name, result>
*/
std::pair<std::string, float> Evaluate(const DMatrix<FMatrix> &data, std::string metric) {
std::pair<std::string, float> Evaluate(const DMatrix &data, std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric();
IEvaluator *ev = CreateEvaluator(metric.c_str());
this->PredictRaw(data, &preds_);
@ -213,7 +212,7 @@ class BoostLearner {
* \param output_margin whether to only predict margin value instead of transformed prediction
* \param out_preds output vector that stores the prediction
*/
inline void Predict(const DMatrix<FMatrix> &data,
inline void Predict(const DMatrix &data,
bool output_margin,
std::vector<float> *out_preds) const {
this->PredictRaw(data, out_preds);
@ -235,7 +234,7 @@ class BoostLearner {
if (obj_ != NULL) return;
utils::Assert(gbm_ == NULL, "GBM and obj should be NULL");
obj_ = CreateObjFunction(name_obj_.c_str());
gbm_ = gbm::CreateGradBooster<FMatrix>(name_gbm_.c_str());
gbm_ = gbm::CreateGradBooster(name_gbm_.c_str());
for (size_t i = 0; i < cfg_.size(); ++i) {
obj_->SetParam(cfg_[i].first.c_str(), cfg_[i].second.c_str());
gbm_->SetParam(cfg_[i].first.c_str(), cfg_[i].second.c_str());
@ -247,7 +246,7 @@ class BoostLearner {
* \param data training data matrix
* \param out_preds output vector that stores the prediction
*/
inline void PredictRaw(const DMatrix<FMatrix> &data,
inline void PredictRaw(const DMatrix &data,
std::vector<float> *out_preds) const {
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
data.info.info, out_preds);
@ -307,7 +306,7 @@ class BoostLearner {
// model parameter
ModelParam mparam;
// gbm model that back everything
gbm::IGradBooster<FMatrix> *gbm_;
gbm::IGradBooster *gbm_;
// name of gbm model used for training
std::string name_gbm_;
// objective fnction
@ -324,14 +323,14 @@ class BoostLearner {
private:
// cache entry object that helps handle feature caching
struct CacheEntry {
const DMatrix<FMatrix> *mat_;
const DMatrix *mat_;
size_t buffer_offset_;
size_t num_row_;
CacheEntry(const DMatrix<FMatrix> *mat, size_t buffer_offset, size_t num_row)
CacheEntry(const DMatrix *mat, size_t buffer_offset, size_t num_row)
:mat_(mat), buffer_offset_(buffer_offset), num_row_(num_row) {}
};
// find internal bufer offset for certain matrix, if not exist, return -1
inline int64_t FindBufferOffset(const DMatrix<FMatrix> &mat) const {
inline int64_t FindBufferOffset(const DMatrix &mat) const {
for (size_t i = 0; i < cache_.size(); ++i) {
if (cache_[i].mat_ == &mat && mat.cache_learner_ptr_ == this) {
if (cache_[i].num_row_ == mat.info.num_row()) {

View File

@ -14,9 +14,7 @@ namespace xgboost {
namespace tree {
/*!
* \brief interface of tree update module, that performs update of a tree
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class IUpdater {
public:
/*!
@ -28,7 +26,7 @@ class IUpdater {
/*!
* \brief peform update to the tree models
* \param gpair the gradient pair statistics of the data
* \param fmat feature matrix that provide access to features
* \param p_fmat feature matrix that provide access to features
* \param info extra side information that may be need, such as root index
* \param trees pointer to the trese to be updated, upater will change the content of the tree
* note: all the trees in the vector are updated, with the same statistics,
@ -36,36 +34,18 @@ class IUpdater {
* there can be multiple trees when we train random forest style model
*/
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) = 0;
// destructor
virtual ~IUpdater(void) {}
};
} // namespace tree
} // namespace xgboost
#include "./updater_prune-inl.hpp"
#include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp"
namespace xgboost {
namespace tree {
/*!
* \brief create a updater based on name
* \param name name of updater
* \return return the updater instance
*/
template<typename FMatrix>
inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
if (!strcmp(name, "prune")) return new TreePruner<FMatrix>();
if (!strcmp(name, "refresh")) return new TreeRefresher<FMatrix, GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
utils::Error("unknown updater:%s", name);
return NULL;
}
IUpdater* CreateUpdater(const char *name);
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_H_

View File

@ -15,8 +15,8 @@
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix, typename TStats>
class ColMaker: public IUpdater<FMatrix> {
template<typename TStats>
class ColMaker: public IUpdater {
public:
virtual ~ColMaker(void) {}
// set training parameter
@ -24,7 +24,7 @@ class ColMaker: public IUpdater<FMatrix> {
param.SetParam(name, val);
}
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
TStats::CheckInfo(info);
@ -34,7 +34,7 @@ class ColMaker: public IUpdater<FMatrix> {
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param);
builder.Update(gpair, fmat, info, trees[i]);
builder.Update(gpair, p_fmat, info, trees[i]);
}
param.learning_rate = lr;
}
@ -77,16 +77,16 @@ class ColMaker: public IUpdater<FMatrix> {
explicit Builder(const TrainParam &param) : param(param) {}
// update one tree, growing
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
RegTree *p_tree) {
this->InitData(gpair, fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, *p_fmat, info, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) {
this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree);
this->ResetPosition(this->qexpand, fmat, *p_tree);
this->FindSplit(depth, this->qexpand, gpair, p_fmat, info, p_tree);
this->ResetPosition(this->qexpand, p_fmat, *p_tree);
this->UpdateQueueExpand(*p_tree, &this->qexpand);
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
this->InitNewNode(qexpand, gpair, *p_fmat, info, *p_tree);
// if nothing left to be expand, break
if (qexpand.size() == 0) break;
}
@ -107,7 +107,7 @@ class ColMaker: public IUpdater<FMatrix> {
private:
// initialize temp data structure
inline void InitData(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const IFMatrix &fmat,
const std::vector<unsigned> &root_index, const RegTree &tree) {
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
@ -137,8 +137,7 @@ class ColMaker: public IUpdater<FMatrix> {
if (random::SampleBinary(param.subsample) == 0) position[ridx] = -1;
}
}
}
}
{
// initialize feature index
unsigned ncol = static_cast<unsigned>(fmat.NumCol());
@ -175,7 +174,7 @@ class ColMaker: public IUpdater<FMatrix> {
/*! \brief initialize the base_weight, root_gain, and NodeEntry for all the new nodes in qexpand */
inline void InitNewNode(const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const IFMatrix &fmat,
const BoosterInfo &info,
const RegTree &tree) {
{// setup statistics space for each tree node
@ -222,24 +221,25 @@ class ColMaker: public IUpdater<FMatrix> {
qexpand = newnodes;
}
// enumerate the split values of specific feature
template<typename Iter>
inline void EnumerateSplit(Iter it, unsigned fid,
inline void EnumerateSplit(const ColBatch::Entry *begin,
const ColBatch::Entry *end,
int d_step,
bst_uint fid,
const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
std::vector<ThreadEntry> &temp,
bool is_forward_search) {
std::vector<ThreadEntry> &temp) {
// clear all the temp statistics
for (size_t j = 0; j < qexpand.size(); ++j) {
temp[qexpand[j]].stats.Clear();
}
// left statistics
TStats c(param);
while (it.Next()) {
const bst_uint ridx = it.rindex();
for(const ColBatch::Entry *it = begin; it != end; it += d_step) {
const bst_uint ridx = it->index;
const int nid = position[ridx];
if (nid < 0) continue;
// start working
const float fvalue = it.fvalue();
const float fvalue = it->fvalue;
// get the statistics of nid
ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init
@ -252,7 +252,7 @@ class ColMaker: public IUpdater<FMatrix> {
c.SetSubstract(snode[nid].stats, e.stats);
if (c.sum_hess >= param.min_child_weight) {
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
}
}
// update the statistics
@ -267,38 +267,46 @@ class ColMaker: public IUpdater<FMatrix> {
c.SetSubstract(snode[nid].stats, e.stats);
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain);
const float delta = is_forward_search ? rt_eps : -rt_eps;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
const float delta = d_step == +1 ? rt_eps : -rt_eps;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
}
}
}
// find splits at current level, do split per level
inline void FindSplit(int depth, const std::vector<int> &qexpand,
inline void FindSplit(int depth,
const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
RegTree *p_tree) {
std::vector<unsigned> feat_set = feat_index;
std::vector<bst_uint> feat_set = feat_index;
if (param.colsample_bylevel != 1.0f) {
random::Shuffle(feat_set);
unsigned n = static_cast<unsigned>(param.colsample_bylevel * feat_index.size());
utils::Check(n > 0, "colsample_bylevel is too small that no feature can be included");
feat_set.resize(n);
}
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(feat_set.size());
#if defined(_OPENMP)
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
#endif
#pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const unsigned fid = feat_set[i];
const int tid = omp_get_thread_num();
if (param.need_forward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, info, stemp[tid], true);
}
if (param.need_backward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, info, stemp[tid], false);
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(feat_set);
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#if defined(_OPENMP)
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
#endif
#pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const bst_uint fid = batch.col_index[i];
const int tid = omp_get_thread_num();
const ColBatch::Inst c = batch[i];
if (param.need_forward_search(p_fmat->GetColDensity(fid))) {
this->EnumerateSplit(c.data, c.data + c.length, +1,
fid, gpair, info, stemp[tid]);
}
if (param.need_backward_search(p_fmat->GetColDensity(fid))) {
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
fid, gpair, info, stemp[tid]);
}
}
}
// after this each thread's stemp will get the best candidates, aggregate results
@ -318,8 +326,8 @@ class ColMaker: public IUpdater<FMatrix> {
}
}
// reset position of each data points after split is created in the tree
inline void ResetPosition(const std::vector<int> &qexpand, const FMatrix &fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
// step 1, set default direct nodes to default, and leaf nodes to -1
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static)
@ -343,22 +351,28 @@ class ColMaker: public IUpdater<FMatrix> {
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
// start put things into right place
const bst_omp_uint nfeats = static_cast<bst_omp_uint>(fsplits.size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nfeats; ++i) {
const unsigned fid = fsplits[i];
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
const bst_uint ridx = it.rindex();
int nid = position[ridx];
if (nid == -1) continue;
// go back to parent, correct those who are not default
nid = tree[nid].parent();
if (tree[nid].split_index() == fid) {
if (it.fvalue() < tree[nid].split_cond()) {
position[ridx] = tree[nid].cleft();
} else {
position[ridx] = tree[nid].cright();
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fsplits);
while (iter->Next()) {
const ColBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
ColBatch::Inst col = batch[i];
const bst_uint fid = batch.col_index[i];
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue;
int nid = position[ridx];
if (nid == -1) continue;
// go back to parent, correct those who are not default
nid = tree[nid].parent();
if (tree[nid].split_index() == fid) {
if (fvalue < tree[nid].split_cond()) {
position[ridx] = tree[nid].cleft();
} else {
position[ridx] = tree[nid].cright();
}
}
}
}
@ -369,7 +383,7 @@ class ColMaker: public IUpdater<FMatrix> {
// number of omp thread used during training
int nthread;
// Per feature: shuffle index of each feature index
std::vector<unsigned> feat_index;
std::vector<bst_uint> feat_index;
// Instance Data: current node position in the tree of each instance
std::vector<int> position;
// PerThread x PerTreeNode: statistics for per thread construction

View File

@ -12,8 +12,7 @@
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix>
class TreePruner: public IUpdater<FMatrix> {
class TreePruner: public IUpdater {
public:
virtual ~TreePruner(void) {}
// set training parameter
@ -23,7 +22,7 @@ class TreePruner: public IUpdater<FMatrix> {
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees
@ -75,7 +74,6 @@ class TreePruner: public IUpdater<FMatrix> {
// training parameter
TrainParam param;
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_

View File

@ -9,12 +9,13 @@
#include <limits>
#include "./param.h"
#include "./updater.h"
#include "../utils/omp.h"
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix, typename TStats>
class TreeRefresher: public IUpdater<FMatrix> {
template<typename TStats>
class TreeRefresher: public IUpdater {
public:
virtual ~TreeRefresher(void) {}
// set training parameter
@ -23,7 +24,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
if (trees.size() == 0) return;
@ -50,7 +51,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
fvec_temp[tid].Init(trees[0]->param.num_feature);
}
// start accumulating statistics
utils::IIterator<RowBatch> *iter = fmat.RowIterator();
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();

View File

@ -8,6 +8,7 @@
#define _CRT_SECURE_NO_WARNINGS
#include <cstdio>
#include <cstdarg>
#include <string>
#include <cstdlib>
#ifdef _MSC_VER
#define fopen64 fopen

View File

@ -234,7 +234,7 @@ class BoostLearnTask{
std::vector<io::DataMatrix*> deval;
std::vector<const io::DataMatrix*> devalall;
utils::FeatMap fmap;
learner::BoostLearner<FMatrixS> learner;
learner::BoostLearner learner;
};
}