Merge pull request #51 from tqchen/unity

merge unity into master, R package ready
This commit is contained in:
Tianqi Chen 2014-08-27 21:13:38 -07:00
commit 582e4e3d8c
46 changed files with 865 additions and 1408 deletions

View File

@ -1,32 +1,32 @@
export CC = gcc
export CXX = g++
export LDFLAGS= -pthread -lm
# note for R module
# add include path to Rinternals.h here
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC
ifeq ($(no_omp),1)
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -DDISABLE_OPENMP
else
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp
CFLAGS += -DDISABLE_OPENMP
else
CFLAGS += -fopenmp
endif
# expose these flags to R CMD SHLIB
export PKG_CPPFLAGS = $(CFLAGS) -DXGBOOST_CUSTOMIZE_ERROR_
# specify tensor path
BIN = xgboost
OBJ =
OBJ = updater.o gbm.o io.o
SLIB = wrapper/libxgboostwrapper.so
RLIB = wrapper/libxgboostR.so
.PHONY: clean all R
all: $(BIN) wrapper/libxgboostwrapper.so
R: wrapper/libxgboostR.so
.PHONY: clean all python
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
all: $(BIN) $(OBJ) $(SLIB)
python: wrapper/libxgboostwrapper.so
# now the wrapper takes in two files. io and wrapper part
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
wrapper/libxgboostR.so: wrapper/xgboost_wrapper.cpp wrapper/xgboost_R.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp $(OBJ)
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
xgboost: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h $(OBJ)
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h $(OBJ)
$(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
@ -34,9 +34,6 @@ $(BIN) :
$(SLIB) :
$(CXX) $(CFLAGS) -fPIC $(LDFLAGS) -shared -o $@ $(filter %.cpp %.o %.c, $^)
$(RLIB) :
R CMD SHLIB -c -o $@ $(filter %.cpp %.o %.c, $^)
$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
@ -44,4 +41,4 @@ install:
cp -f -r $(BIN) $(INSTALL_PATH)
clean:
$(RM) $(OBJ) $(BIN) $(SLIB) $(RLIB) *~ */*~ */*/*~
$(RM) $(OBJ) $(BIN) $(SLIB) *.o *~ */*~ */*/*~

View File

@ -8,4 +8,5 @@ export(xgb.train)
export(xgb.save)
export(xgb.load)
export(xgb.dump)
export(xgb.Booster)
export(xgb.DMatrix.save)

View File

@ -1,5 +1,4 @@
# Main function for xgboost-package
xgboost <- function(data = NULL, label = NULL, params = list(), nrounds = 10,
verbose = 1, ...) {
inClass <- class(data)

View File

@ -93,20 +93,22 @@ print(paste("error=", err))
############################ Save and load model to hard disk
# save model to binary local file
xgb.save(bst, "model.save")
xgb.save(bst, "xgboost.model")
# load binary model to R
bst <- xgb.load("model.save")
bst <- xgb.load("xgboost.model")
pred <- predict(bst, test.x)
# save model to text file
xgb.dump(bst, "model.dump")
xgb.dump(bst, "dump.raw.txt")
# save model to text file, with feature map
xgb.dump(bst, "dump.nice.txt", "featmap.txt")
# save a DMatrix object to hard disk
xgb.DMatrix.save(dtrain, "dtrain.save")
xgb.DMatrix.save(dtrain, "dtrain.buffer")
# load a DMatrix object to R
dtrain <- xgb.DMatrix("dtrain.save")
dtrain <- xgb.DMatrix("dtrain.buffer")
############################ More flexible training function xgb.train

View File

@ -10,7 +10,7 @@ ifeq ($(no_omp),1)
PKG_CPPFLAGS += -DDISABLE_OPENMP
endif
CXXOBJ= xgboost_wrapper.o xgboost_io.o
CXXOBJ= xgboost_wrapper.o xgboost_io.o xgboost_gbm.o xgboost_updater.o
OBJECTS= xgboost_R.o $(CXXOBJ)
.PHONY: all clean
@ -18,7 +18,9 @@ all: $(SHLIB)
$(SHLIB): $(OBJECTS)
xgboost_wrapper.o: ../../wrapper/xgboost_wrapper.cpp
xgboost_io.o: ../../src/io/io.cpp
xgboost_io.o: ../../src/io/io.cpp
xgboost_gbm.o: ../../src/gbm/gbm.cpp
xgboost_updater.o: ../../src/tree/updater.cpp
$(CXXOBJ) :
$(CXX) -c $(PKG_CPPFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )

View File

@ -15,7 +15,7 @@ ifeq ($(no_omp),1)
PKG_CPPFLAGS += -DDISABLE_OPENMP
endif
CXXOBJ= xgboost_wrapper.o xgboost_io.o
CXXOBJ= xgboost_wrapper.o xgboost_io.o xgboost_gbm.o xgboost_updater.o
OBJECTS= xgboost_R.o $(CXXOBJ)
.PHONY: all clean
@ -24,6 +24,8 @@ $(SHLIB): $(OBJECTS)
xgboost_wrapper.o: ../../wrapper/xgboost_wrapper.cpp
xgboost_io.o: ../../src/io/io.cpp
xgboost_gbm.o: ../../src/gbm/gbm.cpp
xgboost_updater.o: ../../src/tree/updater.cpp
$(CXXOBJ) :
$(CXX) -c $(PKG_CPPFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )

View File

@ -1,5 +1,6 @@
# include xgboost library, must set chdir=TRURE
source("../../wrapper/xgboost.R", chdir=TRUE)
# install xgboost package, see R-package in root folder
require(xgboost)
require(methods)
modelfile <- "higgs.model"
outfile <- "higgs.pred.csv"

View File

@ -1,5 +1,7 @@
# include xgboost library, must set chdir=TRURE
source("../../wrapper/xgboost.R", chdir=TRUE)
# install xgboost package, see R-package in root folder
require(xgboost)
require(methods)
testsize <- 550000
dtrain <- read.csv("data/training.csv", header=TRUE)
@ -12,7 +14,7 @@ sumwpos <- sum(weight * (label==1.0))
sumwneg <- sum(weight * (label==0.0))
print(paste("weight statistics: wpos=", sumwpos, "wneg=", sumwneg, "ratio=", sumwneg / sumwpos))
xgmat <- xgb.DMatrix(data, info = list(label=label, weight=weight), missing = -999.0)
xgmat <- xgb.DMatrix(data, label = label, weight = weight, missing = -999.0)
param <- list("objective" = "binary:logitraw",
"scale_pos_weight" = sumwneg / sumwpos,
"bst:eta" = 0.1,

View File

@ -13,10 +13,10 @@ Project Logical Layout
File Naming Convention
=======
* The project is templatized, to make it easy to adjust input data structure.
* .h files are data structures and interface, which are needed to use functions in that layer.
* -inl.hpp files are implementations of interface, like cpp file in most project.
- You only need to understand the interface file to understand the usage of that layer
* In each folder, there can be a .cpp file, that compiles the module of that layer
How to Hack the Code
======

View File

@ -7,16 +7,8 @@
*/
#include <cstdio>
#include <vector>
#include <limits>
#include <climits>
#include <cstring>
#include <algorithm>
#include "utils/io.h"
#include "utils/omp.h"
#include "utils/utils.h"
#include "utils/iterator.h"
#include "utils/random.h"
#include "utils/matrix_csr.h"
namespace xgboost {
/*!
@ -70,12 +62,12 @@ struct SparseBatch {
/*! \brief an entry of sparse vector */
struct Entry {
/*! \brief feature index */
bst_uint findex;
bst_uint index;
/*! \brief feature value */
bst_float fvalue;
// default constructor
Entry(void) {}
Entry(bst_uint findex, bst_float fvalue) : findex(findex), fvalue(fvalue) {}
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
/*! \brief reversely compare feature values */
inline static bool CmpValue(const Entry &a, const Entry &b) {
return a.fvalue < b.fvalue;
@ -86,7 +78,7 @@ struct SparseBatch {
/*! \brief pointer to the elements*/
const Entry *data;
/*! \brief length of the instance */
const bst_uint length;
bst_uint length;
/*! \brief constructor */
Inst(const Entry *data, bst_uint length) : data(data), length(length) {}
/*! \brief get i-th pair in the sparse vector*/
@ -96,298 +88,72 @@ struct SparseBatch {
};
/*! \brief batch size */
size_t size;
};
/*! \brief read-only row batch, used to access row continuously */
struct RowBatch : public SparseBatch {
/*! \brief the offset of rowid of this batch */
size_t base_rowid;
/*! \brief array[size+1], row pointer of each of the elements */
const size_t *row_ptr;
/*! \brief array[row_ptr.back()], content of the sparse element */
const size_t *ind_ptr;
/*! \brief array[ind_ptr.back()], content of the sparse element */
const Entry *data_ptr;
/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return Inst(data_ptr + row_ptr[i], static_cast<bst_uint>(row_ptr[i+1] - row_ptr[i]));
return Inst(data_ptr + ind_ptr[i], static_cast<bst_uint>(ind_ptr[i+1] - ind_ptr[i]));
}
};
/**
* \brief This is a interface convention via template, defining the way to access features,
* column access rule is defined by template, for efficiency purpose,
* row access is defined by iterator of sparse batches
* \tparam Derived type of actual implementation
/*!
* \brief read-only column batch, used to access columns,
* the columns are not required to be continuous
*/
template<typename Derived>
class FMatrixInterface {
struct ColBatch : public SparseBatch {
/*! \brief column index of each columns in the data */
const bst_uint *col_index;
/*! \brief pointer to the column data */
const Inst *col_data;
/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return col_data[i];
}
};
/**
* \brief interface of feature matrix, needed for tree construction
* this interface defines two way to access features,
* row access is defined by iterator of RowBatch
* col access is optional, checked by HaveColAccess, and defined by iterator of ColBatch
*/
class IFMatrix {
public:
/*! \brief example iterator over one column */
struct ColIter{
/*!
* \brief move to next position
* \return whether there is element in next position
*/
inline bool Next(void);
/*! \return row index of current position */
inline bst_uint rindex(void) const;
/*! \return feature value in current position */
inline bst_float fvalue(void) const;
};
/*! \brief backward iterator over column */
struct ColBackIter : public ColIter {};
public:
// column access is needed by some of tree construction algorithms
// the interface only need to ganrantee row iter
// column iter is active, when ColIterator is called, row_iter can be disabled
/*! \brief get the row iterator associated with FMatrix */
virtual utils::IIterator<RowBatch> *RowIterator(void) = 0;
/*!\brief get column iterator */
virtual utils::IIterator<ColBatch> *ColIterator(void) = 0;
/*!
* \brief get column iterator, the columns must be sorted by feature value
* \param cidx column index
* \return column iterator
* \brief get the column iterator associated with FMatrix with subset of column features
* \param fset is the list of column index set that must be contained in the returning Column iterator
* \return the column iterator, initialized so that it reads the elements in fset
*/
inline ColIter GetSortedCol(size_t cidx) const;
/*!
* \brief get column backward iterator, starts from biggest fvalue, and iterator back
* \param cidx column index
* \return reverse column iterator
*/
inline ColBackIter GetReverseSortedCol(size_t cidx) const;
/*!
* \brief get number of columns
* \return number of columns
*/
inline size_t NumCol(void) const;
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) = 0;
/*!
* \brief check if column access is supported, if not, initialize column access
* \param max_rows maximum number of rows allowed in constructor
* \param subsample subsample ratio when generating column access
*/
inline void InitColAccess(void);
virtual void InitColAccess(float subsample) = 0;
// the following are column meta data, should be able to answer them fast
/*! \return whether column access is enabled */
inline bool HaveColAccess(void) const;
/*! \breif return #entries-in-col */
inline size_t GetColSize(size_t cidx) const;
/*!
* \breif return #entries-in-col / #rows
* \param cidx column index
* this function is used to help speedup,
* doese not necessarily implement it if not sure, return 0.0;
* \return column density
*/
inline float GetColDensity(size_t cidx) const;
/*! \brief get the row iterator associated with FMatrix */
inline utils::IIterator<SparseBatch>* RowIterator(void) const;
};
/*!
* \brief sparse matrix that support column access, CSC
*/
class FMatrixS : public FMatrixInterface<FMatrixS>{
public:
typedef SparseBatch::Entry Entry;
/*! \brief row iterator */
struct ColIter{
const Entry *dptr_, *end_;
ColIter(const Entry* begin, const Entry* end)
:dptr_(begin), end_(end) {}
inline bool Next(void) {
if (dptr_ == end_) {
return false;
} else {
++dptr_; return true;
}
}
inline bst_uint rindex(void) const {
return dptr_->findex;
}
inline bst_float fvalue(void) const {
return dptr_->fvalue;
}
};
/*! \brief reverse column iterator */
struct ColBackIter : public ColIter {
ColBackIter(const Entry* dptr, const Entry* end) : ColIter(dptr, end) {}
// shadows ColIter::Next
inline bool Next(void) {
if (dptr_ == end_) {
return false;
} else {
--dptr_; return true;
}
}
};
/*! \brief constructor */
FMatrixS(void) {
iter_ = NULL;
}
// destructor
~FMatrixS(void) {
if (iter_ != NULL) delete iter_;
}
/*! \return whether column access is enabled */
inline bool HaveColAccess(void) const {
return col_ptr_.size() != 0;
}
/*! \brief get number of colmuns */
inline size_t NumCol(void) const {
utils::Check(this->HaveColAccess(), "NumCol:need column access");
return col_ptr_.size() - 1;
}
/*! \brief get number of buffered rows */
inline const std::vector<bst_uint> buffered_rowset(void) const {
return buffered_rowset_;
}
/*! \brief get col sorted iterator */
inline ColIter GetSortedCol(size_t cidx) const {
utils::Assert(cidx < this->NumCol(), "col id exceed bound");
return ColIter(&col_data_[0] + col_ptr_[cidx] - 1,
&col_data_[0] + col_ptr_[cidx + 1] - 1);
}
/*!
* \brief get reversed col iterator,
* this function will be deprecated at some point
*/
inline ColBackIter GetReverseSortedCol(size_t cidx) const {
utils::Assert(cidx < this->NumCol(), "col id exceed bound");
return ColBackIter(&col_data_[0] + col_ptr_[cidx + 1],
&col_data_[0] + col_ptr_[cidx]);
}
/*! \brief get col size */
inline size_t GetColSize(size_t cidx) const {
return col_ptr_[cidx+1] - col_ptr_[cidx];
}
/*! \brief get column density */
inline float GetColDensity(size_t cidx) const {
size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
}
inline void InitColAccess(float pkeep = 1.0f) {
if (this->HaveColAccess()) return;
this->InitColData(pkeep);
}
/*!
* \brief get the row iterator associated with FMatrix
* this function is not threadsafe, returns iterator stored in FMatrixS
*/
inline utils::IIterator<SparseBatch>* RowIterator(void) const {
iter_->BeforeFirst();
return iter_;
}
/*! \brief set iterator */
inline void set_iter(utils::IIterator<SparseBatch> *iter) {
this->iter_ = iter;
}
/*!
* \brief save column access data into stream
* \param fo output stream to save to
*/
inline void SaveColAccess(utils::IStream &fo) const {
fo.Write(buffered_rowset_);
if (buffered_rowset_.size() != 0) {
SaveBinary(fo, col_ptr_, col_data_);
}
}
/*!
* \brief load column access data from stream
* \param fo output stream to load from
*/
inline void LoadColAccess(utils::IStream &fi) {
utils::Check(fi.Read(&buffered_rowset_), "invalid input file format");
if (buffered_rowset_.size() != 0) {
LoadBinary(fi, &col_ptr_, &col_data_);
}
}
/*!
* \brief save data to binary stream
* \param fo output stream
* \param ptr pointer data
* \param data data content
*/
inline static void SaveBinary(utils::IStream &fo,
const std::vector<size_t> &ptr,
const std::vector<SparseBatch::Entry> &data) {
size_t nrow = ptr.size() - 1;
fo.Write(&nrow, sizeof(size_t));
fo.Write(&ptr[0], ptr.size() * sizeof(size_t));
if (data.size() != 0) {
fo.Write(&data[0], data.size() * sizeof(SparseBatch::Entry));
}
}
/*!
* \brief load data from binary stream
* \param fi input stream
* \param out_ptr pointer data
* \param out_data data content
*/
inline static void LoadBinary(utils::IStream &fi,
std::vector<size_t> *out_ptr,
std::vector<SparseBatch::Entry> *out_data) {
size_t nrow;
utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format");
out_ptr->resize(nrow + 1);
utils::Check(fi.Read(&(*out_ptr)[0], out_ptr->size() * sizeof(size_t)) != 0,
"invalid input file format");
out_data->resize(out_ptr->back());
if (out_data->size() != 0) {
utils::Assert(fi.Read(&(*out_data)[0], out_data->size() * sizeof(SparseBatch::Entry)) != 0,
"invalid input file format");
}
}
protected:
/*!
* \brief intialize column data
* \param pkeep probability to keep a row
*/
inline void InitColData(float pkeep) {
buffered_rowset_.clear();
// note: this part of code is serial, todo, parallelize this transformer
utils::SparseCSRMBuilder<SparseBatch::Entry> builder(col_ptr_, col_data_);
builder.InitBudget(0);
// start working
iter_->BeforeFirst();
while (iter_->Next()) {
const SparseBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
if (pkeep == 1.0f || random::SampleBinary(pkeep)) {
buffered_rowset_.push_back(static_cast<bst_uint>(batch.base_rowid+i));
SparseBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].findex);
}
}
}
}
builder.InitStorage();
iter_->BeforeFirst();
size_t ktop = 0;
while (iter_->Next()) {
const SparseBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
if (ktop < buffered_rowset_.size() &&
buffered_rowset_[ktop] == batch.base_rowid+i) {
++ktop;
SparseBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.PushElem(inst[j].findex,
Entry((bst_uint)(batch.base_rowid+i),
inst[j].fvalue));
}
}
}
}
// sort columns
bst_omp_uint ncol = static_cast<bst_omp_uint>(this->NumCol());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ncol; ++i) {
std::sort(&col_data_[0] + col_ptr_[i],
&col_data_[0] + col_ptr_[i + 1], Entry::CmpValue);
}
}
private:
// --- data structure used to support InitColAccess --
utils::IIterator<SparseBatch> *iter_;
/*! \brief list of row index that are buffered */
std::vector<bst_uint> buffered_rowset_;
/*! \brief column pointer of CSC format */
std::vector<size_t> col_ptr_;
/*! \brief column datas in CSC format */
std::vector<SparseBatch::Entry> col_data_;
virtual bool HaveColAccess(void) const = 0;
/*! \return number of columns in the FMatrix */
virtual size_t NumCol(void) const = 0;
/*! \brief get number of non-missing entries in column */
virtual size_t GetColSize(size_t cidx) const = 0;
/*! \brief get column density */
virtual float GetColDensity(size_t cidx) const = 0;
/*! \brief reference of buffered rowset */
virtual const std::vector<bst_uint> &buffered_rowset(void) const = 0;
// virtual destructor
virtual ~IFMatrix(void){}
};
} // namespace xgboost
#endif // XGBOOST_DATA_H

