Testing hist_util (#5251)
* Rank tests * Remove categorical split specialisation * Extend tests to multiple features, switch to WQSketch * Add tests for SparseCuts * Add external memory quantile tests, fix some existing tests
This commit is contained in:
parent
911a902835
commit
24ad9dec0b
@ -70,7 +70,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
bool const use_group_ind,
|
||||
uint32_t beg_col, uint32_t end_col,
|
||||
uint32_t thread_id) {
|
||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||
CHECK_GE(end_col, beg_col);
|
||||
constexpr float kFactor = 8;
|
||||
|
||||
@ -80,7 +79,7 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
|
||||
for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) {
|
||||
// Using a local variable makes things easier, but at the cost of memory trashing.
|
||||
WXQSketch sketch;
|
||||
WQSketch sketch;
|
||||
common::Span<xgboost::Entry const> const column = page[col_id];
|
||||
uint32_t const n_bins = std::min(static_cast<uint32_t>(column.size()),
|
||||
max_num_bins);
|
||||
@ -104,18 +103,18 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
sketch.Push(entry.fvalue, info.GetWeight(weight_ind));
|
||||
}
|
||||
|
||||
WXQSketch::SummaryContainer out_summary;
|
||||
WQSketch::SummaryContainer out_summary;
|
||||
sketch.GetSummary(&out_summary);
|
||||
WXQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_bins);
|
||||
summary.SetPrune(out_summary, n_bins);
|
||||
WQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_bins + 1);
|
||||
summary.SetPrune(out_summary, n_bins + 1);
|
||||
|
||||
// Can be use data[1] as the min values so that we don't need to
|
||||
// store another array?
|
||||
float mval = summary.data[0].value;
|
||||
p_cuts_->min_vals_[col_id - beg_col] = mval - (fabs(mval) + 1e-5);
|
||||
|
||||
this->AddCutPoint(summary);
|
||||
this->AddCutPoint(summary, max_num_bins);
|
||||
|
||||
bst_float cpt = (summary.size > 0) ?
|
||||
summary.data[summary.size - 1].value :
|
||||
@ -234,7 +233,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
|
||||
// safe factor for better accuracy
|
||||
constexpr int kFactor = 8;
|
||||
std::vector<WXQSketch> sketchs;
|
||||
std::vector<WQSketch> sketchs;
|
||||
|
||||
const int nthread = omp_get_max_threads();
|
||||
|
||||
@ -292,34 +291,34 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
}
|
||||
|
||||
void DenseCuts::Init
|
||||
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<WXQSketch>& sketchs = *in_sketchs;
|
||||
std::vector<WQSketch>& sketchs = *in_sketchs;
|
||||
constexpr int kFactor = 8;
|
||||
// gather the histogram data
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WQSketch::SummaryContainer> summary_array;
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
WXQSketch::SummaryContainer out;
|
||||
WQSketch::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].Reserve(max_num_bins * kFactor);
|
||||
summary_array[i].SetPrune(out, max_num_bins * kFactor);
|
||||
}
|
||||
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
|
||||
// we need to move this allreduce before loadcheckpoint call in future
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
p_cuts_->min_vals_.resize(sketchs.size());
|
||||
|
||||
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
|
||||
WXQSketch::SummaryContainer a;
|
||||
a.Reserve(max_num_bins);
|
||||
a.SetPrune(summary_array[fid], max_num_bins);
|
||||
WQSketch::SummaryContainer a;
|
||||
a.Reserve(max_num_bins + 1);
|
||||
a.SetPrune(summary_array[fid], max_num_bins + 1);
|
||||
const bst_float mval = a.data[0].value;
|
||||
p_cuts_->min_vals_[fid] = mval - (fabs(mval) + 1e-5);
|
||||
AddCutPoint(a);
|
||||
AddCutPoint(a, max_num_bins);
|
||||
// push a value that is greater than anything
|
||||
const bst_float cpt
|
||||
= (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_[fid];
|
||||
|
||||
@ -25,9 +25,9 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
using WXQSketch = DenseCuts::WXQSketch;
|
||||
using WQSketch = DenseCuts::WQSketch;
|
||||
|
||||
__global__ void FindCutsK(WXQSketch::Entry* __restrict__ cuts,
|
||||
__global__ void FindCutsK(WQSketch::Entry* __restrict__ cuts,
|
||||
const bst_float* __restrict__ data,
|
||||
const float* __restrict__ cum_weights,
|
||||
int nsamples,
|
||||
@ -52,7 +52,7 @@ __global__ void FindCutsK(WXQSketch::Entry* __restrict__ cuts,
|
||||
// repeated values will be filtered out on the CPU
|
||||
bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0;
|
||||
bst_float rmax = cum_weights[isample];
|
||||
cuts[icut] = WXQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]);
|
||||
cuts[icut] = WQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]);
|
||||
}
|
||||
|
||||
// predictate for thrust filtering that returns true if the element is not a NaN
|
||||
@ -97,7 +97,7 @@ __global__ void UnpackFeaturesK(float* __restrict__ fvalues,
|
||||
* across distinct rows.
|
||||
*/
|
||||
struct SketchContainer {
|
||||
std::vector<DenseCuts::WXQSketch> sketches_; // NOLINT
|
||||
std::vector<DenseCuts::WQSketch> sketches_; // NOLINT
|
||||
std::vector<std::mutex> col_locks_; // NOLINT
|
||||
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
||||
|
||||
@ -245,11 +245,11 @@ class GPUSketcher {
|
||||
if (n_cuts_ > n_unique) {
|
||||
float* weights2_ptr = weights2_.data().get();
|
||||
float* fvalues_ptr = fvalues_cur_.data().get();
|
||||
WXQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_;
|
||||
WQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_;
|
||||
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
|
||||
bst_float rmax = weights2_ptr[i];
|
||||
bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0;
|
||||
cuts_ptr[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]);
|
||||
cuts_ptr[i] = WQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]);
|
||||
});
|
||||
} else if (n_cuts_cur_[icol] > 0) {
|
||||
// if more elements than cuts: use binary search on cumulative weights
|
||||
@ -287,7 +287,7 @@ class GPUSketcher {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bin_);
|
||||
size_t dummy_nlevel;
|
||||
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
||||
WQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
||||
|
||||
// allocate necessary GPU buffers
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
@ -425,7 +425,7 @@ class GPUSketcher {
|
||||
#pragma omp parallel for default(none) schedule(static) \
|
||||
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
WXQSketch::SummaryContainer summary;
|
||||
WQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_cuts_);
|
||||
summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
|
||||
|
||||
@ -450,8 +450,8 @@ if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
dh::device_vector<bst_float> fvalues_{};
|
||||
dh::device_vector<bst_float> feature_weights_{};
|
||||
dh::device_vector<bst_float> fvalues_cur_{};
|
||||
dh::device_vector<WXQSketch::Entry> cuts_d_{};
|
||||
thrust::host_vector<WXQSketch::Entry> cuts_h_{};
|
||||
dh::device_vector<WQSketch::Entry> cuts_d_{};
|
||||
thrust::host_vector<WQSketch::Entry> cuts_h_{};
|
||||
dh::device_vector<bst_float> weights_{};
|
||||
dh::device_vector<bst_float> weights2_{};
|
||||
std::vector<size_t> n_cuts_cur_{};
|
||||
|
||||
@ -101,6 +101,7 @@ struct SimpleArray {
|
||||
using GHistIndexRow = Span<uint32_t const>;
|
||||
|
||||
// A CSC matrix representing histogram cuts, used in CPU quantile hist.
|
||||
// The cut values represent upper bounds of bins containing approximately equal numbers of elements
|
||||
class HistogramCuts {
|
||||
// Using friends to avoid creating a virtual class, since HistogramCuts is used as value
|
||||
// object in many places.
|
||||
@ -147,7 +148,9 @@ class HistogramCuts {
|
||||
|
||||
size_t TotalBins() const { return cut_ptrs_.back(); }
|
||||
|
||||
BinIdx SearchBin(float value, uint32_t column_id) {
|
||||
// Return the index of a cut point that is strictly greater than the input
|
||||
// value, or the last available index if none exists
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
auto beg = cut_ptrs_.at(column_id);
|
||||
auto end = cut_ptrs_.at(column_id + 1);
|
||||
auto it = std::upper_bound(cut_values_.cbegin() + beg, cut_values_.cbegin() + end, value);
|
||||
@ -171,7 +174,7 @@ class HistogramCuts {
|
||||
*/
|
||||
class CutsBuilder {
|
||||
public:
|
||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||
using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
|
||||
|
||||
protected:
|
||||
HistogramCuts* p_cuts_;
|
||||
@ -195,21 +198,12 @@ class CutsBuilder {
|
||||
return group_ind;
|
||||
}
|
||||
|
||||
void AddCutPoint(WXQSketch::SummaryContainer const& summary) {
|
||||
if (summary.size > 1 && summary.size <= 16) {
|
||||
/* specialized code categorial / ordinal data -- use midpoints */
|
||||
for (size_t i = 1; i < summary.size; ++i) {
|
||||
bst_float cpt = (summary.data[i].value + summary.data[i - 1].value) / 2.0f;
|
||||
if (i == 1 || cpt > p_cuts_->cut_values_.back()) {
|
||||
p_cuts_->cut_values_.push_back(cpt);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 2; i < summary.size; ++i) {
|
||||
bst_float cpt = summary.data[i - 1].value;
|
||||
if (i == 2 || cpt > p_cuts_->cut_values_.back()) {
|
||||
p_cuts_->cut_values_.push_back(cpt);
|
||||
}
|
||||
void AddCutPoint(WQSketch::SummaryContainer const& summary, int max_bin) {
|
||||
int required_cuts = std::min(static_cast<int>(summary.size), max_bin);
|
||||
for (size_t i = 1; i < required_cuts; ++i) {
|
||||
bst_float cpt = summary.data[i].value;
|
||||
if (i == 1 || cpt > p_cuts_->cut_values_.back()) {
|
||||
p_cuts_->cut_values_.push_back(cpt);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -250,7 +244,7 @@ class DenseCuts : public CutsBuilder {
|
||||
CutsBuilder(container) {
|
||||
monitor_.Init(__FUNCTION__);
|
||||
}
|
||||
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
|
||||
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins);
|
||||
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||
};
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
#include "../../../src/common/hist_util.h"
|
||||
#include "../helpers.h"
|
||||
#include "test_hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -152,14 +153,6 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class SparseCutsWrapper : public SparseCuts {
|
||||
public:
|
||||
std::vector<uint32_t> const& ColPtrs() const { return p_cuts_->Ptrs(); }
|
||||
std::vector<float> const& ColValues() const { return p_cuts_->Values(); }
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(SparseCuts, SingleThreadedBuild) {
|
||||
size_t constexpr kRows = 267;
|
||||
size_t constexpr kCols = 31;
|
||||
@ -235,5 +228,116 @@ TEST(SparseCuts, MultiThreadedBuild) {
|
||||
omp_set_num_threads(ori_nthreads);
|
||||
}
|
||||
|
||||
TEST(hist_util, DenseCutsCategorical) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
for (auto n : sizes) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
std::vector<float> x_sorted(x);
|
||||
std::sort(x_sorted.begin(), x_sorted.end());
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense(&cuts);
|
||||
dense.Build(&dmat, num_bins);
|
||||
auto cuts_from_sketch = cuts.Values();
|
||||
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
|
||||
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
|
||||
EXPECT_GE(cuts_from_sketch.back(), x_sorted.back());
|
||||
EXPECT_EQ(cuts_from_sketch.size(), num_categories);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, DenseCutsAccuracyTest) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense(&cuts);
|
||||
dense.Build(&dmat, num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, DenseCutsExternalMemory) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
auto dmat =
|
||||
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense(&cuts);
|
||||
dense.Build(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, SparseCutsAccuracyTest) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts;
|
||||
SparseCuts sparse(&cuts);
|
||||
sparse.Build(&dmat, num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, SparseCutsCategorical) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
for (auto n : sizes) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
std::vector<float> x_sorted(x);
|
||||
std::sort(x_sorted.begin(), x_sorted.end());
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
HistogramCuts cuts;
|
||||
SparseCuts sparse(&cuts);
|
||||
sparse.Build(&dmat, num_bins);
|
||||
auto cuts_from_sketch = cuts.Values();
|
||||
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
|
||||
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
|
||||
EXPECT_GE(cuts_from_sketch.back(), x_sorted.back());
|
||||
EXPECT_EQ(cuts_from_sketch.size(), num_categories);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, SparseCutsExternalMemory) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
auto dmat =
|
||||
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts;
|
||||
SparseCuts dense(&cuts);
|
||||
dense.Build(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
159
tests/cpp/common/test_hist_util.h
Normal file
159
tests/cpp/common/test_hist_util.h
Normal file
@ -0,0 +1,159 @@
|
||||
#pragma once
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
|
||||
// Some helper functions used to test both GPU and CPU algorithms
|
||||
//
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
// Generate columns with different ranges
|
||||
inline std::vector<float> GenerateRandom(int num_rows, int num_columns) {
|
||||
std::vector<float> x(num_rows*num_columns);
|
||||
std::mt19937 rng(0);
|
||||
std::uniform_real_distribution<float> dist(0.0, 1.0);
|
||||
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
|
||||
for (auto i = 0ull; i < num_columns; i++) {
|
||||
for (auto j = 0ull; j < num_rows; j++) {
|
||||
x[j * num_columns + i] += i;
|
||||
}
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n,
|
||||
int num_categories) {
|
||||
std::vector<float> x(n);
|
||||
std::mt19937 rng(0);
|
||||
std::uniform_int_distribution<int> dist(0, num_categories - 1);
|
||||
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
|
||||
// Make sure each category is present
|
||||
for(auto i = 0ull; i < num_categories; i++)
|
||||
{
|
||||
x[i] = i;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
inline data::SimpleDMatrix GetDMatrixFromData(const std::vector<float>& x, int num_rows, int num_columns) {
|
||||
data::DenseAdapter adapter(x.data(), num_rows, num_columns);
|
||||
return data::SimpleDMatrix(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
1);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
|
||||
const std::vector<float>& x, int num_rows, int num_columns,
|
||||
size_t page_size, const dmlc::TemporaryDirectory& tempdir) {
|
||||
// Create the svm file in a temp dir
|
||||
const std::string tmp_file = tempdir.path + "/temp.libsvm";
|
||||
std::ofstream fo(tmp_file.c_str());
|
||||
for (auto i = 0ull; i < num_rows; i++) {
|
||||
std::stringstream row_data;
|
||||
for (auto j = 0ull; j < num_columns; j++) {
|
||||
row_data << 1 << " " << j << ":" << std::setprecision(15)
|
||||
<< x[i * num_columns + j];
|
||||
}
|
||||
fo << row_data.str() << "\n";
|
||||
}
|
||||
fo.close();
|
||||
return std::shared_ptr<DMatrix>(DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size));
|
||||
}
|
||||
|
||||
// Test that elements are approximately equally distributed among bins
|
||||
inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx,
|
||||
const std::vector<float>& column,
|
||||
int num_bins) {
|
||||
std::map<int, int> counts;
|
||||
for (auto& v : column) {
|
||||
counts[cuts.SearchBin(v, column_idx)]++;
|
||||
}
|
||||
int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
||||
int expected_num_elements = column.size() / local_num_bins;
|
||||
// Allow about 30% deviation. This test is not very strict, it only ensures
|
||||
// roughly equal distribution
|
||||
int allowable_error = std::max(2, int(expected_num_elements * 0.3));
|
||||
|
||||
// First and last bin can have smaller
|
||||
for (auto& kv : counts) {
|
||||
EXPECT_LE(std::abs(counts[kv.first] - expected_num_elements),
|
||||
allowable_error );
|
||||
}
|
||||
}
|
||||
|
||||
// Test sketch quantiles against the real quantiles
|
||||
// Not a very strict test
|
||||
inline void TestRank(const std::vector<float>& cuts,
|
||||
const std::vector<float>& sorted_x) {
|
||||
float eps = 0.05;
|
||||
// Ignore the last cut, its special
|
||||
size_t j = 0;
|
||||
for (auto i = 0; i < cuts.size() - 1; i++) {
|
||||
int expected_rank = ((i+1) * sorted_x.size()) / cuts.size();
|
||||
while (cuts[i] > sorted_x[j]) {
|
||||
j++;
|
||||
}
|
||||
int actual_rank = j;
|
||||
int acceptable_error = std::max(2, int(sorted_x.size() * eps));
|
||||
ASSERT_LE(std::abs(expected_rank - actual_rank), acceptable_error);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
||||
const std::vector<float>& column,
|
||||
int num_bins) {
|
||||
std::vector<float> sorted_column(column);
|
||||
std::sort(sorted_column.begin(), sorted_column.end());
|
||||
|
||||
// Check the endpoints are correct
|
||||
EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front());
|
||||
EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front());
|
||||
EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back());
|
||||
|
||||
// Check the cuts are sorted
|
||||
auto cuts_begin = cuts.Values().begin() + cuts.Ptrs()[column_idx];
|
||||
auto cuts_end = cuts.Values().begin() + cuts.Ptrs()[column_idx + 1];
|
||||
EXPECT_TRUE(std::is_sorted(cuts_begin, cuts_end));
|
||||
|
||||
// Check all cut points are unique
|
||||
EXPECT_EQ(std::set<float>(cuts_begin, cuts_end).size(),
|
||||
cuts_end - cuts_begin);
|
||||
|
||||
if (sorted_column.size() <= num_bins) {
|
||||
// Less unique values than number of bins
|
||||
// Each value should get its own bin
|
||||
|
||||
// First check the inputs are unique
|
||||
int num_unique =
|
||||
std::set<float>(sorted_column.begin(), sorted_column.end()).size();
|
||||
EXPECT_EQ(num_unique, sorted_column.size());
|
||||
for (auto i = 0ull; i < sorted_column.size(); i++) {
|
||||
ASSERT_EQ(cuts.SearchBin(sorted_column[i], column_idx),
|
||||
cuts.Ptrs()[column_idx] + i);
|
||||
}
|
||||
}
|
||||
int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
||||
std::vector<float> column_cuts(num_cuts_column);
|
||||
std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx],
|
||||
cuts.Values().begin() + cuts.Ptrs()[column_idx + 1],
|
||||
column_cuts.begin());
|
||||
TestBinDistribution(cuts, column_idx, sorted_column, num_bins);
|
||||
TestRank(column_cuts, sorted_column);
|
||||
}
|
||||
|
||||
// x is dense and row major
|
||||
inline void ValidateCuts(const HistogramCuts& cuts, std::vector<float>& x,
|
||||
int num_rows, int num_columns,
|
||||
int num_bins) {
|
||||
for (auto i = 0ull; i < num_columns; i++) {
|
||||
// Extract the column
|
||||
std::vector<float> column(num_rows);
|
||||
for (auto j = 0ull; j < num_rows; j++) {
|
||||
column[j] = x[j*num_columns + i];
|
||||
}
|
||||
ValidateColumn(cuts,i, column, num_bins);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -228,17 +228,23 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
||||
std::stringstream row_data;
|
||||
size_t j = 0;
|
||||
if (rem_cols > 0) {
|
||||
for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) {
|
||||
row_data << label(*gen) << " " << (col_idx+j) << ":" << (col_idx+j+1)*10*i;
|
||||
}
|
||||
rem_cols -= cols_per_row;
|
||||
for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) {
|
||||
row_data << label(*gen) << " " << (col_idx + j) << ":"
|
||||
<< (col_idx + j + 1) * 10 * i;
|
||||
}
|
||||
rem_cols -= cols_per_row;
|
||||
} else {
|
||||
// Take some random number of colums in [1, n_cols] and slot them here
|
||||
size_t ncols = dis(*gen);
|
||||
for (; j < ncols; ++j) {
|
||||
size_t fid = (col_idx+j) % n_cols;
|
||||
row_data << label(*gen) << " " << fid << ":" << (fid+1)*10*i;
|
||||
}
|
||||
// Take some random number of colums in [1, n_cols] and slot them here
|
||||
std::vector<size_t> random_columns;
|
||||
size_t ncols = dis(*gen);
|
||||
for (; j < ncols; ++j) {
|
||||
size_t fid = (col_idx + j) % n_cols;
|
||||
random_columns.push_back(fid);
|
||||
}
|
||||
std::sort(random_columns.begin(), random_columns.end());
|
||||
for (auto fid : random_columns) {
|
||||
row_data << label(*gen) << " " << fid << ":" << (fid + 1) * 10 * i;
|
||||
}
|
||||
}
|
||||
col_idx += j;
|
||||
|
||||
|
||||
@ -342,20 +342,17 @@ TEST(GpuHist, MinSplitLoss) {
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair,
|
||||
DMatrix* dmat,
|
||||
size_t gpu_page_size,
|
||||
RegTree* tree,
|
||||
HostDeviceVector<bst_float>* preds,
|
||||
float subsample = 1.0f,
|
||||
const std::string& sampling_method = "uniform") {
|
||||
constexpr size_t kMaxBin = 2;
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
size_t gpu_page_size, RegTree* tree,
|
||||
HostDeviceVector<bst_float>* preds, float subsample = 1.0f,
|
||||
const std::string& sampling_method = "uniform",
|
||||
int max_bin = 2) {
|
||||
|
||||
if (gpu_page_size > 0) {
|
||||
// Loop over the batches and count the records
|
||||
int64_t batch_count = 0;
|
||||
int64_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, kMaxBin, 0, gpu_page_size})) {
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, max_bin, 0, gpu_page_size})) {
|
||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
@ -366,7 +363,7 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair,
|
||||
|
||||
Args args{
|
||||
{"max_depth", "2"},
|
||||
{"max_bin", std::to_string(kMaxBin)},
|
||||
{"max_bin", std::to_string(max_bin)},
|
||||
{"min_child_weight", "0.0"},
|
||||
{"reg_alpha", "0"},
|
||||
{"reg_lambda", "0"},
|
||||
@ -386,7 +383,7 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair,
|
||||
TEST(GpuHist, UniformSampling) {
|
||||
constexpr size_t kRows = 4096;
|
||||
constexpr size_t kCols = 2;
|
||||
constexpr float kSubsample = 0.99;
|
||||
constexpr float kSubsample = 0.9999;
|
||||
common::GlobalRandom().seed(1994);
|
||||
|
||||
// Create an in-memory DMatrix.
|
||||
@ -397,25 +394,25 @@ TEST(GpuHist, UniformSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds);
|
||||
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
||||
"uniform", kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
auto preds_sampling_h = preds_sampling.ConstHostVector();
|
||||
for (int i = 0; i < kRows; i++) {
|
||||
EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 2e-3);
|
||||
EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuHist, GradientBasedSampling) {
|
||||
constexpr size_t kRows = 4096;
|
||||
constexpr size_t kCols = 2;
|
||||
constexpr float kSubsample = 0.99;
|
||||
constexpr float kSubsample = 0.9999;
|
||||
common::GlobalRandom().seed(1994);
|
||||
|
||||
// Create an in-memory DMatrix.
|
||||
@ -426,12 +423,13 @@ TEST(GpuHist, GradientBasedSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "gradient_based");
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
||||
"gradient_based", kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
@ -459,18 +457,17 @@ TEST(GpuHist, ExternalMemory) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds);
|
||||
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext);
|
||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
auto preds_ext_h = preds_ext.ConstHostVector();
|
||||
for (int i = 0; i < kRows; i++) {
|
||||
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-6);
|
||||
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
@ -495,12 +492,14 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod,
|
||||
kRows);
|
||||
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample, kSamplingMethod);
|
||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext,
|
||||
kSubsample, kSamplingMethod, kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
|
||||
@ -13,7 +13,7 @@ def assert_gpu_results(cpu_results, gpu_results):
|
||||
for cpu_res, gpu_res in zip(cpu_results, gpu_results):
|
||||
# Check final eval result roughly equivalent
|
||||
assert np.allclose(cpu_res["eval"][-1],
|
||||
gpu_res["eval"][-1], 1e-2, 1e-2)
|
||||
gpu_res["eval"][-1], 1e-1, 1e-1)
|
||||
|
||||
|
||||
datasets = ["Boston", "Cancer", "Digits", "Sparse regression",
|
||||
@ -23,7 +23,7 @@ test_param = parameter_combinations({
|
||||
'gpu_id': [0],
|
||||
'max_depth': [2, 8],
|
||||
'max_leaves': [255, 4],
|
||||
'max_bin': [2, 256],
|
||||
'max_bin': [4, 256],
|
||||
'grow_policy': ['lossguide'],
|
||||
'single_precision_histogram': [True],
|
||||
'min_child_weight': [0],
|
||||
|
||||
@ -132,8 +132,8 @@ def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False):
|
||||
"""
|
||||
datasets = [
|
||||
Dataset("Boston", get_boston, "reg:squarederror", "rmse"),
|
||||
Dataset("Digits", get_digits, "multi:softmax", "merror"),
|
||||
Dataset("Cancer", get_cancer, "binary:logistic", "error"),
|
||||
Dataset("Digits", get_digits, "multi:softmax", "mlogloss"),
|
||||
Dataset("Cancer", get_cancer, "binary:logistic", "logloss"),
|
||||
Dataset("Sparse regression", get_sparse, "reg:squarederror", "rmse"),
|
||||
Dataset("Sparse regression with weights", get_sparse_weights,
|
||||
"reg:squarederror", "rmse", has_weights=True),
|
||||
|
||||
@ -53,7 +53,7 @@ def assert_classification_result(results):
|
||||
r["param"]["objective"] != "reg:squarederror"]
|
||||
for res in classification_results:
|
||||
# Check accuracy is reasonable
|
||||
assert res["eval"][-1] < 0.5, (res["dataset"].name, res["eval"][-1])
|
||||
assert res["eval"][-1] < 2.0, (res["dataset"].name, res["eval"][-1])
|
||||
|
||||
|
||||
class TestLinear(unittest.TestCase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user