xgboost/src/data/data.cc

295 lines
11 KiB
C++

/*!
* Copyright 2015 by Contributors
* \file data.cc
*/
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <dmlc/registry.h>
#include <cstring>
#include "./sparse_batch_page.h"
#include "./simple_dmatrix.h"
#include "./simple_csr_source.h"
#include "../common/io.h"
#if DMLC_ENABLE_STD_THREAD
#include "./sparse_page_source.h"
#include "./sparse_page_dmatrix.h"
#endif
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg);
} // namespace dmlc
namespace xgboost {
// implementation of inline functions
void MetaInfo::Clear() {
num_row = num_col = num_nonzero = 0;
labels.clear();
root_index.clear();
group_ptr.clear();
weights.clear();
base_margin.clear();
}
void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
int version = kVersion;
fo->Write(&version, sizeof(version));
fo->Write(&num_row, sizeof(num_row));
fo->Write(&num_col, sizeof(num_col));
fo->Write(&num_nonzero, sizeof(num_nonzero));
fo->Write(labels);
fo->Write(group_ptr);
fo->Write(weights);
fo->Write(root_index);
fo->Write(base_margin);
}
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
int version;
CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid version";
CHECK_EQ(version, kVersion) << "MetaInfo: invalid format";
CHECK(fi->Read(&num_row, sizeof(num_row)) == sizeof(num_row)) << "MetaInfo: invalid format";
CHECK(fi->Read(&num_col, sizeof(num_col)) == sizeof(num_col)) << "MetaInfo: invalid format";
CHECK(fi->Read(&num_nonzero, sizeof(num_nonzero)) == sizeof(num_nonzero))
<< "MetaInfo: invalid format";
CHECK(fi->Read(&labels)) << "MetaInfo: invalid format";
CHECK(fi->Read(&group_ptr)) << "MetaInfo: invalid format";
CHECK(fi->Read(&weights)) << "MetaInfo: invalid format";
CHECK(fi->Read(&root_index)) << "MetaInfo: invalid format";
CHECK(fi->Read(&base_margin)) << "MetaInfo: invalid format";
}
// try to load group information from file, if exists
inline bool MetaTryLoadGroup(const std::string& fname,
std::vector<unsigned>* group) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi.get() == nullptr) return false;
dmlc::istream is(fi.get());
group->clear();
group->push_back(0);
unsigned nline;
while (is >> nline) {
group->push_back(group->back() + nline);
}
return true;
}
// try to load weight information from file, if exists
inline bool MetaTryLoadFloatInfo(const std::string& fname,
std::vector<float>* data) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi.get() == nullptr) return false;
dmlc::istream is(fi.get());
data->clear();
float value;
while (is >> value) {
data->push_back(value);
}
return true;
}
// macro to dispatch according to specified pointer types
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
switch (dtype) { \
case kFloat32: { \
const float* cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; break; \
} \
case kDouble: { \
const double* cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; break; \
} \
case kUInt32: { \
const uint32_t* cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; break; \
} \
case kUInt64: { \
const uint64_t* cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; break; \
} \
default: LOG(FATAL) << "Unknown data type" << dtype; \
} \
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
if (!std::strcmp(key, "root_index")) {
root_index.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, root_index.begin()));
} else if (!std::strcmp(key, "label")) {
labels.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
} else if (!std::strcmp(key, "weight")) {
weights.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
} else if (!std::strcmp(key, "base_margin")) {
base_margin.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
}
}
DMatrix* DMatrix::Load(const std::string& uri,
bool silent,
bool load_row_split,
const std::string& file_format) {
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
cache_file = uri.substr(dlm_pos + 1, uri.length());
fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification.";
if (load_row_split) {
std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) {
size_t pos = cache_shards[i].rfind('.');
if (pos == std::string::npos) {
os << cache_shards[i]
<< ".r" << rabit::GetRank()
<< "-" << rabit::GetWorldSize();
} else {
os << cache_shards[i].substr(0, pos)
<< ".r" << rabit::GetRank()
<< "-" << rabit::GetWorldSize()
<< cache_shards[i].substr(pos, cache_shards[i].length());
}
if (i + 1 != cache_shards.size()) os << ':';
}
cache_file = os.str();
}
} else {
fname = uri;
}
int partid = 0, npart = 1;
if (load_row_split) {
partid = rabit::GetRank();
npart = rabit::GetWorldSize();
} else {
// test option to load in part
npart = dmlc::GetEnv("XGBOOST_TEST_NPART", 1);
}
if (npart != 1) {
LOG(CONSOLE) << "Load part of data " << partid
<< " of " << npart << " parts";
}
// legacy handling of binary data loading
if (file_format == "auto" && !load_row_split) {
int magic;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi.get() != nullptr) {
common::PeekableInStream is(fi.get());
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) &&
magic == data::SimpleCSRSource::kMagic) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
source->LoadBinary(&is);
DMatrix* dmat = DMatrix::Create(std::move(source), cache_file);
if (!silent) {
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
<< dmat->info().num_nonzero << " entries loaded from " << uri;
}
return dmat;
}
}
}
std::string ftype = file_format;
if (file_format == "auto") ftype = "libsvm";
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file);
if (!silent) {
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
<< dmat->info().num_nonzero << " entries loaded from " << uri;
}
// backward compatiblity code.
if (!load_row_split) {
MetaInfo& info = dmat->info();
if (MetaTryLoadGroup(fname + ".group", &info.group_ptr) && !silent) {
LOG(CONSOLE) << info.group_ptr.size() - 1
<< " groups are loaded from " << fname << ".group";
}
if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin) && !silent) {
LOG(CONSOLE) << info.base_margin.size()
<< " base_margin are loaded from " << fname << ".base_margin";
}
}
return dmat;
}
DMatrix* DMatrix::Create(dmlc::Parser<uint32_t>* parser,
const std::string& cache_prefix) {
if (cache_prefix.length() == 0) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
source->CopyFrom(parser);
return DMatrix::Create(std::move(source), cache_prefix);
} else {
#if DMLC_ENABLE_STD_THREAD
if (!data::SparsePageSource::CacheExist(cache_prefix)) {
data::SparsePageSource::Create(parser, cache_prefix);
}
std::unique_ptr<data::SparsePageSource> source(new data::SparsePageSource(cache_prefix));
return DMatrix::Create(std::move(source), cache_prefix);
#else
LOG(FATAL) << "External memory is not enabled in mingw";
return nullptr;
#endif
}
}
void DMatrix::SaveToLocalFile(const std::string& fname) {
data::SimpleCSRSource source;
source.CopyFrom(this);
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
source.SaveBinary(fo.get());
}
DMatrix* DMatrix::Create(std::unique_ptr<DataSource>&& source,
const std::string& cache_prefix) {
if (cache_prefix.length() == 0) {
return new data::SimpleDMatrix(std::move(source));
} else {
#if DMLC_ENABLE_STD_THREAD
return new data::SparsePageDMatrix(std::move(source), cache_prefix);
#else
LOG(FATAL) << "External memory is not enabled in mingw";
return nullptr;
#endif
}
}
} // namespace xgboost
namespace xgboost {
namespace data {
SparsePage::Format* SparsePage::Format::Create(const std::string& name) {
auto *e = ::dmlc::Registry< ::xgboost::data::SparsePageFormatReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown format type " << name;
}
return (e->body)();
}
std::pair<std::string, std::string>
SparsePage::Format::DecideFormat(const std::string& cache_prefix) {
size_t pos = cache_prefix.rfind(".fmt-");
if (pos != std::string::npos) {
std::string fmt = cache_prefix.substr(pos + 5, cache_prefix.length());
size_t cpos = fmt.rfind('-');
if (cpos != std::string::npos) {
return std::make_pair(fmt.substr(0, cpos), fmt.substr(cpos + 1, fmt.length()));
} else {
return std::make_pair(fmt, fmt);
}
} else {
std::string raw = "raw";
return std::make_pair(raw, raw);
}
}
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);
} // namespace data
} // namespace xgboost