[Breaking] Require format to be specified in input URI. (#9077)

Previously, we use `libsvm` as default when format is not specified. However, the dmlc
data parser is not particularly robust against errors, and the most common type of error
is undefined format.

Along with which, we will recommend users to use other data loader instead. We will
continue the maintenance of the parsers as it's currently used for many internal tests
including federated learning.
This commit is contained in:
Jiaming Yuan 2023-04-28 19:45:15 +08:00 committed by GitHub
parent e922004329
commit 1f9a57d17b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 327 additions and 268 deletions

View File

@ -72,7 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
tmp_file <- tempfile(fileext = ".libsvm")
writeLines(tmp, tmp_file)
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
expect_equal(dim(dtest4), c(3, 4))
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))

View File

@ -20,10 +20,10 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 2
# The path of training data
data = "agaricus.txt.train"
data = "agaricus.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "agaricus.txt.test"
eval[test] = "agaricus.txt.test?format=libsvm"
# evaluate on training data as well each round
eval_train = 1
# The path of test data
test:data = "agaricus.txt.test"
test:data = "agaricus.txt.test?format=libsvm"

View File

@ -21,8 +21,8 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 0
# The path of training data
data = "machine.txt.train"
data = "machine.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "machine.txt.test"
eval[test] = "machine.txt.test?format=libsvm"
# The path of test data
test:data = "machine.txt.test"
test:data = "machine.txt.test?format=libsvm"

View File

@ -42,8 +42,8 @@ int main() {
// load the data
DMatrixHandle dtrain, dtest;
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train?format=libsvm", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test?format=libsvm", silent, &dtest));
// create the booster
BoosterHandle booster;

View File

@ -7,15 +7,19 @@ import os
import xgboost as xgb
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
watchlist = [(dtest, "eval"), (dtrain, "train")]
###
# advanced: start from a initial base prediction
#
print('start running example to start from a initial prediction')
print("start running example to start from a initial prediction")
# specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
# train xgboost for 1 round
bst = xgb.train(param, dtrain, 1, watchlist)
# Note: we need the margin value instead of transformed prediction in
@ -27,5 +31,5 @@ ptest = bst.predict(dtest, output_margin=True)
dtrain.set_base_margin(ptrain)
dtest.set_base_margin(ptest)
print('this is result of running from initial prediction')
print("this is result of running from initial prediction")
bst = xgb.train(param, dtrain, 1, watchlist)

View File

@ -10,27 +10,45 @@ import xgboost as xgb
# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic'}
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
num_round = 2
print('running cross validation')
print("running cross validation")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)])
xgb.cv(
param,
dtrain,
num_round,
nfold=5,
metrics={"error"},
seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)],
)
print('running cross validation, disable standard deviation display')
print("running cross validation, disable standard deviation display")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
metrics={'error'}, seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=False),
xgb.callback.EarlyStopping(3)])
res = xgb.cv(
param,
dtrain,
num_boost_round=10,
nfold=5,
metrics={"error"},
seed=0,
callbacks=[
xgb.callback.EvaluationMonitor(show_stdv=False),
xgb.callback.EarlyStopping(3),
],
)
print(res)
print('running cross validation, with preprocessing function')
print("running cross validation, with preprocessing function")
# define the preprocessing function
# used to return the preprocessed training, test data, and parameter
# we can use this to do weight rescale, etc.
@ -38,32 +56,36 @@ print('running cross validation, with preprocessing function')
def fpreproc(dtrain, dtest, param):
label = dtrain.get_label()
ratio = float(np.sum(label == 0)) / np.sum(label == 1)
param['scale_pos_weight'] = ratio
param["scale_pos_weight"] = ratio
return (dtrain, dtest, param)
# do cross validation, for each fold
# the dtrain, dtest, param will be passed into fpreproc
# then the return value of fpreproc will be used to generate
# results of that fold
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'auc'}, seed=0, fpreproc=fpreproc)
xgb.cv(param, dtrain, num_round, nfold=5, metrics={"auc"}, seed=0, fpreproc=fpreproc)
###
# you can also do cross validation with customized loss function
# See custom_objective.py
##
print('running cross validation, with customized loss function')
print("running cross validation, with customized loss function")
def logregobj(preds, dtrain):
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
grad = preds - labels
hess = preds * (1.0 - preds)
return grad, hess
def evalerror(preds, dtrain):
labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
return "error", float(sum(labels != (preds > 0.0))) / len(labels)
param = {'max_depth':2, 'eta':1}
param = {"max_depth": 2, "eta": 1}
# train with customized objective
xgb.cv(param, dtrain, num_round, nfold=5, seed=0,
obj=logregobj, feval=evalerror)
xgb.cv(param, dtrain, num_round, nfold=5, seed=0, obj=logregobj, feval=evalerror)

