Handle duplicated values in sketching. (#6178)
* Accumulate weights in duplicated values. * Fix device id in iterative dmatrix.
This commit is contained in:
parent
ab5b35134f
commit
2241563f23
@ -27,6 +27,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
@ -1059,14 +1060,14 @@ struct SegmentedUniqueReduceOp {
|
||||
*
|
||||
* \return Number of unique values in total.
|
||||
*/
|
||||
template <typename KeyInIt, typename KeyOutIt, typename ValInIt,
|
||||
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
|
||||
typename ValOutIt, typename Comp>
|
||||
size_t
|
||||
SegmentedUnique(KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
|
||||
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
|
||||
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
|
||||
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
|
||||
Comp comp) {
|
||||
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto unique_key_it = dh::MakeTransformIterator<Key>(
|
||||
thrust::make_counting_iterator(static_cast<size_t>(0)),
|
||||
[=] __device__(size_t i) {
|
||||
@ -1083,7 +1084,7 @@ SegmentedUnique(KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt v
|
||||
thrust::make_discard_iterator(),
|
||||
detail::SegmentedUniqueReduceOp<Key, KeyOutIt>{key_segments_out});
|
||||
auto uniques_ret = thrust::unique_by_key_copy(
|
||||
thrust::cuda::par(alloc), unique_key_it, unique_key_it + n_inputs,
|
||||
exec, unique_key_it, unique_key_it + n_inputs,
|
||||
val_first, reduce_it, val_out,
|
||||
[=] __device__(Key const &l, Key const &r) {
|
||||
if (l.first == r.first) {
|
||||
@ -1094,8 +1095,16 @@ SegmentedUnique(KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt v
|
||||
});
|
||||
auto n_uniques = uniques_ret.second - val_out;
|
||||
CHECK_LE(n_uniques, n_inputs);
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), key_segments_out,
|
||||
thrust::exclusive_scan(exec, key_segments_out,
|
||||
key_segments_out + segments_len, key_segments_out, 0);
|
||||
return n_uniques;
|
||||
}
|
||||
|
||||
template <typename... Inputs,
|
||||
std::enable_if_t<std::tuple_size<std::tuple<Inputs...>>::value == 7>
|
||||
* = nullptr>
|
||||
size_t SegmentedUnique(Inputs &&...inputs) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
|
||||
}
|
||||
} // namespace dh
|
||||
|
||||
@ -269,7 +269,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
|
||||
// Extract cuts
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
|
||||
@ -153,7 +153,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
// Extract the cuts from all columns concurrently
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||
@ -224,7 +224,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
detail::SortByWeight(&temp_weights, &sorted_entries);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
|
||||
// Extract cuts
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
|
||||
@ -104,9 +104,10 @@ void PruneImpl(int device,
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTo(Span<T> out, Span<T const> src) {
|
||||
template <typename T, typename U>
|
||||
void CopyTo(Span<T> out, Span<U> src) {
|
||||
CHECK_EQ(out.size(), src.size());
|
||||
static_assert(std::is_same<std::remove_cv_t<T>, std::remove_cv_t<T>>::value, "");
|
||||
dh::safe_cuda(cudaMemcpyAsync(out.data(), src.data(),
|
||||
out.size_bytes(),
|
||||
cudaMemcpyDefault));
|
||||
@ -307,7 +308,7 @@ void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
|
||||
}
|
||||
|
||||
void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||
common::Span<OffsetT const> cuts_ptr,
|
||||
common::Span<OffsetT> cuts_ptr,
|
||||
size_t total_cuts, Span<float> weights) {
|
||||
Span<SketchEntry> out;
|
||||
dh::device_vector<SketchEntry> cuts;
|
||||
@ -346,12 +347,15 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||
PruneImpl<Entry>(device_, cuts_ptr, entries, columns_ptr, ft, out,
|
||||
to_sketch_entry);
|
||||
}
|
||||
auto n_uniques = this->ScanInput(out, cuts_ptr);
|
||||
|
||||
if (!first_window) {
|
||||
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
|
||||
out = out.subspan(0, n_uniques);
|
||||
this->Merge(cuts_ptr, out);
|
||||
this->FixError();
|
||||
} else {
|
||||
this->Current().resize(n_uniques);
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(cuts_ptr.size());
|
||||
|
||||
@ -360,6 +364,49 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||
}
|
||||
}
|
||||
|
||||
size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in) {
|
||||
/* There are 2 types of duplication. First is duplicated feature values, which comes
|
||||
* from user input data. Second is duplicated sketching entries, which is generated by
|
||||
* prunning or merging. We preserve the first type and remove the second type.
|
||||
*/
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
auto key_it = dh::MakeTransformIterator<size_t>(
|
||||
thrust::make_reverse_iterator(thrust::make_counting_iterator(entries.size())),
|
||||
[=] __device__(size_t idx) {
|
||||
return dh::SegmentId(d_columns_ptr_in, idx);
|
||||
});
|
||||
// Reverse scan to accumulate weights into first duplicated element on left.
|
||||
auto val_it = thrust::make_reverse_iterator(dh::tend(entries));
|
||||
thrust::inclusive_scan_by_key(
|
||||
thrust::cuda::par(alloc), key_it, key_it + entries.size(),
|
||||
val_it, val_it,
|
||||
thrust::equal_to<size_t>{},
|
||||
[] __device__(SketchEntry const &r, SketchEntry const &l) {
|
||||
// Only accumulate for the first type of duplication.
|
||||
if (l.value - r.value == 0 && l.rmin - r.rmin != 0) {
|
||||
auto w = l.wmin + r.wmin;
|
||||
SketchEntry v{l.rmin, l.rmin + w, w, l.value};
|
||||
return v;
|
||||
}
|
||||
return l;
|
||||
});
|
||||
|
||||
auto d_columns_ptr_out = columns_ptr_b_.DeviceSpan();
|
||||
// thrust unique_by_key preserves the first element.
|
||||
auto n_uniques = dh::SegmentedUnique(
|
||||
d_columns_ptr_in.data(),
|
||||
d_columns_ptr_in.data() + d_columns_ptr_in.size(), entries.data(),
|
||||
entries.data() + entries.size(), d_columns_ptr_out.data(), entries.data(),
|
||||
detail::SketchUnique{});
|
||||
CopyTo(d_columns_ptr_in, d_columns_ptr_out);
|
||||
|
||||
timer_.Stop(__func__);
|
||||
return n_uniques;
|
||||
}
|
||||
|
||||
size_t SketchContainer::Unique() {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
@ -389,7 +436,6 @@ void SketchContainer::Prune(size_t to) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
this->Unique();
|
||||
OffsetT to_total = 0;
|
||||
auto& h_columns_ptr = columns_ptr_b_.HostVector();
|
||||
h_columns_ptr[0] = to_total;
|
||||
@ -417,6 +463,8 @@ void SketchContainer::Prune(size_t to) {
|
||||
out, no_op);
|
||||
this->columns_ptr_.Copy(columns_ptr_b_);
|
||||
this->Alternate();
|
||||
|
||||
this->Unique();
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
|
||||
@ -447,6 +495,7 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
|
||||
this->columns_ptr_.Copy(columns_ptr_b_);
|
||||
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
|
||||
this->Alternate();
|
||||
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
|
||||
@ -558,7 +607,6 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
|
||||
// Prune to final number of bins.
|
||||
this->Prune(num_bins_ + 1);
|
||||
this->Unique();
|
||||
this->FixError();
|
||||
|
||||
// Set up inputs
|
||||
|
||||
@ -106,6 +106,8 @@ class SketchContainer {
|
||||
}
|
||||
/* \brief Return GPU ID for this container. */
|
||||
int32_t DeviceIdx() const { return device_; }
|
||||
/* \brief Accumulate weights of duplicated entries in input. */
|
||||
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
|
||||
/* \brief Removes all the duplicated elements in quantile structure. */
|
||||
size_t Unique();
|
||||
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
|
||||
@ -121,7 +123,7 @@ class SketchContainer {
|
||||
* \param weights (optional) data weights.
|
||||
*/
|
||||
void Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||
common::Span<OffsetT const> cuts_ptr, size_t total_cuts,
|
||||
common::Span<OffsetT> cuts_ptr, size_t total_cuts,
|
||||
Span<float> weights = {});
|
||||
/* \brief Prune the quantile structure.
|
||||
*
|
||||
|
||||
@ -63,15 +63,17 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
size_t accumulated_rows = 0;
|
||||
bst_feature_t cols = 0;
|
||||
int32_t device = GenericParameter::kCpuId;
|
||||
int32_t current_device_;
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device_));
|
||||
int32_t current_device;
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
||||
auto get_device = [&]() -> int32_t {
|
||||
int32_t d = GenericParameter::kCpuId ? current_device_ : device;
|
||||
int32_t d = (device == GenericParameter::kCpuId) ? current_device : device;
|
||||
CHECK_NE(d, GenericParameter::kCpuId);
|
||||
return d;
|
||||
};
|
||||
|
||||
while (iter.Next()) {
|
||||
device = proxy->DeviceIdx();
|
||||
CHECK_LT(device, common::AllVisibleGPUs());
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
|
||||
@ -66,6 +66,9 @@ class DMatrixProxy : public DMatrix {
|
||||
} else {
|
||||
this->FromCudaArray(interface_str);
|
||||
}
|
||||
if (this->info_.num_row_ == 0) {
|
||||
this->device_ = GenericParameter::kCpuId;
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
|
||||
@ -314,10 +314,9 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto cuts = MakeUnweightedCutsForTest(adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_constant = 1000;
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95);
|
||||
}
|
||||
|
||||
|
||||
@ -45,12 +45,11 @@ void TestSketchUnique(float sparsity) {
|
||||
detail::GetColumnSizesScan(0, kCols, n_cuts, batch_iter, is_valid, 0, end,
|
||||
&cut_sizes_scan, &column_sizes_scan);
|
||||
auto const& cut_sizes = cut_sizes_scan.HostVector();
|
||||
ASSERT_LE(sketch.Data().size(), cut_sizes.back());
|
||||
|
||||
if (sparsity == 0) {
|
||||
ASSERT_EQ(sketch.Data().size(), n_cuts * kCols);
|
||||
} else {
|
||||
ASSERT_EQ(sketch.Data().size(), cut_sizes.back());
|
||||
}
|
||||
std::vector<size_t> h_columns_ptr(sketch.ColumnsPtr().size());
|
||||
dh::CopyDeviceSpanToVector(&h_columns_ptr, sketch.ColumnsPtr());
|
||||
ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back());
|
||||
|
||||
sketch.Unique();
|
||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
|
||||
@ -67,31 +66,36 @@ TEST(GPUQuantile, Unique) {
|
||||
// if with_error is true, the test tolerates floating point error
|
||||
void TestQuantileElemRank(int32_t device, Span<SketchEntry const> in,
|
||||
Span<bst_row_t const> d_columns_ptr, bool with_error = false) {
|
||||
dh::LaunchN(device, in.size(), [=]XGBOOST_DEVICE(size_t idx) {
|
||||
auto column_id = dh::SegmentId(d_columns_ptr, idx);
|
||||
auto in_column = in.subspan(d_columns_ptr[column_id],
|
||||
d_columns_ptr[column_id + 1] -
|
||||
d_columns_ptr[column_id]);
|
||||
auto constexpr kEps = 1e-6f;
|
||||
idx -= d_columns_ptr[column_id];
|
||||
float prev_rmin = idx == 0 ? 0.0f : in_column[idx-1].rmin;
|
||||
float prev_rmax = idx == 0 ? 0.0f : in_column[idx-1].rmax;
|
||||
std::vector<SketchEntry> h_in(in.size());
|
||||
dh::CopyDeviceSpanToVector(&h_in, in);
|
||||
std::vector<bst_row_t> h_columns_ptr(d_columns_ptr.size());
|
||||
dh::CopyDeviceSpanToVector(&h_columns_ptr, d_columns_ptr);
|
||||
|
||||
for (size_t i = 1; i < d_columns_ptr.size(); ++i) {
|
||||
auto column_id = i - 1;
|
||||
auto beg = h_columns_ptr[column_id];
|
||||
auto end = h_columns_ptr[i];
|
||||
|
||||
auto in_column = Span<SketchEntry>{h_in}.subspan(beg, end - beg);
|
||||
for (size_t idx = 1; idx < in_column.size(); ++idx) {
|
||||
float prev_rmin = in_column[idx - 1].rmin;
|
||||
float prev_rmax = in_column[idx - 1].rmax;
|
||||
float rmin_next = in_column[idx].RMinNext();
|
||||
|
||||
if (with_error) {
|
||||
SPAN_CHECK(in_column[idx].rmin + in_column[idx].rmin * kEps >= prev_rmin);
|
||||
SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= prev_rmax);
|
||||
SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= rmin_next);
|
||||
ASSERT_GE(in_column[idx].rmin + in_column[idx].rmin * kRtEps,
|
||||
prev_rmin);
|
||||
ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps,
|
||||
prev_rmax);
|
||||
ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps,
|
||||
rmin_next);
|
||||
} else {
|
||||
SPAN_CHECK(in_column[idx].rmin >= prev_rmin);
|
||||
SPAN_CHECK(in_column[idx].rmax >= prev_rmax);
|
||||
SPAN_CHECK(in_column[idx].rmax >= rmin_next);
|
||||
ASSERT_GE(in_column[idx].rmin, prev_rmin);
|
||||
ASSERT_GE(in_column[idx].rmax, prev_rmax);
|
||||
ASSERT_GE(in_column[idx].rmax, rmin_next);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
// Force sync to terminate current test instead of a later one.
|
||||
dh::DebugSyncDevice(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
|
||||
TEST(GPUQuantile, Prune) {
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
@ -108,16 +112,12 @@ TEST(GPUQuantile, Prune) {
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &sketch);
|
||||
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
|
||||
ASSERT_EQ(sketch.Data().size(), n_cuts * kCols);
|
||||
|
||||
sketch.Prune(n_bins);
|
||||
if (n_bins <= kRows) {
|
||||
ASSERT_EQ(sketch.Data().size(), n_bins * kCols);
|
||||
} else {
|
||||
// LE because kRows * kCols is pushed into sketch, after removing
|
||||
// duplicated entries we might not have that much inputs for prune.
|
||||
ASSERT_LE(sketch.Data().size(), n_cuts * kCols);
|
||||
|
||||
sketch.Prune(n_bins);
|
||||
ASSERT_LE(sketch.Data().size(), kRows * kCols);
|
||||
}
|
||||
// This is not necessarily true for all inputs without calling unique after
|
||||
// prune.
|
||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
|
||||
@ -278,6 +278,45 @@ TEST(GPUQuantile, MergeDuplicated) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUQuantile, MultiMerge) {
|
||||
constexpr size_t kRows = 20, kCols = 1;
|
||||
int32_t world = 2;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
||||
MetaInfo const &info) {
|
||||
// Set up single node version
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, 0);
|
||||
|
||||
size_t intermediate_num_cuts = std::min(
|
||||
kRows * world, static_cast<size_t>(n_bins * WQSketch::kFactor));
|
||||
std::vector<SketchContainer> containers;
|
||||
for (auto rank = 0; rank < world; ++rank) {
|
||||
HostDeviceVector<float> storage;
|
||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||
.Device(0)
|
||||
.Seed(rank + seed)
|
||||
.GenerateArrayInterface(&storage);
|
||||
data::CupyAdapter adapter(interface_str);
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
containers.emplace_back(ft, n_bins, kCols, kRows, 0);
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&containers.back());
|
||||
}
|
||||
for (auto &sketch : containers) {
|
||||
sketch.Prune(intermediate_num_cuts);
|
||||
sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
sketch_on_single_node.FixError();
|
||||
}
|
||||
TestQuantileElemRank(0, sketch_on_single_node.Data(),
|
||||
sketch_on_single_node.ColumnsPtr());
|
||||
|
||||
sketch_on_single_node.Unique();
|
||||
TestQuantileElemRank(0, sketch_on_single_node.Data(),
|
||||
sketch_on_single_node.ColumnsPtr());
|
||||
});
|
||||
}
|
||||
|
||||
TEST(GPUQuantile, AllReduceBasic) {
|
||||
// This test is supposed to run by a python test that setups the environment.
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
@ -442,5 +481,99 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
return;
|
||||
#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL)
|
||||
}
|
||||
|
||||
TEST(GPUQuantile, Push) {
|
||||
size_t constexpr kRows = 100;
|
||||
std::vector<float> data(kRows);
|
||||
|
||||
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
|
||||
std::fill(data.begin() + (data.size() / 2), data.end(), 0.5f);
|
||||
int32_t n_bins = 128;
|
||||
bst_feature_t constexpr kCols = 1;
|
||||
|
||||
std::vector<Entry> entries(kRows);
|
||||
for (bst_feature_t i = 0; i < entries.size(); ++i) {
|
||||
Entry e{i, data[i]};
|
||||
entries[i] = e;
|
||||
}
|
||||
|
||||
dh::device_vector<Entry> d_entries(entries);
|
||||
dh::device_vector<size_t> columns_ptr(2);
|
||||
columns_ptr[0] = 0;
|
||||
columns_ptr[1] = kRows;
|
||||
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch(ft, n_bins, kCols, kRows, 0);
|
||||
sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {});
|
||||
|
||||
auto sketch_data = sketch.Data();
|
||||
|
||||
thrust::host_vector<SketchEntry> h_sketch_data(sketch_data.size());
|
||||
|
||||
auto ptr = thrust::device_ptr<SketchEntry const>(sketch_data.data());
|
||||
thrust::copy(ptr, ptr + sketch_data.size(), h_sketch_data.begin());
|
||||
ASSERT_EQ(h_sketch_data.size(), 2);
|
||||
|
||||
auto v_0 = h_sketch_data[0];
|
||||
ASSERT_EQ(v_0.rmin, 0);
|
||||
ASSERT_EQ(v_0.wmin, kRows / 2.0f);
|
||||
ASSERT_EQ(v_0.rmax, kRows / 2.0f);
|
||||
|
||||
auto v_1 = h_sketch_data[1];
|
||||
ASSERT_EQ(v_1.rmin, kRows / 2.0f);
|
||||
ASSERT_EQ(v_1.wmin, kRows / 2.0f);
|
||||
ASSERT_EQ(v_1.rmax, static_cast<float>(kRows));
|
||||
}
|
||||
|
||||
TEST(GPUQuantile, MultiColPush) {
|
||||
size_t constexpr kRows = 100, kCols = 4;
|
||||
std::vector<float> data(kRows * kCols);
|
||||
|
||||
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
|
||||
|
||||
std::vector<Entry> entries(kRows * kCols);
|
||||
|
||||
for (bst_feature_t c = 0; c < kCols; ++c) {
|
||||
for (size_t r = 0; r < kRows; ++r) {
|
||||
float v = (r >= kRows / 2) ? 0.7 : 0.4;
|
||||
auto e = Entry{c, v};
|
||||
entries[c * kRows + r] = e;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t n_bins = 16;
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch(ft, n_bins, kCols, kRows, 0);
|
||||
dh::device_vector<Entry> d_entries {entries};
|
||||
|
||||
dh::device_vector<size_t> columns_ptr(kCols + 1, 0);
|
||||
for (size_t i = 1; i < kCols + 1; ++i) {
|
||||
columns_ptr[i] = kRows;
|
||||
}
|
||||
thrust::inclusive_scan(thrust::device, columns_ptr.begin(), columns_ptr.end(),
|
||||
columns_ptr.begin());
|
||||
dh::device_vector<size_t> cuts_ptr(columns_ptr);
|
||||
|
||||
sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr),
|
||||
dh::ToSpan(cuts_ptr), kRows * kCols, {});
|
||||
|
||||
auto sketch_data = sketch.Data();
|
||||
ASSERT_EQ(sketch_data.size(), kCols * 2);
|
||||
auto ptr = thrust::device_ptr<SketchEntry const>(sketch_data.data());
|
||||
std::vector<SketchEntry> h_sketch_data(sketch_data.size());
|
||||
thrust::copy(ptr, ptr + sketch_data.size(), h_sketch_data.begin());
|
||||
|
||||
for (size_t i = 0; i < kCols; ++i) {
|
||||
auto v_0 = h_sketch_data[i * 2];
|
||||
ASSERT_EQ(v_0.rmin, 0);
|
||||
ASSERT_EQ(v_0.wmin, kRows / 2.0f);
|
||||
ASSERT_EQ(v_0.rmax, kRows / 2.0f);
|
||||
|
||||
auto v_1 = h_sketch_data[i * 2 + 1];
|
||||
ASSERT_EQ(v_1.rmin, kRows / 2.0f);
|
||||
ASSERT_EQ(v_1.wmin, kRows / 2.0f);
|
||||
ASSERT_EQ(v_1.rmax, static_cast<float>(kRows));
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user