Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -15,16 +15,17 @@ class DMatrixForTest : public data::SimpleDMatrix {
public:
using SimpleDMatrix::SimpleDMatrix;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override {
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
const BatchParam& param) override {
auto backup = this->gradient_index_;
auto iter = SimpleDMatrix::GetGradientIndex(param);
auto iter = SimpleDMatrix::GetGradientIndex(ctx, param);
n_regen_ += (backup != this->gradient_index_);
return iter;
}
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, const BatchParam& param) override {
auto backup = this->ellpack_page_;
auto iter = SimpleDMatrix::GetEllpackBatches(param);
auto iter = SimpleDMatrix::GetEllpackBatches(ctx, param);
n_regen_ += (backup != this->ellpack_page_);
return iter;
}
@@ -50,8 +51,8 @@ class RegenTest : public ::testing::Test {
HostDeviceVector<float> storage;
auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage);
auto adapter = data::ArrayAdapter(StringView{dense});
p_fmat_ = std::shared_ptr<DMatrix>(new DMatrixForTest{
&adapter, std::numeric_limits<float>::quiet_NaN(), AllThreadsForTest()});
p_fmat_ = std::shared_ptr<DMatrix>(
new DMatrixForTest{&adapter, std::numeric_limits<float>::quiet_NaN(), AllThreadsForTest()});
p_fmat_->Info().labels.Reshape(256, 1);
auto labels = p_fmat_->Info().labels.Data();
@@ -74,7 +75,7 @@ class RegenTest : public ::testing::Test {
auto for_test = dynamic_cast<DMatrixForTest*>(p_fmat_.get());
CHECK(for_test);
auto backup = for_test->NumRegen();
for_test->GetBatches<Page>(BatchParam{});
for_test->GetBatches<Page>(p_fmat_->Ctx(), BatchParam{});
CHECK_EQ(for_test->NumRegen(), backup);
if (reset) {