refresher test
This commit is contained in:
parent
762b360739
commit
91e70c76ff
@ -199,6 +199,10 @@ class FMatrixS : public FMatrixInterface<FMatrixS>{
|
|||||||
utils::Check(this->HaveColAccess(), "NumCol:need column access");
|
utils::Check(this->HaveColAccess(), "NumCol:need column access");
|
||||||
return col_ptr_.size() - 1;
|
return col_ptr_.size() - 1;
|
||||||
}
|
}
|
||||||
|
/*! \brief get number of buffered rows */
|
||||||
|
inline size_t NumBufferedRow(void) const {
|
||||||
|
return num_buffered_row_;
|
||||||
|
}
|
||||||
/*! \brief get col sorted iterator */
|
/*! \brief get col sorted iterator */
|
||||||
inline ColIter GetSortedCol(size_t cidx) const {
|
inline ColIter GetSortedCol(size_t cidx) const {
|
||||||
utils::Assert(cidx < this->NumCol(), "col id exceed bound");
|
utils::Assert(cidx < this->NumCol(), "col id exceed bound");
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <limits>
|
||||||
#include "./objective.h"
|
#include "./objective.h"
|
||||||
#include "./evaluation.h"
|
#include "./evaluation.h"
|
||||||
#include "../gbm/gbm.h"
|
#include "../gbm/gbm.h"
|
||||||
@ -28,6 +29,8 @@ class BoostLearner {
|
|||||||
gbm_ = NULL;
|
gbm_ = NULL;
|
||||||
name_obj_ = "reg:linear";
|
name_obj_ = "reg:linear";
|
||||||
name_gbm_ = "gbtree";
|
name_gbm_ = "gbtree";
|
||||||
|
silent= 0;
|
||||||
|
max_buffer_row = std::numeric_limits<size_t>::max();
|
||||||
}
|
}
|
||||||
~BoostLearner(void) {
|
~BoostLearner(void) {
|
||||||
if (obj_ != NULL) delete obj_;
|
if (obj_ != NULL) delete obj_;
|
||||||
@ -77,6 +80,7 @@ class BoostLearner {
|
|||||||
*/
|
*/
|
||||||
inline void SetParam(const char *name, const char *val) {
|
inline void SetParam(const char *name, const char *val) {
|
||||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
if (!strcmp(name, "silent")) silent = atoi(val);
|
||||||
|
if (!strcmp(name, "max_buffer_row")) sscanf(val, "%lu", &max_buffer_row);
|
||||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
||||||
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
||||||
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
||||||
@ -87,7 +91,9 @@ class BoostLearner {
|
|||||||
}
|
}
|
||||||
if (gbm_ != NULL) gbm_->SetParam(name, val);
|
if (gbm_ != NULL) gbm_->SetParam(name, val);
|
||||||
if (obj_ != NULL) obj_->SetParam(name, val);
|
if (obj_ != NULL) obj_->SetParam(name, val);
|
||||||
cfg_.push_back(std::make_pair(std::string(name), std::string(val)));
|
if (gbm_ == NULL || obj_ == NULL) {
|
||||||
|
cfg_.push_back(std::make_pair(std::string(name), std::string(val)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize the model
|
* \brief initialize the model
|
||||||
@ -144,8 +150,8 @@ class BoostLearner {
|
|||||||
* if not intialize it
|
* if not intialize it
|
||||||
* \param p_train pointer to the matrix used by training
|
* \param p_train pointer to the matrix used by training
|
||||||
*/
|
*/
|
||||||
inline void CheckInit(DMatrix<FMatrix> *p_train) const {
|
inline void CheckInit(DMatrix<FMatrix> *p_train) {
|
||||||
p_train->fmat.InitColAccess();
|
p_train->fmat.InitColAccess(max_buffer_row);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief update the model for one iteration
|
* \brief update the model for one iteration
|
||||||
@ -286,6 +292,8 @@ class BoostLearner {
|
|||||||
// data fields
|
// data fields
|
||||||
// silent during training
|
// silent during training
|
||||||
int silent;
|
int silent;
|
||||||
|
// maximum buffred row value
|
||||||
|
size_t max_buffer_row;
|
||||||
// evaluation set
|
// evaluation set
|
||||||
EvalSet evaluator_;
|
EvalSet evaluator_;
|
||||||
// model parameter
|
// model parameter
|
||||||
|
|||||||
@ -110,22 +110,22 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
const std::vector<unsigned> &root_index, const RegTree &tree) {
|
const std::vector<unsigned> &root_index, const RegTree &tree) {
|
||||||
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
|
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
|
||||||
{// setup position
|
{// setup position
|
||||||
position.resize(gpair.size());
|
position.resize(fmat.NumBufferedRow());
|
||||||
if (root_index.size() == 0) {
|
if (root_index.size() == 0) {
|
||||||
std::fill(position.begin(), position.end(), 0);
|
std::fill(position.begin(), position.end(), 0);
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < root_index.size(); ++i) {
|
for (size_t i = 0; i < position.size(); ++i) {
|
||||||
position[i] = root_index[i];
|
position[i] = root_index[i];
|
||||||
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, "root index exceed setting");
|
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, "root index exceed setting");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// mark delete for the deleted datas
|
// mark delete for the deleted datas
|
||||||
for (size_t i = 0; i < gpair.size(); ++i) {
|
for (size_t i = 0; i < position.size(); ++i) {
|
||||||
if (gpair[i].hess < 0.0f) position[i] = -1;
|
if (gpair[i].hess < 0.0f) position[i] = -1;
|
||||||
}
|
}
|
||||||
// mark subsample
|
// mark subsample
|
||||||
if (param.subsample < 1.0f) {
|
if (param.subsample < 1.0f) {
|
||||||
for (size_t i = 0; i < gpair.size(); ++i) {
|
for (size_t i = 0; i < position.size(); ++i) {
|
||||||
if (gpair[i].hess < 0.0f) continue;
|
if (gpair[i].hess < 0.0f) continue;
|
||||||
if (random::SampleBinary(param.subsample) == 0) position[i] = -1;
|
if (random::SampleBinary(param.subsample) == 0) position[i] = -1;
|
||||||
}
|
}
|
||||||
@ -271,7 +271,9 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
}
|
}
|
||||||
// start enumeration
|
// start enumeration
|
||||||
const unsigned nsize = static_cast<unsigned>(feat_set.size());
|
const unsigned nsize = static_cast<unsigned>(feat_set.size());
|
||||||
|
#if defined(_OPENMP)
|
||||||
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
|
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
|
||||||
|
#endif
|
||||||
#pragma omp parallel for schedule(dynamic, batch_size)
|
#pragma omp parallel for schedule(dynamic, batch_size)
|
||||||
for (unsigned i = 0; i < nsize; ++i) {
|
for (unsigned i = 0; i < nsize; ++i) {
|
||||||
const unsigned fid = feat_set[i];
|
const unsigned fid = feat_set[i];
|
||||||
|
|||||||
@ -20,7 +20,6 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
// set training parameter
|
// set training parameter
|
||||||
virtual void SetParam(const char *name, const char *val) {
|
virtual void SetParam(const char *name, const char *val) {
|
||||||
param.SetParam(name, val);
|
param.SetParam(name, val);
|
||||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
|
||||||
}
|
}
|
||||||
// update the tree, do pruning
|
// update the tree, do pruning
|
||||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||||
@ -127,8 +126,6 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
}
|
}
|
||||||
// number of thread in the data
|
// number of thread in the data
|
||||||
int nthread;
|
int nthread;
|
||||||
// shutup
|
|
||||||
int silent;
|
|
||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user