Reduce time for some multi-gpu tests (#8288)
* Faster dask tests * Reuse AllReducer objects in tests. * Faster boost from prediction tests. * Use rmm dask fixture. * Speed up dask demo. * mypy * Format with black. * mypy * Clang-tidy Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -508,7 +508,7 @@ void SketchContainer::AllReduce() {
|
||||
|
||||
timer_.Start(__func__);
|
||||
if (!reducer_) {
|
||||
reducer_ = std::make_unique<dh::AllReducer>();
|
||||
reducer_ = std::make_shared<dh::AllReducer>();
|
||||
reducer_->Init(device_);
|
||||
}
|
||||
// Reduce the overhead on syncing.
|
||||
@@ -518,6 +518,7 @@ void SketchContainer::AllReduce() {
|
||||
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
|
||||
this->Prune(intermediate_num_cuts);
|
||||
|
||||
|
||||
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
|
||||
size_t n = d_columns_ptr.size();
|
||||
|
||||
@@ -37,7 +37,7 @@ class SketchContainer {
|
||||
|
||||
private:
|
||||
Monitor timer_;
|
||||
std::unique_ptr<dh::AllReducer> reducer_;
|
||||
std::shared_ptr<dh::AllReducer> reducer_;
|
||||
HostDeviceVector<FeatureType> feature_types_;
|
||||
bst_row_t num_rows_;
|
||||
bst_feature_t num_columns_;
|
||||
@@ -93,35 +93,37 @@ class SketchContainer {
|
||||
* \param num_columns Total number of columns in dataset.
|
||||
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
|
||||
* \param device GPU ID.
|
||||
* \param reducer Optional initialised reducer. Useful for speeding up testing.
|
||||
*/
|
||||
SketchContainer(HostDeviceVector<FeatureType> const& feature_types,
|
||||
int32_t max_bin,
|
||||
bst_feature_t num_columns, bst_row_t num_rows,
|
||||
int32_t device)
|
||||
: num_rows_{num_rows},
|
||||
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
CHECK_GE(device, 0);
|
||||
// Initialize Sketches for this dmatrix
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(num_columns + 1);
|
||||
this->columns_ptr_b_.SetDevice(device_);
|
||||
this->columns_ptr_b_.Resize(num_columns + 1);
|
||||
SketchContainer(HostDeviceVector<FeatureType> const &feature_types,
|
||||
int32_t max_bin, bst_feature_t num_columns,
|
||||
bst_row_t num_rows, int32_t device,
|
||||
std::shared_ptr<dh::AllReducer> reducer = nullptr)
|
||||
: num_rows_{num_rows},
|
||||
num_columns_{num_columns}, num_bins_{max_bin}, device_{device},
|
||||
reducer_(std::move(reducer)) {
|
||||
CHECK_GE(device, 0);
|
||||
// Initialize Sketches for this dmatrix
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(num_columns + 1);
|
||||
this->columns_ptr_b_.SetDevice(device_);
|
||||
this->columns_ptr_b_.Resize(num_columns + 1);
|
||||
|
||||
this->feature_types_.Resize(feature_types.Size());
|
||||
this->feature_types_.Copy(feature_types);
|
||||
// Pull to device.
|
||||
this->feature_types_.SetDevice(device);
|
||||
this->feature_types_.ConstDeviceSpan();
|
||||
this->feature_types_.ConstHostSpan();
|
||||
this->feature_types_.Resize(feature_types.Size());
|
||||
this->feature_types_.Copy(feature_types);
|
||||
// Pull to device.
|
||||
this->feature_types_.SetDevice(device);
|
||||
this->feature_types_.ConstDeviceSpan();
|
||||
this->feature_types_.ConstHostSpan();
|
||||
|
||||
auto d_feature_types = feature_types_.ConstDeviceSpan();
|
||||
has_categorical_ =
|
||||
!d_feature_types.empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
common::IsCatOp{});
|
||||
auto d_feature_types = feature_types_.ConstDeviceSpan();
|
||||
has_categorical_ =
|
||||
!d_feature_types.empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
common::IsCatOp{});
|
||||
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
/* \brief Return GPU ID for this container. */
|
||||
int32_t DeviceIdx() const { return device_; }
|
||||
/* \brief Whether the predictor matrix contains categorical features. */
|
||||
|
||||
Reference in New Issue
Block a user