Support categorical data for dask functional interface and DQM. (#7043)

* Support categorical data for dask functional interface and DQM.

* Implement categorical data support for GPU GK-merge.
* Add support for dask functional interface.
* Add support for DQM.

* Get newer cupy.
This commit is contained in:
Jiaming Yuan 2021-06-18 13:06:52 +08:00 committed by GitHub
parent 7dd29ffd47
commit 86715e4cd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 364 additions and 167 deletions

View File

@ -321,6 +321,7 @@ class DataIter:
def __init__(self): def __init__(self):
self._handle = _ProxyDMatrix() self._handle = _ProxyDMatrix()
self.exception = None self.exception = None
self.enable_categorical = False
@property @property
def proxy(self): def proxy(self):
@ -346,13 +347,12 @@ class DataIter:
data, data,
feature_names=None, feature_names=None,
feature_types=None, feature_types=None,
enable_categorical=False,
**kwargs **kwargs
): ):
from .data import dispatch_device_quantile_dmatrix_set_data from .data import dispatch_device_quantile_dmatrix_set_data
from .data import _device_quantile_transform from .data import _device_quantile_transform
data, feature_names, feature_types = _device_quantile_transform( data, feature_names, feature_types = _device_quantile_transform(
data, feature_names, feature_types, enable_categorical, data, feature_names, feature_types, self.enable_categorical,
) )
dispatch_device_quantile_dmatrix_set_data(self.proxy, data) dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
self.proxy.set_info( self.proxy.set_info(
@ -1106,15 +1106,10 @@ class DeviceQuantileDMatrix(DMatrix):
data = _transform_dlpack(data) data = _transform_dlpack(data)
if _is_iter(data): if _is_iter(data):
it = data it = data
if enable_categorical:
raise NotImplementedError(
"categorical support is not enabled on data iterator."
)
else: else:
it = SingleBatchInternalIter( it = SingleBatchInternalIter(data=data, **meta)
data=data, enable_categorical=enable_categorical, **meta
)
it.enable_categorical = enable_categorical
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper) reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
next_callback = ctypes.CFUNCTYPE( next_callback = ctypes.CFUNCTYPE(
ctypes.c_int, ctypes.c_int,

View File

@ -182,7 +182,7 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
lazy_isinstance(value[0], 'cudf.core.series', 'Series'): lazy_isinstance(value[0], 'cudf.core.series', 'Series'):
from cudf import concat as CUDF_concat # pylint: disable=import-error from cudf import concat as CUDF_concat # pylint: disable=import-error
return CUDF_concat(value, axis=0) return CUDF_concat(value, axis=0)
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'): if lazy_isinstance(value[0], 'cupy._core.core', 'ndarray'):
import cupy import cupy
# pylint: disable=c-extension-no-member,no-member # pylint: disable=c-extension-no-member,no-member
d = cupy.cuda.runtime.getDevice() d = cupy.cuda.runtime.getDevice()
@ -258,6 +258,7 @@ class DaskDMatrix:
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types
self.missing = missing self.missing = missing
self.enable_categorical = enable_categorical
if qid is not None and weight is not None: if qid is not None and weight is not None:
raise NotImplementedError("per-group weight is not implemented.") raise NotImplementedError("per-group weight is not implemented.")
@ -265,10 +266,6 @@ class DaskDMatrix:
raise NotImplementedError( raise NotImplementedError(
"group structure is not implemented, use qid instead." "group structure is not implemented, use qid instead."
) )
if enable_categorical:
raise NotImplementedError(
"categorical support is not enabled on `DaskDMatrix`."
)
if len(data.shape) != 2: if len(data.shape) != 2:
raise ValueError( raise ValueError(
@ -311,7 +308,7 @@ class DaskDMatrix:
qid: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None label_upper_bound: Optional[_DaskCollection] = None,
) -> "DaskDMatrix": ) -> "DaskDMatrix":
'''Obtain references to local data.''' '''Obtain references to local data.'''
@ -430,6 +427,7 @@ class DaskDMatrix:
'feature_weights': self.feature_weights, 'feature_weights': self.feature_weights,
'meta_names': self.meta_names, 'meta_names': self.meta_names,
'missing': self.missing, 'missing': self.missing,
'enable_categorical': self.enable_categorical,
'parts': self.worker_map.get(worker_addr, None), 'parts': self.worker_map.get(worker_addr, None),
'is_quantile': self.is_quantile} 'is_quantile': self.is_quantile}
@ -668,6 +666,7 @@ def _create_device_quantile_dmatrix(
missing: float, missing: float,
parts: Optional[_DataParts], parts: Optional[_DataParts],
max_bin: int, max_bin: int,
enable_categorical: bool,
) -> DeviceQuantileDMatrix: ) -> DeviceQuantileDMatrix:
worker = distributed.get_worker() worker = distributed.get_worker()
if parts is None: if parts is None:
@ -680,6 +679,7 @@ def _create_device_quantile_dmatrix(
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
max_bin=max_bin, max_bin=max_bin,
enable_categorical=enable_categorical,
) )
return d return d
@ -709,6 +709,7 @@ def _create_device_quantile_dmatrix(
feature_types=feature_types, feature_types=feature_types,
nthread=worker.nthreads, nthread=worker.nthreads,
max_bin=max_bin, max_bin=max_bin,
enable_categorical=enable_categorical,
) )
dmatrix.set_info(feature_weights=feature_weights) dmatrix.set_info(feature_weights=feature_weights)
return dmatrix return dmatrix
@ -720,6 +721,7 @@ def _create_dmatrix(
feature_weights: Optional[Any], feature_weights: Optional[Any],
meta_names: List[str], meta_names: List[str],
missing: float, missing: float,
enable_categorical: bool,
parts: Optional[_DataParts] parts: Optional[_DataParts]
) -> DMatrix: ) -> DMatrix:
'''Get data that local to worker from DaskDMatrix. '''Get data that local to worker from DaskDMatrix.
@ -734,9 +736,12 @@ def _create_dmatrix(
if list_of_parts is None: if list_of_parts is None:
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address) msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
LOGGER.warning(msg) LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)), d = DMatrix(
feature_names=feature_names, numpy.empty((0, 0)),
feature_types=feature_types) feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
)
return d return d
T = TypeVar('T') T = TypeVar('T')
@ -764,6 +769,7 @@ def _create_dmatrix(
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
nthread=worker.nthreads, nthread=worker.nthreads,
enable_categorical=enable_categorical,
) )
dmatrix.set_info( dmatrix.set_info(
base_margin=_base_margin, base_margin=_base_margin,

View File

@ -1151,12 +1151,12 @@ struct SegmentedUniqueReduceOp {
* \return Number of unique values in total. * \return Number of unique values in total.
*/ */
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt, template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
typename ValOutIt, typename Comp> typename ValOutIt, typename CompValue, typename CompKey>
size_t size_t
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec, SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first, KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out, ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
Comp comp) { CompValue comp, CompKey comp_key=thrust::equal_to<size_t>{}) {
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>; using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
auto unique_key_it = dh::MakeTransformIterator<Key>( auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)), thrust::make_counting_iterator(static_cast<size_t>(0)),
@ -1177,7 +1177,7 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
exec, unique_key_it, unique_key_it + n_inputs, exec, unique_key_it, unique_key_it + n_inputs,
val_first, reduce_it, val_out, val_first, reduce_it, val_out,
[=] __device__(Key const &l, Key const &r) { [=] __device__(Key const &l, Key const &r) {
if (l.first == r.first) { if (comp_key(l.first, r.first)) {
// In the same segment. // In the same segment.
return comp(l.second, r.second); return comp(l.second, r.second);
} }
@ -1195,7 +1195,9 @@ template <typename... Inputs,
* = nullptr> * = nullptr>
size_t SegmentedUnique(Inputs &&...inputs) { size_t SegmentedUnique(Inputs &&...inputs) {
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...); return SegmentedUnique(thrust::cuda::par(alloc),
std::forward<Inputs &&>(inputs)...,
thrust::equal_to<size_t>{});
} }
/** /**

View File

@ -129,60 +129,52 @@ void SortByWeight(dh::device_vector<float>* weights,
}); });
} }
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};
void RemoveDuplicatedCategories( void RemoveDuplicatedCategories(
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr, int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
dh::device_vector<Entry> *p_sorted_entries, dh::device_vector<Entry> *p_sorted_entries,
dh::caching_device_vector<size_t>* p_column_sizes_scan) { dh::caching_device_vector<size_t> *p_column_sizes_scan) {
auto d_feature_types = info.feature_types.ConstDeviceSpan(); auto d_feature_types = info.feature_types.ConstDeviceSpan();
auto& column_sizes_scan = *p_column_sizes_scan; CHECK(!d_feature_types.empty());
if (!info.feature_types.Empty() && auto &column_sizes_scan = *p_column_sizes_scan;
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), auto &sorted_entries = *p_sorted_entries;
IsCatOp{})) { // Removing duplicated entries in categorical features.
auto& sorted_entries = *p_sorted_entries; dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
// Removing duplicated entries in categorical features. dh::SegmentedUnique(column_sizes_scan.data().get(),
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size()); column_sizes_scan.data().get() + column_sizes_scan.size(),
dh::SegmentedUnique( sorted_entries.begin(), sorted_entries.end(),
column_sizes_scan.data().get(), new_column_scan.data().get(), sorted_entries.begin(),
column_sizes_scan.data().get() + column_sizes_scan.size(), [=] __device__(Entry const &l, Entry const &r) {
sorted_entries.begin(), sorted_entries.end(), if (l.index == r.index) {
new_column_scan.data().get(), sorted_entries.begin(), if (IsCat(d_feature_types, l.index)) {
[=] __device__(Entry const &l, Entry const &r) { return l.fvalue == r.fvalue;
if (l.index == r.index) { }
if (IsCat(d_feature_types, l.index)) { }
return l.fvalue == r.fvalue; return false;
} });
}
return false;
});
// Renew the column scan and cut scan based on categorical data. // Renew the column scan and cut scan based on categorical data.
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan); auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size( dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
info.num_col_ + 1); info.num_col_ + 1);
auto d_new_cuts_size = dh::ToSpan(new_cuts_size); auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
auto d_new_columns_ptr = dh::ToSpan(new_column_scan); auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) { dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx]; d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
if (idx == d_new_columns_ptr.size() - 1) { if (idx == d_new_columns_ptr.size() - 1) {
return; return;
} }
if (IsCat(d_feature_types, idx)) { if (IsCat(d_feature_types, idx)) {
// Cut size is the same as number of categories in input. // Cut size is the same as number of categories in input.
d_new_cuts_size[idx] = d_new_cuts_size[idx] =
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx]; d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
} else { } else {
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx]; d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
} }
}); });
// Turn size into ptr. // Turn size into ptr.
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
new_cuts_size.cend(), d_cuts_ptr.data()); new_cuts_size.cend(), d_cuts_ptr.data());
}
} }
} // namespace detail } // namespace detail
@ -215,8 +207,11 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
0, sorted_entries.size(), 0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan); &cuts_ptr, &column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan(); auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
&column_sizes_scan); if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
}
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
@ -281,8 +276,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
0, sorted_entries.size(), 0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan); &cuts_ptr, &column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan(); auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, if (sketch_container->HasCategorical()) {
&column_sizes_scan); detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
}
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
// Extract cuts // Extract cuts

View File

@ -210,6 +210,7 @@ void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
Span<bst_row_t const> const &x_ptr, Span<bst_row_t const> const &x_ptr,
Span<SketchEntry const> const &d_y, Span<SketchEntry const> const &d_y,
Span<bst_row_t const> const &y_ptr, Span<bst_row_t const> const &y_ptr,
Span<FeatureType const> feature_types,
Span<SketchEntry> out, Span<SketchEntry> out,
Span<bst_row_t> out_ptr) { Span<bst_row_t> out_ptr) {
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
@ -408,31 +409,6 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
return n_uniques; return n_uniques;
} }
size_t SketchContainer::Unique() {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();
d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(),
detail::SketchUnique{});
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());
this->Current().resize(n_uniques);
timer_.Stop(__func__);
return n_uniques;
}
void SketchContainer::Prune(size_t to) { void SketchContainer::Prune(size_t to) {
timer_.Start(__func__); timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
@ -490,13 +466,20 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
this->Other().resize(this->Current().size() + that.size()); this->Other().resize(this->Current().size() + that.size());
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size()); CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
MergeImpl(device_, this->Data(), this->ColumnsPtr(), auto feature_types = this->FeatureTypes().ConstDeviceSpan();
that, d_that_columns_ptr, MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan()); feature_types, dh::ToSpan(this->Other()),
columns_ptr_b_.DeviceSpan());
this->columns_ptr_.Copy(columns_ptr_b_); this->columns_ptr_.Copy(columns_ptr_b_);
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1); CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
this->Alternate(); this->Alternate();
if (this->HasCategorical()) {
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
});
}
timer_.Stop(__func__); timer_.Stop(__func__);
} }

View File

@ -16,6 +16,19 @@ class HistogramCuts;
using WQSketch = WQuantileSketch<bst_float, bst_float>; using WQSketch = WQuantileSketch<bst_float, bst_float>;
using SketchEntry = WQSketch::Entry; using SketchEntry = WQSketch::Entry;
namespace detail {
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
};
struct SketchUnique {
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
}
};
} // namespace detail
/*! /*!
* \brief A container that holds the device sketches. Sketching is performed per-column, * \brief A container that holds the device sketches. Sketching is performed per-column,
* but fused into single operation for performance. * but fused into single operation for performance.
@ -43,6 +56,8 @@ class SketchContainer {
HostDeviceVector<OffsetT> columns_ptr_; HostDeviceVector<OffsetT> columns_ptr_;
HostDeviceVector<OffsetT> columns_ptr_b_; HostDeviceVector<OffsetT> columns_ptr_b_;
bool has_categorical_{false};
dh::device_vector<SketchEntry>& Current() { dh::device_vector<SketchEntry>& Current() {
if (current_buffer_) { if (current_buffer_) {
return entries_a_; return entries_a_;
@ -102,14 +117,21 @@ class SketchContainer {
this->feature_types_.SetDevice(device); this->feature_types_.SetDevice(device);
this->feature_types_.ConstDeviceSpan(); this->feature_types_.ConstDeviceSpan();
this->feature_types_.ConstHostSpan(); this->feature_types_.ConstHostSpan();
auto d_feature_types = feature_types_.ConstDeviceSpan();
has_categorical_ =
!d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
detail::IsCatOp{});
timer_.Init(__func__); timer_.Init(__func__);
} }
/* \brief Return GPU ID for this container. */ /* \brief Return GPU ID for this container. */
int32_t DeviceIdx() const { return device_; } int32_t DeviceIdx() const { return device_; }
/* \brief Whether the predictor matrix contains categorical features. */
bool HasCategorical() const { return has_categorical_; }
/* \brief Accumulate weights of duplicated entries in input. */ /* \brief Accumulate weights of duplicated entries in input. */
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in); size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
/* \brief Removes all the duplicated elements in quantile structure. */
size_t Unique();
/* Fix rounding error and re-establish invariance. The error is mostly generated by the /* Fix rounding error and re-establish invariance. The error is mostly generated by the
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */ * addition inside `RMinNext` and subtraction in `RMaxPrev`. */
void FixError(); void FixError();
@ -154,15 +176,35 @@ class SketchContainer {
SketchContainer(const SketchContainer&) = delete; SketchContainer(const SketchContainer&) = delete;
SketchContainer& operator=(const SketchContainer&) = delete; SketchContainer& operator=(const SketchContainer&) = delete;
};
namespace detail { /* \brief Removes all the duplicated elements in quantile structure. */
struct SketchUnique { template <typename KeyComp = thrust::equal_to<size_t>>
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const { size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
return a.value - b.value == 0; timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;
d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
thrust::cuda::par(alloc), d_column_scan.data(),
d_column_scan.data() + d_column_scan.size(), entries.data(),
entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(), detail::SketchUnique{}, key_comp);
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());
this->Current().resize(n_uniques);
timer_.Stop(__func__);
return n_uniques;
} }
}; };
} // namespace detail
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -134,17 +134,20 @@ struct WriteCompressedEllpackFunctor {
const common::CompressedBufferWriter& writer, const common::CompressedBufferWriter& writer,
AdapterBatchT batch, AdapterBatchT batch,
EllpackDeviceAccessor accessor, EllpackDeviceAccessor accessor,
common::Span<FeatureType const> feature_types,
const data::IsValidFunctor& is_valid) const data::IsValidFunctor& is_valid)
: d_buffer(buffer), : d_buffer(buffer),
writer(writer), writer(writer),
batch(std::move(batch)), batch(std::move(batch)),
accessor(std::move(accessor)), accessor(std::move(accessor)),
feature_types(std::move(feature_types)),
is_valid(is_valid) {} is_valid(is_valid) {}
common::CompressedByteT* d_buffer; common::CompressedByteT* d_buffer;
common::CompressedBufferWriter writer; common::CompressedBufferWriter writer;
AdapterBatchT batch; AdapterBatchT batch;
EllpackDeviceAccessor accessor; EllpackDeviceAccessor accessor;
common::Span<FeatureType const> feature_types;
data::IsValidFunctor is_valid; data::IsValidFunctor is_valid;
using Tuple = thrust::tuple<size_t, size_t, size_t>; using Tuple = thrust::tuple<size_t, size_t, size_t>;
@ -154,7 +157,12 @@ struct WriteCompressedEllpackFunctor {
// -1 because the scan is inclusive // -1 because the scan is inclusive
size_t output_position = size_t output_position =
accessor.row_stride * e.row_idx + out.get<1>() - 1; accessor.row_stride * e.row_idx + out.get<1>() - 1;
auto bin_idx = accessor.SearchBin(e.value, e.column_idx); uint32_t bin_idx = 0;
if (common::IsCat(feature_types, e.column_idx)) {
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
} else {
bin_idx = accessor.SearchBin<false>(e.value, e.column_idx);
}
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position); writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
} }
return 0; return 0;
@ -184,8 +192,9 @@ class TypedDiscard : public thrust::discard_iterator<T> {
// Here the data is already correctly ordered and simply needs to be compacted // Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data // to remove missing data
template <typename AdapterBatchT> template <typename AdapterBatchT>
void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, void CopyDataToEllpack(const AdapterBatchT &batch,
int device_idx, float missing) { common::Span<FeatureType const> feature_types,
EllpackPageImpl *dst, int device_idx, float missing) {
// Some witchcraft happens here // Some witchcraft happens here
// The goal is to copy valid elements out of the input to an ELLPACK matrix // The goal is to copy valid elements out of the input to an ELLPACK matrix
// with a given row stride, using no extra working memory Standard stream // with a given row stride, using no extra working memory Standard stream
@ -220,7 +229,8 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
// We redirect the scan output into this functor to do the actual writing // We redirect the scan output into this functor to do the actual writing
WriteCompressedEllpackFunctor<AdapterBatchT> functor( WriteCompressedEllpackFunctor<AdapterBatchT> functor(
d_compressed_buffer, writer, batch, device_accessor, is_valid); d_compressed_buffer, writer, batch, device_accessor, feature_types,
is_valid);
TypedDiscard<Tuple> discard; TypedDiscard<Tuple> discard;
thrust::transform_output_iterator< thrust::transform_output_iterator<
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)> WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
@ -263,22 +273,22 @@ template <typename AdapterBatch>
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
bool is_dense, int nthread, bool is_dense, int nthread,
common::Span<size_t> row_counts_span, common::Span<size_t> row_counts_span,
common::Span<FeatureType const> feature_types,
size_t row_stride, size_t n_rows, size_t n_cols, size_t row_stride, size_t n_rows, size_t n_cols,
common::HistogramCuts const& cuts) { common::HistogramCuts const& cuts) {
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
*this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows); *this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows);
CopyDataToEllpack(batch, this, device, missing); CopyDataToEllpack(batch, feature_types, this, device, missing);
WriteNullValues(this, device, row_counts_span); WriteNullValues(this, device, row_counts_span);
} }
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \ #define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
template EllpackPageImpl::EllpackPageImpl( \ template EllpackPageImpl::EllpackPageImpl( \
__BATCH_T batch, float missing, int device, \ __BATCH_T batch, float missing, int device, bool is_dense, int nthread, \
bool is_dense, int nthread, \ common::Span<size_t> row_counts_span, \
common::Span<size_t> row_counts_span, \ common::Span<FeatureType const> feature_types, size_t row_stride, \
size_t row_stride, size_t n_rows, size_t n_cols, \ size_t n_rows, size_t n_cols, common::HistogramCuts const &cuts);
common::HistogramCuts const& cuts);
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
@ -467,11 +477,17 @@ size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
return compressed_size_bytes; return compressed_size_bytes;
} }
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(int device) const { EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
int device, common::Span<FeatureType const> feature_types) const {
gidx_buffer.SetDevice(device); gidx_buffer.SetDevice(device);
return EllpackDeviceAccessor( return {device,
device, cuts_, is_dense, row_stride, base_rowid, n_rows, cuts_,
common::CompressedIterator<uint32_t>(gidx_buffer.ConstDevicePointer(), is_dense,
NumSymbols())); row_stride,
base_rowid,
n_rows,
common::CompressedIterator<uint32_t>(gidx_buffer.ConstDevicePointer(),
NumSymbols()),
feature_types};
} }
} // namespace xgboost } // namespace xgboost

