Support categorical data in GPU sketching. (#6137)
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "device_helpers.cuh"
|
||||
#include "quantile.h"
|
||||
#include "timer.h"
|
||||
@@ -28,6 +29,7 @@ class SketchContainer {
|
||||
private:
|
||||
Monitor timer_;
|
||||
std::unique_ptr<dh::AllReducer> reducer_;
|
||||
HostDeviceVector<FeatureType> feature_types_;
|
||||
bst_row_t num_rows_;
|
||||
bst_feature_t num_columns_;
|
||||
int32_t num_bins_;
|
||||
@@ -39,6 +41,7 @@ class SketchContainer {
|
||||
bool current_buffer_ {true};
|
||||
// The container is just a CSC matrix.
|
||||
HostDeviceVector<OffsetT> columns_ptr_;
|
||||
HostDeviceVector<OffsetT> columns_ptr_b_;
|
||||
|
||||
dh::caching_device_vector<SketchEntry>& Current() {
|
||||
if (current_buffer_) {
|
||||
@@ -80,12 +83,25 @@ class SketchContainer {
|
||||
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
|
||||
* \param device GPU ID.
|
||||
*/
|
||||
SketchContainer(int32_t max_bin, bst_feature_t num_columns, bst_row_t num_rows, int32_t device) :
|
||||
num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
SketchContainer(HostDeviceVector<FeatureType> const& feature_types,
|
||||
int32_t max_bin,
|
||||
bst_feature_t num_columns, bst_row_t num_rows,
|
||||
int32_t device)
|
||||
: num_rows_{num_rows},
|
||||
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
CHECK_GE(device, 0);
|
||||
// Initialize Sketches for this dmatrix
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(num_columns + 1);
|
||||
CHECK_GE(device, 0);
|
||||
this->columns_ptr_b_.SetDevice(device_);
|
||||
this->columns_ptr_b_.Resize(num_columns + 1);
|
||||
|
||||
this->feature_types_.Resize(feature_types.Size());
|
||||
this->feature_types_.Copy(feature_types);
|
||||
// Pull to device.
|
||||
this->feature_types_.SetDevice(device);
|
||||
this->feature_types_.ConstDeviceSpan();
|
||||
this->feature_types_.ConstHostSpan();
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
/* \brief Return GPU ID for this container. */
|
||||
@@ -127,6 +143,7 @@ class SketchContainer {
|
||||
Span<SketchEntry const> Data() const {
|
||||
return {this->Current().data().get(), this->Current().size()};
|
||||
}
|
||||
HostDeviceVector<FeatureType> const& FeatureTypes() const { return feature_types_; }
|
||||
|
||||
Span<OffsetT const> ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user