Add data split mode to DMatrix MetaInfo (#8568)

This commit is contained in:
Rong Ou 2022-12-25 04:37:37 -08:00 committed by GitHub
parent 77b069c25d
commit 3ceeb8c61c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 113 additions and 103 deletions

View File

@ -203,7 +203,6 @@ Will print out something similar to (not actual output as it's too long for demo
"learner_train_param": { "learner_train_param": {
"booster": "gbtree", "booster": "gbtree",
"disable_default_eval_metric": "0", "disable_default_eval_metric": "0",
"dsplit": "auto",
"objective": "reg:squarederror" "objective": "reg:squarederror"
}, },
"metrics": [], "metrics": [],

View File

@ -126,12 +126,28 @@ XGB_DLL int XGBGetGlobalConfig(char const **out_config);
/*! /*!
* \brief load a data matrix * \brief load a data matrix
* \deprecated since 2.0.0
* \see XGDMatrixCreateFromURI()
* \param fname the name of the file * \param fname the name of the file
* \param silent whether print messages during loading * \param silent whether print messages during loading
* \param out a loaded data matrix * \param out a loaded data matrix
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out); XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out);
/*!
* \brief load a data matrix
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - uri: The URI of the input file.
* - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
* file is split accordingly; otherwise this is only an indicator on how the file was split
* beforehand. Default to row.
* \param out a loaded data matrix
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromURI(char const *config, DMatrixHandle *out);
/** /**
* @example c-api-demo.c * @example c-api-demo.c
*/ */

View File

@ -40,9 +40,7 @@ enum class DataType : uint8_t {
enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
enum class DataSplitMode : int { enum class DataSplitMode : int { kRow = 0, kCol = 1 };
kAuto = 0, kCol = 1, kRow = 2, kNone = 3
};
/*! /*!
* \brief Meta information about dataset, always sit in memory. * \brief Meta information about dataset, always sit in memory.
@ -60,6 +58,8 @@ class MetaInfo {
uint64_t num_nonzero_{0}; // NOLINT uint64_t num_nonzero_{0}; // NOLINT
/*! \brief label of each instance */ /*! \brief label of each instance */
linalg::Tensor<float, 2> labels; linalg::Tensor<float, 2> labels;
/*! \brief data split mode */
DataSplitMode data_split_mode{DataSplitMode::kRow};
/*! /*!
* \brief the index of begin and end of a group * \brief the index of begin and end of a group
* needed when the learning task is ranking. * needed when the learning task is ranking.
@ -544,15 +544,16 @@ class DMatrix {
* \brief Load DMatrix from URI. * \brief Load DMatrix from URI.
* \param uri The URI of input. * \param uri The URI of input.
* \param silent Whether print information during loading. * \param silent Whether print information during loading.
* \param data_split_mode Mode to read in part of the data, divided among the workers in distributed mode. * \param data_split_mode In distributed mode, split the input according this mode; otherwise,
* it's just an indicator on how the input was split beforehand.
* \param file_format The format type of the file, used for dmlc::Parser::Create. * \param file_format The format type of the file, used for dmlc::Parser::Create.
* By default "auto" will be able to load in both local binary file. * By default "auto" will be able to load in both local binary file.
* \param page_size Page size for external memory. * \param page_size Page size for external memory.
* \return The created DMatrix. * \return The created DMatrix.
*/ */
static DMatrix* Load(const std::string& uri, static DMatrix* Load(const std::string& uri,
bool silent, bool silent = true,
DataSplitMode data_split_mode, DataSplitMode data_split_mode = DataSplitMode::kRow,
const std::string& file_format = "auto"); const std::string& file_format = "auto");
/** /**

View File

@ -10,6 +10,7 @@ import sys
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from enum import IntEnum, unique
from functools import wraps from functools import wraps
from inspect import Parameter, signature from inspect import Parameter, signature
from typing import ( from typing import (
@ -608,6 +609,13 @@ def require_keyword_args(
_deprecate_positional_args = require_keyword_args(False) _deprecate_positional_args = require_keyword_args(False)
@unique
class DataSplitMode(IntEnum):
"""Supported data split mode for DMatrix."""
ROW = 0
COL = 1
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
@ -635,6 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
label_upper_bound: Optional[ArrayLike] = None, label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False, enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None: ) -> None:
"""Parameters """Parameters
---------- ----------
@ -728,6 +737,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
enable_categorical=enable_categorical, enable_categorical=enable_categorical,
data_split_mode=data_split_mode,
) )
assert handle is not None assert handle is not None
self.handle = handle self.handle = handle
@ -1332,6 +1342,7 @@ class QuantileDMatrix(DMatrix):
label_upper_bound: Optional[ArrayLike] = None, label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False, enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None: ) -> None:
self.max_bin: int = max_bin if max_bin is not None else 256 self.max_bin: int = max_bin if max_bin is not None else 256
self.missing = missing if missing is not None else np.nan self.missing = missing if missing is not None else np.nan

View File

@ -23,6 +23,7 @@ from .compat import DataFrame, lazy_isinstance
from .core import ( from .core import (
_LIB, _LIB,
DataIter, DataIter,
DataSplitMode,
DMatrix, DMatrix,
_check_call, _check_call,
_cuda_array_interface, _cuda_array_interface,
@ -865,13 +866,17 @@ def _from_uri(
missing: Optional[FloatCompatible], missing: Optional[FloatCompatible],
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
_warn_unused_missing(data, missing) _warn_unused_missing(data, missing)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
data = os.fspath(os.path.expanduser(data)) data = os.fspath(os.path.expanduser(data))
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data), args = {
ctypes.c_int(1), "uri": str(data),
ctypes.byref(handle))) "data_split_mode": int(data_split_mode),
}
config = bytes(json.dumps(args), "utf-8")
_check_call(_LIB.XGDMatrixCreateFromURI(config, ctypes.byref(handle)))
return handle, feature_names, feature_types return handle, feature_names, feature_types
@ -938,6 +943,7 @@ def dispatch_data_backend(
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
enable_categorical: bool = False, enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
'''Dispatch data for DMatrix.''' '''Dispatch data for DMatrix.'''
if not _is_cudf_ser(data) and not _is_pandas_series(data): if not _is_cudf_ser(data) and not _is_pandas_series(data):
@ -953,7 +959,7 @@ def dispatch_data_backend(
if _is_numpy_array(data): if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names, feature_types) return _from_numpy_array(data, missing, threads, feature_names, feature_types)
if _is_uri(data): if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types) return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
if _is_list(data): if _is_list(data):
return _from_list(data, missing, threads, feature_names, feature_types) return _from_list(data, missing, threads, feature_names, feature_types)
if _is_tuple(data): if _is_tuple(data):

View File

@ -206,17 +206,29 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
} }
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) { XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN();
auto data_split_mode = DataSplitMode::kNone;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
data_split_mode = DataSplitMode::kRow;
}
xgboost_CHECK_C_ARG_PTR(fname); xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, data_split_mode));
Json config{Object()};
config["uri"] = std::string{fname};
config["silent"] = silent;
std::string config_str;
Json::Dump(config, &config_str);
return XGDMatrixCreateFromURI(config_str.c_str(), out);
}
XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(config);
xgboost_CHECK_C_ARG_PTR(out);
auto jconfig = Json::Load(StringView{config});
std::string uri = RequiredArg<String>(jconfig, "uri", __func__);
auto silent = static_cast<bool>(OptionalArg<Integer, int64_t>(jconfig, "silent", 1));
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(jconfig, "data_split_mode", 0));
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(uri, silent, data_split_mode));
API_END(); API_END();
} }

View File

@ -112,10 +112,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt") DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt")
.describe("Name of the prediction file."); .describe("Name of the prediction file.");
DMLC_DECLARE_FIELD(dsplit).set_default(0) DMLC_DECLARE_FIELD(dsplit).set_default(0)
.add_enum("auto", 0) .add_enum("row", 0)
.add_enum("col", 1) .add_enum("col", 1)
.add_enum("row", 2)
.add_enum("none", 3)
.describe("Data split mode."); .describe("Data split mode.");
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0) DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
.describe("(Deprecated) Use iteration_begin/iteration_end instead."); .describe("(Deprecated) Use iteration_begin/iteration_end instead.");
@ -158,15 +156,6 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
if (name_pred == "stdout") { if (name_pred == "stdout") {
save_period = 0; save_period = 0;
} }
if (dsplit == static_cast<int>(DataSplitMode::kAuto)) {
if (collective::IsFederated()) {
dsplit = static_cast<int>(DataSplitMode::kNone);
} else if (collective::IsDistributed()) {
dsplit = static_cast<int>(DataSplitMode::kRow);
} else {
dsplit = static_cast<int>(DataSplitMode::kNone);
}
}
} }
}; };

View File

@ -783,10 +783,14 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode, DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
const std::string& file_format) { const std::string& file_format) {
CHECK(data_split_mode == DataSplitMode::kRow || auto need_split = false;
data_split_mode == DataSplitMode::kCol || if (collective::IsFederated()) {
data_split_mode == DataSplitMode::kNone) LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
<< "Precondition violated; data split mode can only be 'row', 'col', or 'none'"; } else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
need_split = true;
}
std::string fname, cache_file; std::string fname, cache_file;
size_t dlm_pos = uri.find('#'); size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) { if (dlm_pos != std::string::npos) {
@ -794,7 +798,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
fname = uri.substr(0, dlm_pos); fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos) CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification."; << "Only one `#` is allowed in file path for cache file specification.";
if (data_split_mode == DataSplitMode::kRow) { if (need_split && data_split_mode == DataSplitMode::kRow) {
std::ostringstream os; std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':'); std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) { for (size_t i = 0; i < cache_shards.size(); ++i) {
@ -828,7 +832,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
} }
int partid = 0, npart = 1; int partid = 0, npart = 1;
if (data_split_mode == DataSplitMode::kRow) { if (need_split && data_split_mode == DataSplitMode::kRow) {
partid = collective::GetRank(); partid = collective::GetRank();
npart = collective::GetWorldSize(); npart = collective::GetWorldSize();
} else { } else {
@ -887,7 +891,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
* since partitioned data not knowing the real number of features. */ * since partitioned data not knowing the real number of features. */
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1); collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
if (data_split_mode == DataSplitMode::kCol) { if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) { if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory."; LOG(FATAL) << "Column-wise data split is not support for external memory.";
} }
@ -898,6 +902,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
delete dmat; delete dmat;
return sliced; return sliced;
} else { } else {
dmat->Info().data_split_mode = data_split_mode;
return dmat; return dmat;
} }
} }

