Add data split mode to DMatrix MetaInfo (#8568)

This commit is contained in:
Rong Ou 2022-12-25 04:37:37 -08:00 committed by GitHub
parent 77b069c25d
commit 3ceeb8c61c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 113 additions and 103 deletions

View File

@ -203,7 +203,6 @@ Will print out something similar to (not actual output as it's too long for demo
"learner_train_param": {
"booster": "gbtree",
"disable_default_eval_metric": "0",
"dsplit": "auto",
"objective": "reg:squarederror"
},
"metrics": [],

View File

@ -126,12 +126,28 @@ XGB_DLL int XGBGetGlobalConfig(char const **out_config);
/*!
* \brief load a data matrix
* \deprecated since 2.0.0
* \see XGDMatrixCreateFromURI()
* \param fname the name of the file
* \param silent whether print messages during loading
* \param out a loaded data matrix
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out);
/*!
* \brief load a data matrix
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - uri: The URI of the input file.
* - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
* file is split accordingly; otherwise this is only an indicator on how the file was split
* beforehand. Default to row.
* \param out a loaded data matrix
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromURI(char const *config, DMatrixHandle *out);
/**
* @example c-api-demo.c
*/

View File

@ -40,9 +40,7 @@ enum class DataType : uint8_t {
enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
enum class DataSplitMode : int {
kAuto = 0, kCol = 1, kRow = 2, kNone = 3
};
enum class DataSplitMode : int { kRow = 0, kCol = 1 };
/*!
* \brief Meta information about dataset, always sit in memory.
@ -60,6 +58,8 @@ class MetaInfo {
uint64_t num_nonzero_{0}; // NOLINT
/*! \brief label of each instance */
linalg::Tensor<float, 2> labels;
/*! \brief data split mode */
DataSplitMode data_split_mode{DataSplitMode::kRow};
/*!
* \brief the index of begin and end of a group
* needed when the learning task is ranking.
@ -544,15 +544,16 @@ class DMatrix {
* \brief Load DMatrix from URI.
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode Mode to read in part of the data, divided among the workers in distributed mode.
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
* it's just an indicator on how the input was split beforehand.
* \param file_format The format type of the file, used for dmlc::Parser::Create.
* By default "auto" will be able to load in both local binary file.
* \param page_size Page size for external memory.
* \return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri,
bool silent,
DataSplitMode data_split_mode,
bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow,
const std::string& file_format = "auto");
/**

View File

@ -10,6 +10,7 @@ import sys
import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping
from enum import IntEnum, unique
from functools import wraps
from inspect import Parameter, signature
from typing import (
@ -608,6 +609,13 @@ def require_keyword_args(
_deprecate_positional_args = require_keyword_args(False)
@unique
class DataSplitMode(IntEnum):
"""Supported data split mode for DMatrix."""
ROW = 0
COL = 1
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""Data Matrix used in XGBoost.
@ -635,6 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None:
"""Parameters
----------
@ -728,6 +737,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
data_split_mode=data_split_mode,
)
assert handle is not None
self.handle = handle
@ -1332,6 +1342,7 @@ class QuantileDMatrix(DMatrix):
label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None:
self.max_bin: int = max_bin if max_bin is not None else 256
self.missing = missing if missing is not None else np.nan

View File

@ -23,6 +23,7 @@ from .compat import DataFrame, lazy_isinstance
from .core import (
_LIB,
DataIter,
DataSplitMode,
DMatrix,
_check_call,
_cuda_array_interface,
@ -865,13 +866,17 @@ def _from_uri(
missing: Optional[FloatCompatible],
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
_warn_unused_missing(data, missing)
handle = ctypes.c_void_p()
data = os.fspath(os.path.expanduser(data))
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
ctypes.c_int(1),
ctypes.byref(handle)))
args = {
"uri": str(data),
"data_split_mode": int(data_split_mode),
}
config = bytes(json.dumps(args), "utf-8")
_check_call(_LIB.XGDMatrixCreateFromURI(config, ctypes.byref(handle)))
return handle, feature_names, feature_types
@ -938,6 +943,7 @@ def dispatch_data_backend(
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
'''Dispatch data for DMatrix.'''
if not _is_cudf_ser(data) and not _is_pandas_series(data):
@ -953,7 +959,7 @@ def dispatch_data_backend(
if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types)
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
if _is_list(data):
return _from_list(data, missing, threads, feature_names, feature_types)
if _is_tuple(data):

View File

@ -206,17 +206,29 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
}
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN();
auto data_split_mode = DataSplitMode::kNone;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
data_split_mode = DataSplitMode::kRow;
}
xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, data_split_mode));
Json config{Object()};
config["uri"] = std::string{fname};
config["silent"] = silent;
std::string config_str;
Json::Dump(config, &config_str);
return XGDMatrixCreateFromURI(config_str.c_str(), out);
}
XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(config);
xgboost_CHECK_C_ARG_PTR(out);
auto jconfig = Json::Load(StringView{config});
std::string uri = RequiredArg<String>(jconfig, "uri", __func__);
auto silent = static_cast<bool>(OptionalArg<Integer, int64_t>(jconfig, "silent", 1));
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(jconfig, "data_split_mode", 0));
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(uri, silent, data_split_mode));
API_END();
}

View File

@ -112,10 +112,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt")
.describe("Name of the prediction file.");
DMLC_DECLARE_FIELD(dsplit).set_default(0)
.add_enum("auto", 0)
.add_enum("row", 0)
.add_enum("col", 1)
.add_enum("row", 2)
.add_enum("none", 3)
.describe("Data split mode.");
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
.describe("(Deprecated) Use iteration_begin/iteration_end instead.");
@ -158,15 +156,6 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
if (name_pred == "stdout") {
save_period = 0;
}
if (dsplit == static_cast<int>(DataSplitMode::kAuto)) {
if (collective::IsFederated()) {
dsplit = static_cast<int>(DataSplitMode::kNone);
} else if (collective::IsDistributed()) {
dsplit = static_cast<int>(DataSplitMode::kRow);
} else {
dsplit = static_cast<int>(DataSplitMode::kNone);
}
}
}
};

View File

@ -783,10 +783,14 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
const std::string& file_format) {
CHECK(data_split_mode == DataSplitMode::kRow ||
data_split_mode == DataSplitMode::kCol ||
data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; data split mode can only be 'row', 'col', or 'none'";
auto need_split = false;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
need_split = true;
}
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
@ -794,7 +798,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification.";
if (data_split_mode == DataSplitMode::kRow) {
if (need_split && data_split_mode == DataSplitMode::kRow) {
std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) {
@ -828,7 +832,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
}
int partid = 0, npart = 1;
if (data_split_mode == DataSplitMode::kRow) {
if (need_split && data_split_mode == DataSplitMode::kRow) {
partid = collective::GetRank();
npart = collective::GetWorldSize();
} else {
@ -887,7 +891,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
* since partitioned data not knowing the real number of features. */
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
if (data_split_mode == DataSplitMode::kCol) {
if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory.";
}
@ -898,6 +902,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
delete dmat;
return sliced;
} else {
dmat->Info().data_split_mode = data_split_mode;
return dmat;
}
}

View File

@ -65,6 +65,7 @@ DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
out->Info() = this->Info().Copy();
out->Info().num_nonzero_ = h_offset.back();
}
out->Info().data_split_mode = DataSplitMode::kCol;
return out;
}

