diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 5dc9049c4..75e3fa586 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -70,7 +70,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, bool const use_group_ind, uint32_t beg_col, uint32_t end_col, uint32_t thread_id) { - using WXQSketch = common::WXQuantileSketch; CHECK_GE(end_col, beg_col); constexpr float kFactor = 8; @@ -80,7 +79,7 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) { // Using a local variable makes things easier, but at the cost of memory trashing. - WXQSketch sketch; + WQSketch sketch; common::Span const column = page[col_id]; uint32_t const n_bins = std::min(static_cast(column.size()), max_num_bins); @@ -104,18 +103,18 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, sketch.Push(entry.fvalue, info.GetWeight(weight_ind)); } - WXQSketch::SummaryContainer out_summary; + WQSketch::SummaryContainer out_summary; sketch.GetSummary(&out_summary); - WXQSketch::SummaryContainer summary; - summary.Reserve(n_bins); - summary.SetPrune(out_summary, n_bins); + WQSketch::SummaryContainer summary; + summary.Reserve(n_bins + 1); + summary.SetPrune(out_summary, n_bins + 1); // Can be use data[1] as the min values so that we don't need to // store another array? float mval = summary.data[0].value; p_cuts_->min_vals_[col_id - beg_col] = mval - (fabs(mval) + 1e-5); - this->AddCutPoint(summary); + this->AddCutPoint(summary, max_num_bins); bst_float cpt = (summary.size > 0) ? summary.data[summary.size - 1].value : @@ -234,7 +233,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { // safe factor for better accuracy constexpr int kFactor = 8; - std::vector sketchs; + std::vector sketchs; const int nthread = omp_get_max_threads(); @@ -292,34 +291,34 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { } void DenseCuts::Init -(std::vector* in_sketchs, uint32_t max_num_bins) { +(std::vector* in_sketchs, uint32_t max_num_bins) { monitor_.Start(__func__); - std::vector& sketchs = *in_sketchs; + std::vector& sketchs = *in_sketchs; constexpr int kFactor = 8; // gather the histogram data - rabit::SerializeReducer sreducer; - std::vector summary_array; + rabit::SerializeReducer sreducer; + std::vector summary_array; summary_array.resize(sketchs.size()); for (size_t i = 0; i < sketchs.size(); ++i) { - WXQSketch::SummaryContainer out; + WQSketch::SummaryContainer out; sketchs[i].GetSummary(&out); summary_array[i].Reserve(max_num_bins * kFactor); summary_array[i].SetPrune(out, max_num_bins * kFactor); } CHECK_EQ(summary_array.size(), in_sketchs->size()); - size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor); + size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor); // TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint // we need to move this allreduce before loadcheckpoint call in future sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); p_cuts_->min_vals_.resize(sketchs.size()); for (size_t fid = 0; fid < summary_array.size(); ++fid) { - WXQSketch::SummaryContainer a; - a.Reserve(max_num_bins); - a.SetPrune(summary_array[fid], max_num_bins); + WQSketch::SummaryContainer a; + a.Reserve(max_num_bins + 1); + a.SetPrune(summary_array[fid], max_num_bins + 1); const bst_float mval = a.data[0].value; p_cuts_->min_vals_[fid] = mval - (fabs(mval) + 1e-5); - AddCutPoint(a); + AddCutPoint(a, max_num_bins); // push a value that is greater than anything const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_[fid]; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index b23777458..2420ef36a 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -25,9 +25,9 @@ namespace xgboost { namespace common { -using WXQSketch = DenseCuts::WXQSketch; +using WQSketch = DenseCuts::WQSketch; -__global__ void FindCutsK(WXQSketch::Entry* __restrict__ cuts, +__global__ void FindCutsK(WQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data, const float* __restrict__ cum_weights, int nsamples, @@ -52,7 +52,7 @@ __global__ void FindCutsK(WXQSketch::Entry* __restrict__ cuts, // repeated values will be filtered out on the CPU bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0; bst_float rmax = cum_weights[isample]; - cuts[icut] = WXQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]); + cuts[icut] = WQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]); } // predictate for thrust filtering that returns true if the element is not a NaN @@ -97,7 +97,7 @@ __global__ void UnpackFeaturesK(float* __restrict__ fvalues, * across distinct rows. */ struct SketchContainer { - std::vector sketches_; // NOLINT + std::vector sketches_; // NOLINT std::vector col_locks_; // NOLINT static constexpr int kOmpNumColsParallelizeLimit = 1000; @@ -245,11 +245,11 @@ class GPUSketcher { if (n_cuts_ > n_unique) { float* weights2_ptr = weights2_.data().get(); float* fvalues_ptr = fvalues_cur_.data().get(); - WXQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_; + WQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_; dh::LaunchN(device_, n_unique, [=]__device__(size_t i) { bst_float rmax = weights2_ptr[i]; bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0; - cuts_ptr[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]); + cuts_ptr[i] = WQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]); }); } else if (n_cuts_cur_[icol] > 0) { // if more elements than cuts: use binary search on cumulative weights @@ -287,7 +287,7 @@ class GPUSketcher { constexpr int kFactor = 8; double eps = 1.0 / (kFactor * max_bin_); size_t dummy_nlevel; - WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); + WQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); // allocate necessary GPU buffers dh::safe_cuda(cudaSetDevice(device_)); @@ -425,7 +425,7 @@ class GPUSketcher { #pragma omp parallel for default(none) schedule(static) \ if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT for (int icol = 0; icol < num_cols_; ++icol) { - WXQSketch::SummaryContainer summary; + WQSketch::SummaryContainer summary; summary.Reserve(n_cuts_); summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); @@ -450,8 +450,8 @@ if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT dh::device_vector fvalues_{}; dh::device_vector feature_weights_{}; dh::device_vector fvalues_cur_{}; - dh::device_vector cuts_d_{}; - thrust::host_vector cuts_h_{}; + dh::device_vector cuts_d_{}; + thrust::host_vector cuts_h_{}; dh::device_vector weights_{}; dh::device_vector weights2_{}; std::vector n_cuts_cur_{}; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 115bae32c..a47eae4ae 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -101,6 +101,7 @@ struct SimpleArray { using GHistIndexRow = Span; // A CSC matrix representing histogram cuts, used in CPU quantile hist. +// The cut values represent upper bounds of bins containing approximately equal numbers of elements class HistogramCuts { // Using friends to avoid creating a virtual class, since HistogramCuts is used as value // object in many places. @@ -147,7 +148,9 @@ class HistogramCuts { size_t TotalBins() const { return cut_ptrs_.back(); } - BinIdx SearchBin(float value, uint32_t column_id) { + // Return the index of a cut point that is strictly greater than the input + // value, or the last available index if none exists + BinIdx SearchBin(float value, uint32_t column_id) const { auto beg = cut_ptrs_.at(column_id); auto end = cut_ptrs_.at(column_id + 1); auto it = std::upper_bound(cut_values_.cbegin() + beg, cut_values_.cbegin() + end, value); @@ -171,7 +174,7 @@ class HistogramCuts { */ class CutsBuilder { public: - using WXQSketch = common::WXQuantileSketch; + using WQSketch = common::WQuantileSketch; protected: HistogramCuts* p_cuts_; @@ -195,21 +198,12 @@ class CutsBuilder { return group_ind; } - void AddCutPoint(WXQSketch::SummaryContainer const& summary) { - if (summary.size > 1 && summary.size <= 16) { - /* specialized code categorial / ordinal data -- use midpoints */ - for (size_t i = 1; i < summary.size; ++i) { - bst_float cpt = (summary.data[i].value + summary.data[i - 1].value) / 2.0f; - if (i == 1 || cpt > p_cuts_->cut_values_.back()) { - p_cuts_->cut_values_.push_back(cpt); - } - } - } else { - for (size_t i = 2; i < summary.size; ++i) { - bst_float cpt = summary.data[i - 1].value; - if (i == 2 || cpt > p_cuts_->cut_values_.back()) { - p_cuts_->cut_values_.push_back(cpt); - } + void AddCutPoint(WQSketch::SummaryContainer const& summary, int max_bin) { + int required_cuts = std::min(static_cast(summary.size), max_bin); + for (size_t i = 1; i < required_cuts; ++i) { + bst_float cpt = summary.data[i].value; + if (i == 1 || cpt > p_cuts_->cut_values_.back()) { + p_cuts_->cut_values_.push_back(cpt); } } } @@ -250,7 +244,7 @@ class DenseCuts : public CutsBuilder { CutsBuilder(container) { monitor_.Init(__FUNCTION__); } - void Init(std::vector* sketchs, uint32_t max_num_bins); + void Init(std::vector* sketchs, uint32_t max_num_bins); void Build(DMatrix* p_fmat, uint32_t max_num_bins) override; }; diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 08c721136..2d8306d53 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -5,6 +5,7 @@ #include "../../../src/common/hist_util.h" #include "../helpers.h" +#include "test_hist_util.h" namespace xgboost { namespace common { @@ -152,14 +153,6 @@ TEST(CutsBuilder, SearchGroupInd) { delete pp_dmat; } -namespace { -class SparseCutsWrapper : public SparseCuts { - public: - std::vector const& ColPtrs() const { return p_cuts_->Ptrs(); } - std::vector const& ColValues() const { return p_cuts_->Values(); } -}; -} // anonymous namespace - TEST(SparseCuts, SingleThreadedBuild) { size_t constexpr kRows = 267; size_t constexpr kCols = 31; @@ -235,5 +228,116 @@ TEST(SparseCuts, MultiThreadedBuild) { omp_set_num_threads(ori_nthreads); } +TEST(hist_util, DenseCutsCategorical) { + int categorical_sizes[] = {2, 6, 8, 12}; + int num_bins = 256; + int sizes[] = {25, 100, 1000}; + for (auto n : sizes) { + for (auto num_categories : categorical_sizes) { + auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); + std::vector x_sorted(x); + std::sort(x_sorted.begin(), x_sorted.end()); + auto dmat = GetDMatrixFromData(x, n, 1); + HistogramCuts cuts; + DenseCuts dense(&cuts); + dense.Build(&dmat, num_bins); + auto cuts_from_sketch = cuts.Values(); + EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); + EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); + EXPECT_GE(cuts_from_sketch.back(), x_sorted.back()); + EXPECT_EQ(cuts_from_sketch.size(), num_categories); + } + } +} + +TEST(hist_util, DenseCutsAccuracyTest) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + for (auto num_bins : bin_sizes) { + HistogramCuts cuts; + DenseCuts dense(&cuts); + dense.Build(&dmat, num_bins); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } + } +} + +TEST(hist_util, DenseCutsExternalMemory) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + dmlc::TemporaryDirectory tmpdir; + auto dmat = + GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir); + for (auto num_bins : bin_sizes) { + HistogramCuts cuts; + DenseCuts dense(&cuts); + dense.Build(dmat.get(), num_bins); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } + } +} + +TEST(hist_util, SparseCutsAccuracyTest) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + for (auto num_bins : bin_sizes) { + HistogramCuts cuts; + SparseCuts sparse(&cuts); + sparse.Build(&dmat, num_bins); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } + } +} + +TEST(hist_util, SparseCutsCategorical) { + int categorical_sizes[] = {2, 6, 8, 12}; + int num_bins = 256; + int sizes[] = {25, 100, 1000}; + for (auto n : sizes) { + for (auto num_categories : categorical_sizes) { + auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); + std::vector x_sorted(x); + std::sort(x_sorted.begin(), x_sorted.end()); + auto dmat = GetDMatrixFromData(x, n, 1); + HistogramCuts cuts; + SparseCuts sparse(&cuts); + sparse.Build(&dmat, num_bins); + auto cuts_from_sketch = cuts.Values(); + EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); + EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); + EXPECT_GE(cuts_from_sketch.back(), x_sorted.back()); + EXPECT_EQ(cuts_from_sketch.size(), num_categories); + } + } +} + +TEST(hist_util, SparseCutsExternalMemory) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + dmlc::TemporaryDirectory tmpdir; + auto dmat = + GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir); + for (auto num_bins : bin_sizes) { + HistogramCuts cuts; + SparseCuts dense(&cuts); + dense.Build(dmat.get(), num_bins); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } + } +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h new file mode 100644 index 000000000..60f920fe9 --- /dev/null +++ b/tests/cpp/common/test_hist_util.h @@ -0,0 +1,159 @@ +#pragma once +#include +#include "../../../src/data/simple_dmatrix.h" + +// Some helper functions used to test both GPU and CPU algorithms +// +namespace xgboost { +namespace common { + + // Generate columns with different ranges +inline std::vector GenerateRandom(int num_rows, int num_columns) { + std::vector x(num_rows*num_columns); + std::mt19937 rng(0); + std::uniform_real_distribution dist(0.0, 1.0); + std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); + for (auto i = 0ull; i < num_columns; i++) { + for (auto j = 0ull; j < num_rows; j++) { + x[j * num_columns + i] += i; + } + } + return x; +} + +inline std::vector GenerateRandomCategoricalSingleColumn(int n, + int num_categories) { + std::vector x(n); + std::mt19937 rng(0); + std::uniform_int_distribution dist(0, num_categories - 1); + std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); + // Make sure each category is present + for(auto i = 0ull; i < num_categories; i++) + { + x[i] = i; + } + return x; +} + +inline data::SimpleDMatrix GetDMatrixFromData(const std::vector& x, int num_rows, int num_columns) { + data::DenseAdapter adapter(x.data(), num_rows, num_columns); + return data::SimpleDMatrix(&adapter, std::numeric_limits::quiet_NaN(), + 1); +} + +inline std::shared_ptr GetExternalMemoryDMatrixFromData( + const std::vector& x, int num_rows, int num_columns, + size_t page_size, const dmlc::TemporaryDirectory& tempdir) { + // Create the svm file in a temp dir + const std::string tmp_file = tempdir.path + "/temp.libsvm"; + std::ofstream fo(tmp_file.c_str()); + for (auto i = 0ull; i < num_rows; i++) { + std::stringstream row_data; + for (auto j = 0ull; j < num_columns; j++) { + row_data << 1 << " " << j << ":" << std::setprecision(15) + << x[i * num_columns + j]; + } + fo << row_data.str() << "\n"; + } + fo.close(); + return std::shared_ptr(DMatrix::Load( + tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size)); +} + +// Test that elements are approximately equally distributed among bins +inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, + const std::vector& column, + int num_bins) { + std::map counts; + for (auto& v : column) { + counts[cuts.SearchBin(v, column_idx)]++; + } + int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; + int expected_num_elements = column.size() / local_num_bins; + // Allow about 30% deviation. This test is not very strict, it only ensures + // roughly equal distribution + int allowable_error = std::max(2, int(expected_num_elements * 0.3)); + + // First and last bin can have smaller + for (auto& kv : counts) { + EXPECT_LE(std::abs(counts[kv.first] - expected_num_elements), + allowable_error ); + } +} + + // Test sketch quantiles against the real quantiles + // Not a very strict test +inline void TestRank(const std::vector& cuts, + const std::vector& sorted_x) { + float eps = 0.05; + // Ignore the last cut, its special + size_t j = 0; + for (auto i = 0; i < cuts.size() - 1; i++) { + int expected_rank = ((i+1) * sorted_x.size()) / cuts.size(); + while (cuts[i] > sorted_x[j]) { + j++; + } + int actual_rank = j; + int acceptable_error = std::max(2, int(sorted_x.size() * eps)); + ASSERT_LE(std::abs(expected_rank - actual_rank), acceptable_error); + } +} + +inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, + const std::vector& column, + int num_bins) { + std::vector sorted_column(column); + std::sort(sorted_column.begin(), sorted_column.end()); + + // Check the endpoints are correct + EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front()); + EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back()); + + // Check the cuts are sorted + auto cuts_begin = cuts.Values().begin() + cuts.Ptrs()[column_idx]; + auto cuts_end = cuts.Values().begin() + cuts.Ptrs()[column_idx + 1]; + EXPECT_TRUE(std::is_sorted(cuts_begin, cuts_end)); + + // Check all cut points are unique + EXPECT_EQ(std::set(cuts_begin, cuts_end).size(), + cuts_end - cuts_begin); + + if (sorted_column.size() <= num_bins) { + // Less unique values than number of bins + // Each value should get its own bin + + // First check the inputs are unique + int num_unique = + std::set(sorted_column.begin(), sorted_column.end()).size(); + EXPECT_EQ(num_unique, sorted_column.size()); + for (auto i = 0ull; i < sorted_column.size(); i++) { + ASSERT_EQ(cuts.SearchBin(sorted_column[i], column_idx), + cuts.Ptrs()[column_idx] + i); + } + } + int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; + std::vector column_cuts(num_cuts_column); + std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx], + cuts.Values().begin() + cuts.Ptrs()[column_idx + 1], + column_cuts.begin()); + TestBinDistribution(cuts, column_idx, sorted_column, num_bins); + TestRank(column_cuts, sorted_column); +} + +// x is dense and row major +inline void ValidateCuts(const HistogramCuts& cuts, std::vector& x, + int num_rows, int num_columns, + int num_bins) { + for (auto i = 0ull; i < num_columns; i++) { + // Extract the column + std::vector column(num_rows); + for (auto j = 0ull; j < num_rows; j++) { + column[j] = x[j*num_columns + i]; + } + ValidateColumn(cuts,i, column, num_bins); + } +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index a479622cb..e91e40e0c 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -228,17 +228,23 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( std::stringstream row_data; size_t j = 0; if (rem_cols > 0) { - for (; j < std::min(static_cast(rem_cols), cols_per_row); ++j) { - row_data << label(*gen) << " " << (col_idx+j) << ":" << (col_idx+j+1)*10*i; - } - rem_cols -= cols_per_row; + for (; j < std::min(static_cast(rem_cols), cols_per_row); ++j) { + row_data << label(*gen) << " " << (col_idx + j) << ":" + << (col_idx + j + 1) * 10 * i; + } + rem_cols -= cols_per_row; } else { - // Take some random number of colums in [1, n_cols] and slot them here - size_t ncols = dis(*gen); - for (; j < ncols; ++j) { - size_t fid = (col_idx+j) % n_cols; - row_data << label(*gen) << " " << fid << ":" << (fid+1)*10*i; - } + // Take some random number of colums in [1, n_cols] and slot them here + std::vector random_columns; + size_t ncols = dis(*gen); + for (; j < ncols; ++j) { + size_t fid = (col_idx + j) % n_cols; + random_columns.push_back(fid); + } + std::sort(random_columns.begin(), random_columns.end()); + for (auto fid : random_columns) { + row_data << label(*gen) << " " << fid << ":" << (fid + 1) * 10 * i; + } } col_idx += j; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index d3ed0d8e3..a8d74d4d4 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -342,20 +342,17 @@ TEST(GpuHist, MinSplitLoss) { delete dmat; } -void UpdateTree(HostDeviceVector* gpair, - DMatrix* dmat, - size_t gpu_page_size, - RegTree* tree, - HostDeviceVector* preds, - float subsample = 1.0f, - const std::string& sampling_method = "uniform") { - constexpr size_t kMaxBin = 2; +void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, + size_t gpu_page_size, RegTree* tree, + HostDeviceVector* preds, float subsample = 1.0f, + const std::string& sampling_method = "uniform", + int max_bin = 2) { if (gpu_page_size > 0) { // Loop over the batches and count the records int64_t batch_count = 0; int64_t row_count = 0; - for (const auto& batch : dmat->GetBatches({0, kMaxBin, 0, gpu_page_size})) { + for (const auto& batch : dmat->GetBatches({0, max_bin, 0, gpu_page_size})) { EXPECT_LT(batch.Size(), dmat->Info().num_row_); batch_count++; row_count += batch.Size(); @@ -366,7 +363,7 @@ void UpdateTree(HostDeviceVector* gpair, Args args{ {"max_depth", "2"}, - {"max_bin", std::to_string(kMaxBin)}, + {"max_bin", std::to_string(max_bin)}, {"min_child_weight", "0.0"}, {"reg_alpha", "0"}, {"reg_lambda", "0"}, @@ -386,7 +383,7 @@ void UpdateTree(HostDeviceVector* gpair, TEST(GpuHist, UniformSampling) { constexpr size_t kRows = 4096; constexpr size_t kCols = 2; - constexpr float kSubsample = 0.99; + constexpr float kSubsample = 0.9999; common::GlobalRandom().seed(1994); // Create an in-memory DMatrix. @@ -397,25 +394,25 @@ TEST(GpuHist, UniformSampling) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, 0); - UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); - + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); // Build another tree using sampling. RegTree tree_sampling; HostDeviceVector preds_sampling(kRows, 0.0, 0); - UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample); + UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, + "uniform", kRows); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); auto preds_sampling_h = preds_sampling.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 2e-3); + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-8); } } TEST(GpuHist, GradientBasedSampling) { constexpr size_t kRows = 4096; constexpr size_t kCols = 2; - constexpr float kSubsample = 0.99; + constexpr float kSubsample = 0.9999; common::GlobalRandom().seed(1994); // Create an in-memory DMatrix. @@ -426,12 +423,13 @@ TEST(GpuHist, GradientBasedSampling) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, 0); - UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); // Build another tree using sampling. RegTree tree_sampling; HostDeviceVector preds_sampling(kRows, 0.0, 0); - UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "gradient_based"); + UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, + "gradient_based", kRows); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); @@ -459,18 +457,17 @@ TEST(GpuHist, ExternalMemory) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, 0); - UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); - + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); // Build another tree using multiple ELLPACK pages. RegTree tree_ext; HostDeviceVector preds_ext(kRows, 0.0, 0); - UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext); + UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); auto preds_ext_h = preds_ext.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-6); + EXPECT_NEAR(preds_h[i], preds_ext_h[i], 1e-6); } } @@ -495,12 +492,14 @@ TEST(GpuHist, ExternalMemoryWithSampling) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, 0); - UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, + kRows); // Build another tree using multiple ELLPACK pages. RegTree tree_ext; HostDeviceVector preds_ext(kRows, 0.0, 0); - UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample, kSamplingMethod); + UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, + kSubsample, kSamplingMethod, kRows); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 62467bfae..86daef62f 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -13,7 +13,7 @@ def assert_gpu_results(cpu_results, gpu_results): for cpu_res, gpu_res in zip(cpu_results, gpu_results): # Check final eval result roughly equivalent assert np.allclose(cpu_res["eval"][-1], - gpu_res["eval"][-1], 1e-2, 1e-2) + gpu_res["eval"][-1], 1e-1, 1e-1) datasets = ["Boston", "Cancer", "Digits", "Sparse regression", @@ -23,7 +23,7 @@ test_param = parameter_combinations({ 'gpu_id': [0], 'max_depth': [2, 8], 'max_leaves': [255, 4], - 'max_bin': [2, 256], + 'max_bin': [4, 256], 'grow_policy': ['lossguide'], 'single_precision_histogram': [True], 'min_child_weight': [0], diff --git a/tests/python/regression_test_utilities.py b/tests/python/regression_test_utilities.py index cf1ddded3..1a3b80690 100644 --- a/tests/python/regression_test_utilities.py +++ b/tests/python/regression_test_utilities.py @@ -132,8 +132,8 @@ def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False): """ datasets = [ Dataset("Boston", get_boston, "reg:squarederror", "rmse"), - Dataset("Digits", get_digits, "multi:softmax", "merror"), - Dataset("Cancer", get_cancer, "binary:logistic", "error"), + Dataset("Digits", get_digits, "multi:softmax", "mlogloss"), + Dataset("Cancer", get_cancer, "binary:logistic", "logloss"), Dataset("Sparse regression", get_sparse, "reg:squarederror", "rmse"), Dataset("Sparse regression with weights", get_sparse_weights, "reg:squarederror", "rmse", has_weights=True), diff --git a/tests/python/test_linear.py b/tests/python/test_linear.py index 31494a56c..82dcca637 100644 --- a/tests/python/test_linear.py +++ b/tests/python/test_linear.py @@ -53,7 +53,7 @@ def assert_classification_result(results): r["param"]["objective"] != "reg:squarederror"] for res in classification_results: # Check accuracy is reasonable - assert res["eval"][-1] < 0.5, (res["dataset"].name, res["eval"][-1]) + assert res["eval"][-1] < 2.0, (res["dataset"].name, res["eval"][-1]) class TestLinear(unittest.TestCase):