Implement sketching with adapter. (#8019)
This commit is contained in:
parent
142a208a90
commit
f0c1b842bf
@ -33,6 +33,40 @@ HistogramCuts::HistogramCuts() {
|
|||||||
cut_ptrs_.HostVector().emplace_back(0);
|
cut_ptrs_.HostVector().emplace_back(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, bool use_sorted,
|
||||||
|
Span<float> const hessian) {
|
||||||
|
HistogramCuts out;
|
||||||
|
auto const& info = m->Info();
|
||||||
|
std::vector<bst_row_t> reduced(info.num_col_, 0);
|
||||||
|
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||||
|
auto const &entries_per_column =
|
||||||
|
CalcColumnSize(data::SparsePageAdapterBatch{page.GetView()}, info.num_col_, n_threads,
|
||||||
|
[](auto) { return true; });
|
||||||
|
CHECK_EQ(entries_per_column.size(), info.num_col_);
|
||||||
|
for (size_t i = 0; i < entries_per_column.size(); ++i) {
|
||||||
|
reduced[i] += entries_per_column[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!use_sorted) {
|
||||||
|
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
|
||||||
|
n_threads);
|
||||||
|
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||||
|
container.PushRowPage(page, info, hessian);
|
||||||
|
}
|
||||||
|
container.MakeCuts(&out);
|
||||||
|
} else {
|
||||||
|
SortedSketchContainer container{max_bins, m->Info(), reduced,
|
||||||
|
HostSketchContainer::UseGroup(info), n_threads};
|
||||||
|
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||||
|
container.PushColPage(page, info, hessian);
|
||||||
|
}
|
||||||
|
container.MakeCuts(&out);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief fill a histogram by zeros in range [begin, end)
|
* \brief fill a histogram by zeros in range [begin, end)
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -152,41 +152,8 @@ class HistogramCuts {
|
|||||||
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
||||||
* but consumes more memory.
|
* but consumes more memory.
|
||||||
*/
|
*/
|
||||||
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
|
HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
|
||||||
bool use_sorted = false, Span<float> const hessian = {}) {
|
bool use_sorted = false, Span<float> const hessian = {});
|
||||||
HistogramCuts out;
|
|
||||||
auto const& info = m->Info();
|
|
||||||
std::vector<std::vector<bst_row_t>> column_sizes(n_threads);
|
|
||||||
for (auto& column : column_sizes) {
|
|
||||||
column.resize(info.num_col_, 0);
|
|
||||||
}
|
|
||||||
std::vector<bst_row_t> reduced(info.num_col_, 0);
|
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
|
||||||
auto const& entries_per_column =
|
|
||||||
HostSketchContainer::CalcColumnSize(page, info.num_col_, n_threads);
|
|
||||||
for (size_t i = 0; i < entries_per_column.size(); ++i) {
|
|
||||||
reduced[i] += entries_per_column[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!use_sorted) {
|
|
||||||
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
|
|
||||||
n_threads);
|
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
|
||||||
container.PushRowPage(page, info, hessian);
|
|
||||||
}
|
|
||||||
container.MakeCuts(&out);
|
|
||||||
} else {
|
|
||||||
SortedSketchContainer container{max_bins, m->Info(), reduced,
|
|
||||||
HostSketchContainer::UseGroup(info), n_threads};
|
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
|
||||||
container.PushColPage(page, info, hessian);
|
|
||||||
}
|
|
||||||
container.MakeCuts(&out);
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum BinTypeSize : uint8_t {
|
enum BinTypeSize : uint8_t {
|
||||||
kUint8BinsTypeSize = 1,
|
kUint8BinsTypeSize = 1,
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../data/adapter.h"
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
#include "rabit/rabit.h"
|
#include "rabit/rabit.h"
|
||||||
@ -31,72 +32,6 @@ SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> column
|
|||||||
has_categorical_ = std::any_of(feature_types_.cbegin(), feature_types_.cend(), IsCatOp{});
|
has_categorical_ = std::any_of(feature_types_.cbegin(), feature_types_.cend(), IsCatOp{});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename WQSketch>
|
|
||||||
std::vector<bst_row_t> SketchContainerImpl<WQSketch>::CalcColumnSize(SparsePage const &batch,
|
|
||||||
bst_feature_t const n_columns,
|
|
||||||
size_t const nthreads) {
|
|
||||||
auto page = batch.GetView();
|
|
||||||
std::vector<std::vector<bst_row_t>> column_sizes(nthreads);
|
|
||||||
for (auto &column : column_sizes) {
|
|
||||||
column.resize(n_columns, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
ParallelFor(page.Size(), nthreads, [&](omp_ulong i) {
|
|
||||||
auto &local_column_sizes = column_sizes.at(omp_get_thread_num());
|
|
||||||
auto row = page[i];
|
|
||||||
auto const *p_row = row.data();
|
|
||||||
for (size_t j = 0; j < row.size(); ++j) {
|
|
||||||
local_column_sizes.at(p_row[j].index)++;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
std::vector<bst_row_t> entries_per_columns(n_columns, 0);
|
|
||||||
ParallelFor(n_columns, nthreads, [&](bst_omp_uint i) {
|
|
||||||
for (auto const &thread : column_sizes) {
|
|
||||||
entries_per_columns[i] += thread[i];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return entries_per_columns;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename WQSketch>
|
|
||||||
std::vector<bst_feature_t> SketchContainerImpl<WQSketch>::LoadBalance(SparsePage const &batch,
|
|
||||||
bst_feature_t n_columns,
|
|
||||||
size_t const nthreads) {
|
|
||||||
/* Some sparse datasets have their mass concentrating on small number of features. To
|
|
||||||
* avoid waiting for a few threads running forever, we here distribute different number
|
|
||||||
* of columns to different threads according to number of entries.
|
|
||||||
*/
|
|
||||||
auto page = batch.GetView();
|
|
||||||
size_t const total_entries = page.data.size();
|
|
||||||
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);
|
|
||||||
|
|
||||||
std::vector<std::vector<bst_row_t>> column_sizes(nthreads);
|
|
||||||
for (auto& column : column_sizes) {
|
|
||||||
column.resize(n_columns, 0);
|
|
||||||
}
|
|
||||||
std::vector<bst_row_t> entries_per_columns =
|
|
||||||
CalcColumnSize(batch, n_columns, nthreads);
|
|
||||||
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
|
|
||||||
size_t count {0};
|
|
||||||
size_t current_thread {1};
|
|
||||||
|
|
||||||
for (auto col : entries_per_columns) {
|
|
||||||
cols_ptr.at(current_thread)++; // add one column to thread
|
|
||||||
count += col;
|
|
||||||
CHECK_LE(count, total_entries);
|
|
||||||
if (count > entries_per_thread) {
|
|
||||||
current_thread++;
|
|
||||||
count = 0;
|
|
||||||
cols_ptr.at(current_thread) = cols_ptr[current_thread-1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Idle threads.
|
|
||||||
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
|
|
||||||
cols_ptr[current_thread+1] = cols_ptr[current_thread];
|
|
||||||
}
|
|
||||||
return cols_ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Function to merge hessian and sample weights
|
// Function to merge hessian and sample weights
|
||||||
std::vector<float> MergeWeights(MetaInfo const &info, Span<float const> hessian, bool use_group,
|
std::vector<float> MergeWeights(MetaInfo const &info, Span<float const> hessian, bool use_group,
|
||||||
@ -143,54 +78,37 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
|
|||||||
CHECK_EQ(weights.size(), info.num_row_);
|
CHECK_EQ(weights.size(), info.num_row_);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto batch = page.GetView();
|
auto batch = data::SparsePageAdapterBatch{page.GetView()};
|
||||||
// Parallel over columns. Each thread owns a set of consecutive columns.
|
this->PushRowPageImpl(batch, page.base_rowid, OptionalWeights{weights}, page.data.Size(),
|
||||||
auto const ncol = static_cast<bst_feature_t>(info.num_col_);
|
info.num_col_, is_dense, [](auto) { return true; });
|
||||||
auto thread_columns_ptr = LoadBalance(page, info.num_col_, n_threads_);
|
|
||||||
|
|
||||||
dmlc::OMPException exc;
|
|
||||||
#pragma omp parallel num_threads(n_threads_)
|
|
||||||
{
|
|
||||||
exc.Run([&]() {
|
|
||||||
auto tid = static_cast<uint32_t>(omp_get_thread_num());
|
|
||||||
auto const begin = thread_columns_ptr[tid];
|
|
||||||
auto const end = thread_columns_ptr[tid + 1];
|
|
||||||
|
|
||||||
// do not iterate if no columns are assigned to the thread
|
|
||||||
if (begin < end && end <= ncol) {
|
|
||||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
|
||||||
size_t const ridx = page.base_rowid + i;
|
|
||||||
SparsePage::Inst const inst = batch[i];
|
|
||||||
auto w = weights.empty() ? 1.0f : weights[ridx];
|
|
||||||
auto p_inst = inst.data();
|
|
||||||
if (is_dense) {
|
|
||||||
for (size_t ii = begin; ii < end; ii++) {
|
|
||||||
if (IsCat(feature_types_, ii)) {
|
|
||||||
categories_[ii].emplace(p_inst[ii].fvalue);
|
|
||||||
} else {
|
|
||||||
sketches_[ii].Push(p_inst[ii].fvalue, w);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < inst.size(); ++i) {
|
|
||||||
auto const& entry = p_inst[i];
|
|
||||||
if (entry.index >= begin && entry.index < end) {
|
|
||||||
if (IsCat(feature_types_, entry.index)) {
|
|
||||||
categories_[entry.index].emplace(entry.fvalue);
|
|
||||||
} else {
|
|
||||||
sketches_[entry.index].Push(entry.fvalue, w);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
exc.Rethrow();
|
|
||||||
monitor_.Stop(__func__);
|
monitor_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Batch>
|
||||||
|
void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid,
|
||||||
|
MetaInfo const &info, size_t nnz, float missing) {
|
||||||
|
auto const &h_weights =
|
||||||
|
(use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector());
|
||||||
|
|
||||||
|
auto is_valid = data::IsValidFunctor{missing};
|
||||||
|
auto weights = OptionalWeights{Span<float const>{h_weights}};
|
||||||
|
// the nnz from info is not reliable as sketching might be the first place to go through
|
||||||
|
// the data.
|
||||||
|
auto is_dense = nnz == info.num_col_ * info.num_row_;
|
||||||
|
this->PushRowPageImpl(batch, base_rowid, weights, nnz, info.num_col_, is_dense, is_valid);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INSTANTIATE(_type) \
|
||||||
|
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
|
||||||
|
data::_type const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz, \
|
||||||
|
float missing);
|
||||||
|
|
||||||
|
INSTANTIATE(ArrayAdapterBatch)
|
||||||
|
INSTANTIATE(CSRArrayAdapterBatch)
|
||||||
|
INSTANTIATE(CSCAdapterBatch)
|
||||||
|
INSTANTIATE(DataTableAdapterBatch)
|
||||||
|
INSTANTIATE(SparsePageAdapterBatch)
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/**
|
/**
|
||||||
* \brief A view over gathered sketch values.
|
* \brief A view over gathered sketch values.
|
||||||
|
|||||||
@ -8,15 +8,19 @@
|
|||||||
#define XGBOOST_COMMON_QUANTILE_H_
|
#define XGBOOST_COMMON_QUANTILE_H_
|
||||||
|
|
||||||
#include <dmlc/base.h>
|
#include <dmlc/base.h>
|
||||||
#include <xgboost/logging.h>
|
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <cmath>
|
#include <xgboost/logging.h>
|
||||||
#include <vector>
|
|
||||||
#include <cstring>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "categorical.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "threading_utils.h"
|
||||||
#include "timer.h"
|
#include "timer.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -722,6 +726,69 @@ inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
|||||||
|
|
||||||
class HistogramCuts;
|
class HistogramCuts;
|
||||||
|
|
||||||
|
template <typename Batch, typename IsValid>
|
||||||
|
std::vector<bst_row_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_columns,
|
||||||
|
size_t const n_threads, IsValid &&is_valid) {
|
||||||
|
std::vector<std::vector<bst_row_t>> column_sizes_tloc(n_threads);
|
||||||
|
for (auto &column : column_sizes_tloc) {
|
||||||
|
column.resize(n_columns, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
ParallelFor(batch.Size(), n_threads, [&](omp_ulong i) {
|
||||||
|
auto &local_column_sizes = column_sizes_tloc.at(omp_get_thread_num());
|
||||||
|
auto const &line = batch.GetLine(i);
|
||||||
|
for (size_t j = 0; j < line.Size(); ++j) {
|
||||||
|
auto elem = line.GetElement(j);
|
||||||
|
if (is_valid(elem)) {
|
||||||
|
local_column_sizes[elem.column_idx]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// reduce to first thread
|
||||||
|
auto &entries_per_columns = column_sizes_tloc.front();
|
||||||
|
CHECK_EQ(entries_per_columns.size(), static_cast<size_t>(n_columns));
|
||||||
|
for (size_t i = 1; i < n_threads; ++i) {
|
||||||
|
CHECK_EQ(column_sizes_tloc[i].size(), static_cast<size_t>(n_columns));
|
||||||
|
for (size_t j = 0; j < n_columns; ++j) {
|
||||||
|
entries_per_columns[j] += column_sizes_tloc[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return entries_per_columns;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Batch, typename IsValid>
|
||||||
|
std::vector<bst_feature_t> LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns,
|
||||||
|
size_t const nthreads, IsValid&& is_valid) {
|
||||||
|
/* Some sparse datasets have their mass concentrating on small number of features. To
|
||||||
|
* avoid waiting for a few threads running forever, we here distribute different number
|
||||||
|
* of columns to different threads according to number of entries.
|
||||||
|
*/
|
||||||
|
size_t const total_entries = nnz;
|
||||||
|
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);
|
||||||
|
|
||||||
|
// Need to calculate the size for each batch.
|
||||||
|
std::vector<bst_row_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
|
||||||
|
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
|
||||||
|
size_t count{0};
|
||||||
|
size_t current_thread{1};
|
||||||
|
|
||||||
|
for (auto col : entries_per_columns) {
|
||||||
|
cols_ptr.at(current_thread)++; // add one column to thread
|
||||||
|
count += col;
|
||||||
|
CHECK_LE(count, total_entries);
|
||||||
|
if (count > entries_per_thread) {
|
||||||
|
current_thread++;
|
||||||
|
count = 0;
|
||||||
|
cols_ptr.at(current_thread) = cols_ptr[current_thread - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Idle threads.
|
||||||
|
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
|
||||||
|
cols_ptr[current_thread + 1] = cols_ptr[current_thread];
|
||||||
|
}
|
||||||
|
return cols_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* A sketch matrix storing sketches for each feature.
|
* A sketch matrix storing sketches for each feature.
|
||||||
*/
|
*/
|
||||||
@ -759,14 +826,6 @@ class SketchContainerImpl {
|
|||||||
return use_group_ind;
|
return use_group_ind;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<bst_row_t> CalcColumnSize(SparsePage const &page,
|
|
||||||
bst_feature_t const n_columns,
|
|
||||||
size_t const nthreads);
|
|
||||||
|
|
||||||
static std::vector<bst_feature_t> LoadBalance(SparsePage const &page,
|
|
||||||
bst_feature_t n_columns,
|
|
||||||
size_t const nthreads);
|
|
||||||
|
|
||||||
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
|
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
|
||||||
size_t const base_rowid) {
|
size_t const base_rowid) {
|
||||||
CHECK_LT(base_rowid, group_ptr.back())
|
CHECK_LT(base_rowid, group_ptr.back())
|
||||||
@ -785,6 +844,54 @@ class SketchContainerImpl {
|
|||||||
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||||
std::vector<int32_t> *p_num_cuts);
|
std::vector<int32_t> *p_num_cuts);
|
||||||
|
|
||||||
|
template <typename Batch, typename IsValid>
|
||||||
|
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
|
||||||
|
size_t n_features, bool is_dense, IsValid is_valid) {
|
||||||
|
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
|
||||||
|
|
||||||
|
dmlc::OMPException exc;
|
||||||
|
#pragma omp parallel num_threads(n_threads_)
|
||||||
|
{
|
||||||
|
exc.Run([&]() {
|
||||||
|
auto tid = static_cast<uint32_t>(omp_get_thread_num());
|
||||||
|
auto const begin = thread_columns_ptr[tid];
|
||||||
|
auto const end = thread_columns_ptr[tid + 1];
|
||||||
|
|
||||||
|
// do not iterate if no columns are assigned to the thread
|
||||||
|
if (begin < end && end <= n_features) {
|
||||||
|
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
|
||||||
|
auto const &line = batch.GetLine(ridx);
|
||||||
|
auto w = weights[ridx + base_rowid];
|
||||||
|
if (is_dense) {
|
||||||
|
for (size_t ii = begin; ii < end; ii++) {
|
||||||
|
auto elem = line.GetElement(ii);
|
||||||
|
if (is_valid(elem)) {
|
||||||
|
if (IsCat(feature_types_, ii)) {
|
||||||
|
categories_[ii].emplace(elem.value);
|
||||||
|
} else {
|
||||||
|
sketches_[ii].Push(elem.value, w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < line.Size(); ++i) {
|
||||||
|
auto const &elem = line.GetElement(i);
|
||||||
|
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
|
||||||
|
if (IsCat(feature_types_, elem.column_idx)) {
|
||||||
|
categories_[elem.column_idx].emplace(elem.value);
|
||||||
|
} else {
|
||||||
|
sketches_[elem.column_idx].Push(elem.value, w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
exc.Rethrow();
|
||||||
|
}
|
||||||
|
|
||||||
/* \brief Push a CSR matrix. */
|
/* \brief Push a CSR matrix. */
|
||||||
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
||||||
|
|
||||||
@ -798,6 +905,10 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
|
|||||||
public:
|
public:
|
||||||
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
|
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
|
||||||
bool use_group, int32_t n_threads);
|
bool use_group, int32_t n_threads);
|
||||||
|
|
||||||
|
template <typename Batch>
|
||||||
|
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz,
|
||||||
|
float missing);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -29,9 +29,6 @@ class GHistIndexMatrix {
|
|||||||
/**
|
/**
|
||||||
* \brief Push a page into index matrix, the function is only necessary because hist has
|
* \brief Push a page into index matrix, the function is only necessary because hist has
|
||||||
* partial support for external memory.
|
* partial support for external memory.
|
||||||
*
|
|
||||||
* \param rbegin The beginning row index of current page. (total rows in previous pages)
|
|
||||||
* \param prev_sum Total number of entries in previous pages.
|
|
||||||
*/
|
*/
|
||||||
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft,
|
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft,
|
||||||
bst_bin_t n_total_bins, int32_t n_threads);
|
bst_bin_t n_total_bins, int32_t n_threads);
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2020-2022 by XGBoost Contributors
|
* Copyright 2020-2022 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "test_quantile.h"
|
#include "test_quantile.h"
|
||||||
#include "../../../src/common/quantile.h"
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../src/common/hist_util.h"
|
#include "../../../src/common/hist_util.h"
|
||||||
|
#include "../../../src/common/quantile.h"
|
||||||
|
#include "../../../src/data/adapter.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -13,8 +16,9 @@ TEST(Quantile, LoadBalance) {
|
|||||||
size_t constexpr kRows = 1000, kCols = 100;
|
size_t constexpr kRows = 1000, kCols = 100;
|
||||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||||
std::vector<bst_feature_t> cols_ptr;
|
std::vector<bst_feature_t> cols_ptr;
|
||||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||||
cols_ptr = HostSketchContainer::LoadBalance(page, kCols, 13);
|
data::SparsePageAdapterBatch adapter{page.GetView()};
|
||||||
|
cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; });
|
||||||
}
|
}
|
||||||
size_t n_cols = 0;
|
size_t n_cols = 0;
|
||||||
for (size_t i = 1; i < cols_ptr.size(); ++i) {
|
for (size_t i = 1; i < cols_ptr.size(); ++i) {
|
||||||
@ -22,6 +26,7 @@ TEST(Quantile, LoadBalance) {
|
|||||||
}
|
}
|
||||||
CHECK_EQ(n_cols, kCols);
|
CHECK_EQ(n_cols, kCols);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <bool use_column>
|
template <bool use_column>
|
||||||
using ContainerType = std::conditional_t<use_column, SortedSketchContainer, HostSketchContainer>;
|
using ContainerType = std::conditional_t<use_column, SortedSketchContainer, HostSketchContainer>;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user