View File

@ -273,8 +273,6 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) {
}
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none.
DataSplitMode dsplit {DataSplitMode::kAuto};
// flag to disable default metric
bool disable_default_eval_metric {false};
// FIXME(trivialfis): The following parameters belong to model itself, but can be
@ -284,13 +282,6 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// declare parameters
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
DMLC_DECLARE_FIELD(dsplit)
.set_default(DataSplitMode::kAuto)
.add_enum("auto", DataSplitMode::kAuto)
.add_enum("col", DataSplitMode::kCol)
.add_enum("row", DataSplitMode::kRow)
.add_enum("none", DataSplitMode::kNone)
.describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(false)
.describe("Flag to disable default metric. Set to >0 to disable");
@ -445,12 +436,6 @@ class LearnerConfiguration : public Learner {
ConsoleLogger::Configure(args);
// add additional parameters
// These are cosntraints that need to be satisfied.
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow;
}
// set seed only before the model is initialized
if (!initialized || ctx_.seed != old_seed) {
common::GlobalRandom().seed(ctx_.seed);
@ -1055,11 +1040,6 @@ class LearnerIO : public LearnerConfiguration {
auto n = tparam_.__DICT__();
cfg_.insert(n.cbegin(), n.cend());
// copy dsplit from config since it will not run again during restore
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow;
}
this->need_configuration_ = true;
}
@ -1199,16 +1179,6 @@ class LearnerImpl : public LearnerIO {
local_map->erase(this);
}
}
// Configuration before data is known.
void CheckDataSplitMode() {
if (collective::IsDistributed()) {
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
if (tparam_.dsplit == DataSplitMode::kCol) {
LOG(FATAL) << "Column-wise data split is currently not supported.";
}
}
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
std::string format) override {
@ -1266,7 +1236,6 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache();
@ -1295,7 +1264,6 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache();
@ -1444,19 +1412,14 @@ class LearnerImpl : public LearnerIO {
MetaInfo const& info = p_fmat->Info();
info.Validate(ctx_.gpu_id);
auto const row_based_split = [this]() {
return tparam_.dsplit == DataSplitMode::kRow || tparam_.dsplit == DataSplitMode::kAuto;
};
if (row_based_split()) {
if (is_training) {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
} else {
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
}
if (is_training) {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
} else {
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
}
if (p_fmat->Info().num_row_ == 0) {

View File

@ -185,6 +185,17 @@ TEST(CAPI, CatchDMLCError) {
EXPECT_THROW({ dmlc::Stream::Create("foo", "r"); }, dmlc::Error);
}
TEST(CAPI, CatchDMLCErrorURI) {
Json config{Object()};
config["uri"] = String{"foo"};
config["silent"] = Integer{0};
std::string config_str;
Json::Dump(config, &config_str);
DMatrixHandle out;
ASSERT_EQ(XGDMatrixCreateFromURI(config_str.c_str(), &out), -1);
EXPECT_THROW({ dmlc::Stream::Create("foo", "r"); }, dmlc::Error);
}
TEST(CAPI, DMatrixSetFeatureName) {
size_t constexpr kRows = 10;
bst_feature_t constexpr kCols = 2;

View File

@ -88,8 +88,7 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
fo << row_data.str() << "\n";
}
fo.close();
return std::shared_ptr<DMatrix>(DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
return std::shared_ptr<DMatrix>(DMatrix::Load(tmp_file + "#" + tmp_file + ".cache"));
}
// Test that elements are approximately equally distributed among bins

View File

@ -27,7 +27,6 @@ std::string GetModelStr() {
"train_parameter": {
"debug_verbose": "0",
"disable_default_eval_metric": "0",
"dsplit": "auto",
"nthread": "0",
"seed": "0",
"seed_per_iteration": "0",

View File

@ -143,7 +143,7 @@ TEST(DMatrix, Uri) {
// EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error);
std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kRow));
dmat.reset(DMatrix::Load(uri, false));
ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows);

View File

@ -175,7 +175,7 @@ TEST(MetaInfo, LoadQid) {
os.set_stream(nullptr);
}
std::unique_ptr<xgboost::DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone, "libsvm"));
xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kRow, "libsvm"));
const xgboost::MetaInfo& info = dmat->Info();
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};

