[Breaking] Require format to be specified in input URI. (#9077)
Previously, we use `libsvm` as default when format is not specified. However, the dmlc data parser is not particularly robust against errors, and the most common type of error is undefined format. Along with which, we will recommend users to use other data loader instead. We will continue the maintenance of the parsers as it's currently used for many internal tests including federated learning.
This commit is contained in:
@@ -819,8 +819,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
|
||||
const std::string& file_format) {
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
|
||||
auto need_split = false;
|
||||
if (collective::IsFederated()) {
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
@@ -862,11 +861,9 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
}
|
||||
|
||||
// legacy handling of binary data loading
|
||||
if (file_format == "auto") {
|
||||
DMatrix* loaded = TryLoadBinary(fname, silent);
|
||||
if (loaded) {
|
||||
return loaded;
|
||||
}
|
||||
DMatrix* loaded = TryLoadBinary(fname, silent);
|
||||
if (loaded) {
|
||||
return loaded;
|
||||
}
|
||||
|
||||
int partid = 0, npart = 1;
|
||||
@@ -882,47 +879,24 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
|
||||
}
|
||||
|
||||
data::ValidateFileFormat(fname);
|
||||
DMatrix* dmat {nullptr};
|
||||
try {
|
||||
if (cache_file.empty()) {
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||
cache_file, data_split_mode);
|
||||
} else {
|
||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
||||
file_format};
|
||||
dmat = new data::SparsePageDMatrix{&iter,
|
||||
iter.Proxy(),
|
||||
data::fileiter::Reset,
|
||||
data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
1,
|
||||
cache_file};
|
||||
}
|
||||
} catch (dmlc::Error& e) {
|
||||
std::vector<std::string> splited = common::Split(fname, '#');
|
||||
std::vector<std::string> args = common::Split(splited.front(), '?');
|
||||
std::string format {file_format};
|
||||
if (args.size() == 1 && file_format == "auto") {
|
||||
auto extension = common::Split(args.front(), '.').back();
|
||||
if (extension == "csv" || extension == "libsvm") {
|
||||
format = extension;
|
||||
}
|
||||
if (format == extension) {
|
||||
LOG(WARNING)
|
||||
<< "No format parameter is provided in input uri, but found file extension: "
|
||||
<< format << " . "
|
||||
<< "Consider providing a uri parameter: filename?format=" << format;
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "No format parameter is provided in input uri. "
|
||||
<< "Choosing default parser in dmlc-core. "
|
||||
<< "Consider providing a uri parameter like: filename?format=csv";
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Encountered parser error:\n" << e.what();
|
||||
|
||||
if (cache_file.empty()) {
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||
cache_file, data_split_mode);
|
||||
} else {
|
||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart)};
|
||||
dmat = new data::SparsePageDMatrix{&iter,
|
||||
iter.Proxy(),
|
||||
data::fileiter::Reset,
|
||||
data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
1,
|
||||
cache_file};
|
||||
}
|
||||
|
||||
if (need_split && data_split_mode == DataSplitMode::kCol) {
|
||||
|
||||
@@ -1,22 +1,50 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
|
||||
#define XGBOOST_DATA_FILE_ITERATOR_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "array_interface.h"
|
||||
#include "dmlc/data.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "array_interface.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
inline void ValidateFileFormat(std::string const& uri) {
|
||||
std::vector<std::string> name_cache = common::Split(uri, '#');
|
||||
CHECK_LE(name_cache.size(), 2)
|
||||
<< "Only one `#` is allowed in file path for cachefile specification";
|
||||
|
||||
std::vector<std::string> name_args = common::Split(name_cache[0], '?');
|
||||
CHECK_LE(name_args.size(), 2) << "only one `?` is allowed in file path.";
|
||||
|
||||
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
|
||||
CHECK_EQ(name_args.size(), 2) << msg;
|
||||
|
||||
std::map<std::string, std::string> args;
|
||||
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
|
||||
for (size_t i = 0; i < arg_list.size(); ++i) {
|
||||
std::istringstream is(arg_list[i]);
|
||||
std::pair<std::string, std::string> kv;
|
||||
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
|
||||
<< " for key in arg " << i + 1;
|
||||
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
|
||||
<< " for value in arg " << i + 1;
|
||||
args.insert(kv);
|
||||
}
|
||||
if (args.find("format") == args.cend()) {
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An iterator for implementing external memory support with file inputs. Users of
|
||||
* external memory are encouraged to define their own file parsers/loaders so this one is
|
||||
@@ -31,8 +59,6 @@ class FileIterator {
|
||||
uint32_t part_idx_;
|
||||
// Equals to total number of workers.
|
||||
uint32_t n_parts_;
|
||||
// Format of the input file, like "libsvm".
|
||||
std::string type_;
|
||||
|
||||
DMatrixHandle proxy_;
|
||||
|
||||
@@ -45,10 +71,9 @@ class FileIterator {
|
||||
std::string indices_;
|
||||
|
||||
public:
|
||||
FileIterator(std::string uri, unsigned part_index, unsigned num_parts,
|
||||
std::string type)
|
||||
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts},
|
||||
type_{std::move(type)} {
|
||||
FileIterator(std::string uri, unsigned part_index, unsigned num_parts)
|
||||
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts} {
|
||||
ValidateFileFormat(uri_);
|
||||
XGProxyDMatrixCreate(&proxy_);
|
||||
}
|
||||
~FileIterator() {
|
||||
@@ -94,9 +119,7 @@ class FileIterator {
|
||||
auto Proxy() -> decltype(proxy_) { return proxy_; }
|
||||
|
||||
void Reset() {
|
||||
CHECK(!type_.empty());
|
||||
parser_.reset(dmlc::Parser<uint32_t>::Create(uri_.c_str(), part_idx_,
|
||||
n_parts_, type_.c_str()));
|
||||
parser_.reset(dmlc::Parser<uint32_t>::Create(uri_.c_str(), part_idx_, n_parts_, "auto"));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user