Support categorical data for dask functional interface and DQM. (#7043)
* Support categorical data for dask functional interface and DQM. * Implement categorical data support for GPU GK-merge. * Add support for dask functional interface. * Add support for DQM. * Get newer cupy.
This commit is contained in:
parent
7dd29ffd47
commit
86715e4cd4
@ -321,6 +321,7 @@ class DataIter:
|
||||
def __init__(self):
|
||||
self._handle = _ProxyDMatrix()
|
||||
self.exception = None
|
||||
self.enable_categorical = False
|
||||
|
||||
@property
|
||||
def proxy(self):
|
||||
@ -346,13 +347,12 @@ class DataIter:
|
||||
data,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
enable_categorical=False,
|
||||
**kwargs
|
||||
):
|
||||
from .data import dispatch_device_quantile_dmatrix_set_data
|
||||
from .data import _device_quantile_transform
|
||||
data, feature_names, feature_types = _device_quantile_transform(
|
||||
data, feature_names, feature_types, enable_categorical,
|
||||
data, feature_names, feature_types, self.enable_categorical,
|
||||
)
|
||||
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||
self.proxy.set_info(
|
||||
@ -1106,15 +1106,10 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
data = _transform_dlpack(data)
|
||||
if _is_iter(data):
|
||||
it = data
|
||||
if enable_categorical:
|
||||
raise NotImplementedError(
|
||||
"categorical support is not enabled on data iterator."
|
||||
)
|
||||
else:
|
||||
it = SingleBatchInternalIter(
|
||||
data=data, enable_categorical=enable_categorical, **meta
|
||||
)
|
||||
it = SingleBatchInternalIter(data=data, **meta)
|
||||
|
||||
it.enable_categorical = enable_categorical
|
||||
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
||||
next_callback = ctypes.CFUNCTYPE(
|
||||
ctypes.c_int,
|
||||
|
||||
@ -182,7 +182,7 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
|
||||
lazy_isinstance(value[0], 'cudf.core.series', 'Series'):
|
||||
from cudf import concat as CUDF_concat # pylint: disable=import-error
|
||||
return CUDF_concat(value, axis=0)
|
||||
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
|
||||
if lazy_isinstance(value[0], 'cupy._core.core', 'ndarray'):
|
||||
import cupy
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
d = cupy.cuda.runtime.getDevice()
|
||||
@ -258,6 +258,7 @@ class DaskDMatrix:
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
self.missing = missing
|
||||
self.enable_categorical = enable_categorical
|
||||
|
||||
if qid is not None and weight is not None:
|
||||
raise NotImplementedError("per-group weight is not implemented.")
|
||||
@ -265,10 +266,6 @@ class DaskDMatrix:
|
||||
raise NotImplementedError(
|
||||
"group structure is not implemented, use qid instead."
|
||||
)
|
||||
if enable_categorical:
|
||||
raise NotImplementedError(
|
||||
"categorical support is not enabled on `DaskDMatrix`."
|
||||
)
|
||||
|
||||
if len(data.shape) != 2:
|
||||
raise ValueError(
|
||||
@ -311,7 +308,7 @@ class DaskDMatrix:
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
label_upper_bound: Optional[_DaskCollection] = None
|
||||
label_upper_bound: Optional[_DaskCollection] = None,
|
||||
) -> "DaskDMatrix":
|
||||
'''Obtain references to local data.'''
|
||||
|
||||
@ -430,6 +427,7 @@ class DaskDMatrix:
|
||||
'feature_weights': self.feature_weights,
|
||||
'meta_names': self.meta_names,
|
||||
'missing': self.missing,
|
||||
'enable_categorical': self.enable_categorical,
|
||||
'parts': self.worker_map.get(worker_addr, None),
|
||||
'is_quantile': self.is_quantile}
|
||||
|
||||
@ -668,6 +666,7 @@ def _create_device_quantile_dmatrix(
|
||||
missing: float,
|
||||
parts: Optional[_DataParts],
|
||||
max_bin: int,
|
||||
enable_categorical: bool,
|
||||
) -> DeviceQuantileDMatrix:
|
||||
worker = distributed.get_worker()
|
||||
if parts is None:
|
||||
@ -680,6 +679,7 @@ def _create_device_quantile_dmatrix(
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
max_bin=max_bin,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
return d
|
||||
|
||||
@ -709,6 +709,7 @@ def _create_device_quantile_dmatrix(
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
max_bin=max_bin,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
dmatrix.set_info(feature_weights=feature_weights)
|
||||
return dmatrix
|
||||
@ -720,6 +721,7 @@ def _create_dmatrix(
|
||||
feature_weights: Optional[Any],
|
||||
meta_names: List[str],
|
||||
missing: float,
|
||||
enable_categorical: bool,
|
||||
parts: Optional[_DataParts]
|
||||
) -> DMatrix:
|
||||
'''Get data that local to worker from DaskDMatrix.
|
||||
@ -734,9 +736,12 @@ def _create_dmatrix(
|
||||
if list_of_parts is None:
|
||||
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
|
||||
LOGGER.warning(msg)
|
||||
d = DMatrix(numpy.empty((0, 0)),
|
||||
d = DMatrix(
|
||||
numpy.empty((0, 0)),
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
feature_types=feature_types,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
return d
|
||||
|
||||
T = TypeVar('T')
|
||||
@ -764,6 +769,7 @@ def _create_dmatrix(
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
dmatrix.set_info(
|
||||
base_margin=_base_margin,
|
||||
|
||||
@ -1151,12 +1151,12 @@ struct SegmentedUniqueReduceOp {
|
||||
* \return Number of unique values in total.
|
||||
*/
|
||||
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
|
||||
typename ValOutIt, typename Comp>
|
||||
typename ValOutIt, typename CompValue, typename CompKey>
|
||||
size_t
|
||||
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
|
||||
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
|
||||
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
|
||||
Comp comp) {
|
||||
CompValue comp, CompKey comp_key=thrust::equal_to<size_t>{}) {
|
||||
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
|
||||
auto unique_key_it = dh::MakeTransformIterator<Key>(
|
||||
thrust::make_counting_iterator(static_cast<size_t>(0)),
|
||||
@ -1177,7 +1177,7 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
|
||||
exec, unique_key_it, unique_key_it + n_inputs,
|
||||
val_first, reduce_it, val_out,
|
||||
[=] __device__(Key const &l, Key const &r) {
|
||||
if (l.first == r.first) {
|
||||
if (comp_key(l.first, r.first)) {
|
||||
// In the same segment.
|
||||
return comp(l.second, r.second);
|
||||
}
|
||||
@ -1195,7 +1195,9 @@ template <typename... Inputs,
|
||||
* = nullptr>
|
||||
size_t SegmentedUnique(Inputs &&...inputs) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
|
||||
return SegmentedUnique(thrust::cuda::par(alloc),
|
||||
std::forward<Inputs &&>(inputs)...,
|
||||
thrust::equal_to<size_t>{});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -129,24 +129,17 @@ void SortByWeight(dh::device_vector<float>* weights,
|
||||
});
|
||||
}
|
||||
|
||||
struct IsCatOp {
|
||||
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
|
||||
};
|
||||
|
||||
void RemoveDuplicatedCategories(
|
||||
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
|
||||
dh::device_vector<Entry> *p_sorted_entries,
|
||||
dh::caching_device_vector<size_t> *p_column_sizes_scan) {
|
||||
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
||||
CHECK(!d_feature_types.empty());
|
||||
auto &column_sizes_scan = *p_column_sizes_scan;
|
||||
if (!info.feature_types.Empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
IsCatOp{})) {
|
||||
auto &sorted_entries = *p_sorted_entries;
|
||||
// Removing duplicated entries in categorical features.
|
||||
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
|
||||
dh::SegmentedUnique(
|
||||
column_sizes_scan.data().get(),
|
||||
dh::SegmentedUnique(column_sizes_scan.data().get(),
|
||||
column_sizes_scan.data().get() + column_sizes_scan.size(),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
new_column_scan.data().get(), sorted_entries.begin(),
|
||||
@ -183,7 +176,6 @@ void RemoveDuplicatedCategories(
|
||||
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
|
||||
new_cuts_size.cend(), d_cuts_ptr.data());
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
@ -215,8 +207,11 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
0, sorted_entries.size(),
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
||||
&column_sizes_scan);
|
||||
|
||||
if (sketch_container->HasCategorical()) {
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
||||
&sorted_entries, &column_sizes_scan);
|
||||
}
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||
@ -281,8 +276,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
0, sorted_entries.size(),
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
||||
&column_sizes_scan);
|
||||
if (sketch_container->HasCategorical()) {
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
||||
&sorted_entries, &column_sizes_scan);
|
||||
}
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
|
||||
// Extract cuts
|
||||
|
||||
@ -210,6 +210,7 @@ void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
|
||||
Span<bst_row_t const> const &x_ptr,
|
||||
Span<SketchEntry const> const &d_y,
|
||||
Span<bst_row_t const> const &y_ptr,
|
||||
Span<FeatureType const> feature_types,
|
||||
Span<SketchEntry> out,
|
||||
Span<bst_row_t> out_ptr) {
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
@ -408,31 +409,6 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
|
||||
return n_uniques;
|
||||
}
|
||||
|
||||
size_t SketchContainer::Unique() {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||
Span<SketchEntry> entries = dh::ToSpan(this->Current());
|
||||
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
||||
scan_out.SetDevice(device_);
|
||||
auto d_scan_out = scan_out.DeviceSpan();
|
||||
|
||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
size_t n_uniques = dh::SegmentedUnique(
|
||||
d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
|
||||
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||
entries.data(),
|
||||
detail::SketchUnique{});
|
||||
this->columns_ptr_.Copy(scan_out);
|
||||
CHECK(!this->columns_ptr_.HostCanRead());
|
||||
|
||||
this->Current().resize(n_uniques);
|
||||
timer_.Stop(__func__);
|
||||
return n_uniques;
|
||||
}
|
||||
|
||||
void SketchContainer::Prune(size_t to) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
@ -490,13 +466,20 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
|
||||
this->Other().resize(this->Current().size() + that.size());
|
||||
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
|
||||
|
||||
MergeImpl(device_, this->Data(), this->ColumnsPtr(),
|
||||
that, d_that_columns_ptr,
|
||||
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan());
|
||||
auto feature_types = this->FeatureTypes().ConstDeviceSpan();
|
||||
MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
|
||||
feature_types, dh::ToSpan(this->Other()),
|
||||
columns_ptr_b_.DeviceSpan());
|
||||
this->columns_ptr_.Copy(columns_ptr_b_);
|
||||
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
|
||||
this->Alternate();
|
||||
|
||||
if (this->HasCategorical()) {
|
||||
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
|
||||
this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
|
||||
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
|
||||
});
|
||||
}
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
|
||||
|
||||
@ -16,6 +16,19 @@ class HistogramCuts;
|
||||
using WQSketch = WQuantileSketch<bst_float, bst_float>;
|
||||
using SketchEntry = WQSketch::Entry;
|
||||
|
||||
namespace detail {
|
||||
struct IsCatOp {
|
||||
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
||||
return ft == FeatureType::kCategorical;
|
||||
}
|
||||
};
|
||||
struct SketchUnique {
|
||||
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
||||
return a.value - b.value == 0;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/*!
|
||||
* \brief A container that holds the device sketches. Sketching is performed per-column,
|
||||
* but fused into single operation for performance.
|
||||
@ -43,6 +56,8 @@ class SketchContainer {
|
||||
HostDeviceVector<OffsetT> columns_ptr_;
|
||||
HostDeviceVector<OffsetT> columns_ptr_b_;
|
||||
|
||||
bool has_categorical_{false};
|
||||
|
||||
dh::device_vector<SketchEntry>& Current() {
|
||||
if (current_buffer_) {
|
||||
return entries_a_;
|
||||
@ -102,14 +117,21 @@ class SketchContainer {
|
||||
this->feature_types_.SetDevice(device);
|
||||
this->feature_types_.ConstDeviceSpan();
|
||||
this->feature_types_.ConstHostSpan();
|
||||
|
||||
auto d_feature_types = feature_types_.ConstDeviceSpan();
|
||||
has_categorical_ =
|
||||
!d_feature_types.empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
detail::IsCatOp{});
|
||||
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
/* \brief Return GPU ID for this container. */
|
||||
int32_t DeviceIdx() const { return device_; }
|
||||
/* \brief Whether the predictor matrix contains categorical features. */
|
||||
bool HasCategorical() const { return has_categorical_; }
|
||||
/* \brief Accumulate weights of duplicated entries in input. */
|
||||
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
|
||||
/* \brief Removes all the duplicated elements in quantile structure. */
|
||||
size_t Unique();
|
||||
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
|
||||
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
|
||||
void FixError();
|
||||
@ -154,15 +176,35 @@ class SketchContainer {
|
||||
|
||||
SketchContainer(const SketchContainer&) = delete;
|
||||
SketchContainer& operator=(const SketchContainer&) = delete;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
struct SketchUnique {
|
||||
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
||||
return a.value - b.value == 0;
|
||||
/* \brief Removes all the duplicated elements in quantile structure. */
|
||||
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||
Span<SketchEntry> entries = dh::ToSpan(this->Current());
|
||||
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
||||
scan_out.SetDevice(device_);
|
||||
auto d_scan_out = scan_out.DeviceSpan();
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
size_t n_uniques = dh::SegmentedUnique(
|
||||
thrust::cuda::par(alloc), d_column_scan.data(),
|
||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||
entries.data(), detail::SketchUnique{}, key_comp);
|
||||
this->columns_ptr_.Copy(scan_out);
|
||||
CHECK(!this->columns_ptr_.HostCanRead());
|
||||
|
||||
this->Current().resize(n_uniques);
|
||||
timer_.Stop(__func__);
|
||||
return n_uniques;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -134,17 +134,20 @@ struct WriteCompressedEllpackFunctor {
|
||||
const common::CompressedBufferWriter& writer,
|
||||
AdapterBatchT batch,
|
||||
EllpackDeviceAccessor accessor,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
const data::IsValidFunctor& is_valid)
|
||||
: d_buffer(buffer),
|
||||
writer(writer),
|
||||
batch(std::move(batch)),
|
||||
accessor(std::move(accessor)),
|
||||
feature_types(std::move(feature_types)),
|
||||
is_valid(is_valid) {}
|
||||
|
||||
common::CompressedByteT* d_buffer;
|
||||
common::CompressedBufferWriter writer;
|
||||
AdapterBatchT batch;
|
||||
EllpackDeviceAccessor accessor;
|
||||
common::Span<FeatureType const> feature_types;
|
||||
data::IsValidFunctor is_valid;
|
||||
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
@ -154,7 +157,12 @@ struct WriteCompressedEllpackFunctor {
|
||||
// -1 because the scan is inclusive
|
||||
size_t output_position =
|
||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
||||
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
|
||||
uint32_t bin_idx = 0;
|
||||
if (common::IsCat(feature_types, e.column_idx)) {
|
||||
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
|
||||
} else {
|
||||
bin_idx = accessor.SearchBin<false>(e.value, e.column_idx);
|
||||
}
|
||||
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
|
||||
}
|
||||
return 0;
|
||||
@ -184,8 +192,9 @@ class TypedDiscard : public thrust::discard_iterator<T> {
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
int device_idx, float missing) {
|
||||
void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
EllpackPageImpl *dst, int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ELLPACK matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
@ -220,7 +229,8 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
|
||||
// We redirect the scan output into this functor to do the actual writing
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||
d_compressed_buffer, writer, batch, device_accessor, is_valid);
|
||||
d_compressed_buffer, writer, batch, device_accessor, feature_types,
|
||||
is_valid);
|
||||
TypedDiscard<Tuple> discard;
|
||||
thrust::transform_output_iterator<
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||
@ -263,22 +273,22 @@ template <typename AdapterBatch>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||
bool is_dense, int nthread,
|
||||
common::Span<size_t> row_counts_span,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
size_t row_stride, size_t n_rows, size_t n_cols,
|
||||
common::HistogramCuts const& cuts) {
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
*this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows);
|
||||
CopyDataToEllpack(batch, this, device, missing);
|
||||
CopyDataToEllpack(batch, feature_types, this, device, missing);
|
||||
WriteNullValues(this, device, row_counts_span);
|
||||
}
|
||||
|
||||
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
|
||||
template EllpackPageImpl::EllpackPageImpl( \
|
||||
__BATCH_T batch, float missing, int device, \
|
||||
bool is_dense, int nthread, \
|
||||
__BATCH_T batch, float missing, int device, bool is_dense, int nthread, \
|
||||
common::Span<size_t> row_counts_span, \
|
||||
size_t row_stride, size_t n_rows, size_t n_cols, \
|
||||
common::HistogramCuts const& cuts);
|
||||
common::Span<FeatureType const> feature_types, size_t row_stride, \
|
||||
size_t n_rows, size_t n_cols, common::HistogramCuts const &cuts);
|
||||
|
||||
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
|
||||
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
|
||||
@ -467,11 +477,17 @@ size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
|
||||
return compressed_size_bytes;
|
||||
}
|
||||
|
||||
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(int device) const {
|
||||
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
|
||||
int device, common::Span<FeatureType const> feature_types) const {
|
||||
gidx_buffer.SetDevice(device);
|
||||
return EllpackDeviceAccessor(
|
||||
device, cuts_, is_dense, row_stride, base_rowid, n_rows,
|
||||
return {device,
|
||||
cuts_,
|
||||
is_dense,
|
||||
row_stride,
|
||||
base_rowid,
|
||||
n_rows,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.ConstDevicePointer(),
|
||||
NumSymbols()));
|
||||
NumSymbols()),
|
||||
feature_types};
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "../common/compressed_iterator.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/categorical.h"
|
||||
#include <thrust/binary_search.h>
|
||||
|
||||
namespace xgboost {
|
||||
@ -31,13 +32,17 @@ struct EllpackDeviceAccessor {
|
||||
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
|
||||
common::Span<const bst_float> gidx_fvalue_map;
|
||||
|
||||
common::Span<const FeatureType> feature_types;
|
||||
|
||||
EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts,
|
||||
bool is_dense, size_t row_stride, size_t base_rowid,
|
||||
size_t n_rows,common::CompressedIterator<uint32_t> gidx_iter)
|
||||
size_t n_rows,common::CompressedIterator<uint32_t> gidx_iter,
|
||||
common::Span<FeatureType const> feature_types)
|
||||
: is_dense(is_dense),
|
||||
row_stride(row_stride),
|
||||
base_rowid(base_rowid),
|
||||
n_rows(n_rows) ,gidx_iter(gidx_iter){
|
||||
n_rows(n_rows) ,gidx_iter(gidx_iter),
|
||||
feature_types{feature_types} {
|
||||
cuts.cut_values_.SetDevice(device);
|
||||
cuts.cut_ptrs_.SetDevice(device);
|
||||
cuts.min_vals_.SetDevice(device);
|
||||
@ -64,12 +69,23 @@ struct EllpackDeviceAccessor {
|
||||
return gidx;
|
||||
}
|
||||
|
||||
template <bool is_cat>
|
||||
__device__ uint32_t SearchBin(float value, size_t column_id) const {
|
||||
auto beg = feature_segments[column_id];
|
||||
auto end = feature_segments[column_id + 1];
|
||||
auto it =
|
||||
thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin()+ beg, gidx_fvalue_map.cbegin() + end, value);
|
||||
uint32_t idx = it - gidx_fvalue_map.cbegin();
|
||||
uint32_t idx = 0;
|
||||
if (is_cat) {
|
||||
auto it = dh::MakeTransformIterator<bst_cat_t>(
|
||||
gidx_fvalue_map.cbegin(), [](float v) { return common::AsCat(v); });
|
||||
idx = thrust::lower_bound(thrust::seq, it + beg, it + end,
|
||||
common::AsCat(value)) -
|
||||
it;
|
||||
} else {
|
||||
auto it = thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin() + beg,
|
||||
gidx_fvalue_map.cbegin() + end, value);
|
||||
idx = it - gidx_fvalue_map.cbegin();
|
||||
}
|
||||
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
@ -134,8 +150,10 @@ class EllpackPageImpl {
|
||||
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
||||
|
||||
template <typename AdapterBatch>
|
||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread,
|
||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||
bool is_dense, int nthread,
|
||||
common::Span<size_t> row_counts_span,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
size_t row_stride, size_t n_rows, size_t n_cols,
|
||||
common::HistogramCuts const &cuts);
|
||||
|
||||
@ -176,7 +194,9 @@ class EllpackPageImpl {
|
||||
* not found). */
|
||||
size_t NumSymbols() const { return cuts_.TotalBins() + 1; }
|
||||
|
||||
EllpackDeviceAccessor GetDeviceAccessor(int device) const;
|
||||
EllpackDeviceAccessor
|
||||
GetDeviceAccessor(int device,
|
||||
common::Span<FeatureType const> feature_types = {}) const;
|
||||
|
||||
private:
|
||||
/*!
|
||||
|
||||
@ -148,9 +148,13 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
||||
});
|
||||
auto is_dense = this->IsDense();
|
||||
|
||||
proxy->Info().feature_types.SetDevice(get_device());
|
||||
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
|
||||
auto new_impl = Dispatch(proxy, [&](auto const &value) {
|
||||
return EllpackPageImpl(value, missing, get_device(), is_dense, nthread,
|
||||
row_counts_span, row_stride, rows, cols, cuts);
|
||||
row_counts_span, d_feature_types, row_stride, rows,
|
||||
cols, cuts);
|
||||
});
|
||||
size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset);
|
||||
offset += num_elements;
|
||||
|
||||
@ -155,6 +155,9 @@ struct EllpackLoader {
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
}
|
||||
if (common::IsCat(matrix.feature_types, fidx)) {
|
||||
return matrix.gidx_fvalue_map[gidx];
|
||||
}
|
||||
// The gradient index needs to be shifted by one as min values are not included in the
|
||||
// cuts.
|
||||
if (gidx == matrix.feature_segments[fidx]) {
|
||||
@ -592,8 +595,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
} else {
|
||||
size_t batch_offset = 0;
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||
dmat->Info().feature_types.SetDevice(generic_param_->gpu_id);
|
||||
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
this->PredictInternal(
|
||||
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id),
|
||||
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types),
|
||||
d_model,
|
||||
out_preds,
|
||||
batch_offset);
|
||||
|
||||
@ -19,7 +19,7 @@ ENV PATH=/opt/python/bin:$PATH
|
||||
# Create new Conda environment with cuDF, Dask, and cuPy
|
||||
RUN \
|
||||
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy \
|
||||
python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy=9.1* \
|
||||
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
@ -68,7 +68,16 @@ void TestEquivalent(float sparsity) {
|
||||
auto const& buffer_from_iter = page_concatenated->gidx_buffer;
|
||||
auto const& buffer_from_data = ellpack.Impl()->gidx_buffer;
|
||||
ASSERT_NE(buffer_from_data.Size(), 0);
|
||||
ASSERT_EQ(buffer_from_data.ConstHostVector(), buffer_from_data.ConstHostVector());
|
||||
|
||||
common::CompressedIterator<uint32_t> data_buf{
|
||||
buffer_from_data.ConstHostPointer(), from_data.NumSymbols()};
|
||||
common::CompressedIterator<uint32_t> data_iter{
|
||||
buffer_from_iter.ConstHostPointer(), from_iter.NumSymbols()};
|
||||
CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols());
|
||||
CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride);
|
||||
for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) {
|
||||
CHECK_EQ(data_buf[i], data_iter[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -225,6 +225,9 @@ void TestCategoricalPrediction(std::string name) {
|
||||
row[split_ind] = split_cat;
|
||||
auto m = GetDMatrixFromData(row, 1, kCols);
|
||||
|
||||
std::vector<FeatureType> types(10, FeatureType::kCategorical);
|
||||
m->Info().feature_types.HostVector() = types;
|
||||
|
||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
||||
|
||||
@ -225,13 +225,22 @@ class IterForDMatrixTest(xgb.core.DataIter):
|
||||
ROWS_PER_BATCH = 100 # data is splited by rows
|
||||
BATCHES = 16
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, categorical):
|
||||
'''Generate some random data for demostration.
|
||||
|
||||
Actual data can be anything that is currently supported by XGBoost.
|
||||
'''
|
||||
import cudf
|
||||
self.rows = self.ROWS_PER_BATCH
|
||||
|
||||
if categorical:
|
||||
self._data = []
|
||||
self._labels = []
|
||||
for i in range(self.BATCHES):
|
||||
X, y = tm.make_categorical(self.ROWS_PER_BATCH, 4, 13, False)
|
||||
self._data.append(cudf.from_pandas(X))
|
||||
self._labels.append(y)
|
||||
else:
|
||||
rng = np.random.RandomState(1994)
|
||||
self._data = [
|
||||
cudf.DataFrame(
|
||||
@ -272,24 +281,26 @@ class IterForDMatrixTest(xgb.core.DataIter):
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_from_cudf_iter():
|
||||
@pytest.mark.parametrize("enable_categorical", [True, False])
|
||||
def test_from_cudf_iter(enable_categorical):
|
||||
rounds = 100
|
||||
it = IterForDMatrixTest()
|
||||
it = IterForDMatrixTest(enable_categorical)
|
||||
params = {"tree_method": "gpu_hist"}
|
||||
|
||||
# Use iterator
|
||||
m_it = xgb.DeviceQuantileDMatrix(it)
|
||||
reg_with_it = xgb.train({'tree_method': 'gpu_hist'}, m_it,
|
||||
num_boost_round=rounds)
|
||||
predict_with_it = reg_with_it.predict(m_it)
|
||||
m_it = xgb.DeviceQuantileDMatrix(it, enable_categorical=enable_categorical)
|
||||
reg_with_it = xgb.train(params, m_it, num_boost_round=rounds)
|
||||
|
||||
# Without using iterator
|
||||
m = xgb.DMatrix(it.as_array(), it.as_array_labels())
|
||||
X = it.as_array()
|
||||
y = it.as_array_labels()
|
||||
|
||||
m = xgb.DMatrix(X, y, enable_categorical=enable_categorical)
|
||||
|
||||
assert m_it.num_col() == m.num_col()
|
||||
assert m_it.num_row() == m.num_row()
|
||||
|
||||
reg = xgb.train({'tree_method': 'gpu_hist'}, m,
|
||||
num_boost_round=rounds)
|
||||
predict = reg.predict(m)
|
||||
reg = xgb.train(params, m, num_boost_round=rounds)
|
||||
|
||||
predict = reg.predict(m)
|
||||
predict_with_it = reg_with_it.predict(m_it)
|
||||
np.testing.assert_allclose(predict_with_it, predict)
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import sys
|
||||
import os
|
||||
from typing import Type, TypeVar, Any, Dict, List
|
||||
from typing import Type, TypeVar, Any, Dict, List, Tuple
|
||||
import pytest
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import xgboost
|
||||
import subprocess
|
||||
import tempfile
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from inspect import signature
|
||||
from hypothesis import given, strategies, settings, note
|
||||
@ -41,6 +43,49 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def make_categorical(
|
||||
client: Client,
|
||||
n_samples: int,
|
||||
n_features: int,
|
||||
n_categories: int,
|
||||
onehot: bool = False,
|
||||
) -> Tuple[dd.DataFrame, dd.Series]:
|
||||
workers = _get_client_workers(client)
|
||||
n_workers = len(workers)
|
||||
dfs = []
|
||||
|
||||
def pack(**kwargs: Any) -> dd.DataFrame:
|
||||
X, y = tm.make_categorical(**kwargs)
|
||||
X["label"] = y
|
||||
return X
|
||||
|
||||
meta = pack(
|
||||
n_samples=1, n_features=n_features, n_categories=n_categories, onehot=False
|
||||
)
|
||||
|
||||
for i, worker in enumerate(workers):
|
||||
l_n_samples = min(
|
||||
n_samples // n_workers, n_samples - i * (n_samples // n_workers)
|
||||
)
|
||||
future = client.submit(
|
||||
pack,
|
||||
n_samples=l_n_samples,
|
||||
n_features=n_features,
|
||||
n_categories=n_categories,
|
||||
onehot=False,
|
||||
workers=[worker],
|
||||
)
|
||||
dfs.append(future)
|
||||
|
||||
df = dd.from_delayed(dfs, meta=meta)
|
||||
y = df["label"]
|
||||
X = df[df.columns.difference(["label"])]
|
||||
|
||||
if onehot:
|
||||
return dd.get_dummies(X), y
|
||||
return X, y
|
||||
|
||||
|
||||
def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
@ -126,6 +171,62 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
|
||||
inplace_predictions)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||
def test_categorical(local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
with Client(local_cuda_cluster) as client:
|
||||
import dask_cudf
|
||||
|
||||
rounds = 10
|
||||
X, y = make_categorical(client, 10000, 30, 13)
|
||||
X = dask_cudf.from_dask_dataframe(X)
|
||||
|
||||
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
|
||||
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
|
||||
|
||||
parameters = {"tree_method": "gpu_hist"}
|
||||
|
||||
m = dxgb.DaskDMatrix(client, X_onehot, y, enable_categorical=True)
|
||||
by_etl_results = dxgb.train(
|
||||
client,
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
)["history"]
|
||||
|
||||
m = dxgb.DaskDMatrix(client, X, y, enable_categorical=True)
|
||||
output = dxgb.train(
|
||||
client,
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
)
|
||||
by_builtin_results = output["history"]
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(by_etl_results["Train"]["rmse"]),
|
||||
np.array(by_builtin_results["Train"]["rmse"]),
|
||||
rtol=1e-3,
|
||||
)
|
||||
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
||||
|
||||
model = output["booster"]
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
path = os.path.join(tempdir, "model.json")
|
||||
model.save_model(path)
|
||||
with open(path, "r") as fd:
|
||||
categorical = json.load(fd)
|
||||
|
||||
categories_sizes = np.array(
|
||||
categorical["learner"]["gradient_booster"]["model"]["trees"][-1][
|
||||
"categories_sizes"
|
||||
]
|
||||
)
|
||||
assert categories_sizes.shape[0] != 0
|
||||
np.testing.assert_allclose(categories_sizes, 1)
|
||||
|
||||
|
||||
def to_cp(x: Any, DMatrixT: Type) -> Any:
|
||||
import cupy
|
||||
if isinstance(x, np.ndarray) and \
|
||||
|
||||
@ -236,7 +236,7 @@ def get_mq2008(dpath):
|
||||
|
||||
@memory.cache
|
||||
def make_categorical(
|
||||
n_samples: int, n_features: int, n_categories: int, onehot_enc: bool
|
||||
n_samples: int, n_features: int, n_categories: int, onehot: bool
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
@ -244,7 +244,7 @@ def make_categorical(
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(n_features + 1):
|
||||
c = rng.randint(low=0, high=n_categories + 1, size=n_samples)
|
||||
c = rng.randint(low=0, high=n_categories, size=n_samples)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
@ -255,11 +255,13 @@ def make_categorical(
|
||||
label += 1
|
||||
|
||||
df = df.astype("category")
|
||||
if onehot_enc:
|
||||
cat = pd.get_dummies(df)
|
||||
else:
|
||||
cat = df
|
||||
return cat, label
|
||||
categories = np.arange(0, n_categories)
|
||||
for col in df.columns:
|
||||
df[col] = df[col].cat.set_categories(categories)
|
||||
|
||||
if onehot:
|
||||
return pd.get_dummies(df), label
|
||||
return df, label
|
||||
|
||||
|
||||
_unweighted_datasets_strategy = strategies.sampled_from(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user