enable ROCm on latest XGBoost
This commit is contained in:
@@ -41,7 +41,7 @@ class SketchContainer {
|
||||
bst_row_t num_rows_;
|
||||
bst_feature_t num_columns_;
|
||||
int32_t num_bins_;
|
||||
int32_t device_;
|
||||
DeviceOrd device_;
|
||||
|
||||
// Double buffer as neither prune nor merge can be performed inplace.
|
||||
dh::device_vector<SketchEntry> entries_a_;
|
||||
@@ -93,35 +93,32 @@ class SketchContainer {
|
||||
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
|
||||
* \param device GPU ID.
|
||||
*/
|
||||
SketchContainer(HostDeviceVector<FeatureType> const &feature_types,
|
||||
int32_t max_bin, bst_feature_t num_columns,
|
||||
bst_row_t num_rows, int32_t device)
|
||||
: num_rows_{num_rows},
|
||||
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
CHECK_GE(device, 0);
|
||||
// Initialize Sketches for this dmatrix
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(num_columns + 1);
|
||||
this->columns_ptr_b_.SetDevice(device_);
|
||||
this->columns_ptr_b_.Resize(num_columns + 1);
|
||||
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, int32_t max_bin,
|
||||
bst_feature_t num_columns, bst_row_t num_rows, DeviceOrd device)
|
||||
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
CHECK(device.IsCUDA());
|
||||
// Initialize Sketches for this dmatrix
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(num_columns + 1);
|
||||
this->columns_ptr_b_.SetDevice(device_);
|
||||
this->columns_ptr_b_.Resize(num_columns + 1);
|
||||
|
||||
this->feature_types_.Resize(feature_types.Size());
|
||||
this->feature_types_.Copy(feature_types);
|
||||
// Pull to device.
|
||||
this->feature_types_.SetDevice(device);
|
||||
this->feature_types_.ConstDeviceSpan();
|
||||
this->feature_types_.ConstHostSpan();
|
||||
this->feature_types_.Resize(feature_types.Size());
|
||||
this->feature_types_.Copy(feature_types);
|
||||
// Pull to device.
|
||||
this->feature_types_.SetDevice(device);
|
||||
this->feature_types_.ConstDeviceSpan();
|
||||
this->feature_types_.ConstHostSpan();
|
||||
|
||||
auto d_feature_types = feature_types_.ConstDeviceSpan();
|
||||
has_categorical_ =
|
||||
!d_feature_types.empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
common::IsCatOp{});
|
||||
auto d_feature_types = feature_types_.ConstDeviceSpan();
|
||||
has_categorical_ =
|
||||
!d_feature_types.empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), common::IsCatOp{});
|
||||
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
/* \brief Return GPU ID for this container. */
|
||||
int32_t DeviceIdx() const { return device_; }
|
||||
[[nodiscard]] DeviceOrd DeviceIdx() const { return device_; }
|
||||
/* \brief Whether the predictor matrix contains categorical features. */
|
||||
bool HasCategorical() const { return has_categorical_; }
|
||||
/* \brief Accumulate weights of duplicated entries in input. */
|
||||
@@ -175,9 +172,7 @@ class SketchContainer {
|
||||
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||
timer_.Start(__func__);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||
@@ -195,7 +190,7 @@ class SketchContainer {
|
||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||
entries.data(), detail::SketchUnique{}, key_comp);
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
size_t n_uniques = dh::SegmentedUnique(
|
||||
thrust::cuda::par(alloc), d_column_scan.data(),
|
||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||
|
||||
Reference in New Issue
Block a user