Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -15,16 +15,17 @@ class DMatrixForTest : public data::SimpleDMatrix {
|
||||
|
||||
public:
|
||||
using SimpleDMatrix::SimpleDMatrix;
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override {
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
|
||||
const BatchParam& param) override {
|
||||
auto backup = this->gradient_index_;
|
||||
auto iter = SimpleDMatrix::GetGradientIndex(param);
|
||||
auto iter = SimpleDMatrix::GetGradientIndex(ctx, param);
|
||||
n_regen_ += (backup != this->gradient_index_);
|
||||
return iter;
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
|
||||
BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, const BatchParam& param) override {
|
||||
auto backup = this->ellpack_page_;
|
||||
auto iter = SimpleDMatrix::GetEllpackBatches(param);
|
||||
auto iter = SimpleDMatrix::GetEllpackBatches(ctx, param);
|
||||
n_regen_ += (backup != this->ellpack_page_);
|
||||
return iter;
|
||||
}
|
||||
@@ -50,8 +51,8 @@ class RegenTest : public ::testing::Test {
|
||||
HostDeviceVector<float> storage;
|
||||
auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage);
|
||||
auto adapter = data::ArrayAdapter(StringView{dense});
|
||||
p_fmat_ = std::shared_ptr<DMatrix>(new DMatrixForTest{
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), AllThreadsForTest()});
|
||||
p_fmat_ = std::shared_ptr<DMatrix>(
|
||||
new DMatrixForTest{&adapter, std::numeric_limits<float>::quiet_NaN(), AllThreadsForTest()});
|
||||
|
||||
p_fmat_->Info().labels.Reshape(256, 1);
|
||||
auto labels = p_fmat_->Info().labels.Data();
|
||||
@@ -74,7 +75,7 @@ class RegenTest : public ::testing::Test {
|
||||
auto for_test = dynamic_cast<DMatrixForTest*>(p_fmat_.get());
|
||||
CHECK(for_test);
|
||||
auto backup = for_test->NumRegen();
|
||||
for_test->GetBatches<Page>(BatchParam{});
|
||||
for_test->GetBatches<Page>(p_fmat_->Ctx(), BatchParam{});
|
||||
CHECK_EQ(for_test->NumRegen(), backup);
|
||||
|
||||
if (reset) {
|
||||
|
||||
Reference in New Issue
Block a user