[GBM] remove need to explicit InitModel, rename save/load
This commit is contained in:
parent
82ceb4de0a
commit
4b4b36d047
@ -32,21 +32,16 @@ class GradientBooster {
|
||||
* \param cfg configurations on both training and model parameters.
|
||||
*/
|
||||
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
|
||||
/*!
|
||||
* \brief Initialize the model.
|
||||
* User need to call Configure before calling InitModel.
|
||||
*/
|
||||
virtual void InitModel() = 0;
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream.
|
||||
*/
|
||||
virtual void LoadModel(dmlc::Stream* fi) = 0;
|
||||
virtual void Load(dmlc::Stream* fi) = 0;
|
||||
/*!
|
||||
* \brief save model to stream.
|
||||
* \param fo output stream
|
||||
*/
|
||||
virtual void SaveModel(dmlc::Stream* fo) const = 0;
|
||||
virtual void Save(dmlc::Stream* fo) const = 0;
|
||||
/*!
|
||||
* \brief reset the predict buffer size.
|
||||
* This will invalidate all the previous cached results
|
||||
|
||||
@ -304,7 +304,7 @@ class TreeModel {
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
*/
|
||||
inline void LoadModel(dmlc::Stream* fi) {
|
||||
inline void Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||
nodes.resize(param.num_nodes);
|
||||
stats.resize(param.num_nodes);
|
||||
@ -327,7 +327,7 @@ class TreeModel {
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
inline void SaveModel(dmlc::Stream* fo) const {
|
||||
inline void Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes.size()));
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats.size()));
|
||||
fo->Write(¶m, sizeof(TreeParam));
|
||||
|
||||
@ -90,18 +90,20 @@ class GBLinear : public GradientBooster {
|
||||
}
|
||||
param.InitAllowUnknown(cfg);
|
||||
}
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
model.LoadModel(fi);
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
model.Load(fi);
|
||||
}
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
model.SaveModel(fo);
|
||||
}
|
||||
void InitModel() override {
|
||||
model.InitModel();
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
model.Save(fo);
|
||||
}
|
||||
virtual void DoBoost(DMatrix *p_fmat,
|
||||
int64_t buffer_offset,
|
||||
std::vector<bst_gpair> *in_gpair) {
|
||||
// lazily initialize the model when not ready.
|
||||
if (model.weight.size() == 0) {
|
||||
model.InitModel();
|
||||
}
|
||||
|
||||
std::vector<bst_gpair> &gpair = *in_gpair;
|
||||
const int ngroup = model.param.num_output_group;
|
||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||
@ -248,12 +250,12 @@ class GBLinear : public GradientBooster {
|
||||
std::fill(weight.begin(), weight.end(), 0.0f);
|
||||
}
|
||||
// save the model to file
|
||||
inline void SaveModel(dmlc::Stream* fo) const {
|
||||
inline void Save(dmlc::Stream* fo) const {
|
||||
fo->Write(¶m, sizeof(param));
|
||||
fo->Write(weight);
|
||||
}
|
||||
// load model from file
|
||||
inline void LoadModel(dmlc::Stream* fi) {
|
||||
inline void Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param));
|
||||
fi->Read(&weight);
|
||||
}
|
||||
|
||||
@ -52,8 +52,8 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
int num_feature;
|
||||
/*! \brief pad this space, for backward compatiblity reason.*/
|
||||
int pad_32bit;
|
||||
/*! \brief size of prediction buffer allocated used for buffering */
|
||||
int64_t num_pbuffer;
|
||||
/*! \brief deprecated padding space. */
|
||||
int64_t num_pbuffer_deprecated;
|
||||
/*!
|
||||
* \brief how many output group a single instance can produce
|
||||
* this affects the behavior of number of output we have:
|
||||
@ -82,24 +82,13 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
|
||||
.describe("Reserved option for vector tree.");
|
||||
}
|
||||
/*! \return size of prediction buffer actually needed */
|
||||
inline size_t PredBufferSize() const {
|
||||
return num_output_group * num_pbuffer * (size_leaf_vector + 1);
|
||||
}
|
||||
/*!
|
||||
* \brief get the buffer offset given a buffer index and group id
|
||||
* \return calculated buffer offset
|
||||
*/
|
||||
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
|
||||
if (buffer_index < 0) return -1;
|
||||
CHECK_LT(buffer_index, num_pbuffer);
|
||||
return (buffer_index + num_pbuffer * bst_group) * (size_leaf_vector + 1);
|
||||
}
|
||||
};
|
||||
|
||||
// gradient boosted trees
|
||||
class GBTree : public GradientBooster {
|
||||
public:
|
||||
GBTree() : num_pbuffer(0) {}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
|
||||
this->cfg = cfg;
|
||||
// initialize model parameters if not yet been initialized.
|
||||
@ -118,13 +107,13 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
||||
<< "GBTree: invalid model file";
|
||||
trees.clear();
|
||||
for (int i = 0; i < mparam.num_trees; ++i) {
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->LoadModel(fi);
|
||||
ptr->Load(fi);
|
||||
trees.push_back(std::move(ptr));
|
||||
}
|
||||
tree_info.resize(mparam.num_trees);
|
||||
@ -132,38 +121,27 @@ class GBTree : public GradientBooster {
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees),
|
||||
sizeof(int) * mparam.num_trees);
|
||||
}
|
||||
this->ResetPredBuffer(0);
|
||||
// clear the predict buffer.
|
||||
this->ResetPredBuffer(num_pbuffer);
|
||||
}
|
||||
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size()));
|
||||
// not save predict buffer.
|
||||
GBTreeModelParam p = mparam;
|
||||
p.num_pbuffer = 0;
|
||||
fo->Write(&p, sizeof(p));
|
||||
fo->Write(&mparam, sizeof(mparam));
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->SaveModel(fo);
|
||||
trees[i]->Save(fo);
|
||||
}
|
||||
if (tree_info.size() != 0) {
|
||||
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
|
||||
}
|
||||
}
|
||||
|
||||
void InitModel() override {
|
||||
CHECK(mparam.num_trees == 0 && trees.size() == 0)
|
||||
<< "Model has already been initialized.";
|
||||
pred_buffer.clear();
|
||||
pred_counter.clear();
|
||||
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
|
||||
pred_counter.resize(mparam.PredBufferSize(), 0);
|
||||
}
|
||||
|
||||
void ResetPredBuffer(size_t num_pbuffer) override {
|
||||
mparam.num_pbuffer = static_cast<int64_t>(num_pbuffer);
|
||||
this->num_pbuffer = num_pbuffer;
|
||||
pred_buffer.clear();
|
||||
pred_counter.clear();
|
||||
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
|
||||
pred_counter.resize(mparam.PredBufferSize(), 0);
|
||||
pred_buffer.resize(this->PredBufferSize(), 0.0f);
|
||||
pred_counter.resize(this->PredBufferSize(), 0);
|
||||
}
|
||||
|
||||
bool AllowLazyCheckPoint() const override {
|
||||
@ -348,7 +326,7 @@ class GBTree : public GradientBooster {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
const int64_t bid = mparam.BufferOffset(buffer_offset + ridx, bst_group);
|
||||
const int64_t bid = this->BufferOffset(buffer_offset + ridx, bst_group);
|
||||
const int tid = leaf_position[ridx];
|
||||
CHECK_EQ(pred_counter[bid], trees.size());
|
||||
CHECK_GE(tid, 0);
|
||||
@ -372,7 +350,7 @@ class GBTree : public GradientBooster {
|
||||
float psum = 0.0f;
|
||||
// sum of leaf vector
|
||||
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
|
||||
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
|
||||
const int64_t bid = this->BufferOffset(buffer_index, bst_group);
|
||||
// number of valid trees
|
||||
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
|
||||
// load buffered results if any
|
||||
@ -452,6 +430,20 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \return size of prediction buffer actually needed */
|
||||
inline size_t PredBufferSize() const {
|
||||
return mparam.num_output_group * num_pbuffer * (mparam.size_leaf_vector + 1);
|
||||
}
|
||||
/*!
|
||||
* \brief get the buffer offset given a buffer index and group id
|
||||
* \return calculated buffer offset
|
||||
*/
|
||||
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
|
||||
if (buffer_index < 0) return -1;
|
||||
size_t bidx = static_cast<size_t>(buffer_index);
|
||||
CHECK_LT(bidx, num_pbuffer);
|
||||
return (bidx + num_pbuffer * bst_group) * (mparam.size_leaf_vector + 1);
|
||||
}
|
||||
|
||||
// --- data structure ---
|
||||
// training parameter
|
||||
@ -462,6 +454,8 @@ class GBTree : public GradientBooster {
|
||||
std::vector<std::unique_ptr<RegTree> > trees;
|
||||
/*! \brief some information indicator of the tree, reserved */
|
||||
std::vector<int> tree_info;
|
||||
/*! \brief predict buffer size */
|
||||
size_t num_pbuffer;
|
||||
/*! \brief prediction buffer */
|
||||
std::vector<float> pred_buffer;
|
||||
/*! \brief prediction buffer counter, remember the prediction */
|
||||
|
||||
@ -29,13 +29,13 @@ class TreeSyncher: public TreeUpdater {
|
||||
int rank = rabit::GetRank();
|
||||
if (rank == 0) {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->SaveModel(&fs);
|
||||
trees[i]->Save(&fs);
|
||||
}
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->LoadModel(&fs);
|
||||
trees[i]->Load(&fs);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user