Normalize file system path. (#9463)
This commit is contained in:
parent
bdc1a3c178
commit
bb56183396
@ -47,6 +47,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/data/data.o \
|
||||
$(PKGROOT)/src/data/sparse_page_raw_format.o \
|
||||
$(PKGROOT)/src/data/ellpack_page.o \
|
||||
$(PKGROOT)/src/data/file_iterator.o \
|
||||
$(PKGROOT)/src/data/gradient_index.o \
|
||||
$(PKGROOT)/src/data/gradient_index_page_source.o \
|
||||
$(PKGROOT)/src/data/gradient_index_format.o \
|
||||
|
||||
@ -47,6 +47,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/data/data.o \
|
||||
$(PKGROOT)/src/data/sparse_page_raw_format.o \
|
||||
$(PKGROOT)/src/data/ellpack_page.o \
|
||||
$(PKGROOT)/src/data/file_iterator.o \
|
||||
$(PKGROOT)/src/data/gradient_index.o \
|
||||
$(PKGROOT)/src/data/gradient_index_page_source.o \
|
||||
$(PKGROOT)/src/data/gradient_index_format.o \
|
||||
|
||||
@ -72,6 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
|
||||
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
|
||||
tmp_file <- tempfile(fileext = ".libsvm")
|
||||
writeLines(tmp, tmp_file)
|
||||
expect_true(file.exists(tmp_file))
|
||||
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
|
||||
expect_equal(dim(dtest4), c(3, 4))
|
||||
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, uint32_t
|
||||
#include <cstring> // for memcpy
|
||||
#include <filesystem> // for filesystem
|
||||
#include <filesystem> // for filesystem, weakly_canonical
|
||||
#include <fstream> // for ifstream
|
||||
#include <iterator> // for distance
|
||||
#include <limits> // for numeric_limits
|
||||
@ -154,7 +154,8 @@ std::string LoadSequentialFile(std::string uri, bool stream) {
|
||||
// Open in binary mode so that correct file size can be computed with
|
||||
// seekg(). This accommodates Windows platform:
|
||||
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
|
||||
std::ifstream ifs(std::filesystem::u8path(uri), std::ios_base::binary | std::ios_base::in);
|
||||
auto path = std::filesystem::weakly_canonical(std::filesystem::u8path(uri));
|
||||
std::ifstream ifs(path, std::ios_base::binary | std::ios_base::in);
|
||||
if (!ifs) {
|
||||
// https://stackoverflow.com/a/17338934
|
||||
OpenErr();
|
||||
|
||||
120
src/data/data.cc
120
src/data/data.cc
@ -4,42 +4,57 @@
|
||||
*/
|
||||
#include "xgboost/data.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <dmlc/registry.h> // for DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_LINK_TAG
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <algorithm> // for copy, max, none_of, min
|
||||
#include <atomic> // for atomic
|
||||
#include <cmath> // for abs
|
||||
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
|
||||
#include <cstring> // for size_t, strcmp, memcpy
|
||||
#include <exception> // for exception
|
||||
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
|
||||
#include <map> // for map, operator!=
|
||||
#include <numeric> // for accumulate, partial_sum
|
||||
#include <tuple> // for get, apply
|
||||
#include <type_traits> // for remove_pointer_t, remove_reference
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../collective/communicator.h"
|
||||
#include "../common/algorithm.h" // for StableSort
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/numeric.h" // for Iota
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/version.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/iterative_dmatrix.h"
|
||||
#include "./sparse_page_dmatrix.h"
|
||||
#include "./sparse_page_source.h"
|
||||
#include "dmlc/io.h"
|
||||
#include "file_iterator.h"
|
||||
#include "simple_dmatrix.h"
|
||||
#include "sparse_page_writer.h"
|
||||
#include "validation.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/linalg.h" // Vector
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h"
|
||||
#include "xgboost/version_config.h"
|
||||
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../common/algorithm.h" // for StableSort
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/common.h" // for Split
|
||||
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
|
||||
#include "../common/group_data.h" // for ParallelGroupBuilder
|
||||
#include "../common/io.h" // for PeekableInStream
|
||||
#include "../common/linalg_op.h" // for ElementWiseTransformHost
|
||||
#include "../common/math.h" // for CheckNAN
|
||||
#include "../common/numeric.h" // for Iota, RunLengthEncode
|
||||
#include "../common/threading_utils.h" // for ParallelFor
|
||||
#include "../common/version.h" // for Version
|
||||
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
|
||||
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
|
||||
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
|
||||
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
|
||||
#include "dmlc/base.h" // for BeginPtr
|
||||
#include "dmlc/common.h" // for OMPException
|
||||
#include "dmlc/data.h" // for Parser
|
||||
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
|
||||
#include "dmlc/io.h" // for Stream
|
||||
#include "dmlc/thread_local.h" // for ThreadLocalStore
|
||||
#include "ellpack_page.h" // for EllpackPage
|
||||
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
|
||||
#include "gradient_index.h" // for GHistIndexMatrix
|
||||
#include "simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "sparse_page_writer.h" // for SparsePageFormatReg
|
||||
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
|
||||
#include "xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/learner.h" // for HostDeviceVector
|
||||
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
|
||||
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
|
||||
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
|
||||
#include "xgboost/string_view.h" // for operator==, operator<<, StringView
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
|
||||
@ -811,10 +826,10 @@ DMatrix::~DMatrix() {
|
||||
}
|
||||
}
|
||||
|
||||
DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
int magic;
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||
namespace {
|
||||
DMatrix* TryLoadBinary(std::string fname, bool silent) {
|
||||
std::int32_t magic;
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||
if (fi != nullptr) {
|
||||
common::PeekableInStream is(fi.get());
|
||||
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic)) {
|
||||
@ -822,11 +837,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
dmlc::ByteSwap(&magic, sizeof(magic), 1);
|
||||
}
|
||||
if (magic == data::SimpleDMatrix::kMagic) {
|
||||
DMatrix *dmat = new data::SimpleDMatrix(&is);
|
||||
DMatrix* dmat = new data::SimpleDMatrix(&is);
|
||||
if (!silent) {
|
||||
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_
|
||||
<< " matrix with " << dmat->Info().num_nonzero_
|
||||
<< " entries loaded from " << fname;
|
||||
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
|
||||
<< dmat->Info().num_nonzero_ << " entries loaded from " << fname;
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
@ -834,6 +848,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
|
||||
auto need_split = false;
|
||||
@ -845,7 +860,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
}
|
||||
|
||||
std::string fname, cache_file;
|
||||
size_t dlm_pos = uri.find('#');
|
||||
auto dlm_pos = uri.find('#');
|
||||
if (dlm_pos != std::string::npos) {
|
||||
cache_file = uri.substr(dlm_pos + 1, uri.length());
|
||||
fname = uri.substr(0, dlm_pos);
|
||||
@ -857,14 +872,11 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||
size_t pos = cache_shards[i].rfind('.');
|
||||
if (pos == std::string::npos) {
|
||||
os << cache_shards[i]
|
||||
<< ".r" << collective::GetRank()
|
||||
<< "-" << collective::GetWorldSize();
|
||||
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
|
||||
<< collective::GetWorldSize();
|
||||
} else {
|
||||
os << cache_shards[i].substr(0, pos)
|
||||
<< ".r" << collective::GetRank()
|
||||
<< "-" << collective::GetWorldSize()
|
||||
<< cache_shards[i].substr(pos, cache_shards[i].length());
|
||||
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
|
||||
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
|
||||
}
|
||||
if (i + 1 != cache_shards.size()) {
|
||||
os << ':';
|
||||
@ -895,12 +907,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
|
||||
}
|
||||
|
||||
data::ValidateFileFormat(fname);
|
||||
DMatrix* dmat {nullptr};
|
||||
DMatrix* dmat{nullptr};
|
||||
|
||||
if (cache_file.empty()) {
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
|
||||
fname = data::ValidateFileFormat(fname);
|
||||
std::unique_ptr<dmlc::Parser<std::uint32_t>> parser(
|
||||
dmlc::Parser<std::uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||
cache_file, data_split_mode);
|
||||
|
||||
51
src/data/file_iterator.cc
Normal file
51
src/data/file_iterator.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost contributors
|
||||
*/
|
||||
#include "file_iterator.h"
|
||||
|
||||
#include <xgboost/logging.h> // for LogCheck_EQ, LogCheck_LE, CHECK_EQ, CHECK_LE, LOG, LOG_...
|
||||
|
||||
#include <filesystem> // for weakly_canonical, path, u8path
|
||||
#include <map> // for map, operator==
|
||||
#include <ostream> // for operator<<, basic_ostream, istringstream
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/common.h" // for Split
|
||||
#include "xgboost/string_view.h" // for operator<<, StringView
|
||||
|
||||
namespace xgboost::data {
|
||||
std::string ValidateFileFormat(std::string const& uri) {
|
||||
std::vector<std::string> name_args_cache = common::Split(uri, '#');
|
||||
CHECK_LE(name_args_cache.size(), 2)
|
||||
<< "Only one `#` is allowed in file path for cachefile specification";
|
||||
|
||||
std::vector<std::string> name_args = common::Split(name_args_cache[0], '?');
|
||||
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
|
||||
CHECK_EQ(name_args.size(), 2) << msg;
|
||||
|
||||
std::map<std::string, std::string> args;
|
||||
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
|
||||
for (size_t i = 0; i < arg_list.size(); ++i) {
|
||||
std::istringstream is(arg_list[i]);
|
||||
std::pair<std::string, std::string> kv;
|
||||
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
|
||||
<< " for key in arg " << i + 1;
|
||||
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
|
||||
<< " for value in arg " << i + 1;
|
||||
args.insert(kv);
|
||||
}
|
||||
if (args.find("format") == args.cend()) {
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
|
||||
auto path = common::Split(uri, '?')[0];
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
name_args[0] = fs::weakly_canonical(fs::u8path(path)).string();
|
||||
if (name_args_cache.size() == 1) {
|
||||
return name_args[0] + "?" + name_args[1];
|
||||
} else {
|
||||
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
@ -4,46 +4,20 @@
|
||||
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
|
||||
#define XGBOOST_DATA_FILE_ITERATOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm> // for max_element
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
#include "array_interface.h"
|
||||
#include "dmlc/data.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "dmlc/data.h" // for RowBlock, Parser
|
||||
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
|
||||
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
inline void ValidateFileFormat(std::string const& uri) {
|
||||
std::vector<std::string> name_cache = common::Split(uri, '#');
|
||||
CHECK_LE(name_cache.size(), 2)
|
||||
<< "Only one `#` is allowed in file path for cachefile specification";
|
||||
|
||||
std::vector<std::string> name_args = common::Split(name_cache[0], '?');
|
||||
CHECK_LE(name_args.size(), 2) << "only one `?` is allowed in file path.";
|
||||
|
||||
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
|
||||
CHECK_EQ(name_args.size(), 2) << msg;
|
||||
|
||||
std::map<std::string, std::string> args;
|
||||
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
|
||||
for (size_t i = 0; i < arg_list.size(); ++i) {
|
||||
std::istringstream is(arg_list[i]);
|
||||
std::pair<std::string, std::string> kv;
|
||||
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
|
||||
<< " for key in arg " << i + 1;
|
||||
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
|
||||
<< " for value in arg " << i + 1;
|
||||
args.insert(kv);
|
||||
}
|
||||
if (args.find("format") == args.cend()) {
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
}
|
||||
namespace xgboost::data {
|
||||
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
|
||||
|
||||
/**
|
||||
* An iterator for implementing external memory support with file inputs. Users of
|
||||
@ -72,8 +46,7 @@ class FileIterator {
|
||||
|
||||
public:
|
||||
FileIterator(std::string uri, unsigned part_index, unsigned num_parts)
|
||||
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts} {
|
||||
ValidateFileFormat(uri_);
|
||||
: uri_{ValidateFileFormat(std::move(uri))}, part_idx_{part_index}, n_parts_{num_parts} {
|
||||
XGProxyDMatrixCreate(&proxy_);
|
||||
}
|
||||
~FileIterator() {
|
||||
@ -132,6 +105,5 @@ inline int Next(DataIterHandle self) {
|
||||
return static_cast<FileIterator*>(self)->Next();
|
||||
}
|
||||
} // namespace fileiter
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_FILE_ITERATOR_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user