add predict leaf indices
This commit is contained in:
parent
6ed82edad7
commit
168bb0d0c9
@ -32,6 +32,8 @@ This is a list of short codes introducing different functionalities of xgboost a
|
||||
[python](guide-python/cross_validation.py)
|
||||
[R](../R-package/demo/cross_validation.R)
|
||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/cross_validation.jl)
|
||||
* Predicting leaf indices
|
||||
[python](guide-python/predict_leaf_indices.py)
|
||||
|
||||
Basic Examples by Tasks
|
||||
====
|
||||
|
||||
@ -6,3 +6,4 @@ XGBoost Python Feature Walkthrough
|
||||
* [Predicting using first n trees](predict_first_ntree.py)
|
||||
* [Generalized Linear Model](generalized_linear_model.py)
|
||||
* [Cross validation](cross_validation.py)
|
||||
* [Predicting leaf indices](predict_leaf_indices.py)
|
||||
|
||||
22
demo/guide-python/predict_leaf_indices.py
Executable file
22
demo/guide-python/predict_leaf_indices.py
Executable file
@ -0,0 +1,22 @@
|
||||
#!/usr/bin/python
|
||||
import sys
|
||||
import numpy as np
|
||||
sys.path.append('../../wrapper')
|
||||
import xgboost as xgb
|
||||
|
||||
### load data in do training
|
||||
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('../data/agaricus.txt.test')
|
||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
num_round = 3
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
|
||||
print ('start testing predict the leaf indices')
|
||||
### predict using first 2 tree
|
||||
leafindex = bst.predict(dtest, ntree_limit=2, pred_leaf = True)
|
||||
print leafindex.shape
|
||||
print leafindex
|
||||
### predict all trees
|
||||
leafindex = bst.predict(dtest, pred_leaf = True)
|
||||
print leafindex.shape
|
||||
@ -4,4 +4,5 @@ python custom_objective.py
|
||||
python boost_from_prediction.py
|
||||
python generalized_linear_model.py
|
||||
python cross_validation.py
|
||||
python predict_leaf_index.py
|
||||
rm -rf *~ *.model *.buffer
|
||||
@ -135,6 +135,12 @@ class GBLinear : public IGradBooster {
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual void PredictLeaf(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0) {
|
||||
utils::Error("gblinear does not support predict leaf index");
|
||||
}
|
||||
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||
utils::Error("gblinear does not support dump model");
|
||||
return std::vector<std::string>();
|
||||
|
||||
@ -74,6 +74,20 @@ class IGradBooster {
|
||||
const BoosterInfo &info,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
|
||||
/*!
|
||||
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
|
||||
* this is only valid in gbtree predictor
|
||||
* \param p_fmat feature matrix
|
||||
* \param info extra side information that may be needed for prediction
|
||||
* \param out_preds output vector to hold the predictions
|
||||
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
||||
*/
|
||||
virtual void PredictLeaf(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
/*!
|
||||
* \brief dump the model in text format
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
|
||||
@ -126,11 +126,6 @@ class GBTree : public IGradBooster {
|
||||
for (int i = 0; i < nthread; ++i) {
|
||||
thread_temp[i].Init(mparam.num_feature);
|
||||
}
|
||||
if (tparam.pred_path != 0) {
|
||||
this->PredPath(p_fmat, info, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<float> &preds = *out_preds;
|
||||
const size_t stride = info.num_row * mparam.num_output_group;
|
||||
preds.resize(stride * (mparam.size_leaf_vector+1));
|
||||
@ -158,6 +153,22 @@ class GBTree : public IGradBooster {
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual void PredictLeaf(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit) {
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
thread_temp.resize(nthread, tree::RegTree::FVec());
|
||||
for (int i = 0; i < nthread; ++i) {
|
||||
thread_temp[i].Init(mparam.num_feature);
|
||||
}
|
||||
this->PredPath(p_fmat, info, out_preds, ntree_limit);
|
||||
|
||||
}
|
||||
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||
std::vector<std::string> dump;
|
||||
for (size_t i = 0; i < trees.size(); i++) {
|
||||
@ -309,9 +320,14 @@ class GBTree : public IGradBooster {
|
||||
// predict independent leaf index
|
||||
inline void PredPath(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
std::vector<float> *out_preds) {
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit) {
|
||||
// number of valid trees
|
||||
if (ntree_limit == 0 || ntree_limit > trees.size()) {
|
||||
ntree_limit = trees.size();
|
||||
}
|
||||
std::vector<float> &preds = *out_preds;
|
||||
preds.resize(info.num_row * mparam.num_trees);
|
||||
preds.resize(info.num_row * ntree_limit);
|
||||
// start collecting the prediction
|
||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
@ -325,9 +341,9 @@ class GBTree : public IGradBooster {
|
||||
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
tree::RegTree::FVec &feats = thread_temp[tid];
|
||||
feats.Fill(batch[i]);
|
||||
for (size_t j = 0; j < trees.size(); ++j) {
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
int tid = trees[j]->GetLeafIndex(feats, info.GetRoot(ridx));
|
||||
preds[ridx * mparam.num_trees + j] = static_cast<float>(tid);
|
||||
preds[ridx * ntree_limit + j] = static_cast<float>(tid);
|
||||
}
|
||||
feats.Drop(batch[i]);
|
||||
}
|
||||
@ -344,8 +360,6 @@ class GBTree : public IGradBooster {
|
||||
* use this option to support boosted random forest
|
||||
*/
|
||||
int num_parallel_tree;
|
||||
/*! \brief predict path in prediction */
|
||||
int pred_path;
|
||||
/*! \brief whether updater is already initialized */
|
||||
int updater_initialized;
|
||||
/*! \brief tree updater sequence */
|
||||
@ -356,7 +370,6 @@ class GBTree : public IGradBooster {
|
||||
updater_seq = "grow_colmaker,prune";
|
||||
num_parallel_tree = 1;
|
||||
updater_initialized = 0;
|
||||
pred_path = 0;
|
||||
}
|
||||
inline void SetParam(const char *name, const char *val){
|
||||
using namespace std;
|
||||
@ -371,7 +384,6 @@ class GBTree : public IGradBooster {
|
||||
if (!strcmp(name, "num_parallel_tree")) {
|
||||
num_parallel_tree = atoi(val);
|
||||
}
|
||||
if (!strcmp(name, "pred_path")) pred_path = atoi(val);
|
||||
}
|
||||
};
|
||||
/*! \brief model parameters */
|
||||
|
||||
@ -280,10 +280,16 @@ class BoostLearner {
|
||||
inline void Predict(const DMatrix &data,
|
||||
bool output_margin,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
unsigned ntree_limit = 0,
|
||||
bool pred_leaf = false
|
||||
) const {
|
||||
if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data.fmat(), data.info.info, out_preds, ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief dump model out */
|
||||
|
||||
@ -333,23 +333,38 @@ class Booster:
|
||||
return res
|
||||
def eval(self, mat, name = 'eval', it = 0):
|
||||
return self.eval_set( [(mat,name)], it)
|
||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
||||
"""
|
||||
predict with data
|
||||
Args:
|
||||
data: DMatrix
|
||||
the dmatrix storing the input
|
||||
the dmatrix storing the input
|
||||
output_margin: bool
|
||||
whether output raw margin value that is untransformed
|
||||
whether output raw margin value that is untransformed
|
||||
ntree_limit: int
|
||||
limit number of trees in prediction, default to 0, 0 means using all the trees
|
||||
limit number of trees in prediction, default to 0, 0 means using all the trees
|
||||
pred_leaf: bool
|
||||
when this option is on, the output will be a matrix of (nsample, ntrees)
|
||||
with each record indicate the predicted leaf index of each sample in each tree
|
||||
Note that the leaf index of tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0
|
||||
Returns:
|
||||
numpy array of prediction
|
||||
"""
|
||||
option_mask = 0
|
||||
if output_margin:
|
||||
option_mask += 1
|
||||
if pred_leaf:
|
||||
option_mask += 2
|
||||
length = ctypes.c_ulong()
|
||||
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
||||
int(output_margin), ntree_limit, ctypes.byref(length))
|
||||
return ctypes2numpy(preds, length.value, 'float32')
|
||||
option_mask, ntree_limit, ctypes.byref(length))
|
||||
preds = ctypes2numpy(preds, length.value, 'float32')
|
||||
if pred_leaf:
|
||||
preds = preds.astype('int32')
|
||||
nrow = data.num_row()
|
||||
if preds.size != nrow and preds.size % nrow == 0:
|
||||
preds = preds.reshape(nrow, preds.size / nrow)
|
||||
return preds
|
||||
def save_model(self, fname):
|
||||
""" save model to file
|
||||
Args:
|
||||
|
||||
@ -30,9 +30,9 @@ class Booster: public learner::BoostLearner {
|
||||
this->init_model = false;
|
||||
this->SetCacheData(mats);
|
||||
}
|
||||
inline const float *Pred(const DataMatrix &dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
|
||||
inline const float *Pred(const DataMatrix &dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) {
|
||||
this->CheckInitModel();
|
||||
this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit);
|
||||
this->Predict(dmat, (option_mask&1) != 0, &this->preds_, ntree_limit, (option_mask&2) != 0);
|
||||
*len = static_cast<bst_ulong>(this->preds_.size());
|
||||
return BeginPtr(this->preds_);
|
||||
}
|
||||
@ -284,8 +284,8 @@ extern "C"{
|
||||
bst->eval_str = bst->EvalOneIter(iter, mats, names);
|
||||
return bst->eval_str.c_str();
|
||||
}
|
||||
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
|
||||
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, ntree_limit, len);
|
||||
const float *XGBoosterPredict(void *handle, void *dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) {
|
||||
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), option_mask, ntree_limit, len);
|
||||
}
|
||||
void XGBoosterLoadModel(void *handle, const char *fname) {
|
||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||
|
||||
@ -178,12 +178,18 @@ extern "C" {
|
||||
* \brief make prediction based on dmat
|
||||
* \param handle handle
|
||||
* \param dmat data matrix
|
||||
* \param output_margin whether only output raw margin value
|
||||
* \param option_mask bit-mask of options taken in prediction, possible values
|
||||
* 0:normal prediction
|
||||
* 1:output margin instead of transformed value
|
||||
* 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree
|
||||
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
|
||||
* when the parameter is set to 0, we will use all the trees
|
||||
* \param len used to store length of returning result
|
||||
*/
|
||||
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len);
|
||||
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
bst_ulong *len);
|
||||
/*!
|
||||
* \brief load model from existing file
|
||||
* \param handle handle
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user