Merge extract cuts into QuantileContainer. (#6125)
* Use pruning for initial summary construction.
This commit is contained in:
parent
cc82ca167a
commit
e319b63f9e
@ -33,79 +33,6 @@ namespace common {
|
||||
constexpr float SketchContainer::kFactor;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry const> sorted_data,
|
||||
Span<size_t const> column_sizes_scan,
|
||||
Span<SketchEntry> out_cuts) {
|
||||
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
|
||||
size_t cut_idx = idx - cuts_ptr[column_idx];
|
||||
Span<Entry const> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
size_t rank = (column_entries.size() * cut_idx) /
|
||||
static_cast<float>(num_available_cuts);
|
||||
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
|
||||
column_entries[rank].fvalue);
|
||||
});
|
||||
}
|
||||
|
||||
void ExtractWeightedCutsSparse(int device,
|
||||
common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry> sorted_data,
|
||||
Span<float> weights_scan,
|
||||
Span<size_t> column_sizes_scan,
|
||||
Span<SketchEntry> cuts) {
|
||||
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
|
||||
size_t cut_idx = idx - cuts_ptr[column_idx];
|
||||
|
||||
Span<Entry> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
|
||||
Span<float> column_weights_scan =
|
||||
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
|
||||
float total_column_weight = column_weights_scan.back();
|
||||
size_t sample_idx = 0;
|
||||
if (cut_idx == 0) {
|
||||
// First cut
|
||||
sample_idx = 0;
|
||||
} else if (cut_idx == num_available_cuts) {
|
||||
// Last cut
|
||||
sample_idx = column_entries.size() - 1;
|
||||
} else if (num_available_cuts == column_size) {
|
||||
// There are less samples available than our buffer
|
||||
// Take every available sample
|
||||
sample_idx = cut_idx;
|
||||
} else {
|
||||
bst_float rank = (total_column_weight * cut_idx) /
|
||||
static_cast<float>(num_available_cuts);
|
||||
sample_idx = thrust::upper_bound(thrust::seq,
|
||||
column_weights_scan.begin(),
|
||||
column_weights_scan.end(),
|
||||
rank) -
|
||||
column_weights_scan.begin();
|
||||
sample_idx =
|
||||
max(static_cast<size_t>(0),
|
||||
min(sample_idx, column_entries.size() - 1));
|
||||
}
|
||||
// repeated values will be filtered out later.
|
||||
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
|
||||
bst_float rmax = column_weights_scan[sample_idx];
|
||||
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
|
||||
column_entries[sample_idx].fvalue);
|
||||
});
|
||||
}
|
||||
|
||||
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
|
||||
double eps = 1.0 / (WQSketch::kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
@ -220,19 +147,16 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));
|
||||
|
||||
// add cuts into sketches
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
|
||||
d_cuts_ptr, h_cuts_ptr.back());
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
CHECK_EQ(sorted_entries.capacity(), 0);
|
||||
CHECK_NE(cuts_ptr.Size(), 0);
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
@ -285,18 +209,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
// Extract cuts
|
||||
detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(temp_weights),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
|
||||
// add cuts into sketches
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
}
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
|
||||
@ -28,37 +28,6 @@ struct EntryCompareOp {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Extracts the cuts from sorted data.
|
||||
*
|
||||
* \param device The device.
|
||||
* \param cuts_ptr Column pointers to CSC structured cuts
|
||||
* \param sorted_data Sorted entries in segments of columns
|
||||
* \param column_sizes_scan Describes the boundaries of column segments in sorted data
|
||||
* \param out_cuts Output cut values
|
||||
*/
|
||||
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry const> sorted_data,
|
||||
Span<size_t const> column_sizes_scan,
|
||||
Span<SketchEntry> out_cuts);
|
||||
|
||||
/**
|
||||
* \brief Extracts the cuts from sorted data, considering weights.
|
||||
*
|
||||
* \param device The device.
|
||||
* \param cuts_ptr Column pointers to CSC structured cuts
|
||||
* \param sorted_data Sorted entries in segments of columns.
|
||||
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
|
||||
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
|
||||
* \param cuts Output cuts.
|
||||
*/
|
||||
void ExtractWeightedCutsSparse(int device,
|
||||
common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry> sorted_data,
|
||||
Span<float> weights_scan,
|
||||
Span<size_t> column_sizes_scan,
|
||||
Span<SketchEntry> cuts);
|
||||
|
||||
// Get column size from adapter batch and for output cuts.
|
||||
template <typename Iter>
|
||||
void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature,
|
||||
@ -186,16 +155,12 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
// Extract the cuts from all columns concurrently
|
||||
detail::ExtractCutsSparse(device, d_cuts_ptr,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||
h_cuts_ptr.back());
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
}
|
||||
|
||||
template <typename Batch>
|
||||
@ -263,16 +228,11 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
// Extract cuts
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(temp_weights),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
// add cuts into sketches
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -324,4 +284,4 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // COMMON_HIST_UTIL_CUH_
|
||||
#endif // COMMON_HIST_UTIL_CUH_
|
||||
|
||||
@ -24,29 +24,83 @@ using WQSketch = HostSketchContainer::WQSketch;
|
||||
using SketchEntry = WQSketch::Entry;
|
||||
|
||||
// Algorithm 4 in XGBoost's paper, using binary search to find i.
|
||||
__device__ SketchEntry BinarySearchQuery(Span<SketchEntry const> const& entries, float rank) {
|
||||
assert(entries.size() >= 2);
|
||||
template <typename EntryIter>
|
||||
__device__ SketchEntry BinarySearchQuery(EntryIter beg, EntryIter end, float rank) {
|
||||
assert(end - beg >= 2);
|
||||
rank *= 2;
|
||||
if (rank < entries.front().rmin + entries.front().rmax) {
|
||||
return entries.front();
|
||||
auto front = *beg;
|
||||
if (rank < front.rmin + front.rmax) {
|
||||
return *beg;
|
||||
}
|
||||
if (rank >= entries.back().rmin + entries.back().rmax) {
|
||||
return entries.back();
|
||||
auto back = *(end - 1);
|
||||
if (rank >= back.rmin + back.rmax) {
|
||||
return back;
|
||||
}
|
||||
|
||||
auto begin = dh::MakeTransformIterator<float>(
|
||||
entries.begin(), [=] __device__(SketchEntry const &entry) {
|
||||
auto search_begin = dh::MakeTransformIterator<float>(
|
||||
beg, [=] __device__(SketchEntry const &entry) {
|
||||
return entry.rmin + entry.rmax;
|
||||
});
|
||||
auto end = begin + entries.size();
|
||||
auto i = thrust::upper_bound(thrust::seq, begin + 1, end - 1, rank) - begin - 1;
|
||||
if (rank < entries[i].RMinNext() + entries[i+1].RMaxPrev()) {
|
||||
return entries[i];
|
||||
auto search_end = search_begin + (end - beg);
|
||||
auto i =
|
||||
thrust::upper_bound(thrust::seq, search_begin + 1, search_end - 1, rank) -
|
||||
search_begin - 1;
|
||||
if (rank < (*(beg + i)).RMinNext() + (*(beg + i + 1)).RMaxPrev()) {
|
||||
return *(beg + i);
|
||||
} else {
|
||||
return entries[i+1];
|
||||
return *(beg + i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InEntry, typename ToSketchEntry>
|
||||
void PruneImpl(int device,
|
||||
common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<InEntry const> sorted_data,
|
||||
Span<size_t const> columns_ptr_in, // could be ptr for data or cuts
|
||||
Span<SketchEntry> out_cuts,
|
||||
ToSketchEntry to_sketch_entry) {
|
||||
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
|
||||
size_t column_id = dh::SegmentId(cuts_ptr, idx);
|
||||
auto out_column = out_cuts.subspan(
|
||||
cuts_ptr[column_id], cuts_ptr[column_id + 1] - cuts_ptr[column_id]);
|
||||
auto in_column = sorted_data.subspan(columns_ptr_in[column_id],
|
||||
columns_ptr_in[column_id + 1] -
|
||||
columns_ptr_in[column_id]);
|
||||
auto to = cuts_ptr[column_id + 1] - cuts_ptr[column_id];
|
||||
idx -= cuts_ptr[column_id];
|
||||
auto front = to_sketch_entry(0ul, in_column, column_id);
|
||||
auto back = to_sketch_entry(in_column.size() - 1, in_column, column_id);
|
||||
|
||||
if (in_column.size() <= to) {
|
||||
// cut idx equals sample idx
|
||||
out_column[idx] = to_sketch_entry(idx, in_column, column_id);
|
||||
return;
|
||||
}
|
||||
// 1 thread for each output. See A.4 for detail.
|
||||
auto d_out = out_column;
|
||||
if (idx == 0) {
|
||||
d_out.front() = front;
|
||||
return;
|
||||
}
|
||||
if (idx == to - 1) {
|
||||
d_out.back() = back;
|
||||
return;
|
||||
}
|
||||
|
||||
float w = back.rmin - front.rmax;
|
||||
assert(w != 0);
|
||||
auto budget = static_cast<float>(d_out.size());
|
||||
assert(budget != 0);
|
||||
auto q = ((static_cast<float>(idx) * w) / (static_cast<float>(to) - 1.0f) + front.rmax);
|
||||
auto it = dh::MakeTransformIterator<SketchEntry>(
|
||||
thrust::make_counting_iterator(0ul), [=] __device__(size_t idx) {
|
||||
auto e = to_sketch_entry(idx, in_column, column_id);
|
||||
return e;
|
||||
});
|
||||
d_out[idx] = BinarySearchQuery(it, it + in_column.size(), q);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTo(Span<T> out, Span<T const> src) {
|
||||
CHECK_EQ(out.size(), src.size());
|
||||
@ -249,27 +303,58 @@ void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
|
||||
});
|
||||
}
|
||||
|
||||
void SketchContainer::Push(common::Span<OffsetT const> cuts_ptr,
|
||||
dh::caching_device_vector<SketchEntry>* entries) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
// Copy or merge the new cuts, pruning is performed during `MakeCuts`.
|
||||
if (this->Current().size() == 0) {
|
||||
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
|
||||
// See thrust issue 1030, THRUST_CPP_DIALECT is not correctly defined so
|
||||
// move constructor is not used.
|
||||
this->Current().swap(*entries);
|
||||
CHECK_EQ(entries->size(), 0);
|
||||
auto d_cuts_ptr = this->columns_ptr_.DevicePointer();
|
||||
thrust::copy(thrust::device, cuts_ptr.data(),
|
||||
cuts_ptr.data() + cuts_ptr.size(), d_cuts_ptr);
|
||||
void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||
common::Span<OffsetT const> cuts_ptr,
|
||||
size_t total_cuts, Span<float> weights) {
|
||||
Span<SketchEntry> out;
|
||||
dh::caching_device_vector<SketchEntry> cuts;
|
||||
bool first_window = this->Current().empty();
|
||||
if (!first_window) {
|
||||
cuts.resize(total_cuts);
|
||||
out = dh::ToSpan(cuts);
|
||||
} else {
|
||||
auto d_entries = dh::ToSpan(*entries);
|
||||
this->Merge(cuts_ptr, d_entries);
|
||||
this->FixError();
|
||||
this->Current().resize(total_cuts);
|
||||
out = dh::ToSpan(this->Current());
|
||||
}
|
||||
|
||||
if (weights.empty()) {
|
||||
auto to_sketch_entry = [] __device__(size_t sample_idx,
|
||||
Span<Entry const> const &column,
|
||||
size_t) {
|
||||
float rmin = sample_idx;
|
||||
float rmax = sample_idx + 1;
|
||||
return SketchEntry{rmin, rmax, 1, column[sample_idx].fvalue};
|
||||
}; // NOLINT
|
||||
PruneImpl<Entry>(device_, cuts_ptr, entries, columns_ptr, out,
|
||||
to_sketch_entry);
|
||||
} else {
|
||||
auto to_sketch_entry = [weights, columns_ptr] __device__(
|
||||
size_t sample_idx,
|
||||
Span<Entry const> const &column,
|
||||
size_t column_id) {
|
||||
Span<float const> column_weights_scan =
|
||||
weights.subspan(columns_ptr[column_id], column.size());
|
||||
float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
|
||||
float rmax = column_weights_scan[sample_idx];
|
||||
float wmin = rmax - rmin;
|
||||
wmin = wmin < 0 ? kRtEps : wmin; // GPU scan can generate floating error.
|
||||
return SketchEntry{rmin, rmax, wmin, column[sample_idx].fvalue};
|
||||
}; // NOLINT
|
||||
PruneImpl<Entry>(device_, cuts_ptr, entries, columns_ptr, out,
|
||||
to_sketch_entry);
|
||||
}
|
||||
|
||||
if (!first_window) {
|
||||
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
|
||||
this->Merge(cuts_ptr, out);
|
||||
this->FixError();
|
||||
} else {
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(cuts_ptr.size());
|
||||
|
||||
auto d_cuts_ptr = this->columns_ptr_.DeviceSpan();
|
||||
CopyTo(d_cuts_ptr, cuts_ptr);
|
||||
}
|
||||
CHECK_NE(this->columns_ptr_.Size(), 0);
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
|
||||
size_t SketchContainer::Unique() {
|
||||
@ -317,41 +402,11 @@ void SketchContainer::Prune(size_t to) {
|
||||
auto d_columns_ptr_out = new_columns_ptr.ConstDeviceSpan();
|
||||
auto out = dh::ToSpan(this->Other());
|
||||
auto in = dh::ToSpan(this->Current());
|
||||
dh::LaunchN(0, to_total, [=] __device__(size_t idx) {
|
||||
size_t column_id = dh::SegmentId(d_columns_ptr_out, idx);
|
||||
auto out_column = out.subspan(d_columns_ptr_out[column_id],
|
||||
d_columns_ptr_out[column_id + 1] -
|
||||
d_columns_ptr_out[column_id]);
|
||||
auto in_column = in.subspan(d_columns_ptr_in[column_id],
|
||||
d_columns_ptr_in[column_id + 1] -
|
||||
d_columns_ptr_in[column_id]);
|
||||
idx -= d_columns_ptr_out[column_id];
|
||||
// Input has lesser columns than `to`, just copy them to the output. This is correct
|
||||
// as the new output size is calculated based on both the size of `to` and current
|
||||
// column.
|
||||
if (in_column.size() <= to) {
|
||||
out_column[idx] = in_column[idx];
|
||||
return;
|
||||
}
|
||||
// 1 thread for each output. See A.4 for detail.
|
||||
auto entries = in_column;
|
||||
auto d_out = out_column;
|
||||
if (idx == 0) {
|
||||
d_out.front() = entries.front();
|
||||
return;
|
||||
}
|
||||
if (idx == to - 1) {
|
||||
d_out.back() = entries.back();
|
||||
return;
|
||||
}
|
||||
|
||||
float w = entries.back().rmin - entries.front().rmax;
|
||||
assert(w != 0);
|
||||
auto budget = static_cast<float>(d_out.size());
|
||||
assert(budget != 0);
|
||||
auto q = ((idx * w) / (to - 1) + entries.front().rmax);
|
||||
d_out[idx] = BinarySearchQuery(entries, q);
|
||||
});
|
||||
auto no_op = [] __device__(size_t sample_idx,
|
||||
Span<SketchEntry const> const &entries,
|
||||
size_t) { return entries[sample_idx]; }; // NOLINT
|
||||
PruneImpl<SketchEntry>(device_, d_columns_ptr_out, in, d_columns_ptr_in, out,
|
||||
no_op);
|
||||
this->columns_ptr_.HostVector() = new_columns_ptr.HostVector();
|
||||
this->Alternate();
|
||||
timer_.Stop(__func__);
|
||||
|
||||
@ -96,13 +96,21 @@ class SketchContainer {
|
||||
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
|
||||
void FixError();
|
||||
|
||||
/* \brief Push a CSC structured cut matrix. */
|
||||
void Push(common::Span<OffsetT const> cuts_ptr,
|
||||
dh::caching_device_vector<SketchEntry>* entries);
|
||||
/* \brief Push sorted entries.
|
||||
*
|
||||
* \param entries Sorted entries.
|
||||
* \param columns_ptr CSC pointer for entries.
|
||||
* \param cuts_ptr CSC pointer for cuts.
|
||||
* \param total_cuts Total number of cuts, equal to the back of cuts_ptr.
|
||||
* \param weights (optional) data weights.
|
||||
*/
|
||||
void Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||
common::Span<OffsetT const> cuts_ptr, size_t total_cuts,
|
||||
Span<float> weights = {});
|
||||
/* \brief Prune the quantile structure.
|
||||
*
|
||||
* \param to The maximum size of pruned quantile. If the size of quantile structure is
|
||||
* already less than `to`, then no operation is performed.
|
||||
* \param to The maximum size of pruned quantile. If the size of quantile
|
||||
* structure is already less than `to`, then no operation is performed.
|
||||
*/
|
||||
void Prune(size_t to);
|
||||
/* \brief Merge another set of sketch.
|
||||
@ -135,8 +143,8 @@ struct SketchUnique {
|
||||
return a.value - b.value == 0;
|
||||
}
|
||||
};
|
||||
} // anonymous detail
|
||||
} // namespace detail
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_COMMON_QUANTILE_CUH_
|
||||
#endif // XGBOOST_COMMON_QUANTILE_CUH_
|
||||
|
||||
@ -29,6 +29,24 @@ TEST(AtomicAdd, SizeT) {
|
||||
TestAtomicSizeT();
|
||||
}
|
||||
|
||||
void TestSegmentID() {
|
||||
std::vector<size_t> segments{0, 1, 3};
|
||||
thrust::device_vector<size_t> d_segments(segments);
|
||||
auto s_segments = dh::ToSpan(d_segments);
|
||||
dh::LaunchN(0, 1, [=]__device__(size_t idx) {
|
||||
auto id = dh::SegmentId(s_segments, 0);
|
||||
SPAN_CHECK(id == 0);
|
||||
id = dh::SegmentId(s_segments, 1);
|
||||
SPAN_CHECK(id == 1);
|
||||
id = dh::SegmentId(s_segments, 2);
|
||||
SPAN_CHECK(id == 1);
|
||||
});
|
||||
}
|
||||
|
||||
TEST(SegmentID, Basic) {
|
||||
TestSegmentID();
|
||||
}
|
||||
|
||||
TEST(SegmentedUnique, Basic) {
|
||||
std::vector<float> values{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.62448811531066895f, 0.4f};
|
||||
std::vector<size_t> segments{0, 3, 6};
|
||||
|
||||
@ -9,11 +9,11 @@ namespace common {
|
||||
TEST(GPUQuantile, Basic) {
|
||||
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
|
||||
SketchContainer sketch(kBins, kCols, kRows, 0);
|
||||
dh::caching_device_vector<SketchEntry> entries;
|
||||
dh::caching_device_vector<Entry> entries;
|
||||
dh::device_vector<bst_row_t> cuts_ptr(kCols+1);
|
||||
thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0);
|
||||
// Push empty
|
||||
sketch.Push(dh::ToSpan(cuts_ptr), &entries);
|
||||
sketch.Push(dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0);
|
||||
ASSERT_EQ(sketch.Data().size(), 0);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user