View File

@ -65,6 +65,7 @@ DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
out->Info() = this->Info().Copy(); out->Info() = this->Info().Copy();
out->Info().num_nonzero_ = h_offset.back(); out->Info().num_nonzero_ = h_offset.back();
} }
out->Info().data_split_mode = DataSplitMode::kCol;
return out; return out;
} }

View File

@ -273,8 +273,6 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) {
} }
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> { struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none.
DataSplitMode dsplit {DataSplitMode::kAuto};
// flag to disable default metric // flag to disable default metric
bool disable_default_eval_metric {false}; bool disable_default_eval_metric {false};
// FIXME(trivialfis): The following parameters belong to model itself, but can be // FIXME(trivialfis): The following parameters belong to model itself, but can be
@ -284,13 +282,6 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(LearnerTrainParam) { DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
DMLC_DECLARE_FIELD(dsplit)
.set_default(DataSplitMode::kAuto)
.add_enum("auto", DataSplitMode::kAuto)
.add_enum("col", DataSplitMode::kCol)
.add_enum("row", DataSplitMode::kRow)
.add_enum("none", DataSplitMode::kNone)
.describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(disable_default_eval_metric) DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(false) .set_default(false)
.describe("Flag to disable default metric. Set to >0 to disable"); .describe("Flag to disable default metric. Set to >0 to disable");
@ -445,12 +436,6 @@ class LearnerConfiguration : public Learner {
ConsoleLogger::Configure(args); ConsoleLogger::Configure(args);
// add additional parameters
// These are cosntraints that need to be satisfied.
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow;
}
// set seed only before the model is initialized // set seed only before the model is initialized
if (!initialized || ctx_.seed != old_seed) { if (!initialized || ctx_.seed != old_seed) {
common::GlobalRandom().seed(ctx_.seed); common::GlobalRandom().seed(ctx_.seed);
@ -1055,11 +1040,6 @@ class LearnerIO : public LearnerConfiguration {
auto n = tparam_.__DICT__(); auto n = tparam_.__DICT__();
cfg_.insert(n.cbegin(), n.cend()); cfg_.insert(n.cbegin(), n.cend());
// copy dsplit from config since it will not run again during restore
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow;
}
this->need_configuration_ = true; this->need_configuration_ = true;
} }
@ -1199,16 +1179,6 @@ class LearnerImpl : public LearnerIO {
local_map->erase(this); local_map->erase(this);
} }
} }
// Configuration before data is known.
void CheckDataSplitMode() {
if (collective::IsDistributed()) {
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
if (tparam_.dsplit == DataSplitMode::kCol) {
LOG(FATAL) << "Column-wise data split is currently not supported.";
}
}
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats, std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
std::string format) override { std::string format) override {
@ -1266,7 +1236,6 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter); common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
} }
this->CheckDataSplitMode();
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
@ -1295,7 +1264,6 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter); common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
} }
this->CheckDataSplitMode();
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
@ -1444,19 +1412,14 @@ class LearnerImpl : public LearnerIO {
MetaInfo const& info = p_fmat->Info(); MetaInfo const& info = p_fmat->Info();
info.Validate(ctx_.gpu_id); info.Validate(ctx_.gpu_id);
auto const row_based_split = [this]() { if (is_training) {
return tparam_.dsplit == DataSplitMode::kRow || tparam_.dsplit == DataSplitMode::kAuto; CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
}; << "Number of columns does not match number of features in "
if (row_based_split()) { "booster.";
if (is_training) { } else {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in " << "Number of columns does not match number of features in "
"booster."; "booster.";
} else {
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
}
} }
if (p_fmat->Info().num_row_ == 0) { if (p_fmat->Info().num_row_ == 0) {

View File

@ -185,6 +185,17 @@ TEST(CAPI, CatchDMLCError) {
EXPECT_THROW({ dmlc::Stream::Create("foo", "r"); }, dmlc::Error); EXPECT_THROW({ dmlc::Stream::Create("foo", "r"); }, dmlc::Error);
} }
TEST(CAPI, CatchDMLCErrorURI) {
Json config{Object()};
config["uri"] = String{"foo"};
config["silent"] = Integer{0};
std::string config_str;
Json::Dump(config, &config_str);
DMatrixHandle out;
ASSERT_EQ(XGDMatrixCreateFromURI(config_str.c_str(), &out), -1);
EXPECT_THROW({ dmlc::Stream::Create("foo", "r"); }, dmlc::Error);
}
TEST(CAPI, DMatrixSetFeatureName) { TEST(CAPI, DMatrixSetFeatureName) {
size_t constexpr kRows = 10; size_t constexpr kRows = 10;
bst_feature_t constexpr kCols = 2; bst_feature_t constexpr kCols = 2;

View File

@ -88,8 +88,7 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
fo << row_data.str() << "\n"; fo << row_data.str() << "\n";
} }
fo.close(); fo.close();
return std::shared_ptr<DMatrix>(DMatrix::Load( return std::shared_ptr<DMatrix>(DMatrix::Load(tmp_file + "#" + tmp_file + ".cache"));
tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
} }
// Test that elements are approximately equally distributed among bins // Test that elements are approximately equally distributed among bins

View File

@ -27,7 +27,6 @@ std::string GetModelStr() {
"train_parameter": { "train_parameter": {
"debug_verbose": "0", "debug_verbose": "0",
"disable_default_eval_metric": "0", "disable_default_eval_metric": "0",
"dsplit": "auto",
"nthread": "0", "nthread": "0",
"seed": "0", "seed": "0",
"seed_per_iteration": "0", "seed_per_iteration": "0",

View File

@ -143,7 +143,7 @@ TEST(DMatrix, Uri) {
// EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error); // EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error);
std::string uri = path + "?format=csv"; std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kRow)); dmat.reset(DMatrix::Load(uri, false));
ASSERT_EQ(dmat->Info().num_col_, kCols); ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows); ASSERT_EQ(dmat->Info().num_row_, kRows);

View File

@ -175,7 +175,7 @@ TEST(MetaInfo, LoadQid) {
os.set_stream(nullptr); os.set_stream(nullptr);
} }
std::unique_ptr<xgboost::DMatrix> dmat( std::unique_ptr<xgboost::DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone, "libsvm")); xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kRow, "libsvm"));
const xgboost::MetaInfo& info = dmat->Info(); const xgboost::MetaInfo& info = dmat->Info();
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12}; const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};

View File

@ -15,13 +15,14 @@ TEST(SimpleDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone); xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file);
// Test the metadata that was parsed // Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 2); EXPECT_EQ(dmat->Info().num_row_, 2);
EXPECT_EQ(dmat->Info().num_col_, 5); EXPECT_EQ(dmat->Info().num_col_, 5);
EXPECT_EQ(dmat->Info().num_nonzero_, 6); EXPECT_EQ(dmat->Info().num_nonzero_, 6);
EXPECT_EQ(dmat->Info().labels.Size(), dmat->Info().num_row_); EXPECT_EQ(dmat->Info().labels.Size(), dmat->Info().num_row_);
EXPECT_EQ(dmat->Info().data_split_mode, DataSplitMode::kRow);
delete dmat; delete dmat;
} }
@ -30,7 +31,7 @@ TEST(SimpleDMatrix, RowAccess) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, xgboost::DataSplitMode::kNone); xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false);
// Loop over the batches and count the records // Loop over the batches and count the records
int64_t row_count = 0; int64_t row_count = 0;
@ -53,7 +54,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone); xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file);
ASSERT_TRUE(dmat->SingleColBlock()); ASSERT_TRUE(dmat->SingleColBlock());
@ -360,6 +361,7 @@ TEST(SimpleDMatrix, SliceCol) {
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_); ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
ASSERT_EQ(out->Info().num_row_, kRows); ASSERT_EQ(out->Info().num_row_, kRows);
ASSERT_EQ(out->Info().num_nonzero_, kRows * kSlicCols); // dense ASSERT_EQ(out->Info().num_nonzero_, kRows * kSlicCols); // dense
ASSERT_EQ(out->Info().data_split_mode, DataSplitMode::kCol);
} }
} }
@ -367,12 +369,12 @@ TEST(SimpleDMatrix, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone); xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file);
data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat); data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat);
const std::string tmp_binfile = tempdir.path + "/csr_source.binary"; const std::string tmp_binfile = tempdir.path + "/csr_source.binary";
simple_dmat->SaveToLocalFile(tmp_binfile); simple_dmat->SaveToLocalFile(tmp_binfile);
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, xgboost::DataSplitMode::kNone); xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile);
EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_); EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_); EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);