View File

@ -7,28 +7,37 @@ import os
import xgboost as xgb
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
param = [('max_depth', 2), ('objective', 'binary:logistic'), ('eval_metric', 'logloss'), ('eval_metric', 'error')]
param = [
("max_depth", 2),
("objective", "binary:logistic"),
("eval_metric", "logloss"),
("eval_metric", "error"),
]
num_round = 2
watchlist = [(dtest,'eval'), (dtrain,'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
evals_result = {}
bst = xgb.train(param, dtrain, num_round, watchlist, evals_result=evals_result)
print('Access logloss metric directly from evals_result:')
print(evals_result['eval']['logloss'])
print("Access logloss metric directly from evals_result:")
print(evals_result["eval"]["logloss"])
print('')
print('Access metrics through a loop:')
print("")
print("Access metrics through a loop:")
for e_name, e_mtrs in evals_result.items():
print('- {}'.format(e_name))
print("- {}".format(e_name))
for e_mtr_name, e_mtr_vals in e_mtrs.items():
print(' - {}'.format(e_mtr_name))
print(' - {}'.format(e_mtr_vals))
print(" - {}".format(e_mtr_name))
print(" - {}".format(e_mtr_vals))
print('')
print('Access complete dictionary:')
print("")
print("Access complete dictionary:")
print(evals_result)

View File

@ -11,14 +11,22 @@ import xgboost as xgb
# basically, we are using linear model, instead of tree for our boosters
##
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
# change booster to gblinear, so that we are fitting a linear model
# alpha is the L1 regularizer
# lambda is the L2 regularizer
# you can also set lambda_bias which is L2 regularizer on the bias term
param = {'objective':'binary:logistic', 'booster':'gblinear',
'alpha': 0.0001, 'lambda': 1}
param = {
"objective": "binary:logistic",
"booster": "gblinear",
"alpha": 0.0001,
"lambda": 1,
}
# normally, you do not need to set eta (step_size)
# XGBoost uses a parallel coordinate descent algorithm (shotgun),
@ -29,9 +37,15 @@ param = {'objective':'binary:logistic', 'booster':'gblinear',
##
# the rest of settings are the same
##
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist)
preds = bst.predict(dtest)
labels = dtest.get_label()
print('error=%f' % (sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))))
print(
"error=%f"
% (
sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i])
/ float(len(preds))
)
)

View File

@ -16,8 +16,8 @@ test = os.path.join(CURRENT_DIR, "../data/agaricus.txt.test")
def native_interface():
# load data in do training
dtrain = xgb.DMatrix(train)
dtest = xgb.DMatrix(test)
dtrain = xgb.DMatrix(train + "?format=libsvm")
dtest = xgb.DMatrix(test + "?format=libsvm")
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3

View File

