diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index fe1305d4b..ba7207ad2 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -253,22 +253,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page, << "Weight size should equal to number of groups."; dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { size_t element_idx = idx + begin; - size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(), - row_ptrs.end(), element_idx) - - row_ptrs.begin() - 1; - auto it = - thrust::upper_bound(thrust::seq, - d_group_ptr.cbegin(), d_group_ptr.cend(), - ridx + base_rowid) - 1; - bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it); - d_temp_weights[idx] = weights[group]; + size_t ridx = dh::SegmentId(row_ptrs, element_idx); + bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid); + d_temp_weights[idx] = weights[group_idx]; }); } else { dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { size_t element_idx = idx + begin; - size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(), - row_ptrs.end(), element_idx) - - row_ptrs.begin() - 1; + size_t ridx = dh::SegmentId(row_ptrs, element_idx); d_temp_weights[idx] = weights[ridx + base_rowid]; }); } diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 94744513a..c8f0e3f7d 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -232,11 +232,8 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, thrust::make_constant_iterator(0lu), [=]__device__(size_t idx) -> float { auto ridx = batch.GetElement(idx).row_idx; - auto it = thrust::upper_bound(thrust::seq, - d_group_ptr.cbegin(), d_group_ptr.cend(), - ridx) - 1; - bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it); - return weights[group]; + bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx); + return weights[group_idx]; }); auto retit = thrust::copy_if(thrust::cuda::par(alloc), weight_iter + begin, weight_iter + end, @@ -277,46 +274,12 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); } -template -HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, - float missing, - size_t sketch_batch_num_elements = 0) { - size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, adapter->NumRows()); - CHECK(adapter->NumRows() != data::kAdapterUnknownSize); - CHECK(adapter->NumColumns() != data::kAdapterUnknownSize); - - adapter->BeforeFirst(); - adapter->Next(); - auto& batch = adapter->Value(); - sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - adapter->NumRows(), adapter->NumColumns(), std::numeric_limits::max(), - adapter->DeviceIdx(), - num_cuts_per_feature, false); - - // Enforce single batch - CHECK(!adapter->Next()); - - HistogramCuts cuts; - SketchContainer sketch_container(num_bins, adapter->NumColumns(), - adapter->NumRows(), adapter->DeviceIdx()); - - for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { - size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); - auto const& batch = adapter->Value(); - ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(), - begin, end, missing, &sketch_container, num_cuts_per_feature); - } - - sketch_container.MakeCuts(&cuts); - return cuts; -} - /* * \brief Perform sketching on GPU. * * \param batch A batch from adapter. * \param num_bins Bins per column. + * \param info Metainfo used for sketching. * \param missing Floating point value that represents invalid value. * \param sketch_container Container for output sketch. * \param sketch_batch_num_elements Number of element per-sliding window, use it only for @@ -324,51 +287,37 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, */ template void AdapterDeviceSketch(Batch batch, int num_bins, + MetaInfo const& info, float missing, SketchContainer* sketch_container, size_t sketch_batch_num_elements = 0) { size_t num_rows = batch.NumRows(); size_t num_cols = batch.NumCols(); size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); int32_t device = sketch_container->DeviceIdx(); - sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - num_rows, num_cols, std::numeric_limits::max(), - device, num_cuts_per_feature, false); - for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { - size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); - ProcessSlidingWindow(batch, device, num_cols, - begin, end, missing, sketch_container, num_cuts_per_feature); - } -} + bool weighted = info.weights_.Size() != 0; -/* - * \brief Perform weighted sketching on GPU. - * - * When weight in info is empty, this function is equivalent to unweighted version. - */ -template -void AdapterDeviceSketchWeighted(Batch batch, int num_bins, - MetaInfo const& info, - float missing, SketchContainer* sketch_container, - size_t sketch_batch_num_elements = 0) { - if (info.weights_.Size() == 0) { - return AdapterDeviceSketch(batch, num_bins, missing, sketch_container, sketch_batch_num_elements); - } - - size_t num_rows = batch.NumRows(); - size_t num_cols = batch.NumCols(); - size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); - int32_t device = sketch_container->DeviceIdx(); - sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - num_rows, num_cols, std::numeric_limits::max(), - device, num_cuts_per_feature, true); - for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { - size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); - ProcessWeightedSlidingWindow(batch, info, - num_cuts_per_feature, - CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end, - sketch_container); + if (weighted) { + sketch_batch_num_elements = detail::SketchBatchNumElements( + sketch_batch_num_elements, + num_rows, num_cols, std::numeric_limits::max(), + device, num_cuts_per_feature, true); + for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { + size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); + ProcessWeightedSlidingWindow(batch, info, + num_cuts_per_feature, + CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end, + sketch_container); + } + } else { + sketch_batch_num_elements = detail::SketchBatchNumElements( + sketch_batch_num_elements, + num_rows, num_cols, std::numeric_limits::max(), + device, num_cuts_per_feature, false); + for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { + size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); + ProcessSlidingWindow(batch, device, num_cols, + begin, end, missing, sketch_container, num_cuts_per_feature); + } } } } // namespace common diff --git a/src/common/hist_util.h b/src/common/hist_util.h index b736670d2..dbb0b35e4 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -138,16 +138,13 @@ class CutsBuilder { explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {} virtual ~CutsBuilder() = default; - static uint32_t SearchGroupIndFromRow( - std::vector const& group_ptr, size_t const base_rowid) { - using KIt = std::vector::const_iterator; - KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid); - // Cannot use CHECK_NE because it will try to print the iterator. - bool const found = res != group_ptr.cend() - 1; - if (!found) { - LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!"; - } - uint32_t group_ind = std::distance(group_ptr.cbegin(), res); + static uint32_t SearchGroupIndFromRow(std::vector const &group_ptr, + size_t const base_rowid) { + CHECK_LT(base_rowid, group_ptr.back()) + << "Row: " << base_rowid << " is not found in any group."; + auto it = + std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid); + bst_group_t group_ind = it - group_ptr.cbegin() - 1; return group_ind; } diff --git a/src/common/quantile.h b/src/common/quantile.h index ee2f44cd2..49345d13f 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -486,30 +486,6 @@ class QuantileSketchTemplate { this->data = dmlc::BeginPtr(space); } } - /*! - * \brief set the space to be merge of all Summary arrays - * \param begin beginning position in the summary array - * \param end ending position in the Summary array - */ - inline void SetMerge(const Summary *begin, - const Summary *end) { - CHECK(begin < end) << "can not set combine to empty instance"; - size_t len = end - begin; - if (len == 1) { - this->Reserve(begin[0].size); - this->CopyFrom(begin[0]); - } else if (len == 2) { - this->Reserve(begin[0].size + begin[1].size); - this->SetMerge(begin[0], begin[1]); - } else { - // recursive merge - SummaryContainer lhs, rhs; - lhs.SetCombine(begin, begin + len / 2); - rhs.SetCombine(begin + len / 2, end); - this->Reserve(lhs.size + rhs.size); - this->SetCombine(lhs, rhs); - } - } /*! * \brief do elementwise combination of summary array * this[i] = combine(this[i], src[i]) for each i diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 4a91c9545..39e845f2d 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -228,31 +228,6 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx, }); } -template -EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread, - int max_bin, common::Span row_counts_span, - size_t row_stride) { - common::HistogramCuts cuts = - common::AdapterDeviceSketch(adapter, max_bin, missing); - dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx())); - auto& batch = adapter->Value(); - - *this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride, - adapter->NumRows()); - CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing); - WriteNullValues(this, adapter->DeviceIdx(), row_counts_span); -} - -#define ELLPACK_SPECIALIZATION(__ADAPTER_T) \ - template EllpackPageImpl::EllpackPageImpl( \ - __ADAPTER_T* adapter, float missing, bool is_dense, int nthread, int max_bin, \ - common::Span row_counts_span, \ - size_t row_stride); - -ELLPACK_SPECIALIZATION(data::CudfAdapter) -ELLPACK_SPECIALIZATION(data::CupyAdapter) - - template EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread, diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index fb54a0c65..8cb0162fb 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -159,12 +159,6 @@ class EllpackPageImpl { */ explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm); - template - explicit EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread, - int max_bin, - common::Span row_counts_span, - size_t row_stride); - template explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread, common::Span row_counts_span, diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 2e9f97c88..3f142acd3 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -75,8 +75,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin auto* p_sketch = &sketch_containers.back(); proxy->Info().weights_.SetDevice(device); Dispatch(proxy, [&](auto const &value) { - common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin, - proxy->Info(), missing, p_sketch); + common::AdapterDeviceSketch(value, batch_param_.max_bin, + proxy->Info(), missing, p_sketch); }); auto batch_rows = num_rows(); diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 0924db8a6..7a0ff9a47 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -164,7 +164,12 @@ TEST(CutsBuilder, SearchGroupInd) { group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5); ASSERT_EQ(group_ind, 2); - EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17)); + p_mat->Info().Validate(-1); + EXPECT_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17), + dmlc::Error); + + std::vector group_ptr {0, 1, 2}; + CHECK_EQ(CutsBuilder::SearchGroupIndFromRow(group_ptr, 1), 1); } TEST(SparseCuts, SingleThreadedBuild) { diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 3ec49668a..433e91679 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -227,16 +227,23 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) { } template -void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows, - DMatrix* dmat) { +auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing, size_t batch_size = 0) { common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); - AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), + SketchContainer sketch_container(num_bins, adapter.NumColumns(), adapter.NumRows(), 0); + MetaInfo info; + AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); sketch_container.MakeCuts(&batched_cuts); - ValidateCuts(batched_cuts, dmat, num_bins); + return batched_cuts; } +template +void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows, + DMatrix* dmat, size_t batch_size = 0) { + common::HistogramCuts batched_cuts = MakeUnweightedCutsForTest( + adapter, num_bins, std::numeric_limits::quiet_NaN()); + ValidateCuts(batched_cuts, dmat, num_bins); +} TEST(HistUtil, AdapterDeviceSketch) { int rows = 5; @@ -251,7 +258,7 @@ TEST(HistUtil, AdapterDeviceSketch) { data::CupyAdapter adapter(str); - auto device_cuts = AdapterDeviceSketch(&adapter, num_bins, missing); + auto device_cuts = MakeUnweightedCutsForTest(adapter, num_bins, missing); auto host_cuts = GetHostCuts(&adapter, num_bins, missing); EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); @@ -269,8 +276,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); - auto cuts = AdapterDeviceSketch(&adapter, num_bins, - std::numeric_limits::quiet_NaN()); + auto cuts = MakeUnweightedCutsForTest(adapter, num_bins, std::numeric_limits::quiet_NaN()); ConsoleLogger::Configure({{"verbosity", "0"}}); size_t bytes_constant = 1000; size_t bytes_required = detail::RequiredMemory( @@ -286,12 +292,13 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) { auto x = GenerateRandom(num_rows, num_columns); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, num_rows, num_columns); + MetaInfo info; dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); - AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), + AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); HistogramCuts cuts; sketch_container.MakeCuts(&cuts); @@ -318,9 +325,9 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); - AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_container); + AdapterDeviceSketch(adapter.Value(), num_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_container); HistogramCuts cuts; sketch_container.MakeCuts(&cuts); ConsoleLogger::Configure({{"verbosity", "0"}}); @@ -340,9 +347,8 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) { auto dmat = GetDMatrixFromData(x, n, 1); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, n, 1); - auto cuts = AdapterDeviceSketch(&adapter, num_bins, - std::numeric_limits::quiet_NaN()); - ValidateCuts(cuts, dmat.get(), num_bins); + ValidateBatchedCuts(adapter, num_bins, adapter.NumColumns(), + adapter.NumRows(), dmat.get()); } } } @@ -357,9 +363,6 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { auto x_device = thrust::device_vector(x); for (auto num_bins : bin_sizes) { auto adapter = AdapterFromData(x_device, num_rows, num_columns); - auto cuts = AdapterDeviceSketch(&adapter, num_bins, - std::numeric_limits::quiet_NaN()); - ValidateCuts(cuts, dmat.get(), num_bins); ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get()); } } @@ -375,11 +378,7 @@ TEST(HistUtil, AdapterDeviceSketchBatches) { auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, num_rows, num_columns); - auto cuts = AdapterDeviceSketch(&adapter, num_bins, - std::numeric_limits::quiet_NaN(), - batch_size); - ValidateCuts(cuts, dmat.get(), num_bins); - ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get()); + ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get(), batch_size); } } @@ -396,8 +395,8 @@ TEST(HistUtil, SketchingEquivalent) { auto dmat_cuts = DeviceSketch(0, dmat.get(), num_bins); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, num_rows, num_columns); - auto adapter_cuts = AdapterDeviceSketch( - &adapter, num_bins, std::numeric_limits::quiet_NaN()); + common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest( + adapter, num_bins, std::numeric_limits::quiet_NaN()); EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values()); EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs()); EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues()); @@ -467,8 +466,8 @@ void TestAdapterSketchFromWeights(bool with_group) { data::CupyAdapter adapter(m); auto const& batch = adapter.Value(); SketchContainer sketch_container(kBins, kCols, kRows, 0); - AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), - &sketch_container); + AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), + &sketch_container); common::HistogramCuts cuts; sketch_container.MakeCuts(&cuts); diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index caef6c2c7..e16edfef0 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -53,8 +53,8 @@ void TestSketchUnique(float sparsity) { .Device(0) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), &sketch); + AdapterDeviceSketch(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch); auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows); dh::caching_device_vector column_sizes_scan; @@ -127,8 +127,8 @@ TEST(GPUQuantile, Prune) { .Seed(seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), &sketch); + AdapterDeviceSketch(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch); auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows); ASSERT_EQ(sketch.Data().size(), n_cuts * kCols); @@ -158,7 +158,8 @@ TEST(GPUQuantile, MergeEmpty) { RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( &storage_0); data::CupyAdapter adapter_0(interface_str_0); - AdapterDeviceSketch(adapter_0.Value(), n_bins, + MetaInfo info; + AdapterDeviceSketch(adapter_0.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_0); std::vector entries_before(sketch_0.Data().size()); @@ -197,8 +198,8 @@ TEST(GPUQuantile, MergeBasic) { .Seed(seed) .GenerateArrayInterface(&storage_0); data::CupyAdapter adapter_0(interface_str_0); - AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), &sketch_0); + AdapterDeviceSketch(adapter_0.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_0); SketchContainer sketch_1(n_bins, kCols, kRows * kRows, 0); HostDeviceVector storage_1; @@ -207,8 +208,8 @@ TEST(GPUQuantile, MergeBasic) { .Seed(seed) .GenerateArrayInterface(&storage_1); data::CupyAdapter adapter_1(interface_str_1); - AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), &sketch_1); + AdapterDeviceSketch(adapter_1.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_1); size_t size_before_merge = sketch_0.Data().size(); sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data()); @@ -243,9 +244,9 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { .Seed(seed) .GenerateArrayInterface(&storage_0); data::CupyAdapter adapter_0(interface_str_0); - AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_0); + AdapterDeviceSketch(adapter_0.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_0); size_t f_rows = rows * frac; SketchContainer sketch_1(n_bins, cols, f_rows, 0); @@ -269,9 +270,9 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { } }); data::CupyAdapter adapter_1(interface_str_1); - AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_1); + AdapterDeviceSketch(adapter_1.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_1); size_t size_before_merge = sketch_0.Data().size(); sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data()); @@ -344,9 +345,9 @@ TEST(GPUQuantile, AllReduceBasic) { .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); containers.emplace_back(n_bins, kCols, kRows, 0); - AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &containers.back()); + AdapterDeviceSketch(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &containers.back()); } for (auto& sketch : containers) { sketch.Prune(intermediate_num_cuts); @@ -367,9 +368,9 @@ TEST(GPUQuantile, AllReduceBasic) { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_distributed); + AdapterDeviceSketch(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_distributed); sketch_distributed.AllReduce(); sketch_distributed.Unique(); @@ -426,9 +427,9 @@ TEST(GPUQuantile, SameOnAllWorkers) { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_distributed); + AdapterDeviceSketch(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_distributed); sketch_distributed.AllReduce(); sketch_distributed.Unique(); TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr());