Fix data loading (#4862)

* Fix loading text data.
* Fix config regex.
* Try to explain the error better in exception.
* Update doc.
This commit is contained in:
Jiaming Yuan 2019-10-22 12:33:14 -04:00 committed by GitHub
parent 95295ce026
commit 7e477a2adb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 81 additions and 8 deletions

View File

@ -28,5 +28,3 @@ Examples
* We are super excited to hear about your story, if you have blogposts, * We are super excited to hear about your story, if you have blogposts,
tutorials code solutions using XGBoost, please tell us and we will add tutorials code solutions using XGBoost, please tell us and we will add
a link in the example pages. a link in the example pages.

View File

@ -18,6 +18,8 @@ To verify your installation, run the following in Python:
import xgboost as xgb import xgboost as xgb
.. _python_data_interface:
Data Interface Data Interface
-------------- --------------
The XGBoost python module is able to load data from: The XGBoost python module is able to load data from:
@ -50,7 +52,7 @@ The data is stored in a :py:class:`DMatrix <xgboost.DMatrix>` object.
.. note:: Categorical features not supported .. note:: Categorical features not supported
Note that XGBoost does not support categorical features; if your data contains Note that XGBoost does not provide specialization for categorical features; if your data contains
categorical features, load it as a NumPy array first and then perform categorical features, load it as a NumPy array first and then perform
`one-hot encoding <http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html>`_. `one-hot encoding <http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html>`_.

View File

@ -7,6 +7,9 @@ Basic Input Format
****************** ******************
XGBoost currently supports two text formats for ingesting data: LibSVM and CSV. The rest of this document will describe the LibSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.) XGBoost currently supports two text formats for ingesting data: LibSVM and CSV. The rest of this document will describe the LibSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.)
.. note::
* XGBoost does **not** understand file extensions nor try to guess the file format. Instead it employs uri format for specifying input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will use the default libsvm parser to digest it and generate a parser error. Instead, users need to provide an uri in the form of ``train.csv?format=csv``. For external memory input, the uri should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` also.
For training or predicting, XGBoost takes an instance file with the format as below: For training or predicting, XGBoost takes an instance file with the format as below:
.. code-block:: none .. code-block:: none

View File

@ -34,8 +34,8 @@ class ConfigParser {
line_comment_regex_("^#"), line_comment_regex_("^#"),
key_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*=)rx"), key_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*=)rx"),
key_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*=)rx"), key_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*=)rx"),
value_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"), value_regex_(R"rx(^([^#"'\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"),
value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx") value_regex_escaped_(R"rx(^(["'])([^"'\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx")
{} {}
std::string LoadConfigFile(const std::string& path) { std::string LoadConfigFile(const std::string& path) {

View File

@ -222,7 +222,35 @@ DMatrix* DMatrix::Load(const std::string& uri,
std::unique_ptr<dmlc::Parser<uint32_t> > parser( std::unique_ptr<dmlc::Parser<uint32_t> > parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str())); dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file, page_size); DMatrix* dmat;
try {
dmat = DMatrix::Create(parser.get(), cache_file, page_size);
} catch (dmlc::Error& e) {
std::vector<std::string> splited = common::Split(fname, '#');
std::vector<std::string> args = common::Split(splited.front(), '?');
std::string format {file_format};
if (args.size() == 1 && file_format == "auto") {
auto extension = common::Split(args.front(), '.').back();
if (extension == "csv" || extension == "libsvm") {
format = extension;
}
if (format == extension) {
LOG(WARNING)
<< "No format parameter is provided in input uri, but found file extension: "
<< format << " . "
<< "Consider providing a uri parameter: filename?format=" << format;
} else {
LOG(WARNING)
<< "No format parameter is provided in input uri. "
<< "Choosing default parser in dmlc-core. "
<< "Consider providing a uri parameter like: filename?format=csv";
}
LOG(FATAL) << "Encountered parser error:\n" << e.what();
}
}
if (!silent) { if (!silent) {
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with " LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
<< dmat->Info().num_nonzero_ << " entries loaded from " << uri; << dmat->Info().num_nonzero_ << " entries loaded from " << uri;

View File

@ -9,5 +9,5 @@ max_depth = 3
num_round = 2 num_round = 2
save_period = 0 save_period = 0
data = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.train" data = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.train?format=libsvm"
eval[test] = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.test" eval[test] = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.test?format=libsvm"

View File

@ -1,5 +1,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <fstream>
#include <memory>
#include <vector> #include <vector>
#include "xgboost/data.h" #include "xgboost/data.h"
@ -81,4 +83,44 @@ TEST(SparsePage, PushCSCAfterTranspose) {
} }
} }
} }
TEST(DMatrix, Uri) {
size_t constexpr kRows {16};
size_t constexpr kCols {8};
std::vector<float> data (kRows * kCols);
for (size_t i = 0; i < kRows * kCols; ++i) {
data[i] = i;
}
dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/small.csv";
std::ofstream fout(path);
ASSERT_TRUE(fout);
size_t i = 0;
for (size_t r = 0; r < kRows; ++r) {
for (size_t c = 0; c < kCols; ++c) {
fout << data[i];
i++;
if (c != kCols - 1) {
fout << ",";
}
}
fout << "\n";
}
fout.flush();
fout.close();
std::unique_ptr<DMatrix> dmat;
// FIXME(trivialfis): Enable the following test by restricting csv parser in dmlc-core.
// EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error);
std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, true));
ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows);
}
} // namespace xgboost } // namespace xgboost