add predict leaf indices
This commit is contained in:
parent
6ed82edad7
commit
168bb0d0c9
@ -32,6 +32,8 @@ This is a list of short codes introducing different functionalities of xgboost a
|
|||||||
[python](guide-python/cross_validation.py)
|
[python](guide-python/cross_validation.py)
|
||||||
[R](../R-package/demo/cross_validation.R)
|
[R](../R-package/demo/cross_validation.R)
|
||||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/cross_validation.jl)
|
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/cross_validation.jl)
|
||||||
|
* Predicting leaf indices
|
||||||
|
[python](guide-python/predict_leaf_indices.py)
|
||||||
|
|
||||||
Basic Examples by Tasks
|
Basic Examples by Tasks
|
||||||
====
|
====
|
||||||
|
|||||||
@ -6,3 +6,4 @@ XGBoost Python Feature Walkthrough
|
|||||||
* [Predicting using first n trees](predict_first_ntree.py)
|
* [Predicting using first n trees](predict_first_ntree.py)
|
||||||
* [Generalized Linear Model](generalized_linear_model.py)
|
* [Generalized Linear Model](generalized_linear_model.py)
|
||||||
* [Cross validation](cross_validation.py)
|
* [Cross validation](cross_validation.py)
|
||||||
|
* [Predicting leaf indices](predict_leaf_indices.py)
|
||||||
|
|||||||
22
demo/guide-python/predict_leaf_indices.py
Executable file
22
demo/guide-python/predict_leaf_indices.py
Executable file
@ -0,0 +1,22 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
sys.path.append('../../wrapper')
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
### load data in do training
|
||||||
|
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix('../data/agaricus.txt.test')
|
||||||
|
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||||
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
|
num_round = 3
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
|
||||||
|
print ('start testing predict the leaf indices')
|
||||||
|
### predict using first 2 tree
|
||||||
|
leafindex = bst.predict(dtest, ntree_limit=2, pred_leaf = True)
|
||||||
|
print leafindex.shape
|
||||||
|
print leafindex
|
||||||
|
### predict all trees
|
||||||
|
leafindex = bst.predict(dtest, pred_leaf = True)
|
||||||
|
print leafindex.shape
|
||||||
@ -4,4 +4,5 @@ python custom_objective.py
|
|||||||
python boost_from_prediction.py
|
python boost_from_prediction.py
|
||||||
python generalized_linear_model.py
|
python generalized_linear_model.py
|
||||||
python cross_validation.py
|
python cross_validation.py
|
||||||
|
python predict_leaf_index.py
|
||||||
rm -rf *~ *.model *.buffer
|
rm -rf *~ *.model *.buffer
|
||||||
@ -135,6 +135,12 @@ class GBLinear : public IGradBooster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
virtual void PredictLeaf(IFMatrix *p_fmat,
|
||||||
|
const BoosterInfo &info,
|
||||||
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit = 0) {
|
||||||
|
utils::Error("gblinear does not support predict leaf index");
|
||||||
|
}
|
||||||
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||||
utils::Error("gblinear does not support dump model");
|
utils::Error("gblinear does not support dump model");
|
||||||
return std::vector<std::string>();
|
return std::vector<std::string>();
|
||||||
|
|||||||
@ -74,6 +74,20 @@ class IGradBooster {
|
|||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
std::vector<float> *out_preds,
|
std::vector<float> *out_preds,
|
||||||
unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) = 0;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
|
||||||
|
* this is only valid in gbtree predictor
|
||||||
|
* \param p_fmat feature matrix
|
||||||
|
* \param info extra side information that may be needed for prediction
|
||||||
|
* \param out_preds output vector to hold the predictions
|
||||||
|
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||||
|
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
||||||
|
*/
|
||||||
|
virtual void PredictLeaf(IFMatrix *p_fmat,
|
||||||
|
const BoosterInfo &info,
|
||||||
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit = 0) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief dump the model in text format
|
* \brief dump the model in text format
|
||||||
* \param fmap feature map that may help give interpretations of feature
|
* \param fmap feature map that may help give interpretations of feature
|
||||||
|
|||||||
@ -126,11 +126,6 @@ class GBTree : public IGradBooster {
|
|||||||
for (int i = 0; i < nthread; ++i) {
|
for (int i = 0; i < nthread; ++i) {
|
||||||
thread_temp[i].Init(mparam.num_feature);
|
thread_temp[i].Init(mparam.num_feature);
|
||||||
}
|
}
|
||||||
if (tparam.pred_path != 0) {
|
|
||||||
this->PredPath(p_fmat, info, out_preds);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> &preds = *out_preds;
|
std::vector<float> &preds = *out_preds;
|
||||||
const size_t stride = info.num_row * mparam.num_output_group;
|
const size_t stride = info.num_row * mparam.num_output_group;
|
||||||
preds.resize(stride * (mparam.size_leaf_vector+1));
|
preds.resize(stride * (mparam.size_leaf_vector+1));
|
||||||
@ -158,6 +153,22 @@ class GBTree : public IGradBooster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
virtual void PredictLeaf(IFMatrix *p_fmat,
|
||||||
|
const BoosterInfo &info,
|
||||||
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit) {
|
||||||
|
int nthread;
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
nthread = omp_get_num_threads();
|
||||||
|
}
|
||||||
|
thread_temp.resize(nthread, tree::RegTree::FVec());
|
||||||
|
for (int i = 0; i < nthread; ++i) {
|
||||||
|
thread_temp[i].Init(mparam.num_feature);
|
||||||
|
}
|
||||||
|
this->PredPath(p_fmat, info, out_preds, ntree_limit);
|
||||||
|
|
||||||
|
}
|
||||||
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||||
std::vector<std::string> dump;
|
std::vector<std::string> dump;
|
||||||
for (size_t i = 0; i < trees.size(); i++) {
|
for (size_t i = 0; i < trees.size(); i++) {
|
||||||
@ -309,9 +320,14 @@ class GBTree : public IGradBooster {
|
|||||||
// predict independent leaf index
|
// predict independent leaf index
|
||||||
inline void PredPath(IFMatrix *p_fmat,
|
inline void PredPath(IFMatrix *p_fmat,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
std::vector<float> *out_preds) {
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit) {
|
||||||
|
// number of valid trees
|
||||||
|
if (ntree_limit == 0 || ntree_limit > trees.size()) {
|
||||||
|
ntree_limit = trees.size();
|
||||||
|
}
|
||||||
std::vector<float> &preds = *out_preds;
|
std::vector<float> &preds = *out_preds;
|
||||||
preds.resize(info.num_row * mparam.num_trees);
|
preds.resize(info.num_row * ntree_limit);
|
||||||
// start collecting the prediction
|
// start collecting the prediction
|
||||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
@ -325,9 +341,9 @@ class GBTree : public IGradBooster {
|
|||||||
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||||
tree::RegTree::FVec &feats = thread_temp[tid];
|
tree::RegTree::FVec &feats = thread_temp[tid];
|
||||||
feats.Fill(batch[i]);
|
feats.Fill(batch[i]);
|
||||||
for (size_t j = 0; j < trees.size(); ++j) {
|
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||||
int tid = trees[j]->GetLeafIndex(feats, info.GetRoot(ridx));
|
int tid = trees[j]->GetLeafIndex(feats, info.GetRoot(ridx));
|
||||||
preds[ridx * mparam.num_trees + j] = static_cast<float>(tid);
|
preds[ridx * ntree_limit + j] = static_cast<float>(tid);
|
||||||
}
|
}
|
||||||
feats.Drop(batch[i]);
|
feats.Drop(batch[i]);
|
||||||
}
|
}
|
||||||
@ -344,8 +360,6 @@ class GBTree : public IGradBooster {
|
|||||||
* use this option to support boosted random forest
|
* use this option to support boosted random forest
|
||||||
*/
|
*/
|
||||||
int num_parallel_tree;
|
int num_parallel_tree;
|
||||||
/*! \brief predict path in prediction */
|
|
||||||
int pred_path;
|
|
||||||
/*! \brief whether updater is already initialized */
|
/*! \brief whether updater is already initialized */
|
||||||
int updater_initialized;
|
int updater_initialized;
|
||||||
/*! \brief tree updater sequence */
|
/*! \brief tree updater sequence */
|
||||||
@ -356,7 +370,6 @@ class GBTree : public IGradBooster {
|
|||||||
updater_seq = "grow_colmaker,prune";
|
updater_seq = "grow_colmaker,prune";
|
||||||
num_parallel_tree = 1;
|
num_parallel_tree = 1;
|
||||||
updater_initialized = 0;
|
updater_initialized = 0;
|
||||||
pred_path = 0;
|
|
||||||
}
|
}
|
||||||
inline void SetParam(const char *name, const char *val){
|
inline void SetParam(const char *name, const char *val){
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -371,7 +384,6 @@ class GBTree : public IGradBooster {
|
|||||||
if (!strcmp(name, "num_parallel_tree")) {
|
if (!strcmp(name, "num_parallel_tree")) {
|
||||||
num_parallel_tree = atoi(val);
|
num_parallel_tree = atoi(val);
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "pred_path")) pred_path = atoi(val);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
/*! \brief model parameters */
|
/*! \brief model parameters */
|
||||||
|
|||||||
@ -280,12 +280,18 @@ class BoostLearner {
|
|||||||
inline void Predict(const DMatrix &data,
|
inline void Predict(const DMatrix &data,
|
||||||
bool output_margin,
|
bool output_margin,
|
||||||
std::vector<float> *out_preds,
|
std::vector<float> *out_preds,
|
||||||
unsigned ntree_limit = 0) const {
|
unsigned ntree_limit = 0,
|
||||||
|
bool pred_leaf = false
|
||||||
|
) const {
|
||||||
|
if (pred_leaf) {
|
||||||
|
gbm_->PredictLeaf(data.fmat(), data.info.info, out_preds, ntree_limit);
|
||||||
|
} else {
|
||||||
this->PredictRaw(data, out_preds, ntree_limit);
|
this->PredictRaw(data, out_preds, ntree_limit);
|
||||||
if (!output_margin) {
|
if (!output_margin) {
|
||||||
obj_->PredTransform(out_preds);
|
obj_->PredTransform(out_preds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
/*! \brief dump model out */
|
/*! \brief dump model out */
|
||||||
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||||
return gbm_->DumpModel(fmap, option);
|
return gbm_->DumpModel(fmap, option);
|
||||||
|
|||||||
@ -333,7 +333,7 @@ class Booster:
|
|||||||
return res
|
return res
|
||||||
def eval(self, mat, name = 'eval', it = 0):
|
def eval(self, mat, name = 'eval', it = 0):
|
||||||
return self.eval_set( [(mat,name)], it)
|
return self.eval_set( [(mat,name)], it)
|
||||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
||||||
"""
|
"""
|
||||||
predict with data
|
predict with data
|
||||||
Args:
|
Args:
|
||||||
@ -343,13 +343,28 @@ class Booster:
|
|||||||
whether output raw margin value that is untransformed
|
whether output raw margin value that is untransformed
|
||||||
ntree_limit: int
|
ntree_limit: int
|
||||||
limit number of trees in prediction, default to 0, 0 means using all the trees
|
limit number of trees in prediction, default to 0, 0 means using all the trees
|
||||||
|
pred_leaf: bool
|
||||||
|
when this option is on, the output will be a matrix of (nsample, ntrees)
|
||||||
|
with each record indicate the predicted leaf index of each sample in each tree
|
||||||
|
Note that the leaf index of tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0
|
||||||
Returns:
|
Returns:
|
||||||
numpy array of prediction
|
numpy array of prediction
|
||||||
"""
|
"""
|
||||||
|
option_mask = 0
|
||||||
|
if output_margin:
|
||||||
|
option_mask += 1
|
||||||
|
if pred_leaf:
|
||||||
|
option_mask += 2
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
||||||
int(output_margin), ntree_limit, ctypes.byref(length))
|
option_mask, ntree_limit, ctypes.byref(length))
|
||||||
return ctypes2numpy(preds, length.value, 'float32')
|
preds = ctypes2numpy(preds, length.value, 'float32')
|
||||||
|
if pred_leaf:
|
||||||
|
preds = preds.astype('int32')
|
||||||
|
nrow = data.num_row()
|
||||||
|
if preds.size != nrow and preds.size % nrow == 0:
|
||||||
|
preds = preds.reshape(nrow, preds.size / nrow)
|
||||||
|
return preds
|
||||||
def save_model(self, fname):
|
def save_model(self, fname):
|
||||||
""" save model to file
|
""" save model to file
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -30,9 +30,9 @@ class Booster: public learner::BoostLearner {
|
|||||||
this->init_model = false;
|
this->init_model = false;
|
||||||
this->SetCacheData(mats);
|
this->SetCacheData(mats);
|
||||||
}
|
}
|
||||||
inline const float *Pred(const DataMatrix &dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
|
inline const float *Pred(const DataMatrix &dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) {
|
||||||
this->CheckInitModel();
|
this->CheckInitModel();
|
||||||
this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit);
|
this->Predict(dmat, (option_mask&1) != 0, &this->preds_, ntree_limit, (option_mask&2) != 0);
|
||||||
*len = static_cast<bst_ulong>(this->preds_.size());
|
*len = static_cast<bst_ulong>(this->preds_.size());
|
||||||
return BeginPtr(this->preds_);
|
return BeginPtr(this->preds_);
|
||||||
}
|
}
|
||||||
@ -284,8 +284,8 @@ extern "C"{
|
|||||||
bst->eval_str = bst->EvalOneIter(iter, mats, names);
|
bst->eval_str = bst->EvalOneIter(iter, mats, names);
|
||||||
return bst->eval_str.c_str();
|
return bst->eval_str.c_str();
|
||||||
}
|
}
|
||||||
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
|
const float *XGBoosterPredict(void *handle, void *dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) {
|
||||||
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, ntree_limit, len);
|
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), option_mask, ntree_limit, len);
|
||||||
}
|
}
|
||||||
void XGBoosterLoadModel(void *handle, const char *fname) {
|
void XGBoosterLoadModel(void *handle, const char *fname) {
|
||||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||||
|
|||||||
@ -178,12 +178,18 @@ extern "C" {
|
|||||||
* \brief make prediction based on dmat
|
* \brief make prediction based on dmat
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \param dmat data matrix
|
* \param dmat data matrix
|
||||||
* \param output_margin whether only output raw margin value
|
* \param option_mask bit-mask of options taken in prediction, possible values
|
||||||
|
* 0:normal prediction
|
||||||
|
* 1:output margin instead of transformed value
|
||||||
|
* 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree
|
||||||
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
|
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
|
||||||
* when the parameter is set to 0, we will use all the trees
|
* when the parameter is set to 0, we will use all the trees
|
||||||
* \param len used to store length of returning result
|
* \param len used to store length of returning result
|
||||||
*/
|
*/
|
||||||
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len);
|
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat,
|
||||||
|
int option_mask,
|
||||||
|
unsigned ntree_limit,
|
||||||
|
bst_ulong *len);
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from existing file
|
* \brief load model from existing file
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user