add predict leaf indices

This commit is contained in:
tqchen 2014-11-21 09:32:09 -08:00
parent 6ed82edad7
commit 168bb0d0c9
11 changed files with 114 additions and 29 deletions

View File

@ -32,6 +32,8 @@ This is a list of short codes introducing different functionalities of xgboost a
[python](guide-python/cross_validation.py)
[R](../R-package/demo/cross_validation.R)
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/cross_validation.jl)
* Predicting leaf indices
[python](guide-python/predict_leaf_indices.py)
Basic Examples by Tasks
====

View File

@ -6,3 +6,4 @@ XGBoost Python Feature Walkthrough
* [Predicting using first n trees](predict_first_ntree.py)
* [Generalized Linear Model](generalized_linear_model.py)
* [Cross validation](cross_validation.py)
* [Predicting leaf indices](predict_leaf_indices.py)

View File

@ -0,0 +1,22 @@
#!/usr/bin/python
import sys
import numpy as np
sys.path.append('../../wrapper')
import xgboost as xgb
### load data in do training
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
dtest = xgb.DMatrix('../data/agaricus.txt.test')
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 3
bst = xgb.train(param, dtrain, num_round, watchlist)
print ('start testing predict the leaf indices')
### predict using first 2 tree
leafindex = bst.predict(dtest, ntree_limit=2, pred_leaf = True)
print leafindex.shape
print leafindex
### predict all trees
leafindex = bst.predict(dtest, pred_leaf = True)
print leafindex.shape

View File

@ -4,4 +4,5 @@ python custom_objective.py
python boost_from_prediction.py
python generalized_linear_model.py
python cross_validation.py
python predict_leaf_index.py
rm -rf *~ *.model *.buffer

View File

@ -135,6 +135,12 @@ class GBLinear : public IGradBooster {
}
}
}
virtual void PredictLeaf(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<float> *out_preds,
unsigned ntree_limit = 0) {
utils::Error("gblinear does not support predict leaf index");
}
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
utils::Error("gblinear does not support dump model");
return std::vector<std::string>();

View File

@ -74,6 +74,20 @@ class IGradBooster {
const BoosterInfo &info,
std::vector<float> *out_preds,
unsigned ntree_limit = 0) = 0;
/*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor
* \param p_fmat feature matrix
* \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void PredictLeaf(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<float> *out_preds,
unsigned ntree_limit = 0) = 0;
/*!
* \brief dump the model in text format
* \param fmap feature map that may help give interpretations of feature

View File

@ -126,11 +126,6 @@ class GBTree : public IGradBooster {
for (int i = 0; i < nthread; ++i) {
thread_temp[i].Init(mparam.num_feature);
}
if (tparam.pred_path != 0) {
this->PredPath(p_fmat, info, out_preds);
return;
}
std::vector<float> &preds = *out_preds;
const size_t stride = info.num_row * mparam.num_output_group;
preds.resize(stride * (mparam.size_leaf_vector+1));
@ -158,6 +153,22 @@ class GBTree : public IGradBooster {
}
}
}
virtual void PredictLeaf(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<float> *out_preds,
unsigned ntree_limit) {
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
thread_temp.resize(nthread, tree::RegTree::FVec());
for (int i = 0; i < nthread; ++i) {
thread_temp[i].Init(mparam.num_feature);
}
this->PredPath(p_fmat, info, out_preds, ntree_limit);
}
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
std::vector<std::string> dump;
for (size_t i = 0; i < trees.size(); i++) {
@ -309,9 +320,14 @@ class GBTree : public IGradBooster {
// predict independent leaf index
inline void PredPath(IFMatrix *p_fmat,
const BoosterInfo &info,
std::vector<float> *out_preds) {
std::vector<float> *out_preds,
unsigned ntree_limit) {
// number of valid trees
if (ntree_limit == 0 || ntree_limit > trees.size()) {
ntree_limit = trees.size();
}
std::vector<float> &preds = *out_preds;
preds.resize(info.num_row * mparam.num_trees);
preds.resize(info.num_row * ntree_limit);
// start collecting the prediction
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst();
@ -325,9 +341,9 @@ class GBTree : public IGradBooster {
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
tree::RegTree::FVec &feats = thread_temp[tid];
feats.Fill(batch[i]);
for (size_t j = 0; j < trees.size(); ++j) {
for (unsigned j = 0; j < ntree_limit; ++j) {
int tid = trees[j]->GetLeafIndex(feats, info.GetRoot(ridx));
preds[ridx * mparam.num_trees + j] = static_cast<float>(tid);
preds[ridx * ntree_limit + j] = static_cast<float>(tid);
}
feats.Drop(batch[i]);
}
@ -344,8 +360,6 @@ class GBTree : public IGradBooster {
* use this option to support boosted random forest
*/
int num_parallel_tree;
/*! \brief predict path in prediction */
int pred_path;
/*! \brief whether updater is already initialized */
int updater_initialized;
/*! \brief tree updater sequence */
@ -356,7 +370,6 @@ class GBTree : public IGradBooster {
updater_seq = "grow_colmaker,prune";
num_parallel_tree = 1;
updater_initialized = 0;
pred_path = 0;
}
inline void SetParam(const char *name, const char *val){
using namespace std;
@ -371,7 +384,6 @@ class GBTree : public IGradBooster {
if (!strcmp(name, "num_parallel_tree")) {
num_parallel_tree = atoi(val);
}
if (!strcmp(name, "pred_path")) pred_path = atoi(val);
}
};
/*! \brief model parameters */

