[LEARNER] Init learner interface

This commit is contained in:
tqchen
2016-01-03 05:16:05 -08:00
parent 084f5f4715
commit 82ceb4de0a
6 changed files with 191 additions and 7 deletions

View File

@@ -206,7 +206,7 @@ class GBLinear : public GradientBooster {
LOG(FATAL) << "gblinear does not support predict leaf index";
}
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) override {
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
std::stringstream fo("");
fo << "bias:\n";
for (int i = 0; i < model.param.num_output_group; ++i) {
@@ -258,13 +258,19 @@ class GBLinear : public GradientBooster {
fi->Read(&weight);
}
// model bias
inline float* bias(void) {
inline float* bias() {
return &weight[param.num_feature * param.num_output_group];
}
inline const float* bias() const {
return &weight[param.num_feature * param.num_output_group];
}
// get i-th weight
inline float* operator[](size_t i) {
return &weight[i * param.num_output_group];
}
inline const float* operator[](size_t i) const {
return &weight[i * param.num_output_group];
}
};
// model field
Model model;

View File

@@ -113,7 +113,11 @@ class GBTree : public GradientBooster {
for (const auto& up : updaters) {
up->Init(cfg);
}
if (tparam.nthread != 0) {
omp_set_num_threads(tparam.nthread);
}
}
void LoadModel(dmlc::Stream* fi) override {
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
<< "GBTree: invalid model file";
@@ -130,6 +134,7 @@ class GBTree : public GradientBooster {
}
this->ResetPredBuffer(0);
}
void SaveModel(dmlc::Stream* fo) const override {
CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size()));
// not save predict buffer.
@@ -143,6 +148,7 @@ class GBTree : public GradientBooster {
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
}
}
void InitModel() override {
CHECK(mparam.num_trees == 0 && trees.size() == 0)
<< "Model has already been initialized.";
@@ -151,6 +157,7 @@ class GBTree : public GradientBooster {
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
pred_counter.resize(mparam.PredBufferSize(), 0);
}
void ResetPredBuffer(size_t num_pbuffer) override {
mparam.num_pbuffer = static_cast<int64_t>(num_pbuffer);
pred_buffer.clear();
@@ -158,10 +165,12 @@ class GBTree : public GradientBooster {
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
pred_counter.resize(mparam.PredBufferSize(), 0);
}
bool AllowLazyCheckPoint() const override {
return mparam.num_output_group == 1 ||
tparam.updater_seq.find("distcol") != std::string::npos;
}
void DoBoost(DMatrix* p_fmat,
int64_t buffer_offset,
std::vector<bst_gpair>* in_gpair) override {
@@ -191,6 +200,7 @@ class GBTree : public GradientBooster {
this->CommitModel(std::move(new_trees[gid]), gid);
}
}
void Predict(DMatrix* p_fmat,
int64_t buffer_offset,
std::vector<float>* out_preds,
@@ -230,6 +240,7 @@ class GBTree : public GradientBooster {
}
}
}
void Predict(const SparseBatch::Inst& inst,
std::vector<float>* out_preds,
unsigned ntree_limit,
@@ -246,9 +257,10 @@ class GBTree : public GradientBooster {
ntree_limit);
}
}
void PredictLeaf(DMatrix* p_fmat,
std::vector<float>* out_preds,
unsigned ntree_limit) {
unsigned ntree_limit) override {
int nthread;
#pragma omp parallel
{
@@ -257,7 +269,8 @@ class GBTree : public GradientBooster {
InitThreadTemp(nthread);
this->PredPath(p_fmat, out_preds, ntree_limit);
}
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) {
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
std::vector<std::string> dump;
for (size_t i = 0; i < trees.size(); i++) {
dump.push_back(trees[i]->Dump2Text(fmap, option & 1));