Dmatrix refactor stage 1 (#3301)

* Use sparse page as singular CSR matrix representation

* Simplify dmatrix methods

* Reduce statefullness of batch iterators

* BREAKING CHANGE: Remove prob_buffer_row parameter. Users are instead recommended to sample their dataset as a preprocessing step before using XGBoost.
This commit is contained in:
Rory Mitchell 2018-06-07 10:25:58 +12:00 committed by GitHub
parent 286dccb8e8
commit a96039141a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 650 additions and 1036 deletions

View File

@ -87,8 +87,7 @@ Parameters for Tree Booster
- 'refresh': refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- 'prune': prunes the splits where loss < min_split_loss (or gamma).
- In a distributed setting, the implicit updater sequence value would be adjusted as follows:
- 'grow_histmaker,prune' when dsplit='row' (or default) and prob_buffer_row == 1 (or default); or when data has multiple sparse pages
- 'grow_histmaker,refresh,prune' when dsplit='row' and prob_buffer_row < 1
- 'grow_histmaker,prune' when dsplit='row' (or default); or when data has multiple sparse pages
- 'distcol' when dsplit='col'
* refresh_leaf, [default=1]
- This is a parameter of the 'refresh' updater plugin. When this flag is true, tree leafs as well as tree nodes' stats are updated. When it is false, only node stats are updated.

View File

@ -9,10 +9,11 @@
#include <dmlc/base.h>
#include <dmlc/data.h>
#include <string>
#include <cstring>
#include <memory>
#include <vector>
#include <numeric>
#include <string>
#include <vector>
#include "./base.h"
namespace xgboost {
@ -117,28 +118,36 @@ class MetaInfo {
mutable std::vector<size_t> label_order_cache_;
};
/*! \brief read-only sparse instance batch in CSR format */
struct SparseBatch {
/*! \brief an entry of sparse vector */
/*! \brief Element from a sparse vector */
struct Entry {
/*! \brief feature index */
bst_uint index;
/*! \brief feature value */
bst_float fvalue;
/*! \brief default constructor */
XGBOOST_DEVICE Entry() {}
Entry() = default;
/*!
* \brief constructor with index and value
* \param index The feature or row index.
* \param fvalue THe feature value.
*/
XGBOOST_DEVICE Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
/*! \brief reversely compare feature values */
XGBOOST_DEVICE inline static bool CmpValue(const Entry& a, const Entry& b) {
inline static bool CmpValue(const Entry& a, const Entry& b) {
return a.fvalue < b.fvalue;
}
};
/*!
* \brief in-memory storage unit of sparse batch
*/
class SparsePage {
public:
std::vector<size_t> offset;
/*! \brief the data of the segments */
std::vector<Entry> data;
size_t base_rowid;
/*! \brief an instance of sparse vector in the batch */
struct Inst {
/*! \brief pointer to the elements*/
@ -154,39 +163,84 @@ struct SparseBatch {
}
};
/*! \brief batch size */
size_t size;
};
/*! \brief read-only row batch, used to access row continuously */
struct RowBatch : public SparseBatch {
/*! \brief the offset of rowid of this batch */
size_t base_rowid;
/*! \brief array[size+1], row pointer of each of the elements */
const size_t *ind_ptr;
/*! \brief array[ind_ptr.back()], content of the sparse element */
const Entry *data_ptr;
/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return {data_ptr + ind_ptr[i], static_cast<bst_uint>(ind_ptr[i + 1] - ind_ptr[i])};
return {data.data() + offset[i], static_cast<bst_uint>(offset[i + 1] - offset[i])};
}
/*! \brief constructor */
SparsePage() {
this->Clear();
}
/*! \return number of instance in the page */
inline size_t Size() const {
return offset.size() - 1;
}
/*! \return estimation of memory cost of this page */
inline size_t MemCostBytes() const {
return offset.size() * sizeof(size_t) + data.size() * sizeof(Entry);
}
/*! \brief clear the page */
inline void Clear() {
base_rowid = 0;
offset.clear();
offset.push_back(0);
data.clear();
}
};
/*!
* \brief read-only column batch, used to access columns,
* the columns are not required to be continuous
* \brief Push row block into the page.
* \param batch the row batch.
*/
struct ColBatch : public SparseBatch {
/*! \brief column index of each columns in the data */
const bst_uint *col_index;
/*! \brief pointer to the column data */
const Inst *col_data;
/*! \brief get i-th column from the batch */
inline Inst operator[](size_t i) const {
return col_data[i];
inline void Push(const dmlc::RowBlock<uint32_t>& batch) {
data.reserve(data.size() + batch.offset[batch.size] - batch.offset[0]);
offset.reserve(offset.size() + batch.size);
CHECK(batch.index != nullptr);
for (size_t i = 0; i < batch.size; ++i) {
offset.push_back(offset.back() + batch.offset[i + 1] - batch.offset[i]);
}
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
data.emplace_back(index, fvalue);
}
CHECK_EQ(offset.back(), data.size());
}
/*!
* \brief Push a sparse page
* \param batch the row page
*/
inline void Push(const SparsePage &batch) {
size_t top = offset.back();
data.resize(top + batch.data.size());
std::memcpy(dmlc::BeginPtr(data) + top,
dmlc::BeginPtr(batch.data),
sizeof(Entry) * batch.data.size());
size_t begin = offset.size();
offset.resize(begin + batch.Size());
for (size_t i = 0; i < batch.Size(); ++i) {
offset[i + begin] = top + batch.offset[i + 1];
}
}
/*!
* \brief Push one instance into page
* \param inst an instance row
*/
inline void Push(const Inst &inst) {
offset.push_back(offset.back() + inst.length);
size_t begin = data.size();
data.resize(begin + inst.length);
if (inst.length != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
sizeof(Entry) * inst.length);
}
}
size_t Size() { return offset.size() - 1; }
};
/*!
* \brief This is data structure that user can pass to DMatrix::Create
* to create a DMatrix for training, user can create this data structure
@ -194,7 +248,7 @@ struct ColBatch : public SparseBatch {
*
* On distributed setting, usually an customized dmlc::Parser is needed instead.
*/
class DataSource : public dmlc::DataIter<RowBatch> {
class DataSource : public dmlc::DataIter<SparsePage> {
public:
/*!
* \brief Meta information about the dataset
@ -260,28 +314,17 @@ class DMatrix {
* \brief get the row iterator, reset to beginning position
* \note Only either RowIterator or column Iterator can be active.
*/
virtual dmlc::DataIter<RowBatch>* RowIterator() = 0;
virtual dmlc::DataIter<SparsePage>* RowIterator() = 0;
/*!\brief get column iterator, reset to the beginning position */
virtual dmlc::DataIter<ColBatch>* ColIterator() = 0;
/*!
* \brief get the column iterator associated with subset of column features.
* \param fset is the list of column index set that must be contained in the returning Column iterator
* \return the column iterator, initialized so that it reads the elements in fset
*/
virtual dmlc::DataIter<ColBatch>* ColIterator(const std::vector<bst_uint>& fset) = 0;
virtual dmlc::DataIter<SparsePage>* ColIterator() = 0;
/*!
* \brief check if column access is supported, if not, initialize column access.
* \param enabled whether certain feature should be included in column access.
* \param subsample subsample ratio when generating column access.
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \param sorted If column features should be in sorted order
* \return Number of column blocks in the column access.
*/
virtual void InitColAccess(const std::vector<bool>& enabled,
float subsample,
size_t max_row_perbatch, bool sorted) = 0;
virtual void InitColAccess(size_t max_row_perbatch, bool sorted) = 0;
// the following are column meta data, should be able to answer them fast.
/*! \return whether column access is enabled */
virtual bool HaveColAccess(bool sorted) const = 0;
@ -388,7 +431,7 @@ inline bool RowSet::Load(dmlc::Stream* fi) {
} // namespace xgboost
namespace dmlc {
DMLC_DECLARE_TRAITS(is_pod, xgboost::SparseBatch::Entry, true);
DMLC_DECLARE_TRAITS(is_pod, xgboost::Entry, true);
DMLC_DECLARE_TRAITS(has_saveload, xgboost::RowSet, true);
}
#endif // XGBOOST_DATA_H_

View File

@ -94,7 +94,7 @@ class GradientBooster {
* \param root_index the root index
* \sa Predict
*/
virtual void PredictInstance(const SparseBatch::Inst& inst,
virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0,
unsigned root_index = 0) = 0;

View File

@ -167,7 +167,7 @@ class Learner : public rabit::Serializable {
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction
*/
inline void Predict(const SparseBatch::Inst &inst,
inline void Predict(const SparsePage::Inst &inst,
bool output_margin,
HostDeviceVector<bst_float> *out_preds,
unsigned ntree_limit = 0) const;
@ -190,7 +190,7 @@ class Learner : public rabit::Serializable {
};
// implementation of inline functions.
inline void Learner::Predict(const SparseBatch::Inst& inst,
inline void Learner::Predict(const SparsePage::Inst& inst,
bool output_margin,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) const {

View File

@ -88,7 +88,7 @@ class Predictor {
int num_new_trees) = 0;
/**
* \fn virtual void Predictor::PredictInstance( const SparseBatch::Inst&
* \fn virtual void Predictor::PredictInstance( const SparsePage::Inst&
* inst, std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model,
* unsigned ntree_limit = 0, unsigned root_index = 0) = 0;
*
@ -104,7 +104,7 @@ class Predictor {
* \param root_index (Optional) Zero-based index of the root.
*/
virtual void PredictInstance(const SparseBatch::Inst& inst,
virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0,

View File

@ -447,12 +447,12 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \brief fill the vector with sparse vector
* \param inst The sparse instance to fill.
*/
inline void Fill(const RowBatch::Inst& inst);
inline void Fill(const SparsePage::Inst& inst);
/*!
* \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instance to drop.
*/
inline void Drop(const RowBatch::Inst& inst);
inline void Drop(const SparsePage::Inst& inst);
/*!
* \brief returns the size of the feature vector
* \return the size of the feature vector
@ -573,14 +573,14 @@ inline void RegTree::FVec::Init(size_t size) {
std::fill(data_.begin(), data_.end(), e);
}
inline void RegTree::FVec::Fill(const RowBatch::Inst& inst) {
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].fvalue = inst[i].fvalue;
}
}
inline void RegTree::FVec::Drop(const RowBatch::Inst& inst) {
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].flag = -1;

View File

@ -9,7 +9,7 @@
#include <dmlc/parameter.h>
#include <lz4.h>
#include <lz4hc.h>
#include "../../src/data/sparse_batch_page.h"
#include "../../src/data/sparse_page_writer.h"
namespace xgboost {
namespace data {
@ -155,7 +155,7 @@ inline void CompressArray<DType>::Write(dmlc::Stream* fo) {
}
template<typename StorageIndex>
class SparsePageLZ4Format : public SparsePage::Format {
class SparsePageLZ4Format : public SparsePageFormat {
public:
explicit SparsePageLZ4Format(bool use_lz4_hc)
: use_lz4_hc_(use_lz4_hc) {
@ -185,7 +185,7 @@ class SparsePageLZ4Format : public SparsePage::Format {
CHECK_EQ(index_.data.size(), value_.data.size());
CHECK_EQ(index_.data.size(), page->data.size());
for (size_t i = 0; i < page->data.size(); ++i) {
page->data[i] = SparseBatch::Entry(index_.data[i] + min_index_, value_.data[i]);
page->data[i] = Entry(index_.data[i] + min_index_, value_.data[i]);
}
return true;
}
@ -212,7 +212,7 @@ class SparsePageLZ4Format : public SparsePage::Format {
size_t src_begin = disk_offset_[cid];
size_t num = disk_offset_[cid + 1] - disk_offset_[cid];
for (size_t j = 0; j < num; ++j) {
page->data[dst_begin + j] = SparseBatch::Entry(
page->data[dst_begin + j] = Entry(
index_.data[src_begin + j] + min_index_, value_.data[src_begin + j]);
}
}
@ -223,7 +223,7 @@ class SparsePageLZ4Format : public SparsePage::Format {
CHECK(page.offset.size() != 0 && page.offset[0] == 0);
CHECK_EQ(page.offset.back(), page.data.size());
fo->Write(page.offset);
min_index_ = page.min_index;
min_index_ = page.base_rowid;
fo->Write(&min_index_, sizeof(min_index_));
index_.data.resize(page.data.size());
value_.data.resize(page.data.size());

View File

@ -238,20 +238,20 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
API_BEGIN();
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.reserve(nindptr);
mat.row_data_.reserve(nelem);
mat.row_ptr_.resize(1);
mat.row_ptr_[0] = 0;
mat.page_.offset.reserve(nindptr);
mat.page_.data.reserve(nelem);
mat.page_.offset.resize(1);
mat.page_.offset[0] = 0;
size_t num_column = 0;
for (size_t i = 1; i < nindptr; ++i) {
for (size_t j = indptr[i - 1]; j < indptr[i]; ++j) {
if (!common::CheckNAN(data[j])) {
// automatically skip nan.
mat.row_data_.emplace_back(RowBatch::Entry(indices[j], data[j]));
mat.page_.data.emplace_back(Entry(indices[j], data[j]));
num_column = std::max(num_column, static_cast<size_t>(indices[j] + 1));
}
}
mat.row_ptr_.push_back(mat.row_data_.size());
mat.page_.offset.push_back(mat.page_.data.size());
}
mat.info.num_col_ = num_column;
@ -261,7 +261,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
mat.info.num_col_ = num_col;
}
mat.info.num_row_ = nindptr - 1;
mat.info.num_nonzero_ = mat.row_data_.size();
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
@ -293,7 +293,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
// FIXME: User should be able to control number of threads
const int nthread = omp_get_max_threads();
data::SimpleCSRSource& mat = *source;
common::ParallelGroupBuilder<RowBatch::Entry> builder(&mat.row_ptr_, &mat.row_data_);
common::ParallelGroupBuilder<Entry> builder(&mat.page_.offset, &mat.page_.data);
builder.InitBudget(0, nthread);
size_t ncol = nindptr - 1; // NOLINT(*)
#pragma omp parallel for schedule(static)
@ -312,12 +312,12 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
for (size_t j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
if (!common::CheckNAN(data[j])) {
builder.Push(indices[j],
RowBatch::Entry(static_cast<bst_uint>(i), data[j]),
Entry(static_cast<bst_uint>(i), data[j]),
tid);
}
}
}
mat.info.num_row_ = mat.row_ptr_.size() - 1;
mat.info.num_row_ = mat.page_.offset.size() - 1;
if (num_row > 0) {
CHECK_LE(mat.info.num_row_, num_row);
mat.info.num_row_ = num_row;
@ -351,7 +351,7 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
API_BEGIN();
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.resize(1+nrow);
mat.page_.offset.resize(1+nrow);
bool nan_missing = common::CheckNAN(missing);
mat.info.num_row_ = nrow;
mat.info.num_col_ = ncol;
@ -371,9 +371,9 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
}
}
}
mat.row_ptr_[i+1] = mat.row_ptr_[i] + nelem;
mat.page_.offset[i+1] = mat.page_.offset[i] + nelem;
}
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
data = data0;
for (xgboost::bst_ulong i = 0; i < nrow; ++i, data += ncol) {
@ -382,14 +382,14 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
if (common::CheckNAN(data[j])) {
} else {
if (nan_missing || data[j] != missing) {
mat.row_data_[mat.row_ptr_[i] + matj] = RowBatch::Entry(j, data[j]);
mat.page_.data[mat.page_.offset[i] + matj] = Entry(j, data[j]);
++matj;
}
}
}
}
mat.info.num_nonzero_ = mat.row_data_.size();
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
@ -443,7 +443,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.resize(1+nrow);
mat.page_.offset.resize(1+nrow);
mat.info.num_row_ = nrow;
mat.info.num_col_ = ncol;
@ -469,7 +469,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
++nelem;
}
}
mat.row_ptr_[i+1] = nelem;
mat.page_.offset[i+1] = nelem;
}
}
// Inform about any NaNs and resize data matrix
@ -478,8 +478,8 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
}
// do cumulative sum (to avoid otherwise need to copy)
PrefixSum(&mat.row_ptr_[0], mat.row_ptr_.size());
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
PrefixSum(&mat.page_.offset[0], mat.page_.offset.size());
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
// Fill data matrix (now that know size, no need for slow push_back())
#pragma omp parallel num_threads(nthread)
@ -490,15 +490,15 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
if (common::CheckNAN(data[ncol * i + j])) {
} else if (nan_missing || data[ncol * i + j] != missing) {
mat.row_data_[mat.row_ptr_[i] + matj] =
RowBatch::Entry(j, data[ncol * i + j]);
mat.page_.data[mat.page_.offset[i] + matj] =
Entry(j, data[ncol * i + j]);
++matj;
}
}
}
}
mat.info.num_nonzero_ = mat.row_data_.size();
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
@ -521,18 +521,18 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
ret.info.num_row_ = len;
ret.info.num_col_ = src.info.num_col_;
dmlc::DataIter<RowBatch>* iter = &src;
auto iter = &src;
iter->BeforeFirst();
CHECK(iter->Next());
const RowBatch& batch = iter->Value();
const auto& batch = iter->Value();
for (xgboost::bst_ulong i = 0; i < len; ++i) {
const int ridx = idxset[i];
RowBatch::Inst inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.size);
ret.row_data_.insert(ret.row_data_.end(), inst.data,
auto inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
ret.page_.data.insert(ret.page_.data.end(), inst.data,
inst.data + inst.length);
ret.row_ptr_.push_back(ret.row_ptr_.back() + inst.length);
ret.page_.offset.push_back(ret.page_.offset.back() + inst.length);
ret.info.num_nonzero_ += inst.length;
if (src.info.labels_.size() != 0) {

View File

@ -33,19 +33,19 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
}
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
auto batch = iter->Value();
#pragma omp parallel num_threads(nthread)
{
CHECK_EQ(nthread, omp_get_num_threads());
auto tid = static_cast<unsigned>(omp_get_thread_num());
unsigned begin = std::min(nstep * tid, ncol);
unsigned end = std::min(nstep * (tid + 1), ncol);
for (size_t i = 0; i < batch.size; ++i) { // NOLINT(*)
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t ridx = batch.base_rowid + i;
RowBatch::Inst inst = batch[i];
SparsePage::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (inst[j].index >= begin && inst[j].index < end) {
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
@ -106,7 +106,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
void GHistIndexMatrix::Init(DMatrix* p_fmat) {
CHECK(cut != nullptr); // NOLINT
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
const int nthread = omp_get_max_threads();
const uint32_t nbins = cut->row_ptr.back();
@ -116,9 +116,9 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
iter->BeforeFirst();
row_ptr.push_back(0);
while (iter->Next()) {
const RowBatch& batch = iter->Value();
auto batch = iter->Value();
const size_t rbegin = row_ptr.size() - 1;
for (size_t i = 0; i < batch.size; ++i) {
for (size_t i = 0; i < batch.Size(); ++i) {
row_ptr.push_back(batch[i].length + row_ptr.back());
}
index.resize(row_ptr.back());
@ -126,13 +126,13 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
CHECK_GT(cut->cut.size(), 0U);
CHECK_EQ(cut->row_ptr.back(), cut->cut.size());
auto bsize = static_cast<omp_ulong>(batch.size);
auto bsize = static_cast<omp_ulong>(batch.Size());
#pragma omp parallel for num_threads(nthread) schedule(static)
for (omp_ulong i = 0; i < bsize; ++i) { // NOLINT(*)
const int tid = omp_get_thread_num();
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
RowBatch::Inst inst = batch[i];
SparsePage::Inst inst = batch[i];
CHECK_EQ(ibegin + inst.length, iend);
for (bst_uint j = 0; j < inst.length; ++j) {
unsigned fid = inst[j].index;

View File

@ -6,7 +6,7 @@
#include <xgboost/logging.h>
#include <dmlc/registry.h>
#include <cstring>
#include "./sparse_batch_page.h"
#include "./sparse_page_writer.h"
#include "./simple_dmatrix.h"
#include "./simple_csr_source.h"
#include "../common/common.h"
@ -278,8 +278,7 @@ DMatrix* DMatrix::Create(std::unique_ptr<DataSource>&& source,
} // namespace xgboost
namespace xgboost {
namespace data {
SparsePage::Format* SparsePage::Format::Create(const std::string& name) {
data::SparsePageFormat* data::SparsePageFormat::Create(const std::string& name) {
auto *e = ::dmlc::Registry< ::xgboost::data::SparsePageFormatReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown format type " << name;
@ -288,7 +287,7 @@ SparsePage::Format* SparsePage::Format::Create(const std::string& name) {
}
std::pair<std::string, std::string>
SparsePage::Format::DecideFormat(const std::string& cache_prefix) {
data::SparsePageFormat::DecideFormat(const std::string& cache_prefix) {
size_t pos = cache_prefix.rfind(".fmt-");
if (pos != std::string::npos) {
@ -305,6 +304,7 @@ SparsePage::Format::DecideFormat(const std::string& cache_prefix) {
}
}
namespace data {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);
} // namespace data

View File

@ -10,24 +10,18 @@ namespace xgboost {
namespace data {
void SimpleCSRSource::Clear() {
row_data_.clear();
row_ptr_.resize(1);
row_ptr_[0] = 0;
page_.Clear();
this->info.Clear();
}
void SimpleCSRSource::CopyFrom(DMatrix* src) {
this->Clear();
this->info = src->Info();
dmlc::DataIter<RowBatch>* iter = src->RowIterator();
auto iter = src->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
RowBatch::Inst inst = batch[i];
row_data_.insert(row_data_.end(), inst.data, inst.data + inst.length);
row_ptr_.push_back(row_ptr_.back() + inst.length);
}
const auto &batch = iter->Value();
page_.Push(batch);
}
}
@ -53,16 +47,16 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
row_data_.emplace_back(index, fvalue);
page_.data.emplace_back(index, fvalue);
this->info.num_col_ = std::max(this->info.num_col_,
static_cast<uint64_t>(index + 1));
}
size_t top = row_ptr_.size();
size_t top = page_.offset.size();
for (size_t i = 0; i < batch.size; ++i) {
row_ptr_.push_back(row_ptr_[top - 1] + batch.offset[i + 1] - batch.offset[0]);
page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]);
}
}
this->info.num_nonzero_ = static_cast<uint64_t>(row_data_.size());
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.size());
}
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
@ -70,16 +64,16 @@ void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info.LoadBinary(fi);
fi->Read(&row_ptr_);
fi->Read(&row_data_);
fi->Read(&page_.offset);
fi->Read(&page_.data);
}
void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const {
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
info.SaveBinary(fo);
fo->Write(row_ptr_);
fo->Write(row_data_);
fo->Write(page_.offset);
fo->Write(page_.data);
}
void SimpleCSRSource::BeforeFirst() {
@ -89,15 +83,11 @@ void SimpleCSRSource::BeforeFirst() {
bool SimpleCSRSource::Next() {
if (!at_first_) return false;
at_first_ = false;
batch_.size = row_ptr_.size() - 1;
batch_.base_rowid = 0;
batch_.ind_ptr = dmlc::BeginPtr(row_ptr_);
batch_.data_ptr = dmlc::BeginPtr(row_data_);
return true;
}
const RowBatch& SimpleCSRSource::Value() const {
return batch_;
const SparsePage& SimpleCSRSource::Value() const {
return page_;
}
} // namespace data

View File

@ -29,13 +29,9 @@ class SimpleCSRSource : public DataSource {
public:
// public data members
// MetaInfo info; // inheritated from DataSource
/*! \brief row pointer of CSR sparse storage */
std::vector<size_t> row_ptr_;
/*! \brief data in the CSR sparse storage */
std::vector<RowBatch::Entry> row_data_;
// functions
SparsePage page_;
/*! \brief default constructor */
SimpleCSRSource() : row_ptr_(1, 0) {}
SimpleCSRSource() = default;
/*! \brief destructor */
~SimpleCSRSource() override = default;
/*! \brief clear the data structure */
@ -66,15 +62,13 @@ class SimpleCSRSource : public DataSource {
// implement BeforeFirst
void BeforeFirst() override;
// implement Value
const RowBatch &Value() const override;
const SparsePage &Value() const override;
/*! \brief magic number used to identify SimpleCSRSource */
static const int kMagic = 0xffffab01;
private:
/*! \brief internal variable, used to support iterator interface */
bool at_first_{true};
/*! \brief */
RowBatch batch_;
};
} // namespace data
} // namespace xgboost

View File

@ -16,238 +16,89 @@ namespace xgboost {
namespace data {
bool SimpleDMatrix::ColBatchIter::Next() {
if (data_ptr_ >= cpages_.size()) return false;
data_ptr_ += 1;
SparsePage* pcol = cpages_[data_ptr_ - 1].get();
batch_.size = col_index_.size();
col_data_.resize(col_index_.size(), SparseBatch::Inst(nullptr, 0));
for (size_t i = 0; i < col_data_.size(); ++i) {
const bst_uint ridx = col_index_[i];
col_data_[i] = SparseBatch::Inst
(dmlc::BeginPtr(pcol->data) + pcol->offset[ridx],
static_cast<bst_uint>(pcol->offset[ridx + 1] - pcol->offset[ridx]));
}
batch_.col_index = dmlc::BeginPtr(col_index_);
batch_.col_data = dmlc::BeginPtr(col_data_);
if (data_ >= 1) return false;
data_ += 1;
return true;
}
dmlc::DataIter<ColBatch>* SimpleDMatrix::ColIterator() {
size_t ncol = this->Info().num_col_;
col_iter_.col_index_.resize(ncol);
for (size_t i = 0; i < ncol; ++i) {
col_iter_.col_index_[i] = static_cast<bst_uint>(i);
}
dmlc::DataIter<SparsePage>* SimpleDMatrix::ColIterator() {
col_iter_.BeforeFirst();
return &col_iter_;
}
dmlc::DataIter<ColBatch>* SimpleDMatrix::ColIterator(const std::vector<bst_uint>&fset) {
size_t ncol = this->Info().num_col_;
col_iter_.col_index_.resize(0);
for (auto fidx : fset) {
if (fidx < ncol) col_iter_.col_index_.push_back(fidx);
}
col_iter_.BeforeFirst();
return &col_iter_;
}
void SimpleDMatrix::InitColAccess(const std::vector<bool> &enabled,
float pkeep,
void SimpleDMatrix::InitColAccess(
size_t max_row_perbatch, bool sorted) {
if (this->HaveColAccess(sorted)) return;
col_iter_.sorted_ = sorted;
col_iter_.cpages_.clear();
if (Info().num_row_ < max_row_perbatch) {
std::unique_ptr<SparsePage> page(new SparsePage());
this->MakeOneBatch(enabled, pkeep, page.get(), sorted);
col_iter_.cpages_.push_back(std::move(page));
} else {
this->MakeManyBatch(enabled, pkeep, max_row_perbatch, sorted);
}
// setup col-size
col_size_.resize(Info().num_col_);
std::fill(col_size_.begin(), col_size_.end(), 0);
for (auto & cpage : col_iter_.cpages_) {
SparsePage *pcol = cpage.get();
for (size_t j = 0; j < pcol->Size(); ++j) {
col_size_[j] += pcol->offset[j + 1] - pcol->offset[j];
}
}
col_iter_.column_page_.reset(new SparsePage());
this->MakeOneBatch(col_iter_.column_page_.get(), sorted);
}
// internal function to make one batch from row iter.
void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled, float pkeep,
SparsePage* pcol, bool sorted) {
void SimpleDMatrix::MakeOneBatch(SparsePage* pcol, bool sorted) {
// clear rowset
buffered_rowset_.Clear();
// bit map
const int nthread = omp_get_max_threads();
std::vector<bool> bmap;
pcol->Clear();
common::ParallelGroupBuilder<SparseBatch::Entry>
common::ParallelGroupBuilder<Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(Info().num_col_, nthread);
// start working
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
auto iter = this->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
bmap.resize(bmap.size() + batch.size, true);
std::bernoulli_distribution coin_flip(pkeep);
auto& rnd = common::GlobalRandom();
long batch_size = static_cast<long>(batch.size); // NOLINT(*)
const auto& batch = iter->Value();
long batch_size = static_cast<long>(batch.Size()); // NOLINT(*)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || coin_flip(rnd)) {
buffered_rowset_.PushBack(ridx);
} else {
bmap[i] = false;
}
}
#pragma omp parallel for schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (bmap[ridx]) {
RowBatch::Inst inst = batch[i];
auto inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (enabled[inst[j].index]) {
builder.AddBudget(inst[j].index, tid);
}
}
}
}
}
builder.InitStorage();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
auto batch = iter->Value();
#pragma omp parallel for schedule(static)
for (long i = 0; i < static_cast<long>(batch.size); ++i) { // NOLINT(*)
for (long i = 0; i < static_cast<long>(batch.Size()); ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (bmap[ridx]) {
RowBatch::Inst inst = batch[i];
auto inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (enabled[inst[j].index]) {
builder.Push(inst[j].index,
SparseBatch::Entry(static_cast<bst_uint>(batch.base_rowid+i),
inst[j].fvalue), tid);
}
}
}
}
}
CHECK_EQ(pcol->Size(), Info().num_col_);
if (sorted) {
// sort columns
auto ncol = static_cast<bst_omp_uint>(pcol->Size());
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
SparseBatch::Entry::CmpValue);
}
}
}
}
void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
float pkeep,
size_t max_row_perbatch, bool sorted) {
size_t btop = 0;
std::bernoulli_distribution coin_flip(pkeep);
auto& rnd = common::GlobalRandom();
buffered_rowset_.Clear();
// internal temp cache
SparsePage tmp; tmp.Clear();
// start working
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || coin_flip(rnd)) {
buffered_rowset_.PushBack(ridx);
tmp.Push(batch[i]);
}
if (tmp.Size() >= max_row_perbatch) {
std::unique_ptr<SparsePage> page(new SparsePage());
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted);
col_iter_.cpages_.push_back(std::move(page));
btop = buffered_rowset_.Size();
tmp.Clear();
}
}
}
if (tmp.Size() != 0) {
std::unique_ptr<SparsePage> page(new SparsePage());
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted);
col_iter_.cpages_.push_back(std::move(page));
}
}
// make column page from subset of rowbatchs
void SimpleDMatrix::MakeColPage(const RowBatch& batch,
size_t buffer_begin,
const std::vector<bool>& enabled,
SparsePage* pcol, bool sorted) {
const int nthread = std::min(omp_get_max_threads(), std::max(omp_get_num_procs() / 2 - 2, 1));
pcol->Clear();
common::ParallelGroupBuilder<SparseBatch::Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(Info().num_col_, nthread);
bst_omp_uint ndata = static_cast<bst_uint>(batch.size);
#pragma omp parallel for schedule(static) num_threads(nthread)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
const SparseBatch::Entry &e = inst[j];
if (enabled[e.index]) {
builder.AddBudget(e.index, tid);
}
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static) num_threads(nthread)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
const SparseBatch::Entry &e = inst[j];
builder.Push(
e.index,
SparseBatch::Entry(buffered_rowset_[i + buffer_begin], e.fvalue),
inst[j].index,
Entry(static_cast<bst_uint>(batch.base_rowid + i), inst[j].fvalue),
tid);
}
}
}
CHECK_EQ(pcol->Size(), Info().num_col_);
// sort columns
if (sorted) {
// sort columns
auto ncol = static_cast<bst_omp_uint>(pcol->Size());
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
SparseBatch::Entry::CmpValue);
Entry::CmpValue);
}
}
}
}
bool SimpleDMatrix::SingleColBlock() const {
return col_iter_.cpages_.size() <= 1;
return true;
}
} // namespace data
} // namespace xgboost

