[GBM] remove need to explicit InitModel, rename save/load

This commit is contained in:
tqchen 2016-01-03 05:31:00 -08:00
parent 82ceb4de0a
commit 4b4b36d047
5 changed files with 50 additions and 59 deletions

View File

@ -32,21 +32,16 @@ class GradientBooster {
* \param cfg configurations on both training and model parameters. * \param cfg configurations on both training and model parameters.
*/ */
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0; virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
/*!
* \brief Initialize the model.
* User need to call Configure before calling InitModel.
*/
virtual void InitModel() = 0;
/*! /*!
* \brief load model from stream * \brief load model from stream
* \param fi input stream. * \param fi input stream.
*/ */
virtual void LoadModel(dmlc::Stream* fi) = 0; virtual void Load(dmlc::Stream* fi) = 0;
/*! /*!
* \brief save model to stream. * \brief save model to stream.
* \param fo output stream * \param fo output stream
*/ */
virtual void SaveModel(dmlc::Stream* fo) const = 0; virtual void Save(dmlc::Stream* fo) const = 0;
/*! /*!
* \brief reset the predict buffer size. * \brief reset the predict buffer size.
* This will invalidate all the previous cached results * This will invalidate all the previous cached results

View File

@ -304,7 +304,7 @@ class TreeModel {
* \brief load model from stream * \brief load model from stream
* \param fi input stream * \param fi input stream
*/ */
inline void LoadModel(dmlc::Stream* fi) { inline void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam)); CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
nodes.resize(param.num_nodes); nodes.resize(param.num_nodes);
stats.resize(param.num_nodes); stats.resize(param.num_nodes);
@ -327,7 +327,7 @@ class TreeModel {
* \brief save model to stream * \brief save model to stream
* \param fo output stream * \param fo output stream
*/ */
inline void SaveModel(dmlc::Stream* fo) const { inline void Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes.size())); CHECK_EQ(param.num_nodes, static_cast<int>(nodes.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats.size())); CHECK_EQ(param.num_nodes, static_cast<int>(stats.size()));
fo->Write(&param, sizeof(TreeParam)); fo->Write(&param, sizeof(TreeParam));

View File

@ -90,18 +90,20 @@ class GBLinear : public GradientBooster {
} }
param.InitAllowUnknown(cfg); param.InitAllowUnknown(cfg);
} }
void LoadModel(dmlc::Stream* fi) override { void Load(dmlc::Stream* fi) override {
model.LoadModel(fi); model.Load(fi);
} }
void SaveModel(dmlc::Stream* fo) const override { void Save(dmlc::Stream* fo) const override {
model.SaveModel(fo); model.Save(fo);
}
void InitModel() override {
model.InitModel();
} }
virtual void DoBoost(DMatrix *p_fmat, virtual void DoBoost(DMatrix *p_fmat,
int64_t buffer_offset, int64_t buffer_offset,
std::vector<bst_gpair> *in_gpair) { std::vector<bst_gpair> *in_gpair) {
// lazily initialize the model when not ready.
if (model.weight.size() == 0) {
model.InitModel();
}
std::vector<bst_gpair> &gpair = *in_gpair; std::vector<bst_gpair> &gpair = *in_gpair;
const int ngroup = model.param.num_output_group; const int ngroup = model.param.num_output_group;
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset(); const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
@ -248,12 +250,12 @@ class GBLinear : public GradientBooster {
std::fill(weight.begin(), weight.end(), 0.0f); std::fill(weight.begin(), weight.end(), 0.0f);
} }
// save the model to file // save the model to file
inline void SaveModel(dmlc::Stream* fo) const { inline void Save(dmlc::Stream* fo) const {
fo->Write(&param, sizeof(param)); fo->Write(&param, sizeof(param));
fo->Write(weight); fo->Write(weight);
} }
// load model from file // load model from file
inline void LoadModel(dmlc::Stream* fi) { inline void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param)); CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param));
fi->Read(&weight); fi->Read(&weight);
} }

View File

