Cleanup on device sketch. (#5874)
* Remove old functions. * Merge weighted and un-weighted into a common interface.
This commit is contained in:
parent
9f85e92602
commit
dd445af56e
@ -253,22 +253,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
<< "Weight size should equal to number of groups.";
|
||||
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
|
||||
size_t element_idx = idx + begin;
|
||||
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
|
||||
row_ptrs.end(), element_idx) -
|
||||
row_ptrs.begin() - 1;
|
||||
auto it =
|
||||
thrust::upper_bound(thrust::seq,
|
||||
d_group_ptr.cbegin(), d_group_ptr.cend(),
|
||||
ridx + base_rowid) - 1;
|
||||
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
|
||||
d_temp_weights[idx] = weights[group];
|
||||
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
|
||||
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid);
|
||||
d_temp_weights[idx] = weights[group_idx];
|
||||
});
|
||||
} else {
|
||||
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
|
||||
size_t element_idx = idx + begin;
|
||||
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
|
||||
row_ptrs.end(), element_idx) -
|
||||
row_ptrs.begin() - 1;
|
||||
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
}
|
||||
|
||||
@ -232,11 +232,8 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
thrust::make_constant_iterator(0lu),
|
||||
[=]__device__(size_t idx) -> float {
|
||||
auto ridx = batch.GetElement(idx).row_idx;
|
||||
auto it = thrust::upper_bound(thrust::seq,
|
||||
d_group_ptr.cbegin(), d_group_ptr.cend(),
|
||||
ridx) - 1;
|
||||
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
|
||||
return weights[group];
|
||||
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
|
||||
return weights[group_idx];
|
||||
});
|
||||
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
|
||||
weight_iter + begin, weight_iter + end,
|
||||
@ -277,46 +274,12 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements = 0) {
|
||||
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, adapter->NumRows());
|
||||
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
|
||||
|
||||
adapter->BeforeFirst();
|
||||
adapter->Next();
|
||||
auto& batch = adapter->Value();
|
||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
adapter->NumRows(), adapter->NumColumns(), std::numeric_limits<size_t>::max(),
|
||||
adapter->DeviceIdx(),
|
||||
num_cuts_per_feature, false);
|
||||
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
|
||||
HistogramCuts cuts;
|
||||
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
|
||||
adapter->NumRows(), adapter->DeviceIdx());
|
||||
|
||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
auto const& batch = adapter->Value();
|
||||
ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(),
|
||||
begin, end, missing, &sketch_container, num_cuts_per_feature);
|
||||
}
|
||||
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
return cuts;
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief Perform sketching on GPU.
|
||||
*
|
||||
* \param batch A batch from adapter.
|
||||
* \param num_bins Bins per column.
|
||||
* \param info Metainfo used for sketching.
|
||||
* \param missing Floating point value that represents invalid value.
|
||||
* \param sketch_container Container for output sketch.
|
||||
* \param sketch_batch_num_elements Number of element per-sliding window, use it only for
|
||||
@ -324,41 +287,16 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
*/
|
||||
template <typename Batch>
|
||||
void AdapterDeviceSketch(Batch batch, int num_bins,
|
||||
float missing, SketchContainer* sketch_container,
|
||||
size_t sketch_batch_num_elements = 0) {
|
||||
size_t num_rows = batch.NumRows();
|
||||
size_t num_cols = batch.NumCols();
|
||||
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
|
||||
int32_t device = sketch_container->DeviceIdx();
|
||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
||||
device, num_cuts_per_feature, false);
|
||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
ProcessSlidingWindow(batch, device, num_cols,
|
||||
begin, end, missing, sketch_container, num_cuts_per_feature);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief Perform weighted sketching on GPU.
|
||||
*
|
||||
* When weight in info is empty, this function is equivalent to unweighted version.
|
||||
*/
|
||||
template <typename Batch>
|
||||
void AdapterDeviceSketchWeighted(Batch batch, int num_bins,
|
||||
MetaInfo const& info,
|
||||
float missing, SketchContainer* sketch_container,
|
||||
size_t sketch_batch_num_elements = 0) {
|
||||
if (info.weights_.Size() == 0) {
|
||||
return AdapterDeviceSketch(batch, num_bins, missing, sketch_container, sketch_batch_num_elements);
|
||||
}
|
||||
|
||||
size_t num_rows = batch.NumRows();
|
||||
size_t num_cols = batch.NumCols();
|
||||
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
|
||||
int32_t device = sketch_container->DeviceIdx();
|
||||
bool weighted = info.weights_.Size() != 0;
|
||||
|
||||
if (weighted) {
|
||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
||||
@ -370,6 +308,17 @@ void AdapterDeviceSketchWeighted(Batch batch, int num_bins,
|
||||
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
|
||||
sketch_container);
|
||||
}
|
||||
} else {
|
||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
||||
device, num_cuts_per_feature, false);
|
||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
ProcessSlidingWindow(batch, device, num_cols,
|
||||
begin, end, missing, sketch_container, num_cuts_per_feature);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -138,16 +138,13 @@ class CutsBuilder {
|
||||
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}
|
||||
virtual ~CutsBuilder() = default;
|
||||
|
||||
static uint32_t SearchGroupIndFromRow(
|
||||
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
|
||||
using KIt = std::vector<bst_uint>::const_iterator;
|
||||
KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
|
||||
// Cannot use CHECK_NE because it will try to print the iterator.
|
||||
bool const found = res != group_ptr.cend() - 1;
|
||||
if (!found) {
|
||||
LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!";
|
||||
}
|
||||
uint32_t group_ind = std::distance(group_ptr.cbegin(), res);
|
||||
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
|
||||
size_t const base_rowid) {
|
||||
CHECK_LT(base_rowid, group_ptr.back())
|
||||
<< "Row: " << base_rowid << " is not found in any group.";
|
||||
auto it =
|
||||
std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
|
||||
bst_group_t group_ind = it - group_ptr.cbegin() - 1;
|
||||
return group_ind;
|
||||
}
|
||||
|
||||
|
||||
@ -486,30 +486,6 @@ class QuantileSketchTemplate {
|
||||
this->data = dmlc::BeginPtr(space);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief set the space to be merge of all Summary arrays
|
||||
* \param begin beginning position in the summary array
|
||||
* \param end ending position in the Summary array
|
||||
*/
|
||||
inline void SetMerge(const Summary *begin,
|
||||
const Summary *end) {
|
||||
CHECK(begin < end) << "can not set combine to empty instance";
|
||||
size_t len = end - begin;
|
||||
if (len == 1) {
|
||||
this->Reserve(begin[0].size);
|
||||
this->CopyFrom(begin[0]);
|
||||
} else if (len == 2) {
|
||||
this->Reserve(begin[0].size + begin[1].size);
|
||||
this->SetMerge(begin[0], begin[1]);
|
||||
} else {
|
||||
// recursive merge
|
||||
SummaryContainer lhs, rhs;
|
||||
lhs.SetCombine(begin, begin + len / 2);
|
||||
rhs.SetCombine(begin + len / 2, end);
|
||||
this->Reserve(lhs.size + rhs.size);
|
||||
this->SetCombine(lhs, rhs);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief do elementwise combination of summary array
|
||||
* this[i] = combine(this[i], src[i]) for each i
|
||||
|
||||
@ -228,31 +228,6 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
|
||||
int max_bin, common::Span<size_t> row_counts_span,
|
||||
size_t row_stride) {
|
||||
common::HistogramCuts cuts =
|
||||
common::AdapterDeviceSketch(adapter, max_bin, missing);
|
||||
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
|
||||
auto& batch = adapter->Value();
|
||||
|
||||
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
|
||||
adapter->NumRows());
|
||||
CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing);
|
||||
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
|
||||
}
|
||||
|
||||
#define ELLPACK_SPECIALIZATION(__ADAPTER_T) \
|
||||
template EllpackPageImpl::EllpackPageImpl( \
|
||||
__ADAPTER_T* adapter, float missing, bool is_dense, int nthread, int max_bin, \
|
||||
common::Span<size_t> row_counts_span, \
|
||||
size_t row_stride);
|
||||
|
||||
ELLPACK_SPECIALIZATION(data::CudfAdapter)
|
||||
ELLPACK_SPECIALIZATION(data::CupyAdapter)
|
||||
|
||||
|
||||
template <typename AdapterBatch>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||
bool is_dense, int nthread,
|
||||
|
||||
@ -159,12 +159,6 @@ class EllpackPageImpl {
|
||||
*/
|
||||
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
||||
|
||||
template <typename AdapterT>
|
||||
explicit EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
|
||||
int max_bin,
|
||||
common::Span<size_t> row_counts_span,
|
||||
size_t row_stride);
|
||||
|
||||
template <typename AdapterBatch>
|
||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread,
|
||||
common::Span<size_t> row_counts_span,
|
||||
|
||||
@ -75,7 +75,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
proxy->Info().weights_.SetDevice(device);
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
|
||||
common::AdapterDeviceSketch(value, batch_param_.max_bin,
|
||||
proxy->Info(), missing, p_sketch);
|
||||
});
|
||||
|
||||
|
||||
@ -164,7 +164,12 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5);
|
||||
ASSERT_EQ(group_ind, 2);
|
||||
|
||||
EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
|
||||
p_mat->Info().Validate(-1);
|
||||
EXPECT_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17),
|
||||
dmlc::Error);
|
||||
|
||||
std::vector<bst_uint> group_ptr {0, 1, 2};
|
||||
CHECK_EQ(CutsBuilder::SearchGroupIndFromRow(group_ptr, 1), 1);
|
||||
}
|
||||
|
||||
TEST(SparseCuts, SingleThreadedBuild) {
|
||||
|
||||
@ -227,16 +227,23 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) {
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows,
|
||||
DMatrix* dmat) {
|
||||
auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing, size_t batch_size = 0) {
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
|
||||
SketchContainer sketch_container(num_bins, adapter.NumColumns(), adapter.NumRows(), 0);
|
||||
MetaInfo info;
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
sketch_container.MakeCuts(&batched_cuts);
|
||||
ValidateCuts(batched_cuts, dmat, num_bins);
|
||||
return batched_cuts;
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows,
|
||||
DMatrix* dmat, size_t batch_size = 0) {
|
||||
common::HistogramCuts batched_cuts = MakeUnweightedCutsForTest(
|
||||
adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
ValidateCuts(batched_cuts, dmat, num_bins);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketch) {
|
||||
int rows = 5;
|
||||
@ -251,7 +258,7 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
||||
|
||||
data::CupyAdapter adapter(str);
|
||||
|
||||
auto device_cuts = AdapterDeviceSketch(&adapter, num_bins, missing);
|
||||
auto device_cuts = MakeUnweightedCutsForTest(adapter, num_bins, missing);
|
||||
auto host_cuts = GetHostCuts(&adapter, num_bins, missing);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
@ -269,8 +276,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
auto cuts = MakeUnweightedCutsForTest(adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_constant = 1000;
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
@ -286,12 +292,13 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
MetaInfo info;
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
@ -318,7 +325,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info,
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
HistogramCuts cuts;
|
||||
@ -340,9 +347,8 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, n, 1);
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateBatchedCuts(adapter, num_bins, adapter.NumColumns(),
|
||||
adapter.NumRows(), dmat.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -357,9 +363,6 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
}
|
||||
}
|
||||
@ -375,11 +378,7 @@ TEST(HistUtil, AdapterDeviceSketchBatches) {
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get(), batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -396,8 +395,8 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
auto dmat_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
auto adapter_cuts = AdapterDeviceSketch(
|
||||
&adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest(
|
||||
adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
|
||||
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
|
||||
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
||||
@ -467,7 +466,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
data::CupyAdapter adapter(m);
|
||||
auto const& batch = adapter.Value();
|
||||
SketchContainer sketch_container(kBins, kCols, kRows, 0);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
common::HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
|
||||
@ -53,7 +53,7 @@ void TestSketchUnique(float sparsity) {
|
||||
.Device(0)
|
||||
.GenerateArrayInterface(&storage);
|
||||
data::CupyAdapter adapter(interface_str);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &sketch);
|
||||
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
|
||||
|
||||
@ -127,7 +127,7 @@ TEST(GPUQuantile, Prune) {
|
||||
.Seed(seed)
|
||||
.GenerateArrayInterface(&storage);
|
||||
data::CupyAdapter adapter(interface_str);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &sketch);
|
||||
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
|
||||
ASSERT_EQ(sketch.Data().size(), n_cuts * kCols);
|
||||
@ -158,7 +158,8 @@ TEST(GPUQuantile, MergeEmpty) {
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
|
||||
&storage_0);
|
||||
data::CupyAdapter adapter_0(interface_str_0);
|
||||
AdapterDeviceSketch(adapter_0.Value(), n_bins,
|
||||
MetaInfo info;
|
||||
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
|
||||
|
||||
std::vector<SketchEntry> entries_before(sketch_0.Data().size());
|
||||
@ -197,7 +198,7 @@ TEST(GPUQuantile, MergeBasic) {
|
||||
.Seed(seed)
|
||||
.GenerateArrayInterface(&storage_0);
|
||||
data::CupyAdapter adapter_0(interface_str_0);
|
||||
AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
|
||||
|
||||
SketchContainer sketch_1(n_bins, kCols, kRows * kRows, 0);
|
||||
@ -207,7 +208,7 @@ TEST(GPUQuantile, MergeBasic) {
|
||||
.Seed(seed)
|
||||
.GenerateArrayInterface(&storage_1);
|
||||
data::CupyAdapter adapter_1(interface_str_1);
|
||||
AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter_1.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &sketch_1);
|
||||
|
||||
size_t size_before_merge = sketch_0.Data().size();
|
||||
@ -243,7 +244,7 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
|
||||
.Seed(seed)
|
||||
.GenerateArrayInterface(&storage_0);
|
||||
data::CupyAdapter adapter_0(interface_str_0);
|
||||
AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_0);
|
||||
|
||||
@ -269,7 +270,7 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
|
||||
}
|
||||
});
|
||||
data::CupyAdapter adapter_1(interface_str_1);
|
||||
AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter_1.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_1);
|
||||
|
||||
@ -344,7 +345,7 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
.GenerateArrayInterface(&storage);
|
||||
data::CupyAdapter adapter(interface_str);
|
||||
containers.emplace_back(n_bins, kCols, kRows, 0);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&containers.back());
|
||||
}
|
||||
@ -367,7 +368,7 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
.Seed(rank + seed)
|
||||
.GenerateArrayInterface(&storage);
|
||||
data::CupyAdapter adapter(interface_str);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_distributed);
|
||||
sketch_distributed.AllReduce();
|
||||
@ -426,7 +427,7 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
.Seed(rank + seed)
|
||||
.GenerateArrayInterface(&storage);
|
||||
data::CupyAdapter adapter(interface_str);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_distributed);
|
||||
sketch_distributed.AllReduce();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user