View File

@ -108,8 +108,7 @@ TEST(SparsePageDMatrix, MetaInfo) {
size_t constexpr kEntries = 24; size_t constexpr kEntries = 24;
CreateBigTestData(tmp_file, kEntries); CreateBigTestData(tmp_file, kEntries);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load( xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", false);
tmp_file + "#" + tmp_file + ".cache", false, xgboost::DataSplitMode::kNone);
// Test the metadata that was parsed // Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 8ul); EXPECT_EQ(dmat->Info().num_row_, 8ul);
@ -136,8 +135,7 @@ TEST(SparsePageDMatrix, ColAccess) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache");
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, xgboost::DataSplitMode::kNone);
// Loop over the batches and assert the data is as expected // Loop over the batches and assert the data is as expected
size_t iter = 0; size_t iter = 0;

View File

@ -12,7 +12,7 @@ TEST(SparsePageDMatrix, EllpackPage) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone); DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache");
// Loop over the batches and assert the data is as expected // Loop over the batches and assert the data is as expected
size_t n = 0; size_t n = 0;

View File

@ -527,8 +527,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
if (page_size > 0) { if (page_size > 0) {
uri += "#" + tmp_file + ".cache"; uri += "#" + tmp_file + ".cache";
} }
std::unique_ptr<DMatrix> dmat( std::unique_ptr<DMatrix> dmat(DMatrix::Load(uri));
DMatrix::Load(uri, true, DataSplitMode::kNone, "auto"));
return dmat; return dmat;
} }

View File

@ -98,8 +98,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/big.libsvm"; const std::string tmp_file = tempdir.path + "/big.libsvm";
CreateBigTestData(tmp_file, 50000); CreateBigTestData(tmp_file, 50000);
std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load( std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache"));
tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
EXPECT_FALSE(dmat->SingleColBlock()); EXPECT_FALSE(dmat->SingleColBlock());
size_t num_row = dmat->Info().num_row_; size_t num_row = dmat->Info().num_row_;
std::vector<bst_float> labels(num_row); std::vector<bst_float> labels(num_row);