View File

@ -12,7 +12,6 @@
#include <vector>
#include <algorithm>
#include <cstring>
#include "./sparse_batch_page.h"
namespace xgboost {
namespace data {
@ -30,14 +29,14 @@ class SimpleDMatrix : public DMatrix {
return source_->info;
}
dmlc::DataIter<RowBatch>* RowIterator() override {
dmlc::DataIter<RowBatch>* iter = source_.get();
dmlc::DataIter<SparsePage>* RowIterator() override {
auto iter = source_.get();
iter->BeforeFirst();
return iter;
}
bool HaveColAccess(bool sorted) const override {
return col_size_.size() != 0 && col_iter_.sorted_ == sorted;
return col_iter_.sorted_ == sorted && col_iter_.column_page_!= nullptr;
}
const RowSet& BufferedRowset() const override {
@ -45,50 +44,42 @@ class SimpleDMatrix : public DMatrix {
}
size_t GetColSize(size_t cidx) const override {
return col_size_[cidx];
auto& batch = *col_iter_.column_page_;
return batch[cidx].length;
}
float GetColDensity(size_t cidx) const override {
size_t nmiss = buffered_rowset_.Size() - col_size_[cidx];
size_t nmiss = buffered_rowset_.Size() - GetColSize(cidx);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.Size();
}
dmlc::DataIter<ColBatch>* ColIterator() override;
dmlc::DataIter<SparsePage>* ColIterator() override;
dmlc::DataIter<ColBatch>* ColIterator(const std::vector<bst_uint>& fset) override;
void InitColAccess(const std::vector<bool>& enabled,
float subsample,
void InitColAccess(
size_t max_row_perbatch, bool sorted) override;
bool SingleColBlock() const override;
private:
// in-memory column batch iterator.
struct ColBatchIter: dmlc::DataIter<ColBatch> {
struct ColBatchIter: dmlc::DataIter<SparsePage> {
public:
ColBatchIter() = default;
void BeforeFirst() override {
data_ptr_ = 0;
data_ = 0;
}
const ColBatch &Value() const override {
return batch_;
const SparsePage &Value() const override {
return *column_page_;
}
bool Next() override;
private:
// allow SimpleDMatrix to access it.
friend class SimpleDMatrix;
// data content
std::vector<bst_uint> col_index_;
// column content
std::vector<ColBatch::Inst> col_data_;
// column sparse pages
std::vector<std::unique_ptr<SparsePage> > cpages_;
// column sparse page
std::unique_ptr<SparsePage> column_page_;
// data pointer
size_t data_ptr_{0};
// temporal space for batch
ColBatch batch_;
size_t data_{0};
// Is column sorted?
bool sorted_{false};
};
@ -99,21 +90,9 @@ class SimpleDMatrix : public DMatrix {
ColBatchIter col_iter_;
// list of row index that are buffered.
RowSet buffered_rowset_;
/*! \brief sizeof column data */
std::vector<size_t> col_size_;
// internal function to make one batch from row iter.
void MakeOneBatch(const std::vector<bool>& enabled,
float pkeep,
SparsePage *pcol, bool sorted);
void MakeManyBatch(const std::vector<bool>& enabled,
float pkeep,
size_t max_row_perbatch, bool sorted);
void MakeColPage(const RowBatch& batch,
size_t buffer_begin,
const std::vector<bool>& enabled,
void MakeOneBatch(
SparsePage *pcol, bool sorted);
};
} // namespace data

View File

@ -1,255 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file sparse_batch_page.h
* content holder of sparse batch that can be saved to disk
* the representation can be effectively
* use in external memory computation
* \author Tianqi Chen
*/
#ifndef XGBOOST_DATA_SPARSE_BATCH_PAGE_H_
#define XGBOOST_DATA_SPARSE_BATCH_PAGE_H_
#include <xgboost/data.h>
#include <dmlc/io.h>
#include <vector>
#include <algorithm>
#include <cstring>
#include <string>
#include <utility>
#include <memory>
#include <functional>
#if DMLC_ENABLE_STD_THREAD
#include <dmlc/concurrency.h>
#include <thread>
#endif
namespace xgboost {
namespace data {
/*!
* \brief in-memory storage unit of sparse batch
*/
class SparsePage {
public:
/*! \brief Format of the sparse page. */
class Format;
/*! \brief Writer to write the sparse page to files. */
class Writer;
/*! \brief minimum index of all index, used as hint for compression. */
bst_uint min_index;
/*! \brief offset of the segments */
std::vector<size_t> offset;
/*! \brief the data of the segments */
std::vector<SparseBatch::Entry> data;
/*! \brief constructor */
SparsePage() {
this->Clear();
}
/*! \return number of instance in the page */
inline size_t Size() const {
return offset.size() - 1;
}
/*! \return estimation of memory cost of this page */
inline size_t MemCostBytes() const {
return offset.size() * sizeof(size_t) + data.size() * sizeof(SparseBatch::Entry);
}
/*! \brief clear the page */
inline void Clear() {
min_index = 0;
offset.clear();
offset.push_back(0);
data.clear();
}
/*!
* \brief Push row batch into the page
* \param batch the row batch
*/
inline void Push(const RowBatch &batch) {
data.resize(offset.back() + batch.ind_ptr[batch.size]);
std::memcpy(dmlc::BeginPtr(data) + offset.back(),
batch.data_ptr + batch.ind_ptr[0],
sizeof(SparseBatch::Entry) * batch.ind_ptr[batch.size]);
size_t top = offset.back();
size_t begin = offset.size();
offset.resize(offset.size() + batch.size);
for (size_t i = 0; i < batch.size; ++i) {
offset[i + begin] = top + batch.ind_ptr[i + 1] - batch.ind_ptr[0];
}
}
/*!
* \brief Push row block into the page.
* \param batch the row batch.
*/
inline void Push(const dmlc::RowBlock<uint32_t>& batch) {
data.reserve(data.size() + batch.offset[batch.size] - batch.offset[0]);
offset.reserve(offset.size() + batch.size);
CHECK(batch.index != nullptr);
for (size_t i = 0; i < batch.size; ++i) {
offset.push_back(offset.back() + batch.offset[i + 1] - batch.offset[i]);
}
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
data.emplace_back(index, fvalue);
}
CHECK_EQ(offset.back(), data.size());
}
/*!
* \brief Push a sparse page
* \param batch the row page
*/
inline void Push(const SparsePage &batch) {
size_t top = offset.back();
data.resize(top + batch.data.size());
std::memcpy(dmlc::BeginPtr(data) + top,
dmlc::BeginPtr(batch.data),
sizeof(SparseBatch::Entry) * batch.data.size());
size_t begin = offset.size();
offset.resize(begin + batch.Size());
for (size_t i = 0; i < batch.Size(); ++i) {
offset[i + begin] = top + batch.offset[i + 1];
}
}
/*!
* \brief Push one instance into page
* \param row an instance row
*/
inline void Push(const SparseBatch::Inst &inst) {
offset.push_back(offset.back() + inst.length);
size_t begin = data.size();
data.resize(begin + inst.length);
if (inst.length != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
sizeof(SparseBatch::Entry) * inst.length);
}
}
/*!
* \param base_rowid base_rowid of the data
* \return row batch representation of the page
*/
inline RowBatch GetRowBatch(size_t base_rowid) const {
RowBatch out;
out.base_rowid = base_rowid;
out.ind_ptr = dmlc::BeginPtr(offset);
out.data_ptr = dmlc::BeginPtr(data);
out.size = offset.size() - 1;
return out;
}
};
/*!
* \brief Format specification of SparsePage.
*/
class SparsePage::Format {
public:
/*! \brief virtual destructor */
virtual ~Format() = default;
/*!
* \brief Load all the segments into page, advance fi to end of the block.
* \param page The data to read page into.
* \param fi the input stream of the file
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page, dmlc::SeekStream* fi) = 0;
/*!
* \brief read only the segments we are interested in, advance fi to end of the block.
* \param page The page to load the data into.
* \param fi the input stream of the file
* \param sorted_index_set sorted index of segments we are interested in
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) = 0;
/*!
* \brief save the data to fo, when a page was written.
* \param fo output stream
*/
virtual void Write(const SparsePage& page, dmlc::Stream* fo) = 0;
/*!
* \brief Create sparse page of format.
* \return The created format functors.
*/
static Format* Create(const std::string& name);
/*!
* \brief decide the format from cache prefix.
* \return pair of row format, column format type of the cache prefix.
*/
static std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix);
};
#if DMLC_ENABLE_STD_THREAD
/*!
* \brief A threaded writer to write sparse batch page to sharded files.
*/
class SparsePage::Writer {
public:
/*!
* \brief constructor
* \param name_shards name of shard files.
* \param format_shards format of each shard.
* \param extra_buffer_capacity Extra buffer capacity before block.
*/
explicit Writer(
const std::vector<std::string>& name_shards,
const std::vector<std::string>& format_shards,
size_t extra_buffer_capacity);
/*! \brief destructor, will close the files automatically */
~Writer();
/*!
* \brief Push a write job to the writer.
* This function won't block,
* writing is done by another thread inside writer.
* \param page The page to be written
*/
void PushWrite(std::shared_ptr<SparsePage>&& page);
/*!
* \brief Allocate a page to store results.
* This function can block when the writer is too slow and buffer pages
* have not yet been recycled.
* \param out_page Used to store the allocated pages.
*/
void Alloc(std::shared_ptr<SparsePage>* out_page);
private:
/*! \brief number of allocated pages */
size_t num_free_buffer_;
/*! \brief clock_pointer */
size_t clock_ptr_;
/*! \brief writer threads */
std::vector<std::unique_ptr<std::thread> > workers_;
/*! \brief recycler queue */
dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > qrecycle_;
/*! \brief worker threads */
std::vector<dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > > qworkers_;
};
#endif // DMLC_ENABLE_STD_THREAD
/*!
* \brief Registry entry for sparse page format.
*/
struct SparsePageFormatReg
: public dmlc::FunctionRegEntryBase<SparsePageFormatReg,
std::function<SparsePage::Format* ()> > {
};
/*!
* \brief Macro to register sparse page format.
*
* \code
* // example of registering a objective
* XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
* .describe("Raw binary data format.")
* .set_body([]() {
* return new RawFormat();
* });
* \endcode
*/
#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \
DMLC_REGISTRY_REGISTER(::xgboost::data::SparsePageFormatReg, SparsePageFormat, Name)
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SPARSE_BATCH_PAGE_H_

View File

@ -29,8 +29,8 @@ SparsePageDMatrix::ColPageIter::ColPageIter(
dmlc::SeekStream* fi = files_[i].get();
std::string format;
CHECK(fi->Read(&format)) << "Invalid page format";
formats_[i].reset(SparsePage::Format::Create(format));
SparsePage::Format* fmt = formats_[i].get();
formats_[i].reset(SparsePageFormat::Create(format));
SparsePageFormat* fmt = formats_[i].get();
size_t fbegin = fi->Tell();
prefetchers_[i].reset(new dmlc::ThreadedIter<SparsePage>(4));
prefetchers_[i]->Init([this, fi, fmt] (SparsePage** dptr) {
@ -61,15 +61,6 @@ bool SparsePageDMatrix::ColPageIter::Next() {
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
}
if (prefetchers_[clock_ptr_]->Next(&page_)) {
out_.col_index = dmlc::BeginPtr(index_set_);
col_data_.resize(page_->offset.size() - 1, SparseBatch::Inst(nullptr, 0));
for (size_t i = 0; i < col_data_.size(); ++i) {
col_data_[i] = SparseBatch::Inst
(dmlc::BeginPtr(page_->data) + page_->offset[i],
static_cast<bst_uint>(page_->offset[i + 1] - page_->offset[i]));
}
out_.col_data = dmlc::BeginPtr(col_data_);
out_.size = col_data_.size();
// advance clock
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
return true;
@ -85,40 +76,22 @@ void SparsePageDMatrix::ColPageIter::BeforeFirst() {
}
}
void SparsePageDMatrix::ColPageIter::Init(const std::vector<bst_uint>& index_set,
bool load_all) {
void SparsePageDMatrix::ColPageIter::Init(
const std::vector<bst_uint>& index_set) {
set_index_set_ = index_set;
set_load_all_ = load_all;
set_load_all_ = true;
std::sort(set_index_set_.begin(), set_index_set_.end());
this->BeforeFirst();
}
dmlc::DataIter<ColBatch>* SparsePageDMatrix::ColIterator() {
dmlc::DataIter<SparsePage>* SparsePageDMatrix::ColIterator() {
CHECK(col_iter_ != nullptr);
std::vector<bst_uint> col_index;
size_t ncol = this->Info().num_col_;
for (size_t i = 0; i < ncol; ++i) {
col_index.push_back(static_cast<bst_uint>(i));
}
col_iter_->Init(col_index, true);
std::iota(col_index.begin(), col_index.end(), bst_uint(0));
col_iter_->Init(col_index);
return col_iter_.get();
}
dmlc::DataIter<ColBatch>* SparsePageDMatrix::
ColIterator(const std::vector<bst_uint>& fset) {
CHECK(col_iter_ != nullptr);
std::vector<bst_uint> col_index;
size_t ncol = this->Info().num_col_;
for (auto fidx : fset) {
if (fidx < ncol) {
col_index.push_back(fidx);
}
}
col_iter_->Init(col_index, false);
return col_iter_.get();
}
bool SparsePageDMatrix::TryInitColData(bool sorted) {
// load meta data.
std::vector<std::string> cache_shards = common::Split(cache_info_, ':');
@ -145,8 +118,7 @@ bool SparsePageDMatrix::TryInitColData(bool sorted) {
return true;
}
void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
float pkeep,
void SparsePageDMatrix::InitColAccess(
size_t max_row_perbatch, bool sorted) {
if (HaveColAccess(sorted)) return;
if (TryInitColData(sorted)) return;
@ -157,11 +129,9 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
buffered_rowset_.Clear();
col_size_.resize(info.num_col_);
std::fill(col_size_.begin(), col_size_.end(), 0);
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
std::bernoulli_distribution coin_flip(pkeep);
auto iter = this->RowIterator();
size_t batch_ptr = 0, batch_top = 0;
SparsePage tmp;
auto& rnd = common::GlobalRandom();
// function to create the page.
auto make_col_batch = [&] (
@ -169,9 +139,9 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
size_t begin,
SparsePage *pcol) {
pcol->Clear();
pcol->min_index = buffered_rowset_[begin];
pcol->base_rowid = buffered_rowset_[begin];
const int nthread = std::max(omp_get_max_threads(), std::max(omp_get_num_procs() / 2 - 1, 1));
common::ParallelGroupBuilder<SparseBatch::Entry>
common::ParallelGroupBuilder<Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(info.num_col_, nthread);
bst_omp_uint ndata = static_cast<bst_uint>(prow.Size());
@ -179,20 +149,18 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
if (enabled[e.index]) {
const auto e = prow.data[j];
builder.AddBudget(e.index, tid);
}
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static) num_threads(nthread)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
const Entry &e = prow.data[j];
builder.Push(e.index,
SparseBatch::Entry(buffered_rowset_[i + begin], e.fvalue),
Entry(buffered_rowset_[i + begin], e.fvalue),
tid);
}
}
@ -205,7 +173,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
SparseBatch::Entry::CmpValue);
Entry::CmpValue);
}
}
}
@ -217,14 +185,12 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
while (true) {
if (batch_ptr != batch_top) {
const RowBatch& batch = iter->Value();
CHECK_EQ(batch_top, batch.size);
auto batch = iter->Value();
CHECK_EQ(batch_top, batch.Size());
for (size_t i = batch_ptr; i < batch_top; ++i) {
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || coin_flip(rnd)) {
buffered_rowset_.PushBack(ridx);
tmp.Push(batch[i]);
}
if (tmp.Size() >= max_row_perbatch ||
tmp.MemCostBytes() >= kPageSize) {
@ -237,7 +203,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
}
if (!iter->Next()) break;
batch_ptr = 0;
batch_top = iter->Value().size;
batch_top = iter->Value().Size();
}
if (tmp.Size() != 0) {
@ -252,11 +218,11 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
std::vector<std::string> name_shards, format_shards;
for (const std::string& prefix : cache_shards) {
name_shards.push_back(prefix + ".col.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).second);
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).second);
}
{
SparsePage::Writer writer(name_shards, format_shards, 6);
SparsePageWriter writer(name_shards, format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page); page->Clear();

View File

@ -14,8 +14,8 @@
#include <vector>
#include <algorithm>
#include <string>
#include "./sparse_batch_page.h"
#include "../common/common.h"
#include "./sparse_page_writer.h"
namespace xgboost {
namespace data {
@ -35,8 +35,8 @@ class SparsePageDMatrix : public DMatrix {
return source_->info;
}
dmlc::DataIter<RowBatch>* RowIterator() override {
dmlc::DataIter<RowBatch>* iter = source_.get();
dmlc::DataIter<SparsePage>* RowIterator() override {
auto iter = source_.get();
iter->BeforeFirst();
return iter;
}
@ -62,12 +62,9 @@ class SparsePageDMatrix : public DMatrix {
return false;
}
dmlc::DataIter<ColBatch>* ColIterator() override;
dmlc::DataIter<SparsePage>* ColIterator() override;
dmlc::DataIter<ColBatch>* ColIterator(const std::vector<bst_uint>& fset) override;
void InitColAccess(const std::vector<bool>& enabled,
float subsample,
void InitColAccess(
size_t max_row_perbatch, bool sorted) override;
/*! \brief page size 256 MB */
@ -77,17 +74,17 @@ class SparsePageDMatrix : public DMatrix {
private:
// declare the column batch iter.
class ColPageIter : public dmlc::DataIter<ColBatch> {
class ColPageIter : public dmlc::DataIter<SparsePage> {
public:
explicit ColPageIter(std::vector<std::unique_ptr<dmlc::SeekStream> >&& files);
~ColPageIter() override;
void BeforeFirst() override;
const ColBatch &Value() const override {
return out_;
const SparsePage &Value() const override {
return *page_;
}
bool Next() override;
// initialize the column iterator with the specified index set.
void Init(const std::vector<bst_uint>& index_set, bool load_all);
void Init(const std::vector<bst_uint>& index_set);
// If the column features are sorted
bool sorted;
@ -99,7 +96,7 @@ class SparsePageDMatrix : public DMatrix {
// data file pointer.
std::vector<std::unique_ptr<dmlc::SeekStream> > files_;
// page format.
std::vector<std::unique_ptr<SparsePage::Format> > formats_;
std::vector<std::unique_ptr<SparsePageFormat> > formats_;
/*! \brief internal prefetcher. */
std::vector<std::unique_ptr<dmlc::ThreadedIter<SparsePage> > > prefetchers_;
// The index set to be loaded.
@ -108,10 +105,6 @@ class SparsePageDMatrix : public DMatrix {
std::vector<bst_uint> set_index_set_;
// whether to load data dataset.
bool set_load_all_, load_all_;
// temporal space for batch
ColBatch out_;
// the pointer data.
std::vector<SparseBatch::Inst> col_data_;
};
/*!
* \brief Try to initialize column data.

View File

@ -5,14 +5,14 @@
*/
#include <xgboost/data.h>
#include <dmlc/registry.h>
#include "./sparse_batch_page.h"
#include "./sparse_page_writer.h"
namespace xgboost {
namespace data {
DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format);
class SparsePageRawFormat : public SparsePage::Format {
class SparsePageRawFormat : public SparsePageFormat {
public:
bool Read(SparsePage* page, dmlc::SeekStream* fi) override {
if (!fi->Read(&(page->offset))) return false;
@ -20,8 +20,8 @@ class SparsePageRawFormat : public SparsePage::Format {
page->data.resize(page->offset.back());
if (page->data.size() != 0) {
CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data),
(page->data).size() * sizeof(SparseBatch::Entry)),
(page->data).size() * sizeof(SparseBatch::Entry))
(page->data).size() * sizeof(Entry)),
(page->data).size() * sizeof(Entry))
<< "Invalid SparsePage file";
}
return true;
@ -47,7 +47,7 @@ class SparsePageRawFormat : public SparsePage::Format {
bst_uint fid = sorted_index_set[i];
if (disk_offset_[fid] != curr_offset) {
CHECK_GT(disk_offset_[fid], curr_offset);
fi->Seek(begin + disk_offset_[fid] * sizeof(SparseBatch::Entry));
fi->Seek(begin + disk_offset_[fid] * sizeof(Entry));
curr_offset = disk_offset_[fid];
}
size_t j, size_to_read = 0;
@ -61,8 +61,8 @@ class SparsePageRawFormat : public SparsePage::Format {
if (size_to_read != 0) {
CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data) + page->offset[i],
size_to_read * sizeof(SparseBatch::Entry)),
size_to_read * sizeof(SparseBatch::Entry))
size_to_read * sizeof(Entry)),
size_to_read * sizeof(Entry))
<< "Invalid SparsePage file";
curr_offset += size_to_read;
}
@ -70,7 +70,7 @@ class SparsePageRawFormat : public SparsePage::Format {
}
// seek to end of record
if (curr_offset != disk_offset_.back()) {
fi->Seek(begin + disk_offset_.back() * sizeof(SparseBatch::Entry));
fi->Seek(begin + disk_offset_.back() * sizeof(Entry));
}
return true;
}
@ -80,7 +80,7 @@ class SparsePageRawFormat : public SparsePage::Format {
CHECK_EQ(page.offset.back(), page.data.size());
fo->Write(page.offset);
if (page.data.size() != 0) {
fo->Write(dmlc::BeginPtr(page.data), page.data.size() * sizeof(SparseBatch::Entry));
fo->Write(dmlc::BeginPtr(page.data), page.data.size() * sizeof(Entry));
}
}

View File

@ -37,8 +37,8 @@ SparsePageSource::SparsePageSource(const std::string& cache_info)
dmlc::SeekStream* fi = files_[i].get();
std::string format;
CHECK(fi->Read(&format)) << "Invalid page format";
formats_[i].reset(SparsePage::Format::Create(format));
SparsePage::Format* fmt = formats_[i].get();
formats_[i].reset(SparsePageFormat::Create(format));
SparsePageFormat* fmt = formats_[i].get();
size_t fbegin = fi->Tell();
prefetchers_[i].reset(new dmlc::ThreadedIter<SparsePage>(4));
prefetchers_[i]->Init([fi, fmt] (SparsePage** dptr) {
@ -61,8 +61,8 @@ bool SparsePageSource::Next() {
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
}
if (prefetchers_[clock_ptr_]->Next(&page_)) {
batch_ = page_->GetRowBatch(base_rowid_);
base_rowid_ += batch_.size;
page_->base_rowid = base_rowid_;
base_rowid_ += page_->Size();
// advance clock
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
return true;
@ -79,8 +79,8 @@ void SparsePageSource::BeforeFirst() {
}
}
const RowBatch& SparsePageSource::Value() const {
return batch_;
const SparsePage& SparsePageSource::Value() const {
return *page_;
}
bool SparsePageSource::CacheExist(const std::string& cache_info) {
@ -108,10 +108,10 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
std::vector<std::string> name_shards, format_shards;
for (const std::string& prefix : cache_shards) {
name_shards.push_back(prefix + ".row.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first);
}
{
SparsePage::Writer writer(name_shards, format_shards, 6);
SparsePageWriter writer(name_shards, format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page); page->Clear();
@ -176,17 +176,17 @@ void SparsePageSource::Create(DMatrix* src,
std::vector<std::string> name_shards, format_shards;
for (const std::string& prefix : cache_shards) {
name_shards.push_back(prefix + ".row.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first);
}
{
SparsePage::Writer writer(name_shards, format_shards, 6);
SparsePageWriter writer(name_shards, format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page); page->Clear();
MetaInfo info = src->Info();
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
dmlc::DataIter<RowBatch>* iter = src->RowIterator();
auto iter = src->RowIterator();
while (iter->Next()) {
page->Push(iter->Value());

View File

@ -13,7 +13,7 @@
#include <vector>
#include <algorithm>
#include <string>
#include "./sparse_batch_page.h"
#include "sparse_page_writer.h"
namespace xgboost {
namespace data {
@ -39,7 +39,7 @@ class SparsePageSource : public DataSource {
// implement BeforeFirst
void BeforeFirst() override;
// implement Value
const RowBatch& Value() const override;
const SparsePage& Value() const override;
/*!
* \brief Create source by taking data from parser.
* \param src source parser.
@ -67,8 +67,6 @@ class SparsePageSource : public DataSource {
private:
/*! \brief number of rows */
size_t base_rowid_;
/*! \brief temp data. */
RowBatch batch_;
/*! \brief page currently on hold. */
SparsePage *page_;
/*! \brief internal clock ptr */
@ -76,7 +74,7 @@ class SparsePageSource : public DataSource {
/*! \brief file pointer to the row blob file. */
std::vector<std::unique_ptr<dmlc::SeekStream> > files_;
/*! \brief Sparse page format file. */
std::vector<std::unique_ptr<SparsePage::Format> > formats_;
std::vector<std::unique_ptr<SparsePageFormat> > formats_;
/*! \brief internal prefetcher. */
std::vector<std::unique_ptr<dmlc::ThreadedIter<SparsePage> > > prefetchers_;
};

View File

@ -5,13 +5,13 @@
*/
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include "./sparse_batch_page.h"
#include "./sparse_page_writer.h"
#if DMLC_ENABLE_STD_THREAD
namespace xgboost {
namespace data {
SparsePage::Writer::Writer(
SparsePageWriter::SparsePageWriter(
const std::vector<std::string>& name_shards,
const std::vector<std::string>& format_shards,
size_t extra_buffer_capacity)
@ -29,8 +29,8 @@ SparsePage::Writer::Writer(
[this, name_shard, format_shard, wqueue] () {
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(name_shard.c_str(), "w"));
std::unique_ptr<SparsePage::Format> fmt(
SparsePage::Format::Create(format_shard));
std::unique_ptr<SparsePageFormat> fmt(
SparsePageFormat::Create(format_shard));
fo->Write(format_shard);
std::shared_ptr<SparsePage> page;
while (wqueue->Pop(&page)) {
@ -44,7 +44,7 @@ SparsePage::Writer::Writer(
}
}
SparsePage::Writer::~Writer() {
SparsePageWriter::~SparsePageWriter() {
for (auto& queue : qworkers_) {
// use nullptr to signal termination.
std::shared_ptr<SparsePage> sig(nullptr);
@ -55,12 +55,12 @@ SparsePage::Writer::~Writer() {
}
}
void SparsePage::Writer::PushWrite(std::shared_ptr<SparsePage>&& page) {
void SparsePageWriter::PushWrite(std::shared_ptr<SparsePage>&& page) {
qworkers_[clock_ptr_].Push(std::move(page));
clock_ptr_ = (clock_ptr_ + 1) % workers_.size();
}
void SparsePage::Writer::Alloc(std::shared_ptr<SparsePage>* out_page) {
void SparsePageWriter::Alloc(std::shared_ptr<SparsePage>* out_page) {
CHECK(*out_page == nullptr);
if (num_free_buffer_ != 0) {
out_page->reset(new SparsePage());

View File

@ -0,0 +1,139 @@
/*!
* Copyright (c) 2014 by Contributors
* \file sparse_page_writer.h
* \author Tianqi Chen
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
#define XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
#include <xgboost/data.h>
#include <dmlc/io.h>
#include <vector>
#include <algorithm>
#include <cstring>
#include <string>
#include <utility>
#include <memory>
#include <functional>
#if DMLC_ENABLE_STD_THREAD
#include <dmlc/concurrency.h>
#include <thread>
#endif
namespace xgboost {
namespace data {
/*!
* \brief Format specification of SparsePage.
*/
class SparsePageFormat {
public:
/*! \brief virtual destructor */
virtual ~SparsePageFormat() = default;
/*!
* \brief Load all the segments into page, advance fi to end of the block.
* \param page The data to read page into.
* \param fi the input stream of the file
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page, dmlc::SeekStream* fi) = 0;
/*!
* \brief read only the segments we are interested in, advance fi to end of the block.
* \param page The page to load the data into.
* \param fi the input stream of the file
* \param sorted_index_set sorted index of segments we are interested in
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) = 0;
/*!
* \brief save the data to fo, when a page was written.
* \param fo output stream
*/
virtual void Write(const SparsePage& page, dmlc::Stream* fo) = 0;
/*!
* \brief Create sparse page of format.
* \return The created format functors.
*/
static SparsePageFormat* Create(const std::string& name);
/*!
* \brief decide the format from cache prefix.
* \return pair of row format, column format type of the cache prefix.
*/
static std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix);
};
#if DMLC_ENABLE_STD_THREAD
/*!
* \brief A threaded writer to write sparse batch page to sharded files.
*/
class SparsePageWriter {
public:
/*!
* \brief constructor
* \param name_shards name of shard files.
* \param format_shards format of each shard.
* \param extra_buffer_capacity Extra buffer capacity before block.
*/
explicit SparsePageWriter(
const std::vector<std::string>& name_shards,
const std::vector<std::string>& format_shards,
size_t extra_buffer_capacity);
/*! \brief destructor, will close the files automatically */
~SparsePageWriter();
/*!
* \brief Push a write job to the writer.
* This function won't block,
* writing is done by another thread inside writer.
* \param page The page to be written
*/
void PushWrite(std::shared_ptr<SparsePage>&& page);
/*!
* \brief Allocate a page to store results.
* This function can block when the writer is too slow and buffer pages
* have not yet been recycled.
* \param out_page Used to store the allocated pages.
*/
void Alloc(std::shared_ptr<SparsePage>* out_page);
private:
/*! \brief number of allocated pages */
size_t num_free_buffer_;
/*! \brief clock_pointer */
size_t clock_ptr_;
/*! \brief writer threads */
std::vector<std::unique_ptr<std::thread> > workers_;
/*! \brief recycler queue */
dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > qrecycle_;
/*! \brief worker threads */
std::vector<dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > > qworkers_;
};
#endif // DMLC_ENABLE_STD_THREAD
/*!
* \brief Registry entry for sparse page format.
*/
struct SparsePageFormatReg
: public dmlc::FunctionRegEntryBase<SparsePageFormatReg,
std::function<SparsePageFormat* ()> > {
};
/*!
* \brief Macro to register sparse page format.
*
* \code
* // example of registering a objective
* XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
* .describe("Raw binary data format.")
* .set_body([]() {
* return new RawFormat();
* });
* \endcode
*/
#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \
DMLC_REGISTRY_REGISTER(::xgboost::data::SparsePageFormatReg, SparsePageFormat, Name)
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_

View File

@ -86,7 +86,7 @@ class GBLinear : public GradientBooster {
if (!p_fmat->HaveColAccess(false)) {
monitor_.Start("InitColAccess");
std::vector<bool> enabled(p_fmat->Info().num_col_, true);
p_fmat->InitColAccess(enabled, 1.0f, param_.max_row_perbatch, false);
p_fmat->InitColAccess(param_.max_row_perbatch, false);
monitor_.Stop("InitColAccess");
}
@ -120,7 +120,7 @@ class GBLinear : public GradientBooster {
monitor_.Stop("PredictBatch");
}
// add base margin
void PredictInstance(const SparseBatch::Inst &inst,
void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds,
unsigned ntree_limit,
unsigned root_index) override {
@ -152,15 +152,15 @@ class GBLinear : public GradientBooster {
// make sure contributions is zeroed, we could be reusing a previously allocated one
std::fill(contribs.begin(), contribs.end(), 0);
// start collecting the contributions
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
auto batch = iter->Value();
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const RowBatch::Inst &inst = batch[i];
auto inst = batch[i];
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
// loop over output groups
for (int gid = 0; gid < ngroup; ++gid) {
@ -203,15 +203,15 @@ class GBLinear : public GradientBooster {
std::vector<bst_float> &preds = *out_preds;
const std::vector<bst_float>& base_margin = p_fmat->Info().base_margin_;
// start collecting the prediction
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
const int ngroup = model_.param.num_output_group;
preds.resize(p_fmat->Info().num_row_ * ngroup);
while (iter->Next()) {
const RowBatch &batch = iter->Value();
auto batch = iter->Value();
// output convention: nrow * k, where nrow is number of rows
// k is number of group
// parallel over local batch
const auto nsize = static_cast<omp_ulong>(batch.size);
const auto nsize = static_cast<omp_ulong>(batch.Size());
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < nsize; ++i) {
const size_t ridx = batch.base_rowid + i;
@ -265,7 +265,7 @@ class GBLinear : public GradientBooster {
}
}
inline void Pred(const RowBatch::Inst &inst, bst_float *preds, int gid,
inline void Pred(const SparsePage::Inst &inst, bst_float *preds, int gid,
bst_float base) {
bst_float psum = model_.bias()[gid] + base;
for (bst_uint i = 0; i < inst.length; ++i) {

View File

@ -221,7 +221,7 @@ class GBTree : public GradientBooster {
predictor_->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
}
void PredictInstance(const SparseBatch::Inst& inst,
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit,
unsigned root_index) override {
@ -361,7 +361,7 @@ class Dart : public GBTree {
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
}
void PredictInstance(const SparseBatch::Inst& inst,
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit,
unsigned root_index) override {
@ -437,21 +437,21 @@ class Dart : public GBTree {
<< "size_leaf_vector is enforced to 0 so far";
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
// start collecting the prediction
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
auto* self = static_cast<Derived*>(this);
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
auto batch = iter->Value();
// parallel over local batch
constexpr int kUnroll = 8;
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
const bst_omp_uint rest = nsize % kUnroll;
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
const int tid = omp_get_thread_num();
RegTree::FVec& feats = thread_temp_[tid];
int64_t ridx[kUnroll];
RowBatch::Inst inst[kUnroll];
SparsePage::Inst inst[kUnroll];
for (int k = 0; k < kUnroll; ++k) {
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
}
@ -470,7 +470,7 @@ class Dart : public GBTree {
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
RegTree::FVec& feats = thread_temp_[0];
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
const RowBatch::Inst inst = batch[i];
const SparsePage::Inst inst = batch[i];
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx * num_group + gid;
preds[offset] +=
@ -497,7 +497,7 @@ class Dart : public GBTree {
}
// predict the leaf scores without dropped trees
inline bst_float PredValue(const RowBatch::Inst &inst,
inline bst_float PredValue(const SparsePage::Inst &inst,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,

View File

@ -80,8 +80,6 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
int tree_method;
// internal test flag
std::string test_flag;
// maximum buffered row value
float prob_buffer_row;
// maximum row per batch.
size_t max_row_perbatch;
// number of threads to use if OpenMP is enabled
@ -116,10 +114,6 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
.describe("Choice of tree construction method.");
DMLC_DECLARE_FIELD(test_flag).set_default("").describe(
"Internal test flag");
DMLC_DECLARE_FIELD(prob_buffer_row)
.set_default(1.0f)
.set_range(0.0f, 1.0f)
.describe("Maximum buffered row portion");
DMLC_DECLARE_FIELD(max_row_perbatch)
.set_default(std::numeric_limits<size_t>::max())
.describe("maximum row per batch.");
@ -163,9 +157,6 @@ class LearnerImpl : public Learner {
} else if (tparam_.dsplit == 2) {
cfg_["updater"] = "grow_histmaker,prune";
}
if (tparam_.prob_buffer_row != 1.0f) {
cfg_["updater"] = "grow_histmaker,refresh,prune";
}
}
} else if (tparam_.tree_method == 3) {
/* histogram-based algorithm */
@ -496,7 +487,7 @@ class LearnerImpl : public Learner {
max_row_perbatch = std::min(max_row_perbatch, safe_max_row);
}
// initialize column access
p_train->InitColAccess(enabled, tparam_.prob_buffer_row, max_row_perbatch, true);
p_train->InitColAccess(max_row_perbatch, true);
}
if (!p_train->SingleColBlock() && cfg_.count("updater") == 0) {

View File

@ -65,10 +65,10 @@ inline std::pair<double, double> GetGradient(int group_idx, int num_group, int f
const std::vector<GradientPair> &gpair,
DMatrix *p_fmat) {
double sum_grad = 0.0, sum_hess = 0.0;
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
ColBatch::Inst col = batch[0];
auto batch = iter->Value();
auto col = batch[fidx];
const auto ndata = static_cast<bst_omp_uint>(col.length);
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_float v = col[j].fvalue;
@ -96,10 +96,10 @@ inline std::pair<double, double> GetGradientParallel(int group_idx, int num_grou
const std::vector<GradientPair> &gpair,
DMatrix *p_fmat) {
double sum_grad = 0.0, sum_hess = 0.0;
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
ColBatch::Inst col = batch[0];
auto batch = iter->Value();
auto col = batch[fidx];
const auto ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
for (bst_omp_uint j = 0; j < ndata; ++j) {
@ -154,10 +154,10 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
float dw, std::vector<GradientPair> *in_gpair,
DMatrix *p_fmat) {
if (dw == 0.0f) return;
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator({static_cast<bst_uint>(fidx)});
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
ColBatch::Inst col = batch[0];
auto batch = iter->Value();
auto col = batch[fidx];
// update grad value
const auto num_row = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
@ -325,12 +325,12 @@ class GreedyFeatureSelector : public FeatureSelector {
const bst_omp_uint nfeat = model.param.num_feature;
// Calculate univariate gradient sums
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
auto batch = iter->Value();
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const ColBatch::Inst col = batch[i];
const auto col = batch[i];
const bst_uint ndata = col.length;
auto &sums = gpair_sums_[group_idx * nfeat + i];
for (bst_uint j = 0u; j < ndata; ++j) {
@ -392,13 +392,13 @@ class ThriftyFeatureSelector : public FeatureSelector {
}
// Calculate univariate gradient sums
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
auto batch = iter->Value();
// column-parallel is usually faster than row-parallel
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const ColBatch::Inst col = batch[i];
const auto col = batch[i];
const bst_uint ndata = col.length;
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
auto &sums = gpair_sums_[gid * nfeat + i];

View File

@ -81,7 +81,7 @@ struct GPUCoordinateTrainParam
float reg_alpha_denorm;
};
void RescaleIndices(size_t ridx_begin, dh::DVec<SparseBatch::Entry> *data) {
void RescaleIndices(size_t ridx_begin, dh::DVec<Entry> *data) {
auto d_data = data->Data();
dh::LaunchN(data->DeviceIdx(), data->Size(),
[=] __device__(size_t idx) { d_data[idx].index -= ridx_begin; });
@ -92,14 +92,14 @@ class DeviceShard {
int normalised_device_idx_; // Device index counting from param.gpu_id
dh::BulkAllocator<dh::MemoryType::kDevice> ba_;
std::vector<size_t> row_ptr_;
dh::DVec<SparseBatch::Entry> data_;
dh::DVec<Entry> data_;
dh::DVec<GradientPair> gpair_;
dh::CubMemory temp_;
size_t ridx_begin_;
size_t ridx_end_;
public:
DeviceShard(int device_idx, int normalised_device_idx, const ColBatch &batch,
DeviceShard(int device_idx, int normalised_device_idx, const SparsePage &batch,
bst_uint row_begin, bst_uint row_end,
const GPUCoordinateTrainParam &param,
const gbm::GBLinearModelParam &model_param)
@ -112,17 +112,17 @@ class DeviceShard {
// this shard
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
row_ptr_ = {0};
for (auto fidx = 0; fidx < batch.size; fidx++) {
for (auto fidx = 0; fidx < batch.Size(); fidx++) {
auto col = batch[fidx];
auto cmp = [](SparseBatch::Entry e1, SparseBatch::Entry e2) {
auto cmp = [](Entry e1, Entry e2) {
return e1.index < e2.index;
};
auto column_begin =
std::lower_bound(col.data, col.data + col.length,
SparseBatch::Entry(row_begin, 0.0f), cmp);
Entry(row_begin, 0.0f), cmp);
auto column_end =
std::upper_bound(col.data, col.data + col.length,
SparseBatch::Entry(row_end, 0.0f), cmp);
Entry(row_end, 0.0f), cmp);
column_segments.push_back(
std::make_pair(column_begin - col.data, column_end - col.data));
row_ptr_.push_back(row_ptr_.back() + column_end - column_begin);
@ -130,8 +130,8 @@ class DeviceShard {
ba_.Allocate(device_idx, param.silent, &data_, row_ptr_.back(), &gpair_,
(row_end - row_begin) * model_param.num_output_group);
for (int fidx = 0; fidx < batch.size; fidx++) {
ColBatch::Inst col = batch[fidx];
for (int fidx = 0; fidx < batch.Size(); fidx++) {
auto col = batch[fidx];
thrust::copy(col.data + column_segments[fidx].first,
col.data + column_segments[fidx].second,
data_.tbegin() + row_ptr_[fidx]);
@ -233,7 +233,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
row_begin = row_end;
}
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
auto iter = p_fmat->ColIterator();
CHECK(p_fmat->SingleColBlock());
iter->Next();
auto batch = iter->Value();

View File

@ -79,17 +79,17 @@ class ShotgunUpdater : public LinearUpdater {
// lock-free parallel updates of weights
selector_->Setup(*model, in_gpair->HostVector(), p_fmat,
param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0);
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
const auto nfeat = static_cast<bst_omp_uint>(batch.size);
auto batch = iter->Value();
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
int ii = selector_->NextFeature(i, *model, 0, in_gpair->HostVector(), p_fmat,
param_.reg_alpha_denorm, param_.reg_lambda_denorm);
if (ii < 0) continue;
const bst_uint fid = batch.col_index[ii];
ColBatch::Inst col = batch[ii];
const bst_uint fid = ii;
auto col = batch[ii];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (bst_uint j = 0; j < col.length; ++j) {

View File

@ -14,7 +14,7 @@ DMLC_REGISTRY_FILE_TAG(cpu_predictor);
class CPUPredictor : public Predictor {
protected:
static bst_float PredValue(const RowBatch::Inst& inst,
static bst_float PredValue(const SparsePage::Inst& inst,
const std::vector<std::unique_ptr<RegTree>>& trees,
const std::vector<int>& tree_info, int bst_group,
unsigned root_index, RegTree::FVec* p_feats,
@ -53,20 +53,20 @@ class CPUPredictor : public Predictor {
<< "size_leaf_vector is enforced to 0 so far";
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
// start collecting the prediction
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
const auto& batch = iter->Value();
// parallel over local batch
constexpr int kUnroll = 8;
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
const bst_omp_uint rest = nsize % kUnroll;
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
const int tid = omp_get_thread_num();
RegTree::FVec& feats = thread_temp[tid];
int64_t ridx[kUnroll];
RowBatch::Inst inst[kUnroll];
SparsePage::Inst inst[kUnroll];
for (int k = 0; k < kUnroll; ++k) {
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
}
@ -85,7 +85,7 @@ class CPUPredictor : public Predictor {
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
RegTree::FVec& feats = thread_temp[0];
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
const RowBatch::Inst inst = batch[i];
auto inst = batch[i];
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx * num_group + gid;
preds[offset] +=
@ -183,7 +183,7 @@ class CPUPredictor : public Predictor {
}
}
void PredictInstance(const SparseBatch::Inst& inst,
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit,
unsigned root_index) override {
@ -218,12 +218,12 @@ class CPUPredictor : public Predictor {
std::vector<bst_float>& preds = *out_preds;
preds.resize(info.num_row_ * ntree_limit);
// start collecting the prediction
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
auto batch = iter->Value();
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const int tid = omp_get_thread_num();
@ -266,13 +266,13 @@ class CPUPredictor : public Predictor {
model.trees[i]->FillNodeMeanValues();
}
// start collecting the contributions
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
const std::vector<bst_float>& base_margin = info.base_margin_;
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
auto batch = iter->Value();
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
auto row_idx = static_cast<size_t>(batch.base_rowid + i);

View File

@ -52,7 +52,7 @@ struct DeviceMatrix {
DMatrix* p_mat; // Pointer to the original matrix on the host
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
dh::DVec<size_t> row_ptr;
dh::DVec<SparseBatch::Entry> data;
dh::DVec<Entry> data;
thrust::device_vector<float> predictions;
DeviceMatrix(DMatrix* dmat, int device_idx, bool silent) : p_mat(dmat) {
@ -66,17 +66,17 @@ struct DeviceMatrix {
while (iter->Next()) {
auto batch = iter->Value();
// Copy row ptr
thrust::copy(batch.ind_ptr, batch.ind_ptr + batch.size + 1,
thrust::copy(batch.offset.data(), batch.offset.data() + batch.Size() + 1,
row_ptr.tbegin() + batch.base_rowid);
if (batch.base_rowid > 0) {
auto begin_itr = row_ptr.tbegin() + batch.base_rowid;
auto end_itr = begin_itr + batch.size + 1;
auto end_itr = begin_itr + batch.Size() + 1;
IncrementOffset(begin_itr, end_itr, batch.base_rowid);
}
// Copy data
thrust::copy(batch.data_ptr, batch.data_ptr + batch.ind_ptr[batch.size],
thrust::copy(batch.data.begin(), batch.data.end(),
data.tbegin() + data_offset);
data_offset += batch.ind_ptr[batch.size];
data_offset += batch.data.size();
}
}
};
@ -139,12 +139,12 @@ struct DevicePredictionNode {
struct ElementLoader {
bool use_shared;
size_t* d_row_ptr;
SparseBatch::Entry* d_data;
Entry* d_data;
int num_features;
float* smem;
__device__ ElementLoader(bool use_shared, size_t* row_ptr,
SparseBatch::Entry* entry, int num_features,
Entry* entry, int num_features,
float* smem, int num_rows)
: use_shared(use_shared),
d_row_ptr(row_ptr),
@ -161,7 +161,7 @@ struct ElementLoader {
bst_uint elem_begin = d_row_ptr[global_idx];
bst_uint elem_end = d_row_ptr[global_idx + 1];
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
SparseBatch::Entry elem = d_data[elem_idx];
Entry elem = d_data[elem_idx];
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
}
}
@ -175,7 +175,7 @@ struct ElementLoader {
// Binary search
auto begin_ptr = d_data + d_row_ptr[ridx];
auto end_ptr = d_data + d_row_ptr[ridx + 1];
SparseBatch::Entry* previous_middle = nullptr;
Entry* previous_middle = nullptr;
while (end_ptr != begin_ptr) {
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
if (middle == previous_middle) {
@ -221,7 +221,7 @@ template <int BLOCK_THREADS>
__global__ void PredictKernel(const DevicePredictionNode* d_nodes,
float* d_out_predictions, size_t* d_tree_segments,
int* d_tree_group, size_t* d_row_ptr,
SparseBatch::Entry* d_data, size_t tree_begin,
Entry* d_data, size_t tree_begin,
size_t tree_end, size_t num_features,
size_t num_rows, bool use_shared, int num_group) {
extern __shared__ float smem[];
@ -422,7 +422,7 @@ class GPUPredictor : public xgboost::Predictor {
}
}
void PredictInstance(const SparseBatch::Inst& inst,
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit,
unsigned root_index) override {

View File

@ -43,13 +43,12 @@ class BaseMaker: public TreeUpdater {
std::fill(fminmax_.begin(), fminmax_.end(),
-std::numeric_limits<bst_float>::max());
// start accumulating statistics
dmlc::DataIter<ColBatch>* iter = p_fmat->ColIterator();
auto iter = p_fmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch& batch = iter->Value();
for (bst_uint i = 0; i < batch.size; ++i) {
const bst_uint fid = batch.col_index[i];
const ColBatch::Inst& c = batch[i];
auto batch = iter->Value();
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
auto c = batch[fid];
if (c.length != 0) {
fminmax_[fid * 2 + 0] = std::max(-c[0].fvalue, fminmax_[fid * 2 + 0]);
fminmax_[fid * 2 + 1] = std::max(c[c.length - 1].fvalue, fminmax_[fid * 2 + 1]);
@ -104,7 +103,7 @@ class BaseMaker: public TreeUpdater {
// ------static helper functions ------
// helper function to get to next level of the tree
/*! \brief this is helper function for row based data*/
inline static int NextLevel(const RowBatch::Inst &inst, const RegTree &tree, int nid) {
inline static int NextLevel(const SparsePage::Inst &inst, const RegTree &tree, int nid) {
const RegTree::Node &n = tree[nid];
bst_uint findex = n.SplitIndex();
for (unsigned i = 0; i < inst.length; ++i) {
@ -244,12 +243,10 @@ class BaseMaker: public TreeUpdater {
* \param tree the regression tree structure
*/
inline void CorrectNonDefaultPositionByBatch(
const ColBatch& batch,
const std::vector<bst_uint> &sorted_split_set,
const SparsePage &batch, const std::vector<bst_uint> &sorted_split_set,
const RegTree &tree) {
for (size_t i = 0; i < batch.size; ++i) {
ColBatch::Inst col = batch[i];
const bst_uint fid = batch.col_index[i];
for (size_t fid = 0; fid < batch.Size(); ++fid) {
auto col = batch[fid];
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
if (it != sorted_split_set.end() && *it == fid) {
@ -306,12 +303,11 @@ class BaseMaker: public TreeUpdater {
const RegTree &tree) {
std::vector<unsigned> fsplits;
this->GetSplitSet(nodes, tree, &fsplits);
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
ColBatch::Inst col = batch[i];
const bst_uint fid = batch.col_index[i];
auto batch = iter->Value();
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {

View File

@ -252,7 +252,7 @@ class ColMaker: public TreeUpdater {
}
// parallel find the best split of current fid
// this function does not support nested functions
inline void ParallelFindSplit(const ColBatch::Inst &col,
inline void ParallelFindSplit(const SparsePage::Inst &col,
bst_uint fid,
const DMatrix &fmat,
const std::vector<GradientPair> &gpair) {
@ -439,8 +439,8 @@ class ColMaker: public TreeUpdater {
}
}
// same as EnumerateSplit, with cacheline prefetch optimization
inline void EnumerateSplitCacheOpt(const ColBatch::Entry *begin,
const ColBatch::Entry *end,
inline void EnumerateSplitCacheOpt(const Entry *begin,
const Entry *end,
int d_step,
bst_uint fid,
const std::vector<GradientPair> &gpair,
@ -457,18 +457,18 @@ class ColMaker: public TreeUpdater {
int buf_position[kBuffer] = {};
GradientPair buf_gpair[kBuffer] = {};
// aligned ending position
const ColBatch::Entry *align_end;
const Entry *align_end;
if (d_step > 0) {
align_end = begin + (end - begin) / kBuffer * kBuffer;
} else {
align_end = begin - (begin - end) / kBuffer * kBuffer;
}
int i;
const ColBatch::Entry *it;
const Entry *it;
const int align_step = d_step * kBuffer;
// internal cached loop
for (it = begin; it != align_end; it += align_step) {
const ColBatch::Entry *p;
const Entry *p;
for (i = 0, p = it; i < kBuffer; ++i, p += d_step) {
buf_position[i] = position_[p->index];
buf_gpair[i] = gpair[p->index];
@ -519,8 +519,8 @@ class ColMaker: public TreeUpdater {
}
// enumerate the split values of specific feature
inline void EnumerateSplit(const ColBatch::Entry *begin,
const ColBatch::Entry *end,
inline void EnumerateSplit(const Entry *begin,
const Entry *end,
int d_step,
bst_uint fid,
const std::vector<GradientPair> &gpair,
@ -538,7 +538,7 @@ class ColMaker: public TreeUpdater {
}
// left statistics
TStats c(param_);
for (const ColBatch::Entry *it = begin; it != end; it += d_step) {
for (const Entry *it = begin; it != end; it += d_step) {
const bst_uint ridx = it->index;
const int nid = position_[ridx];
if (nid < 0) continue;
@ -602,25 +602,26 @@ class ColMaker: public TreeUpdater {
}
// update the solution candidate
virtual void UpdateSolution(const ColBatch& batch,
virtual void UpdateSolution(const SparsePage &batch,
const std::vector<bst_uint> &feat_set,
const std::vector<GradientPair> &gpair,
const DMatrix &fmat) {
const MetaInfo& info = fmat.Info();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto num_features = static_cast<bst_omp_uint>(feat_set.size());
#if defined(_OPENMP)
const int batch_size = std::max(static_cast<int>(nsize / this->nthread_ / 32), 1);
const int batch_size = std::max(static_cast<int>(num_features / this->nthread_ / 32), 1);
#endif
int poption = param_.parallel_option;
if (poption == 2) {
poption = static_cast<int>(nsize) * 2 < this->nthread_ ? 1 : 0;
poption = static_cast<int>(num_features) * 2 < this->nthread_ ? 1 : 0;
}
if (poption == 0) {
#pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const bst_uint fid = batch.col_index[i];
for (bst_omp_uint i = 0; i < num_features; ++i) {
int fid = feat_set[i];
const int tid = omp_get_thread_num();
const ColBatch::Inst c = batch[i];
auto c = batch[fid];
const bool ind = c.length != 0 && c.data[0].fvalue == c.data[c.length - 1].fvalue;
if (param_.NeedForwardSearch(fmat.GetColDensity(fid), ind)) {
this->EnumerateSplit(c.data, c.data + c.length, +1,
@ -632,8 +633,8 @@ class ColMaker: public TreeUpdater {
}
}
} else {
for (bst_omp_uint i = 0; i < nsize; ++i) {
this->ParallelFindSplit(batch[i], batch.col_index[i],
for (bst_omp_uint fid = 0; fid < num_features; ++fid) {
this->ParallelFindSplit(batch[fid], fid,
fmat, gpair);
}
}
@ -653,9 +654,9 @@ class ColMaker: public TreeUpdater {
<< "colsample_bylevel cannot be zero.";
feat_set.resize(n);
}
dmlc::DataIter<ColBatch>* iter = p_fmat->ColIterator(feat_set);
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
this->UpdateSolution(iter->Value(), gpair, *p_fmat);
this->UpdateSolution(iter->Value(), feat_set, gpair, *p_fmat);
}
// after this each thread's stemp will get the best candidates, aggregate results
this->SyncBestSolution(qexpand);
@ -730,12 +731,11 @@ class ColMaker: public TreeUpdater {
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
ColBatch::Inst col = batch[i];
const bst_uint fid = batch.col_index[i];
auto batch = iter->Value();
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
@ -859,12 +859,11 @@ class DistColMaker : public ColMaker<TStats, TConstraint> {
boolmap_[j] = 0;
}
}
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
auto iter = p_fmat->ColIterator();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
ColBatch::Inst col = batch[i];
const bst_uint fid = batch.col_index[i];
auto batch = iter->Value();
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {

View File

@ -661,16 +661,15 @@ class GPUMaker : public TreeUpdater {
// in case you end up with a DMatrix having no column access
// then make sure to enable that before copying the data!
if (!dmat->HaveColAccess(true)) {
const std::vector<bool> enable(nCols, true);
dmat->InitColAccess(enable, 1, nRows, true);
dmat->InitColAccess(nRows, true);
}
dmlc::DataIter<ColBatch>* iter = dmat->ColIterator();
auto iter = dmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch& batch = iter->Value();
for (int i = 0; i < batch.size; i++) {
const ColBatch::Inst& col = batch[i];
for (const ColBatch::Entry* it = col.data; it != col.data + col.length;
auto batch = iter->Value();
for (int i = 0; i < batch.Size(); i++) {
auto col = batch[i];
for (const Entry* it = col.data; it != col.data + col.length;
it++) {
int inst_id = static_cast<int>(it->index);
fval->push_back(it->fvalue);

View File

@ -250,7 +250,7 @@ __device__ int upper_bound(const float* __restrict__ cuts, int n, float v) {
__global__ void compress_bin_ellpack_k
(common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer,
const size_t* __restrict__ row_ptrs,
const RowBatch::Entry* __restrict__ entries,
const Entry* __restrict__ entries,
const float* __restrict__ cuts, const size_t* __restrict__ cut_rows,
size_t base_row, size_t n_rows, size_t row_ptr_begin, size_t row_stride,
unsigned int null_gidx_value) {
@ -261,7 +261,7 @@ __global__ void compress_bin_ellpack_k
int row_size = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_size) {
RowBatch::Entry entry = entries[row_ptrs[irow] - row_ptr_begin + ifeature];
Entry entry = entries[row_ptrs[irow] - row_ptr_begin + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
const float *feature_cuts = &cuts[cut_rows[feature]];
@ -332,7 +332,7 @@ struct DeviceShard {
param(param),
prediction_cache_initialised(false) {}
void Init(const common::HistCutMatrix& hmat, const RowBatch& row_batch) {
void Init(const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
// copy cuts to the GPU
dh::safe_cuda(cudaSetDevice(device_idx));
thrust::device_vector<float> cuts_d(hmat.cut);
@ -340,7 +340,7 @@ struct DeviceShard {
// find the maximum row size
thrust::device_vector<size_t> row_ptr_d(
row_batch.ind_ptr + row_begin_idx, row_batch.ind_ptr + row_end_idx + 1);
&row_batch.offset[row_begin_idx], &row_batch.offset[row_end_idx + 1]);
auto row_iter = row_ptr_d.begin();
auto get_size = [=] __device__(size_t row) {
@ -369,11 +369,11 @@ struct DeviceShard {
// bin and compress entries in batches of rows
// use no more than 1/16th of GPU memory per batch
size_t gpu_batch_nrows = dh::TotalMemory(device_idx) /
(16 * row_stride * sizeof(RowBatch::Entry));
(16 * row_stride * sizeof(Entry));
if (gpu_batch_nrows > n_rows) {
gpu_batch_nrows = n_rows;
}
thrust::device_vector<RowBatch::Entry> entries_d(gpu_batch_nrows * row_stride);
thrust::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
@ -383,13 +383,13 @@ struct DeviceShard {
}
size_t batch_nrows = batch_row_end - batch_row_begin;
size_t n_entries =
row_batch.ind_ptr[row_begin_idx + batch_row_end] -
row_batch.ind_ptr[row_begin_idx + batch_row_begin];
row_batch.offset[row_begin_idx + batch_row_end] -
row_batch.offset[row_begin_idx + batch_row_begin];
dh::safe_cuda
(cudaMemcpy
(entries_d.data().get(),
&row_batch.data_ptr[row_batch.ind_ptr[row_begin_idx + batch_row_begin]],
n_entries * sizeof(RowBatch::Entry), cudaMemcpyDefault));
&row_batch.data[row_batch.offset[row_begin_idx + batch_row_begin]],
n_entries * sizeof(Entry), cudaMemcpyDefault));
dim3 block3(32, 8, 1);
dim3 grid3(dh::DivRoundUp(n_rows, block3.x),
dh::DivRoundUp(row_stride, block3.y), 1);
@ -398,7 +398,7 @@ struct DeviceShard {
row_ptr_d.data().get() + batch_row_begin,
entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(),
batch_row_begin, batch_nrows,
row_batch.ind_ptr[row_begin_idx + batch_row_begin],
row_batch.offset[row_begin_idx + batch_row_begin],
row_stride, null_gidx_value);
dh::safe_cuda(cudaGetLastError());
@ -702,10 +702,10 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Start("BinningCompression", device_list_);
{
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next()) << "Empty batches are not supported";
const RowBatch& batch = iter->Value();
const SparsePage& batch = iter->Value();
// Create device shards
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>

View File

@ -344,17 +344,18 @@ class CQHistMaker: public HistMaker<TStats> {
{
thread_hist_.resize(omp_get_max_threads());
// start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fset);
auto iter = p_fmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
auto batch = iter->Value();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(fset.size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = feat2workindex_[batch.col_index[i]];
int fid = fset[i];
int offset = feat2workindex_[fid];
if (offset >= 0) {
this->UpdateHistCol(gpair, batch[i], info, tree,
this->UpdateHistCol(gpair, batch[fid], info, tree,
fset, offset,
&thread_hist_[omp_get_thread_num()]);
}
@ -425,20 +426,20 @@ class CQHistMaker: public HistMaker<TStats> {
work_set_.resize(std::unique(work_set_.begin(), work_set_.end()) - work_set_.begin());
// start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(work_set_);
auto iter = p_fmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
auto batch = iter->Value();
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = feat2workindex_[batch.col_index[i]];
for (bst_omp_uint fid = 0; fid < nsize; ++fid) {
int offset = feat2workindex_[fid];
if (offset >= 0) {
this->UpdateSketchCol(gpair, batch[i], tree,
this->UpdateSketchCol(gpair, batch[fid], tree,
work_set_size, offset,
&thread_sketch_[omp_get_thread_num()]);
}
@ -494,7 +495,7 @@ class CQHistMaker: public HistMaker<TStats> {
}
inline void UpdateHistCol(const std::vector<GradientPair> &gpair,
const ColBatch::Inst &c,
const SparsePage::Inst &c,
const MetaInfo &info,
const RegTree &tree,
const std::vector<bst_uint> &fset,
@ -546,7 +547,7 @@ class CQHistMaker: public HistMaker<TStats> {
}
}
inline void UpdateSketchCol(const std::vector<GradientPair> &gpair,
const ColBatch::Inst &c,
const SparsePage::Inst &c,
const RegTree &tree,
size_t work_set_size,
bst_uint offset,
@ -712,18 +713,18 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
// start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(this->work_set_);
auto iter = p_fmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
auto batch = iter->Value();
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = this->feat2workindex_[batch.col_index[i]];
int offset = this->feat2workindex_[this->work_set_[i]];
if (offset >= 0) {
this->UpdateHistCol(gpair, batch[i], info, tree,
fset, offset,
@ -769,19 +770,19 @@ class QuantileHistMaker: public HistMaker<TStats> {
sketchs_[i].Init(info.num_row_, this->param_.sketch_eps);
}
// start accumulating statistics
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
auto iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
auto batch = iter->Value();
// parallel convert to column major format
common::ParallelGroupBuilder<SparseBatch::Entry>
common::ParallelGroupBuilder<Entry>
builder(&col_ptr_, &col_data_, &thread_col_ptr_);
builder.InitBudget(tree.param.num_feature, nthread);
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
RowBatch::Inst inst = batch[i];
SparsePage::Inst inst = batch[i];
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
int nid = this->position_[ridx];
if (nid >= 0) {
@ -800,13 +801,13 @@ class QuantileHistMaker: public HistMaker<TStats> {
builder.InitStorage();
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
RowBatch::Inst inst = batch[i];
SparsePage::Inst inst = batch[i];
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
const int nid = this->position_[ridx];
if (nid >= 0) {
for (bst_uint j = 0; j < inst.length; ++j) {
builder.Push(inst[j].index,
SparseBatch::Entry(nid, inst[j].fvalue),
Entry(nid, inst[j].fvalue),
omp_get_thread_num());
}
}
@ -816,7 +817,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint k = 0; k < nfeat; ++k) {
for (size_t i = col_ptr_[k]; i < col_ptr_[k+1]; ++i) {
const SparseBatch::Entry &e = col_data_[i];
const Entry &e = col_data_[i];
const int wid = this->node2workindex_[e.index];
sketchs_[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].GetHess());
}
@ -873,7 +874,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
// local temp column data structure
std::vector<size_t> col_ptr_;
// local storage of column data
std::vector<SparseBatch::Entry> col_data_;
std::vector<Entry> col_data_;
std::vector<std::vector<size_t> > thread_col_ptr_;
// per node, per feature sketch
std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs_;

View File

@ -57,15 +57,15 @@ class TreeRefresher: public TreeUpdater {
{
const MetaInfo &info = p_fmat->Info();
// start accumulating statistics
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
auto *iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
CHECK_LT(batch.size, std::numeric_limits<unsigned>::max());
const auto nbatch = static_cast<bst_omp_uint>(batch.size);
auto batch = iter->Value();
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
RowBatch::Inst inst = batch[i];
SparsePage::Inst inst = batch[i];
const int tid = omp_get_thread_num();
const auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
RegTree::FVec &feats = fvec_temp[tid];

View File

@ -144,18 +144,18 @@ class SketchMaker: public BaseMaker {
// number of rows in
const size_t nrows = p_fmat->BufferedRowset().Size();
// start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
auto iter = p_fmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
auto batch = iter->Value();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(batch.size);
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
this->UpdateSketchCol(gpair, batch[i], tree,
for (bst_omp_uint fidx = 0; fidx < nsize; ++fidx) {
this->UpdateSketchCol(gpair, batch[fidx], tree,
node_stats_,
batch.col_index[i],
batch[i].length == nrows,
fidx,
batch[fidx].length == nrows,
&thread_sketch_[omp_get_thread_num()]);
}
}
@ -174,7 +174,7 @@ class SketchMaker: public BaseMaker {
}
// update sketch information in column fid
inline void UpdateSketchCol(const std::vector<GradientPair> &gpair,
const ColBatch::Inst &c,
const SparsePage::Inst &c,
const RegTree &tree,
const std::vector<SKStats> &nstats,
bst_uint fid,

View File

@ -29,7 +29,7 @@ TEST(c_api, XGDMatrixCreateFromMat_omp) {
iter->BeforeFirst();
while (iter->Next()) {
auto batch = iter->Value();
for (int i = 0; i < batch.size; i++) {
for (int i = 0; i < batch.Size(); i++) {
auto inst = batch[i];
for (int j = 0; i < inst.length; i++) {
ASSERT_EQ(inst[j].fvalue, 1.5);

View File

@ -18,13 +18,13 @@ TEST(SimpleCSRSource, SaveLoadBinary) {
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);
dmlc::DataIter<xgboost::RowBatch> * row_iter = dmat->RowIterator();
dmlc::DataIter<xgboost::RowBatch> * row_iter_read = dmat_read->RowIterator();
auto row_iter = dmat->RowIterator();
auto row_iter_read = dmat_read->RowIterator();
// Test the data read into the first row
row_iter->BeforeFirst(); row_iter->Next();
row_iter_read->BeforeFirst(); row_iter_read->Next();
xgboost::SparseBatch::Inst first_row = row_iter->Value()[0];
xgboost::SparseBatch::Inst first_row_read = row_iter_read->Value()[0];
auto first_row = row_iter->Value()[0];
auto first_row_read = row_iter_read->Value()[0];
EXPECT_EQ(first_row.length, first_row_read.length);
EXPECT_EQ(first_row[2].index, first_row_read[2].index);
EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue);

View File

@ -18,19 +18,19 @@ TEST(SimpleDMatrix, MetaInfo) {
TEST(SimpleDMatrix, RowAccess) {
std::string tmp_file = CreateSimpleTestData();
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, false, false);
std::remove(tmp_file.c_str());
dmlc::DataIter<xgboost::RowBatch> * row_iter = dmat->RowIterator();
auto row_iter = dmat->RowIterator();
// Loop over the batches and count the records
long row_count = 0;
row_iter->BeforeFirst();
while (row_iter->Next()) row_count += row_iter->Value().size;
while (row_iter->Next()) row_count += row_iter->Value().Size();
EXPECT_EQ(row_count, dmat->Info().num_row_);
// Test the data read into the first row
row_iter->BeforeFirst();
row_iter->Next();
xgboost::SparseBatch::Inst first_row = row_iter->Value()[0];
auto first_row = row_iter->Value()[0];
ASSERT_EQ(first_row.length, 3);
EXPECT_EQ(first_row[2].index, 2);
EXPECT_EQ(first_row[2].fvalue, 20);
@ -45,14 +45,14 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
// Unsorted column access
const std::vector<bool> enable(dmat->Info().num_col_, true);
EXPECT_EQ(dmat->HaveColAccess(false), false);
dmat->InitColAccess(enable, 1, dmat->Info().num_row_, false);
dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it
dmat->InitColAccess(dmat->Info().num_row_, false);
dmat->InitColAccess(0, false); // Calling it again should not change it
ASSERT_EQ(dmat->HaveColAccess(false), true);
// Sorted column access
EXPECT_EQ(dmat->HaveColAccess(true), false);
dmat->InitColAccess(enable, 1, dmat->Info().num_row_, true);
dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it
dmat->InitColAccess(dmat->Info().num_row_, true);
dmat->InitColAccess(0, true); // Calling it again should not change it
ASSERT_EQ(dmat->HaveColAccess(true), true);
EXPECT_EQ(dmat->GetColSize(0), 2);
@ -61,84 +61,19 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
ASSERT_TRUE(dmat->SingleColBlock());
dmlc::DataIter<xgboost::ColBatch> * col_iter = dmat->ColIterator();
auto* col_iter = dmat->ColIterator();
// Loop over the batches and assert the data is as expected
long num_col_batch = 0;
col_iter->BeforeFirst();
while (col_iter->Next()) {
num_col_batch += 1;
EXPECT_EQ(col_iter->Value().size, dmat->Info().num_col_)
EXPECT_EQ(col_iter->Value().Size(), dmat->Info().num_col_)
<< "Expected batch size = number of cells as #batches is 1.";
for (int i = 0; i < static_cast<int>(col_iter->Value().size); ++i) {
for (int i = 0; i < static_cast<int>(col_iter->Value().Size()); ++i) {
EXPECT_EQ(col_iter->Value()[i].length, dmat->GetColSize(i))
<< "Expected length of each colbatch = colsize as #batches is 1.";
}
}
EXPECT_EQ(num_col_batch, 1) << "Expected number of batches to be 1";
col_iter = nullptr;
std::vector<xgboost::bst_uint> sub_feats = {4, 3};
dmlc::DataIter<xgboost::ColBatch> * sub_col_iter = dmat->ColIterator(sub_feats);
// Loop over the batches and assert the data is as expected
sub_col_iter->BeforeFirst();
while (sub_col_iter->Next()) {
EXPECT_EQ(sub_col_iter->Value().size, sub_feats.size())
<< "Expected size of a batch = number of cells in subset as #batches is 1.";
}
sub_col_iter = nullptr;
}
TEST(SimpleDMatrix, ColAccessWithBatches) {
std::string tmp_file = CreateSimpleTestData();
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
std::remove(tmp_file.c_str());
// Unsorted column access
const std::vector<bool> enable(dmat->Info().num_col_, true);
EXPECT_EQ(dmat->HaveColAccess(false), false);
dmat->InitColAccess(enable, 1, 1, false);
dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it
ASSERT_EQ(dmat->HaveColAccess(false), true);
// Sorted column access
EXPECT_EQ(dmat->HaveColAccess(true), false);
dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch
dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it
ASSERT_EQ(dmat->HaveColAccess(true), true);
EXPECT_EQ(dmat->GetColSize(0), 2);
EXPECT_EQ(dmat->GetColSize(1), 1);
EXPECT_EQ(dmat->GetColDensity(0), 1);
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
ASSERT_FALSE(dmat->SingleColBlock());
dmlc::DataIter<xgboost::ColBatch> * col_iter = dmat->ColIterator();
// Loop over the batches and assert the data is as expected
long num_col_batch = 0;
col_iter->BeforeFirst();
while (col_iter->Next()) {
num_col_batch += 1;
EXPECT_EQ(col_iter->Value().size, dmat->Info().num_col_)
<< "Expected batch size = num_cols as max_row_perbatch is 1.";
for (int i = 0; i < static_cast<int>(col_iter->Value().size); ++i) {
EXPECT_LE(col_iter->Value()[i].length, 1)
<< "Expected length of each colbatch <=1 as max_row_perbatch is 1.";
}
}
EXPECT_EQ(num_col_batch, dmat->Info().num_row_)
<< "Expected num batches = num_rows as max_row_perbatch is 1";
col_iter = nullptr;
// The iterator feats should ignore any numbers larger than the num_col
std::vector<xgboost::bst_uint> sub_feats = {
4, 3, static_cast<unsigned int>(dmat->Info().num_col_ + 1)};
dmlc::DataIter<xgboost::ColBatch> * sub_col_iter = dmat->ColIterator(sub_feats);
// Loop over the batches and assert the data is as expected
sub_col_iter->BeforeFirst();
while (sub_col_iter->Next()) {
EXPECT_EQ(sub_col_iter->Value().size, sub_feats.size() - 1)
<< "Expected size of a batch = number of columns in subset "
<< "as max_row_perbatch is 1.";
}
sub_col_iter = nullptr;
}

View File

@ -7,8 +7,9 @@
TEST(SparsePageDMatrix, MetaInfo) {
std::string tmp_file = CreateSimpleTestData();
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, false);
tmp_file + "#" + tmp_file + ".cache", false, false);
std::remove(tmp_file.c_str());
std::cout << tmp_file << std::endl;
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
// Test the metadata that was parsed
@ -29,16 +30,16 @@ TEST(SparsePageDMatrix, RowAccess) {
std::remove(tmp_file.c_str());
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page"));
dmlc::DataIter<xgboost::RowBatch> * row_iter = dmat->RowIterator();
auto row_iter = dmat->RowIterator();
// Loop over the batches and count the records
long row_count = 0;
row_iter->BeforeFirst();
while (row_iter->Next()) row_count += row_iter->Value().size;
while (row_iter->Next()) row_count += row_iter->Value().Size();
EXPECT_EQ(row_count, dmat->Info().num_row_);
// Test the data read into the first row
row_iter->BeforeFirst();
row_iter->Next();
xgboost::SparseBatch::Inst first_row = row_iter->Value()[0];
auto first_row = row_iter->Value()[0];
ASSERT_EQ(first_row.length, 3);
EXPECT_EQ(first_row[2].index, 2);
EXPECT_EQ(first_row[2].fvalue, 20);
@ -58,7 +59,7 @@ TEST(SparsePageDMatrix, ColAcess) {
EXPECT_EQ(dmat->HaveColAccess(true), false);
const std::vector<bool> enable(dmat->Info().num_col_, true);
dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch
dmat->InitColAccess(1, true); // Max 1 row per patch
ASSERT_EQ(dmat->HaveColAccess(true), true);
EXPECT_TRUE(FileExists(tmp_file + ".cache.col.page"));
@ -67,31 +68,19 @@ TEST(SparsePageDMatrix, ColAcess) {
EXPECT_EQ(dmat->GetColDensity(0), 1);
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
dmlc::DataIter<xgboost::ColBatch> * col_iter = dmat->ColIterator();
auto col_iter = dmat->ColIterator();
// Loop over the batches and assert the data is as expected
long num_col_batch = 0;
col_iter->BeforeFirst();
while (col_iter->Next()) {
num_col_batch += 1;
EXPECT_EQ(col_iter->Value().size, dmat->Info().num_col_)
EXPECT_EQ(col_iter->Value().Size(), dmat->Info().num_col_)
<< "Expected batch size to be same as num_cols as max_row_perbatch is 1.";
}
EXPECT_EQ(num_col_batch, dmat->Info().num_row_)
<< "Expected num batches to be same as num_rows as max_row_perbatch is 1";
col_iter = nullptr;
std::vector<xgboost::bst_uint> sub_feats = {4, 3};
dmlc::DataIter<xgboost::ColBatch> * sub_col_iter = dmat->ColIterator(sub_feats);
// Loop over the batches and assert the data is as expected
sub_col_iter->BeforeFirst();
while (sub_col_iter->Next()) {
EXPECT_EQ(sub_col_iter->Value().size, sub_feats.size())
<< "Expected size of a batch to be same as number of columns "
<< "as max_row_perbatch was set to 1.";
}
sub_col_iter = nullptr;
// Clean up of external memory files
std::remove((tmp_file + ".cache").c_str());
std::remove((tmp_file + ".cache.col.page").c_str());
std::remove((tmp_file + ".cache.row.page").c_str());

View File

@ -3,7 +3,13 @@
#include <random>
std::string TempFileName() {
return std::tmpnam(nullptr);
std::string tmp = std::tmpnam(nullptr);
std::replace(tmp.begin(), tmp.end(), '\\',
'/'); // Remove windows backslashes
// Remove drive prefix for windows
if (tmp.find("C:") != std::string::npos)
tmp.erase(tmp.find("C:"), 2);
return tmp;
}
bool FileExists(const std::string name) {

View File

@ -9,7 +9,7 @@ TEST(Linear, shotgun) {
typedef std::pair<std::string, std::string> arg;
auto mat = CreateDMatrix(10, 10, 0);
std::vector<bool> enabled(mat->Info().num_col_, true);
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
mat->InitColAccess(1 << 16, false);
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("shotgun"));
updater->Init({{"eta", "1."}});
@ -28,7 +28,7 @@ TEST(Linear, coordinate) {
typedef std::pair<std::string, std::string> arg;
auto mat = CreateDMatrix(10, 10, 0);
std::vector<bool> enabled(mat->Info().num_col_, true);
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
mat->InitColAccess(1 << 16, false);
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("coord_descent"));
updater->Init({{"eta", "1."}});

View File

@ -33,7 +33,7 @@ TEST(cpu_predictor, Test) {
// Test predict instance
auto batch = dmat->RowIterator()->Value();
for (int i = 0; i < batch.size; i++) {
for (int i = 0; i < batch.Size(); i++) {
std::vector<float> instance_out_predictions;
cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model);
ASSERT_EQ(instance_out_predictions[0], 1.5);

View File

@ -46,7 +46,7 @@ TEST(gpu_predictor, Test) {
}
// Test predict instance
auto batch = dmat->RowIterator()->Value();
for (int i = 0; i < batch.size; i++) {
for (int i = 0; i < batch.Size(); i++) {
std::vector<float> gpu_instance_out_predictions;
std::vector<float> cpu_instance_out_predictions;
cpu_predictor->PredictInstance(batch[i], &cpu_instance_out_predictions,

View File

@ -26,10 +26,10 @@ TEST(gpu_hist_experimental, TestSparseShard) {
TrainParam p;
p.max_depth = 6;
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next());
const RowBatch& batch = iter->Value();
const SparsePage& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
shard.Init(hmat, batch);
CHECK(!iter->Next());
@ -67,10 +67,10 @@ TEST(gpu_hist_experimental, TestDenseShard) {
TrainParam p;
p.max_depth = 6;
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next());
const RowBatch& batch = iter->Value();
const SparsePage& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
shard.Init(hmat, batch);

View File

@ -16,15 +16,16 @@ if [ ${TASK} == "lint" ]; then
cp "$file" "${file/.cu/_tmp.cc}"
done
echo "Running clang tidy..."
header_filter='(xgboost\/src|xgboost\/include)'
for filename in $(find src -name '*.cc'); do
clang-tidy $filename -header-filter=$header_filter -- -Iinclude -Idmlc-core/include -Irabit/include -std=c++11 >> logtidy.txt
done
echo "---------clang-tidy log----------"
cat logtidy.txt
echo "----------------------------"
echo "---------clang-tidy failures----------"
# Fail only on warnings related to XGBoost source files
(cat logtidy.txt|grep -E 'dmlc/xgboost.*warning'|grep -v dmlc-core) && exit -1
(cat logtidy.txt|grep -E 'xgboost.*warning'|grep -v dmlc-core) && exit -1
echo "----------------------------"
exit 0
fi