Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <memory> // std::make_shared
|
||||
#include <xgboost/context.h> // for Context
|
||||
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for make_shared
|
||||
|
||||
#include "../../../src/data/iterative_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
@@ -10,7 +13,7 @@
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
template <typename Page, typename Iter, typename Cuts>
|
||||
void TestRefDMatrix(Cuts&& get_cuts) {
|
||||
void TestRefDMatrix(Context const* ctx, Cuts&& get_cuts) {
|
||||
int n_bins = 256;
|
||||
Iter iter(0.3, 2048);
|
||||
auto m = std::make_shared<IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||
@@ -20,8 +23,8 @@ void TestRefDMatrix(Cuts&& get_cuts) {
|
||||
auto m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), m, Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||
|
||||
for (auto const& page_0 : m->template GetBatches<Page>({})) {
|
||||
for (auto const& page_1 : m_1->template GetBatches<Page>({})) {
|
||||
for (auto const& page_0 : m->template GetBatches<Page>(ctx, {})) {
|
||||
for (auto const& page_1 : m_1->template GetBatches<Page>(ctx, {})) {
|
||||
auto const& cuts_0 = get_cuts(page_0);
|
||||
auto const& cuts_1 = get_cuts(page_1);
|
||||
ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
|
||||
@@ -32,8 +35,8 @@ void TestRefDMatrix(Cuts&& get_cuts) {
|
||||
|
||||
m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), nullptr, Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||
for (auto const& page_0 : m->template GetBatches<Page>({})) {
|
||||
for (auto const& page_1 : m_1->template GetBatches<Page>({})) {
|
||||
for (auto const& page_0 : m->template GetBatches<Page>(ctx, {})) {
|
||||
for (auto const& page_1 : m_1->template GetBatches<Page>(ctx, {})) {
|
||||
auto const& cuts_0 = get_cuts(page_0);
|
||||
auto const& cuts_1 = get_cuts(page_1);
|
||||
ASSERT_NE(cuts_0.Values(), cuts_1.Values());
|
||||
@@ -45,8 +48,8 @@ void TestRefDMatrix(Cuts&& get_cuts) {
|
||||
auto dm = RandomDataGenerator(2048, Iter::Cols(), 0.5).GenerateDMatrix(true);
|
||||
auto dqm = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), dm, Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||
for (auto const& page_0 : dm->template GetBatches<Page>({})) {
|
||||
for (auto const& page_1 : dqm->template GetBatches<Page>({})) {
|
||||
for (auto const& page_0 : dm->template GetBatches<Page>(ctx, {})) {
|
||||
for (auto const& page_1 : dqm->template GetBatches<Page>(ctx, {})) {
|
||||
auto const& cuts_0 = get_cuts(page_0);
|
||||
auto const& cuts_1 = get_cuts(page_1);
|
||||
ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
|
||||
|
||||
Reference in New Issue
Block a user