Support categorical data in GPU sketching. (#6137)
This commit is contained in:
@@ -79,7 +79,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
} else {
|
||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||
}
|
||||
sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device());
|
||||
sketch_containers.emplace_back(proxy->Info().feature_types,
|
||||
batch_param_.max_bin, cols, num_rows(), get_device());
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
proxy->Info().weights_.SetDevice(get_device());
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
@@ -101,7 +102,10 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
}
|
||||
iter.Reset();
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device());
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
common::SketchContainer final_sketch(
|
||||
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(),
|
||||
batch_param_.max_bin, cols, accumulated_rows, get_device());
|
||||
for (auto const& sketch : sketch_containers) {
|
||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
final_sketch.FixError();
|
||||
|
||||
Reference in New Issue
Block a user