@ -8,14 +8,18 @@ import xgboost as xgb
# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3
bst = xgb.train(param, dtrain, num_round, watchlist)
print('start testing predict the leaf indices')
print("start testing predict the leaf indices")
# predict using first 2 tree
leafindex = bst.predict(
dtest, iteration_range=(0, 2), pred_leaf=True, strict_shape=True

View File

@ -77,7 +77,7 @@ The external memory version takes in the following `URI <https://en.wikipedia.or
.. code-block:: none
filename#cacheprefix
filename?format=libsvm#cacheprefix
The ``filename`` is the normal path to LIBSVM format file you want to load in, and
``cacheprefix`` is a path to a cache file that XGBoost will use for caching preprocessed
@ -97,13 +97,13 @@ you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSV
.. code-block:: python
dtrain = DMatrix('../data/agaricus.txt.train#dtrain.cache')
dtrain = DMatrix('../data/agaricus.txt.train?format=libsvm#dtrain.cache')
XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to a new file named
``dtrain.cache`` as an on disk cache for storing preprocessed data in an internal binary format. For
more notes about text input formats, see :doc:`/tutorials/input_format`.
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train#dtrain.cache"``.
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
**********************************

View File

@ -2,10 +2,15 @@
Text Input Format of DMatrix
############################
.. _basic_input_format:
Here we will briefly describe the text input formats for XGBoost. However, for users with access to a supported language environment like Python or R, it's recommended to use data parsers from that ecosystem instead. For instance, :py:func:`sklearn.datasets.load_svmlight_file`.
******************
Basic Input Format
******************
XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv`` or ``train.csv?format=libsvm``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
For training or predicting, XGBoost takes an instance file with the format as below:

View File

@ -138,7 +138,11 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
/*!
* \brief load a data matrix
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - uri: The URI of the input file.
* - uri: The URI of the input file. The URI parameter `format` is required when loading text data.
* \verbatim embed:rst:leading-asterisk
* See :doc:`/tutorials/input_format` for more info.
* \endverbatim
* - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
* file is split accordingly; otherwise this is only an indicator on how the file was split

View File

@ -566,21 +566,17 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}
/*!
/**
* \brief Load DMatrix from URI.
*
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
* it's just an indicator on how the input was split beforehand.
* \param file_format The format type of the file, used for dmlc::Parser::Create.
* By default "auto" will be able to load in both local binary file.
* \param page_size Page size for external memory.
* \return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri,
bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow,
const std::string& file_format = "auto");
static DMatrix* Load(const std::string& uri, bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Creates a new DMatrix from an external data adapter.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2021 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -62,8 +62,8 @@ public class BasicWalkThrough {
public static void main(String[] args) throws IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
@ -112,7 +112,8 @@ public class BasicWalkThrough {
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DataLoader.CSRSparseData spData =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
DMatrix.SparseType.CSR, 127);

View File

@ -32,8 +32,8 @@ public class BoostFromPrediction {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -30,7 +30,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class CrossValidation {
public static void main(String[] args) throws IOException, XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//set params
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -139,9 +139,9 @@ public class CustomObjective {
public static void main(String[] args) throws XGBoostError {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//load valid mat (svmlight format)
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);

View File

@ -29,9 +29,9 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader;
public class EarlyStopping {
public static void main(String[] args) throws IOException, XGBoostError {
DataLoader.CSRSparseData trainCSR =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DataLoader.CSRSparseData testCSR =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test?format=libsvm");
Map<String, Object> paramMap = new HashMap<String, Object>() {
{

View File

@ -32,8 +32,8 @@ public class ExternalMemory {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -32,8 +32,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class GeneralizedLinearModel {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model

View File

@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class PredictFirstNtree {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class PredictLeafIndices {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -36,8 +36,8 @@ object BasicWalkThrough {
}
def main(args: Array[String]): Unit = {
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
@ -76,7 +76,7 @@ object BasicWalkThrough {
// build dmatrix from CSR Sparse Matrix
println("start build dmatrix from csr sparse data ...")
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm")
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
JDMatrix.SparseType.CSR)
trainMax2.setLabel(spData.labels)

View File

@ -24,8 +24,8 @@ object BoostFromPrediction {
def main(args: Array[String]): Unit = {
println("start running example to start from a initial prediction")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object CrossValidation {
def main(args: Array[String]): Unit = {
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
// set params
val params = new mutable.HashMap[String, Any]

View File

@ -138,8 +138,8 @@ object CustomObjective {
}
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2

View File

@ -25,8 +25,8 @@ object ExternalMemory {
// this is the only difference, add a # followed by a cache prefix name
// several cache file with the prefix will be generated
// currently only support convert from libsvm file
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -27,8 +27,8 @@ import ml.dmlc.xgboost4j.scala.example.util.CustomEval
*/
object GeneralizedLinearModel {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
// specify parameters
// change booster to gblinear, so that we are fitting a linear model

View File

@ -23,8 +23,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictFirstNTree {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictLeafIndices {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -30,8 +30,8 @@ import org.junit.Test;
* @author hzx
*/
public class BoosterImplTest {
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1";
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1";
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1&format=libsvm";
public static class EvalError implements IEvaluation {
@Override

View File

@ -88,7 +88,7 @@ public class DMatrixTest {
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
DMatrix dmat = new DMatrix(filePath);
DMatrix dmat = new DMatrix(filePath + "?format=libsvm");
//get label
float[] labels = dmat.getLabel();
//check length

View File

@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
class DMatrixSuite extends AnyFunSuite {
test("create DMatrix from File") {
val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
val dmat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
// get label
val labels: Array[Float] = dmat.getLabel
// check length

View File

@ -95,8 +95,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("basic operation of booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val booster = trainBooster(trainMat, testMat)
val predicts = booster.predict(testMat, true)
@ -106,8 +106,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
test("save/load model with path") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
@ -123,8 +123,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("save/load model with stream") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("cross validation") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6",
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
@ -148,8 +148,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("test with quantile histo depthwise") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "eval_metric" -> "auc").toMap
@ -158,8 +158,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("test with quantile histo lossguide") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
@ -168,8 +168,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("test with quantile histo lossguide with max bin") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
@ -179,8 +179,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("test with quantile histo depthwidth with max depth") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
@ -190,8 +190,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("test with quantile histo depthwidth with max depth and max bin") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
@ -201,7 +201,7 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("test training from existing model in scala") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
@ -213,8 +213,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
}
test("test getting number of features from a booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val booster = trainBooster(trainMat, testMat)
TestCase.assertEquals(booster.getNumFeature, 127)

View File

@ -882,5 +882,12 @@ def data_dir(path: str) -> str:
return os.path.join(demo_dir(path), "data")
def load_agaricus(path: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
dpath = data_dir(path)
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train?format=libsvm"))
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test?format=libsvm"))
return dtrain, dtest
def project_root(path: str) -> str:
return normpath(os.path.join(demo_dir(path), os.path.pardir))

View File

@ -819,8 +819,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
return nullptr;
}
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
const std::string& file_format) {
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
auto need_split = false;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
@ -862,12 +861,10 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
}
// legacy handling of binary data loading
if (file_format == "auto") {
DMatrix* loaded = TryLoadBinary(fname, silent);
if (loaded) {
return loaded;
}
}
int partid = 0, npart = 1;
if (need_split && data_split_mode == DataSplitMode::kRow) {
@ -882,17 +879,17 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
}
data::ValidateFileFormat(fname);
DMatrix* dmat {nullptr};
try {
if (cache_file.empty()) {
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
data::FileAdapter adapter(parser.get());
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
cache_file, data_split_mode);
} else {
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
file_format};
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart)};
dmat = new data::SparsePageDMatrix{&iter,
iter.Proxy(),
data::fileiter::Reset,
@ -901,29 +898,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
1,
cache_file};
}
} catch (dmlc::Error& e) {
std::vector<std::string> splited = common::Split(fname, '#');
std::vector<std::string> args = common::Split(splited.front(), '?');
std::string format {file_format};
if (args.size() == 1 && file_format == "auto") {
auto extension = common::Split(args.front(), '.').back();
if (extension == "csv" || extension == "libsvm") {
format = extension;
}
if (format == extension) {
LOG(WARNING)
<< "No format parameter is provided in input uri, but found file extension: "
<< format << " . "
<< "Consider providing a uri parameter: filename?format=" << format;
} else {
LOG(WARNING)
<< "No format parameter is provided in input uri. "
<< "Choosing default parser in dmlc-core. "
<< "Consider providing a uri parameter like: filename?format=csv";
}
}
LOG(FATAL) << "Encountered parser error:\n" << e.what();
}
if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {

View File

@ -1,22 +1,50 @@
/*!
* Copyright 2021 XGBoost contributors
/**
* Copyright 2021-2023, XGBoost contributors
*/
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_
#include <string>
#include <map>
#include <memory>
#include <vector>
#include <string>
#include <utility>
#include <vector>
#include "array_interface.h"
#include "dmlc/data.h"
#include "xgboost/c_api.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "array_interface.h"
namespace xgboost {
namespace data {
inline void ValidateFileFormat(std::string const& uri) {
std::vector<std::string> name_cache = common::Split(uri, '#');
CHECK_LE(name_cache.size(), 2)
<< "Only one `#` is allowed in file path for cachefile specification";
std::vector<std::string> name_args = common::Split(name_cache[0], '?');
CHECK_LE(name_args.size(), 2) << "only one `?` is allowed in file path.";
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
CHECK_EQ(name_args.size(), 2) << msg;
std::map<std::string, std::string> args;
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
args.insert(kv);
}
if (args.find("format") == args.cend()) {
LOG(FATAL) << msg;
}
}
/**
* An iterator for implementing external memory support with file inputs. Users of
* external memory are encouraged to define their own file parsers/loaders so this one is
@ -31,8 +59,6 @@ class FileIterator {
uint32_t part_idx_;
// Equals to total number of workers.
uint32_t n_parts_;
// Format of the input file, like "libsvm".
std::string type_;
DMatrixHandle proxy_;
@ -45,10 +71,9 @@ class FileIterator {
std::string indices_;
public:
FileIterator(std::string uri, unsigned part_index, unsigned num_parts,
std::string type)
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts},
type_{std::move(type)} {
FileIterator(std::string uri, unsigned part_index, unsigned num_parts)
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts} {
ValidateFileFormat(uri_);
XGProxyDMatrixCreate(&proxy_);
}
~FileIterator() {
@ -94,9 +119,7 @@ class FileIterator {
auto Proxy() -> decltype(proxy_) { return proxy_; }
void Reset() {
CHECK(!type_.empty());
parser_.reset(dmlc::Parser<uint32_t>::Create(uri_.c_str(), part_idx_,
n_parts_, type_.c_str()));
parser_.reset(dmlc::Parser<uint32_t>::Create(uri_.c_str(), part_idx_, n_parts_, "auto"));
}
};

View File

@ -88,7 +88,8 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
fo << row_data.str() << "\n";
}
fo.close();
return std::shared_ptr<DMatrix>(DMatrix::Load(tmp_file + "#" + tmp_file + ".cache"));
return std::shared_ptr<DMatrix>(
DMatrix::Load(tmp_file + "?format=libsvm" + "#" + tmp_file + ".cache"));
}
// Test that elements are approximately equally distributed among bins

View File

@ -29,16 +29,16 @@ TEST(FileIterator, Basic) {
{
auto zpath = tmpdir.path + "/0-based.svm";
CreateBigTestData(zpath, 3 * 64, true);
zpath += "?indexing_mode=0";
FileIterator iter{zpath, 0, 1, "libsvm"};
zpath += "?indexing_mode=0&format=libsvm";
FileIterator iter{zpath, 0, 1};
check_n_features(&iter);
}
{
auto opath = tmpdir.path + "/1-based.svm";
CreateBigTestData(opath, 3 * 64, false);
opath += "?indexing_mode=1";
FileIterator iter{opath, 0, 1, "libsvm"};
opath += "?indexing_mode=1&format=libsvm";
FileIterator iter{opath, 0, 1};
check_n_features(&iter);
}
}

View File

@ -157,8 +157,7 @@ TEST(MetaInfo, LoadQid) {
dmlc::TemporaryDirectory tempdir;
std::string tmp_file = tempdir.path + "/qid_test.libsvm";
{
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(tmp_file.c_str(), "w"));
std::unique_ptr<dmlc::Stream> fs(dmlc::Stream::Create(tmp_file.c_str(), "w"));
dmlc::ostream os(fs.get());
os << R"qid(3 qid:1 1:1 2:1 3:0 4:0.2 5:0
2 qid:1 1:0 2:0 3:1 4:0.1 5:1
@ -175,7 +174,7 @@ TEST(MetaInfo, LoadQid) {
os.set_stream(nullptr);
}
std::unique_ptr<xgboost::DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kRow, "libsvm"));
xgboost::DMatrix::Load(tmp_file + "?format=libsvm", true, xgboost::DataSplitMode::kRow));
const xgboost::MetaInfo& info = dmat->Info();
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};

View File

@ -17,11 +17,15 @@
using namespace xgboost; // NOLINT
namespace {
std::string UriSVM(std::string name) { return name + "?format=libsvm"; }
} // namespace
TEST(SimpleDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(UriSVM(tmp_file));
// Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 2);
@ -37,7 +41,7 @@ TEST(SimpleDMatrix, RowAccess) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(UriSVM(tmp_file), false);
// Loop over the batches and count the records
int64_t row_count = 0;
@ -60,7 +64,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(UriSVM(tmp_file));
ASSERT_TRUE(dmat->SingleColBlock());
@ -387,7 +391,7 @@ TEST(SimpleDMatrix, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(UriSVM(tmp_file));
data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat);
const std::string tmp_binfile = tempdir.path + "/csr_source.binary";

View File

@ -16,14 +16,19 @@
#include "../helpers.h"
using namespace xgboost; // NOLINT
namespace {
std::string UriSVM(std::string name, std::string cache) {
return name + "?format=libsvm" + "#" + cache + ".cache";
}
} // namespace
template <typename Page>
void TestSparseDMatrixLoadFile() {
dmlc::TemporaryDirectory tmpdir;
auto opath = tmpdir.path + "/1-based.svm";
CreateBigTestData(opath, 3 * 64, false);
opath += "?indexing_mode=1";
data::FileIterator iter{opath, 0, 1, "libsvm"};
opath += "?indexing_mode=1&format=libsvm";
data::FileIterator iter{opath, 0, 1};
auto n_threads = 0;
data::SparsePageDMatrix m{&iter,
iter.Proxy(),
@ -112,15 +117,13 @@ TEST(SparsePageDMatrix, MetaInfo) {
size_t constexpr kEntries = 24;
CreateBigTestData(tmp_file, kEntries);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", false);
std::unique_ptr<DMatrix> dmat{xgboost::DMatrix::Load(UriSVM(tmp_file, tmp_file), false)};
// Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 8ul);
EXPECT_EQ(dmat->Info().num_col_, 5ul);
EXPECT_EQ(dmat->Info().num_nonzero_, kEntries);
EXPECT_EQ(dmat->Info().labels.Size(), dmat->Info().num_row_);
delete dmat;
}
TEST(SparsePageDMatrix, RowAccess) {
@ -139,7 +142,7 @@ TEST(SparsePageDMatrix, ColAccess) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache");
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(UriSVM(tmp_file, tmp_file));
// Loop over the batches and assert the data is as expected
size_t iter = 0;
@ -231,7 +234,7 @@ auto TestSparsePageDMatrixDeterminism(int32_t threads) {
std::string filename = tempdir.path + "/simple.libsvm";
CreateBigTestData(filename, 1 << 16);
data::FileIterator iter(filename, 0, 1, "auto");
data::FileIterator iter(filename + "?format=libsvm", 0, 1);
std::unique_ptr<DMatrix> sparse{
new data::SparsePageDMatrix{&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next,
std::numeric_limits<float>::quiet_NaN(), threads, filename}};

View File

@ -13,7 +13,7 @@ TEST(SparsePageDMatrix, EllpackPage) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache");
DMatrix* dmat = DMatrix::Load(tmp_file + "?format=libsvm" + "#" + tmp_file + ".cache");
// Loop over the batches and assert the data is as expected
size_t n = 0;

View File

@ -548,7 +548,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
}
fo.close();
std::string uri = tmp_file;
std::string uri = tmp_file + "?format=libsvm";
if (page_size > 0) {
uri += "#" + tmp_file + ".cache";
}

View File

@ -126,7 +126,8 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/big.libsvm";
CreateBigTestData(tmp_file, 50000);
std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache"));
std::shared_ptr<DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file + "?format=libsvm" + "#" + tmp_file + ".cache"));
EXPECT_FALSE(dmat->SingleColBlock());
size_t num_row = dmat->Info().num_row_;
std::vector<bst_float> labels(num_row);

View File

@ -21,8 +21,7 @@ class TestBasic:
assert not lazy_isinstance(a, 'numpy', 'dataframe')
def test_basic(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
dtrain, dtest = tm.load_agaricus(__file__)
param = {'max_depth': 2, 'eta': 1,
'objective': 'binary:logistic'}
# specify validations set to watch performance
@ -61,8 +60,7 @@ class TestBasic:
def test_metric_config(self):
# Make sure that the metric configuration happens in booster so the
# string `['error', 'auc']` doesn't get passed down to core.
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
dtrain, dtest = tm.load_agaricus(__file__)
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic', 'eval_metric': ['error', 'auc']}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
@ -78,8 +76,7 @@ class TestBasic:
np.testing.assert_allclose(predt_0, predt_1)
def test_multiclass(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
dtrain, dtest = tm.load_agaricus(__file__)
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, 'num_class': 2}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
@ -188,7 +185,7 @@ class TestBasic:
assert dm.num_col() == cols
def test_cv(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
dm, _ = tm.load_agaricus(__file__)
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic'}
@ -198,7 +195,7 @@ class TestBasic:
assert len(cv) == (4)
def test_cv_no_shuffle(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
dm, _ = tm.load_agaricus(__file__)
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic'}
@ -209,7 +206,7 @@ class TestBasic:
assert len(cv) == (4)
def test_cv_explicit_fold_indices(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
dm, _ = tm.load_agaricus(__file__)
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0, 'objective':
'binary:logistic'}
folds = [
@ -268,8 +265,7 @@ class TestBasicPathLike:
def test_DMatrix_init_from_path(self):
"""Initialization from the data path."""
dpath = Path('demo/data')
dtrain = xgb.DMatrix(dpath / 'agaricus.txt.train')
dtrain, _ = tm.load_agaricus(__file__)
assert dtrain.num_row() == 6513
assert dtrain.num_col() == 127

View File

@ -42,8 +42,7 @@ class TestModels:
param = {'verbosity': 0, 'objective': 'binary:logistic',
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
'nthread': 1}
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist)
@ -55,8 +54,7 @@ class TestModels:
assert err < 0.2
def test_dart(self):
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
dtrain, dtest = tm.load_agaricus(__file__)
param = {'max_depth': 5, 'objective': 'binary:logistic',
'eval_metric': 'logloss', 'booster': 'dart', 'verbosity': 1}
# specify validations set to watch performance
@ -122,7 +120,7 @@ class TestModels:
def test_boost_from_prediction(self):
# Re-construct dtrain here to avoid modification
margined = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
margined, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
predt_0 = bst.predict(margined, output_margin=True)
margined.set_base_margin(predt_0)
@ -130,13 +128,13 @@ class TestModels:
predt_1 = bst.predict(margined)
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtrain, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
def test_boost_from_existing_model(self):
X = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
X, _ = tm.load_agaricus(__file__)
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4)
assert booster.num_boosted_rounds() == 4
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4,
@ -156,8 +154,7 @@ class TestModels:
'objective': 'reg:logistic',
"tree_method": tree_method
}
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 10
@ -203,8 +200,7 @@ class TestModels:
self.run_custom_objective()
def test_multi_eval_metric(self):
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
param = {'max_depth': 2, 'eta': 0.2, 'verbosity': 1,
'objective': 'binary:logistic'}
@ -226,7 +222,7 @@ class TestModels:
param['scale_pos_weight'] = ratio
return (dtrain, dtest, param)
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtrain, _ = tm.load_agaricus(__file__)
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'auc'}, seed=0, fpreproc=fpreproc)
@ -234,7 +230,7 @@ class TestModels:
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic'}
num_round = 2
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtrain, _ = tm.load_agaricus(__file__)
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0, show_stdv=False)
@ -392,7 +388,7 @@ class TestModels:
os.remove(model_path)
try:
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtrain, _ = tm.load_agaricus(__file__)
xgb.train({'objective': 'foo'}, dtrain, num_boost_round=1)
except ValueError as e:
e_str = str(e)

View File

@ -275,9 +275,7 @@ class TestCallbacks:
"""Test learning rate scheduler, used by both CPU and GPU tests."""
scheduler = xgb.callback.LearningRateScheduler
dpath = tm.data_dir(__file__)
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4
@ -361,9 +359,7 @@ class TestCallbacks:
num_round = 4
scheduler = xgb.callback.LearningRateScheduler
dpath = tm.data_dir(__file__)
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
param = {

View File

@ -283,7 +283,7 @@ class TestDMatrix:
assert m0.feature_types == m1.feature_types
def test_get_info(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtrain, _ = tm.load_agaricus(__file__)
dtrain.get_float_info('label')
dtrain.get_float_info('weight')
dtrain.get_float_info('base_margin')
@ -432,7 +432,9 @@ class TestDMatrix:
def test_uri_categorical(self):
path = os.path.join(dpath, 'agaricus.txt.train')
feature_types = ["q"] * 5 + ["c"] + ["q"] * 120
Xy = xgb.DMatrix(path + "?indexing_mode=1", feature_types=feature_types)
Xy = xgb.DMatrix(
path + "?indexing_mode=1&format=libsvm", feature_types=feature_types
)
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
def test_base_margin(self):

View File

@ -88,8 +88,12 @@ class TestInteractionConstraints:
def training_accuracy(self, tree_method):
"""Test accuracy, reused by GPU tests."""
from sklearn.metrics import accuracy_score
dtrain = xgboost.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
dtest = xgboost.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
dtrain = xgboost.DMatrix(
dpath + "agaricus.txt.train?indexing_mode=1&format=libsvm"
)
dtest = xgboost.DMatrix(
dpath + "agaricus.txt.test?indexing_mode=1&format=libsvm"
)
params = {
'eta': 1,
'max_depth': 6,

View File

@ -134,8 +134,8 @@ class TestMonotoneConstraints:
@pytest.mark.skipif(**tm.no_sklearn())
def test_training_accuracy(self):
from sklearn.metrics import accuracy_score
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
dtrain = xgb.DMatrix(dpath + "agaricus.txt.train?indexing_mode=1&format=libsvm")
dtest = xgb.DMatrix(dpath + "agaricus.txt.test?indexing_mode=1&format=libsvm")
params = {'eta': 1, 'max_depth': 6, 'objective': 'binary:logistic',
'tree_method': 'hist', 'monotone_constraints': '(1, 0)'}
num_boost_round = 5

View File

@ -13,9 +13,7 @@ pytestmark = tm.timeout(10)
class TestOMP:
def test_omp(self):
dpath = 'demo/data/'
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
dtrain, dtest = tm.load_agaricus(__file__)
param = {'booster': 'gbtree',
'objective': 'binary:logistic',

View File

@ -13,7 +13,7 @@ rng = np.random.RandomState(1994)
class TestTreesToDataFrame:
def build_model(self, max_depth, num_round):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtrain, _ = tm.load_agaricus(__file__)
param = {'max_depth': max_depth, 'objective': 'binary:logistic',
'verbosity': 1}
num_round = num_round

View File

@ -17,12 +17,10 @@ except ImportError:
pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(),
tm.no_graphviz()))
dpath = 'demo/data/agaricus.txt.train'
class TestPlotting:
def test_plotting(self):
m = xgb.DMatrix(dpath)
m, _ = tm.load_agaricus(__file__)
booster = xgb.train({'max_depth': 2, 'eta': 1,
'objective': 'binary:logistic'}, m,
num_boost_round=2)

View File

@ -46,8 +46,8 @@ class TestSHAP:
fscores = bst.get_fscore()
assert scores1 == fscores
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train?format=libsvm')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test?format=libsvm')
def fn(max_depth, num_rounds):
# train

View File

@ -154,9 +154,7 @@ class TestTreeMethod:
def test_hist_categorical(self):
# hist must be same as exact on all-categorial data
dpath = 'demo/data/'
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
ag_dtrain, ag_dtest = tm.load_agaricus(__file__)
ag_param = {'max_depth': 2,
'tree_method': 'hist',
'eta': 1,

View File

@ -222,7 +222,7 @@ class TestPandas:
set_base_margin_info(pd.DataFrame, xgb.DMatrix, "hist")
def test_cv_as_pandas(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
dm, _ = tm.load_agaricus(__file__)
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic', 'eval_metric': 'error'}