Fix CPU hist init for sparse dataset. (#4625)

* Fix CPU hist init for sparse dataset.

* Implement sparse histogram cut.
* Allow empty features.

* Fix windows build, don't use sparse in distributed environment.

* Comments.

* Smaller threshold.

* Fix windows omp.

* Fix msvc lambda capture.

* Fix MSVC macro.

* Fix MSVC initialization list.

* Fix MSVC initialization list x2.

* Preserve categorical feature behavior.

* Rename matrix to sparse cuts.
* Reuse UseGroup.
* Check for categorical data when adding cut.

Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>

* Sanity check.

* Fix comments.

* Fix comment.
This commit is contained in:
Jiaming Yuan 2019-07-04 19:27:03 -04:00 committed by Philip Hyunsu Cho
parent b7a1f22d24
commit d9a47794a5
33 changed files with 681 additions and 299 deletions

View File

@ -1,4 +1,4 @@
Checks: 'modernize-*,-modernize-make-*,-modernize-use-auto,-modernize-raw-string-literal,google-*,-google-default-arguments,-clang-diagnostic-#pragma-messages,readability-identifier-naming'
Checks: 'modernize-*,-modernize-make-*,-modernize-use-auto,-modernize-raw-string-literal,-modernize-avoid-c-arrays,google-*,-google-default-arguments,-clang-diagnostic-#pragma-messages,readability-identifier-naming'
CheckOptions:
- { key: readability-identifier-naming.ClassCase, value: CamelCase }
- { key: readability-identifier-naming.StructCase, value: CamelCase }

View File

@ -437,6 +437,7 @@ class DMatrix {
bool load_row_split,
const std::string& file_format = "auto",
const size_t page_size = kPageSize);
/*!
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
* \param source The source iterator of the data, the create function takes ownership of the source.

View File

@ -119,8 +119,7 @@ class NativeDataIter : public dmlc::Parser<uint32_t> {
}
bool Next() override {
if ((*next_callback_)(
data_handle_,
if ((*next_callback_)(data_handle_,
XGBoostNativeDataIterSetData,
this) != 0) {
at_first_ = false;

View File

@ -75,7 +75,7 @@ class ColumnMatrix {
// construct column matrix from GHistIndexMatrix
inline void Init(const GHistIndexMatrix& gmat,
double sparse_threshold) {
const int32_t nfeature = static_cast<int32_t>(gmat.cut.row_ptr.size() - 1);
const int32_t nfeature = static_cast<int32_t>(gmat.cut.Ptrs().size() - 1);
const size_t nrow = gmat.row_ptr.size() - 1;
// identify type of each column
@ -85,7 +85,7 @@ class ColumnMatrix {
uint32_t max_val = std::numeric_limits<uint32_t>::max();
for (int32_t fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val);
CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val);
}
gmat.GetFeatureCounts(&feature_counts_[0]);
@ -123,7 +123,7 @@ class ColumnMatrix {
// store least bin id for each feature
index_base_.resize(nfeature);
for (int32_t fid = 0; fid < nfeature; ++fid) {
index_base_[fid] = gmat.cut.row_ptr[fid];
index_base_[fid] = gmat.cut.Ptrs()[fid];
}
// pre-fill index_ for dense columns
@ -150,9 +150,9 @@ class ColumnMatrix {
size_t fid = 0;
for (size_t i = ibegin; i < iend; ++i) {
const uint32_t bin_id = gmat.index[i];
while (bin_id >= gmat.cut.row_ptr[fid + 1]) {
++fid;
}
auto iter = std::upper_bound(gmat.cut.Ptrs().cbegin() + fid,
gmat.cut.Ptrs().cend(), bin_id);
fid = std::distance(gmat.cut.Ptrs().cbegin(), iter) - 1;
if (type_[fid] == kDenseColumn) {
uint32_t* begin = &index_[boundary_[fid].index_begin];
begin[rid] = bin_id - index_base_[fid];

View File

@ -72,6 +72,11 @@ inline std::string ToString(const T& data) {
return os.str();
}
template <typename T1, typename T2>
XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) {
return static_cast<T1>(std::ceil(static_cast<double>(a) / b));
}
/*
* Range iterator
*/

View File

@ -30,12 +30,13 @@ class ConfigParser {
* \param path path to configuration file
*/
explicit ConfigParser(const std::string& path)
: line_comment_regex_("^#"),
: path_(path),
line_comment_regex_("^#"),
key_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*=)rx"),
key_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*=)rx"),
value_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"),
value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx"),
path_(path) {}
value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx")
{}
std::string LoadConfigFile(const std::string& path) {
std::ifstream fin(path, std::ios_base::in | std::ios_base::binary);
@ -77,8 +78,6 @@ class ConfigParser {
content = NormalizeConfigEOL(content);
std::stringstream ss { content };
std::vector<std::pair<std::string, std::string>> results;
char delimiter = '=';
char comment = '#';
std::string line;
std::string key, value;
// Loop over every line of the configuration file

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017 XGBoost contributors
* Copyright 2017-2019 XGBoost contributors
*/
#pragma once
#include <thrust/device_ptr.h>
@ -183,11 +183,6 @@ __device__ void BlockFill(IterT begin, size_t n, ValueT value) {
* Kernel launcher
*/
template <typename T1, typename T2>
T1 DivRoundUp(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
template <typename L>
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
for (auto i : GridStrideRange(begin, end)) {
@ -211,7 +206,7 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
safe_cuda(cudaSetDevice(device_idx));
const int GRID_SIZE =
static_cast<int>(DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>(static_cast<size_t>(0),
n, lambda);
}
@ -738,7 +733,7 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
const int BLOCK_THREADS = 256;
const int ITEMS_PER_THREAD = 1;
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
auto num_tiles = dh::DivRoundUp(count + num_segments, BLOCK_THREADS);
auto num_tiles = xgboost::common::DivRoundUp(count + num_segments, BLOCK_THREADS);
CHECK(num_tiles < std::numeric_limits<unsigned int>::max());
temp_memory->LazyAllocate(sizeof(CoordinateT) * (num_tiles + 1));

View File

@ -25,25 +25,206 @@
namespace xgboost {
namespace common {
HistCutMatrix::HistCutMatrix() {
monitor_.Init("HistCutMatrix");
HistogramCuts::HistogramCuts() {
monitor_.Init(__FUNCTION__);
cut_ptrs_.emplace_back(0);
}
size_t HistCutMatrix::SearchGroupIndFromBaseRow(
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) const {
using KIt = std::vector<bst_uint>::const_iterator;
KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
// Cannot use CHECK_NE because it will try to print the iterator.
bool const found = res != group_ptr.cend() - 1;
if (!found) {
LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!\n";
// Dispatch to specific builder.
void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {
auto const& info = dmat->Info();
size_t const total = info.num_row_ * info.num_col_;
size_t const nnz = info.num_nonzero_;
float const sparsity = static_cast<float>(nnz) / static_cast<float>(total);
// Use a small number to avoid calling `dmat->GetColumnBatches'.
float constexpr kSparsityThreshold = 0.0005;
// FIXME(trivialfis): Distributed environment is not supported.
if (sparsity < kSparsityThreshold && (!rabit::IsDistributed())) {
LOG(INFO) << "Building quantile cut on a sparse dataset.";
SparseCuts cuts(this);
cuts.Build(dmat, max_num_bins);
} else {
LOG(INFO) << "Building quantile cut on a dense dataset or distributed environment.";
DenseCuts cuts(this);
cuts.Build(dmat, max_num_bins);
}
size_t group_ind = std::distance(group_ptr.cbegin(), res);
return group_ind;
}
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
monitor_.Start("Init");
bool CutsBuilder::UseGroup(DMatrix* dmat) {
auto& info = dmat->Info();
size_t const num_groups = info.group_ptr_.size() == 0 ?
0 : info.group_ptr_.size() - 1;
// Use group index for weights?
bool const use_group_ind = num_groups != 0 &&
(info.weights_.Size() != info.num_row_);
return use_group_ind;
}
void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
uint32_t max_num_bins,
bool const use_group_ind,
uint32_t beg_col, uint32_t end_col,
uint32_t thread_id) {
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
CHECK_GE(end_col, beg_col);
constexpr float kFactor = 8;
// Data groups, used in ranking.
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
p_cuts_->min_vals_.resize(end_col - beg_col, 0);
for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) {
// Using a local variable makes things easier, but at the cost of memory trashing.
WXQSketch sketch;
common::Span<xgboost::Entry const> const column = page[col_id];
uint32_t const n_bins = std::min(static_cast<uint32_t>(column.size()),
max_num_bins);
if (n_bins == 0) {
// cut_ptrs_ is initialized with a zero, so there's always an element at the back
p_cuts_->cut_ptrs_.emplace_back(p_cuts_->cut_ptrs_.back());
continue;
}
sketch.Init(info.num_row_, 1.0 / (n_bins * kFactor));
for (auto const& entry : column) {
uint32_t weight_ind = 0;
if (use_group_ind) {
auto row_idx = entry.index;
uint32_t group_ind =
this->SearchGroupIndFromRow(group_ptr, page.base_rowid + row_idx);
weight_ind = group_ind;
} else {
weight_ind = entry.index;
}
sketch.Push(entry.fvalue, info.GetWeight(weight_ind));
}
WXQSketch::SummaryContainer out_summary;
sketch.GetSummary(&out_summary);
WXQSketch::SummaryContainer summary;
summary.Reserve(n_bins);
summary.SetPrune(out_summary, n_bins);
// Can be use data[1] as the min values so that we don't need to
// store another array?
float mval = summary.data[0].value;
p_cuts_->min_vals_[col_id - beg_col] = mval - (fabs(mval) + 1e-5);
this->AddCutPoint(summary);
bst_float cpt = (summary.size > 0) ?
summary.data[summary.size - 1].value :
p_cuts_->min_vals_[col_id - beg_col];
cpt += fabs(cpt) + 1e-5;
p_cuts_->cut_values_.emplace_back(cpt);
p_cuts_->cut_ptrs_.emplace_back(p_cuts_->cut_values_.size());
}
}
std::vector<size_t> SparseCuts::LoadBalance(SparsePage const& page,
size_t const nthreads) {
/* Some sparse datasets have their mass concentrating on small
* number of features. To avoid wating for a few threads running
* forever, we here distirbute different number of columns to
* different threads according to number of entries. */
size_t const total_entries = page.data.Size();
size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads);
std::vector<size_t> cols_ptr(nthreads+1, 0);
size_t count {0};
size_t current_thread {1};
for (size_t col_id = 0; col_id < page.Size(); ++col_id) {
auto const column = page[col_id];
cols_ptr[current_thread]++; // add one column to thread
count += column.size();
if (count > entries_per_thread + 1) {
current_thread++;
count = 0;
cols_ptr[current_thread] = cols_ptr[current_thread-1];
}
}
// Idle threads.
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
cols_ptr[current_thread+1] = cols_ptr[current_thread];
}
return cols_ptr;
}
void SparseCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {
monitor_.Start(__FUNCTION__);
// Use group index for weights?
auto use_group = UseGroup(dmat);
uint32_t nthreads = omp_get_max_threads();
CHECK_GT(nthreads, 0);
std::vector<HistogramCuts> cuts_containers(nthreads);
std::vector<std::unique_ptr<SparseCuts>> sparse_cuts(nthreads);
for (size_t i = 0; i < nthreads; ++i) {
sparse_cuts[i].reset(new SparseCuts(&cuts_containers[i]));
}
for (auto const& page : dmat->GetColumnBatches()) {
CHECK_LE(page.Size(), dmat->Info().num_col_);
monitor_.Start("Load balance");
std::vector<size_t> col_ptr = LoadBalance(page, nthreads);
monitor_.Stop("Load balance");
// We here decouples the logic between build and parallelization
// to simplify things a bit.
#pragma omp parallel for num_threads(nthreads) schedule(static)
for (omp_ulong i = 0; i < nthreads; ++i) {
common::Monitor t_monitor;
t_monitor.Init("SingleThreadBuild: " + std::to_string(i));
t_monitor.Start(std::to_string(i));
sparse_cuts[i]->SingleThreadBuild(page, dmat->Info(), max_num_bins, use_group,
col_ptr[i], col_ptr[i+1], i);
t_monitor.Stop(std::to_string(i));
}
this->Concat(sparse_cuts, dmat->Info().num_col_);
}
monitor_.Stop(__FUNCTION__);
}
void SparseCuts::Concat(
std::vector<std::unique_ptr<SparseCuts>> const& cuts, uint32_t n_cols) {
monitor_.Start(__FUNCTION__);
uint32_t nthreads = omp_get_max_threads();
p_cuts_->min_vals_.resize(n_cols, std::numeric_limits<float>::max());
size_t min_vals_tail = 0;
for (uint32_t t = 0; t < nthreads; ++t) {
// concat csc pointers.
size_t const old_ptr_size = p_cuts_->cut_ptrs_.size();
p_cuts_->cut_ptrs_.resize(
cuts[t]->p_cuts_->cut_ptrs_.size() + p_cuts_->cut_ptrs_.size() - 1);
size_t const new_icp_size = p_cuts_->cut_ptrs_.size();
auto tail = p_cuts_->cut_ptrs_[old_ptr_size-1];
for (size_t j = old_ptr_size; j < new_icp_size; ++j) {
p_cuts_->cut_ptrs_[j] = tail + cuts[t]->p_cuts_->cut_ptrs_[j-old_ptr_size+1];
}
// concat csc values
size_t const old_iv_size = p_cuts_->cut_values_.size();
p_cuts_->cut_values_.resize(
cuts[t]->p_cuts_->cut_values_.size() + p_cuts_->cut_values_.size());
size_t const new_iv_size = p_cuts_->cut_values_.size();
for (size_t j = old_iv_size; j < new_iv_size; ++j) {
p_cuts_->cut_values_[j] = cuts[t]->p_cuts_->cut_values_[j-old_iv_size];
}
// merge min values
for (size_t j = 0; j < cuts[t]->p_cuts_->min_vals_.size(); ++j) {
p_cuts_->min_vals_.at(min_vals_tail + j) =
std::min(p_cuts_->min_vals_.at(min_vals_tail + j), cuts.at(t)->p_cuts_->min_vals_.at(j));
}
min_vals_tail += cuts[t]->p_cuts_->min_vals_.size();
}
monitor_.Stop(__FUNCTION__);
}
void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
monitor_.Start(__FUNCTION__);
const MetaInfo& info = p_fmat->Info();
// safe factor for better accuracy
@ -60,20 +241,18 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
}
const auto& weights = info.weights_.HostVector();
// Data groups, used in ranking.
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
size_t const num_groups = group_ptr.size() == 0 ? 0 : group_ptr.size() - 1;
// Use group index for weights?
bool const use_group_ind = num_groups != 0 && weights.size() != info.num_row_;
bool const use_group = UseGroup(p_fmat);
for (const auto &batch : p_fmat->GetRowBatches()) {
size_t group_ind = 0;
if (use_group_ind) {
group_ind = this->SearchGroupIndFromBaseRow(group_ptr, batch.base_rowid);
if (use_group) {
group_ind = this->SearchGroupIndFromRow(group_ptr, batch.base_rowid);
}
#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group_ind)
#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group)
{
CHECK_EQ(nthread, omp_get_num_threads());
auto tid = static_cast<unsigned>(omp_get_thread_num());
@ -85,7 +264,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t const ridx = batch.base_rowid + i;
SparsePage::Inst const inst = batch[i];
if (use_group_ind &&
if (use_group &&
group_ptr[group_ind] == ridx &&
// maximum equals to weights.size() - 1
group_ind < num_groups - 1) {
@ -94,7 +273,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
}
for (auto const& entry : inst) {
if (entry.index >= begin && entry.index < end) {
size_t w_idx = use_group_ind ? group_ind : ridx;
size_t w_idx = use_group ? group_ind : ridx;
sketchs[entry.index].Push(entry.fvalue, info.GetWeight(w_idx));
}
}
@ -104,10 +283,10 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
}
Init(&sketchs, max_num_bins);
monitor_.Stop("Init");
monitor_.Stop(__FUNCTION__);
}
void HistCutMatrix::Init
void DenseCuts::Init
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
std::vector<WXQSketch>& sketchs = *in_sketchs;
constexpr int kFactor = 8;
@ -124,62 +303,34 @@ void HistCutMatrix::Init
CHECK_EQ(summary_array.size(), in_sketchs->size());
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
this->min_val.resize(sketchs.size());
row_ptr.push_back(0);
p_cuts_->min_vals_.resize(sketchs.size());
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
WXQSketch::SummaryContainer a;
a.Reserve(max_num_bins);
a.SetPrune(summary_array[fid], max_num_bins);
const bst_float mval = a.data[0].value;
this->min_val[fid] = mval - (fabs(mval) + 1e-5);
if (a.size > 1 && a.size <= 16) {
/* specialized code categorial / ordinal data -- use midpoints */
for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = (a.data[i].value + a.data[i - 1].value) / 2.0f;
if (i == 1 || cpt > cut.back()) {
cut.push_back(cpt);
}
}
} else {
for (size_t i = 2; i < a.size; ++i) {
bst_float cpt = a.data[i - 1].value;
if (i == 2 || cpt > cut.back()) {
cut.push_back(cpt);
}
}
}
p_cuts_->min_vals_[fid] = mval - (fabs(mval) + 1e-5);
AddCutPoint(a);
// push a value that is greater than anything
const bst_float cpt
= (a.size > 0) ? a.data[a.size - 1].value : this->min_val[fid];
= (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5);
cut.push_back(last);
p_cuts_->cut_values_.push_back(last);
// Ensure that every feature gets at least one quantile point
CHECK_LE(cut.size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(cut.size());
CHECK_GT(cut_size, row_ptr.back());
row_ptr.push_back(cut_size);
CHECK_LE(p_cuts_->cut_values_.size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(p_cuts_->cut_values_.size());
CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back());
p_cuts_->cut_ptrs_.push_back(cut_size);
}
}
uint32_t HistCutMatrix::GetBinIdx(const Entry& e) {
unsigned fid = e.index;
auto cbegin = cut.begin() + row_ptr[fid];
auto cend = cut.begin() + row_ptr[fid + 1];
CHECK(cbegin != cend);
auto it = std::upper_bound(cbegin, cend, e.fvalue);
if (it == cend) {
it = cend - 1;
}
uint32_t idx = static_cast<uint32_t>(it - cut.begin());
return idx;
}
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
cut.Init(p_fmat, max_num_bins);
cut.Build(p_fmat, max_num_bins);
const int32_t nthread = omp_get_max_threads();
const uint32_t nbins = cut.row_ptr.back();
const uint32_t nbins = cut.Ptrs().back();
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0);
@ -208,7 +359,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
#pragma omp parallel num_threads(batch_threads)
{
#pragma omp for
for (int32_t tid = 0; tid < batch_threads; ++tid) {
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
size_t ibegin = block_size * tid;
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
@ -222,13 +373,13 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
#pragma omp single
{
p_part[0] = prev_sum;
for (int32_t i = 1; i < batch_threads; ++i) {
for (size_t i = 1; i < batch_threads; ++i) {
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
}
}
#pragma omp for
for (int32_t tid = 0; tid < batch_threads; ++tid) {
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
size_t ibegin = block_size * tid;
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
@ -240,7 +391,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
index.resize(row_ptr[rbegin + batch.Size()]);
CHECK_GT(cut.cut.size(), 0U);
CHECK_GT(cut.Values().size(), 0U);
#pragma omp parallel for num_threads(batch_threads) schedule(static)
for (omp_ulong i = 0; i < batch.Size(); ++i) { // NOLINT(*)
@ -251,7 +402,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) {
uint32_t idx = cut.GetBinIdx(inst[j]);
uint32_t idx = cut.SearchBin(inst[j]);
index[ibegin + j] = idx;
++hit_count_tloc_[tid * nbins + idx];
@ -382,7 +533,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
const ColumnMatrix& colmat,
const tree::TrainParam& param) {
const size_t nrow = gmat.row_ptr.size() - 1;
const size_t nfeature = gmat.cut.row_ptr.size() - 1;
const size_t nfeature = gmat.cut.Ptrs().size() - 1;
std::vector<unsigned> feature_list(nfeature);
std::iota(feature_list.begin(), feature_list.end(), 0);
@ -438,7 +589,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
cut_ = &gmat.cut;
const size_t nrow = gmat.row_ptr.size() - 1;
const uint32_t nbins = gmat.cut.row_ptr.back();
const uint32_t nbins = gmat.cut.Ptrs().back();
/* step 1: form feature groups */
auto groups = FastFeatureGrouping(gmat, colmat, param);
@ -448,8 +599,8 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
std::vector<uint32_t> bin2block(nbins); // lookup table [bin id] => [block id]
for (uint32_t group_id = 0; group_id < nblock; ++group_id) {
for (auto& fid : groups[group_id]) {
const uint32_t bin_begin = gmat.cut.row_ptr[fid];
const uint32_t bin_end = gmat.cut.row_ptr[fid + 1];
const uint32_t bin_begin = gmat.cut.Ptrs()[fid];
const uint32_t bin_end = gmat.cut.Ptrs()[fid + 1];
for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
bin2block[bin_id] = group_id;
}
@ -627,8 +778,8 @@ void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
const size_t block_size = 1024; // aproximatly 1024 values per block
size_t n_blocks = size/block_size + !!(size%block_size);
#pragma omp parallel for
for (int iblock = 0; iblock < n_blocks; ++iblock) {
#pragma omp parallel for
for (omp_ulong iblock = 0; iblock < n_blocks; ++iblock) {
const size_t ibegin = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {

View File

@ -3,6 +3,7 @@
*/
#include "./hist_util.h"
#include <xgboost/logging.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
@ -24,7 +25,7 @@
namespace xgboost {
namespace common {
using WXQSketch = HistCutMatrix::WXQSketch;
using WXQSketch = DenseCuts::WXQSketch;
__global__ void FindCutsK
(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data,
@ -92,7 +93,7 @@ __global__ void UnpackFeaturesK
* across distinct rows.
*/
struct SketchContainer {
std::vector<HistCutMatrix::WXQSketch> sketches_; // NOLINT
std::vector<DenseCuts::WXQSketch> sketches_; // NOLINT
std::vector<std::mutex> col_locks_; // NOLINT
static constexpr int kOmpNumColsParallelizeLimit = 1000;
@ -300,7 +301,7 @@ struct GPUSketcher {
} else if (n_cuts_cur_[icol] > 0) {
// if more elements than cuts: use binary search on cumulative weights
int block = 256;
FindCutsK<<<dh::DivRoundUp(n_cuts_cur_[icol], block), block>>>
FindCutsK<<<common::DivRoundUp(n_cuts_cur_[icol], block), block>>>
(cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(),
weights2_.data().get(), n_unique, n_cuts_cur_[icol]);
dh::safe_cuda(cudaGetLastError()); // NOLINT
@ -342,8 +343,8 @@ struct GPUSketcher {
dim3 block3(16, 64, 1);
// NOTE: This will typically support ~ 4M features - 64K*64
dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
dh::DivRoundUp(num_cols_, block3.y), 1);
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(num_cols_, block3.y), 1);
UnpackFeaturesK<<<grid3, block3>>>
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
row_ptrs_.data().get() + batch_row_begin,
@ -392,7 +393,7 @@ struct GPUSketcher {
row_ptrs_.resize(n_rows_ + 1);
thrust::copy(offset_vec.data() + row_begin_,
offset_vec.data() + row_end_ + 1, row_ptrs_.begin());
size_t gpu_nbatches = dh::DivRoundUp(n_rows_, gpu_batch_nrows_);
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
SketchBatch(row_batch, info, gpu_batch);
}
@ -434,7 +435,7 @@ struct GPUSketcher {
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
* for the entire dataset */
size_t Sketch(DMatrix *dmat, HistCutMatrix *hmat) {
size_t Sketch(DMatrix *dmat, DenseCuts *hmat) {
const MetaInfo &info = dmat->Info();
row_stride_ = 0;
@ -459,9 +460,13 @@ struct GPUSketcher {
size_t DeviceSketch
(const tree::TrainParam &param, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
DMatrix *dmat, HistCutMatrix *hmat) {
DMatrix *dmat, HistogramCuts *hmat) {
GPUSketcher sketcher(param, learner_param, gpu_batch_nrows);
return sketcher.Sketch(dmat, hmat);
// We only need to return the result in HistogramCuts container, so it is safe to
// use a pointer of local HistogramCutsDense
DenseCuts dense_cuts(hmat);
auto res = sketcher.Sketch(dmat, &dense_cuts);
return res;
}
} // namespace common

View File

@ -12,18 +12,21 @@
#include <limits>
#include <vector>
#include <algorithm>
#include <memory>
#include <utility>
#include "row_set.h"
#include "../tree/param.h"
#include "./quantile.h"
#include "./timer.h"
#include "../include/rabit/rabit.h"
#include "random.h"
namespace xgboost {
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template<typename T, size_t MaxStackSize>
class MemStackAllocator {
@ -122,47 +125,175 @@ struct SimpleArray {
size_t n_ = 0;
};
/*! \brief Cut configuration for all the features. */
struct HistCutMatrix {
/*! \brief Unit pointer to rows by element position */
std::vector<uint32_t> row_ptr;
/*! \brief minimum value of each feature */
std::vector<bst_float> min_val;
/*! \brief the cut field */
std::vector<bst_float> cut;
uint32_t GetBinIdx(const Entry &e);
/*!
* \brief A single row in global histogram index.
* Directly represent the global index in the histogram entry.
*/
using GHistIndexRow = Span<uint32_t const>;
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
// create histogram cut matrix given statistics from data
// using approximate quantile sketch approach
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
HistCutMatrix();
size_t NumBins() const { return row_ptr.back(); }
// A CSC matrix representing histogram cuts, used in CPU quantile hist.
class HistogramCuts {
// Using friends to avoid creating a virtual class, since HistogramCuts is used as value
// object in many places.
friend class SparseCuts;
friend class DenseCuts;
friend class CutsBuilder;
protected:
virtual size_t SearchGroupIndFromBaseRow(
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) const;
using BinIdx = uint32_t;
common::Monitor monitor_;
Monitor monitor_;
std::vector<bst_float> cut_values_;
std::vector<uint32_t> cut_ptrs_;
std::vector<float> min_vals_; // storing minimum value in a sketch set.
public:
HistogramCuts();
HistogramCuts(HistogramCuts const& that) = delete;
HistogramCuts(HistogramCuts&& that) noexcept(true) {
*this = std::forward<HistogramCuts&&>(that);
}
HistogramCuts& operator=(HistogramCuts const& that) = delete;
HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) {
monitor_ = std::move(that.monitor_);
cut_ptrs_ = std::move(that.cut_ptrs_);
cut_values_ = std::move(that.cut_values_);
min_vals_ = std::move(that.min_vals_);
return *this;
}
/* \brief Build histogram cuts. */
void Build(DMatrix* dmat, uint32_t const max_num_bins);
/* \brief How many bins a feature has. */
uint32_t FeatureBins(uint32_t feature) const {
return cut_ptrs_.at(feature+1) - cut_ptrs_[feature];
}
// Getters. Cuts should be of no use after building histogram indices, but currently
// it's deeply linked with quantile_hist, gpu sketcher and gpu_hist. So we preserve
// these for now.
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_; }
std::vector<float> const& Values() const { return cut_values_; }
std::vector<float> const& MinValues() const { return min_vals_; }
size_t TotalBins() const { return cut_ptrs_.back(); }
BinIdx SearchBin(float value, uint32_t column_id) {
auto beg = cut_ptrs_.at(column_id);
auto end = cut_ptrs_.at(column_id + 1);
auto it = std::upper_bound(cut_values_.cbegin() + beg, cut_values_.cbegin() + end, value);
if (it == cut_values_.cend()) {
it = cut_values_.cend() - 1;
}
BinIdx idx = it - cut_values_.cbegin();
return idx;
}
BinIdx SearchBin(Entry const& e) {
return SearchBin(e.fvalue, e.index);
}
};
/* \brief An interface for building quantile cuts.
*
* `DenseCuts' always assumes there are `max_bins` for each feature, which makes it not
* suitable for sparse dataset. On the other hand `SparseCuts' uses `GetColumnBatches',
* which doubles the memory usage, hence can not be applied to dense dataset.
*/
class CutsBuilder {
public:
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
protected:
HistogramCuts* p_cuts_;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);
public:
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}
virtual ~CutsBuilder() = default;
static uint32_t SearchGroupIndFromRow(
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
using KIt = std::vector<bst_uint>::const_iterator;
KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
// Cannot use CHECK_NE because it will try to print the iterator.
bool const found = res != group_ptr.cend() - 1;
if (!found) {
LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!";
}
uint32_t group_ind = std::distance(group_ptr.cbegin(), res);
return group_ind;
}
void AddCutPoint(WXQSketch::SummaryContainer const& summary) {
if (summary.size > 1 && summary.size <= 16) {
/* specialized code categorial / ordinal data -- use midpoints */
for (size_t i = 1; i < summary.size; ++i) {
bst_float cpt = (summary.data[i].value + summary.data[i - 1].value) / 2.0f;
if (i == 1 || cpt > p_cuts_->cut_values_.back()) {
p_cuts_->cut_values_.push_back(cpt);
}
}
} else {
for (size_t i = 2; i < summary.size; ++i) {
bst_float cpt = summary.data[i - 1].value;
if (i == 2 || cpt > p_cuts_->cut_values_.back()) {
p_cuts_->cut_values_.push_back(cpt);
}
}
}
}
/* \brief Build histogram indices. */
virtual void Build(DMatrix* dmat, uint32_t const max_num_bins) = 0;
};
/*! \brief Cut configuration for sparse dataset. */
class SparseCuts : public CutsBuilder {
/* \brief Distrbute columns to each thread according to number of entries. */
static std::vector<size_t> LoadBalance(SparsePage const& page, size_t const nthreads);
Monitor monitor_;
public:
explicit SparseCuts(HistogramCuts* container) :
CutsBuilder(container) {
monitor_.Init(__FUNCTION__);
}
/* \brief Concatonate the built cuts in each thread. */
void Concat(std::vector<std::unique_ptr<SparseCuts>> const& cuts, uint32_t n_cols);
/* \brief Build histogram indices in single thread. */
void SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
uint32_t max_num_bins,
bool const use_group_ind,
uint32_t beg, uint32_t end, uint32_t thread_id);
void Build(DMatrix* dmat, uint32_t const max_num_bins) override;
};
/*! \brief Cut configuration for dense dataset. */
class DenseCuts : public CutsBuilder {
protected:
Monitor monitor_;
public:
explicit DenseCuts(HistogramCuts* container) :
CutsBuilder(container) {
monitor_.Init(__FUNCTION__);
}
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
};
// FIXME(trivialfis): Merge this into generic cut builder.
/*! \brief Builds the cut matrix on the GPU.
*
* \return The row stride across the entire dataset.
*/
size_t DeviceSketch
(const tree::TrainParam& param, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
DMatrix* dmat, HistCutMatrix* hmat);
DMatrix* dmat, HistogramCuts* hmat);
/*!
* \brief A single row in global histogram index.
* Directly represent the global index in the histogram entry.
*/
using GHistIndexRow = Span<uint32_t const>;
/*!
* \brief preprocessed global index matrix, in CSR format
@ -178,7 +309,7 @@ struct GHistIndexMatrix {
/*! \brief hit count of each index */
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
HistCutMatrix cut;
HistogramCuts cut;
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins);
// get i-th row
@ -188,10 +319,10 @@ struct GHistIndexMatrix {
row_ptr[i + 1] - row_ptr[i])};
}
inline void GetFeatureCounts(size_t* counts) const {
auto nfeature = cut.row_ptr.size() - 1;
auto nfeature = cut.Ptrs().size() - 1;
for (unsigned fid = 0; fid < nfeature; ++fid) {
auto ibegin = cut.row_ptr[fid];
auto iend = cut.row_ptr[fid + 1];
auto ibegin = cut.Ptrs()[fid];
auto iend = cut.Ptrs()[fid + 1];
for (auto i = ibegin; i < iend; ++i) {
counts[fid] += hit_count[i];
}
@ -234,7 +365,7 @@ class GHistIndexBlockMatrix {
private:
std::vector<size_t> row_ptr_;
std::vector<uint32_t> index_;
const HistCutMatrix* cut_;
const HistogramCuts* cut_;
struct Block {
const size_t* row_ptr_begin;
const size_t* row_ptr_end;

View File

@ -549,7 +549,7 @@ class Span {
detail::ExtentValue<Extent, Offset, Count>::value> {
SPAN_CHECK(Offset >= 0 && (Offset < size() || size() == 0));
SPAN_CHECK(Count == dynamic_extent ||
Count >= 0 && Offset + Count <= size());
(Count >= 0 && Offset + Count <= size()));
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
}

View File

@ -60,7 +60,7 @@ class Transform {
Evaluator(Functor func, Range range, GPUSet devices, bool shard) :
func_(func), range_{std::move(range)},
shard_{shard},
distribution_{std::move(GPUDistribution::Block(devices))} {}
distribution_{GPUDistribution::Block(devices)} {}
Evaluator(Functor func, Range range, GPUDistribution dist,
bool shard) :
func_(func), range_{std::move(range)}, shard_{shard},
@ -142,7 +142,7 @@ class Transform {
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
dh::safe_cuda(cudaSetDevice(device));
const int GRID_SIZE =
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
_func, shard_range, UnpackHDV(_vectors, device)...);
}

View File

@ -238,7 +238,7 @@ class GPUPredictor : public xgboost::Predictor {
auto& offsets = *out_offsets;
size_t n_shards = devices_.Size();
offsets.resize(n_shards + 2);
size_t rows_per_shard = dh::DivRoundUp(batch_size, n_shards);
size_t rows_per_shard = common::DivRoundUp(batch_size, n_shards);
for (size_t shard = 0; shard < devices_.Size(); ++shard) {
size_t n_rows = std::min(batch_size, shard * rows_per_shard);
offsets[shard] = batch_offset + n_rows * n_classes;
@ -284,7 +284,7 @@ class GPUPredictor : public xgboost::Predictor {
dh::safe_cuda(cudaSetDevice(device_));
const int BLOCK_THREADS = 128;
size_t num_rows = batch.offset.DeviceSize(device_) - 1;
const int GRID_SIZE = static_cast<int>(dh::DivRoundUp(num_rows, BLOCK_THREADS));
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
int shared_memory_bytes = static_cast<int>
(sizeof(float) * num_features * BLOCK_THREADS);

View File

@ -170,7 +170,7 @@ void FeatureInteractionConstraint::ClearBuffers() {
CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size());
int constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>(
dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
output_buffer_bits_, input_buffer_bits_);
}
@ -227,7 +227,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
int constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>(
dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
@ -328,8 +328,8 @@ void FeatureInteractionConstraint::Split(
BitField right = s_node_constraints_[right_id];
dim3 const block3(16, 64, 1);
dim3 const grid3(dh::DivRoundUp(n_sets_, 16),
dh::DivRoundUp(s_fconstraints_.size(), 64));
dim3 const grid3(common::DivRoundUp(n_sets_, 16),
common::DivRoundUp(s_fconstraints_.size(), 64));
RestoreFeatureListFromSetsKernel<<<grid3, block3>>>
(feature_buffer_,
feature_id,
@ -339,7 +339,7 @@ void FeatureInteractionConstraint::Split(
s_sets_ptr_);
int constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>(dh::DivRoundUp(node.Size(), kBlockThreads));
const int n_grids = static_cast<int>(common::DivRoundUp(node.Size(), kBlockThreads));
InteractionConstraintSplitKernel<<<n_grids, kBlockThreads>>>
(feature_buffer_,
feature_id,

View File

@ -76,7 +76,7 @@ static const int kNoneKey = -100;
*/
template <int BLKDIM_L1L3 = 256>
int ScanTempBufferSize(int size) {
int num_blocks = dh::DivRoundUp(size, BLKDIM_L1L3);
int num_blocks = common::DivRoundUp(size, BLKDIM_L1L3);
return num_blocks;
}
@ -250,7 +250,7 @@ void ReduceScanByKey(common::Span<GradientPair> sums,
common::Span<GradientPair> tmpScans,
common::Span<int> tmpKeys,
common::Span<const int> colIds, NodeIdT nodeStart) {
int nBlks = dh::DivRoundUp(size, BLKDIM_L1L3);
int nBlks = common::DivRoundUp(size, BLKDIM_L1L3);
cudaMemset(sums.data(), 0, nUniqKeys * nCols * sizeof(GradientPair));
CubScanByKeyL1<BLKDIM_L1L3>
<<<nBlks, BLKDIM_L1L3>>>(scans, vals, instIds, tmpScans, tmpKeys, keys,
@ -448,7 +448,7 @@ void ArgMaxByKey(common::Span<ExactSplitCandidate> nodeSplits,
dh::FillConst<ExactSplitCandidate, BLKDIM, ITEMS_PER_THREAD>(
*(devices.begin()), nodeSplits.data(), nUniqKeys,
ExactSplitCandidate());
int nBlks = dh::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM);
int nBlks = common::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM);
switch (algo) {
case kAbkGmem:
AtomicArgMaxByKeyGmem<<<nBlks, BLKDIM>>>(
@ -793,11 +793,11 @@ class GPUMaker : public TreeUpdater {
const int BlkDim = 256;
const int ItemsPerThread = 4;
// assign default node ids first
int nBlks = dh::DivRoundUp(n_rows_, BlkDim);
int nBlks = common::DivRoundUp(n_rows_, BlkDim);
FillDefaultNodeIds<<<nBlks, BlkDim>>>(node_assigns_per_inst_.data(),
nodes_.data(), n_rows_);
// evaluate the correct child indices of non-missing values next
nBlks = dh::DivRoundUp(n_vals_, BlkDim * ItemsPerThread);
nBlks = common::DivRoundUp(n_vals_, BlkDim * ItemsPerThread);
AssignNodeIds<<<nBlks, BlkDim>>>(
node_assigns_per_inst_.data(), nodeLocations_.Current(),
nodeAssigns_.Current(), instIds_.Current(), nodes_.data(),
@ -823,7 +823,7 @@ class GPUMaker : public TreeUpdater {
void MarkLeaves() {
const int BlkDim = 128;
int nBlks = dh::DivRoundUp(maxNodes_, BlkDim);
int nBlks = common::DivRoundUp(maxNodes_, BlkDim);
MarkLeavesKernel<<<nBlks, BlkDim>>>(nodes_.data(), maxNodes_);
}
};

View File

@ -480,8 +480,8 @@ __global__ void CompressBinEllpackKernel(
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistCutMatrix::cut
const uint32_t* __restrict__ cut_rows, // HistCutMatrix::row_ptrs
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
@ -593,7 +593,7 @@ struct DeviceShard {
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist;
/*! \brief row_ptr form HistCutMatrix. */
/*! \brief row_ptr form HistogramCuts. */
common::Span<uint32_t> feature_segments;
/*! \brief minimum value for each feature. */
common::Span<bst_float> min_fvalue;
@ -654,10 +654,10 @@ struct DeviceShard {
}
void InitCompressedData(
const common::HistCutMatrix& hmat, size_t row_stride, bool is_dense);
const common::HistogramCuts& hmat, size_t row_stride, bool is_dense);
void CreateHistIndices(
const SparsePage &row_batch, const common::HistCutMatrix &hmat,
const SparsePage &row_batch, const common::HistogramCuts &hmat,
const RowStateOnDevice &device_row_state, int rows_per_batch);
~DeviceShard() {
@ -718,7 +718,7 @@ struct DeviceShard {
// Work out cub temporary memory requirement
GPUTrainingParam gpu_param(param);
DeviceSplitCandidateReduceOp op(gpu_param);
size_t temp_storage_bytes;
size_t temp_storage_bytes = 0;
DeviceSplitCandidate*dummy = nullptr;
cub::DeviceReduce::Reduce(
nullptr, temp_storage_bytes, dummy,
@ -806,7 +806,7 @@ struct DeviceShard {
const int items_per_thread = 8;
const int block_threads = 256;
const int grid_size = static_cast<int>(
dh::DivRoundUp(n_elements, items_per_thread * block_threads));
common::DivRoundUp(n_elements, items_per_thread * block_threads));
if (grid_size <= 0) {
return;
}
@ -1106,9 +1106,9 @@ struct DeviceShard {
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::InitCompressedData(
const common::HistCutMatrix &hmat, size_t row_stride, bool is_dense) {
n_bins = hmat.row_ptr.back();
int null_gidx_value = hmat.row_ptr.back();
const common::HistogramCuts &hmat, size_t row_stride, bool is_dense) {
n_bins = hmat.Ptrs().back();
int null_gidx_value = hmat.Ptrs().back();
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
@ -1121,14 +1121,14 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
&gpair, n_rows,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
&feature_segments, hmat.row_ptr.size(),
&gidx_fvalue_map, hmat.cut.size(),
&min_fvalue, hmat.min_val.size(),
&feature_segments, hmat.Ptrs().size(),
&gidx_fvalue_map, hmat.Values().size(),
&min_fvalue, hmat.MinValues().size(),
&monotone_constraints, param.monotone_constraints.size());
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.cut);
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.min_val);
dh::CopyVectorToDeviceSpan(feature_segments, hmat.row_ptr);
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
node_sum_gradients.resize(max_nodes);
@ -1153,26 +1153,26 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
// check if we can use shared memory for building histograms
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
// hiding)
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
auto histogram_size = sizeof(GradientSumT) * hmat.Ptrs().back();
auto max_smem = dh::MaxSharedMemory(device_id);
if (histogram_size <= max_smem) {
use_shared_memory_histograms = true;
}
// Init histogram
hist.Init(device_id, hmat.NumBins());
hist.Init(device_id, hmat.Ptrs().back());
}
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::CreateHistIndices(
const SparsePage &row_batch,
const common::HistCutMatrix &hmat,
const common::HistogramCuts &hmat,
const RowStateOnDevice &device_row_state,
int rows_per_batch) {
// Has any been allocated for me in this batch?
if (!device_row_state.rows_to_process_from_batch) return;
unsigned int null_gidx_value = hmat.row_ptr.back();
unsigned int null_gidx_value = hmat.Ptrs().back();
size_t row_stride = this->ellpack_matrix.row_stride;
const auto &offset_vec = row_batch.offset.ConstHostVector();
@ -1184,7 +1184,7 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch,
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
@ -1216,8 +1216,8 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
(entries_d.data().get(), data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
dh::DivRoundUp(row_stride, block3.y), 1);
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1);
CompressBinEllpackKernel<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
@ -1361,13 +1361,13 @@ class GPUHistMakerSpecialised {
});
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistCutMatrix
// Create the quantile sketches for the dmatrix and initialize HistogramCuts
size_t row_stride = common::DeviceSketch(param_, *learner_param_,
hist_maker_param_.gpu_batch_nrows,
dmat, &hmat_);
monitor_.StopCuda("Quantiles");
n_bins_ = hmat_.row_ptr.back();
n_bins_ = hmat_.Ptrs().back();
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
@ -1476,7 +1476,7 @@ class GPUHistMakerSpecialised {
}
TrainParam param_; // NOLINT
common::HistCutMatrix hmat_; // NOLINT
common::HistogramCuts hmat_; // NOLINT
MetaInfo* info_; // NOLINT
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT

View File

@ -247,15 +247,15 @@ int32_t QuantileHistMaker::Builder::FindSplitCond(int32_t nid,
// Categorize member rows
const bst_uint fid = node.SplitIndex();
const bst_float split_pt = node.SplitCond();
const uint32_t lower_bound = gmat.cut.row_ptr[fid];
const uint32_t upper_bound = gmat.cut.row_ptr[fid + 1];
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
int32_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
if (split_pt == gmat.cut.cut[i]) {
if (split_pt == gmat.cut.Values()[i]) {
split_cond = static_cast<int32_t>(i);
}
}
@ -533,7 +533,7 @@ void QuantileHistMaker::Builder::BuildHistsBatch(const std::vector<ExpandEntry>&
perf_monitor.TickStart();
const size_t block_size_rows = 256;
const size_t nthread = static_cast<size_t>(this->nthread_);
const size_t nbins = gmat.cut.row_ptr.back();
const size_t nbins = gmat.cut.Ptrs().back();
const size_t hist_size = 2 * nbins;
hist_buffers->resize(nodes.size());
@ -856,8 +856,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
}
}
#pragma omp parallel for schedule(guided)
for (int32_t k = 0; k < tasks_elem.size(); ++k) {
#pragma omp parallel for schedule(guided)
for (omp_ulong k = 0; k < tasks_elem.size(); ++k) {
const RowSetCollection::Elem rowset = tasks_elem[k];
if (rowset.begin != nullptr && rowset.end != nullptr && rowset.node_id != -1) {
const size_t nrows = rowset.Size();
@ -909,7 +909,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
// clear local prediction cache
leaf_value_cache_.clear();
// initialize histogram collection
uint32_t nbins = gmat.cut.row_ptr.back();
uint32_t nbins = gmat.cut.Ptrs().back();
hist_.Init(nbins);
hist_buff_.Init(nbins);
@ -999,7 +999,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
const size_t ncol = info.num_col_;
const size_t nnz = info.num_nonzero_;
// number of discrete bins for feature 0
const uint32_t nbins_f0 = gmat.cut.row_ptr[1] - gmat.cut.row_ptr[0];
const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0];
if (nrow * ncol == nnz) {
// dense data with zero-based indexing
data_layout_ = kDenseDataZeroBased;
@ -1029,7 +1029,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
choose the column that has a least positive number of discrete bins.
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
uint32_t min_nbins_per_feature = 0;
for (bst_uint i = 0; i < nfeature; ++i) {
@ -1079,8 +1079,8 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch(
// partial results
std::vector<std::pair<SplitEntry, SplitEntry>> splits(tasks.size());
// parallel enumeration
#pragma omp parallel for schedule(guided)
for (int32_t i = 0; i < tasks.size(); ++i) {
#pragma omp parallel for schedule(guided)
for (omp_ulong i = 0; i < tasks.size(); ++i) {
// node_idx : offset within `nodes` list
const int32_t node_idx = tasks[i].first;
const size_t fid = tasks[i].second;
@ -1098,7 +1098,7 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch(
// reduce needed part of a hist here to have it in cache before enumeration
if (!rabit::IsDistributed()) {
const std::vector<uint32_t>& cut_ptr = gmat.cut.row_ptr;
const std::vector<uint32_t>& cut_ptr = gmat.cut.Ptrs();
const size_t ibegin = 2 * cut_ptr[fid];
const size_t iend = 2 * cut_ptr[fid + 1];
ReduceHistograms(hist_data, sibling_hist_data, parent_hist_data, ibegin, iend, node_idx,
@ -1179,8 +1179,8 @@ bool QuantileHistMaker::Builder::EnumerateSplit(int d_step,
CHECK(d_step == +1 || d_step == -1);
// aliases
const std::vector<uint32_t>& cut_ptr = gmat.cut.row_ptr;
const std::vector<bst_float>& cut_val = gmat.cut.cut;
const std::vector<uint32_t>& cut_ptr = gmat.cut.Ptrs();
const std::vector<bst_float>& cut_val = gmat.cut.Values();
// statistics on both sides of split
GradStats c;
@ -1239,7 +1239,7 @@ bool QuantileHistMaker::Builder::EnumerateSplit(int d_step,
if (i == imin) {
// for leftmost bin, left bound is the smallest feature value
split_pt = gmat.cut.min_val[fid];
split_pt = gmat.cut.MinValues()[fid];
} else {
split_pt = cut_val[i - 1];
}

View File

@ -33,7 +33,6 @@ namespace common {
}
namespace tree {
using xgboost::common::HistCutMatrix;
using xgboost::common::GHistIndexMatrix;
using xgboost::common::GHistIndexBlockMatrix;
using xgboost::common::GHistIndexRow;

View File

@ -53,10 +53,10 @@ TEST(c_api, XGDMatrixCreateFromMat_omp) {
ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing);
for (const auto &batch : (*dmat)->GetRowBatches()) {
for (int i = 0; i < batch.Size(); i++) {
for (size_t i = 0; i < batch.Size(); i++) {
auto inst = batch[i];
for (int j = 0; i < inst.size(); i++) {
ASSERT_EQ(inst[j].fvalue, 1.5);
for (auto e : inst) {
ASSERT_EQ(e.fvalue, 1.5);
}
}
}

View File

@ -7,6 +7,7 @@
namespace xgboost {
namespace common {
TEST(DenseColumn, Test) {
auto dmat = CreateDMatrix(100, 10, 0.0);
GHistIndexMatrix gmat;
@ -17,7 +18,7 @@ TEST(DenseColumn, Test) {
for (auto i = 0ull; i < (*dmat)->Info().num_row_; i++) {
for (auto j = 0ull; j < (*dmat)->Info().num_col_; j++) {
auto col = column_matrix.GetColumn(j);
EXPECT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j],
ASSERT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j],
col.GetGlobalBinIdx(i));
}
}
@ -33,7 +34,7 @@ TEST(SparseColumn, Test) {
auto col = column_matrix.GetColumn(0);
ASSERT_EQ(col.Size(), gmat.index.size());
for (auto i = 0ull; i < col.Size(); i++) {
EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
col.GetGlobalBinIdx(i));
}
delete dmat;

View File

@ -28,7 +28,7 @@ TEST(CompressedIterator, Test) {
CompressedIterator<int> ci(buffer.data(), alphabet_size);
std::vector<int> output(input.size());
for (int i = 0; i < input.size(); i++) {
for (size_t i = 0; i < input.size(); i++) {
output[i] = ci[i];
}
@ -38,12 +38,12 @@ TEST(CompressedIterator, Test) {
std::vector<unsigned char> buffer2(
CompressedBufferWriter::CalculateBufferSize(input.size(),
alphabet_size));
for (int i = 0; i < input.size(); i++) {
for (size_t i = 0; i < input.size(); i++) {
cbw.WriteSymbol(buffer2.data(), input[i], i);
}
CompressedIterator<int> ci2(buffer.data(), alphabet_size);
std::vector<int> output2(input.size());
for (int i = 0; i < input.size(); i++) {
for (size_t i = 0; i < input.size(); i++) {
output2[i] = ci2[i];
}
ASSERT_TRUE(input == output2);

View File

@ -48,11 +48,11 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
int gpu_batch_nrows = 0;
// find quantiles on the CPU
HistCutMatrix hmat_cpu;
hmat_cpu.Init((*dmat).get(), p.max_bin);
HistogramCuts hmat_cpu;
hmat_cpu.Build((*dmat).get(), p.max_bin);
// find the cuts on the GPU
HistCutMatrix hmat_gpu;
HistogramCuts hmat_gpu;
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows,
dmat->get(), &hmat_gpu);
@ -69,12 +69,12 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
// compare the cuts
double eps = 1e-2;
ASSERT_EQ(hmat_gpu.min_val.size(), num_cols);
ASSERT_EQ(hmat_gpu.row_ptr.size(), num_cols + 1);
ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size());
ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows);
for (int i = 0; i < hmat_gpu.cut.size(); ++i) {
ASSERT_LT(fabs(hmat_cpu.cut[i] - hmat_gpu.cut[i]), eps * nrows);
ASSERT_EQ(hmat_gpu.MinValues().size(), num_cols);
ASSERT_EQ(hmat_gpu.Ptrs().size(), num_cols + 1);
ASSERT_EQ(hmat_gpu.Values().size(), hmat_cpu.Values().size());
ASSERT_LT(fabs(hmat_cpu.MinValues()[0] - hmat_gpu.MinValues()[0]), eps * nrows);
for (int i = 0; i < hmat_gpu.Values().size(); ++i) {
ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows);
}
delete dmat;

View File

@ -9,15 +9,7 @@
namespace xgboost {
namespace common {
class HistCutMatrixMock : public HistCutMatrix {
public:
size_t SearchGroupIndFromBaseRow(
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
return HistCutMatrix::SearchGroupIndFromBaseRow(group_ptr, base_rowid);
}
};
TEST(HistCutMatrix, SearchGroupInd) {
TEST(CutsBuilder, SearchGroupInd) {
size_t constexpr kNumGroups = 4;
size_t constexpr kNumRows = 17;
size_t constexpr kNumCols = 15;
@ -34,18 +26,102 @@ TEST(HistCutMatrix, SearchGroupInd) {
p_mat->Info().SetInfo(
"group", group.data(), DataType::kUInt32, kNumGroups);
HistCutMatrixMock hmat;
HistogramCuts hmat;
size_t group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 0);
size_t group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0);
ASSERT_EQ(group_ind, 0);
group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 5);
group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5);
ASSERT_EQ(group_ind, 2);
EXPECT_ANY_THROW(hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 17));
EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
delete pp_mat;
}
namespace {
class SparseCutsWrapper : public SparseCuts {
public:
std::vector<uint32_t> const& ColPtrs() const { return p_cuts_->Ptrs(); }
std::vector<float> const& ColValues() const { return p_cuts_->Values(); }
};
} // anonymous namespace
TEST(SparseCuts, SingleThreadedBuild) {
size_t constexpr kRows = 267;
size_t constexpr kCols = 31;
size_t constexpr kBins = 256;
// Dense matrix.
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
DMatrix* p_fmat = (*pp_mat).get();
common::GHistIndexMatrix hmat;
hmat.Init(p_fmat, kBins);
HistogramCuts cuts;
SparseCuts indices(&cuts);
auto const& page = *(p_fmat->GetColumnBatches().begin());
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size());
ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs());
ASSERT_EQ(hmat.cut.Values(), cuts.Values());
ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues());
delete pp_mat;
}
TEST(SparseCuts, MultiThreadedBuild) {
size_t constexpr kRows = 17;
size_t constexpr kCols = 15;
size_t constexpr kBins = 255;
omp_ulong ori_nthreads = omp_get_max_threads();
omp_set_num_threads(16);
auto Compare =
#if defined(_MSC_VER) // msvc fails to capture
[kBins](DMatrix* p_fmat) {
#else
[](DMatrix* p_fmat) {
#endif
HistogramCuts threaded_container;
SparseCuts threaded_indices(&threaded_container);
threaded_indices.Build(p_fmat, kBins);
HistogramCuts container;
SparseCuts indices(&container);
auto const& page = *(p_fmat->GetColumnBatches().begin());
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
ASSERT_EQ(container.Ptrs().size(), threaded_container.Ptrs().size());
ASSERT_EQ(container.Values().size(), threaded_container.Values().size());
for (uint32_t i = 0; i < container.Ptrs().size(); ++i) {
ASSERT_EQ(container.Ptrs()[i], threaded_container.Ptrs()[i]);
}
for (uint32_t i = 0; i < container.Values().size(); ++i) {
ASSERT_EQ(container.Values()[i], threaded_container.Values()[i]);
}
};
{
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
DMatrix* p_fmat = (*pp_mat).get();
Compare(p_fmat);
delete pp_mat;
}
{
auto pp_mat = CreateDMatrix(kRows, kCols, 0.0001);
DMatrix* p_fmat = (*pp_mat).get();
Compare(p_fmat);
delete pp_mat;
}
omp_set_num_threads(ori_nthreads);
}
} // namespace common
} // namespace xgboost

View File

@ -53,8 +53,8 @@ TEST(ColumnSampler, Test) {
TEST(ColumnSampler, ThreadSynchronisation) {
const int64_t num_threads = 100;
int n = 128;
int iterations = 10;
int levels = 5;
size_t iterations = 10;
size_t levels = 5;
std::vector<int> reference_result;
bool success =
true; // Cannot use google test asserts in multithreaded region

View File

@ -310,7 +310,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(first.size(), 4);
ASSERT_EQ(first.data(), arr);
for (size_t i = 0; i < first.size(); ++i) {
for (int64_t i = 0; i < first.size(); ++i) {
ASSERT_EQ(first[i], arr[i]);
}
@ -329,7 +329,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(last.size(), 4);
ASSERT_EQ(last.data(), arr + 12);
for (size_t i = 0; i < last.size(); ++i) {
for (int64_t i = 0; i < last.size(); ++i) {
ASSERT_EQ(last[i], arr[i+12]);
}
@ -348,7 +348,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(first.size(), 4);
ASSERT_EQ(first.data(), s.data());
for (size_t i = 0; i < first.size(); ++i) {
for (int64_t i = 0; i < first.size(); ++i) {
ASSERT_EQ(first[i], s[i]);
}
@ -368,7 +368,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(last.size(), 4);
ASSERT_EQ(last.data(), s.data() + 12);
for (size_t i = 0; i < last.size(); ++i) {
for (int64_t i = 0; i < last.size(); ++i) {
ASSERT_EQ(s[12 + i], last[i]);
}

View File

@ -50,7 +50,7 @@ TEST(SparsePage, PushCSC) {
inst = page[1];
ASSERT_EQ(inst.size(), 6);
std::vector<size_t> indices_sol {1, 2, 3};
for (size_t i = 0; i < inst.size(); ++i) {
for (int64_t i = 0; i < inst.size(); ++i) {
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
}
}

View File

@ -21,13 +21,13 @@ TEST(cpu_predictor, Test) {
HostDeviceVector<float> out_predictions;
cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0);
std::vector<float>& out_predictions_h = out_predictions.HostVector();
for (int i = 0; i < out_predictions.Size(); i++) {
for (size_t i = 0; i < out_predictions.Size(); i++) {
ASSERT_EQ(out_predictions_h[i], 1.5);
}
// Test predict instance
auto &batch = *(*dmat)->GetRowBatches().begin();
for (int i = 0; i < batch.Size(); i++) {
for (size_t i = 0; i < batch.Size(); i++) {
std::vector<float> instance_out_predictions;
cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model);
ASSERT_EQ(instance_out_predictions[0], 1.5);

View File

@ -53,27 +53,43 @@ TEST(GpuHist, DeviceHistogram) {
}
}
};
}
namespace {
class HistogramCutsWrapper : public common::HistogramCuts {
public:
using SuperT = common::HistogramCuts;
void SetValues(std::vector<float> cuts) {
SuperT::cut_values_ = cuts;
}
void SetPtrs(std::vector<uint32_t> ptrs) {
SuperT::cut_ptrs_ = ptrs;
}
void SetMins(std::vector<float> mins) {
SuperT::min_vals_ = mins;
}
};
} // anonymous namespace
template <typename GradientSumT>
void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
bst_float sparsity=0) {
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
common::HistCutMatrix cmat;
cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24};
cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f};
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
// 24 cut fields, 3 cut fields for each feature (column).
cmat.cut = {0.30f, 0.67f, 1.64f,
cmat.SetValues({0.30f, 0.67f, 1.64f,
0.32f, 0.77f, 1.95f,
0.29f, 0.70f, 1.80f,
0.32f, 0.75f, 1.85f,
0.18f, 0.59f, 1.69f,
0.25f, 0.74f, 2.00f,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f};
0.26f, 0.71f, 1.83f});
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
auto is_dense = (*dmat)->Info().num_nonzero_ ==
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
@ -241,20 +257,20 @@ TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPair>(true);
}
common::HistCutMatrix GetHostCutMatrix () {
common::HistCutMatrix cmat;
cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24};
cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f};
HistogramCutsWrapper GetHostCutMatrix () {
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
// 24 cut fields, 3 cut fields for each feature (column).
// Each row of the cut represents the cuts for a data column.
cmat.cut = {0.30f, 0.67f, 1.64f,
cmat.SetValues({0.30f, 0.67f, 1.64f,
0.32f, 0.77f, 1.95f,
0.29f, 0.70f, 1.80f,
0.32f, 0.75f, 1.85f,
0.18f, 0.59f, 1.69f,
0.25f, 0.74f, 2.00f,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f};
0.26f, 0.71f, 1.83f});
return cmat;
}
@ -293,21 +309,21 @@ TEST(GpuHist, EvaluateSplits) {
shard->node_sum_gradients = {{6.4f, 12.8f}};
// Initialize DeviceShard::cut
common::HistCutMatrix cmat = GetHostCutMatrix();
auto cmat = GetHostCutMatrix();
// Copy cut matrix to device.
shard->ba.Allocate(0,
&(shard->feature_segments), cmat.row_ptr.size(),
&(shard->min_fvalue), cmat.min_val.size(),
&(shard->feature_segments), cmat.Ptrs().size(),
&(shard->min_fvalue), cmat.MinValues().size(),
&(shard->gidx_fvalue_map), 24,
&(shard->monotone_constraints), kNCols);
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr);
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut);
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.Ptrs());
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.Values());
dh::CopyVectorToDeviceSpan(shard->monotone_constraints,
param.monotone_constraints);
shard->ellpack_matrix.feature_segments = shard->feature_segments;
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val);
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.MinValues());
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
// Initialize DeviceShard::hist

View File

@ -13,7 +13,7 @@ namespace xgboost {
namespace tree {
TEST(Updater, Prune) {
int constexpr kNRows = 32, kNCols = 16;
int constexpr kNCols = 16;
std::vector<std::pair<std::string, std::string>> cfg;
cfg.emplace_back(std::pair<std::string, std::string>(

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2018 by Contributors
* Copyright 2018-2019 by Contributors
*/
#include "../helpers.h"
#include "../../../src/tree/param.h"
@ -46,23 +46,25 @@ class QuantileHistMock : public QuantileHistMaker {
const size_t num_row = p_fmat->Info().num_row_;
const size_t num_col = p_fmat->Info().num_col_;
/* Validate HistCutMatrix */
ASSERT_EQ(gmat.cut.row_ptr.size(), num_col + 1);
ASSERT_EQ(gmat.cut.Ptrs().size(), num_col + 1);
for (size_t fid = 0; fid < num_col; ++fid) {
// Each feature must have at least one quantile point (cut)
const size_t ibegin = gmat.cut.row_ptr[fid];
const size_t iend = gmat.cut.row_ptr[fid + 1];
ASSERT_LT(ibegin, iend);
const size_t ibegin = gmat.cut.Ptrs()[fid];
const size_t iend = gmat.cut.Ptrs()[fid + 1];
// Ordered, but empty feature is allowed.
ASSERT_LE(ibegin, iend);
for (size_t i = ibegin; i < iend - 1; ++i) {
// Quantile points must be sorted in ascending order
// No duplicates allowed
ASSERT_LT(gmat.cut.cut[i], gmat.cut.cut[i + 1]);
ASSERT_LT(gmat.cut.Values()[i], gmat.cut.Values()[i + 1])
<< "ibegin: " << ibegin << ", "
<< "iend: " << iend;
}
}
/* Validate GHistIndexMatrix */
ASSERT_EQ(gmat.row_ptr.size(), num_row + 1);
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
gmat.cut.row_ptr.back());
gmat.cut.Ptrs().back());
for (const auto& batch : p_fmat->GetRowBatches()) {
for (size_t i = 0; i < batch.Size(); ++i) {
const size_t rid = batch.base_rowid + i;
@ -71,20 +73,20 @@ class QuantileHistMock : public QuantileHistMaker {
ASSERT_LT(gmat_row_offset, gmat.index.size());
SparsePage::Inst inst = batch[i];
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
for (size_t j = 0; j < inst.size(); ++j) {
for (int64_t j = 0; j < inst.size(); ++j) {
// Each entry of GHistIndexMatrix represents a bin ID
const size_t bin_id = gmat.index[gmat_row_offset + j];
const size_t fid = inst[j].index;
// The bin ID must correspond to correct feature
ASSERT_GE(bin_id, gmat.cut.row_ptr[fid]);
ASSERT_LT(bin_id, gmat.cut.row_ptr[fid + 1]);
ASSERT_GE(bin_id, gmat.cut.Ptrs()[fid]);
ASSERT_LT(bin_id, gmat.cut.Ptrs()[fid + 1]);
// The bin ID must correspond to a region between two
// suitable quantile points
ASSERT_LT(inst[j].fvalue, gmat.cut.cut[bin_id]);
if (bin_id > gmat.cut.row_ptr[fid]) {
ASSERT_GE(inst[j].fvalue, gmat.cut.cut[bin_id - 1]);
ASSERT_LT(inst[j].fvalue, gmat.cut.Values()[bin_id]);
if (bin_id > gmat.cut.Ptrs()[fid]) {
ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]);
} else {
ASSERT_GE(inst[j].fvalue, gmat.cut.min_val[fid]);
ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]);
}
}
}
@ -106,11 +108,12 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<std::vector<uint8_t>> hist_is_init;
std::vector<ExpandEntry> nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)};
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, gpair, &hist_buffers, &hist_is_init);
RealImpl::InitNewNode(nid, gmat, gpair, fmat, const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
RealImpl::InitNewNode(nid, gmat, gpair, fmat,
const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers);
// Check if number of histogram bins is correct
ASSERT_EQ(hist_[nid].size(), gmat.cut.row_ptr.back());
ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back());
std::vector<GradientPairPrecise> histogram_expected(hist_[nid].size());
// Compute the correct histogram (histogram_expected)
@ -126,7 +129,7 @@ class QuantileHistMock : public QuantileHistMaker {
}
// Now validate the computed histogram returned by BuildHist
for (size_t i = 0; i < hist_[nid].size(); ++i) {
for (int64_t i = 0; i < hist_[nid].size(); ++i) {
GradientPairPrecise sol = histogram_expected[i];
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
@ -152,7 +155,8 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<std::vector<float*>> hist_buffers;
std::vector<std::vector<uint8_t>> hist_is_init;
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init);
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat),
const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
/* Compute correct split (best_split) using the computed histogram */
@ -178,8 +182,8 @@ class QuantileHistMock : public QuantileHistMaker {
size_t best_split_feature = std::numeric_limits<size_t>::max();
// Enumerate all features
for (size_t fid = 0; fid < num_feature; ++fid) {
const size_t bin_id_min = gmat.cut.row_ptr[fid];
const size_t bin_id_max = gmat.cut.row_ptr[fid + 1];
const size_t bin_id_min = gmat.cut.Ptrs()[fid];
const size_t bin_id_max = gmat.cut.Ptrs()[fid + 1];
// Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible
// choice of thresholds for feature fid
for (size_t split_thresh = bin_id_min;
@ -217,7 +221,7 @@ class QuantileHistMock : public QuantileHistMaker {
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.cut[best_split_threshold]);
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
delete dmat;
}