Support categorical data in GPU sketching. (#6137)

This commit is contained in:
Jiaming Yuan
2020-09-21 13:53:06 +08:00
committed by GitHub
parent c932fb50a1
commit 210c131ce7
6 changed files with 196 additions and 62 deletions

View File

@@ -4,6 +4,7 @@
#include <memory>
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "device_helpers.cuh"
#include "quantile.h"
#include "timer.h"
@@ -28,6 +29,7 @@ class SketchContainer {
private:
Monitor timer_;
std::unique_ptr<dh::AllReducer> reducer_;
HostDeviceVector<FeatureType> feature_types_;
bst_row_t num_rows_;
bst_feature_t num_columns_;
int32_t num_bins_;
@@ -39,6 +41,7 @@ class SketchContainer {
bool current_buffer_ {true};
// The container is just a CSC matrix.
HostDeviceVector<OffsetT> columns_ptr_;
HostDeviceVector<OffsetT> columns_ptr_b_;
dh::caching_device_vector<SketchEntry>& Current() {
if (current_buffer_) {
@@ -80,12 +83,25 @@ class SketchContainer {
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
* \param device GPU ID.
*/
SketchContainer(int32_t max_bin, bst_feature_t num_columns, bst_row_t num_rows, int32_t device) :
num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
SketchContainer(HostDeviceVector<FeatureType> const& feature_types,
int32_t max_bin,
bst_feature_t num_columns, bst_row_t num_rows,
int32_t device)
: num_rows_{num_rows},
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
CHECK_GE(device, 0);
// Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_);
this->columns_ptr_.Resize(num_columns + 1);
CHECK_GE(device, 0);
this->columns_ptr_b_.SetDevice(device_);
this->columns_ptr_b_.Resize(num_columns + 1);
this->feature_types_.Resize(feature_types.Size());
this->feature_types_.Copy(feature_types);
// Pull to device.
this->feature_types_.SetDevice(device);
this->feature_types_.ConstDeviceSpan();
this->feature_types_.ConstHostSpan();
timer_.Init(__func__);
}
/* \brief Return GPU ID for this container. */
@@ -127,6 +143,7 @@ class SketchContainer {
Span<SketchEntry const> Data() const {
return {this->Current().data().get(), this->Current().size()};
}
HostDeviceVector<FeatureType> const& FeatureTypes() const { return feature_types_; }
Span<OffsetT const> ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); }