View File

@ -280,12 +280,18 @@ class BoostLearner {
inline void Predict(const DMatrix &data,
bool output_margin,
std::vector<float> *out_preds,
unsigned ntree_limit = 0) const {
unsigned ntree_limit = 0,
bool pred_leaf = false
) const {
if (pred_leaf) {
gbm_->PredictLeaf(data.fmat(), data.info.info, out_preds, ntree_limit);
} else {
this->PredictRaw(data, out_preds, ntree_limit);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
}
}
/*! \brief dump model out */
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
return gbm_->DumpModel(fmap, option);

View File

@ -333,7 +333,7 @@ class Booster:
return res
def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it)
def predict(self, data, output_margin=False, ntree_limit=0):
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
"""
predict with data
Args:
@ -343,13 +343,28 @@ class Booster:
whether output raw margin value that is untransformed
ntree_limit: int
limit number of trees in prediction, default to 0, 0 means using all the trees
pred_leaf: bool
when this option is on, the output will be a matrix of (nsample, ntrees)
with each record indicate the predicted leaf index of each sample in each tree
Note that the leaf index of tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0
Returns:
numpy array of prediction
"""
option_mask = 0
if output_margin:
option_mask += 1
if pred_leaf:
option_mask += 2
length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle,
int(output_margin), ntree_limit, ctypes.byref(length))
return ctypes2numpy(preds, length.value, 'float32')
option_mask, ntree_limit, ctypes.byref(length))
preds = ctypes2numpy(preds, length.value, 'float32')
if pred_leaf:
preds = preds.astype('int32')
nrow = data.num_row()
if preds.size != nrow and preds.size % nrow == 0:
preds = preds.reshape(nrow, preds.size / nrow)
return preds
def save_model(self, fname):
""" save model to file
Args:

View File

@ -30,9 +30,9 @@ class Booster: public learner::BoostLearner {
this->init_model = false;
this->SetCacheData(mats);
}
inline const float *Pred(const DataMatrix &dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
inline const float *Pred(const DataMatrix &dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) {
this->CheckInitModel();
this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit);
this->Predict(dmat, (option_mask&1) != 0, &this->preds_, ntree_limit, (option_mask&2) != 0);
*len = static_cast<bst_ulong>(this->preds_.size());
return BeginPtr(this->preds_);
}
@ -284,8 +284,8 @@ extern "C"{
bst->eval_str = bst->EvalOneIter(iter, mats, names);
return bst->eval_str.c_str();
}
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, ntree_limit, len);
const float *XGBoosterPredict(void *handle, void *dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) {
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), option_mask, ntree_limit, len);
}
void XGBoosterLoadModel(void *handle, const char *fname) {
static_cast<Booster*>(handle)->LoadModel(fname);

View File

@ -178,12 +178,18 @@ extern "C" {
* \brief make prediction based on dmat
* \param handle handle
* \param dmat data matrix
* \param output_margin whether only output raw margin value
* \param option_mask bit-mask of options taken in prediction, possible values
* 0:normal prediction
* 1:output margin instead of transformed value
* 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
* when the parameter is set to 0, we will use all the trees
* \param len used to store length of returning result
*/
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len);
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat,
int option_mask,
unsigned ntree_limit,
bst_ulong *len);
/*!
* \brief load model from existing file
* \param handle handle