@ -52,8 +52,8 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
int num_feature; int num_feature;
/*! \brief pad this space, for backward compatiblity reason.*/ /*! \brief pad this space, for backward compatiblity reason.*/
int pad_32bit; int pad_32bit;
/*! \brief size of prediction buffer allocated used for buffering */ /*! \brief deprecated padding space. */
int64_t num_pbuffer; int64_t num_pbuffer_deprecated;
/*! /*!
* \brief how many output group a single instance can produce * \brief how many output group a single instance can produce
* this affects the behavior of number of output we have: * this affects the behavior of number of output we have:
@ -82,24 +82,13 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
.describe("Reserved option for vector tree."); .describe("Reserved option for vector tree.");
} }
/*! \return size of prediction buffer actually needed */
inline size_t PredBufferSize() const {
return num_output_group * num_pbuffer * (size_leaf_vector + 1);
}
/*!
* \brief get the buffer offset given a buffer index and group id
* \return calculated buffer offset
*/
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
if (buffer_index < 0) return -1;
CHECK_LT(buffer_index, num_pbuffer);
return (buffer_index + num_pbuffer * bst_group) * (size_leaf_vector + 1);
}
}; };
// gradient boosted trees // gradient boosted trees
class GBTree : public GradientBooster { class GBTree : public GradientBooster {
public: public:
GBTree() : num_pbuffer(0) {}
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override { void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
this->cfg = cfg; this->cfg = cfg;
// initialize model parameters if not yet been initialized. // initialize model parameters if not yet been initialized.
@ -118,13 +107,13 @@ class GBTree : public GradientBooster {
} }
} }
void LoadModel(dmlc::Stream* fi) override { void Load(dmlc::Stream* fi) override {
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam)) CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
<< "GBTree: invalid model file"; << "GBTree: invalid model file";
trees.clear(); trees.clear();
for (int i = 0; i < mparam.num_trees; ++i) { for (int i = 0; i < mparam.num_trees; ++i) {
std::unique_ptr<RegTree> ptr(new RegTree()); std::unique_ptr<RegTree> ptr(new RegTree());
ptr->LoadModel(fi); ptr->Load(fi);
trees.push_back(std::move(ptr)); trees.push_back(std::move(ptr));
} }
tree_info.resize(mparam.num_trees); tree_info.resize(mparam.num_trees);
@ -132,38 +121,27 @@ class GBTree : public GradientBooster {
CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees), CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees),
sizeof(int) * mparam.num_trees); sizeof(int) * mparam.num_trees);
} }
this->ResetPredBuffer(0); // clear the predict buffer.
this->ResetPredBuffer(num_pbuffer);
} }
void SaveModel(dmlc::Stream* fo) const override { void Save(dmlc::Stream* fo) const override {
CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size())); CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size()));
// not save predict buffer. fo->Write(&mparam, sizeof(mparam));
GBTreeModelParam p = mparam;
p.num_pbuffer = 0;
fo->Write(&p, sizeof(p));
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fo); trees[i]->Save(fo);
} }
if (tree_info.size() != 0) { if (tree_info.size() != 0) {
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size()); fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
} }
} }
void InitModel() override {
CHECK(mparam.num_trees == 0 && trees.size() == 0)
<< "Model has already been initialized.";
pred_buffer.clear();
pred_counter.clear();
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
pred_counter.resize(mparam.PredBufferSize(), 0);
}
void ResetPredBuffer(size_t num_pbuffer) override { void ResetPredBuffer(size_t num_pbuffer) override {
mparam.num_pbuffer = static_cast<int64_t>(num_pbuffer); this->num_pbuffer = num_pbuffer;
pred_buffer.clear(); pred_buffer.clear();
pred_counter.clear(); pred_counter.clear();
pred_buffer.resize(mparam.PredBufferSize(), 0.0f); pred_buffer.resize(this->PredBufferSize(), 0.0f);
pred_counter.resize(mparam.PredBufferSize(), 0); pred_counter.resize(this->PredBufferSize(), 0);
} }
bool AllowLazyCheckPoint() const override { bool AllowLazyCheckPoint() const override {
@ -348,7 +326,7 @@ class GBTree : public GradientBooster {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
const int64_t bid = mparam.BufferOffset(buffer_offset + ridx, bst_group); const int64_t bid = this->BufferOffset(buffer_offset + ridx, bst_group);
const int tid = leaf_position[ridx]; const int tid = leaf_position[ridx];
CHECK_EQ(pred_counter[bid], trees.size()); CHECK_EQ(pred_counter[bid], trees.size());
CHECK_GE(tid, 0); CHECK_GE(tid, 0);
@ -372,7 +350,7 @@ class GBTree : public GradientBooster {
float psum = 0.0f; float psum = 0.0f;
// sum of leaf vector // sum of leaf vector
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f); std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group); const int64_t bid = this->BufferOffset(buffer_index, bst_group);
// number of valid trees // number of valid trees
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit; unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
// load buffered results if any // load buffered results if any
@ -452,6 +430,20 @@ class GBTree : public GradientBooster {
} }
} }
} }
/*! \return size of prediction buffer actually needed */
inline size_t PredBufferSize() const {
return mparam.num_output_group * num_pbuffer * (mparam.size_leaf_vector + 1);
}
/*!
* \brief get the buffer offset given a buffer index and group id
* \return calculated buffer offset
*/
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
if (buffer_index < 0) return -1;
size_t bidx = static_cast<size_t>(buffer_index);
CHECK_LT(bidx, num_pbuffer);
return (bidx + num_pbuffer * bst_group) * (mparam.size_leaf_vector + 1);
}
// --- data structure --- // --- data structure ---
// training parameter // training parameter
@ -462,6 +454,8 @@ class GBTree : public GradientBooster {
std::vector<std::unique_ptr<RegTree> > trees; std::vector<std::unique_ptr<RegTree> > trees;
/*! \brief some information indicator of the tree, reserved */ /*! \brief some information indicator of the tree, reserved */
std::vector<int> tree_info; std::vector<int> tree_info;
/*! \brief predict buffer size */
size_t num_pbuffer;
/*! \brief prediction buffer */ /*! \brief prediction buffer */
std::vector<float> pred_buffer; std::vector<float> pred_buffer;
/*! \brief prediction buffer counter, remember the prediction */ /*! \brief prediction buffer counter, remember the prediction */

View File

@ -29,13 +29,13 @@ class TreeSyncher: public TreeUpdater {
int rank = rabit::GetRank(); int rank = rabit::GetRank();
if (rank == 0) { if (rank == 0) {
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(&fs); trees[i]->Save(&fs);
} }
} }
fs.Seek(0); fs.Seek(0);
rabit::Broadcast(&s_model, 0); rabit::Broadcast(&s_model, 0);
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->LoadModel(&fs); trees[i]->Load(&fs);
} }
} }
}; };