[LEARNER] Init learner interface
This commit is contained in:
@@ -206,7 +206,7 @@ class GBLinear : public GradientBooster {
|
||||
LOG(FATAL) << "gblinear does not support predict leaf index";
|
||||
}
|
||||
|
||||
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) override {
|
||||
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
|
||||
std::stringstream fo("");
|
||||
fo << "bias:\n";
|
||||
for (int i = 0; i < model.param.num_output_group; ++i) {
|
||||
@@ -258,13 +258,19 @@ class GBLinear : public GradientBooster {
|
||||
fi->Read(&weight);
|
||||
}
|
||||
// model bias
|
||||
inline float* bias(void) {
|
||||
inline float* bias() {
|
||||
return &weight[param.num_feature * param.num_output_group];
|
||||
}
|
||||
inline const float* bias() const {
|
||||
return &weight[param.num_feature * param.num_output_group];
|
||||
}
|
||||
// get i-th weight
|
||||
inline float* operator[](size_t i) {
|
||||
return &weight[i * param.num_output_group];
|
||||
}
|
||||
inline const float* operator[](size_t i) const {
|
||||
return &weight[i * param.num_output_group];
|
||||
}
|
||||
};
|
||||
// model field
|
||||
Model model;
|
||||
|
||||
@@ -113,7 +113,11 @@ class GBTree : public GradientBooster {
|
||||
for (const auto& up : updaters) {
|
||||
up->Init(cfg);
|
||||
}
|
||||
if (tparam.nthread != 0) {
|
||||
omp_set_num_threads(tparam.nthread);
|
||||
}
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
||||
<< "GBTree: invalid model file";
|
||||
@@ -130,6 +134,7 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
this->ResetPredBuffer(0);
|
||||
}
|
||||
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size()));
|
||||
// not save predict buffer.
|
||||
@@ -143,6 +148,7 @@ class GBTree : public GradientBooster {
|
||||
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
|
||||
}
|
||||
}
|
||||
|
||||
void InitModel() override {
|
||||
CHECK(mparam.num_trees == 0 && trees.size() == 0)
|
||||
<< "Model has already been initialized.";
|
||||
@@ -151,6 +157,7 @@ class GBTree : public GradientBooster {
|
||||
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
|
||||
pred_counter.resize(mparam.PredBufferSize(), 0);
|
||||
}
|
||||
|
||||
void ResetPredBuffer(size_t num_pbuffer) override {
|
||||
mparam.num_pbuffer = static_cast<int64_t>(num_pbuffer);
|
||||
pred_buffer.clear();
|
||||
@@ -158,10 +165,12 @@ class GBTree : public GradientBooster {
|
||||
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
|
||||
pred_counter.resize(mparam.PredBufferSize(), 0);
|
||||
}
|
||||
|
||||
bool AllowLazyCheckPoint() const override {
|
||||
return mparam.num_output_group == 1 ||
|
||||
tparam.updater_seq.find("distcol") != std::string::npos;
|
||||
}
|
||||
|
||||
void DoBoost(DMatrix* p_fmat,
|
||||
int64_t buffer_offset,
|
||||
std::vector<bst_gpair>* in_gpair) override {
|
||||
@@ -191,6 +200,7 @@ class GBTree : public GradientBooster {
|
||||
this->CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
}
|
||||
|
||||
void Predict(DMatrix* p_fmat,
|
||||
int64_t buffer_offset,
|
||||
std::vector<float>* out_preds,
|
||||
@@ -230,6 +240,7 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Predict(const SparseBatch::Inst& inst,
|
||||
std::vector<float>* out_preds,
|
||||
unsigned ntree_limit,
|
||||
@@ -246,9 +257,10 @@ class GBTree : public GradientBooster {
|
||||
ntree_limit);
|
||||
}
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat,
|
||||
std::vector<float>* out_preds,
|
||||
unsigned ntree_limit) {
|
||||
unsigned ntree_limit) override {
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
@@ -257,7 +269,8 @@ class GBTree : public GradientBooster {
|
||||
InitThreadTemp(nthread);
|
||||
this->PredPath(p_fmat, out_preds, ntree_limit);
|
||||
}
|
||||
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) {
|
||||
|
||||
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
|
||||
std::vector<std::string> dump;
|
||||
for (size_t i = 0; i < trees.size(); i++) {
|
||||
dump.push_back(trees[i]->Dump2Text(fmap, option & 1));
|
||||
|
||||
Reference in New Issue
Block a user