View File

@ -15,13 +15,14 @@ TEST(SimpleDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file);
// Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 2);
EXPECT_EQ(dmat->Info().num_col_, 5);
EXPECT_EQ(dmat->Info().num_nonzero_, 6);
EXPECT_EQ(dmat->Info().labels.Size(), dmat->Info().num_row_);
EXPECT_EQ(dmat->Info().data_split_mode, DataSplitMode::kRow);
delete dmat;
}
@ -30,7 +31,7 @@ TEST(SimpleDMatrix, RowAccess) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, xgboost::DataSplitMode::kNone);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false);
// Loop over the batches and count the records
int64_t row_count = 0;
@ -53,7 +54,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file);
ASSERT_TRUE(dmat->SingleColBlock());
@ -360,6 +361,7 @@ TEST(SimpleDMatrix, SliceCol) {
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
ASSERT_EQ(out->Info().num_row_, kRows);
ASSERT_EQ(out->Info().num_nonzero_, kRows * kSlicCols); // dense
ASSERT_EQ(out->Info().data_split_mode, DataSplitMode::kCol);
}
}
@ -367,12 +369,12 @@ TEST(SimpleDMatrix, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file);
data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat);
const std::string tmp_binfile = tempdir.path + "/csr_source.binary";
simple_dmat->SaveToLocalFile(tmp_binfile);
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, xgboost::DataSplitMode::kNone);
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile);
EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);

View File

@ -108,8 +108,7 @@ TEST(SparsePageDMatrix, MetaInfo) {
size_t constexpr kEntries = 24;
CreateBigTestData(tmp_file, kEntries);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", false, xgboost::DataSplitMode::kNone);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", false);
// Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 8ul);
@ -136,8 +135,7 @@ TEST(SparsePageDMatrix, ColAccess) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat =
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, xgboost::DataSplitMode::kNone);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache");
// Loop over the batches and assert the data is as expected
size_t iter = 0;

View File

@ -12,7 +12,7 @@ TEST(SparsePageDMatrix, EllpackPage) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone);
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache");
// Loop over the batches and assert the data is as expected
size_t n = 0;

View File

@ -527,8 +527,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
if (page_size > 0) {
uri += "#" + tmp_file + ".cache";
}
std::unique_ptr<DMatrix> dmat(
DMatrix::Load(uri, true, DataSplitMode::kNone, "auto"));
std::unique_ptr<DMatrix> dmat(DMatrix::Load(uri));
return dmat;
}

View File

@ -98,8 +98,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/big.libsvm";
CreateBigTestData(tmp_file, 50000);
std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache"));
EXPECT_FALSE(dmat->SingleColBlock());
size_t num_row = dmat->Info().num_row_;
std::vector<bst_float> labels(num_row);