View File

@ -18,8 +18,7 @@ namespace gbm {
* \brief gradient boosted linear model
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class GBLinear : public IGradBooster<FMatrix> {
class GBLinear : public IGradBooster {
public:
virtual ~GBLinear(void) {
}
@ -41,13 +40,12 @@ class GBLinear : public IGradBooster<FMatrix> {
virtual void InitModel(void) {
model.InitModel();
}
virtual void DoBoost(const FMatrix &fmat,
virtual void DoBoost(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<bst_gpair> *in_gpair) {
this->InitFeatIndex(fmat);
std::vector<bst_gpair> &gpair = *in_gpair;
const int ngroup = model.param.num_output_group;
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
// for all the output group
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
@ -72,45 +70,49 @@ class GBLinear : public IGradBooster<FMatrix> {
}
}
}
// number of features
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(feat_index.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const bst_uint fid = feat_index[i];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
const float v = it.fvalue();
bst_gpair &p = gpair[it.rindex() * ngroup + gid];
if (p.hess < 0.0f) continue;
sum_grad += p.grad * v;
sum_hess += p.hess * v * v;
}
float &w = model[fid][gid];
bst_float dw = static_cast<bst_float>(param.learning_rate * param.CalcDelta(sum_grad, sum_hess, w));
w += dw;
// update grad value
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
bst_gpair &p = gpair[it.rindex() * ngroup + gid];
if (p.hess < 0.0f) continue;
p.grad += p.hess * it.fvalue() * dw;
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
while (iter->Next()) {
// number of features
const ColBatch &batch = iter->Value();
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const bst_uint fid = batch.col_index[i];
ColBatch::Inst col = batch[i];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (bst_uint j = 0; j < col.length; ++j) {
const float v = col[j].fvalue;
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.hess < 0.0f) continue;
sum_grad += p.grad * v;
sum_hess += p.hess * v * v;
}
float &w = model[fid][gid];
bst_float dw = static_cast<bst_float>(param.learning_rate * param.CalcDelta(sum_grad, sum_hess, w));
w += dw;
// update grad value
for (bst_uint j = 0; j < col.length; ++j) {
bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.hess < 0.0f) continue;
p.grad += p.hess * col[j].fvalue * dw;
}
}
}
}
}
virtual void Predict(const FMatrix &fmat,
virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<float> *out_preds) {
std::vector<float> &preds = *out_preds;
preds.resize(0);
// start collecting the prediction
utils::IIterator<SparseBatch> *iter = fmat.RowIterator();
iter->BeforeFirst();
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
const int ngroup = model.param.num_output_group;
while (iter->Next()) {
const SparseBatch &batch = iter->Value();
const RowBatch &batch = iter->Value();
utils::Assert(batch.base_rowid * ngroup == preds.size(),
"base_rowid is not set correctly");
// output convention: nrow * k, where nrow is number of rows
@ -134,23 +136,11 @@ class GBLinear : public IGradBooster<FMatrix> {
}
protected:
inline void InitFeatIndex(const FMatrix &fmat) {
if (feat_index.size() != 0) return;
// initialize feature index
unsigned ncol = static_cast<unsigned>(fmat.NumCol());
feat_index.reserve(ncol);
for (unsigned i = 0; i < ncol; ++i) {
if (fmat.GetColSize(i) != 0) {
feat_index.push_back(i);
}
}
random::Shuffle(feat_index);
}
inline void Pred(const SparseBatch::Inst &inst, float *preds) {
inline void Pred(const RowBatch::Inst &inst, float *preds) {
for (int gid = 0; gid < model.param.num_output_group; ++gid) {
float psum = model.bias()[gid];
for (bst_uint i = 0; i < inst.length; ++i) {
psum += inst[i].fvalue * model[inst[i].findex][gid];
psum += inst[i].fvalue * model[inst[i].index][gid];
}
preds[gid] = psum;
}

18
src/gbm/gbm.cpp Normal file
View File

@ -0,0 +1,18 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#include <cstring>
#include "./gbm.h"
#include "./gbtree-inl.hpp"
#include "./gblinear-inl.hpp"
namespace xgboost {
namespace gbm {
IGradBooster* CreateGradBooster(const char *name) {
if (!strcmp("gbtree", name)) return new GBTree();
if (!strcmp("gblinear", name)) return new GBLinear();
utils::Error("unknown booster type: %s", name);
return NULL;
}
} // namespace gbm
} // namespace xgboost

View File

@ -7,6 +7,7 @@
*/
#include <vector>
#include "../data.h"
#include "../utils/io.h"
#include "../utils/fmap.h"
namespace xgboost {
@ -14,9 +15,7 @@ namespace xgboost {
namespace gbm {
/*!
* \brief interface of gradient boosting model
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class IGradBooster {
public:
/*!
@ -41,17 +40,17 @@ class IGradBooster {
virtual void InitModel(void) = 0;
/*!
* \brief peform update to the model(boosting)
* \param fmat feature matrix that provide access to features
* \param p_fmat feature matrix that provide access to features
* \param info meta information about training
* \param in_gpair address of the gradient pair statistics of the data
* the booster may change content of gpair
*/
virtual void DoBoost(const FMatrix &fmat,
virtual void DoBoost(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<bst_gpair> *in_gpair) = 0;
/*!
* \brief generate predictions for given feature matrix
* \param fmat feature matrix
* \param p_fmat feature matrix
* \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction
@ -59,7 +58,7 @@ class IGradBooster {
* \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions
*/
virtual void Predict(const FMatrix &fmat,
virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<float> *out_preds) = 0;
@ -73,21 +72,11 @@ class IGradBooster {
// destrcutor
virtual ~IGradBooster(void){}
};
} // namespace gbm
} // namespace xgboost
#include "gbtree-inl.hpp"
#include "gblinear-inl.hpp"
namespace xgboost {
namespace gbm {
template<typename FMatrix>
inline IGradBooster<FMatrix>* CreateGradBooster(const char *name) {
if (!strcmp("gbtree", name)) return new GBTree<FMatrix>();
if (!strcmp("gblinear", name)) return new GBLinear<FMatrix>();
utils::Error("unknown booster type: %s", name);
return NULL;
}
/*!
* \breif create a gradient booster from given name
* \param name name of gradient booster
*/
IGradBooster* CreateGradBooster(const char *name);
} // namespace gbm
} // namespace xgboost
#endif // XGBOOST_GBM_GBM_H_

View File

@ -9,16 +9,15 @@
#include <utility>
#include <string>
#include "./gbm.h"
#include "../utils/omp.h"
#include "../tree/updater.h"
namespace xgboost {
namespace gbm {
/*!
* \brief gradient boosted tree
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class GBTree : public IGradBooster<FMatrix> {
class GBTree : public IGradBooster {
public:
virtual ~GBTree(void) {
this->Clear();
@ -82,12 +81,12 @@ class GBTree : public IGradBooster<FMatrix> {
utils::Assert(mparam.num_trees == 0, "GBTree: model already initialized");
utils::Assert(trees.size() == 0, "GBTree: model already initialized");
}
virtual void DoBoost(const FMatrix &fmat,
virtual void DoBoost(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<bst_gpair> *in_gpair) {
const std::vector<bst_gpair> &gpair = *in_gpair;
if (mparam.num_output_group == 1) {
this->BoostNewTrees(gpair, fmat, info, 0);
this->BoostNewTrees(gpair, p_fmat, info, 0);
} else {
const int ngroup = mparam.num_output_group;
utils::Check(gpair.size() % ngroup == 0,
@ -99,11 +98,11 @@ class GBTree : public IGradBooster<FMatrix> {
for (bst_omp_uint i = 0; i < nsize; ++i) {
tmp[i] = gpair[i * ngroup + gid];
}
this->BoostNewTrees(tmp, fmat, info, gid);
this->BoostNewTrees(tmp, p_fmat, info, gid);
}
}
}
virtual void Predict(const FMatrix &fmat,
virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<float> *out_preds) {
@ -118,17 +117,13 @@ class GBTree : public IGradBooster<FMatrix> {
}
std::vector<float> &preds = *out_preds;
preds.resize(0);
const size_t stride = info.num_row * mparam.num_output_group;
preds.resize(stride * (mparam.size_leaf_vector+1));
// start collecting the prediction
utils::IIterator<SparseBatch> *iter = fmat.RowIterator();
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const SparseBatch &batch = iter->Value();
utils::Assert(batch.base_rowid * mparam.num_output_group == preds.size(),
"base_rowid is not set correctly");
// output convention: nrow * k, where nrow is number of rows
// k is number of group
preds.resize(preds.size() + batch.size * mparam.num_output_group);
const RowBatch &batch = iter->Value();
// parallel over local batch
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
@ -136,13 +131,13 @@ class GBTree : public IGradBooster<FMatrix> {
const int tid = omp_get_thread_num();
tree::RegTree::FVec &feats = thread_temp[tid];
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
const unsigned root_idx = info.GetRoot(ridx);
utils::Assert(static_cast<size_t>(ridx) < info.num_row, "data row index exceed bound");
// loop over output groups
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
preds[ridx * mparam.num_output_group + gid] =
this->Pred(batch[i],
buffer_offset < 0 ? -1 : buffer_offset+ridx,
gid, root_idx, &feats);
this->Pred(batch[i],
buffer_offset < 0 ? -1 : buffer_offset + ridx,
gid, info.GetRoot(ridx), &feats,
&preds[ridx * mparam.num_output_group + gid], stride);
}
}
}
@ -176,7 +171,7 @@ class GBTree : public IGradBooster<FMatrix> {
char *pstr;
pstr = strtok(&tval[0], ",");
while (pstr != NULL) {
updaters.push_back(tree::CreateUpdater<FMatrix>(pstr));
updaters.push_back(tree::CreateUpdater(pstr));
for (size_t j = 0; j < cfg.size(); ++j) {
// set parameters
updaters.back()->SetParam(cfg[j].first.c_str(), cfg[j].second.c_str());
@ -187,7 +182,7 @@ class GBTree : public IGradBooster<FMatrix> {
}
// do group specific group
inline void BoostNewTrees(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
int bst_group) {
this->InitUpdater();
@ -202,7 +197,7 @@ class GBTree : public IGradBooster<FMatrix> {
}
// update the trees
for (size_t i = 0; i < updaters.size(); ++i) {
updaters[i]->Update(gpair, fmat, info, new_trees);
updaters[i]->Update(gpair, p_fmat, info, new_trees);
}
// push back to model
for (size_t i = 0; i < new_trees.size(); ++i) {
@ -212,24 +207,34 @@ class GBTree : public IGradBooster<FMatrix> {
mparam.num_trees += tparam.num_parallel_tree;
}
// make a prediction for a single instance
inline float Pred(const SparseBatch::Inst &inst,
int64_t buffer_index,
int bst_group,
unsigned root_index,
tree::RegTree::FVec *p_feats) {
inline void Pred(const RowBatch::Inst &inst,
int64_t buffer_index,
int bst_group,
unsigned root_index,
tree::RegTree::FVec *p_feats,
float *out_pred, size_t stride) {
size_t itop = 0;
float psum = 0.0f;
// sum of leaf vector
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
// load buffered results if any
if (bid >= 0) {
itop = pred_counter[bid];
psum = pred_buffer[bid];
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
vec_psum[i] = pred_buffer[bid + i + 1];
}
}
if (itop != trees.size()) {
p_feats->Fill(inst);
for (size_t i = itop; i < trees.size(); ++i) {
if (tree_info[i] == bst_group) {
psum += trees[i]->Predict(*p_feats, root_index);
int tid = trees[i]->GetLeafIndex(*p_feats, root_index);
psum += (*trees[i])[tid].leaf_value();
for (int j = 0; j < mparam.size_leaf_vector; ++j) {
vec_psum[j] += trees[i]->leafvec(tid)[j];
}
}
}
p_feats->Drop(inst);
@ -238,8 +243,14 @@ class GBTree : public IGradBooster<FMatrix> {
if (bid >= 0) {
pred_counter[bid] = static_cast<unsigned>(trees.size());
pred_buffer[bid] = psum;
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
pred_buffer[bid + i + 1] = vec_psum[i];
}
}
out_pred[0] = psum;
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
out_pred[stride * (i + 1)] = vec_psum[i];
}
return psum;
}
// --- data structure ---
/*! \brief training parameters */
@ -292,14 +303,17 @@ class GBTree : public IGradBooster<FMatrix> {
* suppose we have n instance and k group, output will be k*n
*/
int num_output_group;
/*! \brief size of leaf vector needed in tree */
int size_leaf_vector;
/*! \brief reserved parameters */
int reserved[32];
int reserved[31];
/*! \brief constructor */
ModelParam(void) {
num_trees = 0;
num_roots = num_feature = 0;
num_pbuffer = 0;
num_output_group = 1;
size_leaf_vector = 0;
memset(reserved, 0, sizeof(reserved));
}
/*!
@ -312,10 +326,11 @@ class GBTree : public IGradBooster<FMatrix> {
if (!strcmp("num_output_group", name)) num_output_group = atol(val);
if (!strcmp("bst:num_roots", name)) num_roots = atoi(val);
if (!strcmp("bst:num_feature", name)) num_feature = atoi(val);
if (!strcmp("bst:size_leaf_vector", name)) size_leaf_vector = atoi(val);
}
/*! \return size of prediction buffer actually needed */
inline size_t PredBufferSize(void) const {
return num_output_group * num_pbuffer;
return num_output_group * num_pbuffer * (size_leaf_vector + 1);
}
/*!
* \brief get the buffer offset given a buffer index and group id
@ -324,7 +339,7 @@ class GBTree : public IGradBooster<FMatrix> {
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
if (buffer_index < 0) return -1;
utils::Check(buffer_index < num_pbuffer, "buffer_index exceed num_pbuffer");
return buffer_index + num_pbuffer * bst_group;
return (buffer_index + num_pbuffer * bst_group) * (size_leaf_vector + 1);
}
};
// training parameter
@ -345,7 +360,7 @@ class GBTree : public IGradBooster<FMatrix> {
// temporal storage for per thread
std::vector<tree::RegTree::FVec> thread_temp;
// the updaters that can be applied to each of tree
std::vector< tree::IUpdater<FMatrix>* > updaters;
std::vector<tree::IUpdater*> updaters;
};
} // namespace gbm

View File

@ -13,7 +13,7 @@ namespace xgboost {
/*! \brief namespace related to data format */
namespace io {
/*! \brief DMatrix object that I/O module support save/load */
typedef learner::DMatrix<FMatrixS> DataMatrix;
typedef learner::DMatrix DataMatrix;
/*!
* \brief load DataMatrix from stream
* \param fname file name to be loaded

View File

@ -16,6 +16,7 @@
#include "../utils/utils.h"
#include "../learner/dmatrix.h"
#include "./io.h"
#include "./simple_fmatrix-inl.hpp"
namespace xgboost {
namespace io {
@ -24,11 +25,16 @@ class DMatrixSimple : public DataMatrix {
public:
// constructor
DMatrixSimple(void) : DataMatrix(kMagic) {
this->fmat.set_iter(new OneBatchIter(this));
fmat_ = new FMatrixS(new OneBatchIter(this));
this->Clear();
}
// virtual destructor
virtual ~DMatrixSimple(void) {}
virtual ~DMatrixSimple(void) {
delete fmat_;
}
virtual IFMatrix *fmat(void) const {
return fmat_;
}
/*! \brief clear the storage */
inline void Clear(void) {
row_ptr_.clear();
@ -41,15 +47,15 @@ class DMatrixSimple : public DataMatrix {
this->info = src.info;
this->Clear();
// clone data content in thos matrix
utils::IIterator<SparseBatch> *iter = src.fmat.RowIterator();
utils::IIterator<RowBatch> *iter = src.fmat()->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const SparseBatch &batch = iter->Value();
const RowBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
SparseBatch::Inst inst = batch[i];
RowBatch::Inst inst = batch[i];
row_data_.resize(row_data_.size() + inst.length);
memcpy(&row_data_[row_ptr_.back()], inst.data,
sizeof(SparseBatch::Entry) * inst.length);
sizeof(RowBatch::Entry) * inst.length);
row_ptr_.push_back(row_ptr_.back() + inst.length);
}
}
@ -59,10 +65,10 @@ class DMatrixSimple : public DataMatrix {
* \param feats features
* \return the index of added row
*/
inline size_t AddRow(const std::vector<SparseBatch::Entry> &feats) {
inline size_t AddRow(const std::vector<RowBatch::Entry> &feats) {
for (size_t i = 0; i < feats.size(); ++i) {
row_data_.push_back(feats[i]);
info.info.num_col = std::max(info.info.num_col, static_cast<size_t>(feats[i].findex+1));
info.info.num_col = std::max(info.info.num_col, static_cast<size_t>(feats[i].index+1));
}
row_ptr_.push_back(row_ptr_.back() + feats.size());
info.info.num_row += 1;
@ -78,10 +84,10 @@ class DMatrixSimple : public DataMatrix {
FILE* file = utils::FopenCheck(fname, "r");
float label; bool init = true;
char tmp[1024];
std::vector<SparseBatch::Entry> feats;
std::vector<RowBatch::Entry> feats;
while (fscanf(file, "%s", tmp) == 1) {
SparseBatch::Entry e;
if (sscanf(tmp, "%u:%f", &e.findex, &e.fvalue) == 2) {
RowBatch::Entry e;
if (sscanf(tmp, "%u:%f", &e.index, &e.fvalue) == 2) {
feats.push_back(e);
} else {
if (!init) {
@ -145,7 +151,7 @@ class DMatrixSimple : public DataMatrix {
info.LoadBinary(fs);
FMatrixS::LoadBinary(fs, &row_ptr_, &row_data_);
fmat.LoadColAccess(fs);
fmat_->LoadColAccess(fs);
if (!silent) {
printf("%lux%lu matrix with %lu entries is loaded",
@ -172,7 +178,7 @@ class DMatrixSimple : public DataMatrix {
info.SaveBinary(fs);
FMatrixS::SaveBinary(fs, row_ptr_, row_data_);
fmat.SaveColAccess(fs);
fmat_->SaveColAccess(fs);
fs.Close();
if (!silent) {
@ -211,13 +217,15 @@ class DMatrixSimple : public DataMatrix {
/*! \brief row pointer of CSR sparse storage */
std::vector<size_t> row_ptr_;
/*! \brief data in the row */
std::vector<SparseBatch::Entry> row_data_;
std::vector<RowBatch::Entry> row_data_;
/*! \brief the real fmatrix */
FMatrixS *fmat_;
/*! \brief magic number used to identify DMatrix */
static const int kMagic = 0xffffab01;
protected:
// one batch iterator that return content in the matrix
struct OneBatchIter: utils::IIterator<SparseBatch> {
struct OneBatchIter: utils::IIterator<RowBatch> {
explicit OneBatchIter(DMatrixSimple *parent)
: at_first_(true), parent_(parent) {}
virtual ~OneBatchIter(void) {}
@ -229,11 +237,11 @@ class DMatrixSimple : public DataMatrix {
at_first_ = false;
batch_.size = parent_->row_ptr_.size() - 1;
batch_.base_rowid = 0;
batch_.row_ptr = &parent_->row_ptr_[0];
batch_.ind_ptr = &parent_->row_ptr_[0];
batch_.data_ptr = &parent_->row_data_[0];
return true;
}
virtual const SparseBatch &Value(void) const {
virtual const RowBatch &Value(void) const {
return batch_;
}
@ -243,8 +251,8 @@ class DMatrixSimple : public DataMatrix {
// pointer to parient
DMatrixSimple *parent_;
// temporal space for batch
SparseBatch batch_;
};
RowBatch batch_;
};
};
} // namespace io
} // namespace xgboost

View File

@ -0,0 +1,242 @@
#ifndef XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP
#define XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP
/*!
* \file simple_fmatrix-inl.hpp
* \brief the input data structure for gradient boosting
* \author Tianqi Chen
*/
#include "../data.h"
#include "../utils/utils.h"
#include "../utils/random.h"
#include "../utils/omp.h"
#include "../utils/matrix_csr.h"
namespace xgboost {
namespace io {
/*!
* \brief sparse matrix that support column access, CSC
*/
class FMatrixS : public IFMatrix{
public:
typedef SparseBatch::Entry Entry;
/*! \brief constructor */
FMatrixS(utils::IIterator<RowBatch> *iter) {
this->iter_ = iter;
}
// destructor
virtual ~FMatrixS(void) {
if (iter_ != NULL) delete iter_;
}
/*! \return whether column access is enabled */
virtual bool HaveColAccess(void) const {
return col_ptr_.size() != 0;
}
/*! \brief get number of colmuns */
virtual size_t NumCol(void) const {
utils::Check(this->HaveColAccess(), "NumCol:need column access");
return col_ptr_.size() - 1;
}
/*! \brief get number of buffered rows */
virtual const std::vector<bst_uint> &buffered_rowset(void) const {
return buffered_rowset_;
}
/*! \brief get column size */
virtual size_t GetColSize(size_t cidx) const {
return col_ptr_[cidx+1] - col_ptr_[cidx];
}
/*! \brief get column density */
virtual float GetColDensity(size_t cidx) const {
size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
}
virtual void InitColAccess(float pkeep = 1.0f) {
if (this->HaveColAccess()) return;
this->InitColData(pkeep);
}
/*!
* \brief get the row iterator associated with FMatrix
*/
virtual utils::IIterator<RowBatch>* RowIterator(void) {
iter_->BeforeFirst();
return iter_;
}
/*!
* \brief get the column based iterator
*/
virtual utils::IIterator<ColBatch>* ColIterator(void) {
size_t ncol = this->NumCol();
col_iter_.col_index_.resize(ncol);
for (size_t i = 0; i < ncol; ++i) {
col_iter_.col_index_[i] = static_cast<bst_uint>(i);
}
col_iter_.SetBatch(col_ptr_, col_data_);
return &col_iter_;
}
/*!
* \brief colmun based iterator
*/
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) {
col_iter_.col_index_ = fset;
col_iter_.SetBatch(col_ptr_, col_data_);
return &col_iter_;
}
/*!
* \brief save column access data into stream
* \param fo output stream to save to
*/
inline void SaveColAccess(utils::IStream &fo) const {
fo.Write(buffered_rowset_);
if (buffered_rowset_.size() != 0) {
SaveBinary(fo, col_ptr_, col_data_);
}
}
/*!
* \brief load column access data from stream
* \param fo output stream to load from
*/
inline void LoadColAccess(utils::IStream &fi) {
utils::Check(fi.Read(&buffered_rowset_), "invalid input file format");
if (buffered_rowset_.size() != 0) {
LoadBinary(fi, &col_ptr_, &col_data_);
}
}
/*!
* \brief save data to binary stream
* \param fo output stream
* \param ptr pointer data
* \param data data content
*/
inline static void SaveBinary(utils::IStream &fo,
const std::vector<size_t> &ptr,
const std::vector<RowBatch::Entry> &data) {
size_t nrow = ptr.size() - 1;
fo.Write(&nrow, sizeof(size_t));
fo.Write(&ptr[0], ptr.size() * sizeof(size_t));
if (data.size() != 0) {
fo.Write(&data[0], data.size() * sizeof(RowBatch::Entry));
}
}
/*!
* \brief load data from binary stream
* \param fi input stream
* \param out_ptr pointer data
* \param out_data data content
*/
inline static void LoadBinary(utils::IStream &fi,
std::vector<size_t> *out_ptr,
std::vector<RowBatch::Entry> *out_data) {
size_t nrow;
utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format");
out_ptr->resize(nrow + 1);
utils::Check(fi.Read(&(*out_ptr)[0], out_ptr->size() * sizeof(size_t)) != 0,
"invalid input file format");
out_data->resize(out_ptr->back());
if (out_data->size() != 0) {
utils::Assert(fi.Read(&(*out_data)[0], out_data->size() * sizeof(RowBatch::Entry)) != 0,
"invalid input file format");
}
}
protected:
/*!
* \brief intialize column data
* \param pkeep probability to keep a row
*/
inline void InitColData(float pkeep) {
buffered_rowset_.clear();
// note: this part of code is serial, todo, parallelize this transformer
utils::SparseCSRMBuilder<RowBatch::Entry> builder(col_ptr_, col_data_);
builder.InitBudget(0);
// start working
iter_->BeforeFirst();
while (iter_->Next()) {
const RowBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
if (pkeep == 1.0f || random::SampleBinary(pkeep)) {
buffered_rowset_.push_back(static_cast<bst_uint>(batch.base_rowid+i));
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].index);
}
}
}
}
builder.InitStorage();
iter_->BeforeFirst();
size_t ktop = 0;
while (iter_->Next()) {
const RowBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
if (ktop < buffered_rowset_.size() &&
buffered_rowset_[ktop] == batch.base_rowid+i) {
++ktop;
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.PushElem(inst[j].index,
Entry((bst_uint)(batch.base_rowid+i),
inst[j].fvalue));
}
}
}
}
// sort columns
bst_omp_uint ncol = static_cast<bst_omp_uint>(this->NumCol());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ncol; ++i) {
std::sort(&col_data_[0] + col_ptr_[i],
&col_data_[0] + col_ptr_[i + 1], Entry::CmpValue);
}
}
private:
// one batch iterator that return content in the matrix
struct OneBatchIter: utils::IIterator<ColBatch> {
OneBatchIter(void) : at_first_(true){}
virtual ~OneBatchIter(void) {}
virtual void BeforeFirst(void) {
at_first_ = true;
}
virtual bool Next(void) {
if (!at_first_) return false;
at_first_ = false;
return true;
}
virtual const ColBatch &Value(void) const {
return batch_;
}
inline void SetBatch(const std::vector<size_t> &ptr,
const std::vector<ColBatch::Entry> &data) {
batch_.size = col_index_.size();
col_data_.resize(col_index_.size(), SparseBatch::Inst(NULL,0));
for (size_t i = 0; i < col_data_.size(); ++i) {
const bst_uint ridx = col_index_[i];
col_data_[i] = SparseBatch::Inst(&data[0] + ptr[ridx],
static_cast<bst_uint>(ptr[ridx+1] - ptr[ridx]));
}
batch_.col_index = &col_index_[0];
batch_.col_data = &col_data_[0];
this->BeforeFirst();
}
// data content
std::vector<bst_uint> col_index_;
std::vector<ColBatch::Inst> col_data_;
// whether is at first
bool at_first_;
// temporal space for batch
ColBatch batch_;
};
// --- data structure used to support InitColAccess --
// column iterator
OneBatchIter col_iter_;
// row iterator
utils::IIterator<RowBatch> *iter_;
/*! \brief list of row index that are buffered */
std::vector<bst_uint> buffered_rowset_;
/*! \brief column pointer of CSC format */
std::vector<size_t> col_ptr_;
/*! \brief column datas in CSC format */
std::vector<ColBatch::Entry> col_data_;
};
} // namespace io
} // namespace xgboost
#endif // XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP

View File

@ -7,8 +7,9 @@
* \author Tianqi Chen
*/
#include <vector>
#include <cstring>
#include "../data.h"
#include "../utils/io.h"
namespace xgboost {
namespace learner {
/*!
@ -142,7 +143,6 @@ struct MetaInfo {
* \brief data object used for learning,
* \tparam FMatrix type of feature data source
*/
template<typename FMatrix>
struct DMatrix {
/*!
* \brief magic number associated with this object
@ -151,8 +151,6 @@ struct DMatrix {
const int magic;
/*! \brief meta information about the dataset */
MetaInfo info;
/*! \brief feature matrix about data content */
FMatrix fmat;
/*!
* \brief cache pointer to verify if the data structure is cached in some learner
* used to verify if DMatrix is cached
@ -160,6 +158,8 @@ struct DMatrix {
void *cache_learner_ptr_;
/*! \brief default constructor */
explicit DMatrix(int magic) : magic(magic), cache_learner_ptr_(NULL) {}
/*! \brief get feature matrix about data content */
virtual IFMatrix *fmat(void) const = 0;
// virtual destructor
virtual ~DMatrix(void){}
};

View File

@ -24,9 +24,12 @@ template<typename Derived>
struct EvalEWiseBase : public IEvaluator {
virtual float Eval(const std::vector<float> &preds,
const MetaInfo &info) const {
utils::Check(preds.size() == info.labels.size(),
utils::Check(info.labels.size() != 0, "label set cannot be empty");
utils::Check(preds.size() % info.labels.size() == 0,
"label and prediction size not match");
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
float sum = 0.0, wsum = 0.0;
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
@ -99,6 +102,45 @@ struct EvalMatchError : public EvalEWiseBase<EvalMatchError> {
}
};
/*! \brief ctest */
struct EvalCTest: public IEvaluator {
EvalCTest(IEvaluator *base, const char *name)
: base_(base), name_(name) {}
virtual ~EvalCTest(void) {
delete base_;
}
virtual const char *Name(void) const {
return name_.c_str();
}
virtual float Eval(const std::vector<float> &preds,
const MetaInfo &info) const {
utils::Check(preds.size() % info.labels.size() == 0,
"label and prediction size not match");
size_t ngroup = preds.size() / info.labels.size() - 1;
const unsigned ndata = static_cast<unsigned>(info.labels.size());
utils::Check(ngroup > 1, "pred size does not meet requirement");
utils::Check(ndata == info.info.fold_index.size(), "need fold index");
double wsum = 0.0;
for (size_t k = 0; k < ngroup; ++k) {
std::vector<float> tpred;
MetaInfo tinfo;
for (unsigned i = 0; i < ndata; ++i) {
if (info.info.fold_index[i] == k) {
tpred.push_back(preds[i + (k + 1) * ndata]);
tinfo.labels.push_back(info.labels[i]);
tinfo.weights.push_back(info.GetWeight(i));
}
}
wsum += base_->Eval(tpred, tinfo);
}
return static_cast<float>(wsum / ngroup);
}
private:
IEvaluator *base_;
std::string name_;
};
/*! \brief AMS: also records best threshold */
struct EvalAMS : public IEvaluator {
public:
@ -109,7 +151,8 @@ struct EvalAMS : public IEvaluator {
}
virtual float Eval(const std::vector<float> &preds,
const MetaInfo &info) const {
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
utils::Check(info.weights.size() == ndata, "we need weight to evaluate ams");
std::vector< std::pair<float, unsigned> > rec(ndata);
@ -168,9 +211,11 @@ struct EvalPrecisionRatio : public IEvaluator{
}
virtual float Eval(const std::vector<float> &preds,
const MetaInfo &info) const {
utils::Assert(preds.size() == info.labels.size(), "label size predict size not match");
utils::Check(info.labels.size() != 0, "label set cannot be empty");
utils::Assert(preds.size() % info.labels.size() == 0,
"label size predict size not match");
std::vector< std::pair<float, unsigned> > rec;
for (size_t j = 0; j < preds.size(); ++j) {
for (size_t j = 0; j < info.labels.size(); ++j) {
rec.push_back(std::make_pair(preds[j], static_cast<unsigned>(j)));
}
std::sort(rec.begin(), rec.end(), CmpFirst);
@ -206,10 +251,14 @@ struct EvalPrecisionRatio : public IEvaluator{
struct EvalAuc : public IEvaluator {
virtual float Eval(const std::vector<float> &preds,
const MetaInfo &info) const {
utils::Check(preds.size() == info.labels.size(), "label size predict size not match");
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(preds.size());
utils::Check(info.labels.size() != 0, "label set cannot be empty");
utils::Check(preds.size() % info.labels.size() == 0,
"label size predict size not match");
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels.size());
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
utils::Check(gptr.back() == preds.size(),
utils::Check(gptr.back() == info.labels.size(),
"EvalAuc: group structure must match number of prediction");
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum statictis

View File

@ -45,7 +45,9 @@ inline IEvaluator* CreateEvaluator(const char *name) {
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
if (!strncmp(name, "ndcg", 4)) return new EvalNDCG(name);
if (!strncmp(name, "ct-", 3)) return new EvalCTest(CreateEvaluator(name+3), name);
utils::Error("unknown evaluation metric type: %s", name);
return NULL;
}

View File

@ -21,7 +21,6 @@ namespace learner {
* \brief learner that takes do gradient boosting on specific objective functions
* and do training and prediction
*/
template<typename FMatrix>
class BoostLearner {
public:
BoostLearner(void) {
@ -44,7 +43,7 @@ class BoostLearner {
* data matrices to continue training otherwise it will cause error
* \param mats array of pointers to matrix whose prediction result need to be cached
*/
inline void SetCacheData(const std::vector<DMatrix<FMatrix>*>& mats) {
inline void SetCacheData(const std::vector<DMatrix*>& mats) {
// estimate feature bound
unsigned num_feature = 0;
// assign buffer index
@ -158,18 +157,18 @@ class BoostLearner {
* if not intialize it
* \param p_train pointer to the matrix used by training
*/
inline void CheckInit(DMatrix<FMatrix> *p_train) {
p_train->fmat.InitColAccess(prob_buffer_row);
inline void CheckInit(DMatrix *p_train) {
p_train->fmat()->InitColAccess(prob_buffer_row);
}
/*!
* \brief update the model for one iteration
* \param iter current iteration number
* \param p_train pointer to the data matrix
*/
inline void UpdateOneIter(int iter, const DMatrix<FMatrix> &train) {
inline void UpdateOneIter(int iter, const DMatrix &train) {
this->PredictRaw(train, &preds_);
obj_->GetGradient(preds_, train.info, iter, &gpair_);
gbm_->DoBoost(train.fmat, train.info.info, &gpair_);
gbm_->DoBoost(train.fmat(), train.info.info, &gpair_);
}
/*!
* \brief evaluate the model for specific iteration
@ -179,7 +178,7 @@ class BoostLearner {
* \return a string corresponding to the evaluation result
*/
inline std::string EvalOneIter(int iter,
const std::vector<const DMatrix<FMatrix>*> &evals,
const std::vector<const DMatrix*> &evals,
const std::vector<std::string> &evname) {
std::string res;
char tmp[256];
@ -198,7 +197,7 @@ class BoostLearner {
* \param metric name of metric
* \return a pair of <evaluation name, result>
*/
std::pair<std::string, float> Evaluate(const DMatrix<FMatrix> &data, std::string metric) {
std::pair<std::string, float> Evaluate(const DMatrix &data, std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric();
IEvaluator *ev = CreateEvaluator(metric.c_str());
this->PredictRaw(data, &preds_);
@ -213,7 +212,7 @@ class BoostLearner {
* \param output_margin whether to only predict margin value instead of transformed prediction
* \param out_preds output vector that stores the prediction
*/
inline void Predict(const DMatrix<FMatrix> &data,
inline void Predict(const DMatrix &data,
bool output_margin,
std::vector<float> *out_preds) const {
this->PredictRaw(data, out_preds);
@ -235,7 +234,7 @@ class BoostLearner {
if (obj_ != NULL) return;
utils::Assert(gbm_ == NULL, "GBM and obj should be NULL");
obj_ = CreateObjFunction(name_obj_.c_str());
gbm_ = gbm::CreateGradBooster<FMatrix>(name_gbm_.c_str());
gbm_ = gbm::CreateGradBooster(name_gbm_.c_str());
for (size_t i = 0; i < cfg_.size(); ++i) {
obj_->SetParam(cfg_[i].first.c_str(), cfg_[i].second.c_str());
gbm_->SetParam(cfg_[i].first.c_str(), cfg_[i].second.c_str());
@ -247,9 +246,9 @@ class BoostLearner {
* \param data training data matrix
* \param out_preds output vector that stores the prediction
*/
inline void PredictRaw(const DMatrix<FMatrix> &data,
inline void PredictRaw(const DMatrix &data,
std::vector<float> *out_preds) const {
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
gbm_->Predict(data.fmat(), this->FindBufferOffset(data),
data.info.info, out_preds);
// add base margin
std::vector<float> &preds = *out_preds;
@ -307,7 +306,7 @@ class BoostLearner {
// model parameter
ModelParam mparam;
// gbm model that back everything
gbm::IGradBooster<FMatrix> *gbm_;
gbm::IGradBooster *gbm_;
// name of gbm model used for training
std::string name_gbm_;
// objective fnction
@ -324,14 +323,14 @@ class BoostLearner {
private:
// cache entry object that helps handle feature caching
struct CacheEntry {
const DMatrix<FMatrix> *mat_;
const DMatrix *mat_;
size_t buffer_offset_;
size_t num_row_;
CacheEntry(const DMatrix<FMatrix> *mat, size_t buffer_offset, size_t num_row)
CacheEntry(const DMatrix *mat, size_t buffer_offset, size_t num_row)
:mat_(mat), buffer_offset_(buffer_offset), num_row_(num_row) {}
};
// find internal bufer offset for certain matrix, if not exist, return -1
inline int64_t FindBufferOffset(const DMatrix<FMatrix> &mat) const {
inline int64_t FindBufferOffset(const DMatrix &mat) const {
for (size_t i = 0; i < cache_.size(); ++i) {
if (cache_[i].mat_ == &mat && mat.cache_learner_ptr_ == this) {
if (cache_[i].num_row_ == mat.info.num_row()) {

View File

@ -123,7 +123,7 @@ class RegLossObj : public IObjFunction{
float p = loss.PredTransform(preds[i]);
float w = info.GetWeight(j);
if (info.labels[j] == 1.0f) w *= scale_pos_weight;
gpair[j] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w,
gpair[i] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w,
loss.SecondOrderGradient(p, info.labels[j]) * w);
}
}

View File

@ -272,6 +272,7 @@ class TreeModel {
param.num_nodes = param.num_roots;
nodes.resize(param.num_nodes);
stats.resize(param.num_nodes);
leaf_vector.resize(param.num_nodes * param.size_leaf_vector, 0.0f);
for (int i = 0; i < param.num_nodes; i ++) {
nodes[i].set_leaf(0.0f);
nodes[i].set_parent(-1);
@ -289,6 +290,9 @@ class TreeModel {
"TreeModel: wrong format");
utils::Check(fi.Read(&stats[0], sizeof(NodeStat) * stats.size()) > 0,
"TreeModel: wrong format");
if (param.size_leaf_vector != 0) {
utils::Check(fi.Read(&leaf_vector), "TreeModel: wrong format");
}
// chg deleted nodes
deleted_nodes.resize(0);
for (int i = param.num_roots; i < param.num_nodes; i ++) {
@ -309,6 +313,7 @@ class TreeModel {
fo.Write(&param, sizeof(Param));
fo.Write(&nodes[0], sizeof(Node) * nodes.size());
fo.Write(&stats[0], sizeof(NodeStat) * nodes.size());
if (param.size_leaf_vector != 0) fo.Write(leaf_vector);
}
/*!
* \brief add child nodes to node
@ -486,15 +491,15 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat>{
std::fill(data.begin(), data.end(), e);
}
/*! \brief fill the vector with sparse vector */
inline void Fill(const SparseBatch::Inst &inst) {
inline void Fill(const RowBatch::Inst &inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
data[inst[i].findex].fvalue = inst[i].fvalue;
data[inst[i].index].fvalue = inst[i].fvalue;
}
}
/*! \brief drop the trace after fill, must be called after fill */
inline void Drop(const SparseBatch::Inst &inst) {
inline void Drop(const RowBatch::Inst &inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
data[inst[i].findex].flag = -1;
data[inst[i].index].flag = -1;
}
}
/*! \brief get ith value */

View File

@ -22,10 +22,10 @@ struct TrainParam{
//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
// weight decay parameter used to control leaf fitting
// L2 regularization factor
float reg_lambda;
// reg method
int reg_method;
// L1 regularization factor
float reg_alpha;
// default direction choice
int default_direction;
// whether we want to do subsample
@ -36,6 +36,8 @@ struct TrainParam{
float colsample_bytree;
// speed optimization for dense column
float opt_dense_col;
// leaf vector size
int size_leaf_vector;
// number of threads to be used for tree construction,
// if OpenMP is enabled, if equals 0, use system default
int nthread;
@ -45,13 +47,14 @@ struct TrainParam{
min_child_weight = 1.0f;
max_depth = 6;
reg_lambda = 1.0f;
reg_method = 2;
reg_alpha = 0.0f;
default_direction = 0;
subsample = 1.0f;
colsample_bytree = 1.0f;
colsample_bylevel = 1.0f;
opt_dense_col = 1.0f;
nthread = 0;
size_leaf_vector = 0;
}
/*!
* \brief set parameters from outside
@ -63,15 +66,17 @@ struct TrainParam{
if (!strcmp(name, "gamma")) min_split_loss = static_cast<float>(atof(val));
if (!strcmp(name, "eta")) learning_rate = static_cast<float>(atof(val));
if (!strcmp(name, "lambda")) reg_lambda = static_cast<float>(atof(val));
if (!strcmp(name, "alpha")) reg_alpha = static_cast<float>(atof(val));
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
if (!strcmp(name, "reg_method")) reg_method = atoi(val);
if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast<float>(atof(val));
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val);
if (!strcmp(name, "max_depth")) max_depth = atoi(val);
if (!strcmp(name, "nthread")) nthread = atoi(val);
if (!strcmp(name, "default_direction")) {
@ -82,31 +87,31 @@ struct TrainParam{
}
// calculate the cost of loss function
inline double CalcGain(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) {
return 0.0;
if (sum_hess < min_child_weight) return 0.0;
if (reg_alpha == 0.0f) {
return Sqr(sum_grad) / (sum_hess + reg_lambda);
} else {
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
}
switch (reg_method) {
case 1 : return Sqr(ThresholdL1(sum_grad, reg_lambda)) / sum_hess;
case 2 : return Sqr(sum_grad) / (sum_hess + reg_lambda);
case 3 : return
Sqr(ThresholdL1(sum_grad, 0.5 * reg_lambda)) /
(sum_hess + 0.5 * reg_lambda);
default: return Sqr(sum_grad) / sum_hess;
}
// calculate cost of loss function with four stati
inline double CalcGain(double sum_grad, double sum_hess,
double test_grad, double test_hess) const {
double w = CalcWeight(sum_grad, sum_hess);
double ret = test_grad * w + 0.5 * (test_hess + reg_lambda) * Sqr(w);
if (reg_alpha == 0.0f) {
return - 2.0 * ret;
} else {
return - 2.0 * (ret + reg_alpha * std::abs(w));
}
}
// calculate weight given the statistics
inline double CalcWeight(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) {
return 0.0;
if (sum_hess < min_child_weight) return 0.0;
if (reg_alpha == 0.0f) {
return -sum_grad / (sum_hess + reg_lambda);
} else {
switch (reg_method) {
case 1: return - ThresholdL1(sum_grad, reg_lambda) / sum_hess;
case 2: return - sum_grad / (sum_hess + reg_lambda);
case 3: return
- ThresholdL1(sum_grad, 0.5 * reg_lambda) /
(sum_hess + 0.5 * reg_lambda);
default: return - sum_grad / sum_hess;
}
return -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
}
}
/*! \brief whether need forward small to big search: default right */
@ -153,6 +158,9 @@ struct GradStats {
inline void Clear(void) {
sum_grad = sum_hess = 0.0f;
}
/*! \brief check if necessary information is ready */
inline static void CheckInfo(const BoosterInfo &info) {
}
/*!
* \brief accumulate statistics,
* \param gpair the vector storing the gradient statistics
@ -188,14 +196,88 @@ struct GradStats {
}
/*! \brief set leaf vector value based on statistics */
inline void SetLeafVec(const TrainParam &param, bst_float *vec) const{
}
protected:
}
// constructor to allow inheritance
GradStats(void) {}
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
};
/*! \brief vectorized cv statistics */
template<unsigned vsize>
struct CVGradStats : public GradStats {
// additional statistics
GradStats train[vsize], valid[vsize];
// constructor
explicit CVGradStats(const TrainParam &param) {
utils::Check(param.size_leaf_vector == vsize,
"CVGradStats: vsize must match size_leaf_vector");
this->Clear();
}
/*! \brief check if necessary information is ready */
inline static void CheckInfo(const BoosterInfo &info) {
utils::Check(info.fold_index.size() != 0,
"CVGradStats: require fold_index");
}
/*! \brief clear the statistics */
inline void Clear(void) {
GradStats::Clear();
for (unsigned i = 0; i < vsize; ++i) {
train[i].Clear(); valid[i].Clear();
}
}
inline void Add(const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
bst_uint ridx) {
GradStats::Add(gpair[ridx].grad, gpair[ridx].hess);
const size_t step = info.fold_index.size();
for (unsigned i = 0; i < vsize; ++i) {
const bst_gpair &b = gpair[(i + 1) * step + ridx];
if (info.fold_index[ridx] == i) {
valid[i].Add(b.grad, b.hess);
} else {
train[i].Add(b.grad, b.hess);
}
}
}
/*! \brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const {
double ret = 0.0;
for (unsigned i = 0; i < vsize; ++i) {
ret += param.CalcGain(train[i].sum_grad,
train[i].sum_hess,
vsize * valid[i].sum_grad,
vsize * valid[i].sum_hess);
}
return ret / vsize;
}
/*! \brief add statistics to the data */
inline void Add(const CVGradStats &b) {
GradStats::Add(b);
for (unsigned i = 0; i < vsize; ++i) {
train[i].Add(b.train[i]);
valid[i].Add(b.valid[i]);
}
}
/*! \brief set current value to a - b */
inline void SetSubstract(const CVGradStats &a, const CVGradStats &b) {
GradStats::SetSubstract(a, b);
for (int i = 0; i < vsize; ++i) {
train[i].SetSubstract(a.train[i], b.train[i]);
valid[i].SetSubstract(a.valid[i], b.valid[i]);
}
}
/*! \brief set leaf vector value based on statistics */
inline void SetLeafVec(const TrainParam &param, bst_float *vec) const{
for (int i = 0; i < vsize; ++i) {
vec[i] = param.learning_rate *
param.CalcWeight(train[i].sum_grad, train[i].sum_hess);
}
}
};
/*!
* \brief statistics that is helpful to store
* and represent a split solution for the tree

20
src/tree/updater.cpp Normal file
View File

@ -0,0 +1,20 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#include <cstring>
#include "./updater.h"
#include "./updater_prune-inl.hpp"
#include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp"
namespace xgboost {
namespace tree {
IUpdater* CreateUpdater(const char *name) {
if (!strcmp(name, "prune")) return new TreePruner();
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
utils::Error("unknown updater:%s", name);
return NULL;
}
} // namespace tree
} // namespace xgboost

View File

@ -14,9 +14,7 @@ namespace xgboost {
namespace tree {
/*!
* \brief interface of tree update module, that performs update of a tree
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class IUpdater {
public:
/*!
@ -28,7 +26,7 @@ class IUpdater {
/*!
* \brief peform update to the tree models
* \param gpair the gradient pair statistics of the data
* \param fmat feature matrix that provide access to features
* \param p_fmat feature matrix that provide access to features
* \param info extra side information that may be need, such as root index
* \param trees pointer to the trese to be updated, upater will change the content of the tree
* note: all the trees in the vector are updated, with the same statistics,
@ -36,36 +34,18 @@ class IUpdater {
* there can be multiple trees when we train random forest style model
*/
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) = 0;
// destructor
virtual ~IUpdater(void) {}
};
} // namespace tree
} // namespace xgboost
#include "./updater_prune-inl.hpp"
#include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp"
namespace xgboost {
namespace tree {
/*!
* \brief create a updater based on name
* \param name name of updater
* \return return the updater instance
*/
template<typename FMatrix>
inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
if (!strcmp(name, "prune")) return new TreePruner<FMatrix>();
if (!strcmp(name, "refresh")) return new TreeRefresher<FMatrix, GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
utils::Error("unknown updater:%s", name);
return NULL;
}
IUpdater* CreateUpdater(const char *name);
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_H_

View File

@ -15,8 +15,8 @@
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix, typename TStats>
class ColMaker: public IUpdater<FMatrix> {
template<typename TStats>
class ColMaker: public IUpdater {
public:
virtual ~ColMaker(void) {}
// set training parameter
@ -24,16 +24,17 @@ class ColMaker: public IUpdater<FMatrix> {
param.SetParam(name, val);
}
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
TStats::CheckInfo(info);
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param);
builder.Update(gpair, fmat, info, trees[i]);
builder.Update(gpair, p_fmat, info, trees[i]);
}
param.learning_rate = lr;
}
@ -76,17 +77,16 @@ class ColMaker: public IUpdater<FMatrix> {
explicit Builder(const TrainParam &param) : param(param) {}
// update one tree, growing
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
RegTree *p_tree) {
this->InitData(gpair, fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, *p_fmat, info, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) {
this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree);
this->ResetPosition(this->qexpand, fmat, *p_tree);
this->FindSplit(depth, this->qexpand, gpair, p_fmat, info, p_tree);
this->ResetPosition(this->qexpand, p_fmat, *p_tree);
this->UpdateQueueExpand(*p_tree, &this->qexpand);
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
this->InitNewNode(qexpand, gpair, *p_fmat, info, *p_tree);
// if nothing left to be expand, break
if (qexpand.size() == 0) break;
}
@ -107,7 +107,7 @@ class ColMaker: public IUpdater<FMatrix> {
private:
// initialize temp data structure
inline void InitData(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const IFMatrix &fmat,
const std::vector<unsigned> &root_index, const RegTree &tree) {
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
@ -137,8 +137,7 @@ class ColMaker: public IUpdater<FMatrix> {
if (random::SampleBinary(param.subsample) == 0) position[ridx] = -1;
}
}
}
}
{
// initialize feature index
unsigned ncol = static_cast<unsigned>(fmat.NumCol());
@ -175,7 +174,7 @@ class ColMaker: public IUpdater<FMatrix> {
/*! \brief initialize the base_weight, root_gain, and NodeEntry for all the new nodes in qexpand */
inline void InitNewNode(const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const IFMatrix &fmat,
const BoosterInfo &info,
const RegTree &tree) {
{// setup statistics space for each tree node
@ -222,24 +221,25 @@ class ColMaker: public IUpdater<FMatrix> {
qexpand = newnodes;
}
// enumerate the split values of specific feature
template<typename Iter>
inline void EnumerateSplit(Iter it, unsigned fid,
inline void EnumerateSplit(const ColBatch::Entry *begin,
const ColBatch::Entry *end,
int d_step,
bst_uint fid,
const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
std::vector<ThreadEntry> &temp,
bool is_forward_search) {
std::vector<ThreadEntry> &temp) {
// clear all the temp statistics
for (size_t j = 0; j < qexpand.size(); ++j) {
temp[qexpand[j]].stats.Clear();
}
// left statistics
TStats c(param);
while (it.Next()) {
const bst_uint ridx = it.rindex();
for(const ColBatch::Entry *it = begin; it != end; it += d_step) {
const bst_uint ridx = it->index;
const int nid = position[ridx];
if (nid < 0) continue;
// start working
const float fvalue = it.fvalue();
const float fvalue = it->fvalue;
// get the statistics of nid
ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init
@ -252,7 +252,7 @@ class ColMaker: public IUpdater<FMatrix> {
c.SetSubstract(snode[nid].stats, e.stats);
if (c.sum_hess >= param.min_child_weight) {
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
}
}
// update the statistics
@ -267,38 +267,46 @@ class ColMaker: public IUpdater<FMatrix> {
c.SetSubstract(snode[nid].stats, e.stats);
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain);
const float delta = is_forward_search ? rt_eps : -rt_eps;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
const float delta = d_step == +1 ? rt_eps : -rt_eps;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
}
}
}
// find splits at current level, do split per level
inline void FindSplit(int depth, const std::vector<int> &qexpand,
inline void FindSplit(int depth,
const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
RegTree *p_tree) {
std::vector<unsigned> feat_set = feat_index;
std::vector<bst_uint> feat_set = feat_index;
if (param.colsample_bylevel != 1.0f) {
random::Shuffle(feat_set);
unsigned n = static_cast<unsigned>(param.colsample_bylevel * feat_index.size());
utils::Check(n > 0, "colsample_bylevel is too small that no feature can be included");
feat_set.resize(n);
}
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(feat_set.size());
#if defined(_OPENMP)
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
#endif
#pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const unsigned fid = feat_set[i];
const int tid = omp_get_thread_num();
if (param.need_forward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, info, stemp[tid], true);
}
if (param.need_backward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, info, stemp[tid], false);
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(feat_set);
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#if defined(_OPENMP)
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
#endif
#pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const bst_uint fid = batch.col_index[i];
const int tid = omp_get_thread_num();
const ColBatch::Inst c = batch[i];
if (param.need_forward_search(p_fmat->GetColDensity(fid))) {
this->EnumerateSplit(c.data, c.data + c.length, +1,
fid, gpair, info, stemp[tid]);
}
if (param.need_backward_search(p_fmat->GetColDensity(fid))) {
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
fid, gpair, info, stemp[tid]);
}
}
}
// after this each thread's stemp will get the best candidates, aggregate results
@ -318,8 +326,8 @@ class ColMaker: public IUpdater<FMatrix> {
}
}
// reset position of each data points after split is created in the tree
inline void ResetPosition(const std::vector<int> &qexpand, const FMatrix &fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
// step 1, set default direct nodes to default, and leaf nodes to -1
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static)
@ -343,22 +351,28 @@ class ColMaker: public IUpdater<FMatrix> {
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
// start put things into right place
const bst_omp_uint nfeats = static_cast<bst_omp_uint>(fsplits.size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nfeats; ++i) {
const unsigned fid = fsplits[i];
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
const bst_uint ridx = it.rindex();
int nid = position[ridx];
if (nid == -1) continue;
// go back to parent, correct those who are not default
nid = tree[nid].parent();
if (tree[nid].split_index() == fid) {
if (it.fvalue() < tree[nid].split_cond()) {
position[ridx] = tree[nid].cleft();
} else {
position[ridx] = tree[nid].cright();
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fsplits);
while (iter->Next()) {
const ColBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
ColBatch::Inst col = batch[i];
const bst_uint fid = batch.col_index[i];
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue;
int nid = position[ridx];
if (nid == -1) continue;
// go back to parent, correct those who are not default
nid = tree[nid].parent();
if (tree[nid].split_index() == fid) {
if (fvalue < tree[nid].split_cond()) {
position[ridx] = tree[nid].cleft();
} else {
position[ridx] = tree[nid].cright();
}
}
}
}
@ -369,7 +383,7 @@ class ColMaker: public IUpdater<FMatrix> {
// number of omp thread used during training
int nthread;
// Per feature: shuffle index of each feature index
std::vector<unsigned> feat_index;
std::vector<bst_uint> feat_index;
// Instance Data: current node position in the tree of each instance
std::vector<int> position;
// PerThread x PerTreeNode: statistics for per thread construction

View File

@ -12,8 +12,7 @@
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix>
class TreePruner: public IUpdater<FMatrix> {
class TreePruner: public IUpdater {
public:
virtual ~TreePruner(void) {}
// set training parameter
@ -23,7 +22,7 @@ class TreePruner: public IUpdater<FMatrix> {
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees
@ -75,7 +74,6 @@ class TreePruner: public IUpdater<FMatrix> {
// training parameter
TrainParam param;
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_

View File

@ -9,12 +9,13 @@
#include <limits>
#include "./param.h"
#include "./updater.h"
#include "../utils/omp.h"
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix, typename TStats>
class TreeRefresher: public IUpdater<FMatrix> {
template<typename TStats>
class TreeRefresher: public IUpdater {
public:
virtual ~TreeRefresher(void) {}
// set training parameter
@ -23,7 +24,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
if (trees.size() == 0) return;
@ -50,16 +51,16 @@ class TreeRefresher: public IUpdater<FMatrix> {
fvec_temp[tid].Init(trees[0]->param.num_feature);
}
// start accumulating statistics
utils::IIterator<SparseBatch> *iter = fmat.RowIterator();
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const SparseBatch &batch = iter->Value();
const RowBatch &batch = iter->Value();
utils::Check(batch.size < std::numeric_limits<unsigned>::max(),
"too large batch size ");
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
SparseBatch::Inst inst = batch[i];
RowBatch::Inst inst = batch[i];
const int tid = omp_get_thread_num();
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
RegTree::FVec &feats = fvec_temp[tid];

View File

@ -8,6 +8,7 @@
#define _CRT_SECURE_NO_WARNINGS
#include <cstdio>
#include <cstdarg>
#include <string>
#include <cstdlib>
#ifdef _MSC_VER
#define fopen64 fopen

View File

@ -234,7 +234,7 @@ class BoostLearnTask{
std::vector<io::DataMatrix*> deval;
std::vector<const io::DataMatrix*> devalall;
utils::FeatMap fmap;
learner::BoostLearner<FMatrixS> learner;
learner::BoostLearner learner;
};
}

View File

@ -1,4 +1,4 @@
The solution has been created with Visual Studio Express 2013.
The solution has been created with Visual Studio Express 2010.
Make sure to compile the Release version, unless you need to debug the code
(and in the latter case modify the path in xgboost.py from release to test).
Note that you have two projects in one solution and they need to be compiled to use the standalone executable from the command line

View File

@ -1,11 +1,9 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Express 2013 for Windows Desktop
VisualStudioVersion = 12.0.30723.0
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboost", "xgboost\xgboost.vcxproj", "{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}"
Microsoft Visual Studio Solution File, Format Version 11.00
# Visual Studio 2010
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboost", "xgboost\xgboost.vcxproj", "{19766C3F-7508-49D0-BAAC-0988FCC9970C}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboost_wrapper", "xgboost_wrapper\xgboost_wrapper.vcxproj", "{2E1AF937-28BB-4832-B916-309C9A0F6C4F}"
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboost_wrapper", "xgboost_wrapper\xgboost_wrapper.vcxproj", "{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@ -15,22 +13,21 @@ Global
Release|x64 = Release|x64
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Debug|Win32.ActiveCfg = Debug|Win32
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Debug|Win32.Build.0 = Debug|Win32
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Debug|x64.ActiveCfg = Debug|x64
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Debug|x64.Build.0 = Debug|x64
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Release|Win32.ActiveCfg = Release|Win32
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Release|Win32.Build.0 = Release|Win32
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Release|x64.ActiveCfg = Release|x64
{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}.Release|x64.Build.0 = Release|x64
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Debug|Win32.ActiveCfg = Debug|Win32
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Debug|Win32.Build.0 = Debug|Win32
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Debug|x64.ActiveCfg = Debug|x64
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Debug|x64.Build.0 = Debug|x64
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Release|Win32.ActiveCfg = Release|Win32
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Release|Win32.Build.0 = Release|Win32
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Release|x64.ActiveCfg = Release|x64
{2E1AF937-28BB-4832-B916-309C9A0F6C4F}.Release|x64.Build.0 = Release|x64
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Debug|Win32.ActiveCfg = Debug|Win32
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Debug|Win32.Build.0 = Debug|Win32
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Debug|x64.ActiveCfg = Release|x64
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Debug|x64.Build.0 = Release|x64
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Release|Win32.ActiveCfg = Release|Win32
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Release|Win32.Build.0 = Release|Win32
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Release|x64.ActiveCfg = Release|x64
{19766C3F-7508-49D0-BAAC-0988FCC9970C}.Release|x64.Build.0 = Release|x64
{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}.Debug|Win32.ActiveCfg = Debug|Win32
{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}.Debug|Win32.Build.0 = Debug|Win32
{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}.Debug|x64.ActiveCfg = Debug|Win32
{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}.Release|Win32.ActiveCfg = Release|Win32
{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}.Release|Win32.Build.0 = Release|Win32
{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}.Release|x64.ActiveCfg = Release|x64
{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}.Release|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE

View File

@ -1,5 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Project DefaultTargets="Build" ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
@ -18,8 +18,14 @@
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\src\gbm\gbm.cpp" />
<ClCompile Include="..\..\src\io\io.cpp" />
<ClCompile Include="..\..\src\tree\updater.cpp" />
<ClCompile Include="..\..\src\xgboost_main.cpp" />
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{1D6A56A5-5557-4D20-9D50-3DE4C30BE00C}</ProjectGuid>
<ProjectGuid>{19766C3F-7508-49D0-BAAC-0988FCC9970C}</ProjectGuid>
<RootNamespace>xgboost</RootNamespace>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
@ -27,27 +33,23 @@
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<CharacterSet>MultiByte</CharacterSet>
<PlatformToolset>v120</PlatformToolset>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<CharacterSet>MultiByte</CharacterSet>
<PlatformToolset>v120</PlatformToolset>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
<PlatformToolset>v120</PlatformToolset>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
<PlatformToolset>v120</PlatformToolset>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
@ -111,10 +113,6 @@
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClCompile Include="..\..\src\io\io.cpp" />
<ClCompile Include="..\..\src\xgboost_main.cpp" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>

View File

@ -1,5 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Project DefaultTargets="Build" ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
@ -18,40 +18,38 @@
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\src\gbm\gbm.cpp" />
<ClCompile Include="..\..\src\io\io.cpp" />
<ClCompile Include="..\..\src\tree\updater.cpp" />
<ClCompile Include="..\..\wrapper\xgboost_wrapper.cpp" />
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{2E1AF937-28BB-4832-B916-309C9A0F6C4F}</ProjectGuid>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<Keyword>ManagedCProj</Keyword>
<ProjectGuid>{B0E22ADD-7849-4D3A-BDC6-0932C5F11ED5}</ProjectGuid>
<RootNamespace>xgboost_wrapper</RootNamespace>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CLRSupport>true</CLRSupport>
<CharacterSet>Unicode</CharacterSet>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CLRSupport>true</CLRSupport>
<CharacterSet>Unicode</CharacterSet>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CLRSupport>true</CLRSupport>
<CharacterSet>Unicode</CharacterSet>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CLRSupport>true</CLRSupport>
<CharacterSet>Unicode</CharacterSet>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
@ -69,85 +67,53 @@
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<PropertyGroup />
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies />
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<OpenMPSupport>true</OpenMPSupport>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>
</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PreprocessorDefinitions>WIN32;NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PrecompiledHeader>Use</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<OpenMPSupport>true</OpenMPSupport>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies />
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PreprocessorDefinitions>WIN32;NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<OpenMPSupport>true</OpenMPSupport>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>
</AdditionalDependencies>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<Reference Include="System" />
<Reference Include="System.Data" />
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="..\..\src\io\io.h" />
<ClInclude Include="..\..\src\io\simple_dmatrix-inl.hpp" />
<ClInclude Include="..\..\wrapper\xgboost_wrapper.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\src\io\io.cpp" />
<ClCompile Include="..\..\wrapper\xgboost_wrapper.cpp" />
</ItemGroup>
<ItemGroup>
<None Include="..\..\wrapper\xgboost.py" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>

View File

@ -1,126 +0,0 @@
# include xgboost library, must set chdir=TRURE
source("../xgboost.R", chdir=TRUE)
# helper function to read libsvm format
# this is very badly written, load in dense, and convert to sparse
# use this only for demo purpose
# adopted from https://github.com/zygmuntz/r-libsvm-format-read-write/blob/master/f_read.libsvm.r
read.libsvm <- function(fname, maxcol) {
content <- readLines(fname)
nline <- length(content)
label <- numeric(nline)
mat <- matrix(0, nline, maxcol+1)
for (i in 1:nline) {
arr <- as.vector(strsplit(content[i], " ")[[1]])
label[i] <- as.numeric(arr[[1]])
for (j in 2:length(arr)) {
kv <- strsplit(arr[j], ":")[[1]]
# to avoid 0 index
findex <- as.integer(kv[1]) + 1
fvalue <- as.numeric(kv[2])
mat[i,findex] <- fvalue
}
}
mat <- as(mat, "sparseMatrix")
return(list(label=label, data=mat))
}
# test code here
dtrain <- xgb.DMatrix("agaricus.txt.train")
dtest <- xgb.DMatrix("agaricus.txt.test")
param = list("bst:max_depth"=2, "bst:eta"=1, "silent"=1, "objective"="binary:logistic")
watchlist <- list("eval"=dtest,"train"=dtrain)
# training xgboost model
bst <- xgb.train(param, dtrain, nround=2, watchlist=watchlist)
# make prediction
preds <- xgb.predict(bst, dtest)
labels <- xgb.getinfo(dtest, "label")
err <- as.numeric(sum(as.integer(preds > 0.5) != labels)) / length(labels)
# print error rate
print(paste("error=",err))
# dump model
xgb.dump(bst, "dump.raw.txt")
# dump model with feature map
xgb.dump(bst, "dump.nice.txt", "featmap.txt")
# save dmatrix into binary buffer
succ <- xgb.save(dtest, "dtest.buffer")
# save model into file
succ <- xgb.save(bst, "xgb.model")
# load model and data in
bst2 <- xgb.Booster(modelfile="xgb.model")
dtest2 <- xgb.DMatrix("dtest.buffer")
preds2 <- xgb.predict(bst2, dtest2)
# assert they are the same
stopifnot(sum(abs(preds2-preds)) == 0)
###
# build dmatrix from sparseMatrix
###
print ('start running example of build DMatrix from R.sparseMatrix')
csc <- read.libsvm("agaricus.txt.train", 126)
label <- csc$label
data <- csc$data
dtrain <- xgb.DMatrix(data, info=list(label=label) )
watchlist <- list("eval"=dtest,"train"=dtrain)
bst <- xgb.train(param, dtrain, nround=2, watchlist=watchlist)
###
# build dmatrix from dense matrix
###
print ('start running example of build DMatrix from R.Matrix')
mat = as.matrix(data)
dtrain <- xgb.DMatrix(mat, info=list(label=label) )
watchlist <- list("eval"=dtest,"train"=dtrain)
bst <- xgb.train(param, dtrain, nround=2, watchlist=watchlist)
###
# advanced: cutomsized loss function
#
print("start running example to used cutomized objective function")
# note: for customized objective function, we leave objective as default
# note: what we are getting is margin value in prediction
# you must know what you are doing
param <- list("bst:max_depth" = 2, "bst:eta" = 1, "silent" =1)
# user define objective function, given prediction, return gradient and second order gradient
# this is loglikelihood loss
logregobj <- function(preds, dtrain) {
labels <- xgb.getinfo(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels
hess <- preds * (1.0-preds)
return(list(grad=grad, hess=hess))
}
# user defined evaluation function, return a list(metric="metric-name", value="metric-value")
# NOTE: when you do customized loss function, the default prediction value is margin
# this may make buildin evalution metric not function properly
# for example, we are doing logistic loss, the prediction is score before logistic transformation
# the buildin evaluation error assumes input is after logistic transformation
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) {
labels <- xgb.getinfo(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0.0))) / length(labels)
return(list(metric="error", value=err))
}
# training with customized objective, we can also do step by step training
# simply look at xgboost.py"s implementation of train
bst <- xgb.train(param, dtrain, nround=2, watchlist, logregobj, evalerror)
###
# advanced: start from a initial base prediction
#
print ("start running example to start from a initial prediction")
# specify parameters via map, definition are same as c++ version
param = list("bst:max_depth"=2, "bst:eta"=1, "silent"=1, "objective"="binary:logistic")
# train xgboost for 1 round
bst <- xgb.train( param, dtrain, 1, watchlist )
# Note: we need the margin value instead of transformed prediction in set_base_margin
# do predict with output_margin=True, will always give you margin values before logistic transformation
ptrain <- xgb.predict(bst, dtrain, outputmargin=TRUE)
ptest <- xgb.predict(bst, dtest, outputmargin=TRUE)
succ <- xgb.setinfo(dtrain, "base_margin", ptrain)
succ <- xgb.setinfo(dtest, "base_margin", ptest)
print ("this is result of running from initial prediction")
bst <- xgb.train( param, dtrain, 1, watchlist )

View File

@ -10,6 +10,4 @@ Python
R
=====
* To make the R wrapper, type ```make R``` in the root directory of project
* R module need Rinternals.h, find the path in your system and add it to CPLUS_INCLUDE_PATH in Makefile
* Refer to the walk through example in [R-example/demo.R](R-example/demo.R)
* See ../R-package

View File

@ -1,222 +0,0 @@
# depends on matrix
succ <- require("Matrix")
if (!succ) {
stop("xgboost depends on Matrix library")
}
# load in library
dyn.load("./libxgboostR.so")
# constructing DMatrix
xgb.DMatrix <- function(data, info=list(), missing=0.0) {
if (typeof(data) == "character") {
handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE))
} else if(is.matrix(data)) {
handle <- .Call("XGDMatrixCreateFromMat_R", data, missing)
} else if(class(data) == "dgCMatrix") {
handle <- .Call("XGDMatrixCreateFromCSC_R", data@p, data@i, data@x)
} else {
stop(paste("xgb.DMatrix: does not support to construct from ", typeof(data)))
}
dmat <- structure(handle, class="xgb.DMatrix")
if (length(info) != 0) {
for (i in 1:length(info)) {
p <- info[i]
xgb.setinfo(dmat, names(p), p[[1]])
}
}
return(dmat)
}
# get information from dmatrix
xgb.getinfo <- function(dmat, name) {
if (typeof(name) != "character") {
stop("xgb.getinfo: name must be character")
}
if (class(dmat) != "xgb.DMatrix") {
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix");
}
if (name != "label" &&
name != "weight" &&
name != "base_margin" ) {
stop(paste("xgb.getinfo: unknown info name", name))
}
ret <- .Call("XGDMatrixGetInfo_R", dmat, name)
return(ret)
}
# set information into dmatrix, this mutate dmatrix
xgb.setinfo <- function(dmat, name, info) {
if (class(dmat) != "xgb.DMatrix") {
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix");
}
if (name == "label") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info))
return(TRUE)
}
if (name == "weight") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info))
return(TRUE)
}
if (name == "base_margin") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info))
return(TRUE)
}
if (name == "group") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info))
return(TRUE)
}
stop(pase("xgb.setinfo: unknown info name", name))
return(FALSE)
}
# construct a Booster from cachelist
xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
if (typeof(cachelist) != "list") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
}
for (dm in cachelist) {
if (class(dm) != "xgb.DMatrix") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
}
}
handle <- .Call("XGBoosterCreate_R", cachelist)
.Call("XGBoosterSetParam_R", handle, "seed", "0")
if (length(params) != 0) {
for (i in 1:length(params)) {
p <- params[i]
.Call("XGBoosterSetParam_R", handle, names(p), as.character(p))
}
}
if (!is.null(modelfile)) {
if (typeof(modelfile) != "character"){
stop("xgb.Booster: modelfile must be character");
}
.Call("XGBoosterLoadModel_R", handle, modelfile)
}
return(structure(handle, class="xgb.Booster"))
}
# train a model using given parameters
xgb.train <- function(params, dtrain, nrounds=10, watchlist=list(), obj=NULL, feval=NULL) {
if (typeof(params) != "list") {
stop("xgb.train: first argument params must be list");
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.train: second argument dtrain must be xgb.DMatrix");
}
bst <- xgb.Booster(params, append(watchlist,dtrain))
for (i in 1:nrounds) {
if (is.null(obj)) {
succ <- xgb.iter.update(bst, dtrain, i-1)
} else {
pred <- xgb.predict(bst, dtrain)
gpair <- obj(pred, dtrain)
succ <- xgb.iter.boost(bst, dtrain, gpair)
}
if (length(watchlist) != 0) {
if (is.null(feval)) {
msg <- xgb.iter.eval(bst, watchlist, i-1)
cat(msg); cat("\n")
} else {
cat("["); cat(i); cat("]");
for (j in 1:length(watchlist)) {
w <- watchlist[j]
if (length(names(w)) == 0) {
stop("xgb.eval: name tag must be presented for every elements in watchlist")
}
ret <- feval(xgb.predict(bst, w[[1]]), w[[1]])
cat("\t"); cat(names(w)); cat("-"); cat(ret$metric);
cat(":"); cat(ret$value)
}
cat("\n")
}
}
}
return(bst)
}
# save model or DMatrix to file
xgb.save <- function(handle, fname) {
if (typeof(fname) != "character") {
stop("xgb.save: fname must be character");
}
if (class(handle) == "xgb.Booster") {
.Call("XGBoosterSaveModel_R", handle, fname);
return(TRUE)
}
if (class(handle) == "xgb.DMatrix") {
.Call("XGDMatrixSaveBinary_R", handle, fname, as.integer(FALSE))
return(TRUE)
}
stop("xgb.save: the input must be either xgb.DMatrix or xgb.Booster")
return(FALSE)
}
# predict
xgb.predict <- function(booster, dmat, outputmargin = FALSE) {
if (class(booster) != "xgb.Booster") {
stop("xgb.predict: first argument must be type xgb.Booster")
}
if (class(dmat) != "xgb.DMatrix") {
stop("xgb.predict: second argument must be type xgb.DMatrix")
}
ret <- .Call("XGBoosterPredict_R", booster, dmat, as.integer(outputmargin))
return(ret)
}
# dump model
xgb.dump <- function(booster, fname, fmap = "") {
if (class(booster) != "xgb.Booster") {
stop("xgb.dump: first argument must be type xgb.Booster")
}
if (typeof(fname) != "character"){
stop("xgb.dump: second argument must be type character")
}
.Call("XGBoosterDumpModel_R", booster, fname, fmap)
return(TRUE)
}
##--------------------------------------
# the following are low level iteratively function, not needed
# if you do not want to use them
#---------------------------------------
# iteratively update booster with dtrain
xgb.iter.update <- function(booster, dtrain, iter) {
if (class(booster) != "xgb.Booster") {
stop("xgb.iter.update: first argument must be type xgb.Booster")
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
}
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain)
return(TRUE)
}
# iteratively update booster with customized statistics
xgb.iter.boost <- function(booster, dtrain, gpair) {
if (class(booster) != "xgb.Booster") {
stop("xgb.iter.update: first argument must be type xgb.Booster")
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
}
.Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess)
return(TRUE)
}
# iteratively evaluate one iteration
xgb.iter.eval <- function(booster, watchlist, iter) {
if (class(booster) != "xgb.Booster") {
stop("xgb.eval: first argument must be type xgb.Booster")
}
if (typeof(watchlist) != "list") {
stop("xgb.eval: only accepts list of DMatrix as watchlist")
}
for (w in watchlist) {
if (class(w) != "xgb.DMatrix") {
stop("xgb.eval: watch list can only contain xgb.DMatrix")
}
}
evnames <- list()
if (length(watchlist) != 0) {
for (i in 1:length(watchlist)) {
w <- watchlist[i]
if (length(names(w)) == 0) {
stop("xgb.eval: name tag must be presented for every elements in watchlist")
}
evnames <- append(evnames, names(w))
}
}
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames)
return(msg)
}

View File

@ -12,7 +12,7 @@ import scipy.sparse as scp
if os.name == 'nt':
XGBOOST_PATH = os.path.dirname(__file__)+'/../windows/x64/Release/xgboost_wrapper.dll'
else:
XGBOOST_PATH = os.path.dirname(__file__)+'/../libxgboostwrapper.so'
XGBOOST_PATH = os.path.dirname(__file__)+'/libxgboostwrapper.so'
# load in xgboost library
xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH)

View File

@ -1,221 +0,0 @@
#include <vector>
#include <string>
#include <utility>
#include <cstring>
#include "xgboost_R.h"
#include "xgboost_wrapper.h"
#include "../src/utils/utils.h"
#include "../src/utils/omp.h"
#include "../src/utils/matrix_csr.h"
using namespace xgboost;
// implements error handling
namespace xgboost {
namespace utils {
void HandleAssertError(const char *msg) {
error("%s", msg);
}
void HandleCheckError(const char *msg) {
error("%s", msg);
}
} // namespace utils
} // namespace xgboost
extern "C" {
void _DMatrixFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return;
XGDMatrixFree(R_ExternalPtrAddr(ext));
R_ClearExternalPtr(ext);
}
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
void *handle = XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent));
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
SEXP XGDMatrixCreateFromMat_R(SEXP mat,
SEXP missing) {
SEXP dim = getAttrib(mat, R_DimSymbol);
int nrow = INTEGER(dim)[0];
int ncol = INTEGER(dim)[1];
double *din = REAL(mat);
std::vector<float> data(nrow * ncol);
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
data[i * ncol +j] = din[i + nrow * j];
}
}
void *handle = XGDMatrixCreateFromMat(&data[0], nrow, ncol, asReal(missing));
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
SEXP indices,
SEXP data) {
const int *col_ptr = INTEGER(indptr);
const int *row_index = INTEGER(indices);
const double *col_data = REAL(data);
int ncol = length(indptr) - 1;
int ndata = length(data);
// transform into CSR format
std::vector<bst_ulong> row_ptr;
std::vector< std::pair<unsigned, float> > csr_data;
utils::SparseCSRMBuilder<std::pair<unsigned,float>, false, bst_ulong> builder(row_ptr, csr_data);
builder.InitBudget();
for (int i = 0; i < ncol; ++i) {
for (int j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
builder.AddBudget(row_index[j]);
}
}
builder.InitStorage();
for (int i = 0; i < ncol; ++i) {
for (int j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
builder.PushElem(row_index[j], std::make_pair(i, col_data[j]));
}
}
utils::Assert(csr_data.size() == static_cast<size_t>(ndata), "BUG CreateFromCSC");
std::vector<float> row_data(ndata);
std::vector<unsigned> col_index(ndata);
#pragma omp parallel for schedule(static)
for (int i = 0; i < ndata; ++i) {
col_index[i] = csr_data[i].first;
row_data[i] = csr_data[i].second;
}
void *handle = XGDMatrixCreateFromCSR(&row_ptr[0], &col_index[0], &row_data[0], row_ptr.size(), ndata );
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)), asInteger(silent));
}
void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
int len = length(array);
const char *name = CHAR(asChar(field));
if (!strcmp("group", name)) {
std::vector<unsigned> vec(len);
#pragma omp parallel for schedule(static)
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
}
XGDMatrixSetGroup(R_ExternalPtrAddr(handle), &vec[0], len);
return;
}
{
std::vector<float> vec(len);
#pragma omp parallel for schedule(static)
for (int i = 0; i < len; ++i) {
vec[i] = REAL(array)[i];
}
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
&vec[0], len);
}
}
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
bst_ulong olen;
const float *res = XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)), &olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i];
}
UNPROTECT(1);
return ret;
}
// functions related to booster
void _BoosterFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return;
XGBoosterFree(R_ExternalPtrAddr(ext));
R_ClearExternalPtr(ext);
}
SEXP XGBoosterCreate_R(SEXP dmats) {
int len = length(dmats);
std::vector<void*> dvec;
for (int i = 0; i < len; ++i){
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
void *handle = XGBoosterCreate(&dvec[0], dvec.size());
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val)));
}
void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain));
}
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
utils::Check(length(grad) == length(hess), "gradient and hess must have same length");
int len = length(grad);
std::vector<float> tgrad(len), thess(len);
#pragma omp parallel for schedule(static)
for (int j = 0; j < len; ++j) {
tgrad[j] = REAL(grad)[j];
thess[j] = REAL(hess)[j];
}
XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dtrain),
&tgrad[0], &thess[0], len);
}
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length");
int len = length(dmats);
std::vector<void*> vec_dmats;
std::vector<std::string> vec_names;
std::vector<const char*> vec_sptr;
for (int i = 0; i < len; ++i) {
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
vec_names.push_back(std::string(CHAR(asChar(VECTOR_ELT(evnames, i)))));
}
for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str());
}
return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
&vec_dmats[0], &vec_sptr[0], len));
}
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) {
bst_ulong olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat),
asInteger(output_margin),
&olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i];
}
UNPROTECT(1);
return ret;
}
void XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
}
void XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
}
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap) {
bst_ulong olen;
const char **res = XGBoosterDumpModel(R_ExternalPtrAddr(handle),
CHAR(asChar(fmap)),
&olen);
FILE *fo = utils::FopenCheck(CHAR(asChar(fname)), "w");
for (size_t i = 0; i < olen; ++i) {
fprintf(fo, "booster[%u]:\n", static_cast<unsigned>(i));
fprintf(fo, "%s", res[i]);
}
fclose(fo);
}
}

View File

@ -1,124 +0,0 @@
#ifndef XGBOOST_WRAPPER_R_H_
#define XGBOOST_WRAPPER_R_H_
/*!
* \file xgboost_wrapper_R.h
* \author Tianqi Chen
* \brief R wrapper of xgboost
*/
extern "C" {
#include <Rinternals.h>
}
extern "C" {
/*!
* \brief load a data matrix
* \param fname name of the content
* \param silent whether print messages
* \return a loaded data matrix
*/
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent);
/*!
* \brief create matrix content from dense matrix
* This assumes the matrix is stored in column major format
* \param data R Matrix object
* \param missing which value to represent missing value
* \return created dmatrix
*/
SEXP XGDMatrixCreateFromMat_R(SEXP mat,
SEXP missing);
/*!
* \brief create a matrix content from CSC format
* \param indptr pointer to column headers
* \param indices row indices
* \param data content of the data
* \return created dmatrix
*/
SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
SEXP indices,
SEXP data);
/*!
* \brief load a data matrix into binary file
* \param handle a instance of data matrix
* \param fname file name
* \param silent print statistics when saving
*/
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);
/*!
* \brief set information to dmatrix
* \param handle a instance of data matrix
* \param field field name, can be label, weight
* \param array pointer to float vector
*/
void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);
/*!
* \brief get info vector from matrix
* \param handle a instance of data matrix
* \param field field name
* \return info vector
*/
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field);
/*!
* \brief create xgboost learner
* \param dmats a list of dmatrix handles that will be cached
*/
SEXP XGBoosterCreate_R(SEXP dmats);
/*!
* \brief set parameters
* \param handle handle
* \param name parameter name
* \param val value of parameter
*/
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);
/*!
* \brief update the model in one round using dtrain
* \param handle handle
* \param iter current iteration rounds
* \param dtrain training data
*/
void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
/*!
* \brief update the model, by directly specify gradient and second order gradient,
* this can be used to replace UpdateOneIter, to support customized loss function
* \param handle handle
* \param dtrain training data
* \param grad gradient statistics
* \param hess second order gradient statistics
*/
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess);
/*!
* \brief get evaluation statistics for xgboost
* \param handle handle
* \param iter current iteration rounds
* \param dmats list of handles to dmatrices
* \param evname name of evaluation
* \return the string containing evaluation stati
*/
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);
/*!
* \brief make prediction based on dmat
* \param handle handle
* \param dmat data matrix
* \param output_margin whether only output raw margin value
*/
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin);
/*!
* \brief load model from existing file
* \param handle handle
* \param fname file name
*/
void XGBoosterLoadModel_R(SEXP handle, SEXP fname);
/*!
* \brief save model into existing file
* \param handle handle
* \param fname file name
*/
void XGBoosterSaveModel_R(SEXP handle, SEXP fname);
/*!
* \brief dump model into text file
* \param handle handle
* \param fname file name of model that can be dumped into
* \param fmap name to fmap can be empty string
*/
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap);
};
#endif // XGBOOST_WRAPPER_R_H_

