From 7e477a2adbd418277ec2b09d30c48b1c02ab521f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 22 Oct 2019 12:33:14 -0400 Subject: [PATCH] Fix data loading (#4862) * Fix loading text data. * Fix config regex. * Try to explain the error better in exception. * Update doc. --- doc/contrib/docs.rst | 2 -- doc/python/python_intro.rst | 4 +++- doc/tutorials/input_format.rst | 3 +++ src/common/config.h | 4 ++-- src/data/data.cc | 30 +++++++++++++++++++++++- tests/cli/machine.conf.in | 4 ++-- tests/cpp/data/test_data.cc | 42 ++++++++++++++++++++++++++++++++++ 7 files changed, 81 insertions(+), 8 deletions(-) diff --git a/doc/contrib/docs.rst b/doc/contrib/docs.rst index 449bcabaa..4990f0e0b 100644 --- a/doc/contrib/docs.rst +++ b/doc/contrib/docs.rst @@ -28,5 +28,3 @@ Examples * We are super excited to hear about your story, if you have blogposts, tutorials code solutions using XGBoost, please tell us and we will add a link in the example pages. - - diff --git a/doc/python/python_intro.rst b/doc/python/python_intro.rst index 9f4d8efd8..3cd2a00e2 100644 --- a/doc/python/python_intro.rst +++ b/doc/python/python_intro.rst @@ -18,6 +18,8 @@ To verify your installation, run the following in Python: import xgboost as xgb +.. _python_data_interface: + Data Interface -------------- The XGBoost python module is able to load data from: @@ -50,7 +52,7 @@ The data is stored in a :py:class:`DMatrix ` object. .. note:: Categorical features not supported - Note that XGBoost does not support categorical features; if your data contains + Note that XGBoost does not provide specialization for categorical features; if your data contains categorical features, load it as a NumPy array first and then perform `one-hot encoding `_. diff --git a/doc/tutorials/input_format.rst b/doc/tutorials/input_format.rst index 2ec4b8e59..f844e09a4 100644 --- a/doc/tutorials/input_format.rst +++ b/doc/tutorials/input_format.rst @@ -7,6 +7,9 @@ Basic Input Format ****************** XGBoost currently supports two text formats for ingesting data: LibSVM and CSV. The rest of this document will describe the LibSVM format. (See `this Wikipedia article `_ for a description of the CSV format.) +.. note:: + * XGBoost does **not** understand file extensions nor try to guess the file format. Instead it employs uri format for specifying input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will use the default libsvm parser to digest it and generate a parser error. Instead, users need to provide an uri in the form of ``train.csv?format=csv``. For external memory input, the uri should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` also. + For training or predicting, XGBoost takes an instance file with the format as below: .. code-block:: none diff --git a/src/common/config.h b/src/common/config.h index a85fee609..2efe35a4f 100644 --- a/src/common/config.h +++ b/src/common/config.h @@ -34,8 +34,8 @@ class ConfigParser { line_comment_regex_("^#"), key_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*=)rx"), key_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*=)rx"), - value_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"), - value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx") + value_regex_(R"rx(^([^#"'\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"), + value_regex_escaped_(R"rx(^(["'])([^"'\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx") {} std::string LoadConfigFile(const std::string& path) { diff --git a/src/data/data.cc b/src/data/data.cc index 9f3088531..026285150 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -222,7 +222,35 @@ DMatrix* DMatrix::Load(const std::string& uri, std::unique_ptr > parser( dmlc::Parser::Create(fname.c_str(), partid, npart, file_format.c_str())); - DMatrix* dmat = DMatrix::Create(parser.get(), cache_file, page_size); + DMatrix* dmat; + + try { + dmat = DMatrix::Create(parser.get(), cache_file, page_size); + } catch (dmlc::Error& e) { + std::vector splited = common::Split(fname, '#'); + std::vector args = common::Split(splited.front(), '?'); + std::string format {file_format}; + if (args.size() == 1 && file_format == "auto") { + auto extension = common::Split(args.front(), '.').back(); + if (extension == "csv" || extension == "libsvm") { + format = extension; + } + if (format == extension) { + LOG(WARNING) + << "No format parameter is provided in input uri, but found file extension: " + << format << " . " + << "Consider providing a uri parameter: filename?format=" << format; + } else { + LOG(WARNING) + << "No format parameter is provided in input uri. " + << "Choosing default parser in dmlc-core. " + << "Consider providing a uri parameter like: filename?format=csv"; + } + + LOG(FATAL) << "Encountered parser error:\n" << e.what(); + } + } + if (!silent) { LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with " << dmat->Info().num_nonzero_ << " entries loaded from " << uri; diff --git a/tests/cli/machine.conf.in b/tests/cli/machine.conf.in index 0f12bfddf..e9575261a 100644 --- a/tests/cli/machine.conf.in +++ b/tests/cli/machine.conf.in @@ -9,5 +9,5 @@ max_depth = 3 num_round = 2 save_period = 0 -data = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.train" -eval[test] = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.test" \ No newline at end of file +data = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.train?format=libsvm" +eval[test] = "@PROJECT_SOURCE_DIR@/demo/data/agaricus.txt.test?format=libsvm" \ No newline at end of file diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 2f9926857..3b5aa9faa 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include "xgboost/data.h" @@ -81,4 +83,44 @@ TEST(SparsePage, PushCSCAfterTranspose) { } } } + +TEST(DMatrix, Uri) { + size_t constexpr kRows {16}; + size_t constexpr kCols {8}; + std::vector data (kRows * kCols); + + for (size_t i = 0; i < kRows * kCols; ++i) { + data[i] = i; + } + + dmlc::TemporaryDirectory tmpdir; + std::string path = tmpdir.path + "/small.csv"; + + std::ofstream fout(path); + ASSERT_TRUE(fout); + size_t i = 0; + for (size_t r = 0; r < kRows; ++r) { + for (size_t c = 0; c < kCols; ++c) { + fout << data[i]; + i++; + if (c != kCols - 1) { + fout << ","; + } + } + fout << "\n"; + } + fout.flush(); + fout.close(); + + std::unique_ptr dmat; + // FIXME(trivialfis): Enable the following test by restricting csv parser in dmlc-core. + // EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error); + + std::string uri = path + "?format=csv"; + dmat.reset(DMatrix::Load(uri, false, true)); + + ASSERT_EQ(dmat->Info().num_col_, kCols); + ASSERT_EQ(dmat->Info().num_row_, kRows); +} + } // namespace xgboost