Support categorical data for dask functional interface and DQM. (#7043)
* Support categorical data for dask functional interface and DQM. * Implement categorical data support for GPU GK-merge. * Add support for dask functional interface. * Add support for DQM. * Get newer cupy.
This commit is contained in:
parent
7dd29ffd47
commit
86715e4cd4
@ -321,6 +321,7 @@ class DataIter:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._handle = _ProxyDMatrix()
|
self._handle = _ProxyDMatrix()
|
||||||
self.exception = None
|
self.exception = None
|
||||||
|
self.enable_categorical = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def proxy(self):
|
def proxy(self):
|
||||||
@ -346,13 +347,12 @@ class DataIter:
|
|||||||
data,
|
data,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None,
|
feature_types=None,
|
||||||
enable_categorical=False,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
from .data import dispatch_device_quantile_dmatrix_set_data
|
from .data import dispatch_device_quantile_dmatrix_set_data
|
||||||
from .data import _device_quantile_transform
|
from .data import _device_quantile_transform
|
||||||
data, feature_names, feature_types = _device_quantile_transform(
|
data, feature_names, feature_types = _device_quantile_transform(
|
||||||
data, feature_names, feature_types, enable_categorical,
|
data, feature_names, feature_types, self.enable_categorical,
|
||||||
)
|
)
|
||||||
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||||
self.proxy.set_info(
|
self.proxy.set_info(
|
||||||
@ -1106,15 +1106,10 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
data = _transform_dlpack(data)
|
data = _transform_dlpack(data)
|
||||||
if _is_iter(data):
|
if _is_iter(data):
|
||||||
it = data
|
it = data
|
||||||
if enable_categorical:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"categorical support is not enabled on data iterator."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
it = SingleBatchInternalIter(
|
it = SingleBatchInternalIter(data=data, **meta)
|
||||||
data=data, enable_categorical=enable_categorical, **meta
|
|
||||||
)
|
|
||||||
|
|
||||||
|
it.enable_categorical = enable_categorical
|
||||||
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
||||||
next_callback = ctypes.CFUNCTYPE(
|
next_callback = ctypes.CFUNCTYPE(
|
||||||
ctypes.c_int,
|
ctypes.c_int,
|
||||||
|
|||||||
@ -182,7 +182,7 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
|
|||||||
lazy_isinstance(value[0], 'cudf.core.series', 'Series'):
|
lazy_isinstance(value[0], 'cudf.core.series', 'Series'):
|
||||||
from cudf import concat as CUDF_concat # pylint: disable=import-error
|
from cudf import concat as CUDF_concat # pylint: disable=import-error
|
||||||
return CUDF_concat(value, axis=0)
|
return CUDF_concat(value, axis=0)
|
||||||
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
|
if lazy_isinstance(value[0], 'cupy._core.core', 'ndarray'):
|
||||||
import cupy
|
import cupy
|
||||||
# pylint: disable=c-extension-no-member,no-member
|
# pylint: disable=c-extension-no-member,no-member
|
||||||
d = cupy.cuda.runtime.getDevice()
|
d = cupy.cuda.runtime.getDevice()
|
||||||
@ -258,6 +258,7 @@ class DaskDMatrix:
|
|||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
self.missing = missing
|
self.missing = missing
|
||||||
|
self.enable_categorical = enable_categorical
|
||||||
|
|
||||||
if qid is not None and weight is not None:
|
if qid is not None and weight is not None:
|
||||||
raise NotImplementedError("per-group weight is not implemented.")
|
raise NotImplementedError("per-group weight is not implemented.")
|
||||||
@ -265,10 +266,6 @@ class DaskDMatrix:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"group structure is not implemented, use qid instead."
|
"group structure is not implemented, use qid instead."
|
||||||
)
|
)
|
||||||
if enable_categorical:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"categorical support is not enabled on `DaskDMatrix`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(data.shape) != 2:
|
if len(data.shape) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -311,7 +308,7 @@ class DaskDMatrix:
|
|||||||
qid: Optional[_DaskCollection] = None,
|
qid: Optional[_DaskCollection] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
label_lower_bound: Optional[_DaskCollection] = None,
|
label_lower_bound: Optional[_DaskCollection] = None,
|
||||||
label_upper_bound: Optional[_DaskCollection] = None
|
label_upper_bound: Optional[_DaskCollection] = None,
|
||||||
) -> "DaskDMatrix":
|
) -> "DaskDMatrix":
|
||||||
'''Obtain references to local data.'''
|
'''Obtain references to local data.'''
|
||||||
|
|
||||||
@ -430,6 +427,7 @@ class DaskDMatrix:
|
|||||||
'feature_weights': self.feature_weights,
|
'feature_weights': self.feature_weights,
|
||||||
'meta_names': self.meta_names,
|
'meta_names': self.meta_names,
|
||||||
'missing': self.missing,
|
'missing': self.missing,
|
||||||
|
'enable_categorical': self.enable_categorical,
|
||||||
'parts': self.worker_map.get(worker_addr, None),
|
'parts': self.worker_map.get(worker_addr, None),
|
||||||
'is_quantile': self.is_quantile}
|
'is_quantile': self.is_quantile}
|
||||||
|
|
||||||
@ -668,6 +666,7 @@ def _create_device_quantile_dmatrix(
|
|||||||
missing: float,
|
missing: float,
|
||||||
parts: Optional[_DataParts],
|
parts: Optional[_DataParts],
|
||||||
max_bin: int,
|
max_bin: int,
|
||||||
|
enable_categorical: bool,
|
||||||
) -> DeviceQuantileDMatrix:
|
) -> DeviceQuantileDMatrix:
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
if parts is None:
|
if parts is None:
|
||||||
@ -680,6 +679,7 @@ def _create_device_quantile_dmatrix(
|
|||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
max_bin=max_bin,
|
max_bin=max_bin,
|
||||||
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -709,6 +709,7 @@ def _create_device_quantile_dmatrix(
|
|||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=worker.nthreads,
|
nthread=worker.nthreads,
|
||||||
max_bin=max_bin,
|
max_bin=max_bin,
|
||||||
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
dmatrix.set_info(feature_weights=feature_weights)
|
dmatrix.set_info(feature_weights=feature_weights)
|
||||||
return dmatrix
|
return dmatrix
|
||||||
@ -720,6 +721,7 @@ def _create_dmatrix(
|
|||||||
feature_weights: Optional[Any],
|
feature_weights: Optional[Any],
|
||||||
meta_names: List[str],
|
meta_names: List[str],
|
||||||
missing: float,
|
missing: float,
|
||||||
|
enable_categorical: bool,
|
||||||
parts: Optional[_DataParts]
|
parts: Optional[_DataParts]
|
||||||
) -> DMatrix:
|
) -> DMatrix:
|
||||||
'''Get data that local to worker from DaskDMatrix.
|
'''Get data that local to worker from DaskDMatrix.
|
||||||
@ -734,9 +736,12 @@ def _create_dmatrix(
|
|||||||
if list_of_parts is None:
|
if list_of_parts is None:
|
||||||
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
|
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
d = DMatrix(numpy.empty((0, 0)),
|
d = DMatrix(
|
||||||
feature_names=feature_names,
|
numpy.empty((0, 0)),
|
||||||
feature_types=feature_types)
|
feature_names=feature_names,
|
||||||
|
feature_types=feature_types,
|
||||||
|
enable_categorical=enable_categorical,
|
||||||
|
)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@ -764,6 +769,7 @@ def _create_dmatrix(
|
|||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=worker.nthreads,
|
nthread=worker.nthreads,
|
||||||
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
dmatrix.set_info(
|
dmatrix.set_info(
|
||||||
base_margin=_base_margin,
|
base_margin=_base_margin,
|
||||||
|
|||||||
@ -1151,12 +1151,12 @@ struct SegmentedUniqueReduceOp {
|
|||||||
* \return Number of unique values in total.
|
* \return Number of unique values in total.
|
||||||
*/
|
*/
|
||||||
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
|
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
|
||||||
typename ValOutIt, typename Comp>
|
typename ValOutIt, typename CompValue, typename CompKey>
|
||||||
size_t
|
size_t
|
||||||
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
|
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
|
||||||
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
|
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
|
||||||
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
|
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
|
||||||
Comp comp) {
|
CompValue comp, CompKey comp_key=thrust::equal_to<size_t>{}) {
|
||||||
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
|
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
|
||||||
auto unique_key_it = dh::MakeTransformIterator<Key>(
|
auto unique_key_it = dh::MakeTransformIterator<Key>(
|
||||||
thrust::make_counting_iterator(static_cast<size_t>(0)),
|
thrust::make_counting_iterator(static_cast<size_t>(0)),
|
||||||
@ -1177,7 +1177,7 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
|
|||||||
exec, unique_key_it, unique_key_it + n_inputs,
|
exec, unique_key_it, unique_key_it + n_inputs,
|
||||||
val_first, reduce_it, val_out,
|
val_first, reduce_it, val_out,
|
||||||
[=] __device__(Key const &l, Key const &r) {
|
[=] __device__(Key const &l, Key const &r) {
|
||||||
if (l.first == r.first) {
|
if (comp_key(l.first, r.first)) {
|
||||||
// In the same segment.
|
// In the same segment.
|
||||||
return comp(l.second, r.second);
|
return comp(l.second, r.second);
|
||||||
}
|
}
|
||||||
@ -1195,7 +1195,9 @@ template <typename... Inputs,
|
|||||||
* = nullptr>
|
* = nullptr>
|
||||||
size_t SegmentedUnique(Inputs &&...inputs) {
|
size_t SegmentedUnique(Inputs &&...inputs) {
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
|
return SegmentedUnique(thrust::cuda::par(alloc),
|
||||||
|
std::forward<Inputs &&>(inputs)...,
|
||||||
|
thrust::equal_to<size_t>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -129,60 +129,52 @@ void SortByWeight(dh::device_vector<float>* weights,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IsCatOp {
|
|
||||||
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
|
|
||||||
};
|
|
||||||
|
|
||||||
void RemoveDuplicatedCategories(
|
void RemoveDuplicatedCategories(
|
||||||
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
|
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
|
||||||
dh::device_vector<Entry> *p_sorted_entries,
|
dh::device_vector<Entry> *p_sorted_entries,
|
||||||
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
|
dh::caching_device_vector<size_t> *p_column_sizes_scan) {
|
||||||
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
||||||
auto& column_sizes_scan = *p_column_sizes_scan;
|
CHECK(!d_feature_types.empty());
|
||||||
if (!info.feature_types.Empty() &&
|
auto &column_sizes_scan = *p_column_sizes_scan;
|
||||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
auto &sorted_entries = *p_sorted_entries;
|
||||||
IsCatOp{})) {
|
// Removing duplicated entries in categorical features.
|
||||||
auto& sorted_entries = *p_sorted_entries;
|
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
|
||||||
// Removing duplicated entries in categorical features.
|
dh::SegmentedUnique(column_sizes_scan.data().get(),
|
||||||
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
|
column_sizes_scan.data().get() + column_sizes_scan.size(),
|
||||||
dh::SegmentedUnique(
|
sorted_entries.begin(), sorted_entries.end(),
|
||||||
column_sizes_scan.data().get(),
|
new_column_scan.data().get(), sorted_entries.begin(),
|
||||||
column_sizes_scan.data().get() + column_sizes_scan.size(),
|
[=] __device__(Entry const &l, Entry const &r) {
|
||||||
sorted_entries.begin(), sorted_entries.end(),
|
if (l.index == r.index) {
|
||||||
new_column_scan.data().get(), sorted_entries.begin(),
|
if (IsCat(d_feature_types, l.index)) {
|
||||||
[=] __device__(Entry const &l, Entry const &r) {
|
return l.fvalue == r.fvalue;
|
||||||
if (l.index == r.index) {
|
}
|
||||||
if (IsCat(d_feature_types, l.index)) {
|
}
|
||||||
return l.fvalue == r.fvalue;
|
return false;
|
||||||
}
|
});
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
|
|
||||||
// Renew the column scan and cut scan based on categorical data.
|
// Renew the column scan and cut scan based on categorical data.
|
||||||
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
|
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
|
||||||
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
|
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
|
||||||
info.num_col_ + 1);
|
info.num_col_ + 1);
|
||||||
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
|
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
|
||||||
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
|
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
|
||||||
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
|
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
|
||||||
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
|
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
|
||||||
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
|
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
|
||||||
if (idx == d_new_columns_ptr.size() - 1) {
|
if (idx == d_new_columns_ptr.size() - 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (IsCat(d_feature_types, idx)) {
|
if (IsCat(d_feature_types, idx)) {
|
||||||
// Cut size is the same as number of categories in input.
|
// Cut size is the same as number of categories in input.
|
||||||
d_new_cuts_size[idx] =
|
d_new_cuts_size[idx] =
|
||||||
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
|
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
|
||||||
} else {
|
} else {
|
||||||
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
|
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// Turn size into ptr.
|
// Turn size into ptr.
|
||||||
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
|
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
|
||||||
new_cuts_size.cend(), d_cuts_ptr.data());
|
new_cuts_size.cend(), d_cuts_ptr.data());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
@ -215,8 +207,11 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
|||||||
0, sorted_entries.size(),
|
0, sorted_entries.size(),
|
||||||
&cuts_ptr, &column_sizes_scan);
|
&cuts_ptr, &column_sizes_scan);
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
|
||||||
&column_sizes_scan);
|
if (sketch_container->HasCategorical()) {
|
||||||
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
||||||
|
&sorted_entries, &column_sizes_scan);
|
||||||
|
}
|
||||||
|
|
||||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||||
@ -281,8 +276,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
|||||||
0, sorted_entries.size(),
|
0, sorted_entries.size(),
|
||||||
&cuts_ptr, &column_sizes_scan);
|
&cuts_ptr, &column_sizes_scan);
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
if (sketch_container->HasCategorical()) {
|
||||||
&column_sizes_scan);
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
||||||
|
&sorted_entries, &column_sizes_scan);
|
||||||
|
}
|
||||||
|
|
||||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||||
|
|
||||||
// Extract cuts
|
// Extract cuts
|
||||||
|
|||||||
@ -210,6 +210,7 @@ void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
|
|||||||
Span<bst_row_t const> const &x_ptr,
|
Span<bst_row_t const> const &x_ptr,
|
||||||
Span<SketchEntry const> const &d_y,
|
Span<SketchEntry const> const &d_y,
|
||||||
Span<bst_row_t const> const &y_ptr,
|
Span<bst_row_t const> const &y_ptr,
|
||||||
|
Span<FeatureType const> feature_types,
|
||||||
Span<SketchEntry> out,
|
Span<SketchEntry> out,
|
||||||
Span<bst_row_t> out_ptr) {
|
Span<bst_row_t> out_ptr) {
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
@ -408,31 +409,6 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
|
|||||||
return n_uniques;
|
return n_uniques;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t SketchContainer::Unique() {
|
|
||||||
timer_.Start(__func__);
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
|
||||||
this->columns_ptr_.SetDevice(device_);
|
|
||||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
|
||||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
|
||||||
Span<SketchEntry> entries = dh::ToSpan(this->Current());
|
|
||||||
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
|
||||||
scan_out.SetDevice(device_);
|
|
||||||
auto d_scan_out = scan_out.DeviceSpan();
|
|
||||||
|
|
||||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
|
||||||
size_t n_uniques = dh::SegmentedUnique(
|
|
||||||
d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
|
|
||||||
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(),
|
|
||||||
entries.data(),
|
|
||||||
detail::SketchUnique{});
|
|
||||||
this->columns_ptr_.Copy(scan_out);
|
|
||||||
CHECK(!this->columns_ptr_.HostCanRead());
|
|
||||||
|
|
||||||
this->Current().resize(n_uniques);
|
|
||||||
timer_.Stop(__func__);
|
|
||||||
return n_uniques;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SketchContainer::Prune(size_t to) {
|
void SketchContainer::Prune(size_t to) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
@ -490,13 +466,20 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
|
|||||||
this->Other().resize(this->Current().size() + that.size());
|
this->Other().resize(this->Current().size() + that.size());
|
||||||
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
|
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
|
||||||
|
|
||||||
MergeImpl(device_, this->Data(), this->ColumnsPtr(),
|
auto feature_types = this->FeatureTypes().ConstDeviceSpan();
|
||||||
that, d_that_columns_ptr,
|
MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
|
||||||
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan());
|
feature_types, dh::ToSpan(this->Other()),
|
||||||
|
columns_ptr_b_.DeviceSpan());
|
||||||
this->columns_ptr_.Copy(columns_ptr_b_);
|
this->columns_ptr_.Copy(columns_ptr_b_);
|
||||||
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
|
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
|
||||||
this->Alternate();
|
this->Alternate();
|
||||||
|
|
||||||
|
if (this->HasCategorical()) {
|
||||||
|
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
|
||||||
|
this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
|
||||||
|
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
|
||||||
|
});
|
||||||
|
}
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,19 @@ class HistogramCuts;
|
|||||||
using WQSketch = WQuantileSketch<bst_float, bst_float>;
|
using WQSketch = WQuantileSketch<bst_float, bst_float>;
|
||||||
using SketchEntry = WQSketch::Entry;
|
using SketchEntry = WQSketch::Entry;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
struct IsCatOp {
|
||||||
|
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
||||||
|
return ft == FeatureType::kCategorical;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct SketchUnique {
|
||||||
|
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
||||||
|
return a.value - b.value == 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief A container that holds the device sketches. Sketching is performed per-column,
|
* \brief A container that holds the device sketches. Sketching is performed per-column,
|
||||||
* but fused into single operation for performance.
|
* but fused into single operation for performance.
|
||||||
@ -43,6 +56,8 @@ class SketchContainer {
|
|||||||
HostDeviceVector<OffsetT> columns_ptr_;
|
HostDeviceVector<OffsetT> columns_ptr_;
|
||||||
HostDeviceVector<OffsetT> columns_ptr_b_;
|
HostDeviceVector<OffsetT> columns_ptr_b_;
|
||||||
|
|
||||||
|
bool has_categorical_{false};
|
||||||
|
|
||||||
dh::device_vector<SketchEntry>& Current() {
|
dh::device_vector<SketchEntry>& Current() {
|
||||||
if (current_buffer_) {
|
if (current_buffer_) {
|
||||||
return entries_a_;
|
return entries_a_;
|
||||||
@ -102,14 +117,21 @@ class SketchContainer {
|
|||||||
this->feature_types_.SetDevice(device);
|
this->feature_types_.SetDevice(device);
|
||||||
this->feature_types_.ConstDeviceSpan();
|
this->feature_types_.ConstDeviceSpan();
|
||||||
this->feature_types_.ConstHostSpan();
|
this->feature_types_.ConstHostSpan();
|
||||||
|
|
||||||
|
auto d_feature_types = feature_types_.ConstDeviceSpan();
|
||||||
|
has_categorical_ =
|
||||||
|
!d_feature_types.empty() &&
|
||||||
|
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||||
|
detail::IsCatOp{});
|
||||||
|
|
||||||
timer_.Init(__func__);
|
timer_.Init(__func__);
|
||||||
}
|
}
|
||||||
/* \brief Return GPU ID for this container. */
|
/* \brief Return GPU ID for this container. */
|
||||||
int32_t DeviceIdx() const { return device_; }
|
int32_t DeviceIdx() const { return device_; }
|
||||||
|
/* \brief Whether the predictor matrix contains categorical features. */
|
||||||
|
bool HasCategorical() const { return has_categorical_; }
|
||||||
/* \brief Accumulate weights of duplicated entries in input. */
|
/* \brief Accumulate weights of duplicated entries in input. */
|
||||||
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
|
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
|
||||||
/* \brief Removes all the duplicated elements in quantile structure. */
|
|
||||||
size_t Unique();
|
|
||||||
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
|
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
|
||||||
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
|
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
|
||||||
void FixError();
|
void FixError();
|
||||||
@ -154,15 +176,35 @@ class SketchContainer {
|
|||||||
|
|
||||||
SketchContainer(const SketchContainer&) = delete;
|
SketchContainer(const SketchContainer&) = delete;
|
||||||
SketchContainer& operator=(const SketchContainer&) = delete;
|
SketchContainer& operator=(const SketchContainer&) = delete;
|
||||||
};
|
|
||||||
|
|
||||||
namespace detail {
|
/* \brief Removes all the duplicated elements in quantile structure. */
|
||||||
struct SketchUnique {
|
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||||
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||||
return a.value - b.value == 0;
|
timer_.Start(__func__);
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
|
this->columns_ptr_.SetDevice(device_);
|
||||||
|
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||||
|
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||||
|
Span<SketchEntry> entries = dh::ToSpan(this->Current());
|
||||||
|
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
||||||
|
scan_out.SetDevice(device_);
|
||||||
|
auto d_scan_out = scan_out.DeviceSpan();
|
||||||
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
|
||||||
|
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||||
|
size_t n_uniques = dh::SegmentedUnique(
|
||||||
|
thrust::cuda::par(alloc), d_column_scan.data(),
|
||||||
|
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||||
|
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||||
|
entries.data(), detail::SketchUnique{}, key_comp);
|
||||||
|
this->columns_ptr_.Copy(scan_out);
|
||||||
|
CHECK(!this->columns_ptr_.HostCanRead());
|
||||||
|
|
||||||
|
this->Current().resize(n_uniques);
|
||||||
|
timer_.Stop(__func__);
|
||||||
|
return n_uniques;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace detail
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -134,17 +134,20 @@ struct WriteCompressedEllpackFunctor {
|
|||||||
const common::CompressedBufferWriter& writer,
|
const common::CompressedBufferWriter& writer,
|
||||||
AdapterBatchT batch,
|
AdapterBatchT batch,
|
||||||
EllpackDeviceAccessor accessor,
|
EllpackDeviceAccessor accessor,
|
||||||
|
common::Span<FeatureType const> feature_types,
|
||||||
const data::IsValidFunctor& is_valid)
|
const data::IsValidFunctor& is_valid)
|
||||||
: d_buffer(buffer),
|
: d_buffer(buffer),
|
||||||
writer(writer),
|
writer(writer),
|
||||||
batch(std::move(batch)),
|
batch(std::move(batch)),
|
||||||
accessor(std::move(accessor)),
|
accessor(std::move(accessor)),
|
||||||
|
feature_types(std::move(feature_types)),
|
||||||
is_valid(is_valid) {}
|
is_valid(is_valid) {}
|
||||||
|
|
||||||
common::CompressedByteT* d_buffer;
|
common::CompressedByteT* d_buffer;
|
||||||
common::CompressedBufferWriter writer;
|
common::CompressedBufferWriter writer;
|
||||||
AdapterBatchT batch;
|
AdapterBatchT batch;
|
||||||
EllpackDeviceAccessor accessor;
|
EllpackDeviceAccessor accessor;
|
||||||
|
common::Span<FeatureType const> feature_types;
|
||||||
data::IsValidFunctor is_valid;
|
data::IsValidFunctor is_valid;
|
||||||
|
|
||||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||||
@ -154,7 +157,12 @@ struct WriteCompressedEllpackFunctor {
|
|||||||
// -1 because the scan is inclusive
|
// -1 because the scan is inclusive
|
||||||
size_t output_position =
|
size_t output_position =
|
||||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
||||||
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
|
uint32_t bin_idx = 0;
|
||||||
|
if (common::IsCat(feature_types, e.column_idx)) {
|
||||||
|
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
|
||||||
|
} else {
|
||||||
|
bin_idx = accessor.SearchBin<false>(e.value, e.column_idx);
|
||||||
|
}
|
||||||
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
|
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
@ -184,8 +192,9 @@ class TypedDiscard : public thrust::discard_iterator<T> {
|
|||||||
// Here the data is already correctly ordered and simply needs to be compacted
|
// Here the data is already correctly ordered and simply needs to be compacted
|
||||||
// to remove missing data
|
// to remove missing data
|
||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||||
int device_idx, float missing) {
|
common::Span<FeatureType const> feature_types,
|
||||||
|
EllpackPageImpl *dst, int device_idx, float missing) {
|
||||||
// Some witchcraft happens here
|
// Some witchcraft happens here
|
||||||
// The goal is to copy valid elements out of the input to an ELLPACK matrix
|
// The goal is to copy valid elements out of the input to an ELLPACK matrix
|
||||||
// with a given row stride, using no extra working memory Standard stream
|
// with a given row stride, using no extra working memory Standard stream
|
||||||
@ -220,7 +229,8 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
|||||||
|
|
||||||
// We redirect the scan output into this functor to do the actual writing
|
// We redirect the scan output into this functor to do the actual writing
|
||||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||||
d_compressed_buffer, writer, batch, device_accessor, is_valid);
|
d_compressed_buffer, writer, batch, device_accessor, feature_types,
|
||||||
|
is_valid);
|
||||||
TypedDiscard<Tuple> discard;
|
TypedDiscard<Tuple> discard;
|
||||||
thrust::transform_output_iterator<
|
thrust::transform_output_iterator<
|
||||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||||
@ -263,22 +273,22 @@ template <typename AdapterBatch>
|
|||||||
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||||
bool is_dense, int nthread,
|
bool is_dense, int nthread,
|
||||||
common::Span<size_t> row_counts_span,
|
common::Span<size_t> row_counts_span,
|
||||||
|
common::Span<FeatureType const> feature_types,
|
||||||
size_t row_stride, size_t n_rows, size_t n_cols,
|
size_t row_stride, size_t n_rows, size_t n_cols,
|
||||||
common::HistogramCuts const& cuts) {
|
common::HistogramCuts const& cuts) {
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
|
|
||||||
*this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows);
|
*this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows);
|
||||||
CopyDataToEllpack(batch, this, device, missing);
|
CopyDataToEllpack(batch, feature_types, this, device, missing);
|
||||||
WriteNullValues(this, device, row_counts_span);
|
WriteNullValues(this, device, row_counts_span);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
|
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
|
||||||
template EllpackPageImpl::EllpackPageImpl( \
|
template EllpackPageImpl::EllpackPageImpl( \
|
||||||
__BATCH_T batch, float missing, int device, \
|
__BATCH_T batch, float missing, int device, bool is_dense, int nthread, \
|
||||||
bool is_dense, int nthread, \
|
common::Span<size_t> row_counts_span, \
|
||||||
common::Span<size_t> row_counts_span, \
|
common::Span<FeatureType const> feature_types, size_t row_stride, \
|
||||||
size_t row_stride, size_t n_rows, size_t n_cols, \
|
size_t n_rows, size_t n_cols, common::HistogramCuts const &cuts);
|
||||||
common::HistogramCuts const& cuts);
|
|
||||||
|
|
||||||
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
|
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
|
||||||
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
|
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
|
||||||
@ -467,11 +477,17 @@ size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
|
|||||||
return compressed_size_bytes;
|
return compressed_size_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(int device) const {
|
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
|
||||||
|
int device, common::Span<FeatureType const> feature_types) const {
|
||||||
gidx_buffer.SetDevice(device);
|
gidx_buffer.SetDevice(device);
|
||||||
return EllpackDeviceAccessor(
|
return {device,
|
||||||
device, cuts_, is_dense, row_stride, base_rowid, n_rows,
|
cuts_,
|
||||||
common::CompressedIterator<uint32_t>(gidx_buffer.ConstDevicePointer(),
|
is_dense,
|
||||||
NumSymbols()));
|
row_stride,
|
||||||
|
base_rowid,
|
||||||
|
n_rows,
|
||||||
|
common::CompressedIterator<uint32_t>(gidx_buffer.ConstDevicePointer(),
|
||||||
|
NumSymbols()),
|
||||||
|
feature_types};
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include "../common/compressed_iterator.h"
|
#include "../common/compressed_iterator.h"
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
|
#include "../common/categorical.h"
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -31,13 +32,17 @@ struct EllpackDeviceAccessor {
|
|||||||
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
|
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
|
||||||
common::Span<const bst_float> gidx_fvalue_map;
|
common::Span<const bst_float> gidx_fvalue_map;
|
||||||
|
|
||||||
|
common::Span<const FeatureType> feature_types;
|
||||||
|
|
||||||
EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts,
|
EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts,
|
||||||
bool is_dense, size_t row_stride, size_t base_rowid,
|
bool is_dense, size_t row_stride, size_t base_rowid,
|
||||||
size_t n_rows,common::CompressedIterator<uint32_t> gidx_iter)
|
size_t n_rows,common::CompressedIterator<uint32_t> gidx_iter,
|
||||||
|
common::Span<FeatureType const> feature_types)
|
||||||
: is_dense(is_dense),
|
: is_dense(is_dense),
|
||||||
row_stride(row_stride),
|
row_stride(row_stride),
|
||||||
base_rowid(base_rowid),
|
base_rowid(base_rowid),
|
||||||
n_rows(n_rows) ,gidx_iter(gidx_iter){
|
n_rows(n_rows) ,gidx_iter(gidx_iter),
|
||||||
|
feature_types{feature_types} {
|
||||||
cuts.cut_values_.SetDevice(device);
|
cuts.cut_values_.SetDevice(device);
|
||||||
cuts.cut_ptrs_.SetDevice(device);
|
cuts.cut_ptrs_.SetDevice(device);
|
||||||
cuts.min_vals_.SetDevice(device);
|
cuts.min_vals_.SetDevice(device);
|
||||||
@ -64,12 +69,23 @@ struct EllpackDeviceAccessor {
|
|||||||
return gidx;
|
return gidx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool is_cat>
|
||||||
__device__ uint32_t SearchBin(float value, size_t column_id) const {
|
__device__ uint32_t SearchBin(float value, size_t column_id) const {
|
||||||
auto beg = feature_segments[column_id];
|
auto beg = feature_segments[column_id];
|
||||||
auto end = feature_segments[column_id + 1];
|
auto end = feature_segments[column_id + 1];
|
||||||
auto it =
|
uint32_t idx = 0;
|
||||||
thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin()+ beg, gidx_fvalue_map.cbegin() + end, value);
|
if (is_cat) {
|
||||||
uint32_t idx = it - gidx_fvalue_map.cbegin();
|
auto it = dh::MakeTransformIterator<bst_cat_t>(
|
||||||
|
gidx_fvalue_map.cbegin(), [](float v) { return common::AsCat(v); });
|
||||||
|
idx = thrust::lower_bound(thrust::seq, it + beg, it + end,
|
||||||
|
common::AsCat(value)) -
|
||||||
|
it;
|
||||||
|
} else {
|
||||||
|
auto it = thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin() + beg,
|
||||||
|
gidx_fvalue_map.cbegin() + end, value);
|
||||||
|
idx = it - gidx_fvalue_map.cbegin();
|
||||||
|
}
|
||||||
|
|
||||||
if (idx == end) {
|
if (idx == end) {
|
||||||
idx -= 1;
|
idx -= 1;
|
||||||
}
|
}
|
||||||
@ -134,10 +150,12 @@ class EllpackPageImpl {
|
|||||||
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
||||||
|
|
||||||
template <typename AdapterBatch>
|
template <typename AdapterBatch>
|
||||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread,
|
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||||
|
bool is_dense, int nthread,
|
||||||
common::Span<size_t> row_counts_span,
|
common::Span<size_t> row_counts_span,
|
||||||
|
common::Span<FeatureType const> feature_types,
|
||||||
size_t row_stride, size_t n_rows, size_t n_cols,
|
size_t row_stride, size_t n_rows, size_t n_cols,
|
||||||
common::HistogramCuts const& cuts);
|
common::HistogramCuts const &cuts);
|
||||||
|
|
||||||
/*! \brief Copy the elements of the given ELLPACK page into this page.
|
/*! \brief Copy the elements of the given ELLPACK page into this page.
|
||||||
*
|
*
|
||||||
@ -176,7 +194,9 @@ class EllpackPageImpl {
|
|||||||
* not found). */
|
* not found). */
|
||||||
size_t NumSymbols() const { return cuts_.TotalBins() + 1; }
|
size_t NumSymbols() const { return cuts_.TotalBins() + 1; }
|
||||||
|
|
||||||
EllpackDeviceAccessor GetDeviceAccessor(int device) const;
|
EllpackDeviceAccessor
|
||||||
|
GetDeviceAccessor(int device,
|
||||||
|
common::Span<FeatureType const> feature_types = {}) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -148,9 +148,13 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
|||||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
||||||
});
|
});
|
||||||
auto is_dense = this->IsDense();
|
auto is_dense = this->IsDense();
|
||||||
|
|
||||||
|
proxy->Info().feature_types.SetDevice(get_device());
|
||||||
|
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
|
||||||
auto new_impl = Dispatch(proxy, [&](auto const &value) {
|
auto new_impl = Dispatch(proxy, [&](auto const &value) {
|
||||||
return EllpackPageImpl(value, missing, get_device(), is_dense, nthread,
|
return EllpackPageImpl(value, missing, get_device(), is_dense, nthread,
|
||||||
row_counts_span, row_stride, rows, cols, cuts);
|
row_counts_span, d_feature_types, row_stride, rows,
|
||||||
|
cols, cuts);
|
||||||
});
|
});
|
||||||
size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset);
|
size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset);
|
||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
|
|||||||
@ -155,6 +155,9 @@ struct EllpackLoader {
|
|||||||
if (gidx == -1) {
|
if (gidx == -1) {
|
||||||
return nan("");
|
return nan("");
|
||||||
}
|
}
|
||||||
|
if (common::IsCat(matrix.feature_types, fidx)) {
|
||||||
|
return matrix.gidx_fvalue_map[gidx];
|
||||||
|
}
|
||||||
// The gradient index needs to be shifted by one as min values are not included in the
|
// The gradient index needs to be shifted by one as min values are not included in the
|
||||||
// cuts.
|
// cuts.
|
||||||
if (gidx == matrix.feature_segments[fidx]) {
|
if (gidx == matrix.feature_segments[fidx]) {
|
||||||
@ -592,8 +595,10 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
} else {
|
} else {
|
||||||
size_t batch_offset = 0;
|
size_t batch_offset = 0;
|
||||||
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||||
|
dmat->Info().feature_types.SetDevice(generic_param_->gpu_id);
|
||||||
|
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
||||||
this->PredictInternal(
|
this->PredictInternal(
|
||||||
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id),
|
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types),
|
||||||
d_model,
|
d_model,
|
||||||
out_preds,
|
out_preds,
|
||||||
batch_offset);
|
batch_offset);
|
||||||
|
|||||||
@ -19,7 +19,7 @@ ENV PATH=/opt/python/bin:$PATH
|
|||||||
# Create new Conda environment with cuDF, Dask, and cuPy
|
# Create new Conda environment with cuDF, Dask, and cuPy
|
||||||
RUN \
|
RUN \
|
||||||
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||||
python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy \
|
python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy=9.1* \
|
||||||
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
|
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
|
||||||
|
|
||||||
ENV GOSU_VERSION 1.10
|
ENV GOSU_VERSION 1.10
|
||||||
|
|||||||
@ -68,7 +68,16 @@ void TestEquivalent(float sparsity) {
|
|||||||
auto const& buffer_from_iter = page_concatenated->gidx_buffer;
|
auto const& buffer_from_iter = page_concatenated->gidx_buffer;
|
||||||
auto const& buffer_from_data = ellpack.Impl()->gidx_buffer;
|
auto const& buffer_from_data = ellpack.Impl()->gidx_buffer;
|
||||||
ASSERT_NE(buffer_from_data.Size(), 0);
|
ASSERT_NE(buffer_from_data.Size(), 0);
|
||||||
ASSERT_EQ(buffer_from_data.ConstHostVector(), buffer_from_data.ConstHostVector());
|
|
||||||
|
common::CompressedIterator<uint32_t> data_buf{
|
||||||
|
buffer_from_data.ConstHostPointer(), from_data.NumSymbols()};
|
||||||
|
common::CompressedIterator<uint32_t> data_iter{
|
||||||
|
buffer_from_iter.ConstHostPointer(), from_iter.NumSymbols()};
|
||||||
|
CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols());
|
||||||
|
CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride);
|
||||||
|
for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) {
|
||||||
|
CHECK_EQ(data_buf[i], data_iter[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -225,6 +225,9 @@ void TestCategoricalPrediction(std::string name) {
|
|||||||
row[split_ind] = split_cat;
|
row[split_ind] = split_cat;
|
||||||
auto m = GetDMatrixFromData(row, 1, kCols);
|
auto m = GetDMatrixFromData(row, 1, kCols);
|
||||||
|
|
||||||
|
std::vector<FeatureType> types(10, FeatureType::kCategorical);
|
||||||
|
m->Info().feature_types.HostVector() = types;
|
||||||
|
|
||||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||||
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
||||||
|
|||||||
@ -225,19 +225,28 @@ class IterForDMatrixTest(xgb.core.DataIter):
|
|||||||
ROWS_PER_BATCH = 100 # data is splited by rows
|
ROWS_PER_BATCH = 100 # data is splited by rows
|
||||||
BATCHES = 16
|
BATCHES = 16
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, categorical):
|
||||||
'''Generate some random data for demostration.
|
'''Generate some random data for demostration.
|
||||||
|
|
||||||
Actual data can be anything that is currently supported by XGBoost.
|
Actual data can be anything that is currently supported by XGBoost.
|
||||||
'''
|
'''
|
||||||
import cudf
|
import cudf
|
||||||
self.rows = self.ROWS_PER_BATCH
|
self.rows = self.ROWS_PER_BATCH
|
||||||
rng = np.random.RandomState(1994)
|
|
||||||
self._data = [
|
if categorical:
|
||||||
cudf.DataFrame(
|
self._data = []
|
||||||
{'a': rng.randn(self.ROWS_PER_BATCH),
|
self._labels = []
|
||||||
'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES
|
for i in range(self.BATCHES):
|
||||||
self._labels = [rng.randn(self.rows)] * self.BATCHES
|
X, y = tm.make_categorical(self.ROWS_PER_BATCH, 4, 13, False)
|
||||||
|
self._data.append(cudf.from_pandas(X))
|
||||||
|
self._labels.append(y)
|
||||||
|
else:
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
self._data = [
|
||||||
|
cudf.DataFrame(
|
||||||
|
{'a': rng.randn(self.ROWS_PER_BATCH),
|
||||||
|
'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES
|
||||||
|
self._labels = [rng.randn(self.rows)] * self.BATCHES
|
||||||
|
|
||||||
self.it = 0 # set iterator to 0
|
self.it = 0 # set iterator to 0
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -272,24 +281,26 @@ class IterForDMatrixTest(xgb.core.DataIter):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_from_cudf_iter():
|
@pytest.mark.parametrize("enable_categorical", [True, False])
|
||||||
|
def test_from_cudf_iter(enable_categorical):
|
||||||
rounds = 100
|
rounds = 100
|
||||||
it = IterForDMatrixTest()
|
it = IterForDMatrixTest(enable_categorical)
|
||||||
|
params = {"tree_method": "gpu_hist"}
|
||||||
|
|
||||||
# Use iterator
|
# Use iterator
|
||||||
m_it = xgb.DeviceQuantileDMatrix(it)
|
m_it = xgb.DeviceQuantileDMatrix(it, enable_categorical=enable_categorical)
|
||||||
reg_with_it = xgb.train({'tree_method': 'gpu_hist'}, m_it,
|
reg_with_it = xgb.train(params, m_it, num_boost_round=rounds)
|
||||||
num_boost_round=rounds)
|
|
||||||
predict_with_it = reg_with_it.predict(m_it)
|
|
||||||
|
|
||||||
# Without using iterator
|
X = it.as_array()
|
||||||
m = xgb.DMatrix(it.as_array(), it.as_array_labels())
|
y = it.as_array_labels()
|
||||||
|
|
||||||
|
m = xgb.DMatrix(X, y, enable_categorical=enable_categorical)
|
||||||
|
|
||||||
assert m_it.num_col() == m.num_col()
|
assert m_it.num_col() == m.num_col()
|
||||||
assert m_it.num_row() == m.num_row()
|
assert m_it.num_row() == m.num_row()
|
||||||
|
|
||||||
reg = xgb.train({'tree_method': 'gpu_hist'}, m,
|
reg = xgb.train(params, m, num_boost_round=rounds)
|
||||||
num_boost_round=rounds)
|
|
||||||
predict = reg.predict(m)
|
|
||||||
|
|
||||||
|
predict = reg.predict(m)
|
||||||
|
predict_with_it = reg_with_it.predict(m_it)
|
||||||
np.testing.assert_allclose(predict_with_it, predict)
|
np.testing.assert_allclose(predict_with_it, predict)
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
from typing import Type, TypeVar, Any, Dict, List
|
from typing import Type, TypeVar, Any, Dict, List, Tuple
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import asyncio
|
import asyncio
|
||||||
import xgboost
|
import xgboost
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import json
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from hypothesis import given, strategies, settings, note
|
from hypothesis import given, strategies, settings, note
|
||||||
@ -41,6 +43,49 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def make_categorical(
|
||||||
|
client: Client,
|
||||||
|
n_samples: int,
|
||||||
|
n_features: int,
|
||||||
|
n_categories: int,
|
||||||
|
onehot: bool = False,
|
||||||
|
) -> Tuple[dd.DataFrame, dd.Series]:
|
||||||
|
workers = _get_client_workers(client)
|
||||||
|
n_workers = len(workers)
|
||||||
|
dfs = []
|
||||||
|
|
||||||
|
def pack(**kwargs: Any) -> dd.DataFrame:
|
||||||
|
X, y = tm.make_categorical(**kwargs)
|
||||||
|
X["label"] = y
|
||||||
|
return X
|
||||||
|
|
||||||
|
meta = pack(
|
||||||
|
n_samples=1, n_features=n_features, n_categories=n_categories, onehot=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, worker in enumerate(workers):
|
||||||
|
l_n_samples = min(
|
||||||
|
n_samples // n_workers, n_samples - i * (n_samples // n_workers)
|
||||||
|
)
|
||||||
|
future = client.submit(
|
||||||
|
pack,
|
||||||
|
n_samples=l_n_samples,
|
||||||
|
n_features=n_features,
|
||||||
|
n_categories=n_categories,
|
||||||
|
onehot=False,
|
||||||
|
workers=[worker],
|
||||||
|
)
|
||||||
|
dfs.append(future)
|
||||||
|
|
||||||
|
df = dd.from_delayed(dfs, meta=meta)
|
||||||
|
y = df["label"]
|
||||||
|
X = df[df.columns.difference(["label"])]
|
||||||
|
|
||||||
|
if onehot:
|
||||||
|
return dd.get_dummies(X), y
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
|
def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
cp.cuda.runtime.setDevice(0)
|
cp.cuda.runtime.setDevice(0)
|
||||||
@ -126,6 +171,62 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
|
|||||||
inplace_predictions)
|
inplace_predictions)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||||
|
def test_categorical(local_cuda_cluster: LocalCUDACluster) -> None:
|
||||||
|
with Client(local_cuda_cluster) as client:
|
||||||
|
import dask_cudf
|
||||||
|
|
||||||
|
rounds = 10
|
||||||
|
X, y = make_categorical(client, 10000, 30, 13)
|
||||||
|
X = dask_cudf.from_dask_dataframe(X)
|
||||||
|
|
||||||
|
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
|
||||||
|
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
|
||||||
|
|
||||||
|
parameters = {"tree_method": "gpu_hist"}
|
||||||
|
|
||||||
|
m = dxgb.DaskDMatrix(client, X_onehot, y, enable_categorical=True)
|
||||||
|
by_etl_results = dxgb.train(
|
||||||
|
client,
|
||||||
|
parameters,
|
||||||
|
m,
|
||||||
|
num_boost_round=rounds,
|
||||||
|
evals=[(m, "Train")],
|
||||||
|
)["history"]
|
||||||
|
|
||||||
|
m = dxgb.DaskDMatrix(client, X, y, enable_categorical=True)
|
||||||
|
output = dxgb.train(
|
||||||
|
client,
|
||||||
|
parameters,
|
||||||
|
m,
|
||||||
|
num_boost_round=rounds,
|
||||||
|
evals=[(m, "Train")],
|
||||||
|
)
|
||||||
|
by_builtin_results = output["history"]
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(by_etl_results["Train"]["rmse"]),
|
||||||
|
np.array(by_builtin_results["Train"]["rmse"]),
|
||||||
|
rtol=1e-3,
|
||||||
|
)
|
||||||
|
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
||||||
|
|
||||||
|
model = output["booster"]
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
path = os.path.join(tempdir, "model.json")
|
||||||
|
model.save_model(path)
|
||||||
|
with open(path, "r") as fd:
|
||||||
|
categorical = json.load(fd)
|
||||||
|
|
||||||
|
categories_sizes = np.array(
|
||||||
|
categorical["learner"]["gradient_booster"]["model"]["trees"][-1][
|
||||||
|
"categories_sizes"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert categories_sizes.shape[0] != 0
|
||||||
|
np.testing.assert_allclose(categories_sizes, 1)
|
||||||
|
|
||||||
|
|
||||||
def to_cp(x: Any, DMatrixT: Type) -> Any:
|
def to_cp(x: Any, DMatrixT: Type) -> Any:
|
||||||
import cupy
|
import cupy
|
||||||
if isinstance(x, np.ndarray) and \
|
if isinstance(x, np.ndarray) and \
|
||||||
|
|||||||
@ -236,7 +236,7 @@ def get_mq2008(dpath):
|
|||||||
|
|
||||||
@memory.cache
|
@memory.cache
|
||||||
def make_categorical(
|
def make_categorical(
|
||||||
n_samples: int, n_features: int, n_categories: int, onehot_enc: bool
|
n_samples: int, n_features: int, n_categories: int, onehot: bool
|
||||||
):
|
):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ def make_categorical(
|
|||||||
|
|
||||||
pd_dict = {}
|
pd_dict = {}
|
||||||
for i in range(n_features + 1):
|
for i in range(n_features + 1):
|
||||||
c = rng.randint(low=0, high=n_categories + 1, size=n_samples)
|
c = rng.randint(low=0, high=n_categories, size=n_samples)
|
||||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||||
|
|
||||||
df = pd.DataFrame(pd_dict)
|
df = pd.DataFrame(pd_dict)
|
||||||
@ -255,11 +255,13 @@ def make_categorical(
|
|||||||
label += 1
|
label += 1
|
||||||
|
|
||||||
df = df.astype("category")
|
df = df.astype("category")
|
||||||
if onehot_enc:
|
categories = np.arange(0, n_categories)
|
||||||
cat = pd.get_dummies(df)
|
for col in df.columns:
|
||||||
else:
|
df[col] = df[col].cat.set_categories(categories)
|
||||||
cat = df
|
|
||||||
return cat, label
|
if onehot:
|
||||||
|
return pd.get_dummies(df), label
|
||||||
|
return df, label
|
||||||
|
|
||||||
|
|
||||||
_unweighted_datasets_strategy = strategies.sampled_from(
|
_unweighted_datasets_strategy = strategies.sampled_from(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user