[GBM] remove need to explicit InitModel, rename save/load
This commit is contained in:
parent
82ceb4de0a
commit
4b4b36d047
@ -32,21 +32,16 @@ class GradientBooster {
|
|||||||
* \param cfg configurations on both training and model parameters.
|
* \param cfg configurations on both training and model parameters.
|
||||||
*/
|
*/
|
||||||
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
|
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
|
||||||
/*!
|
|
||||||
* \brief Initialize the model.
|
|
||||||
* User need to call Configure before calling InitModel.
|
|
||||||
*/
|
|
||||||
virtual void InitModel() = 0;
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from stream
|
* \brief load model from stream
|
||||||
* \param fi input stream.
|
* \param fi input stream.
|
||||||
*/
|
*/
|
||||||
virtual void LoadModel(dmlc::Stream* fi) = 0;
|
virtual void Load(dmlc::Stream* fi) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief save model to stream.
|
* \brief save model to stream.
|
||||||
* \param fo output stream
|
* \param fo output stream
|
||||||
*/
|
*/
|
||||||
virtual void SaveModel(dmlc::Stream* fo) const = 0;
|
virtual void Save(dmlc::Stream* fo) const = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief reset the predict buffer size.
|
* \brief reset the predict buffer size.
|
||||||
* This will invalidate all the previous cached results
|
* This will invalidate all the previous cached results
|
||||||
|
|||||||
@ -304,7 +304,7 @@ class TreeModel {
|
|||||||
* \brief load model from stream
|
* \brief load model from stream
|
||||||
* \param fi input stream
|
* \param fi input stream
|
||||||
*/
|
*/
|
||||||
inline void LoadModel(dmlc::Stream* fi) {
|
inline void Load(dmlc::Stream* fi) {
|
||||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||||
nodes.resize(param.num_nodes);
|
nodes.resize(param.num_nodes);
|
||||||
stats.resize(param.num_nodes);
|
stats.resize(param.num_nodes);
|
||||||
@ -327,7 +327,7 @@ class TreeModel {
|
|||||||
* \brief save model to stream
|
* \brief save model to stream
|
||||||
* \param fo output stream
|
* \param fo output stream
|
||||||
*/
|
*/
|
||||||
inline void SaveModel(dmlc::Stream* fo) const {
|
inline void Save(dmlc::Stream* fo) const {
|
||||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes.size()));
|
CHECK_EQ(param.num_nodes, static_cast<int>(nodes.size()));
|
||||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats.size()));
|
CHECK_EQ(param.num_nodes, static_cast<int>(stats.size()));
|
||||||
fo->Write(¶m, sizeof(TreeParam));
|
fo->Write(¶m, sizeof(TreeParam));
|
||||||
|
|||||||
@ -90,18 +90,20 @@ class GBLinear : public GradientBooster {
|
|||||||
}
|
}
|
||||||
param.InitAllowUnknown(cfg);
|
param.InitAllowUnknown(cfg);
|
||||||
}
|
}
|
||||||
void LoadModel(dmlc::Stream* fi) override {
|
void Load(dmlc::Stream* fi) override {
|
||||||
model.LoadModel(fi);
|
model.Load(fi);
|
||||||
}
|
}
|
||||||
void SaveModel(dmlc::Stream* fo) const override {
|
void Save(dmlc::Stream* fo) const override {
|
||||||
model.SaveModel(fo);
|
model.Save(fo);
|
||||||
}
|
|
||||||
void InitModel() override {
|
|
||||||
model.InitModel();
|
|
||||||
}
|
}
|
||||||
virtual void DoBoost(DMatrix *p_fmat,
|
virtual void DoBoost(DMatrix *p_fmat,
|
||||||
int64_t buffer_offset,
|
int64_t buffer_offset,
|
||||||
std::vector<bst_gpair> *in_gpair) {
|
std::vector<bst_gpair> *in_gpair) {
|
||||||
|
// lazily initialize the model when not ready.
|
||||||
|
if (model.weight.size() == 0) {
|
||||||
|
model.InitModel();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<bst_gpair> &gpair = *in_gpair;
|
std::vector<bst_gpair> &gpair = *in_gpair;
|
||||||
const int ngroup = model.param.num_output_group;
|
const int ngroup = model.param.num_output_group;
|
||||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||||
@ -248,12 +250,12 @@ class GBLinear : public GradientBooster {
|
|||||||
std::fill(weight.begin(), weight.end(), 0.0f);
|
std::fill(weight.begin(), weight.end(), 0.0f);
|
||||||
}
|
}
|
||||||
// save the model to file
|
// save the model to file
|
||||||
inline void SaveModel(dmlc::Stream* fo) const {
|
inline void Save(dmlc::Stream* fo) const {
|
||||||
fo->Write(¶m, sizeof(param));
|
fo->Write(¶m, sizeof(param));
|
||||||
fo->Write(weight);
|
fo->Write(weight);
|
||||||
}
|
}
|
||||||
// load model from file
|
// load model from file
|
||||||
inline void LoadModel(dmlc::Stream* fi) {
|
inline void Load(dmlc::Stream* fi) {
|
||||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param));
|
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param));
|
||||||
fi->Read(&weight);
|
fi->Read(&weight);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,8 +52,8 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
|||||||
int num_feature;
|
int num_feature;
|
||||||
/*! \brief pad this space, for backward compatiblity reason.*/
|
/*! \brief pad this space, for backward compatiblity reason.*/
|
||||||
int pad_32bit;
|
int pad_32bit;
|
||||||
/*! \brief size of prediction buffer allocated used for buffering */
|
/*! \brief deprecated padding space. */
|
||||||
int64_t num_pbuffer;
|
int64_t num_pbuffer_deprecated;
|
||||||
/*!
|
/*!
|
||||||
* \brief how many output group a single instance can produce
|
* \brief how many output group a single instance can produce
|
||||||
* this affects the behavior of number of output we have:
|
* this affects the behavior of number of output we have:
|
||||||
@ -82,24 +82,13 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
|||||||
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
|
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
|
||||||
.describe("Reserved option for vector tree.");
|
.describe("Reserved option for vector tree.");
|
||||||
}
|
}
|
||||||
/*! \return size of prediction buffer actually needed */
|
|
||||||
inline size_t PredBufferSize() const {
|
|
||||||
return num_output_group * num_pbuffer * (size_leaf_vector + 1);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief get the buffer offset given a buffer index and group id
|
|
||||||
* \return calculated buffer offset
|
|
||||||
*/
|
|
||||||
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
|
|
||||||
if (buffer_index < 0) return -1;
|
|
||||||
CHECK_LT(buffer_index, num_pbuffer);
|
|
||||||
return (buffer_index + num_pbuffer * bst_group) * (size_leaf_vector + 1);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// gradient boosted trees
|
// gradient boosted trees
|
||||||
class GBTree : public GradientBooster {
|
class GBTree : public GradientBooster {
|
||||||
public:
|
public:
|
||||||
|
GBTree() : num_pbuffer(0) {}
|
||||||
|
|
||||||
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
|
||||||
this->cfg = cfg;
|
this->cfg = cfg;
|
||||||
// initialize model parameters if not yet been initialized.
|
// initialize model parameters if not yet been initialized.
|
||||||
@ -118,13 +107,13 @@ class GBTree : public GradientBooster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoadModel(dmlc::Stream* fi) override {
|
void Load(dmlc::Stream* fi) override {
|
||||||
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
||||||
<< "GBTree: invalid model file";
|
<< "GBTree: invalid model file";
|
||||||
trees.clear();
|
trees.clear();
|
||||||
for (int i = 0; i < mparam.num_trees; ++i) {
|
for (int i = 0; i < mparam.num_trees; ++i) {
|
||||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||||
ptr->LoadModel(fi);
|
ptr->Load(fi);
|
||||||
trees.push_back(std::move(ptr));
|
trees.push_back(std::move(ptr));
|
||||||
}
|
}
|
||||||
tree_info.resize(mparam.num_trees);
|
tree_info.resize(mparam.num_trees);
|
||||||
@ -132,38 +121,27 @@ class GBTree : public GradientBooster {
|
|||||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees),
|
CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees),
|
||||||
sizeof(int) * mparam.num_trees);
|
sizeof(int) * mparam.num_trees);
|
||||||
}
|
}
|
||||||
this->ResetPredBuffer(0);
|
// clear the predict buffer.
|
||||||
|
this->ResetPredBuffer(num_pbuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SaveModel(dmlc::Stream* fo) const override {
|
void Save(dmlc::Stream* fo) const override {
|
||||||
CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size()));
|
CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size()));
|
||||||
// not save predict buffer.
|
fo->Write(&mparam, sizeof(mparam));
|
||||||
GBTreeModelParam p = mparam;
|
|
||||||
p.num_pbuffer = 0;
|
|
||||||
fo->Write(&p, sizeof(p));
|
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
trees[i]->SaveModel(fo);
|
trees[i]->Save(fo);
|
||||||
}
|
}
|
||||||
if (tree_info.size() != 0) {
|
if (tree_info.size() != 0) {
|
||||||
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
|
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitModel() override {
|
|
||||||
CHECK(mparam.num_trees == 0 && trees.size() == 0)
|
|
||||||
<< "Model has already been initialized.";
|
|
||||||
pred_buffer.clear();
|
|
||||||
pred_counter.clear();
|
|
||||||
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
|
|
||||||
pred_counter.resize(mparam.PredBufferSize(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ResetPredBuffer(size_t num_pbuffer) override {
|
void ResetPredBuffer(size_t num_pbuffer) override {
|
||||||
mparam.num_pbuffer = static_cast<int64_t>(num_pbuffer);
|
this->num_pbuffer = num_pbuffer;
|
||||||
pred_buffer.clear();
|
pred_buffer.clear();
|
||||||
pred_counter.clear();
|
pred_counter.clear();
|
||||||
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
|
pred_buffer.resize(this->PredBufferSize(), 0.0f);
|
||||||
pred_counter.resize(mparam.PredBufferSize(), 0);
|
pred_counter.resize(this->PredBufferSize(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllowLazyCheckPoint() const override {
|
bool AllowLazyCheckPoint() const override {
|
||||||
@ -348,7 +326,7 @@ class GBTree : public GradientBooster {
|
|||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
const bst_uint ridx = rowset[i];
|
const bst_uint ridx = rowset[i];
|
||||||
const int64_t bid = mparam.BufferOffset(buffer_offset + ridx, bst_group);
|
const int64_t bid = this->BufferOffset(buffer_offset + ridx, bst_group);
|
||||||
const int tid = leaf_position[ridx];
|
const int tid = leaf_position[ridx];
|
||||||
CHECK_EQ(pred_counter[bid], trees.size());
|
CHECK_EQ(pred_counter[bid], trees.size());
|
||||||
CHECK_GE(tid, 0);
|
CHECK_GE(tid, 0);
|
||||||
@ -372,7 +350,7 @@ class GBTree : public GradientBooster {
|
|||||||
float psum = 0.0f;
|
float psum = 0.0f;
|
||||||
// sum of leaf vector
|
// sum of leaf vector
|
||||||
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
|
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
|
||||||
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
|
const int64_t bid = this->BufferOffset(buffer_index, bst_group);
|
||||||
// number of valid trees
|
// number of valid trees
|
||||||
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
|
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
|
||||||
// load buffered results if any
|
// load buffered results if any
|
||||||
@ -452,6 +430,20 @@ class GBTree : public GradientBooster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*! \return size of prediction buffer actually needed */
|
||||||
|
inline size_t PredBufferSize() const {
|
||||||
|
return mparam.num_output_group * num_pbuffer * (mparam.size_leaf_vector + 1);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief get the buffer offset given a buffer index and group id
|
||||||
|
* \return calculated buffer offset
|
||||||
|
*/
|
||||||
|
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
|
||||||
|
if (buffer_index < 0) return -1;
|
||||||
|
size_t bidx = static_cast<size_t>(buffer_index);
|
||||||
|
CHECK_LT(bidx, num_pbuffer);
|
||||||
|
return (bidx + num_pbuffer * bst_group) * (mparam.size_leaf_vector + 1);
|
||||||
|
}
|
||||||
|
|
||||||
// --- data structure ---
|
// --- data structure ---
|
||||||
// training parameter
|
// training parameter
|
||||||
@ -462,6 +454,8 @@ class GBTree : public GradientBooster {
|
|||||||
std::vector<std::unique_ptr<RegTree> > trees;
|
std::vector<std::unique_ptr<RegTree> > trees;
|
||||||
/*! \brief some information indicator of the tree, reserved */
|
/*! \brief some information indicator of the tree, reserved */
|
||||||
std::vector<int> tree_info;
|
std::vector<int> tree_info;
|
||||||
|
/*! \brief predict buffer size */
|
||||||
|
size_t num_pbuffer;
|
||||||
/*! \brief prediction buffer */
|
/*! \brief prediction buffer */
|
||||||
std::vector<float> pred_buffer;
|
std::vector<float> pred_buffer;
|
||||||
/*! \brief prediction buffer counter, remember the prediction */
|
/*! \brief prediction buffer counter, remember the prediction */
|
||||||
|
|||||||
@ -29,13 +29,13 @@ class TreeSyncher: public TreeUpdater {
|
|||||||
int rank = rabit::GetRank();
|
int rank = rabit::GetRank();
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
trees[i]->SaveModel(&fs);
|
trees[i]->Save(&fs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fs.Seek(0);
|
fs.Seek(0);
|
||||||
rabit::Broadcast(&s_model, 0);
|
rabit::Broadcast(&s_model, 0);
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
trees[i]->LoadModel(&fs);
|
trees[i]->Load(&fs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user