View File

@ -10,6 +10,7 @@
#include "../common/compressed_iterator.h" #include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/categorical.h"
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
namespace xgboost { namespace xgboost {
@ -31,13 +32,17 @@ struct EllpackDeviceAccessor {
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */ /*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
common::Span<const bst_float> gidx_fvalue_map; common::Span<const bst_float> gidx_fvalue_map;
common::Span<const FeatureType> feature_types;
EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts, EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts,
bool is_dense, size_t row_stride, size_t base_rowid, bool is_dense, size_t row_stride, size_t base_rowid,
size_t n_rows,common::CompressedIterator<uint32_t> gidx_iter) size_t n_rows,common::CompressedIterator<uint32_t> gidx_iter,
common::Span<FeatureType const> feature_types)
: is_dense(is_dense), : is_dense(is_dense),
row_stride(row_stride), row_stride(row_stride),
base_rowid(base_rowid), base_rowid(base_rowid),
n_rows(n_rows) ,gidx_iter(gidx_iter){ n_rows(n_rows) ,gidx_iter(gidx_iter),
feature_types{feature_types} {
cuts.cut_values_.SetDevice(device); cuts.cut_values_.SetDevice(device);
cuts.cut_ptrs_.SetDevice(device); cuts.cut_ptrs_.SetDevice(device);
cuts.min_vals_.SetDevice(device); cuts.min_vals_.SetDevice(device);
@ -64,12 +69,23 @@ struct EllpackDeviceAccessor {
return gidx; return gidx;
} }
template <bool is_cat>
__device__ uint32_t SearchBin(float value, size_t column_id) const { __device__ uint32_t SearchBin(float value, size_t column_id) const {
auto beg = feature_segments[column_id]; auto beg = feature_segments[column_id];
auto end = feature_segments[column_id + 1]; auto end = feature_segments[column_id + 1];
auto it = uint32_t idx = 0;
thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin()+ beg, gidx_fvalue_map.cbegin() + end, value); if (is_cat) {
uint32_t idx = it - gidx_fvalue_map.cbegin(); auto it = dh::MakeTransformIterator<bst_cat_t>(
gidx_fvalue_map.cbegin(), [](float v) { return common::AsCat(v); });
idx = thrust::lower_bound(thrust::seq, it + beg, it + end,
common::AsCat(value)) -
it;
} else {
auto it = thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin() + beg,
gidx_fvalue_map.cbegin() + end, value);
idx = it - gidx_fvalue_map.cbegin();
}
if (idx == end) { if (idx == end) {
idx -= 1; idx -= 1;
} }
@ -134,10 +150,12 @@ class EllpackPageImpl {
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm); explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
template <typename AdapterBatch> template <typename AdapterBatch>
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread, explicit EllpackPageImpl(AdapterBatch batch, float missing, int device,
bool is_dense, int nthread,
common::Span<size_t> row_counts_span, common::Span<size_t> row_counts_span,
common::Span<FeatureType const> feature_types,
size_t row_stride, size_t n_rows, size_t n_cols, size_t row_stride, size_t n_rows, size_t n_cols,
common::HistogramCuts const& cuts); common::HistogramCuts const &cuts);
/*! \brief Copy the elements of the given ELLPACK page into this page. /*! \brief Copy the elements of the given ELLPACK page into this page.
* *
@ -176,7 +194,9 @@ class EllpackPageImpl {
* not found). */ * not found). */
size_t NumSymbols() const { return cuts_.TotalBins() + 1; } size_t NumSymbols() const { return cuts_.TotalBins() + 1; }
EllpackDeviceAccessor GetDeviceAccessor(int device) const; EllpackDeviceAccessor
GetDeviceAccessor(int device,
common::Span<FeatureType const> feature_types = {}) const;
private: private:
/*! /*!

View File

@ -148,9 +148,13 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
return GetRowCounts(value, row_counts_span, get_device(), missing); return GetRowCounts(value, row_counts_span, get_device(), missing);
}); });
auto is_dense = this->IsDense(); auto is_dense = this->IsDense();
proxy->Info().feature_types.SetDevice(get_device());
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
auto new_impl = Dispatch(proxy, [&](auto const &value) { auto new_impl = Dispatch(proxy, [&](auto const &value) {
return EllpackPageImpl(value, missing, get_device(), is_dense, nthread, return EllpackPageImpl(value, missing, get_device(), is_dense, nthread,
row_counts_span, row_stride, rows, cols, cuts); row_counts_span, d_feature_types, row_stride, rows,
cols, cuts);
}); });
size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset); size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset);
offset += num_elements; offset += num_elements;

View File

@ -155,6 +155,9 @@ struct EllpackLoader {
if (gidx == -1) { if (gidx == -1) {
return nan(""); return nan("");
} }
if (common::IsCat(matrix.feature_types, fidx)) {
return matrix.gidx_fvalue_map[gidx];
}
// The gradient index needs to be shifted by one as min values are not included in the // The gradient index needs to be shifted by one as min values are not included in the
// cuts. // cuts.
if (gidx == matrix.feature_segments[fidx]) { if (gidx == matrix.feature_segments[fidx]) {
@ -592,8 +595,10 @@ class GPUPredictor : public xgboost::Predictor {
} else { } else {
size_t batch_offset = 0; size_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>()) { for (auto const& page : dmat->GetBatches<EllpackPage>()) {
dmat->Info().feature_types.SetDevice(generic_param_->gpu_id);
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
this->PredictInternal( this->PredictInternal(
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types),
d_model, d_model,
out_preds, out_preds,
batch_offset); batch_offset);

View File

@ -19,7 +19,7 @@ ENV PATH=/opt/python/bin:$PATH
# Create new Conda environment with cuDF, Dask, and cuPy # Create new Conda environment with cuDF, Dask, and cuPy
RUN \ RUN \
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \ conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy \ python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy=9.1* \
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
ENV GOSU_VERSION 1.10 ENV GOSU_VERSION 1.10

View File

@ -68,7 +68,16 @@ void TestEquivalent(float sparsity) {
auto const& buffer_from_iter = page_concatenated->gidx_buffer; auto const& buffer_from_iter = page_concatenated->gidx_buffer;
auto const& buffer_from_data = ellpack.Impl()->gidx_buffer; auto const& buffer_from_data = ellpack.Impl()->gidx_buffer;
ASSERT_NE(buffer_from_data.Size(), 0); ASSERT_NE(buffer_from_data.Size(), 0);
ASSERT_EQ(buffer_from_data.ConstHostVector(), buffer_from_data.ConstHostVector());
common::CompressedIterator<uint32_t> data_buf{
buffer_from_data.ConstHostPointer(), from_data.NumSymbols()};
common::CompressedIterator<uint32_t> data_iter{
buffer_from_iter.ConstHostPointer(), from_iter.NumSymbols()};
CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols());
CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride);
for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) {
CHECK_EQ(data_buf[i], data_iter[i]);
}
} }
} }

View File

@ -225,6 +225,9 @@ void TestCategoricalPrediction(std::string name) {
row[split_ind] = split_cat; row[split_ind] = split_cat;
auto m = GetDMatrixFromData(row, 1, kCols); auto m = GetDMatrixFromData(row, 1, kCols);
std::vector<FeatureType> types(10, FeatureType::kCategorical);
m->Info().feature_types.HostVector() = types;
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictBatch(m.get(), &out_predictions, model, 0); predictor->PredictBatch(m.get(), &out_predictions, model, 0);
ASSERT_EQ(out_predictions.predictions.Size(), 1ul); ASSERT_EQ(out_predictions.predictions.Size(), 1ul);

View File

@ -225,19 +225,28 @@ class IterForDMatrixTest(xgb.core.DataIter):
ROWS_PER_BATCH = 100 # data is splited by rows ROWS_PER_BATCH = 100 # data is splited by rows
BATCHES = 16 BATCHES = 16
def __init__(self): def __init__(self, categorical):
'''Generate some random data for demostration. '''Generate some random data for demostration.
Actual data can be anything that is currently supported by XGBoost. Actual data can be anything that is currently supported by XGBoost.
''' '''
import cudf import cudf
self.rows = self.ROWS_PER_BATCH self.rows = self.ROWS_PER_BATCH
rng = np.random.RandomState(1994)
self._data = [ if categorical:
cudf.DataFrame( self._data = []
{'a': rng.randn(self.ROWS_PER_BATCH), self._labels = []
'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES for i in range(self.BATCHES):
self._labels = [rng.randn(self.rows)] * self.BATCHES X, y = tm.make_categorical(self.ROWS_PER_BATCH, 4, 13, False)
self._data.append(cudf.from_pandas(X))
self._labels.append(y)
else:
rng = np.random.RandomState(1994)
self._data = [
cudf.DataFrame(
{'a': rng.randn(self.ROWS_PER_BATCH),
'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES
self._labels = [rng.randn(self.rows)] * self.BATCHES
self.it = 0 # set iterator to 0 self.it = 0 # set iterator to 0
super().__init__() super().__init__()
@ -272,24 +281,26 @@ class IterForDMatrixTest(xgb.core.DataIter):
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_from_cudf_iter(): @pytest.mark.parametrize("enable_categorical", [True, False])
def test_from_cudf_iter(enable_categorical):
rounds = 100 rounds = 100
it = IterForDMatrixTest() it = IterForDMatrixTest(enable_categorical)
params = {"tree_method": "gpu_hist"}
# Use iterator # Use iterator
m_it = xgb.DeviceQuantileDMatrix(it) m_it = xgb.DeviceQuantileDMatrix(it, enable_categorical=enable_categorical)
reg_with_it = xgb.train({'tree_method': 'gpu_hist'}, m_it, reg_with_it = xgb.train(params, m_it, num_boost_round=rounds)
num_boost_round=rounds)
predict_with_it = reg_with_it.predict(m_it)
# Without using iterator X = it.as_array()
m = xgb.DMatrix(it.as_array(), it.as_array_labels()) y = it.as_array_labels()
m = xgb.DMatrix(X, y, enable_categorical=enable_categorical)
assert m_it.num_col() == m.num_col() assert m_it.num_col() == m.num_col()
assert m_it.num_row() == m.num_row() assert m_it.num_row() == m.num_row()
reg = xgb.train({'tree_method': 'gpu_hist'}, m, reg = xgb.train(params, m, num_boost_round=rounds)
num_boost_round=rounds)
predict = reg.predict(m)
predict = reg.predict(m)
predict_with_it = reg_with_it.predict(m_it)
np.testing.assert_allclose(predict_with_it, predict) np.testing.assert_allclose(predict_with_it, predict)

View File

@ -1,11 +1,13 @@
import sys import sys
import os import os
from typing import Type, TypeVar, Any, Dict, List from typing import Type, TypeVar, Any, Dict, List, Tuple
import pytest import pytest
import numpy as np import numpy as np
import asyncio import asyncio
import xgboost import xgboost
import subprocess import subprocess
import tempfile
import json
from collections import OrderedDict from collections import OrderedDict
from inspect import signature from inspect import signature
from hypothesis import given, strategies, settings, note from hypothesis import given, strategies, settings, note
@ -41,6 +43,49 @@ except ImportError:
pass pass
def make_categorical(
client: Client,
n_samples: int,
n_features: int,
n_categories: int,
onehot: bool = False,
) -> Tuple[dd.DataFrame, dd.Series]:
workers = _get_client_workers(client)
n_workers = len(workers)
dfs = []
def pack(**kwargs: Any) -> dd.DataFrame:
X, y = tm.make_categorical(**kwargs)
X["label"] = y
return X
meta = pack(
n_samples=1, n_features=n_features, n_categories=n_categories, onehot=False
)
for i, worker in enumerate(workers):
l_n_samples = min(
n_samples // n_workers, n_samples - i * (n_samples // n_workers)
)
future = client.submit(
pack,
n_samples=l_n_samples,
n_features=n_features,
n_categories=n_categories,
onehot=False,
workers=[worker],
)
dfs.append(future)
df = dd.from_delayed(dfs, meta=meta)
y = df["label"]
X = df[df.columns.difference(["label"])]
if onehot:
return dd.get_dummies(X), y
return X, y
def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None: def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
import cupy as cp import cupy as cp
cp.cuda.runtime.setDevice(0) cp.cuda.runtime.setDevice(0)
@ -126,6 +171,62 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
inplace_predictions) inplace_predictions)
@pytest.mark.skipif(**tm.no_dask_cudf())
def test_categorical(local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
import dask_cudf
rounds = 10
X, y = make_categorical(client, 10000, 30, 13)
X = dask_cudf.from_dask_dataframe(X)
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
parameters = {"tree_method": "gpu_hist"}
m = dxgb.DaskDMatrix(client, X_onehot, y, enable_categorical=True)
by_etl_results = dxgb.train(
client,
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
)["history"]
m = dxgb.DaskDMatrix(client, X, y, enable_categorical=True)
output = dxgb.train(
client,
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
)
by_builtin_results = output["history"]
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
model = output["booster"]
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, "model.json")
model.save_model(path)
with open(path, "r") as fd:
categorical = json.load(fd)
categories_sizes = np.array(
categorical["learner"]["gradient_booster"]["model"]["trees"][-1][
"categories_sizes"
]
)
assert categories_sizes.shape[0] != 0
np.testing.assert_allclose(categories_sizes, 1)
def to_cp(x: Any, DMatrixT: Type) -> Any: def to_cp(x: Any, DMatrixT: Type) -> Any:
import cupy import cupy
if isinstance(x, np.ndarray) and \ if isinstance(x, np.ndarray) and \

View File

@ -236,7 +236,7 @@ def get_mq2008(dpath):
@memory.cache @memory.cache
def make_categorical( def make_categorical(
n_samples: int, n_features: int, n_categories: int, onehot_enc: bool n_samples: int, n_features: int, n_categories: int, onehot: bool
): ):
import pandas as pd import pandas as pd
@ -244,7 +244,7 @@ def make_categorical(
pd_dict = {} pd_dict = {}
for i in range(n_features + 1): for i in range(n_features + 1):
c = rng.randint(low=0, high=n_categories + 1, size=n_samples) c = rng.randint(low=0, high=n_categories, size=n_samples)
pd_dict[str(i)] = pd.Series(c, dtype=np.int64) pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
df = pd.DataFrame(pd_dict) df = pd.DataFrame(pd_dict)
@ -255,11 +255,13 @@ def make_categorical(
label += 1 label += 1
df = df.astype("category") df = df.astype("category")
if onehot_enc: categories = np.arange(0, n_categories)
cat = pd.get_dummies(df) for col in df.columns:
else: df[col] = df[col].cat.set_categories(categories)
cat = df
return cat, label if onehot:
return pd.get_dummies(df), label
return df, label
_unweighted_datasets_strategy = strategies.sampled_from( _unweighted_datasets_strategy = strategies.sampled_from(