From 168bb0d0c9bf6a5251d5a60de37ff9993e22789b Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 21 Nov 2014 09:32:09 -0800 Subject: [PATCH] add predict leaf indices --- demo/README.md | 2 ++ demo/guide-python/README.md | 1 + demo/guide-python/predict_leaf_indices.py | 22 +++++++++++++ demo/guide-python/runall.sh | 1 + src/gbm/gblinear-inl.hpp | 6 ++++ src/gbm/gbm.h | 14 +++++++++ src/gbm/gbtree-inl.hpp | 38 +++++++++++++++-------- src/learner/learner-inl.hpp | 14 ++++++--- wrapper/xgboost.py | 27 ++++++++++++---- wrapper/xgboost_wrapper.cpp | 8 ++--- wrapper/xgboost_wrapper.h | 10 ++++-- 11 files changed, 114 insertions(+), 29 deletions(-) create mode 100755 demo/guide-python/predict_leaf_indices.py diff --git a/demo/README.md b/demo/README.md index bcc356712..56915a32e 100644 --- a/demo/README.md +++ b/demo/README.md @@ -32,6 +32,8 @@ This is a list of short codes introducing different functionalities of xgboost a [python](guide-python/cross_validation.py) [R](../R-package/demo/cross_validation.R) [Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/cross_validation.jl) +* Predicting leaf indices + [python](guide-python/predict_leaf_indices.py) Basic Examples by Tasks ==== diff --git a/demo/guide-python/README.md b/demo/guide-python/README.md index 3625c40f5..bc1c219d0 100644 --- a/demo/guide-python/README.md +++ b/demo/guide-python/README.md @@ -6,3 +6,4 @@ XGBoost Python Feature Walkthrough * [Predicting using first n trees](predict_first_ntree.py) * [Generalized Linear Model](generalized_linear_model.py) * [Cross validation](cross_validation.py) +* [Predicting leaf indices](predict_leaf_indices.py) diff --git a/demo/guide-python/predict_leaf_indices.py b/demo/guide-python/predict_leaf_indices.py new file mode 100755 index 000000000..291ad1ee7 --- /dev/null +++ b/demo/guide-python/predict_leaf_indices.py @@ -0,0 +1,22 @@ +#!/usr/bin/python +import sys +import numpy as np +sys.path.append('../../wrapper') +import xgboost as xgb + +### load data in do training +dtrain = xgb.DMatrix('../data/agaricus.txt.train') +dtest = xgb.DMatrix('../data/agaricus.txt.test') +param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' } +watchlist = [(dtest,'eval'), (dtrain,'train')] +num_round = 3 +bst = xgb.train(param, dtrain, num_round, watchlist) + +print ('start testing predict the leaf indices') +### predict using first 2 tree +leafindex = bst.predict(dtest, ntree_limit=2, pred_leaf = True) +print leafindex.shape +print leafindex +### predict all trees +leafindex = bst.predict(dtest, pred_leaf = True) +print leafindex.shape diff --git a/demo/guide-python/runall.sh b/demo/guide-python/runall.sh index 2dd2c20b0..5317186d5 100755 --- a/demo/guide-python/runall.sh +++ b/demo/guide-python/runall.sh @@ -4,4 +4,5 @@ python custom_objective.py python boost_from_prediction.py python generalized_linear_model.py python cross_validation.py +python predict_leaf_index.py rm -rf *~ *.model *.buffer \ No newline at end of file diff --git a/src/gbm/gblinear-inl.hpp b/src/gbm/gblinear-inl.hpp index cae5cf4f3..6d507ac6e 100644 --- a/src/gbm/gblinear-inl.hpp +++ b/src/gbm/gblinear-inl.hpp @@ -135,6 +135,12 @@ class GBLinear : public IGradBooster { } } } + virtual void PredictLeaf(IFMatrix *p_fmat, + const BoosterInfo &info, + std::vector *out_preds, + unsigned ntree_limit = 0) { + utils::Error("gblinear does not support predict leaf index"); + } virtual std::vector DumpModel(const utils::FeatMap& fmap, int option) { utils::Error("gblinear does not support dump model"); return std::vector(); diff --git a/src/gbm/gbm.h b/src/gbm/gbm.h index 28b370c48..f8eae6dbb 100644 --- a/src/gbm/gbm.h +++ b/src/gbm/gbm.h @@ -74,6 +74,20 @@ class IGradBooster { const BoosterInfo &info, std::vector *out_preds, unsigned ntree_limit = 0) = 0; + + /*! + * \brief predict the leaf index of each tree, the output will be nsample * ntree vector + * this is only valid in gbtree predictor + * \param p_fmat feature matrix + * \param info extra side information that may be needed for prediction + * \param out_preds output vector to hold the predictions + * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means + * we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear + */ + virtual void PredictLeaf(IFMatrix *p_fmat, + const BoosterInfo &info, + std::vector *out_preds, + unsigned ntree_limit = 0) = 0; /*! * \brief dump the model in text format * \param fmap feature map that may help give interpretations of feature diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index d334296c8..b20acd48e 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -126,11 +126,6 @@ class GBTree : public IGradBooster { for (int i = 0; i < nthread; ++i) { thread_temp[i].Init(mparam.num_feature); } - if (tparam.pred_path != 0) { - this->PredPath(p_fmat, info, out_preds); - return; - } - std::vector &preds = *out_preds; const size_t stride = info.num_row * mparam.num_output_group; preds.resize(stride * (mparam.size_leaf_vector+1)); @@ -158,6 +153,22 @@ class GBTree : public IGradBooster { } } } + virtual void PredictLeaf(IFMatrix *p_fmat, + const BoosterInfo &info, + std::vector *out_preds, + unsigned ntree_limit) { + int nthread; + #pragma omp parallel + { + nthread = omp_get_num_threads(); + } + thread_temp.resize(nthread, tree::RegTree::FVec()); + for (int i = 0; i < nthread; ++i) { + thread_temp[i].Init(mparam.num_feature); + } + this->PredPath(p_fmat, info, out_preds, ntree_limit); + + } virtual std::vector DumpModel(const utils::FeatMap& fmap, int option) { std::vector dump; for (size_t i = 0; i < trees.size(); i++) { @@ -309,9 +320,14 @@ class GBTree : public IGradBooster { // predict independent leaf index inline void PredPath(IFMatrix *p_fmat, const BoosterInfo &info, - std::vector *out_preds) { + std::vector *out_preds, + unsigned ntree_limit) { + // number of valid trees + if (ntree_limit == 0 || ntree_limit > trees.size()) { + ntree_limit = trees.size(); + } std::vector &preds = *out_preds; - preds.resize(info.num_row * mparam.num_trees); + preds.resize(info.num_row * ntree_limit); // start collecting the prediction utils::IIterator *iter = p_fmat->RowIterator(); iter->BeforeFirst(); @@ -325,9 +341,9 @@ class GBTree : public IGradBooster { int64_t ridx = static_cast(batch.base_rowid + i); tree::RegTree::FVec &feats = thread_temp[tid]; feats.Fill(batch[i]); - for (size_t j = 0; j < trees.size(); ++j) { + for (unsigned j = 0; j < ntree_limit; ++j) { int tid = trees[j]->GetLeafIndex(feats, info.GetRoot(ridx)); - preds[ridx * mparam.num_trees + j] = static_cast(tid); + preds[ridx * ntree_limit + j] = static_cast(tid); } feats.Drop(batch[i]); } @@ -344,8 +360,6 @@ class GBTree : public IGradBooster { * use this option to support boosted random forest */ int num_parallel_tree; - /*! \brief predict path in prediction */ - int pred_path; /*! \brief whether updater is already initialized */ int updater_initialized; /*! \brief tree updater sequence */ @@ -356,7 +370,6 @@ class GBTree : public IGradBooster { updater_seq = "grow_colmaker,prune"; num_parallel_tree = 1; updater_initialized = 0; - pred_path = 0; } inline void SetParam(const char *name, const char *val){ using namespace std; @@ -371,7 +384,6 @@ class GBTree : public IGradBooster { if (!strcmp(name, "num_parallel_tree")) { num_parallel_tree = atoi(val); } - if (!strcmp(name, "pred_path")) pred_path = atoi(val); } }; /*! \brief model parameters */ diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 6b7440239..d16986e83 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -280,10 +280,16 @@ class BoostLearner { inline void Predict(const DMatrix &data, bool output_margin, std::vector *out_preds, - unsigned ntree_limit = 0) const { - this->PredictRaw(data, out_preds, ntree_limit); - if (!output_margin) { - obj_->PredTransform(out_preds); + unsigned ntree_limit = 0, + bool pred_leaf = false + ) const { + if (pred_leaf) { + gbm_->PredictLeaf(data.fmat(), data.info.info, out_preds, ntree_limit); + } else { + this->PredictRaw(data, out_preds, ntree_limit); + if (!output_margin) { + obj_->PredTransform(out_preds); + } } } /*! \brief dump model out */ diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index b549ddd8b..08aacb90e 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -333,23 +333,38 @@ class Booster: return res def eval(self, mat, name = 'eval', it = 0): return self.eval_set( [(mat,name)], it) - def predict(self, data, output_margin=False, ntree_limit=0): + def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): """ predict with data Args: data: DMatrix - the dmatrix storing the input + the dmatrix storing the input output_margin: bool - whether output raw margin value that is untransformed + whether output raw margin value that is untransformed ntree_limit: int - limit number of trees in prediction, default to 0, 0 means using all the trees + limit number of trees in prediction, default to 0, 0 means using all the trees + pred_leaf: bool + when this option is on, the output will be a matrix of (nsample, ntrees) + with each record indicate the predicted leaf index of each sample in each tree + Note that the leaf index of tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0 Returns: numpy array of prediction """ + option_mask = 0 + if output_margin: + option_mask += 1 + if pred_leaf: + option_mask += 2 length = ctypes.c_ulong() preds = xglib.XGBoosterPredict(self.handle, data.handle, - int(output_margin), ntree_limit, ctypes.byref(length)) - return ctypes2numpy(preds, length.value, 'float32') + option_mask, ntree_limit, ctypes.byref(length)) + preds = ctypes2numpy(preds, length.value, 'float32') + if pred_leaf: + preds = preds.astype('int32') + nrow = data.num_row() + if preds.size != nrow and preds.size % nrow == 0: + preds = preds.reshape(nrow, preds.size / nrow) + return preds def save_model(self, fname): """ save model to file Args: diff --git a/wrapper/xgboost_wrapper.cpp b/wrapper/xgboost_wrapper.cpp index ac054090c..d0efc4bd0 100644 --- a/wrapper/xgboost_wrapper.cpp +++ b/wrapper/xgboost_wrapper.cpp @@ -30,9 +30,9 @@ class Booster: public learner::BoostLearner { this->init_model = false; this->SetCacheData(mats); } - inline const float *Pred(const DataMatrix &dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) { + inline const float *Pred(const DataMatrix &dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) { this->CheckInitModel(); - this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit); + this->Predict(dmat, (option_mask&1) != 0, &this->preds_, ntree_limit, (option_mask&2) != 0); *len = static_cast(this->preds_.size()); return BeginPtr(this->preds_); } @@ -284,8 +284,8 @@ extern "C"{ bst->eval_str = bst->EvalOneIter(iter, mats, names); return bst->eval_str.c_str(); } - const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) { - return static_cast(handle)->Pred(*static_cast(dmat), output_margin, ntree_limit, len); + const float *XGBoosterPredict(void *handle, void *dmat, int option_mask, unsigned ntree_limit, bst_ulong *len) { + return static_cast(handle)->Pred(*static_cast(dmat), option_mask, ntree_limit, len); } void XGBoosterLoadModel(void *handle, const char *fname) { static_cast(handle)->LoadModel(fname); diff --git a/wrapper/xgboost_wrapper.h b/wrapper/xgboost_wrapper.h index 2ae70f026..16d54f62b 100644 --- a/wrapper/xgboost_wrapper.h +++ b/wrapper/xgboost_wrapper.h @@ -178,12 +178,18 @@ extern "C" { * \brief make prediction based on dmat * \param handle handle * \param dmat data matrix - * \param output_margin whether only output raw margin value + * \param option_mask bit-mask of options taken in prediction, possible values + * 0:normal prediction + * 1:output margin instead of transformed value + * 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree * \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees * when the parameter is set to 0, we will use all the trees * \param len used to store length of returning result */ - XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len); + XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, + int option_mask, + unsigned ntree_limit, + bst_ulong *len); /*! * \brief load model from existing file * \param handle handle