/** * Copyright 2019-2024, XGBoost Contributors */ #pragma once #include // for size_t #include "../../common/device_vector.cuh" // for device_vector, caching_device_vector #include "../../common/timer.h" // for Monitor #include "xgboost/base.h" // for GradientPair #include "xgboost/data.h" // for BatchParam #include "xgboost/span.h" // for Span namespace xgboost::tree { struct GradientBasedSample { /** @brief Sampled rows in ELLPACK format. */ DMatrix* p_fmat; /** @brief Gradient pairs for the sampled rows. */ common::Span gpair; }; class SamplingStrategy { public: /*! \brief Sample from a DMatrix based on the given gradient pairs. */ virtual GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) = 0; virtual ~SamplingStrategy() = default; }; /*! \brief No sampling in in-memory mode. */ class NoSampling : public SamplingStrategy { public: explicit NoSampling(BatchParam batch_param); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: BatchParam batch_param_; }; /*! \brief No sampling in external memory mode. */ class ExternalMemoryNoSampling : public SamplingStrategy { public: explicit ExternalMemoryNoSampling(BatchParam batch_param); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: BatchParam batch_param_; }; /*! \brief Uniform sampling in in-memory mode. */ class UniformSampling : public SamplingStrategy { public: UniformSampling(BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: BatchParam batch_param_; float subsample_; }; /*! \brief No sampling in external memory mode. */ class ExternalMemoryUniformSampling : public SamplingStrategy { public: ExternalMemoryUniformSampling(size_t n_rows, BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: BatchParam batch_param_; float subsample_; std::unique_ptr p_fmat_new_{nullptr}; dh::device_vector gpair_{}; dh::caching_device_vector sample_row_index_; dh::device_vector compact_row_index_; }; /*! \brief Gradient-based sampling in in-memory mode.. */ class GradientBasedSampling : public SamplingStrategy { public: GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: BatchParam batch_param_; float subsample_; dh::caching_device_vector threshold_; dh::caching_device_vector grad_sum_; }; /*! \brief Gradient-based sampling in external memory mode.. */ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { public: ExternalMemoryGradientBasedSampling(size_t n_rows, BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: BatchParam batch_param_; float subsample_; dh::device_vector threshold_; dh::device_vector grad_sum_; std::unique_ptr p_fmat_new_{nullptr}; dh::device_vector gpair_; dh::device_vector sample_row_index_; dh::device_vector compact_row_index_; }; /*! \brief Draw a sample of rows from a DMatrix. * * \see Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., ... & Liu, T. Y. (2017). * Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information * Processing Systems (pp. 3146-3154). * \see Zhu, R. (2016). Gradient-based sampling: An adaptive importance sampling for least-squares. * In Advances in Neural Information Processing Systems (pp. 406-414). * \see Ohlsson, E. (1998). Sequential Poisson sampling. Journal of official Statistics, 14(2), 149. */ class GradientBasedSampler { public: GradientBasedSampler(Context const* ctx, size_t n_rows, const BatchParam& batch_param, float subsample, int sampling_method, bool is_external_memory); /*! \brief Sample from a DMatrix based on the given gradient pairs. */ GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat); /*! \brief Calculate the threshold used to normalize sampling probabilities. */ static size_t CalculateThresholdIndex(Context const* ctx, common::Span gpair, common::Span threshold, common::Span grad_sum, size_t sample_rows); private: common::Monitor monitor_; std::unique_ptr strategy_; }; }; // namespace xgboost::tree