rename SparseBatch to RowBatch

This commit is contained in:
tqchen@graphlab.com
2014-08-27 10:56:55 -07:00
parent d5a5e0a42a
commit a59f8945dc
7 changed files with 56 additions and 46 deletions

View File

@@ -106,11 +106,11 @@ class GBLinear : public IGradBooster<FMatrix> {
std::vector<float> &preds = *out_preds;
preds.resize(0);
// start collecting the prediction
utils::IIterator<SparseBatch> *iter = fmat.RowIterator();
utils::IIterator<RowBatch> *iter = fmat.RowIterator();
iter->BeforeFirst();
const int ngroup = model.param.num_output_group;
while (iter->Next()) {
const SparseBatch &batch = iter->Value();
const RowBatch &batch = iter->Value();
utils::Assert(batch.base_rowid * ngroup == preds.size(),
"base_rowid is not set correctly");
// output convention: nrow * k, where nrow is number of rows
@@ -146,7 +146,7 @@ class GBLinear : public IGradBooster<FMatrix> {
}
random::Shuffle(feat_index);
}
inline void Pred(const SparseBatch::Inst &inst, float *preds) {
inline void Pred(const RowBatch::Inst &inst, float *preds) {
for (int gid = 0; gid < model.param.num_output_group; ++gid) {
float psum = model.bias()[gid];
for (bst_uint i = 0; i < inst.length; ++i) {

View File

@@ -121,10 +121,10 @@ class GBTree : public IGradBooster<FMatrix> {
const size_t stride = info.num_row * mparam.num_output_group;
preds.resize(stride * (mparam.size_leaf_vector+1));
// start collecting the prediction
utils::IIterator<SparseBatch> *iter = fmat.RowIterator();
utils::IIterator<RowBatch> *iter = fmat.RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const SparseBatch &batch = iter->Value();
const RowBatch &batch = iter->Value();
// parallel over local batch
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
@@ -208,7 +208,7 @@ class GBTree : public IGradBooster<FMatrix> {
mparam.num_trees += tparam.num_parallel_tree;
}
// make a prediction for a single instance
inline void Pred(const SparseBatch::Inst &inst,
inline void Pred(const RowBatch::Inst &inst,
int64_t buffer_index,
int bst_group,
unsigned root_index,