Support categorical data in GPU sketching. (#6137)

This commit is contained in:
Jiaming Yuan
2020-09-21 13:53:06 +08:00
committed by GitHub
parent c932fb50a1
commit 210c131ce7
6 changed files with 196 additions and 62 deletions

View File

@@ -79,7 +79,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
} else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
}
sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device());
sketch_containers.emplace_back(proxy->Info().feature_types,
batch_param_.max_bin, cols, num_rows(), get_device());
auto* p_sketch = &sketch_containers.back();
proxy->Info().weights_.SetDevice(get_device());
Dispatch(proxy, [&](auto const &value) {
@@ -101,7 +102,10 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
}
iter.Reset();
dh::safe_cuda(cudaSetDevice(get_device()));
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device());
HostDeviceVector<FeatureType> ft;
common::SketchContainer final_sketch(
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(),
batch_param_.max_bin, cols, accumulated_rows, get_device());
for (auto const& sketch : sketch_containers) {
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
final_sketch.FixError();