diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index b5a2637af..1c3fe38cc 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -1076,6 +1076,8 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch( } } + // rabit::IsDistributed is not thread-safe + auto isDistributed = rabit::IsDistributed(); // partial results std::vector> splits(tasks.size()); // parallel enumeration @@ -1088,16 +1090,16 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch( const int32_t sibling_nid = nodes[node_idx].sibling_nid; const int32_t parent_nid = nodes[node_idx].parent_nid; - common::GradStatHist::GradType* hist_data = - reinterpret_cast(hist_[nid].data()); - common::GradStatHist::GradType* sibling_hist_data = sibling_nid > -1 ? - reinterpret_cast( - hist_[sibling_nid].data()) : nullptr; - common::GradStatHist::GradType* parent_hist_data = sibling_nid > -1 ? - reinterpret_cast(hist_[parent_nid].data()) : nullptr; - // reduce needed part of a hist here to have it in cache before enumeration - if (!rabit::IsDistributed()) { + if (!isDistributed) { + auto hist_data = reinterpret_cast(hist_[nid].data()); + auto sibling_hist_data = sibling_nid > -1 ? + reinterpret_cast( + hist_[sibling_nid].data()) : nullptr; + auto parent_hist_data = sibling_nid > -1 ? + reinterpret_cast( + hist_[parent_nid].data()) : nullptr; + const std::vector& cut_ptr = gmat.cut.Ptrs(); const size_t ibegin = 2 * cut_ptr[fid]; const size_t iend = 2 * cut_ptr[fid + 1];