Use view for SparsePage exclusively. (#6590)
This commit is contained in:
parent
78f2cd83d7
commit
f2f7dd87b8
@ -252,15 +252,6 @@ class SparsePage {
|
|||||||
/*! \brief an instance of sparse vector in the batch */
|
/*! \brief an instance of sparse vector in the batch */
|
||||||
using Inst = common::Span<Entry const>;
|
using Inst = common::Span<Entry const>;
|
||||||
|
|
||||||
/*! \brief get i-th row from the batch */
|
|
||||||
inline Inst operator[](size_t i) const {
|
|
||||||
const auto& data_vec = data.HostVector();
|
|
||||||
const auto& offset_vec = offset.HostVector();
|
|
||||||
size_t size = offset_vec[i + 1] - offset_vec[i];
|
|
||||||
return {data_vec.data() + offset_vec[i],
|
|
||||||
static_cast<Inst::index_type>(size)};
|
|
||||||
}
|
|
||||||
|
|
||||||
HostSparsePageView GetView() const {
|
HostSparsePageView GetView() const {
|
||||||
return {offset.ConstHostSpan(), data.ConstHostSpan()};
|
return {offset.ConstHostSpan(), data.ConstHostSpan()};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -78,6 +78,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
|
|||||||
const size_t batch_threads = std::max(
|
const size_t batch_threads = std::max(
|
||||||
size_t(1),
|
size_t(1),
|
||||||
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
|
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
|
||||||
|
auto page = batch.GetView();
|
||||||
MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
||||||
size_t* p_part = partial_sums.Get();
|
size_t* p_part = partial_sums.Get();
|
||||||
|
|
||||||
@ -92,7 +93,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
|
|||||||
|
|
||||||
size_t sum = 0;
|
size_t sum = 0;
|
||||||
for (size_t i = ibegin; i < iend; ++i) {
|
for (size_t i = ibegin; i < iend; ++i) {
|
||||||
sum += batch[i].size();
|
sum += page[i].size();
|
||||||
row_ptr[rbegin + 1 + i] = sum;
|
row_ptr[rbegin + 1 + i] = sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -825,19 +825,20 @@ SparsePage SparsePage::GetTranspose(int num_columns) const {
|
|||||||
const int nthread = omp_get_max_threads();
|
const int nthread = omp_get_max_threads();
|
||||||
builder.InitBudget(num_columns, nthread);
|
builder.InitBudget(num_columns, nthread);
|
||||||
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
||||||
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
|
auto page = this->GetView();
|
||||||
|
#pragma omp parallel for default(none) shared(batch_size, builder, page) schedule(static)
|
||||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
auto inst = (*this)[i];
|
auto inst = page[i];
|
||||||
for (const auto& entry : inst) {
|
for (const auto& entry : inst) {
|
||||||
builder.AddBudget(entry.index, tid);
|
builder.AddBudget(entry.index, tid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
builder.InitStorage();
|
builder.InitStorage();
|
||||||
#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static)
|
#pragma omp parallel for default(none) shared(batch_size, builder, page) schedule(static)
|
||||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
auto inst = (*this)[i];
|
auto inst = page[i];
|
||||||
for (const auto& entry : inst) {
|
for (const auto& entry : inst) {
|
||||||
builder.Push(
|
builder.Push(
|
||||||
entry.index,
|
entry.index,
|
||||||
|
|||||||
@ -28,13 +28,12 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
|||||||
auto out = new SimpleDMatrix;
|
auto out = new SimpleDMatrix;
|
||||||
SparsePage& out_page = out->sparse_page_;
|
SparsePage& out_page = out->sparse_page_;
|
||||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
for (auto const &page : this->GetBatches<SparsePage>()) {
|
||||||
page.data.HostVector();
|
auto batch = page.GetView();
|
||||||
page.offset.HostVector();
|
|
||||||
auto& h_data = out_page.data.HostVector();
|
auto& h_data = out_page.data.HostVector();
|
||||||
auto& h_offset = out_page.offset.HostVector();
|
auto& h_offset = out_page.offset.HostVector();
|
||||||
size_t rptr{0};
|
size_t rptr{0};
|
||||||
for (auto ridx : ridxs) {
|
for (auto ridx : ridxs) {
|
||||||
auto inst = page[ridx];
|
auto inst = batch[ridx];
|
||||||
rptr += inst.size();
|
rptr += inst.size();
|
||||||
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
|
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
|
||||||
h_offset.emplace_back(rptr);
|
h_offset.emplace_back(rptr);
|
||||||
|
|||||||
@ -173,9 +173,10 @@ class GBLinear : public GradientBooster {
|
|||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
// parallel over local batch
|
// parallel over local batch
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||||
|
auto page = batch.GetView();
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||||
// loop over output groups
|
// loop over output groups
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
|
|||||||
@ -678,6 +678,7 @@ class Dart : public GBTree {
|
|||||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||||
// start collecting the prediction
|
// start collecting the prediction
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
constexpr int kUnroll = 8;
|
constexpr int kUnroll = 8;
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||||
const bst_omp_uint rest = nsize % kUnroll;
|
const bst_omp_uint rest = nsize % kUnroll;
|
||||||
@ -692,7 +693,7 @@ class Dart : public GBTree {
|
|||||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||||
}
|
}
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
for (int k = 0; k < kUnroll; ++k) {
|
||||||
inst[k] = batch[i + k];
|
inst[k] = page[i + k];
|
||||||
}
|
}
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
for (int k = 0; k < kUnroll; ++k) {
|
||||||
for (int gid = 0; gid < num_group; ++gid) {
|
for (int gid = 0; gid < num_group; ++gid) {
|
||||||
@ -707,7 +708,7 @@ class Dart : public GBTree {
|
|||||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||||
RegTree::FVec& feats = thread_temp_[0];
|
RegTree::FVec& feats = thread_temp_[0];
|
||||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||||
const SparsePage::Inst inst = batch[i];
|
const SparsePage::Inst inst = page[i];
|
||||||
for (int gid = 0; gid < num_group; ++gid) {
|
for (int gid = 0; gid < num_group; ++gid) {
|
||||||
const size_t offset = ridx * num_group + gid;
|
const size_t offset = ridx * num_group + gid;
|
||||||
preds[offset] +=
|
preds[offset] +=
|
||||||
|
|||||||
@ -82,7 +82,8 @@ inline std::pair<double, double> GetGradient(int group_idx, int num_group, int f
|
|||||||
DMatrix *p_fmat) {
|
DMatrix *p_fmat) {
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
double sum_grad = 0.0, sum_hess = 0.0;
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto col = batch[fidx];
|
auto page = batch.GetView();
|
||||||
|
auto col = page[fidx];
|
||||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
const bst_float v = col[j].fvalue;
|
const bst_float v = col[j].fvalue;
|
||||||
@ -111,7 +112,8 @@ inline std::pair<double, double> GetGradientParallel(int group_idx, int num_grou
|
|||||||
DMatrix *p_fmat) {
|
DMatrix *p_fmat) {
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
double sum_grad = 0.0, sum_hess = 0.0;
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto col = batch[fidx];
|
auto page = batch.GetView();
|
||||||
|
auto col = page[fidx];
|
||||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||||
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
||||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
@ -166,7 +168,8 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
|||||||
DMatrix *p_fmat) {
|
DMatrix *p_fmat) {
|
||||||
if (dw == 0.0f) return;
|
if (dw == 0.0f) return;
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto col = batch[fidx];
|
auto page = batch.GetView();
|
||||||
|
auto col = page[fidx];
|
||||||
// update grad value
|
// update grad value
|
||||||
const auto num_row = static_cast<bst_omp_uint>(col.size());
|
const auto num_row = static_cast<bst_omp_uint>(col.size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
@ -334,9 +337,10 @@ class GreedyFeatureSelector : public FeatureSelector {
|
|||||||
// Calculate univariate gradient sums
|
// Calculate univariate gradient sums
|
||||||
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||||
const auto col = batch[i];
|
const auto col = page[i];
|
||||||
const bst_uint ndata = col.size();
|
const bst_uint ndata = col.size();
|
||||||
auto &sums = gpair_sums_[group_idx * nfeat + i];
|
auto &sums = gpair_sums_[group_idx * nfeat + i];
|
||||||
for (bst_uint j = 0u; j < ndata; ++j) {
|
for (bst_uint j = 0u; j < ndata; ++j) {
|
||||||
@ -399,10 +403,11 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
|||||||
// Calculate univariate gradient sums
|
// Calculate univariate gradient sums
|
||||||
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
// column-parallel is usually faster than row-parallel
|
auto page = batch.GetView();
|
||||||
|
// column-parallel is usually fastaer than row-parallel
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||||
const auto col = batch[i];
|
const auto col = page[i];
|
||||||
const bst_uint ndata = col.size();
|
const bst_uint ndata = col.size();
|
||||||
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
||||||
auto &sums = gpair_sums_[gid * nfeat + i];
|
auto &sums = gpair_sums_[gid * nfeat + i];
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
|
|
||||||
CHECK(p_fmat->SingleColBlock());
|
CHECK(p_fmat->SingleColBlock());
|
||||||
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
|
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
|
||||||
|
auto page = batch.GetView();
|
||||||
|
|
||||||
if (IsEmpty()) {
|
if (IsEmpty()) {
|
||||||
return;
|
return;
|
||||||
@ -72,7 +73,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
row_ptr_ = {0};
|
row_ptr_ = {0};
|
||||||
// iterate through columns
|
// iterate through columns
|
||||||
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
||||||
common::Span<Entry const> col = batch[fidx];
|
common::Span<Entry const> col = page[fidx];
|
||||||
auto cmp = [](Entry e1, Entry e2) {
|
auto cmp = [](Entry e1, Entry e2) {
|
||||||
return e1.index < e2.index;
|
return e1.index < e2.index;
|
||||||
};
|
};
|
||||||
@ -89,7 +90,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
data_.resize(row_ptr_.back());
|
data_.resize(row_ptr_.back());
|
||||||
gpair_.resize(num_row_ * model_param.num_output_group);
|
gpair_.resize(num_row_ * model_param.num_output_group);
|
||||||
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
||||||
auto col = batch[fidx];
|
auto col = page[fidx];
|
||||||
auto seg = column_segments[fidx];
|
auto seg = column_segments[fidx];
|
||||||
dh::safe_cuda(cudaMemcpy(
|
dh::safe_cuda(cudaMemcpy(
|
||||||
data_.data().get() + row_ptr_[fidx],
|
data_.data().get() + row_ptr_[fidx],
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat,
|
selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat,
|
||||||
param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0);
|
param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0);
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
|
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||||
@ -60,7 +61,7 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
param_.reg_lambda_denorm);
|
param_.reg_lambda_denorm);
|
||||||
if (ii < 0) continue;
|
if (ii < 0) continue;
|
||||||
const bst_uint fid = ii;
|
const bst_uint fid = ii;
|
||||||
auto col = batch[ii];
|
auto col = page[ii];
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
double sum_grad = 0.0, sum_hess = 0.0;
|
||||||
for (auto& c : col) {
|
for (auto& c : col) {
|
||||||
|
|||||||
@ -360,18 +360,19 @@ class CPUPredictor : public Predictor {
|
|||||||
// start collecting the prediction
|
// start collecting the prediction
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
// parallel over local batch
|
// parallel over local batch
|
||||||
|
auto page = batch.GetView();
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
||||||
RegTree::FVec &feats = thread_temp_[tid];
|
RegTree::FVec &feats = thread_temp_[tid];
|
||||||
feats.Fill(batch[i]);
|
feats.Fill(page[i]);
|
||||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
int tid = model.trees[j]->GetLeafIndex(feats);
|
||||||
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
||||||
}
|
}
|
||||||
feats.Drop(batch[i]);
|
feats.Drop(page[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -407,6 +408,7 @@ class CPUPredictor : public Predictor {
|
|||||||
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
|
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
|
||||||
// start collecting the contributions
|
// start collecting the contributions
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
// parallel over local batch
|
// parallel over local batch
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
@ -417,7 +419,7 @@ class CPUPredictor : public Predictor {
|
|||||||
// loop over all classes
|
// loop over all classes
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns];
|
bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns];
|
||||||
feats.Fill(batch[i]);
|
feats.Fill(page[i]);
|
||||||
// calculate contributions
|
// calculate contributions
|
||||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||||
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
|
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
|
||||||
@ -435,7 +437,7 @@ class CPUPredictor : public Predictor {
|
|||||||
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
|
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
feats.Drop(batch[i]);
|
feats.Drop(page[i]);
|
||||||
// add base margin to BIAS
|
// add base margin to BIAS
|
||||||
if (base_margin.size() != 0) {
|
if (base_margin.size() != 0) {
|
||||||
p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid];
|
p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid];
|
||||||
|
|||||||
@ -59,8 +59,9 @@ class BaseMaker: public TreeUpdater {
|
|||||||
-std::numeric_limits<bst_float>::max());
|
-std::numeric_limits<bst_float>::max());
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
|
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
|
||||||
auto c = batch[fid];
|
auto c = page[fid];
|
||||||
if (c.size() != 0) {
|
if (c.size() != 0) {
|
||||||
CHECK_LT(fid * 2, fminmax_.size());
|
CHECK_LT(fid * 2, fminmax_.size());
|
||||||
fminmax_[fid * 2 + 0] =
|
fminmax_[fid * 2 + 0] =
|
||||||
@ -249,8 +250,9 @@ class BaseMaker: public TreeUpdater {
|
|||||||
inline void CorrectNonDefaultPositionByBatch(
|
inline void CorrectNonDefaultPositionByBatch(
|
||||||
const SparsePage &batch, const std::vector<bst_uint> &sorted_split_set,
|
const SparsePage &batch, const std::vector<bst_uint> &sorted_split_set,
|
||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (size_t fid = 0; fid < batch.Size(); ++fid) {
|
for (size_t fid = 0; fid < batch.Size(); ++fid) {
|
||||||
auto col = batch[fid];
|
auto col = page[fid];
|
||||||
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
|
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
|
||||||
|
|
||||||
if (it != sorted_split_set.end() && *it == fid) {
|
if (it != sorted_split_set.end() && *it == fid) {
|
||||||
@ -308,8 +310,9 @@ class BaseMaker: public TreeUpdater {
|
|||||||
std::vector<unsigned> fsplits;
|
std::vector<unsigned> fsplits;
|
||||||
this->GetSplitSet(nodes, tree, &fsplits);
|
this->GetSplitSet(nodes, tree, &fsplits);
|
||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto fid : fsplits) {
|
for (auto fid : fsplits) {
|
||||||
auto col = batch[fid];
|
auto col = page[fid];
|
||||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
|
|||||||
@ -77,8 +77,9 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (column_densities_.empty()) {
|
if (column_densities_.empty()) {
|
||||||
std::vector<size_t> column_size(dmat->Info().num_col_);
|
std::vector<size_t> column_size(dmat->Info().num_col_);
|
||||||
for (const auto &batch : dmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : dmat->GetBatches<SortedCSCPage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0u; i < batch.Size(); i++) {
|
for (auto i = 0u; i < batch.Size(); i++) {
|
||||||
column_size[i] += batch[i].size();
|
column_size[i] += page[i].size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
column_densities_.resize(column_size.size());
|
column_densities_.resize(column_size.size());
|
||||||
@ -447,13 +448,14 @@ class ColMaker: public TreeUpdater {
|
|||||||
#endif // defined(_OPENMP)
|
#endif // defined(_OPENMP)
|
||||||
{
|
{
|
||||||
dmlc::OMPException omp_handler;
|
dmlc::OMPException omp_handler;
|
||||||
|
auto page = batch.GetView();
|
||||||
#pragma omp parallel for schedule(dynamic, batch_size)
|
#pragma omp parallel for schedule(dynamic, batch_size)
|
||||||
for (bst_omp_uint i = 0; i < num_features; ++i) {
|
for (bst_omp_uint i = 0; i < num_features; ++i) {
|
||||||
omp_handler.Run([&]() {
|
omp_handler.Run([&]() {
|
||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
bst_feature_t const fid = feat_set[i];
|
bst_feature_t const fid = feat_set[i];
|
||||||
int32_t const tid = omp_get_thread_num();
|
int32_t const tid = omp_get_thread_num();
|
||||||
auto c = batch[fid];
|
auto c = page[fid];
|
||||||
const bool ind =
|
const bool ind =
|
||||||
c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
|
c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
|
||||||
if (colmaker_train_param_.NeedForwardSearch(
|
if (colmaker_train_param_.NeedForwardSearch(
|
||||||
@ -562,8 +564,9 @@ class ColMaker: public TreeUpdater {
|
|||||||
std::sort(fsplits.begin(), fsplits.end());
|
std::sort(fsplits.begin(), fsplits.end());
|
||||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto fid : fsplits) {
|
for (auto fid : fsplits) {
|
||||||
auto col = batch[fid];
|
auto col = page[fid];
|
||||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
|
|||||||
@ -338,6 +338,7 @@ class CQHistMaker: public HistMaker {
|
|||||||
thread_hist_.resize(omp_get_max_threads());
|
thread_hist_.resize(omp_get_max_threads());
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
// start enumeration
|
// start enumeration
|
||||||
const auto nsize = static_cast<bst_omp_uint>(fset.size());
|
const auto nsize = static_cast<bst_omp_uint>(fset.size());
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
@ -345,7 +346,7 @@ class CQHistMaker: public HistMaker {
|
|||||||
int fid = fset[i];
|
int fid = fset[i];
|
||||||
int offset = feat2workindex_[fid];
|
int offset = feat2workindex_[fid];
|
||||||
if (offset >= 0) {
|
if (offset >= 0) {
|
||||||
this->UpdateHistCol(gpair, batch[fid], info, tree,
|
this->UpdateHistCol(gpair, page[fid], info, tree,
|
||||||
fset, offset,
|
fset, offset,
|
||||||
&thread_hist_[omp_get_thread_num()]);
|
&thread_hist_[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
@ -413,7 +414,7 @@ class CQHistMaker: public HistMaker {
|
|||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||||
// TWOPASS: use the real set + split set in the column iteration.
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
|
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
|
||||||
|
auto page = batch.GetView();
|
||||||
// start enumeration
|
// start enumeration
|
||||||
const auto nsize = static_cast<bst_omp_uint>(work_set_.size());
|
const auto nsize = static_cast<bst_omp_uint>(work_set_.size());
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
@ -421,7 +422,7 @@ class CQHistMaker: public HistMaker {
|
|||||||
int fid = work_set_[i];
|
int fid = work_set_[i];
|
||||||
int offset = feat2workindex_[fid];
|
int offset = feat2workindex_[fid];
|
||||||
if (offset >= 0) {
|
if (offset >= 0) {
|
||||||
this->UpdateSketchCol(gpair, batch[fid], tree,
|
this->UpdateSketchCol(gpair, page[fid], tree,
|
||||||
work_set_size, offset,
|
work_set_size, offset,
|
||||||
&thread_sketch_[omp_get_thread_num()]);
|
&thread_sketch_[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
@ -696,6 +697,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
|
|||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||||
// TWOPASS: use the real set + split set in the column iteration.
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
|
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
|
||||||
|
auto page = batch.GetView();
|
||||||
|
|
||||||
// start enumeration
|
// start enumeration
|
||||||
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
|
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
|
||||||
@ -704,7 +706,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
|
|||||||
int fid = this->work_set_[i];
|
int fid = this->work_set_[i];
|
||||||
int offset = this->feat2workindex_[fid];
|
int offset = this->feat2workindex_[fid];
|
||||||
if (offset >= 0) {
|
if (offset >= 0) {
|
||||||
this->UpdateHistCol(gpair, batch[fid], info, tree,
|
this->UpdateHistCol(gpair, page[fid], info, tree,
|
||||||
fset, offset,
|
fset, offset,
|
||||||
&this->thread_hist_[omp_get_thread_num()]);
|
&this->thread_hist_[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -69,11 +69,12 @@ class TreeRefresher: public TreeUpdater {
|
|||||||
const MetaInfo &info = p_fmat->Info();
|
const MetaInfo &info = p_fmat->Info();
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
|
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
|
||||||
const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
|
const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||||
SparsePage::Inst inst = batch[i];
|
SparsePage::Inst inst = page[i];
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
const auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
const auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||||
RegTree::FVec &feats = fvec_temp[tid];
|
RegTree::FVec &feats = fvec_temp[tid];
|
||||||
|
|||||||
@ -30,10 +30,11 @@ TEST(CAPI, XGDMatrixCreateFromMatDT) {
|
|||||||
ASSERT_EQ(info.num_nonzero_, 6ul);
|
ASSERT_EQ(info.num_nonzero_, 6ul);
|
||||||
|
|
||||||
for (const auto &batch : (*dmat)->GetBatches<xgboost::SparsePage>()) {
|
for (const auto &batch : (*dmat)->GetBatches<xgboost::SparsePage>()) {
|
||||||
ASSERT_EQ(batch[0][0].fvalue, 0.0f);
|
auto page = batch.GetView();
|
||||||
ASSERT_EQ(batch[0][1].fvalue, -4.0f);
|
ASSERT_EQ(page[0][0].fvalue, 0.0f);
|
||||||
ASSERT_EQ(batch[2][0].fvalue, 3.0f);
|
ASSERT_EQ(page[0][1].fvalue, -4.0f);
|
||||||
ASSERT_EQ(batch[2][1].fvalue, 0.0f);
|
ASSERT_EQ(page[2][0].fvalue, 3.0f);
|
||||||
|
ASSERT_EQ(page[2][1].fvalue, 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete dmat;
|
delete dmat;
|
||||||
@ -62,8 +63,9 @@ TEST(CAPI, XGDMatrixCreateFromMatOmp) {
|
|||||||
ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing);
|
ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing);
|
||||||
|
|
||||||
for (const auto &batch : (*dmat)->GetBatches<xgboost::SparsePage>()) {
|
for (const auto &batch : (*dmat)->GetBatches<xgboost::SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (size_t i = 0; i < batch.Size(); i++) {
|
for (size_t i = 0; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto e : inst) {
|
for (auto e : inst) {
|
||||||
ASSERT_EQ(e.fvalue, 1.5);
|
ASSERT_EQ(e.fvalue, 1.5);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -176,9 +176,10 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
|||||||
// Collect data into columns
|
// Collect data into columns
|
||||||
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
||||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
ASSERT_GT(batch.Size(), 0ul);
|
ASSERT_GT(batch.Size(), 0ul);
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
for (auto e : batch[i]) {
|
for (auto e : page[i]) {
|
||||||
columns[e.index].push_back(e.fvalue);
|
columns[e.index].push_back(e.fvalue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -47,7 +47,8 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
|
|||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 8);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 8);
|
||||||
|
|
||||||
auto &batch = *dmat.GetBatches<SparsePage>().begin();
|
auto &batch = *dmat.GetBatches<SparsePage>().begin();
|
||||||
auto inst = batch[0];
|
auto page = batch.GetView();
|
||||||
|
auto inst = page[0];
|
||||||
EXPECT_EQ(inst[0].fvalue, 1);
|
EXPECT_EQ(inst[0].fvalue, 1);
|
||||||
EXPECT_EQ(inst[0].index, 0);
|
EXPECT_EQ(inst[0].index, 0);
|
||||||
EXPECT_EQ(inst[1].fvalue, 3);
|
EXPECT_EQ(inst[1].fvalue, 3);
|
||||||
@ -57,7 +58,7 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
|
|||||||
EXPECT_EQ(inst[3].fvalue, 7);
|
EXPECT_EQ(inst[3].fvalue, 7);
|
||||||
EXPECT_EQ(inst[3].index, 3);
|
EXPECT_EQ(inst[3].index, 3);
|
||||||
|
|
||||||
inst = batch[1];
|
inst = page[1];
|
||||||
EXPECT_EQ(inst[0].fvalue, 2);
|
EXPECT_EQ(inst[0].fvalue, 2);
|
||||||
EXPECT_EQ(inst[0].index, 0);
|
EXPECT_EQ(inst[0].index, 0);
|
||||||
EXPECT_EQ(inst[1].fvalue, 4);
|
EXPECT_EQ(inst[1].fvalue, 4);
|
||||||
|
|||||||
@ -11,9 +11,9 @@ namespace xgboost {
|
|||||||
TEST(SparsePage, PushCSC) {
|
TEST(SparsePage, PushCSC) {
|
||||||
std::vector<bst_row_t> offset {0};
|
std::vector<bst_row_t> offset {0};
|
||||||
std::vector<Entry> data;
|
std::vector<Entry> data;
|
||||||
SparsePage page;
|
SparsePage batch;
|
||||||
page.offset.HostVector() = offset;
|
batch.offset.HostVector() = offset;
|
||||||
page.data.HostVector() = data;
|
batch.data.HostVector() = data;
|
||||||
|
|
||||||
offset = {0, 1, 4};
|
offset = {0, 1, 4};
|
||||||
for (size_t i = 0; i < offset.back(); ++i) {
|
for (size_t i = 0; i < offset.back(); ++i) {
|
||||||
@ -24,25 +24,26 @@ TEST(SparsePage, PushCSC) {
|
|||||||
other.offset.HostVector() = offset;
|
other.offset.HostVector() = offset;
|
||||||
other.data.HostVector() = data;
|
other.data.HostVector() = data;
|
||||||
|
|
||||||
page.PushCSC(other);
|
batch.PushCSC(other);
|
||||||
|
|
||||||
ASSERT_EQ(page.offset.HostVector().size(), offset.size());
|
ASSERT_EQ(batch.offset.HostVector().size(), offset.size());
|
||||||
ASSERT_EQ(page.data.HostVector().size(), data.size());
|
ASSERT_EQ(batch.data.HostVector().size(), data.size());
|
||||||
for (size_t i = 0; i < offset.size(); ++i) {
|
for (size_t i = 0; i < offset.size(); ++i) {
|
||||||
ASSERT_EQ(page.offset.HostVector()[i], offset[i]);
|
ASSERT_EQ(batch.offset.HostVector()[i], offset[i]);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < data.size(); ++i) {
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
ASSERT_EQ(page.data.HostVector()[i].index, data[i].index);
|
ASSERT_EQ(batch.data.HostVector()[i].index, data[i].index);
|
||||||
}
|
}
|
||||||
|
|
||||||
page.PushCSC(other);
|
batch.PushCSC(other);
|
||||||
ASSERT_EQ(page.offset.HostVector().size(), offset.size());
|
ASSERT_EQ(batch.offset.HostVector().size(), offset.size());
|
||||||
ASSERT_EQ(page.data.Size(), data.size() * 2);
|
ASSERT_EQ(batch.data.Size(), data.size() * 2);
|
||||||
|
|
||||||
for (size_t i = 0; i < offset.size(); ++i) {
|
for (size_t i = 0; i < offset.size(); ++i) {
|
||||||
ASSERT_EQ(page.offset.HostVector()[i], offset[i] * 2);
|
ASSERT_EQ(batch.offset.HostVector()[i], offset[i] * 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto page = batch.GetView();
|
||||||
auto inst = page[0];
|
auto inst = page[0];
|
||||||
ASSERT_EQ(inst.size(), 2ul);
|
ASSERT_EQ(inst.size(), 2ul);
|
||||||
for (auto entry : inst) {
|
for (auto entry : inst) {
|
||||||
@ -78,7 +79,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
|
|||||||
// The feature value for a feature in each row should be identical, as that is
|
// The feature value for a feature in each row should be identical, as that is
|
||||||
// how the dmatrix has been created
|
// how the dmatrix has been created
|
||||||
for (size_t i = 0; i < page.Size(); ++i) {
|
for (size_t i = 0; i < page.Size(); ++i) {
|
||||||
auto inst = page[i];
|
auto inst = page.GetView()[i];
|
||||||
for (size_t j = 1; j < inst.size(); ++j) {
|
for (size_t j = 1; j < inst.size(); ++j) {
|
||||||
ASSERT_EQ(inst[0].fvalue, inst[j].fvalue);
|
ASSERT_EQ(inst[0].fvalue, inst[j].fvalue);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,7 +39,8 @@ TEST(SimpleDMatrix, RowAccess) {
|
|||||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||||
// Test the data read into the first row
|
// Test the data read into the first row
|
||||||
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
||||||
auto first_row = batch[0];
|
auto page = batch.GetView();
|
||||||
|
auto first_row = page[0];
|
||||||
ASSERT_EQ(first_row.size(), 3);
|
ASSERT_EQ(first_row.size(), 3);
|
||||||
EXPECT_EQ(first_row[2].index, 2);
|
EXPECT_EQ(first_row[2].index, 2);
|
||||||
EXPECT_EQ(first_row[2].fvalue, 20);
|
EXPECT_EQ(first_row[2].fvalue, 20);
|
||||||
@ -143,8 +144,9 @@ TEST(SimpleDMatrix, FromDense) {
|
|||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 6);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 6);
|
||||||
|
|
||||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto j = 0ull; j < inst.size(); j++) {
|
for (auto j = 0ull; j < inst.size(); j++) {
|
||||||
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
|
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
|
||||||
EXPECT_EQ(inst[j].index, j);
|
EXPECT_EQ(inst[j].index, j);
|
||||||
@ -165,19 +167,20 @@ TEST(SimpleDMatrix, FromCSC) {
|
|||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
|
||||||
|
|
||||||
auto &batch = *dmat.GetBatches<SparsePage>().begin();
|
auto &batch = *dmat.GetBatches<SparsePage>().begin();
|
||||||
auto inst = batch[0];
|
auto page = batch.GetView();
|
||||||
|
auto inst = page[0];
|
||||||
EXPECT_EQ(inst[0].fvalue, 1);
|
EXPECT_EQ(inst[0].fvalue, 1);
|
||||||
EXPECT_EQ(inst[0].index, 0);
|
EXPECT_EQ(inst[0].index, 0);
|
||||||
EXPECT_EQ(inst[1].fvalue, 2);
|
EXPECT_EQ(inst[1].fvalue, 2);
|
||||||
EXPECT_EQ(inst[1].index, 1);
|
EXPECT_EQ(inst[1].index, 1);
|
||||||
|
|
||||||
inst = batch[1];
|
inst = page[1];
|
||||||
EXPECT_EQ(inst[0].fvalue, 3);
|
EXPECT_EQ(inst[0].fvalue, 3);
|
||||||
EXPECT_EQ(inst[0].index, 0);
|
EXPECT_EQ(inst[0].index, 0);
|
||||||
EXPECT_EQ(inst[1].fvalue, 4);
|
EXPECT_EQ(inst[1].fvalue, 4);
|
||||||
EXPECT_EQ(inst[1].index, 1);
|
EXPECT_EQ(inst[1].index, 1);
|
||||||
|
|
||||||
inst = batch[2];
|
inst = page[2];
|
||||||
EXPECT_EQ(inst[0].fvalue, 5);
|
EXPECT_EQ(inst[0].fvalue, 5);
|
||||||
EXPECT_EQ(inst[0].index, 1);
|
EXPECT_EQ(inst[0].index, 1);
|
||||||
}
|
}
|
||||||
@ -194,11 +197,12 @@ TEST(SimpleDMatrix, FromFile) {
|
|||||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||||
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
|
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
|
||||||
|
|
||||||
auto verify_batch = [kExpectedNumRow](SparsePage const &batch) {
|
auto verify_batch = [kExpectedNumRow](SparsePage const &page) {
|
||||||
|
auto batch = page.GetView();
|
||||||
EXPECT_EQ(batch.Size(), kExpectedNumRow);
|
EXPECT_EQ(batch.Size(), kExpectedNumRow);
|
||||||
EXPECT_EQ(batch.offset.HostVector(),
|
EXPECT_EQ(page.offset.HostVector(),
|
||||||
std::vector<bst_row_t>({0, 3, 6, 9, 12, 15, 15}));
|
std::vector<bst_row_t>({0, 3, 6, 9, 12, 15, 15}));
|
||||||
EXPECT_EQ(batch.base_rowid, 0);
|
EXPECT_EQ(page.base_rowid, 0);
|
||||||
|
|
||||||
for (auto i = 0ull; i < batch.Size() - 1; i++) {
|
for (auto i = 0ull; i < batch.Size() - 1; i++) {
|
||||||
if (i % 2 == 0) {
|
if (i % 2 == 0) {
|
||||||
@ -251,8 +255,10 @@ TEST(SimpleDMatrix, Slice) {
|
|||||||
ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size());
|
ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size());
|
||||||
ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses);
|
ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses);
|
||||||
|
|
||||||
for (auto const& in_page : p_m->GetBatches<SparsePage>()) {
|
for (auto const& in_batch : p_m->GetBatches<SparsePage>()) {
|
||||||
for (auto const &out_page : out->GetBatches<SparsePage>()) {
|
auto in_page = in_batch.GetView();
|
||||||
|
for (auto const &out_batch : out->GetBatches<SparsePage>()) {
|
||||||
|
auto out_page = out_batch.GetView();
|
||||||
for (size_t i = 0; i < ridxs.size(); ++i) {
|
for (size_t i = 0; i < ridxs.size(); ++i) {
|
||||||
auto ridx = ridxs[i];
|
auto ridx = ridxs[i];
|
||||||
auto out_inst = out_page[i];
|
auto out_inst = out_page[i];
|
||||||
@ -305,8 +311,8 @@ TEST(SimpleDMatrix, SaveLoadBinary) {
|
|||||||
auto row_iter = dmat->GetBatches<xgboost::SparsePage>().begin();
|
auto row_iter = dmat->GetBatches<xgboost::SparsePage>().begin();
|
||||||
auto row_iter_read = dmat_read->GetBatches<xgboost::SparsePage>().begin();
|
auto row_iter_read = dmat_read->GetBatches<xgboost::SparsePage>().begin();
|
||||||
// Test the data read into the first row
|
// Test the data read into the first row
|
||||||
auto first_row = (*row_iter)[0];
|
auto first_row = (*row_iter).GetView()[0];
|
||||||
auto first_row_read = (*row_iter_read)[0];
|
auto first_row_read = (*row_iter_read).GetView()[0];
|
||||||
EXPECT_EQ(first_row.size(), first_row_read.size());
|
EXPECT_EQ(first_row.size(), first_row_read.size());
|
||||||
EXPECT_EQ(first_row[2].index, first_row_read[2].index);
|
EXPECT_EQ(first_row[2].index, first_row_read[2].index);
|
||||||
EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue);
|
EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue);
|
||||||
|
|||||||
@ -35,8 +35,9 @@ TEST(SimpleDMatrix, FromColumnarDenseBasic) {
|
|||||||
|
|
||||||
void TestDenseColumn(DMatrix* dmat, size_t n_rows, size_t n_cols) {
|
void TestDenseColumn(DMatrix* dmat, size_t n_rows, size_t n_cols) {
|
||||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto j = 0ull; j < inst.size(); j++) {
|
for (auto j = 0ull; j < inst.size(); j++) {
|
||||||
EXPECT_EQ(inst[j].fvalue, i * 2);
|
EXPECT_EQ(inst[j].fvalue, i * 2);
|
||||||
EXPECT_EQ(inst[j].index, j);
|
EXPECT_EQ(inst[j].index, j);
|
||||||
@ -162,8 +163,9 @@ TEST(SimpleDMatrix, FromColumnarWithEmptyRows) {
|
|||||||
-1);
|
-1);
|
||||||
|
|
||||||
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto j = 0ull; j < inst.size(); j++) {
|
for (auto j = 0ull; j < inst.size(); j++) {
|
||||||
EXPECT_EQ(inst[j].fvalue, i);
|
EXPECT_EQ(inst[j].fvalue, i);
|
||||||
EXPECT_EQ(inst[j].index, j);
|
EXPECT_EQ(inst[j].index, j);
|
||||||
@ -257,8 +259,9 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
|
|||||||
data::CudfAdapter adapter(str);
|
data::CudfAdapter adapter(str);
|
||||||
data::SimpleDMatrix dmat(&adapter, 2.0, -1);
|
data::SimpleDMatrix dmat(&adapter, 2.0, -1);
|
||||||
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto e : inst) {
|
for (auto e : inst) {
|
||||||
ASSERT_NE(e.fvalue, 2.0);
|
ASSERT_NE(e.fvalue, 2.0);
|
||||||
}
|
}
|
||||||
@ -304,8 +307,9 @@ TEST(SimpleDMatrix, FromColumnarSparseBasic) {
|
|||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 32);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 32);
|
||||||
|
|
||||||
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto j = 0ull; j < inst.size(); j++) {
|
for (auto j = 0ull; j < inst.size(); j++) {
|
||||||
EXPECT_EQ(inst[j].fvalue, i * 2);
|
EXPECT_EQ(inst[j].fvalue, i * 2);
|
||||||
EXPECT_EQ(inst[j].index, j);
|
EXPECT_EQ(inst[j].index, j);
|
||||||
@ -329,8 +333,9 @@ TEST(SimpleDMatrix, FromCupy){
|
|||||||
EXPECT_EQ(dmat.Info().num_nonzero_, rows*cols);
|
EXPECT_EQ(dmat.Info().num_nonzero_, rows*cols);
|
||||||
|
|
||||||
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto j = 0ull; j < inst.size(); j++) {
|
for (auto j = 0ull; j < inst.size(); j++) {
|
||||||
EXPECT_EQ(inst[j].fvalue, i * cols + j);
|
EXPECT_EQ(inst[j].fvalue, i * cols + j);
|
||||||
EXPECT_EQ(inst[j].index, j);
|
EXPECT_EQ(inst[j].index, j);
|
||||||
@ -354,12 +359,14 @@ TEST(SimpleDMatrix, FromCupySparse){
|
|||||||
EXPECT_EQ(dmat.Info().num_row_, rows);
|
EXPECT_EQ(dmat.Info().num_row_, rows);
|
||||||
EXPECT_EQ(dmat.Info().num_nonzero_, rows * cols - 2);
|
EXPECT_EQ(dmat.Info().num_nonzero_, rows * cols - 2);
|
||||||
auto& batch = *dmat.GetBatches<SparsePage>().begin();
|
auto& batch = *dmat.GetBatches<SparsePage>().begin();
|
||||||
auto inst0 = batch[0];
|
auto page = batch.GetView();
|
||||||
auto inst1 = batch[1];
|
|
||||||
EXPECT_EQ(batch[0].size(), 1);
|
auto inst0 = page[0];
|
||||||
EXPECT_EQ(batch[1].size(), 1);
|
auto inst1 = page[1];
|
||||||
EXPECT_EQ(batch[0][0].fvalue, 0.0f);
|
EXPECT_EQ(page[0].size(), 1);
|
||||||
EXPECT_EQ(batch[0][0].index, 0);
|
EXPECT_EQ(page[1].size(), 1);
|
||||||
EXPECT_EQ(batch[1][0].fvalue, 3.0f);
|
EXPECT_EQ(page[0][0].fvalue, 0.0f);
|
||||||
EXPECT_EQ(batch[1][0].index, 1);
|
EXPECT_EQ(page[0][0].index, 0);
|
||||||
|
EXPECT_EQ(page[1][0].fvalue, 3.0f);
|
||||||
|
EXPECT_EQ(page[1][0].index, 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,7 +39,8 @@ TEST(SparsePageDMatrix, RowAccess) {
|
|||||||
|
|
||||||
// Test the data read into the first row
|
// Test the data read into the first row
|
||||||
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
||||||
auto first_row = batch[0];
|
auto page = batch.GetView();
|
||||||
|
auto first_row = page[0];
|
||||||
ASSERT_EQ(first_row.size(), 3ul);
|
ASSERT_EQ(first_row.size(), 3ul);
|
||||||
EXPECT_EQ(first_row[2].index, 2u);
|
EXPECT_EQ(first_row[2].index, 2u);
|
||||||
EXPECT_EQ(first_row[2].fvalue, 20);
|
EXPECT_EQ(first_row[2].fvalue, 20);
|
||||||
@ -54,16 +55,18 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
||||||
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
auto col_page = col_batch.GetView();
|
||||||
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
EXPECT_EQ(col_page.Size(), dmat->Info().num_col_);
|
||||||
EXPECT_EQ(col_batch[1].size(), 1);
|
EXPECT_EQ(col_page[1][0].fvalue, 10.0f);
|
||||||
|
EXPECT_EQ(col_page[1].size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||||
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
auto col_page = col_batch.GetView();
|
||||||
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
EXPECT_EQ(col_page.Size(), dmat->Info().num_col_);
|
||||||
EXPECT_EQ(col_batch[1].size(), 1);
|
EXPECT_EQ(col_page[1][0].fvalue, 10.0f);
|
||||||
|
EXPECT_EQ(col_page[1].size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
|
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
|
||||||
@ -238,8 +241,9 @@ TEST(SparsePageDMatrix, FromDense) {
|
|||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 6);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 6);
|
||||||
|
|
||||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = page[i];
|
||||||
for (auto j = 0ull; j < inst.size(); j++) {
|
for (auto j = 0ull; j < inst.size(); j++) {
|
||||||
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
|
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
|
||||||
EXPECT_EQ(inst[j].index, j);
|
EXPECT_EQ(inst[j].index, j);
|
||||||
@ -262,19 +266,20 @@ TEST(SparsePageDMatrix, FromCSC) {
|
|||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
|
||||||
|
|
||||||
auto &batch = *dmat.GetBatches<SparsePage>().begin();
|
auto &batch = *dmat.GetBatches<SparsePage>().begin();
|
||||||
auto inst = batch[0];
|
auto page = batch.GetView();
|
||||||
|
auto inst = page[0];
|
||||||
EXPECT_EQ(inst[0].fvalue, 1);
|
EXPECT_EQ(inst[0].fvalue, 1);
|
||||||
EXPECT_EQ(inst[0].index, 0);
|
EXPECT_EQ(inst[0].index, 0);
|
||||||
EXPECT_EQ(inst[1].fvalue, 2);
|
EXPECT_EQ(inst[1].fvalue, 2);
|
||||||
EXPECT_EQ(inst[1].index, 1);
|
EXPECT_EQ(inst[1].index, 1);
|
||||||
|
|
||||||
inst = batch[1];
|
inst = page[1];
|
||||||
EXPECT_EQ(inst[0].fvalue, 3);
|
EXPECT_EQ(inst[0].fvalue, 3);
|
||||||
EXPECT_EQ(inst[0].index, 0);
|
EXPECT_EQ(inst[0].index, 0);
|
||||||
EXPECT_EQ(inst[1].fvalue, 4);
|
EXPECT_EQ(inst[1].fvalue, 4);
|
||||||
EXPECT_EQ(inst[1].index, 1);
|
EXPECT_EQ(inst[1].index, 1);
|
||||||
|
|
||||||
inst = batch[2];
|
inst = page[2];
|
||||||
EXPECT_EQ(inst[0].fvalue, 5);
|
EXPECT_EQ(inst[0].fvalue, 5);
|
||||||
EXPECT_EQ(inst[0].index, 1);
|
EXPECT_EQ(inst[0].index, 1);
|
||||||
}
|
}
|
||||||
@ -294,19 +299,20 @@ TEST(SparsePageDMatrix, FromFile) {
|
|||||||
|
|
||||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||||
std::vector<bst_row_t> expected_offset(batch.Size() + 1);
|
std::vector<bst_row_t> expected_offset(batch.Size() + 1);
|
||||||
|
auto page = batch.GetView();
|
||||||
int n = -3;
|
int n = -3;
|
||||||
std::generate(expected_offset.begin(), expected_offset.end(),
|
std::generate(expected_offset.begin(), expected_offset.end(),
|
||||||
[&n] { return n += 3; });
|
[&n] { return n += 3; });
|
||||||
EXPECT_EQ(batch.offset.HostVector(), expected_offset);
|
EXPECT_EQ(batch.offset.HostVector(), expected_offset);
|
||||||
|
|
||||||
if (batch.base_rowid % 2 == 0) {
|
if (batch.base_rowid % 2 == 0) {
|
||||||
EXPECT_EQ(batch[0][0].index, 0);
|
EXPECT_EQ(page[0][0].index, 0);
|
||||||
EXPECT_EQ(batch[0][1].index, 1);
|
EXPECT_EQ(page[0][1].index, 1);
|
||||||
EXPECT_EQ(batch[0][2].index, 2);
|
EXPECT_EQ(page[0][2].index, 2);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_EQ(batch[0][0].index, 0);
|
EXPECT_EQ(page[0][0].index, 0);
|
||||||
EXPECT_EQ(batch[0][1].index, 3);
|
EXPECT_EQ(page[0][1].index, 3);
|
||||||
EXPECT_EQ(batch[0][2].index, 4);
|
EXPECT_EQ(page[0][2].index, 4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,9 +39,10 @@ TEST(CpuPredictor, Basic) {
|
|||||||
|
|
||||||
// Test predict instance
|
// Test predict instance
|
||||||
auto const &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
auto const &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
||||||
|
auto page = batch.GetView();
|
||||||
for (size_t i = 0; i < batch.Size(); i++) {
|
for (size_t i = 0; i < batch.Size(); i++) {
|
||||||
std::vector<float> instance_out_predictions;
|
std::vector<float> instance_out_predictions;
|
||||||
cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model);
|
cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model);
|
||||||
ASSERT_EQ(instance_out_predictions[0], 1.5);
|
ASSERT_EQ(instance_out_predictions[0], 1.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -72,12 +72,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
|
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
|
||||||
gmat.cut.Ptrs().back());
|
gmat.cut.Ptrs().back());
|
||||||
for (const auto& batch : p_fmat->GetBatches<xgboost::SparsePage>()) {
|
for (const auto& batch : p_fmat->GetBatches<xgboost::SparsePage>()) {
|
||||||
|
auto page = batch.GetView();
|
||||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||||
const size_t rid = batch.base_rowid + i;
|
const size_t rid = batch.base_rowid + i;
|
||||||
ASSERT_LT(rid, num_row);
|
ASSERT_LT(rid, num_row);
|
||||||
const size_t gmat_row_offset = gmat.row_ptr[rid];
|
const size_t gmat_row_offset = gmat.row_ptr[rid];
|
||||||
ASSERT_LT(gmat_row_offset, gmat.index.Size());
|
ASSERT_LT(gmat_row_offset, gmat.index.Size());
|
||||||
SparsePage::Inst inst = batch[i];
|
SparsePage::Inst inst = page[i];
|
||||||
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
|
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
|
||||||
for (size_t j = 0; j < inst.size(); ++j) {
|
for (size_t j = 0; j < inst.size(); ++j) {
|
||||||
// Each entry of GHistIndexMatrix represents a bin ID
|
// Each entry of GHistIndexMatrix represents a bin ID
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user