Avoid regenerating the gradient index for approx. (#7591)
This commit is contained in:
@@ -17,8 +17,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
|
||||
// block is parallelized on anything other than the batch/block size,
|
||||
// it should be reassigned
|
||||
const size_t batch_threads =
|
||||
std::max(size_t(1), std::min(batch.Size(),
|
||||
static_cast<size_t>(n_threads)));
|
||||
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
|
||||
auto page = batch.GetView();
|
||||
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
||||
size_t *p_part = partial_sums.Get();
|
||||
|
||||
@@ -108,5 +108,16 @@ class GHistIndexMatrix {
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
bool isDense_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Should we regenerate the gradient index?
|
||||
*
|
||||
* \param old Parameter stored in DMatrix.
|
||||
* \param p New parameter passed in by caller.
|
||||
*/
|
||||
inline bool RegenGHist(BatchParam old, BatchParam p) {
|
||||
// parameter is renewed or caller requests a regen
|
||||
return p.regen || (old.gpu_id != p.gpu_id || old.max_bin != p.max_bin);
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
|
||||
@@ -94,7 +94,8 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) {
|
||||
if (!gradient_index_ || RegenGHist(batch_param_, param)) {
|
||||
LOG(INFO) << "Generating new Gradient Index.";
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
CHECK_EQ(param.gpu_id, -1);
|
||||
// Used only by approx.
|
||||
|
||||
@@ -157,7 +157,7 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(begin_iter));
|
||||
}
|
||||
|
||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam& param) {
|
||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam ¶m) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
if (param.hess.empty() && !param.regen) {
|
||||
// hist method doesn't support full external memory implementation, so we concatenate
|
||||
@@ -176,10 +176,10 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
|
||||
|
||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
this->InitializeSparsePage();
|
||||
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{}) ||
|
||||
param.regen) {
|
||||
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
LOG(INFO) << "Generating new Gradient Index.";
|
||||
// Use sorted sketch for approx.
|
||||
auto sorted_sketch = param.regen;
|
||||
auto cuts =
|
||||
|
||||
@@ -14,7 +14,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
|
||||
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
size_t row_stride = 0;
|
||||
this->InitializeSparsePage();
|
||||
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) {
|
||||
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
|
||||
// reinitialize the cache
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
|
||||
Reference in New Issue
Block a user