View File

@ -1,4 +1,6 @@
// implementations in ctypes
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#include <cstdio>
#include <vector>
#include <string>
@ -16,7 +18,7 @@ using namespace xgboost::io;
namespace xgboost {
namespace wrapper {
// booster wrapper class
class Booster: public learner::BoostLearner<FMatrixS> {
class Booster: public learner::BoostLearner {
public:
explicit Booster(const std::vector<DataMatrix*>& mats) {
this->silent = 1;
@ -25,8 +27,8 @@ class Booster: public learner::BoostLearner<FMatrixS> {
}
const float *Pred(const DataMatrix &dmat, int output_margin, bst_ulong *len) {
this->CheckInitModel();
this->Predict(dmat, output_margin, &this->preds_);
*len = this->preds_.size();
this->Predict(dmat, output_margin != 0, &this->preds_);
*len = static_cast<bst_ulong>(this->preds_.size());
return &this->preds_[0];
}
inline void BoostOneIter(const DataMatrix &train,
@ -37,7 +39,7 @@ class Booster: public learner::BoostLearner<FMatrixS> {
for (bst_omp_uint j = 0; j < ndata; ++j) {
gpair_[j] = bst_gpair(grad[j], hess[j]);
}
gbm_->DoBoost(train.fmat, train.info.info, &gpair_);
gbm_->DoBoost(train.fmat(), train.info.info, &gpair_);
}
inline void CheckInitModel(void) {
if (!init_model) {
@ -45,7 +47,7 @@ class Booster: public learner::BoostLearner<FMatrixS> {
}
}
inline void LoadModel(const char *fname) {
learner::BoostLearner<FMatrixS>::LoadModel(fname);
learner::BoostLearner::LoadModel(fname);
this->init_model = true;
}
inline const char** GetModelDump(const utils::FeatMap& fmap, bool with_stats, bst_ulong *len) {
@ -54,7 +56,7 @@ class Booster: public learner::BoostLearner<FMatrixS> {
for (size_t i = 0; i < model_dump.size(); ++i) {
model_dump_cptr[i] = model_dump[i].c_str();
}
*len = model_dump.size();
*len = static_cast<bst_ulong>(model_dump.size());
return &model_dump_cptr[0];
}
// temporal fields
@ -74,7 +76,7 @@ using namespace xgboost::wrapper;
extern "C"{
void* XGDMatrixCreateFromFile(const char *fname, int silent) {
return LoadDataMatrix(fname, silent, false);
return LoadDataMatrix(fname, silent != 0, false);
}
void* XGDMatrixCreateFromCSR(const bst_ulong *indptr,
const unsigned *indices,
@ -89,7 +91,7 @@ extern "C"{
}
mat.row_data_.resize(nelem);
for (bst_ulong i = 0; i < nelem; ++i) {
mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]);
mat.row_data_[i] = RowBatch::Entry(indices[i], data[i]);
mat.info.info.num_col = std::max(mat.info.info.num_col,
static_cast<size_t>(indices[i]+1));
}
@ -108,7 +110,7 @@ extern "C"{
bst_ulong nelem = 0;
for (bst_ulong j = 0; j < ncol; ++j) {
if (data[j] != missing) {
mat.row_data_.push_back(SparseBatch::Entry(j, data[j]));
mat.row_data_.push_back(RowBatch::Entry(j, data[j]));
++nelem;
}
}
@ -135,17 +137,17 @@ extern "C"{
ret.info.info.num_row = len;
ret.info.info.num_col = src.info.num_col();
utils::IIterator<SparseBatch> *iter = src.fmat.RowIterator();
utils::IIterator<RowBatch> *iter = src.fmat()->RowIterator();
iter->BeforeFirst();
utils::Assert(iter->Next(), "slice");
const SparseBatch &batch = iter->Value();
const RowBatch &batch = iter->Value();
for (bst_ulong i = 0; i < len; ++i) {
const int ridx = idxset[i];
SparseBatch::Inst inst = batch[ridx];
RowBatch::Inst inst = batch[ridx];
utils::Check(static_cast<bst_ulong>(ridx) < batch.size, "slice index exceed number of rows");
ret.row_data_.resize(ret.row_data_.size() + inst.length);
memcpy(&ret.row_data_[ret.row_ptr_.back()], inst.data,
sizeof(SparseBatch::Entry) * inst.length);
sizeof(RowBatch::Entry) * inst.length);
ret.row_ptr_.push_back(ret.row_ptr_.back() + inst.length);
if (src.info.labels.size() != 0) {
ret.info.labels.push_back(src.info.labels[ridx]);
@ -156,6 +158,9 @@ extern "C"{
if (src.info.info.root_index.size() != 0) {
ret.info.info.root_index.push_back(src.info.info.root_index[ridx]);
}
if (src.info.info.fold_index.size() != 0) {
ret.info.info.fold_index.push_back(src.info.info.fold_index[ridx]);
}
}
return p_ret;
}
@ -163,7 +168,7 @@ extern "C"{
delete static_cast<DataMatrix*>(handle);
}
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent) {
SaveDataMatrix(*static_cast<DataMatrix*>(handle), fname, silent);
SaveDataMatrix(*static_cast<DataMatrix*>(handle), fname, silent != 0);
}
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, bst_ulong len) {
std::vector<float> &vec =
@ -181,24 +186,24 @@ extern "C"{
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
pmat->info.group_ptr.resize(len + 1);
pmat->info.group_ptr[0] = 0;
for (bst_ulong i = 0; i < len; ++i) {
for (uint64_t i = 0; i < len; ++i) {
pmat->info.group_ptr[i+1] = pmat->info.group_ptr[i]+group[i];
}
}
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, bst_ulong* len) {
const std::vector<float> &vec =
static_cast<const DataMatrix*>(handle)->info.GetFloatInfo(field);
*len = vec.size();
*len = static_cast<bst_ulong>(vec.size());
return &vec[0];
}
const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, bst_ulong* len) {
const std::vector<unsigned> &vec =
static_cast<const DataMatrix*>(handle)->info.GetUIntInfo(field);
*len = vec.size();
*len = static_cast<bst_ulong>(vec.size());
return &vec[0];
}
bst_ulong XGDMatrixNumRow(const void *handle) {
return static_cast<const DataMatrix*>(handle)->info.num_row();
return static_cast<bst_ulong>(static_cast<const DataMatrix*>(handle)->info.num_row());
}
// xgboost implementation

View File

@ -15,6 +15,7 @@
// manually define unsign long
typedef unsigned long bst_ulong;
extern "C" {
/*!
* \brief load a data matrix