Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -24,7 +24,8 @@ struct GradientBasedSample {
class SamplingStrategy {
public:
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
virtual GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) = 0;
virtual GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) = 0;
virtual ~SamplingStrategy() = default;
};
@@ -32,7 +33,8 @@ class SamplingStrategy {
class NoSampling : public SamplingStrategy {
public:
explicit NoSampling(EllpackPageImpl const* page);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
EllpackPageImpl const* page_;
@@ -41,10 +43,10 @@ class NoSampling : public SamplingStrategy {
/*! \brief No sampling in external memory mode. */
class ExternalMemoryNoSampling : public SamplingStrategy {
public:
ExternalMemoryNoSampling(EllpackPageImpl const* page,
size_t n_rows,
const BatchParam& batch_param);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page, size_t n_rows,
BatchParam batch_param);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
BatchParam batch_param_;
@@ -56,7 +58,8 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
class UniformSampling : public SamplingStrategy {
public:
UniformSampling(EllpackPageImpl const* page, float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
EllpackPageImpl const* page_;
@@ -66,10 +69,9 @@ class UniformSampling : public SamplingStrategy {
/*! \brief No sampling in external memory mode. */
class ExternalMemoryUniformSampling : public SamplingStrategy {
public:
ExternalMemoryUniformSampling(size_t n_rows,
BatchParam batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
ExternalMemoryUniformSampling(size_t n_rows, BatchParam batch_param, float subsample);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
BatchParam batch_param_;
@@ -82,11 +84,10 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
/*! \brief Gradient-based sampling in in-memory mode.. */
class GradientBasedSampling : public SamplingStrategy {
public:
GradientBasedSampling(EllpackPageImpl const* page,
size_t n_rows,
const BatchParam& batch_param,
GradientBasedSampling(EllpackPageImpl const* page, size_t n_rows, const BatchParam& batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
EllpackPageImpl const* page_;
@@ -98,10 +99,9 @@ class GradientBasedSampling : public SamplingStrategy {
/*! \brief Gradient-based sampling in external memory mode.. */
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
public:
ExternalMemoryGradientBasedSampling(size_t n_rows,
BatchParam batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
ExternalMemoryGradientBasedSampling(size_t n_rows, BatchParam batch_param, float subsample);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
BatchParam batch_param_;
@@ -124,14 +124,11 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
*/
class GradientBasedSampler {
public:
GradientBasedSampler(EllpackPageImpl const* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample,
int sampling_method);
GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page, size_t n_rows,
const BatchParam& batch_param, float subsample, int sampling_method);
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,