Normalize file system path. (#9463)

This commit is contained in:
Jiaming Yuan 2023-08-11 21:26:46 +08:00 committed by GitHub
parent bdc1a3c178
commit bb56183396
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 137 additions and 98 deletions

View File

@ -47,6 +47,7 @@ OBJECTS= \
$(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \
$(PKGROOT)/src/data/ellpack_page.o \
$(PKGROOT)/src/data/file_iterator.o \
$(PKGROOT)/src/data/gradient_index.o \
$(PKGROOT)/src/data/gradient_index_page_source.o \
$(PKGROOT)/src/data/gradient_index_format.o \

View File

@ -47,6 +47,7 @@ OBJECTS= \
$(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \
$(PKGROOT)/src/data/ellpack_page.o \
$(PKGROOT)/src/data/file_iterator.o \
$(PKGROOT)/src/data/gradient_index.o \
$(PKGROOT)/src/data/gradient_index_page_source.o \
$(PKGROOT)/src/data/gradient_index_format.o \

View File

@ -72,6 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
tmp_file <- tempfile(fileext = ".libsvm")
writeLines(tmp, tmp_file)
expect_true(file.exists(tmp_file))
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
expect_equal(dim(dtest4), c(3, 4))
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))

View File

@ -28,7 +28,7 @@
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint32_t
#include <cstring> // for memcpy
#include <filesystem> // for filesystem
#include <filesystem> // for filesystem, weakly_canonical
#include <fstream> // for ifstream
#include <iterator> // for distance
#include <limits> // for numeric_limits
@ -154,7 +154,8 @@ std::string LoadSequentialFile(std::string uri, bool stream) {
// Open in binary mode so that correct file size can be computed with
// seekg(). This accommodates Windows platform:
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
std::ifstream ifs(std::filesystem::u8path(uri), std::ios_base::binary | std::ios_base::in);
auto path = std::filesystem::weakly_canonical(std::filesystem::u8path(uri));
std::ifstream ifs(path, std::ios_base::binary | std::ios_base::in);
if (!ifs) {
// https://stackoverflow.com/a/17338934
OpenErr();

View File

@ -4,42 +4,57 @@
*/
#include "xgboost/data.h"
#include <dmlc/registry.h>
#include <dmlc/registry.h> // for DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_LINK_TAG
#include <array>
#include <cstddef>
#include <cstring>
#include <algorithm> // for copy, max, none_of, min
#include <atomic> // for atomic
#include <cmath> // for abs
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
#include <cstring> // for size_t, strcmp, memcpy
#include <exception> // for exception
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
#include <map> // for map, operator!=
#include <numeric> // for accumulate, partial_sum
#include <tuple> // for get, apply
#include <type_traits> // for remove_pointer_t, remove_reference
#include "../collective/communicator-inl.h"
#include "../collective/communicator.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/common.h"
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
#include "../common/math.h"
#include "../common/numeric.h" // for Iota
#include "../common/threading_utils.h"
#include "../common/version.h"
#include "../data/adapter.h"
#include "../data/iterative_dmatrix.h"
#include "./sparse_page_dmatrix.h"
#include "./sparse_page_source.h"
#include "dmlc/io.h"
#include "file_iterator.h"
#include "simple_dmatrix.h"
#include "sparse_page_writer.h"
#include "validation.h"
#include "xgboost/c_api.h"
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "xgboost/linalg.h" // Vector
#include "xgboost/logging.h"
#include "xgboost/string_view.h"
#include "xgboost/version_config.h"
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
#include "../collective/communicator.h" // for Operation
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/common.h" // for Split
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
#include "../common/linalg_op.h" // for ElementWiseTransformHost
#include "../common/math.h" // for CheckNAN
#include "../common/numeric.h" // for Iota, RunLengthEncode
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/version.h" // for Version
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/common.h" // for OMPException
#include "dmlc/data.h" // for Parser
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
#include "dmlc/io.h" // for Stream
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "ellpack_page.h" // for EllpackPage
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
#include "gradient_index.h" // for GHistIndexMatrix
#include "simple_dmatrix.h" // for SimpleDMatrix
#include "sparse_page_writer.h" // for SparsePageFormatReg
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
#include "xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/learner.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
#include "xgboost/string_view.h" // for operator==, operator<<, StringView
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
@ -811,10 +826,10 @@ DMatrix::~DMatrix() {
}
}
DMatrix *TryLoadBinary(std::string fname, bool silent) {
int magic;
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(fname.c_str(), "r", true));
namespace {
DMatrix* TryLoadBinary(std::string fname, bool silent) {
std::int32_t magic;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi != nullptr) {
common::PeekableInStream is(fi.get());
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic)) {
@ -822,11 +837,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
dmlc::ByteSwap(&magic, sizeof(magic), 1);
}
if (magic == data::SimpleDMatrix::kMagic) {
DMatrix *dmat = new data::SimpleDMatrix(&is);
DMatrix* dmat = new data::SimpleDMatrix(&is);
if (!silent) {
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_
<< " matrix with " << dmat->Info().num_nonzero_
<< " entries loaded from " << fname;
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
<< dmat->Info().num_nonzero_ << " entries loaded from " << fname;
}
return dmat;
}
@ -834,6 +848,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
}
return nullptr;
}
} // namespace
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
auto need_split = false;
@ -845,7 +860,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
}
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
auto dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
cache_file = uri.substr(dlm_pos + 1, uri.length());
fname = uri.substr(0, dlm_pos);
@ -857,14 +872,11 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
for (size_t i = 0; i < cache_shards.size(); ++i) {
size_t pos = cache_shards[i].rfind('.');
if (pos == std::string::npos) {
os << cache_shards[i]
<< ".r" << collective::GetRank()
<< "-" << collective::GetWorldSize();
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize();
} else {
os << cache_shards[i].substr(0, pos)
<< ".r" << collective::GetRank()
<< "-" << collective::GetWorldSize()
<< cache_shards[i].substr(pos, cache_shards[i].length());
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
}
if (i + 1 != cache_shards.size()) {
os << ':';
@ -895,12 +907,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
}
data::ValidateFileFormat(fname);
DMatrix* dmat {nullptr};
DMatrix* dmat{nullptr};
if (cache_file.empty()) {
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
fname = data::ValidateFileFormat(fname);
std::unique_ptr<dmlc::Parser<std::uint32_t>> parser(
dmlc::Parser<std::uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
data::FileAdapter adapter(parser.get());
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
cache_file, data_split_mode);

51
src/data/file_iterator.cc Normal file
View File

@ -0,0 +1,51 @@
/**
* Copyright 2021-2023, XGBoost contributors
*/
#include "file_iterator.h"
#include <xgboost/logging.h> // for LogCheck_EQ, LogCheck_LE, CHECK_EQ, CHECK_LE, LOG, LOG_...
#include <filesystem> // for weakly_canonical, path, u8path
#include <map> // for map, operator==
#include <ostream> // for operator<<, basic_ostream, istringstream
#include <vector> // for vector
#include "../common/common.h" // for Split
#include "xgboost/string_view.h" // for operator<<, StringView
namespace xgboost::data {
std::string ValidateFileFormat(std::string const& uri) {
std::vector<std::string> name_args_cache = common::Split(uri, '#');
CHECK_LE(name_args_cache.size(), 2)
<< "Only one `#` is allowed in file path for cachefile specification";
std::vector<std::string> name_args = common::Split(name_args_cache[0], '?');
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
CHECK_EQ(name_args.size(), 2) << msg;
std::map<std::string, std::string> args;
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
args.insert(kv);
}
if (args.find("format") == args.cend()) {
LOG(FATAL) << msg;
}
auto path = common::Split(uri, '?')[0];
namespace fs = std::filesystem;
name_args[0] = fs::weakly_canonical(fs::u8path(path)).string();
if (name_args_cache.size() == 1) {
return name_args[0] + "?" + name_args[1];
} else {
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
}
}
} // namespace xgboost::data

View File

@ -4,46 +4,20 @@
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <algorithm> // for max_element
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include <utility> // for move
#include "array_interface.h"
#include "dmlc/data.h"
#include "xgboost/c_api.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "dmlc/data.h" // for RowBlock, Parser
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/logging.h" // for CHECK
namespace xgboost {
namespace data {
inline void ValidateFileFormat(std::string const& uri) {
std::vector<std::string> name_cache = common::Split(uri, '#');
CHECK_LE(name_cache.size(), 2)
<< "Only one `#` is allowed in file path for cachefile specification";
std::vector<std::string> name_args = common::Split(name_cache[0], '?');
CHECK_LE(name_args.size(), 2) << "only one `?` is allowed in file path.";
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
CHECK_EQ(name_args.size(), 2) << msg;
std::map<std::string, std::string> args;
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
args.insert(kv);
}
if (args.find("format") == args.cend()) {
LOG(FATAL) << msg;
}
}
namespace xgboost::data {
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
/**
* An iterator for implementing external memory support with file inputs. Users of
@ -72,8 +46,7 @@ class FileIterator {
public:
FileIterator(std::string uri, unsigned part_index, unsigned num_parts)
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts} {
ValidateFileFormat(uri_);
: uri_{ValidateFileFormat(std::move(uri))}, part_idx_{part_index}, n_parts_{num_parts} {
XGProxyDMatrixCreate(&proxy_);
}
~FileIterator() {
@ -132,6 +105,5 @@ inline int Next(DataIterHandle self) {
return static_cast<FileIterator*>(self)->Next();
}
} // namespace fileiter
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_FILE_ITERATOR_H_