Support categorical data for dask functional interface and DQM. (#7043)
* Support categorical data for dask functional interface and DQM. * Implement categorical data support for GPU GK-merge. * Add support for dask functional interface. * Add support for DQM. * Get newer cupy.
This commit is contained in:
@@ -16,6 +16,19 @@ class HistogramCuts;
|
||||
using WQSketch = WQuantileSketch<bst_float, bst_float>;
|
||||
using SketchEntry = WQSketch::Entry;
|
||||
|
||||
namespace detail {
|
||||
struct IsCatOp {
|
||||
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
||||
return ft == FeatureType::kCategorical;
|
||||
}
|
||||
};
|
||||
struct SketchUnique {
|
||||
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
||||
return a.value - b.value == 0;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/*!
|
||||
* \brief A container that holds the device sketches. Sketching is performed per-column,
|
||||
* but fused into single operation for performance.
|
||||
@@ -43,6 +56,8 @@ class SketchContainer {
|
||||
HostDeviceVector<OffsetT> columns_ptr_;
|
||||
HostDeviceVector<OffsetT> columns_ptr_b_;
|
||||
|
||||
bool has_categorical_{false};
|
||||
|
||||
dh::device_vector<SketchEntry>& Current() {
|
||||
if (current_buffer_) {
|
||||
return entries_a_;
|
||||
@@ -102,14 +117,21 @@ class SketchContainer {
|
||||
this->feature_types_.SetDevice(device);
|
||||
this->feature_types_.ConstDeviceSpan();
|
||||
this->feature_types_.ConstHostSpan();
|
||||
|
||||
auto d_feature_types = feature_types_.ConstDeviceSpan();
|
||||
has_categorical_ =
|
||||
!d_feature_types.empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
detail::IsCatOp{});
|
||||
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
/* \brief Return GPU ID for this container. */
|
||||
int32_t DeviceIdx() const { return device_; }
|
||||
/* \brief Whether the predictor matrix contains categorical features. */
|
||||
bool HasCategorical() const { return has_categorical_; }
|
||||
/* \brief Accumulate weights of duplicated entries in input. */
|
||||
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
|
||||
/* \brief Removes all the duplicated elements in quantile structure. */
|
||||
size_t Unique();
|
||||
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
|
||||
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
|
||||
void FixError();
|
||||
@@ -154,15 +176,35 @@ class SketchContainer {
|
||||
|
||||
SketchContainer(const SketchContainer&) = delete;
|
||||
SketchContainer& operator=(const SketchContainer&) = delete;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
struct SketchUnique {
|
||||
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
||||
return a.value - b.value == 0;
|
||||
/* \brief Removes all the duplicated elements in quantile structure. */
|
||||
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||
Span<SketchEntry> entries = dh::ToSpan(this->Current());
|
||||
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
||||
scan_out.SetDevice(device_);
|
||||
auto d_scan_out = scan_out.DeviceSpan();
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
size_t n_uniques = dh::SegmentedUnique(
|
||||
thrust::cuda::par(alloc), d_column_scan.data(),
|
||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||
entries.data(), detail::SketchUnique{}, key_comp);
|
||||
this->columns_ptr_.Copy(scan_out);
|
||||
CHECK(!this->columns_ptr_.HostCanRead());
|
||||
|
||||
this->Current().resize(n_uniques);
|
||||
timer_.Stop(__func__);
|
||||
return n_uniques;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
Reference in New Issue
Block a user