Implement sketching with adapter. (#8019)
This commit is contained in:
@@ -8,15 +8,19 @@
|
||||
#define XGBOOST_COMMON_QUANTILE_H_
|
||||
|
||||
#include <dmlc/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "threading_utils.h"
|
||||
#include "timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -722,6 +726,69 @@ inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||
|
||||
class HistogramCuts;
|
||||
|
||||
template <typename Batch, typename IsValid>
|
||||
std::vector<bst_row_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_columns,
|
||||
size_t const n_threads, IsValid &&is_valid) {
|
||||
std::vector<std::vector<bst_row_t>> column_sizes_tloc(n_threads);
|
||||
for (auto &column : column_sizes_tloc) {
|
||||
column.resize(n_columns, 0);
|
||||
}
|
||||
|
||||
ParallelFor(batch.Size(), n_threads, [&](omp_ulong i) {
|
||||
auto &local_column_sizes = column_sizes_tloc.at(omp_get_thread_num());
|
||||
auto const &line = batch.GetLine(i);
|
||||
for (size_t j = 0; j < line.Size(); ++j) {
|
||||
auto elem = line.GetElement(j);
|
||||
if (is_valid(elem)) {
|
||||
local_column_sizes[elem.column_idx]++;
|
||||
}
|
||||
}
|
||||
});
|
||||
// reduce to first thread
|
||||
auto &entries_per_columns = column_sizes_tloc.front();
|
||||
CHECK_EQ(entries_per_columns.size(), static_cast<size_t>(n_columns));
|
||||
for (size_t i = 1; i < n_threads; ++i) {
|
||||
CHECK_EQ(column_sizes_tloc[i].size(), static_cast<size_t>(n_columns));
|
||||
for (size_t j = 0; j < n_columns; ++j) {
|
||||
entries_per_columns[j] += column_sizes_tloc[i][j];
|
||||
}
|
||||
}
|
||||
return entries_per_columns;
|
||||
}
|
||||
|
||||
template <typename Batch, typename IsValid>
|
||||
std::vector<bst_feature_t> LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns,
|
||||
size_t const nthreads, IsValid&& is_valid) {
|
||||
/* Some sparse datasets have their mass concentrating on small number of features. To
|
||||
* avoid waiting for a few threads running forever, we here distribute different number
|
||||
* of columns to different threads according to number of entries.
|
||||
*/
|
||||
size_t const total_entries = nnz;
|
||||
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);
|
||||
|
||||
// Need to calculate the size for each batch.
|
||||
std::vector<bst_row_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
|
||||
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
|
||||
size_t count{0};
|
||||
size_t current_thread{1};
|
||||
|
||||
for (auto col : entries_per_columns) {
|
||||
cols_ptr.at(current_thread)++; // add one column to thread
|
||||
count += col;
|
||||
CHECK_LE(count, total_entries);
|
||||
if (count > entries_per_thread) {
|
||||
current_thread++;
|
||||
count = 0;
|
||||
cols_ptr.at(current_thread) = cols_ptr[current_thread - 1];
|
||||
}
|
||||
}
|
||||
// Idle threads.
|
||||
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
|
||||
cols_ptr[current_thread + 1] = cols_ptr[current_thread];
|
||||
}
|
||||
return cols_ptr;
|
||||
}
|
||||
|
||||
/*!
|
||||
* A sketch matrix storing sketches for each feature.
|
||||
*/
|
||||
@@ -759,14 +826,6 @@ class SketchContainerImpl {
|
||||
return use_group_ind;
|
||||
}
|
||||
|
||||
static std::vector<bst_row_t> CalcColumnSize(SparsePage const &page,
|
||||
bst_feature_t const n_columns,
|
||||
size_t const nthreads);
|
||||
|
||||
static std::vector<bst_feature_t> LoadBalance(SparsePage const &page,
|
||||
bst_feature_t n_columns,
|
||||
size_t const nthreads);
|
||||
|
||||
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
|
||||
size_t const base_rowid) {
|
||||
CHECK_LT(base_rowid, group_ptr.back())
|
||||
@@ -785,6 +844,54 @@ class SketchContainerImpl {
|
||||
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t> *p_num_cuts);
|
||||
|
||||
template <typename Batch, typename IsValid>
|
||||
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
|
||||
size_t n_features, bool is_dense, IsValid is_valid) {
|
||||
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(n_threads_)
|
||||
{
|
||||
exc.Run([&]() {
|
||||
auto tid = static_cast<uint32_t>(omp_get_thread_num());
|
||||
auto const begin = thread_columns_ptr[tid];
|
||||
auto const end = thread_columns_ptr[tid + 1];
|
||||
|
||||
// do not iterate if no columns are assigned to the thread
|
||||
if (begin < end && end <= n_features) {
|
||||
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
|
||||
auto const &line = batch.GetLine(ridx);
|
||||
auto w = weights[ridx + base_rowid];
|
||||
if (is_dense) {
|
||||
for (size_t ii = begin; ii < end; ii++) {
|
||||
auto elem = line.GetElement(ii);
|
||||
if (is_valid(elem)) {
|
||||
if (IsCat(feature_types_, ii)) {
|
||||
categories_[ii].emplace(elem.value);
|
||||
} else {
|
||||
sketches_[ii].Push(elem.value, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < line.Size(); ++i) {
|
||||
auto const &elem = line.GetElement(i);
|
||||
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
|
||||
if (IsCat(feature_types_, elem.column_idx)) {
|
||||
categories_[elem.column_idx].emplace(elem.value);
|
||||
} else {
|
||||
sketches_[elem.column_idx].Push(elem.value, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
|
||||
/* \brief Push a CSR matrix. */
|
||||
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
||||
|
||||
@@ -798,6 +905,10 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
|
||||
public:
|
||||
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
|
||||
bool use_group, int32_t n_threads);
|
||||
|
||||
template <typename Batch>
|
||||
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz,
|
||||
float missing);
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user