Merge pull request #816 from tqchen/master
[DISK] Major improvements in external memory, add support to group back
This commit is contained in:
commit
70d9732765
@ -1 +1 @@
|
|||||||
Subproject commit ad2ddde8b6624abf3007a71b2923c3925530cc81
|
Subproject commit 0f8fd38bf94e6666aa367be80195b1f2da87428c
|
||||||
@ -62,9 +62,7 @@ struct bst_gpair {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief small eps gap for minimum split decision. */
|
/*! \brief small eps gap for minimum split decision. */
|
||||||
const float rt_eps = 1e-5f;
|
const float rt_eps = 1e-6f;
|
||||||
/*! \brief min gap between feature values to allow a split happen */
|
|
||||||
const float rt_2eps = rt_eps * 2.0f;
|
|
||||||
|
|
||||||
/*! \brief define unsigned long for openmp loop */
|
/*! \brief define unsigned long for openmp loop */
|
||||||
typedef dmlc::omp_ulong omp_ulong;
|
typedef dmlc::omp_ulong omp_ulong;
|
||||||
|
|||||||
@ -183,6 +183,41 @@ class DataSource : public dmlc::DataIter<RowBatch> {
|
|||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief A vector-like structure to represent set of rows.
|
||||||
|
* But saves the memory when all rows are in the set (common case in xgb)
|
||||||
|
*/
|
||||||
|
struct RowSet {
|
||||||
|
public:
|
||||||
|
/*! \return i-th row index */
|
||||||
|
inline bst_uint operator[](size_t i) const;
|
||||||
|
/*! \return the size of the set. */
|
||||||
|
inline size_t size() const;
|
||||||
|
/*! \brief push the index back to the set */
|
||||||
|
inline void push_back(bst_uint i);
|
||||||
|
/*! \brief clear the set */
|
||||||
|
inline void clear();
|
||||||
|
/*!
|
||||||
|
* \brief save rowset to file.
|
||||||
|
* \param fo The file to be saved.
|
||||||
|
*/
|
||||||
|
inline void Save(dmlc::Stream* fo) const;
|
||||||
|
/*!
|
||||||
|
* \brief Load rowset from file.
|
||||||
|
* \param fi The file to be loaded.
|
||||||
|
* \return if read is successful.
|
||||||
|
*/
|
||||||
|
inline bool Load(dmlc::Stream* fi);
|
||||||
|
/*! \brief constructor */
|
||||||
|
RowSet() : size_(0) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief The internal data structure of size */
|
||||||
|
uint64_t size_;
|
||||||
|
/*! \brief The internal data structure of row set if not all*/
|
||||||
|
std::vector<bst_uint> rows_;
|
||||||
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Internal data structured used by XGBoost during training.
|
* \brief Internal data structured used by XGBoost during training.
|
||||||
* There are two ways to create a customized DMatrix that reads in user defined-format.
|
* There are two ways to create a customized DMatrix that reads in user defined-format.
|
||||||
@ -235,7 +270,7 @@ class DMatrix {
|
|||||||
/*! \brief get column density */
|
/*! \brief get column density */
|
||||||
virtual float GetColDensity(size_t cidx) const = 0;
|
virtual float GetColDensity(size_t cidx) const = 0;
|
||||||
/*! \return reference of buffered rowset, in column access */
|
/*! \return reference of buffered rowset, in column access */
|
||||||
virtual const std::vector<bst_uint>& buffered_rowset() const = 0;
|
virtual const RowSet& buffered_rowset() const = 0;
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
virtual ~DMatrix() {}
|
virtual ~DMatrix() {}
|
||||||
/*!
|
/*!
|
||||||
@ -290,9 +325,48 @@ class DMatrix {
|
|||||||
LearnerImpl* cache_learner_ptr_;
|
LearnerImpl* cache_learner_ptr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// implementation of inline functions
|
||||||
|
inline bst_uint RowSet::operator[](size_t i) const {
|
||||||
|
return rows_.size() == 0 ? i : rows_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t RowSet::size() const {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void RowSet::clear() {
|
||||||
|
rows_.clear(); size_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void RowSet::push_back(bst_uint i) {
|
||||||
|
if (rows_.size() == 0) {
|
||||||
|
if (i == size_) {
|
||||||
|
++size_; return;
|
||||||
|
} else {
|
||||||
|
rows_.resize(size_);
|
||||||
|
for (size_t i = 0; i < size_; ++i) {
|
||||||
|
rows_[i] = static_cast<bst_uint>(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows_.push_back(i);
|
||||||
|
++size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void RowSet::Save(dmlc::Stream* fo) const {
|
||||||
|
fo->Write(rows_);
|
||||||
|
fo->Write(&size_, sizeof(size_));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool RowSet::Load(dmlc::Stream* fi) {
|
||||||
|
if (!fi->Read(&rows_)) return false;
|
||||||
|
if (rows_.size() != 0) return true;
|
||||||
|
return fi->Read(&size_, sizeof(size_)) == sizeof(size_);
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
namespace dmlc {
|
namespace dmlc {
|
||||||
DMLC_DECLARE_TRAITS(is_pod, xgboost::SparseBatch::Entry, true);
|
DMLC_DECLARE_TRAITS(is_pod, xgboost::SparseBatch::Entry, true);
|
||||||
|
DMLC_DECLARE_TRAITS(has_saveload, xgboost::RowSet, true);
|
||||||
}
|
}
|
||||||
#endif // XGBOOST_DATA_H_
|
#endif // XGBOOST_DATA_H_
|
||||||
|
|||||||
@ -31,3 +31,4 @@ LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server
|
|||||||
#
|
#
|
||||||
XGB_PLUGINS += plugin/example/plugin.mk
|
XGB_PLUGINS += plugin/example/plugin.mk
|
||||||
XGB_PLUGINS += plugin/lz4/plugin.mk
|
XGB_PLUGINS += plugin/lz4/plugin.mk
|
||||||
|
XGB_PLUGINS += plugin/dense_parser/plugin.mk
|
||||||
|
|||||||
86
plugin/dense_parser/dense_libsvm.cc
Normal file
86
plugin/dense_parser/dense_libsvm.cc
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2015 by Contributors
|
||||||
|
* \file dense_libsvm.cc
|
||||||
|
* \brief Plugin to load in libsvm, but fill all the missing entries with zeros.
|
||||||
|
* This plugin is mainly used for benchmark purposes and do not need to be included.
|
||||||
|
*/
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include <dmlc/data.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace dmlc {
|
||||||
|
namespace data {
|
||||||
|
|
||||||
|
template<typename IndexType>
|
||||||
|
class DensifyParser : public dmlc::Parser<IndexType> {
|
||||||
|
public:
|
||||||
|
DensifyParser(dmlc::Parser<IndexType>* parser, uint32_t num_col)
|
||||||
|
: parser_(parser), num_col_(num_col) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void BeforeFirst() override {
|
||||||
|
parser_->BeforeFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Next() override {
|
||||||
|
if (!parser_->Next()) return false;
|
||||||
|
const RowBlock<IndexType>& batch = parser_->Value();
|
||||||
|
LOG(INFO) << batch.size;
|
||||||
|
dense_index_.resize(num_col_ * batch.size);
|
||||||
|
dense_value_.resize(num_col_ * batch.size);
|
||||||
|
std::fill(dense_value_.begin(), dense_value_.end(), 0.0f);
|
||||||
|
offset_.resize(batch.size + 1);
|
||||||
|
offset_[0] = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
|
offset_[i + 1] = (i + 1) * num_col_;
|
||||||
|
Row<IndexType> row = batch[i];
|
||||||
|
for (uint32_t j = 0; j < num_col_; ++j) {
|
||||||
|
dense_index_[i * num_col_ + j] = j;
|
||||||
|
}
|
||||||
|
for (unsigned k = 0; k < row.length; ++k) {
|
||||||
|
uint32_t index = row.get_index(k);
|
||||||
|
CHECK_LT(index, num_col_)
|
||||||
|
<< "Featuere index larger than num_col";
|
||||||
|
dense_value_[i * num_col_ + index] = row.get_value(k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out_ = batch;
|
||||||
|
out_.index = dmlc::BeginPtr(dense_index_);
|
||||||
|
out_.value = dmlc::BeginPtr(dense_value_);
|
||||||
|
out_.offset = dmlc::BeginPtr(offset_);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dmlc::RowBlock<IndexType>& Value() const override {
|
||||||
|
return out_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BytesRead() const override {
|
||||||
|
return parser_->BytesRead();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
RowBlock<IndexType> out_;
|
||||||
|
std::unique_ptr<Parser<IndexType> > parser_;
|
||||||
|
uint32_t num_col_;
|
||||||
|
std::vector<size_t> offset_;
|
||||||
|
std::vector<IndexType> dense_index_;
|
||||||
|
std::vector<float> dense_value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename IndexType>
|
||||||
|
Parser<IndexType> *
|
||||||
|
CreateDenseLibSVMParser(const std::string& path,
|
||||||
|
const std::map<std::string, std::string>& args,
|
||||||
|
unsigned part_index,
|
||||||
|
unsigned num_parts) {
|
||||||
|
CHECK_NE(args.count("num_col"), 0) << "expect num_col in dense_libsvm";
|
||||||
|
return new DensifyParser<IndexType>(
|
||||||
|
Parser<IndexType>::Create(path.c_str(), part_index, num_parts, "libsvm"),
|
||||||
|
uint32_t(atoi(args.at("num_col").c_str())));
|
||||||
|
}
|
||||||
|
} // namespace data
|
||||||
|
|
||||||
|
DMLC_REGISTER_DATA_PARSER(uint32_t, dense_libsvm, data::CreateDenseLibSVMParser<uint32_t>);
|
||||||
|
} // namespace dmlc
|
||||||
2
plugin/dense_parser/plugin.mk
Normal file
2
plugin/dense_parser/plugin.mk
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
PLUGIN_OBJS += build_plugin/dense_parser/dense_libsvm.o
|
||||||
|
PLUGIN_LDFLAGS +=
|
||||||
@ -318,6 +318,7 @@ int CLIRunTask(int argc, char *argv[]) {
|
|||||||
printf("Usage: <config>\n");
|
printf("Usage: <config>\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
rabit::Init(argc, argv);
|
||||||
|
|
||||||
std::vector<std::pair<std::string, std::string> > cfg;
|
std::vector<std::pair<std::string, std::string> > cfg;
|
||||||
cfg.push_back(std::make_pair("seed", "0"));
|
cfg.push_back(std::make_pair("seed", "0"));
|
||||||
@ -336,7 +337,6 @@ int CLIRunTask(int argc, char *argv[]) {
|
|||||||
CLIParam param;
|
CLIParam param;
|
||||||
param.Configure(cfg);
|
param.Configure(cfg);
|
||||||
|
|
||||||
rabit::Init(argc, argv);
|
|
||||||
switch (param.task) {
|
switch (param.task) {
|
||||||
case kTrain: CLITrain(param); break;
|
case kTrain: CLITrain(param); break;
|
||||||
case kDump2Text: CLIDump2Text(param); break;
|
case kDump2Text: CLIDump2Text(param); break;
|
||||||
|
|||||||
31
src/common/common.h
Normal file
31
src/common/common.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2015 by Contributors
|
||||||
|
* \file common.h
|
||||||
|
* \brief Common utilities
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_COMMON_H_
|
||||||
|
#define XGBOOST_COMMON_COMMON_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
/*!
|
||||||
|
* \brief Split a string by delimiter
|
||||||
|
* \param s String to be splitted.
|
||||||
|
* \param delim The delimiter.
|
||||||
|
*/
|
||||||
|
inline std::vector<std::string> Split(const std::string& s, char delim) {
|
||||||
|
std::string item;
|
||||||
|
std::istringstream is(s);
|
||||||
|
std::vector<std::string> ret;
|
||||||
|
while (std::getline(is, item, delim)) {
|
||||||
|
ret.push_back(item);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_COMMON_H_
|
||||||
@ -9,6 +9,7 @@
|
|||||||
#include "./sparse_batch_page.h"
|
#include "./sparse_batch_page.h"
|
||||||
#include "./simple_dmatrix.h"
|
#include "./simple_dmatrix.h"
|
||||||
#include "./simple_csr_source.h"
|
#include "./simple_csr_source.h"
|
||||||
|
#include "../common/common.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
|
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
@ -124,6 +125,14 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
|||||||
base_margin.resize(num);
|
base_margin.resize(num);
|
||||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||||
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
|
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
|
||||||
|
} else if (!std::strcmp(key, "group")) {
|
||||||
|
group_ptr.resize(num + 1);
|
||||||
|
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||||
|
std::copy(cast_dptr, cast_dptr + num, group_ptr.begin() + 1));
|
||||||
|
group_ptr[0] = 0;
|
||||||
|
for (size_t i = 1; i < group_ptr.size(); ++i) {
|
||||||
|
group_ptr[i] = group_ptr[i - 1] + group_ptr[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,7 +150,21 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
<< "Only one `#` is allowed in file path for cache file specification.";
|
<< "Only one `#` is allowed in file path for cache file specification.";
|
||||||
if (load_row_split) {
|
if (load_row_split) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << cache_file << ".r" << rabit::GetRank();
|
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
|
||||||
|
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||||
|
size_t pos = cache_shards[i].rfind('.');
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
os << cache_shards[i]
|
||||||
|
<< ".r" << rabit::GetRank()
|
||||||
|
<< "-" << rabit::GetWorldSize();
|
||||||
|
} else {
|
||||||
|
os << cache_shards[i].substr(0, pos)
|
||||||
|
<< ".r" << rabit::GetRank()
|
||||||
|
<< "-" << rabit::GetWorldSize()
|
||||||
|
<< cache_shards[i].substr(pos, cache_shards[i].length());
|
||||||
|
}
|
||||||
|
if (i + 1 != cache_shards.size()) os << ':';
|
||||||
|
}
|
||||||
cache_file = os.str();
|
cache_file = os.str();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -154,9 +177,11 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
} else {
|
} else {
|
||||||
// test option to load in part
|
// test option to load in part
|
||||||
npart = dmlc::GetEnv("XGBOOST_TEST_NPART", 1);
|
npart = dmlc::GetEnv("XGBOOST_TEST_NPART", 1);
|
||||||
if (npart != 1) {
|
}
|
||||||
LOG(CONSOLE) << "Partial load option on npart=" << npart;
|
|
||||||
}
|
if (npart != 1) {
|
||||||
|
LOG(CONSOLE) << "Load part of data " << partid
|
||||||
|
<< " of " << npart << " parts";
|
||||||
}
|
}
|
||||||
// legacy handling of binary data loading
|
// legacy handling of binary data loading
|
||||||
if (file_format == "auto" && !load_row_split) {
|
if (file_format == "auto" && !load_row_split) {
|
||||||
@ -181,7 +206,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
std::string ftype = file_format;
|
std::string ftype = file_format;
|
||||||
if (file_format == "auto") ftype = "libsvm";
|
if (file_format == "auto") ftype = "libsvm";
|
||||||
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
|
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
|
||||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, ftype.c_str()));
|
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
||||||
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file);
|
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file);
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
|
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
|
||||||
|
|||||||
@ -41,7 +41,6 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
|||||||
if (batch.weight != nullptr) {
|
if (batch.weight != nullptr) {
|
||||||
info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size);
|
info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size);
|
||||||
}
|
}
|
||||||
row_data_.reserve(row_data_.size() + batch.offset[batch.size] - batch.offset[0]);
|
|
||||||
CHECK(batch.index != nullptr);
|
CHECK(batch.index != nullptr);
|
||||||
// update information
|
// update information
|
||||||
this->info.num_row += batch.size;
|
this->info.num_row += batch.size;
|
||||||
@ -54,9 +53,8 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
|||||||
static_cast<uint64_t>(index + 1));
|
static_cast<uint64_t>(index + 1));
|
||||||
}
|
}
|
||||||
size_t top = row_ptr_.size();
|
size_t top = row_ptr_.size();
|
||||||
row_ptr_.resize(top + batch.size);
|
|
||||||
for (size_t i = 0; i < batch.size; ++i) {
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
row_ptr_[top + i] = row_ptr_[top - 1] + batch.offset[i + 1] - batch.offset[0];
|
row_ptr_.push_back(row_ptr_[top - 1] + batch.offset[i + 1] - batch.offset[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this->info.num_nonzero = static_cast<uint64_t>(row_data_.size());
|
this->info.num_nonzero = static_cast<uint64_t>(row_data_.size());
|
||||||
|
|||||||
@ -184,9 +184,7 @@ void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
|
|||||||
}
|
}
|
||||||
if (tmp.Size() >= max_row_perbatch) {
|
if (tmp.Size() >= max_row_perbatch) {
|
||||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||||
this->MakeColPage(tmp.GetRowBatch(0),
|
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get());
|
||||||
dmlc::BeginPtr(buffered_rowset_) + btop,
|
|
||||||
enabled, page.get());
|
|
||||||
col_iter_.cpages_.push_back(std::move(page));
|
col_iter_.cpages_.push_back(std::move(page));
|
||||||
btop = buffered_rowset_.size();
|
btop = buffered_rowset_.size();
|
||||||
tmp.Clear();
|
tmp.Clear();
|
||||||
@ -196,16 +194,14 @@ void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
|
|||||||
|
|
||||||
if (tmp.Size() != 0) {
|
if (tmp.Size() != 0) {
|
||||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||||
this->MakeColPage(tmp.GetRowBatch(0),
|
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get());
|
||||||
dmlc::BeginPtr(buffered_rowset_) + btop,
|
|
||||||
enabled, page.get());
|
|
||||||
col_iter_.cpages_.push_back(std::move(page));
|
col_iter_.cpages_.push_back(std::move(page));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// make column page from subset of rowbatchs
|
// make column page from subset of rowbatchs
|
||||||
void SimpleDMatrix::MakeColPage(const RowBatch& batch,
|
void SimpleDMatrix::MakeColPage(const RowBatch& batch,
|
||||||
const bst_uint* ridx,
|
size_t buffer_begin,
|
||||||
const std::vector<bool>& enabled,
|
const std::vector<bool>& enabled,
|
||||||
SparsePage* pcol) {
|
SparsePage* pcol) {
|
||||||
int nthread;
|
int nthread;
|
||||||
@ -240,9 +236,10 @@ void SimpleDMatrix::MakeColPage(const RowBatch& batch,
|
|||||||
RowBatch::Inst inst = batch[i];
|
RowBatch::Inst inst = batch[i];
|
||||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||||
const SparseBatch::Entry &e = inst[j];
|
const SparseBatch::Entry &e = inst[j];
|
||||||
builder.Push(e.index,
|
builder.Push(
|
||||||
SparseBatch::Entry(ridx[i], e.fvalue),
|
e.index,
|
||||||
tid);
|
SparseBatch::Entry(buffered_rowset_[i + buffer_begin], e.fvalue),
|
||||||
|
tid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_EQ(pcol->Size(), info().num_col);
|
CHECK_EQ(pcol->Size(), info().num_col);
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
return col_size_.size() != 0;
|
return col_size_.size() != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<bst_uint>& buffered_rowset() const override {
|
const RowSet& buffered_rowset() const override {
|
||||||
return buffered_rowset_;
|
return buffered_rowset_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
// column iterator
|
// column iterator
|
||||||
ColBatchIter col_iter_;
|
ColBatchIter col_iter_;
|
||||||
// list of row index that are buffered.
|
// list of row index that are buffered.
|
||||||
std::vector<bst_uint> buffered_rowset_;
|
RowSet buffered_rowset_;
|
||||||
/*! \brief sizeof column data */
|
/*! \brief sizeof column data */
|
||||||
std::vector<size_t> col_size_;
|
std::vector<size_t> col_size_;
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
size_t max_row_perbatch);
|
size_t max_row_perbatch);
|
||||||
|
|
||||||
void MakeColPage(const RowBatch& batch,
|
void MakeColPage(const RowBatch& batch,
|
||||||
const bst_uint* ridx,
|
size_t buffer_begin,
|
||||||
const std::vector<bool>& enabled,
|
const std::vector<bool>& enabled,
|
||||||
SparsePage* pcol);
|
SparsePage* pcol);
|
||||||
};
|
};
|
||||||
|
|||||||
@ -16,6 +16,12 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
|
#include <dmlc/concurrency.h>
|
||||||
|
#include <thread>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -26,6 +32,8 @@ class SparsePage {
|
|||||||
public:
|
public:
|
||||||
/*! \brief Format of the sparse page. */
|
/*! \brief Format of the sparse page. */
|
||||||
class Format;
|
class Format;
|
||||||
|
/*! \brief Writer to write the sparse page to files. */
|
||||||
|
class Writer;
|
||||||
/*! \brief minimum index of all index, used as hint for compression. */
|
/*! \brief minimum index of all index, used as hint for compression. */
|
||||||
bst_uint min_index;
|
bst_uint min_index;
|
||||||
/*! \brief offset of the segments */
|
/*! \brief offset of the segments */
|
||||||
@ -171,6 +179,53 @@ class SparsePage::Format {
|
|||||||
static std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix);
|
static std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
|
/*!
|
||||||
|
* \brief A threaded writer to write sparse batch page to sharded files.
|
||||||
|
*/
|
||||||
|
class SparsePage::Writer {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief constructor
|
||||||
|
* \param name_shards name of shard files.
|
||||||
|
* \param format_shards format of each shard.
|
||||||
|
* \param extra_buffer_capacity Extra buffer capacity before block.
|
||||||
|
*/
|
||||||
|
explicit Writer(
|
||||||
|
const std::vector<std::string>& name_shards,
|
||||||
|
const std::vector<std::string>& format_shards,
|
||||||
|
size_t extra_buffer_capacity);
|
||||||
|
/*! \brief destructor, will close the files automatically */
|
||||||
|
~Writer();
|
||||||
|
/*!
|
||||||
|
* \brief Push a write job to the writer.
|
||||||
|
* This function won't block,
|
||||||
|
* writing is done by another thread inside writer.
|
||||||
|
* \param page The page to be wriiten
|
||||||
|
*/
|
||||||
|
void PushWrite(std::unique_ptr<SparsePage>&& page);
|
||||||
|
/*!
|
||||||
|
* \brief Allocate a page to store results.
|
||||||
|
* This function can block when the writer is too slow and buffer pages
|
||||||
|
* have not yet been recycled.
|
||||||
|
* \param out_page Used to store the allocated pages.
|
||||||
|
*/
|
||||||
|
void Alloc(std::unique_ptr<SparsePage>* out_page);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief number of allocated pages */
|
||||||
|
size_t num_free_buffer_;
|
||||||
|
/*! \brief clock_pointer */
|
||||||
|
size_t clock_ptr_;
|
||||||
|
/*! \brief writer threads */
|
||||||
|
std::vector<std::unique_ptr<std::thread> > workers_;
|
||||||
|
/*! \brief recycler queue */
|
||||||
|
dmlc::ConcurrentBlockingQueue<std::unique_ptr<SparsePage> > qrecycle_;
|
||||||
|
/*! \brief worker threads */
|
||||||
|
std::vector<dmlc::ConcurrentBlockingQueue<std::unique_ptr<SparsePage> > > qworkers_;
|
||||||
|
};
|
||||||
|
#endif // DMLC_ENABLE_STD_THREAD
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Registry entry for sparse page format.
|
* \brief Registry entry for sparse page format.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -12,34 +12,42 @@
|
|||||||
#if DMLC_ENABLE_STD_THREAD
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
#include "./sparse_page_dmatrix.h"
|
#include "./sparse_page_dmatrix.h"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
|
#include "../common/common.h"
|
||||||
#include "../common/group_data.h"
|
#include "../common/group_data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
SparsePageDMatrix::ColPageIter::ColPageIter(std::unique_ptr<dmlc::SeekStream>&& fi)
|
SparsePageDMatrix::ColPageIter::ColPageIter(
|
||||||
: fi_(std::move(fi)), page_(nullptr) {
|
std::vector<std::unique_ptr<dmlc::SeekStream> >&& files)
|
||||||
|
: page_(nullptr), clock_ptr_(0), files_(std::move(files)) {
|
||||||
load_all_ = false;
|
load_all_ = false;
|
||||||
|
formats_.resize(files_.size());
|
||||||
|
prefetchers_.resize(files_.size());
|
||||||
|
|
||||||
std::string format;
|
for (size_t i = 0; i < files_.size(); ++i) {
|
||||||
CHECK(fi_->Read(&format)) << "Invalid page format";
|
dmlc::SeekStream* fi = files_[i].get();
|
||||||
format_.reset(SparsePage::Format::Create(format));
|
std::string format;
|
||||||
size_t fbegin = fi_->Tell();
|
CHECK(fi->Read(&format)) << "Invalid page format";
|
||||||
|
formats_[i].reset(SparsePage::Format::Create(format));
|
||||||
prefetcher_.Init([this](SparsePage** dptr) {
|
SparsePage::Format* fmt = formats_[i].get();
|
||||||
if (*dptr == nullptr) {
|
size_t fbegin = fi->Tell();
|
||||||
*dptr = new SparsePage();
|
prefetchers_[i].reset(new dmlc::ThreadedIter<SparsePage>(4));
|
||||||
}
|
prefetchers_[i]->Init([this, fi, fmt] (SparsePage** dptr) {
|
||||||
if (load_all_) {
|
if (*dptr == nullptr) {
|
||||||
return format_->Read(*dptr, fi_.get());
|
*dptr = new SparsePage();
|
||||||
} else {
|
}
|
||||||
return format_->Read(*dptr, fi_.get(), index_set_);
|
if (load_all_) {
|
||||||
}
|
return fmt->Read(*dptr, fi);
|
||||||
}, [this, fbegin] () {
|
} else {
|
||||||
fi_->Seek(fbegin);
|
return fmt->Read(*dptr, fi, index_set_);
|
||||||
index_set_ = set_index_set_;
|
}
|
||||||
load_all_ = set_load_all_;
|
}, [this, fi, fbegin] () {
|
||||||
});
|
fi->Seek(fbegin);
|
||||||
|
index_set_ = set_index_set_;
|
||||||
|
load_all_ = set_load_all_;
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SparsePageDMatrix::ColPageIter::~ColPageIter() {
|
SparsePageDMatrix::ColPageIter::~ColPageIter() {
|
||||||
@ -47,10 +55,12 @@ SparsePageDMatrix::ColPageIter::~ColPageIter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool SparsePageDMatrix::ColPageIter::Next() {
|
bool SparsePageDMatrix::ColPageIter::Next() {
|
||||||
|
// doing clock rotation over shards.
|
||||||
if (page_ != nullptr) {
|
if (page_ != nullptr) {
|
||||||
prefetcher_.Recycle(&page_);
|
size_t n = prefetchers_.size();
|
||||||
|
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
|
||||||
}
|
}
|
||||||
if (prefetcher_.Next(&page_)) {
|
if (prefetchers_[clock_ptr_]->Next(&page_)) {
|
||||||
out_.col_index = dmlc::BeginPtr(index_set_);
|
out_.col_index = dmlc::BeginPtr(index_set_);
|
||||||
col_data_.resize(page_->offset.size() - 1, SparseBatch::Inst(nullptr, 0));
|
col_data_.resize(page_->offset.size() - 1, SparseBatch::Inst(nullptr, 0));
|
||||||
for (size_t i = 0; i < col_data_.size(); ++i) {
|
for (size_t i = 0; i < col_data_.size(); ++i) {
|
||||||
@ -60,18 +70,26 @@ bool SparsePageDMatrix::ColPageIter::Next() {
|
|||||||
}
|
}
|
||||||
out_.col_data = dmlc::BeginPtr(col_data_);
|
out_.col_data = dmlc::BeginPtr(col_data_);
|
||||||
out_.size = col_data_.size();
|
out_.size = col_data_.size();
|
||||||
|
// advance clock
|
||||||
|
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SparsePageDMatrix::ColPageIter::BeforeFirst() {
|
||||||
|
clock_ptr_ = 0;
|
||||||
|
for (auto& p : prefetchers_) {
|
||||||
|
p->BeforeFirst();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SparsePageDMatrix::ColPageIter::Init(const std::vector<bst_uint>& index_set,
|
void SparsePageDMatrix::ColPageIter::Init(const std::vector<bst_uint>& index_set,
|
||||||
bool load_all) {
|
bool load_all) {
|
||||||
set_index_set_ = index_set;
|
set_index_set_ = index_set;
|
||||||
set_load_all_ = load_all;
|
set_load_all_ = load_all;
|
||||||
std::sort(set_index_set_.begin(), set_index_set_.end());
|
std::sort(set_index_set_.begin(), set_index_set_.end());
|
||||||
|
|
||||||
this->BeforeFirst();
|
this->BeforeFirst();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,8 +121,9 @@ ColIterator(const std::vector<bst_uint>& fset) {
|
|||||||
|
|
||||||
bool SparsePageDMatrix::TryInitColData() {
|
bool SparsePageDMatrix::TryInitColData() {
|
||||||
// load meta data.
|
// load meta data.
|
||||||
|
std::vector<std::string> cache_shards = common::Split(cache_info_, ':');
|
||||||
{
|
{
|
||||||
std::string col_meta_name = cache_prefix_ + ".col.meta";
|
std::string col_meta_name = cache_shards[0] + ".col.meta";
|
||||||
std::unique_ptr<dmlc::Stream> fmeta(
|
std::unique_ptr<dmlc::Stream> fmeta(
|
||||||
dmlc::Stream::Create(col_meta_name.c_str(), "r", true));
|
dmlc::Stream::Create(col_meta_name.c_str(), "r", true));
|
||||||
if (fmeta.get() == nullptr) return false;
|
if (fmeta.get() == nullptr) return false;
|
||||||
@ -112,13 +131,15 @@ bool SparsePageDMatrix::TryInitColData() {
|
|||||||
CHECK(fmeta->Read(&col_size_)) << "invalid col.meta file";
|
CHECK(fmeta->Read(&col_size_)) << "invalid col.meta file";
|
||||||
}
|
}
|
||||||
// load real data
|
// load real data
|
||||||
{
|
std::vector<std::unique_ptr<dmlc::SeekStream> > files;
|
||||||
std::string col_data_name = cache_prefix_ + ".col.page";
|
for (const std::string& prefix : cache_shards) {
|
||||||
|
std::string col_data_name = prefix + ".col.page";
|
||||||
std::unique_ptr<dmlc::SeekStream> fdata(
|
std::unique_ptr<dmlc::SeekStream> fdata(
|
||||||
dmlc::SeekStream::CreateForRead(col_data_name.c_str(), true));
|
dmlc::SeekStream::CreateForRead(col_data_name.c_str(), true));
|
||||||
if (fdata.get() == nullptr) return false;
|
if (fdata.get() == nullptr) return false;
|
||||||
col_iter_.reset(new ColPageIter(std::move(fdata)));
|
files.push_back(std::move(fdata));
|
||||||
}
|
}
|
||||||
|
col_iter_.reset(new ColPageIter(std::move(files)));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,26 +156,19 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
|||||||
buffered_rowset_.clear();
|
buffered_rowset_.clear();
|
||||||
col_size_.resize(info.num_col);
|
col_size_.resize(info.num_col);
|
||||||
std::fill(col_size_.begin(), col_size_.end(), 0);
|
std::fill(col_size_.begin(), col_size_.end(), 0);
|
||||||
// make the sparse page.
|
|
||||||
dmlc::ThreadedIter<SparsePage> cmaker;
|
|
||||||
SparsePage tmp;
|
|
||||||
size_t batch_ptr = 0, batch_top = 0;
|
|
||||||
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
|
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
|
||||||
std::bernoulli_distribution coin_flip(pkeep);
|
std::bernoulli_distribution coin_flip(pkeep);
|
||||||
|
size_t batch_ptr = 0, batch_top = 0;
|
||||||
|
SparsePage tmp;
|
||||||
auto& rnd = common::GlobalRandom();
|
auto& rnd = common::GlobalRandom();
|
||||||
|
|
||||||
// function to create the page.
|
// function to create the page.
|
||||||
auto make_col_batch = [&] (
|
auto make_col_batch = [&] (
|
||||||
const SparsePage& prow,
|
const SparsePage& prow,
|
||||||
const bst_uint* ridx,
|
size_t begin,
|
||||||
SparsePage **dptr) {
|
SparsePage *pcol) {
|
||||||
if (*dptr == nullptr) {
|
|
||||||
*dptr = new SparsePage();
|
|
||||||
}
|
|
||||||
SparsePage* pcol = *dptr;
|
|
||||||
pcol->Clear();
|
pcol->Clear();
|
||||||
pcol->min_index = ridx[0];
|
pcol->min_index = buffered_rowset_[begin];
|
||||||
int nthread;
|
int nthread;
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
@ -182,7 +196,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
|||||||
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
|
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
|
||||||
const SparseBatch::Entry &e = prow.data[j];
|
const SparseBatch::Entry &e = prow.data[j];
|
||||||
builder.Push(e.index,
|
builder.Push(e.index,
|
||||||
SparseBatch::Entry(ridx[i], e.fvalue),
|
SparseBatch::Entry(buffered_rowset_[i + begin], e.fvalue),
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -199,7 +213,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto make_next_col = [&] (SparsePage** dptr) {
|
auto make_next_col = [&] (SparsePage* dptr) {
|
||||||
tmp.Clear();
|
tmp.Clear();
|
||||||
size_t btop = buffered_rowset_.size();
|
size_t btop = buffered_rowset_.size();
|
||||||
|
|
||||||
@ -216,7 +230,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
|||||||
|
|
||||||
if (tmp.Size() >= max_row_perbatch ||
|
if (tmp.Size() >= max_row_perbatch ||
|
||||||
tmp.MemCostBytes() >= kPageSize) {
|
tmp.MemCostBytes() >= kPageSize) {
|
||||||
make_col_batch(tmp, dmlc::BeginPtr(buffered_rowset_) + btop, dptr);
|
make_col_batch(tmp, btop, dptr);
|
||||||
batch_ptr = i + 1;
|
batch_ptr = i + 1;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -229,48 +243,51 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tmp.Size() != 0) {
|
if (tmp.Size() != 0) {
|
||||||
make_col_batch(tmp, dmlc::BeginPtr(buffered_rowset_) + btop, dptr);
|
make_col_batch(tmp, btop, dptr);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
cmaker.Init(make_next_col, []() {});
|
std::vector<std::string> cache_shards = common::Split(cache_info_, ':');
|
||||||
|
std::vector<std::string> name_shards, format_shards;
|
||||||
std::string col_data_name = cache_prefix_ + ".col.page";
|
for (const std::string& prefix : cache_shards) {
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(col_data_name.c_str(), "w"));
|
name_shards.push_back(prefix + ".col.page");
|
||||||
// find format.
|
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).second);
|
||||||
std::string name_format = SparsePage::Format::DecideFormat(cache_prefix_).second;
|
}
|
||||||
fo->Write(name_format);
|
SparsePage::Writer writer(name_shards, format_shards, 6);
|
||||||
std::unique_ptr<SparsePage::Format> format(SparsePage::Format::Create(name_format));
|
std::unique_ptr<SparsePage> page;
|
||||||
|
writer.Alloc(&page); page->Clear();
|
||||||
|
|
||||||
double tstart = dmlc::GetTime();
|
double tstart = dmlc::GetTime();
|
||||||
size_t bytes_write = 0;
|
size_t bytes_write = 0;
|
||||||
// print every 4 sec.
|
// print every 4 sec.
|
||||||
const double kStep = 4.0;
|
const double kStep = 4.0;
|
||||||
size_t tick_expected = kStep;
|
size_t tick_expected = kStep;
|
||||||
SparsePage* pcol = nullptr;
|
|
||||||
|
|
||||||
while (cmaker.Next(&pcol)) {
|
while (make_next_col(page.get())) {
|
||||||
for (size_t i = 0; i < pcol->Size(); ++i) {
|
for (size_t i = 0; i < page->Size(); ++i) {
|
||||||
col_size_[i] += pcol->offset[i + 1] - pcol->offset[i];
|
col_size_[i] += page->offset[i + 1] - page->offset[i];
|
||||||
}
|
}
|
||||||
format->Write(*pcol, fo.get());
|
|
||||||
size_t spage = pcol->MemCostBytes();
|
bytes_write += page->MemCostBytes();
|
||||||
bytes_write += spage;
|
writer.PushWrite(std::move(page));
|
||||||
|
writer.Alloc(&page);
|
||||||
|
page->Clear();
|
||||||
|
|
||||||
double tdiff = dmlc::GetTime() - tstart;
|
double tdiff = dmlc::GetTime() - tstart;
|
||||||
if (tdiff >= tick_expected) {
|
if (tdiff >= tick_expected) {
|
||||||
LOG(CONSOLE) << "Writing to " << col_data_name
|
LOG(CONSOLE) << "Writing col.page file to " << cache_info_
|
||||||
<< " in " << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
<< " in " << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
||||||
<< (bytes_write >> 20UL) << " MB writen";
|
<< (bytes_write >> 20UL) << " MB writen";
|
||||||
tick_expected += kStep;
|
tick_expected += kStep;
|
||||||
}
|
}
|
||||||
cmaker.Recycle(&pcol);
|
|
||||||
}
|
}
|
||||||
// save meta data
|
// save meta data
|
||||||
std::string col_meta_name = cache_prefix_ + ".col.meta";
|
std::string col_meta_name = cache_shards[0] + ".col.meta";
|
||||||
fo.reset(dmlc::Stream::Create(col_meta_name.c_str(), "w"));
|
std::unique_ptr<dmlc::Stream> fo(
|
||||||
|
dmlc::Stream::Create(col_meta_name.c_str(), "w"));
|
||||||
fo->Write(buffered_rowset_);
|
fo->Write(buffered_rowset_);
|
||||||
fo->Write(col_size_);
|
fo->Write(col_size_);
|
||||||
fo.reset(nullptr);
|
fo.reset(nullptr);
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "./sparse_batch_page.h"
|
#include "./sparse_batch_page.h"
|
||||||
|
#include "../common/common.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -21,9 +22,9 @@ namespace data {
|
|||||||
class SparsePageDMatrix : public DMatrix {
|
class SparsePageDMatrix : public DMatrix {
|
||||||
public:
|
public:
|
||||||
explicit SparsePageDMatrix(std::unique_ptr<DataSource>&& source,
|
explicit SparsePageDMatrix(std::unique_ptr<DataSource>&& source,
|
||||||
const std::string& cache_prefix)
|
const std::string& cache_info)
|
||||||
: source_(std::move(source)),
|
: source_(std::move(source)), cache_info_(cache_info) {
|
||||||
cache_prefix_(cache_prefix) {}
|
}
|
||||||
|
|
||||||
MetaInfo& info() override {
|
MetaInfo& info() override {
|
||||||
return source_->info;
|
return source_->info;
|
||||||
@ -43,7 +44,7 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
return col_iter_.get() != nullptr;
|
return col_iter_.get() != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<bst_uint>& buffered_rowset() const override {
|
const RowSet& buffered_rowset() const override {
|
||||||
return buffered_rowset_;
|
return buffered_rowset_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,11 +78,9 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
// declare the column batch iter.
|
// declare the column batch iter.
|
||||||
class ColPageIter : public dmlc::DataIter<ColBatch> {
|
class ColPageIter : public dmlc::DataIter<ColBatch> {
|
||||||
public:
|
public:
|
||||||
explicit ColPageIter(std::unique_ptr<dmlc::SeekStream>&& fi);
|
explicit ColPageIter(std::vector<std::unique_ptr<dmlc::SeekStream> >&& files);
|
||||||
virtual ~ColPageIter();
|
virtual ~ColPageIter();
|
||||||
void BeforeFirst() override {
|
void BeforeFirst() override;
|
||||||
prefetcher_.BeforeFirst();
|
|
||||||
}
|
|
||||||
const ColBatch &Value() const override {
|
const ColBatch &Value() const override {
|
||||||
return out_;
|
return out_;
|
||||||
}
|
}
|
||||||
@ -90,20 +89,22 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
void Init(const std::vector<bst_uint>& index_set, bool load_all);
|
void Init(const std::vector<bst_uint>& index_set, bool load_all);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// data file pointer.
|
|
||||||
std::unique_ptr<dmlc::SeekStream> fi_;
|
|
||||||
// the temp page.
|
// the temp page.
|
||||||
SparsePage* page_;
|
SparsePage* page_;
|
||||||
|
// internal clock ptr.
|
||||||
|
size_t clock_ptr_;
|
||||||
|
// data file pointer.
|
||||||
|
std::vector<std::unique_ptr<dmlc::SeekStream> > files_;
|
||||||
// page format.
|
// page format.
|
||||||
std::unique_ptr<SparsePage::Format> format_;
|
std::vector<std::unique_ptr<SparsePage::Format> > formats_;
|
||||||
|
/*! \brief internal prefetcher. */
|
||||||
|
std::vector<std::unique_ptr<dmlc::ThreadedIter<SparsePage> > > prefetchers_;
|
||||||
// The index set to be loaded.
|
// The index set to be loaded.
|
||||||
std::vector<bst_uint> index_set_;
|
std::vector<bst_uint> index_set_;
|
||||||
// The index set by the outsiders
|
// The index set by the outsiders
|
||||||
std::vector<bst_uint> set_index_set_;
|
std::vector<bst_uint> set_index_set_;
|
||||||
// whether to load data dataset.
|
// whether to load data dataset.
|
||||||
bool set_load_all_, load_all_;
|
bool set_load_all_, load_all_;
|
||||||
// data prefetcher.
|
|
||||||
dmlc::ThreadedIter<SparsePage> prefetcher_;
|
|
||||||
// temporal space for batch
|
// temporal space for batch
|
||||||
ColBatch out_;
|
ColBatch out_;
|
||||||
// the pointer data.
|
// the pointer data.
|
||||||
@ -117,9 +118,9 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
// source data pointer.
|
// source data pointer.
|
||||||
std::unique_ptr<DataSource> source_;
|
std::unique_ptr<DataSource> source_;
|
||||||
// the cache prefix
|
// the cache prefix
|
||||||
std::string cache_prefix_;
|
std::string cache_info_;
|
||||||
/*! \brief list of row index that are buffered */
|
/*! \brief list of row index that are buffered */
|
||||||
std::vector<bst_uint> buffered_rowset_;
|
RowSet buffered_rowset_;
|
||||||
// count for column data
|
// count for column data
|
||||||
std::vector<size_t> col_size_;
|
std::vector<size_t> col_size_;
|
||||||
// internal column iter.
|
// internal column iter.
|
||||||
|
|||||||
@ -9,35 +9,45 @@
|
|||||||
|
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
#include "./sparse_page_source.h"
|
#include "./sparse_page_source.h"
|
||||||
|
#include "../common/common.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
SparsePageSource::SparsePageSource(const std::string& cache_prefix)
|
SparsePageSource::SparsePageSource(const std::string& cache_info)
|
||||||
: base_rowid_(0), page_(nullptr) {
|
: base_rowid_(0), page_(nullptr), clock_ptr_(0) {
|
||||||
// read in the info files.
|
// read in the info files
|
||||||
|
std::vector<std::string> cache_shards = common::Split(cache_info, ':');
|
||||||
|
CHECK_NE(cache_shards.size(), 0);
|
||||||
{
|
{
|
||||||
std::string name_info = cache_prefix;
|
std::string name_info = cache_shards[0];
|
||||||
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r"));
|
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r"));
|
||||||
int tmagic;
|
int tmagic;
|
||||||
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
|
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
|
||||||
this->info.LoadBinary(finfo.get());
|
this->info.LoadBinary(finfo.get());
|
||||||
}
|
}
|
||||||
|
files_.resize(cache_shards.size());
|
||||||
|
formats_.resize(cache_shards.size());
|
||||||
|
prefetchers_.resize(cache_shards.size());
|
||||||
|
|
||||||
// read in the cache files.
|
// read in the cache files.
|
||||||
std::string name_row = cache_prefix + ".row.page";
|
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||||
fi_.reset(dmlc::SeekStream::CreateForRead(name_row.c_str()));
|
std::string name_row = cache_shards[i] + ".row.page";
|
||||||
|
files_[i].reset(dmlc::SeekStream::CreateForRead(name_row.c_str()));
|
||||||
std::string format;
|
dmlc::SeekStream* fi = files_[i].get();
|
||||||
CHECK(fi_->Read(&format)) << "Invalid page format";
|
std::string format;
|
||||||
format_.reset(SparsePage::Format::Create(format));
|
CHECK(fi->Read(&format)) << "Invalid page format";
|
||||||
size_t fbegin = fi_->Tell();
|
formats_[i].reset(SparsePage::Format::Create(format));
|
||||||
|
SparsePage::Format* fmt = formats_[i].get();
|
||||||
prefetcher_.Init([this] (SparsePage** dptr) {
|
size_t fbegin = fi->Tell();
|
||||||
if (*dptr == nullptr) {
|
prefetchers_[i].reset(new dmlc::ThreadedIter<SparsePage>(4));
|
||||||
*dptr = new SparsePage();
|
prefetchers_[i]->Init([fi, fmt] (SparsePage** dptr) {
|
||||||
}
|
if (*dptr == nullptr) {
|
||||||
return format_->Read(*dptr, fi_.get());
|
*dptr = new SparsePage();
|
||||||
}, [this, fbegin] () { fi_->Seek(fbegin); });
|
}
|
||||||
|
return fmt->Read(*dptr, fi);
|
||||||
|
}, [fi, fbegin] () { fi->Seek(fbegin); });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SparsePageSource::~SparsePageSource() {
|
SparsePageSource::~SparsePageSource() {
|
||||||
@ -45,12 +55,16 @@ SparsePageSource::~SparsePageSource() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool SparsePageSource::Next() {
|
bool SparsePageSource::Next() {
|
||||||
|
// doing clock rotation over shards.
|
||||||
if (page_ != nullptr) {
|
if (page_ != nullptr) {
|
||||||
prefetcher_.Recycle(&page_);
|
size_t n = prefetchers_.size();
|
||||||
|
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
|
||||||
}
|
}
|
||||||
if (prefetcher_.Next(&page_)) {
|
if (prefetchers_[clock_ptr_]->Next(&page_)) {
|
||||||
batch_ = page_->GetRowBatch(base_rowid_);
|
batch_ = page_->GetRowBatch(base_rowid_);
|
||||||
base_rowid_ += batch_.size;
|
base_rowid_ += batch_.size;
|
||||||
|
// advance clock
|
||||||
|
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
@ -59,33 +73,48 @@ bool SparsePageSource::Next() {
|
|||||||
|
|
||||||
void SparsePageSource::BeforeFirst() {
|
void SparsePageSource::BeforeFirst() {
|
||||||
base_rowid_ = 0;
|
base_rowid_ = 0;
|
||||||
prefetcher_.BeforeFirst();
|
clock_ptr_ = 0;
|
||||||
|
for (auto& p : prefetchers_) {
|
||||||
|
p->BeforeFirst();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const RowBatch& SparsePageSource::Value() const {
|
const RowBatch& SparsePageSource::Value() const {
|
||||||
return batch_;
|
return batch_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SparsePageSource::CacheExist(const std::string& cache_prefix) {
|
bool SparsePageSource::CacheExist(const std::string& cache_info) {
|
||||||
std::string name_info = cache_prefix;
|
std::vector<std::string> cache_shards = common::Split(cache_info, ':');
|
||||||
std::string name_row = cache_prefix + ".row.page";
|
CHECK_NE(cache_shards.size(), 0);
|
||||||
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r", true));
|
{
|
||||||
std::unique_ptr<dmlc::Stream> frow(dmlc::Stream::Create(name_row.c_str(), "r", true));
|
std::string name_info = cache_shards[0];
|
||||||
return finfo.get() != nullptr && frow.get() != nullptr;
|
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r", true));
|
||||||
|
if (finfo.get() == nullptr) return false;
|
||||||
|
}
|
||||||
|
for (const std::string& prefix : cache_shards) {
|
||||||
|
std::string name_row = prefix + ".row.page";
|
||||||
|
std::unique_ptr<dmlc::Stream> frow(dmlc::Stream::Create(name_row.c_str(), "r", true));
|
||||||
|
if (frow.get() == nullptr) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
||||||
const std::string& cache_prefix) {
|
const std::string& cache_info) {
|
||||||
|
std::vector<std::string> cache_shards = common::Split(cache_info, ':');
|
||||||
|
CHECK_NE(cache_shards.size(), 0);
|
||||||
// read in the info files.
|
// read in the info files.
|
||||||
std::string name_info = cache_prefix;
|
std::string name_info = cache_shards[0];
|
||||||
std::string name_row = cache_prefix + ".row.page";
|
std::vector<std::string> name_shards, format_shards;
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(name_row.c_str(), "w"));
|
for (const std::string& prefix : cache_shards) {
|
||||||
std::string name_format = SparsePage::Format::DecideFormat(cache_prefix).first;
|
name_shards.push_back(prefix + ".row.page");
|
||||||
fo->Write(name_format);
|
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
|
||||||
std::unique_ptr<SparsePage::Format> format(SparsePage::Format::Create(name_format));
|
}
|
||||||
|
SparsePage::Writer writer(name_shards, format_shards, 6);
|
||||||
|
std::unique_ptr<SparsePage> page;
|
||||||
|
writer.Alloc(&page); page->Clear();
|
||||||
|
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
SparsePage page;
|
|
||||||
size_t bytes_write = 0;
|
size_t bytes_write = 0;
|
||||||
double tstart = dmlc::GetTime();
|
double tstart = dmlc::GetTime();
|
||||||
// print every 4 sec.
|
// print every 4 sec.
|
||||||
@ -107,14 +136,16 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
|||||||
info.num_col = std::max(info.num_col,
|
info.num_col = std::max(info.num_col,
|
||||||
static_cast<uint64_t>(index + 1));
|
static_cast<uint64_t>(index + 1));
|
||||||
}
|
}
|
||||||
page.Push(batch);
|
page->Push(batch);
|
||||||
if (page.MemCostBytes() >= kPageSize) {
|
if (page->MemCostBytes() >= kPageSize) {
|
||||||
bytes_write += page.MemCostBytes();
|
bytes_write += page->MemCostBytes();
|
||||||
format->Write(page, fo.get());
|
writer.PushWrite(std::move(page));
|
||||||
page.Clear();
|
writer.Alloc(&page);
|
||||||
|
page->Clear();
|
||||||
|
|
||||||
double tdiff = dmlc::GetTime() - tstart;
|
double tdiff = dmlc::GetTime() - tstart;
|
||||||
if (tdiff >= tick_expected) {
|
if (tdiff >= tick_expected) {
|
||||||
LOG(CONSOLE) << "Writing to " << name_row << " in "
|
LOG(CONSOLE) << "Writing row.page to " << cache_info << " in "
|
||||||
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
||||||
<< (bytes_write >> 20UL) << " written";
|
<< (bytes_write >> 20UL) << " written";
|
||||||
tick_expected += kStep;
|
tick_expected += kStep;
|
||||||
@ -122,57 +153,62 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (page.data.size() != 0) {
|
if (page->data.size() != 0) {
|
||||||
format->Write(page, fo.get());
|
writer.PushWrite(std::move(page));
|
||||||
}
|
}
|
||||||
|
|
||||||
fo.reset(dmlc::Stream::Create(name_info.c_str(), "w"));
|
std::unique_ptr<dmlc::Stream> fo(
|
||||||
|
dmlc::Stream::Create(name_info.c_str(), "w"));
|
||||||
int tmagic = kMagic;
|
int tmagic = kMagic;
|
||||||
fo->Write(&tmagic, sizeof(tmagic));
|
fo->Write(&tmagic, sizeof(tmagic));
|
||||||
info.SaveBinary(fo.get());
|
info.SaveBinary(fo.get());
|
||||||
|
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
|
||||||
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << cache_prefix;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SparsePageSource::Create(DMatrix* src,
|
void SparsePageSource::Create(DMatrix* src,
|
||||||
const std::string& cache_prefix) {
|
const std::string& cache_info) {
|
||||||
|
std::vector<std::string> cache_shards = common::Split(cache_info, ':');
|
||||||
|
CHECK_NE(cache_shards.size(), 0);
|
||||||
// read in the info files.
|
// read in the info files.
|
||||||
std::string name_info = cache_prefix;
|
std::string name_info = cache_shards[0];
|
||||||
std::string name_row = cache_prefix + ".row.page";
|
std::vector<std::string> name_shards, format_shards;
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(name_row.c_str(), "w"));
|
for (const std::string& prefix : cache_shards) {
|
||||||
// find format.
|
name_shards.push_back(prefix + ".row.page");
|
||||||
std::string name_format = SparsePage::Format::DecideFormat(cache_prefix).first;
|
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
|
||||||
fo->Write(name_format);
|
}
|
||||||
std::unique_ptr<SparsePage::Format> format(SparsePage::Format::Create(name_format));
|
SparsePage::Writer writer(name_shards, format_shards, 6);
|
||||||
|
std::unique_ptr<SparsePage> page;
|
||||||
|
writer.Alloc(&page); page->Clear();
|
||||||
|
|
||||||
SparsePage page;
|
MetaInfo info;
|
||||||
size_t bytes_write = 0;
|
size_t bytes_write = 0;
|
||||||
double tstart = dmlc::GetTime();
|
double tstart = dmlc::GetTime();
|
||||||
dmlc::DataIter<RowBatch>* iter = src->RowIterator();
|
dmlc::DataIter<RowBatch>* iter = src->RowIterator();
|
||||||
|
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
page.Push(iter->Value());
|
page->Push(iter->Value());
|
||||||
if (page.MemCostBytes() >= kPageSize) {
|
if (page->MemCostBytes() >= kPageSize) {
|
||||||
bytes_write += page.MemCostBytes();
|
bytes_write += page->MemCostBytes();
|
||||||
format->Write(page, fo.get());
|
writer.PushWrite(std::move(page));
|
||||||
page.Clear();
|
writer.Alloc(&page);
|
||||||
|
page->Clear();
|
||||||
double tdiff = dmlc::GetTime() - tstart;
|
double tdiff = dmlc::GetTime() - tstart;
|
||||||
LOG(CONSOLE) << "Writing to " << name_row << " in "
|
LOG(CONSOLE) << "Writing to " << cache_info << " in "
|
||||||
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
||||||
<< (bytes_write >> 20UL) << " written";
|
<< (bytes_write >> 20UL) << " written";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (page.data.size() != 0) {
|
if (page->data.size() != 0) {
|
||||||
format->Write(page, fo.get());
|
writer.PushWrite(std::move(page));
|
||||||
}
|
}
|
||||||
|
|
||||||
fo.reset(dmlc::Stream::Create(name_info.c_str(), "w"));
|
std::unique_ptr<dmlc::Stream> fo(
|
||||||
|
dmlc::Stream::Create(name_info.c_str(), "w"));
|
||||||
int tmagic = kMagic;
|
int tmagic = kMagic;
|
||||||
fo->Write(&tmagic, sizeof(tmagic));
|
fo->Write(&tmagic, sizeof(tmagic));
|
||||||
src->info().SaveBinary(fo.get());
|
info.SaveBinary(fo.get());
|
||||||
|
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
|
||||||
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << cache_prefix;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
|||||||
@ -71,14 +71,14 @@ class SparsePageSource : public DataSource {
|
|||||||
RowBatch batch_;
|
RowBatch batch_;
|
||||||
/*! \brief page currently on hold. */
|
/*! \brief page currently on hold. */
|
||||||
SparsePage *page_;
|
SparsePage *page_;
|
||||||
/*! \brief The cache predix of the dataset. */
|
/*! \brief internal clock ptr */
|
||||||
std::string cache_prefix_;
|
size_t clock_ptr_;
|
||||||
/*! \brief file pointer to the row blob file. */
|
/*! \brief file pointer to the row blob file. */
|
||||||
std::unique_ptr<dmlc::SeekStream> fi_;
|
std::vector<std::unique_ptr<dmlc::SeekStream> > files_;
|
||||||
/*! \brief Sparse page format file. */
|
/*! \brief Sparse page format file. */
|
||||||
std::unique_ptr<SparsePage::Format> format_;
|
std::vector<std::unique_ptr<SparsePage::Format> > formats_;
|
||||||
/*! \brief internal prefetcher. */
|
/*! \brief internal prefetcher. */
|
||||||
dmlc::ThreadedIter<SparsePage> prefetcher_;
|
std::vector<std::unique_ptr<dmlc::ThreadedIter<SparsePage> > > prefetchers_;
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
72
src/data/sparse_page_writer.cc
Normal file
72
src/data/sparse_page_writer.cc
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright (c) 2015 by Contributors
|
||||||
|
* \file sparse_batch_writer.cc
|
||||||
|
* \param Writer class sparse page.
|
||||||
|
*/
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include <xgboost/logging.h>
|
||||||
|
#include "./sparse_batch_page.h"
|
||||||
|
|
||||||
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
|
namespace xgboost {
|
||||||
|
namespace data {
|
||||||
|
|
||||||
|
SparsePage::Writer::Writer(
|
||||||
|
const std::vector<std::string>& name_shards,
|
||||||
|
const std::vector<std::string>& format_shards,
|
||||||
|
size_t extra_buffer_capacity)
|
||||||
|
: num_free_buffer_(extra_buffer_capacity + name_shards.size()),
|
||||||
|
clock_ptr_(0),
|
||||||
|
workers_(name_shards.size()),
|
||||||
|
qworkers_(name_shards.size()) {
|
||||||
|
CHECK_EQ(name_shards.size(), format_shards.size());
|
||||||
|
// start writer threads
|
||||||
|
for (size_t i = 0; i < name_shards.size(); ++i) {
|
||||||
|
std::string name_shard = name_shards[i];
|
||||||
|
std::string format_shard = format_shards[i];
|
||||||
|
auto* wqueue = &qworkers_[i];
|
||||||
|
workers_[i].reset(new std::thread(
|
||||||
|
[this, name_shard, format_shard, wqueue] () {
|
||||||
|
std::unique_ptr<dmlc::Stream> fo(
|
||||||
|
dmlc::Stream::Create(name_shard.c_str(), "w"));
|
||||||
|
std::unique_ptr<SparsePage::Format> fmt(
|
||||||
|
SparsePage::Format::Create(format_shard));
|
||||||
|
fo->Write(format_shard);
|
||||||
|
std::unique_ptr<SparsePage> page;
|
||||||
|
while (wqueue->Pop(&page)) {
|
||||||
|
fmt->Write(*page, fo.get());
|
||||||
|
qrecycle_.Push(std::move(page));
|
||||||
|
}
|
||||||
|
fo.reset(nullptr);
|
||||||
|
LOG(CONSOLE) << "SparsePage::Writer Finished writing to " << name_shard;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SparsePage::Writer::~Writer() {
|
||||||
|
for (auto& queue : qworkers_) {
|
||||||
|
queue.SignalForKill();
|
||||||
|
}
|
||||||
|
for (auto& thread : workers_) {
|
||||||
|
thread->join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparsePage::Writer::PushWrite(std::unique_ptr<SparsePage>&& page) {
|
||||||
|
qworkers_[clock_ptr_].Push(std::move(page));
|
||||||
|
clock_ptr_ = (clock_ptr_ + 1) % workers_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparsePage::Writer::Alloc(std::unique_ptr<SparsePage>* out_page) {
|
||||||
|
CHECK(out_page->get() == nullptr);
|
||||||
|
if (num_free_buffer_ != 0) {
|
||||||
|
out_page->reset(new SparsePage());
|
||||||
|
--num_free_buffer_;
|
||||||
|
} else {
|
||||||
|
CHECK(qrecycle_.Pop(out_page));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace data
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // DMLC_ENABLE_STD_THREAD
|
||||||
@ -109,7 +109,7 @@ class GBLinear : public GradientBooster {
|
|||||||
|
|
||||||
std::vector<bst_gpair> &gpair = *in_gpair;
|
std::vector<bst_gpair> &gpair = *in_gpair;
|
||||||
const int ngroup = model.param.num_output_group;
|
const int ngroup = model.param.num_output_group;
|
||||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||||
// for all the output group
|
// for all the output group
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
double sum_grad = 0.0, sum_hess = 0.0;
|
||||||
|
|||||||
@ -325,7 +325,7 @@ class GBTree : public GradientBooster {
|
|||||||
int bst_group,
|
int bst_group,
|
||||||
const RegTree &new_tree,
|
const RegTree &new_tree,
|
||||||
const int* leaf_position) {
|
const int* leaf_position) {
|
||||||
const std::vector<bst_uint>& rowset = p_fmat->buffered_rowset();
|
const RowSet& rowset = p_fmat->buffered_rowset();
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
|
|||||||
@ -31,7 +31,7 @@ struct EvalMClassBase : public Metric {
|
|||||||
<< "mlogloss and merror are only used for multi-class classification,"
|
<< "mlogloss and merror are only used for multi-class classification,"
|
||||||
<< " use logloss for binary classification";
|
<< " use logloss for binary classification";
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||||
float sum = 0.0, wsum = 0.0;
|
double sum = 0.0, wsum = 0.0;
|
||||||
int label_error = 0;
|
int label_error = 0;
|
||||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
@ -50,7 +50,7 @@ struct EvalMClassBase : public Metric {
|
|||||||
<< "MultiClassEvaluation: label must be in [0, num_class),"
|
<< "MultiClassEvaluation: label must be in [0, num_class),"
|
||||||
<< " num_class=" << nclass << " but found " << label_error << " in label";
|
<< " num_class=" << nclass << " but found " << label_error << " in label";
|
||||||
|
|
||||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
double dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||||
if (distributed) {
|
if (distributed) {
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,7 +87,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
.describe("Subsample ratio of columns, resample on each tree construction.");
|
.describe("Subsample ratio of columns, resample on each tree construction.");
|
||||||
DMLC_DECLARE_FIELD(opt_dense_col).set_range(0.0f, 1.0f).set_default(1.0f)
|
DMLC_DECLARE_FIELD(opt_dense_col).set_range(0.0f, 1.0f).set_default(1.0f)
|
||||||
.describe("EXP Param: speed optimization for dense column.");
|
.describe("EXP Param: speed optimization for dense column.");
|
||||||
DMLC_DECLARE_FIELD(sketch_eps).set_range(0.0f, 1.0f).set_default(0.1f)
|
DMLC_DECLARE_FIELD(sketch_eps).set_range(0.0f, 1.0f).set_default(0.03f)
|
||||||
.describe("EXP Param: Sketch accuracy of approximate algorithm.");
|
.describe("EXP Param: Sketch accuracy of approximate algorithm.");
|
||||||
DMLC_DECLARE_FIELD(sketch_ratio).set_lower_bound(0.0f).set_default(2.0f)
|
DMLC_DECLARE_FIELD(sketch_ratio).set_lower_bound(0.0f).set_default(2.0f)
|
||||||
.describe("EXP Param: Sketch accuracy related parameter of approximate algorithm.");
|
.describe("EXP Param: Sketch accuracy related parameter of approximate algorithm.");
|
||||||
|
|||||||
@ -206,8 +206,18 @@ class BaseMaker: public TreeUpdater {
|
|||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
// set the positions in the nondefault
|
// set the positions in the nondefault
|
||||||
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
|
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
|
||||||
|
this->SetDefaultPostion(p_fmat, tree);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief helper function to set the non-leaf positions to default direction.
|
||||||
|
* This function can be applied multiple times and will get the same result.
|
||||||
|
* \param p_fmat feature matrix needed for tree construction
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
*/
|
||||||
|
inline void SetDefaultPostion(DMatrix *p_fmat,
|
||||||
|
const RegTree &tree) {
|
||||||
// set rest of instances to default position
|
// set rest of instances to default position
|
||||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||||
// set default direct nodes to default
|
// set default direct nodes to default
|
||||||
// for leaf nodes that are not fresh, mark then to ~nid,
|
// for leaf nodes that are not fresh, mark then to ~nid,
|
||||||
// so that they are ignored in future statistics collection
|
// so that they are ignored in future statistics collection
|
||||||
@ -222,7 +232,7 @@ class BaseMaker: public TreeUpdater {
|
|||||||
if (tree[nid].cright() == -1) {
|
if (tree[nid].cright() == -1) {
|
||||||
position[ridx] = ~nid;
|
position[ridx] = ~nid;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// push to default branch
|
// push to default branch
|
||||||
if (tree[nid].default_left()) {
|
if (tree[nid].default_left()) {
|
||||||
this->SetEncodePosition(ridx, tree[nid].cleft());
|
this->SetEncodePosition(ridx, tree[nid].cleft());
|
||||||
@ -234,16 +244,55 @@ class BaseMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief this is helper function uses column based data structure,
|
* \brief this is helper function uses column based data structure,
|
||||||
* update all positions into nondefault branch, if any, ignore the default branch
|
* to CORRECT the positions of non-default directions that WAS set to default
|
||||||
* \param nodes the set of nodes that contains the split to be used
|
* before calling this function.
|
||||||
* \param p_fmat feature matrix needed for tree construction
|
* \param batch The column batch
|
||||||
|
* \param sorted_split_set The set of index that contains split solutions.
|
||||||
* \param tree the regression tree structure
|
* \param tree the regression tree structure
|
||||||
*/
|
*/
|
||||||
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
inline void CorrectNonDefaultPositionByBatch(
|
||||||
DMatrix *p_fmat,
|
const ColBatch& batch,
|
||||||
const RegTree &tree) {
|
const std::vector<bst_uint> &sorted_split_set,
|
||||||
|
const RegTree &tree) {
|
||||||
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
|
ColBatch::Inst col = batch[i];
|
||||||
|
const bst_uint fid = batch.col_index[i];
|
||||||
|
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
|
||||||
|
|
||||||
|
if (it != sorted_split_set.end() && *it == fid) {
|
||||||
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
|
const bst_uint ridx = col[j].index;
|
||||||
|
const float fvalue = col[j].fvalue;
|
||||||
|
const int nid = this->DecodePosition(ridx);
|
||||||
|
CHECK(tree[nid].is_leaf());
|
||||||
|
int pid = tree[nid].parent();
|
||||||
|
|
||||||
|
// go back to parent, correct those who are not default
|
||||||
|
if (!tree[nid].is_root() && tree[pid].split_index() == fid) {
|
||||||
|
if (fvalue < tree[pid].split_cond()) {
|
||||||
|
this->SetEncodePosition(ridx, tree[pid].cleft());
|
||||||
|
} else {
|
||||||
|
this->SetEncodePosition(ridx, tree[pid].cright());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief this is helper function uses column based data structure,
|
||||||
|
* \param nodes the set of nodes that contains the split to be used
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
* \param out_split_set The split index set
|
||||||
|
*/
|
||||||
|
inline void GetSplitSet(const std::vector<int> &nodes,
|
||||||
|
const RegTree &tree,
|
||||||
|
std::vector<unsigned>* out_split_set) {
|
||||||
|
std::vector<unsigned>& fsplits = *out_split_set;
|
||||||
|
fsplits.clear();
|
||||||
// step 1, classify the non-default data into right places
|
// step 1, classify the non-default data into right places
|
||||||
std::vector<unsigned> fsplits;
|
|
||||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||||
const int nid = nodes[i];
|
const int nid = nodes[i];
|
||||||
if (!tree[nid].is_leaf()) {
|
if (!tree[nid].is_leaf()) {
|
||||||
@ -252,7 +301,19 @@ class BaseMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
std::sort(fsplits.begin(), fsplits.end());
|
std::sort(fsplits.begin(), fsplits.end());
|
||||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief this is helper function uses column based data structure,
|
||||||
|
* update all positions into nondefault branch, if any, ignore the default branch
|
||||||
|
* \param nodes the set of nodes that contains the split to be used
|
||||||
|
* \param p_fmat feature matrix needed for tree construction
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
*/
|
||||||
|
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
||||||
|
DMatrix *p_fmat,
|
||||||
|
const RegTree &tree) {
|
||||||
|
std::vector<unsigned> fsplits;
|
||||||
|
this->GetSplitSet(nodes, tree, &fsplits);
|
||||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
const ColBatch &batch = iter->Value();
|
const ColBatch &batch = iter->Value();
|
||||||
@ -297,7 +358,7 @@ class BaseMaker: public TreeUpdater {
|
|||||||
thread_temp[tid][nid].Clear();
|
thread_temp[tid][nid].Clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
|
const RowSet &rowset = fmat.buffered_rowset();
|
||||||
// setup position
|
// setup position
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
|
|||||||
@ -117,7 +117,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
|
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
|
||||||
<< "ColMaker: can only grow new tree";
|
<< "ColMaker: can only grow new tree";
|
||||||
const std::vector<unsigned>& root_index = fmat.info().root_index;
|
const std::vector<unsigned>& root_index = fmat.info().root_index;
|
||||||
const std::vector<bst_uint>& rowset = fmat.buffered_rowset();
|
const RowSet& rowset = fmat.buffered_rowset();
|
||||||
{
|
{
|
||||||
// setup position
|
// setup position
|
||||||
position.resize(gpair.size());
|
position.resize(gpair.size());
|
||||||
@ -200,7 +200,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
snode.resize(tree.param.num_nodes, NodeEntry(param));
|
snode.resize(tree.param.num_nodes, NodeEntry(param));
|
||||||
}
|
}
|
||||||
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
|
const RowSet &rowset = fmat.buffered_rowset();
|
||||||
const MetaInfo& info = fmat.info();
|
const MetaInfo& info = fmat.info();
|
||||||
// setup position
|
// setup position
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
@ -291,7 +291,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
ThreadEntry &e = stemp[tid][nid];
|
ThreadEntry &e = stemp[tid][nid];
|
||||||
float fsplit;
|
float fsplit;
|
||||||
if (tid != 0) {
|
if (tid != 0) {
|
||||||
if (std::abs(stemp[tid - 1][nid].last_fvalue - e.first_fvalue) > rt_2eps) {
|
if (stemp[tid - 1][nid].last_fvalue != e.first_fvalue) {
|
||||||
fsplit = (stemp[tid - 1][nid].last_fvalue + e.first_fvalue) * 0.5f;
|
fsplit = (stemp[tid - 1][nid].last_fvalue + e.first_fvalue) * 0.5f;
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
@ -352,7 +352,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
e.first_fvalue = fvalue;
|
e.first_fvalue = fvalue;
|
||||||
} else {
|
} else {
|
||||||
// forward default right
|
// forward default right
|
||||||
if (std::abs(fvalue - e.first_fvalue) > rt_2eps) {
|
if (fvalue != e.first_fvalue) {
|
||||||
if (need_forward) {
|
if (need_forward) {
|
||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight &&
|
if (c.sum_hess >= param.min_child_weight &&
|
||||||
@ -393,7 +393,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
e.last_fvalue = fvalue;
|
e.last_fvalue = fvalue;
|
||||||
} else {
|
} else {
|
||||||
// try to find a split
|
// try to find a split
|
||||||
if (std::abs(fvalue - e.last_fvalue) > rt_2eps &&
|
if (fvalue != e.last_fvalue &&
|
||||||
e.stats.sum_hess >= param.min_child_weight) {
|
e.stats.sum_hess >= param.min_child_weight) {
|
||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight) {
|
if (c.sum_hess >= param.min_child_weight) {
|
||||||
@ -511,7 +511,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
e.last_fvalue = fvalue;
|
e.last_fvalue = fvalue;
|
||||||
} else {
|
} else {
|
||||||
// try to find a split
|
// try to find a split
|
||||||
if (std::abs(fvalue - e.last_fvalue) > rt_2eps &&
|
if (fvalue != e.last_fvalue &&
|
||||||
e.stats.sum_hess >= param.min_child_weight) {
|
e.stats.sum_hess >= param.min_child_weight) {
|
||||||
c.SetSubstract(snode[nid].stats, e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight) {
|
if (c.sum_hess >= param.min_child_weight) {
|
||||||
@ -620,7 +620,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
// set the positions in the nondefault
|
// set the positions in the nondefault
|
||||||
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
|
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
|
||||||
// set rest of instances to default position
|
// set rest of instances to default position
|
||||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||||
// set default direct nodes to default
|
// set default direct nodes to default
|
||||||
// for leaf nodes that are not fresh, mark then to ~nid,
|
// for leaf nodes that are not fresh, mark then to ~nid,
|
||||||
// so that they are ignored in future statistics collection
|
// so that they are ignored in future statistics collection
|
||||||
@ -761,7 +761,7 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
: ColMaker<TStats>::Builder(param) {
|
: ColMaker<TStats>::Builder(param) {
|
||||||
}
|
}
|
||||||
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
|
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
|
||||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
@ -831,7 +831,7 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
bitmap.InitFromBool(boolmap);
|
bitmap.InitFromBool(boolmap);
|
||||||
// communicate bitmap
|
// communicate bitmap
|
||||||
rabit::Allreduce<rabit::op::BitOR>(dmlc::BeginPtr(bitmap.data), bitmap.data.size());
|
rabit::Allreduce<rabit::op::BitOR>(dmlc::BeginPtr(bitmap.data), bitmap.data.size());
|
||||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||||
// get the new position
|
// get the new position
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
|
|||||||
@ -127,6 +127,11 @@ class HistMaker: public BaseMaker {
|
|||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
this->InitData(gpair, *p_fmat, *p_tree);
|
this->InitData(gpair, *p_fmat, *p_tree);
|
||||||
this->InitWorkSet(p_fmat, *p_tree, &fwork_set);
|
this->InitWorkSet(p_fmat, *p_tree, &fwork_set);
|
||||||
|
// mark root node as fresh.
|
||||||
|
for (int i = 0; i < p_tree->param.num_roots; ++i) {
|
||||||
|
(*p_tree)[i].set_leaf(0.0f, 0);
|
||||||
|
}
|
||||||
|
|
||||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||||
// reset and propose candidate split
|
// reset and propose candidate split
|
||||||
this->ResetPosAndPropose(gpair, p_fmat, fwork_set, *p_tree);
|
this->ResetPosAndPropose(gpair, p_fmat, fwork_set, *p_tree);
|
||||||
@ -263,6 +268,10 @@ class HistMaker: public BaseMaker {
|
|||||||
|
|
||||||
template<typename TStats>
|
template<typename TStats>
|
||||||
class CQHistMaker: public HistMaker<TStats> {
|
class CQHistMaker: public HistMaker<TStats> {
|
||||||
|
public:
|
||||||
|
CQHistMaker() : cache_dmatrix_(nullptr) {
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
struct HistEntry {
|
struct HistEntry {
|
||||||
typename HistMaker<TStats>::HistUnit hist;
|
typename HistMaker<TStats>::HistUnit hist;
|
||||||
@ -285,9 +294,13 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
*/
|
*/
|
||||||
inline void Add(bst_float fv,
|
inline void Add(bst_float fv,
|
||||||
bst_gpair gstats) {
|
bst_gpair gstats) {
|
||||||
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
|
if (fv < hist.cut[istart]) {
|
||||||
CHECK_NE(istart, hist.size);
|
hist.data[istart].Add(gstats);
|
||||||
hist.data[istart].Add(gstats);
|
} else {
|
||||||
|
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
|
||||||
|
CHECK_NE(istart, hist.size);
|
||||||
|
hist.data[istart].Add(gstats);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// sketch type used for this
|
// sketch type used for this
|
||||||
@ -296,7 +309,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
void InitWorkSet(DMatrix *p_fmat,
|
void InitWorkSet(DMatrix *p_fmat,
|
||||||
const RegTree &tree,
|
const RegTree &tree,
|
||||||
std::vector<bst_uint> *p_fset) override {
|
std::vector<bst_uint> *p_fset) override {
|
||||||
feat_helper.InitByCol(p_fmat, tree);
|
if (p_fmat != cache_dmatrix_) {
|
||||||
|
feat_helper.InitByCol(p_fmat, tree);
|
||||||
|
cache_dmatrix_ = p_fmat;
|
||||||
|
}
|
||||||
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
||||||
}
|
}
|
||||||
// code to create histogram
|
// code to create histogram
|
||||||
@ -337,6 +353,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// update node statistics.
|
||||||
|
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||||
|
&thread_stats, &node_stats);
|
||||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
const int nid = this->qexpand[i];
|
const int nid = this->qexpand[i];
|
||||||
const int wid = this->node2workindex[nid];
|
const int wid = this->node2workindex[nid];
|
||||||
@ -355,8 +374,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||||
const RegTree &tree) override {
|
const RegTree &tree) override {
|
||||||
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
this->GetSplitSet(this->qexpand, tree, &fsplit_set);
|
||||||
}
|
}
|
||||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
@ -366,18 +385,18 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
// fill in reverse map
|
// fill in reverse map
|
||||||
feat2workindex.resize(tree.param.num_feature);
|
feat2workindex.resize(tree.param.num_feature);
|
||||||
std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
|
std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
|
||||||
freal_set.clear();
|
work_set.clear();
|
||||||
for (size_t i = 0; i < fset.size(); ++i) {
|
for (size_t i = 0; i < fset.size(); ++i) {
|
||||||
if (feat_helper.Type(fset[i]) == 2) {
|
if (feat_helper.Type(fset[i]) == 2) {
|
||||||
feat2workindex[fset[i]] = static_cast<int>(freal_set.size());
|
feat2workindex[fset[i]] = static_cast<int>(work_set.size());
|
||||||
freal_set.push_back(fset[i]);
|
work_set.push_back(fset[i]);
|
||||||
} else {
|
} else {
|
||||||
feat2workindex[fset[i]] = -2;
|
feat2workindex[fset[i]] = -2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
const size_t work_set_size = work_set.size();
|
||||||
&thread_stats, &node_stats);
|
|
||||||
sketchs.resize(this->qexpand.size() * freal_set.size());
|
sketchs.resize(this->qexpand.size() * work_set_size);
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||||
}
|
}
|
||||||
@ -388,20 +407,24 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
summary_array[i].Reserve(max_size);
|
summary_array[i].Reserve(max_size);
|
||||||
}
|
}
|
||||||
// if it is C++11, use lazy evaluation for Allreduce
|
{
|
||||||
#if __cplusplus >= 201103L
|
|
||||||
auto lazy_get_summary = [&]()
|
|
||||||
#endif
|
|
||||||
{
|
|
||||||
// get smmary
|
// get smmary
|
||||||
thread_sketch.resize(this->get_nthread());
|
thread_sketch.resize(this->get_nthread());
|
||||||
// number of rows in
|
|
||||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
|
this->SetDefaultPostion(p_fmat, tree);
|
||||||
|
work_set.insert(work_set.end(), fsplit_set.begin(), fsplit_set.end());
|
||||||
|
std::sort(work_set.begin(), work_set.end());
|
||||||
|
work_set.resize(std::unique(work_set.begin(), work_set.end()) - work_set.begin());
|
||||||
|
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(work_set);
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
const ColBatch &batch = iter->Value();
|
const ColBatch &batch = iter->Value();
|
||||||
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
|
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set, tree);
|
||||||
|
|
||||||
// start enumeration
|
// start enumeration
|
||||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
@ -409,9 +432,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
int offset = feat2workindex[batch.col_index[i]];
|
int offset = feat2workindex[batch.col_index[i]];
|
||||||
if (offset >= 0) {
|
if (offset >= 0) {
|
||||||
this->UpdateSketchCol(gpair, batch[i], tree,
|
this->UpdateSketchCol(gpair, batch[i], tree,
|
||||||
node_stats,
|
work_set_size, offset,
|
||||||
freal_set, offset,
|
|
||||||
batch[i].length == nrows,
|
|
||||||
&thread_sketch[omp_get_thread_num()]);
|
&thread_sketch[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -422,15 +443,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
summary_array[i].SetPrune(out, max_size);
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
CHECK_EQ(summary_array.size(), sketchs.size());
|
CHECK_EQ(summary_array.size(), sketchs.size());
|
||||||
};
|
}
|
||||||
if (summary_array.size() != 0) {
|
if (summary_array.size() != 0) {
|
||||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
#if __cplusplus >= 201103L
|
|
||||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array),
|
|
||||||
nbytes, summary_array.size(), lazy_get_summary);
|
|
||||||
#else
|
|
||||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
@ -440,7 +456,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
for (size_t i = 0; i < fset.size(); ++i) {
|
for (size_t i = 0; i < fset.size(); ++i) {
|
||||||
int offset = feat2workindex[fset[i]];
|
int offset = feat2workindex[fset[i]];
|
||||||
if (offset >= 0) {
|
if (offset >= 0) {
|
||||||
const WXQSketch::Summary &a = summary_array[wid * freal_set.size() + offset];
|
const WXQSketch::Summary &a = summary_array[wid * work_set_size + offset];
|
||||||
for (size_t i = 1; i < a.size; ++i) {
|
for (size_t i = 1; i < a.size; ++i) {
|
||||||
bst_float cpt = a.data[i].value - rt_eps;
|
bst_float cpt = a.data[i].value - rt_eps;
|
||||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||||
@ -470,7 +486,6 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
(fset.size() + 1) * this->qexpand.size() + 1);
|
(fset.size() + 1) * this->qexpand.size() + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
inline void UpdateHistCol(const std::vector<bst_gpair> &gpair,
|
inline void UpdateHistCol(const std::vector<bst_gpair> &gpair,
|
||||||
const ColBatch::Inst &c,
|
const ColBatch::Inst &c,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
@ -526,10 +541,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
||||||
const ColBatch::Inst &c,
|
const ColBatch::Inst &c,
|
||||||
const RegTree &tree,
|
const RegTree &tree,
|
||||||
const std::vector<TStats> &nstats,
|
size_t work_set_size,
|
||||||
const std::vector<bst_uint> &frealset,
|
|
||||||
bst_uint offset,
|
bst_uint offset,
|
||||||
bool col_full,
|
|
||||||
std::vector<BaseMaker::SketchEntry> *p_temp) {
|
std::vector<BaseMaker::SketchEntry> *p_temp) {
|
||||||
if (c.length == 0) return;
|
if (c.length == 0) return;
|
||||||
// initialize sbuilder for use
|
// initialize sbuilder for use
|
||||||
@ -539,22 +552,15 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
const unsigned nid = this->qexpand[i];
|
const unsigned nid = this->qexpand[i];
|
||||||
const unsigned wid = this->node2workindex[nid];
|
const unsigned wid = this->node2workindex[nid];
|
||||||
sbuilder[nid].sum_total = 0.0f;
|
sbuilder[nid].sum_total = 0.0f;
|
||||||
sbuilder[nid].sketch = &sketchs[wid * frealset.size() + offset];
|
sbuilder[nid].sketch = &sketchs[wid * work_set_size + offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!col_full) {
|
// first pass, get sum of weight, TODO, optimization to skip first pass
|
||||||
// first pass, get sum of weight, TODO, optimization to skip first pass
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
for (bst_uint j = 0; j < c.length; ++j) {
|
|
||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
sbuilder[nid].sum_total += gpair[ridx].hess;
|
sbuilder[nid].sum_total += gpair[ridx].hess;
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
|
||||||
const unsigned nid = this->qexpand[i];
|
|
||||||
sbuilder[nid].sum_total = static_cast<bst_float>(nstats[nid].sum_hess);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if only one value, no need to do second pass
|
// if only one value, no need to do second pass
|
||||||
@ -611,12 +617,16 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
sbuilder[nid].Finalize(max_size);
|
sbuilder[nid].Finalize(max_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// cached dmatrix where we initialized the feature on.
|
||||||
|
const DMatrix* cache_dmatrix_;
|
||||||
// feature helper
|
// feature helper
|
||||||
BaseMaker::FMetaHelper feat_helper;
|
BaseMaker::FMetaHelper feat_helper;
|
||||||
// temp space to map feature id to working index
|
// temp space to map feature id to working index
|
||||||
std::vector<int> feat2workindex;
|
std::vector<int> feat2workindex;
|
||||||
// set of index from fset that are real
|
// set of index from fset that are current work set
|
||||||
std::vector<bst_uint> freal_set;
|
std::vector<bst_uint> work_set;
|
||||||
|
// set of index from that are split candidates.
|
||||||
|
std::vector<bst_uint> fsplit_set;
|
||||||
// thread temp data
|
// thread temp data
|
||||||
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
||||||
// used to hold statistics
|
// used to hold statistics
|
||||||
@ -633,6 +643,108 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// global proposal
|
||||||
|
template<typename TStats>
|
||||||
|
class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
||||||
|
protected:
|
||||||
|
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||||
|
DMatrix *p_fmat,
|
||||||
|
const std::vector<bst_uint> &fset,
|
||||||
|
const RegTree &tree) override {
|
||||||
|
if (this->qexpand.size() == 1) {
|
||||||
|
cached_rptr_.clear();
|
||||||
|
cached_cut_.clear();
|
||||||
|
}
|
||||||
|
if (cached_rptr_.size() == 0) {
|
||||||
|
CHECK_EQ(this->qexpand.size(), 1);
|
||||||
|
CQHistMaker<TStats>::ResetPosAndPropose(gpair, p_fmat, fset, tree);
|
||||||
|
cached_rptr_ = this->wspace.rptr;
|
||||||
|
cached_cut_ = this->wspace.cut;
|
||||||
|
} else {
|
||||||
|
this->wspace.cut.clear();
|
||||||
|
this->wspace.rptr.clear();
|
||||||
|
this->wspace.rptr.push_back(0);
|
||||||
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
|
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
|
||||||
|
this->wspace.rptr.push_back(
|
||||||
|
this->wspace.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
|
||||||
|
}
|
||||||
|
this->wspace.cut.insert(this->wspace.cut.end(), cached_cut_.begin(), cached_cut_.end());
|
||||||
|
}
|
||||||
|
CHECK_EQ(this->wspace.rptr.size(),
|
||||||
|
(fset.size() + 1) * this->qexpand.size() + 1);
|
||||||
|
CHECK_EQ(this->wspace.rptr.back(), this->wspace.cut.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// code to create histogram
|
||||||
|
void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||||
|
DMatrix *p_fmat,
|
||||||
|
const std::vector<bst_uint> &fset,
|
||||||
|
const RegTree &tree) override {
|
||||||
|
const MetaInfo &info = p_fmat->info();
|
||||||
|
// fill in reverse map
|
||||||
|
this->feat2workindex.resize(tree.param.num_feature);
|
||||||
|
this->work_set = fset;
|
||||||
|
std::fill(this->feat2workindex.begin(), this->feat2workindex.end(), -1);
|
||||||
|
for (size_t i = 0; i < fset.size(); ++i) {
|
||||||
|
this->feat2workindex[fset[i]] = static_cast<int>(i);
|
||||||
|
}
|
||||||
|
// start to work
|
||||||
|
this->wspace.Init(this->param, 1);
|
||||||
|
// to gain speedup in recovery
|
||||||
|
{
|
||||||
|
this->thread_hist.resize(this->get_nthread());
|
||||||
|
|
||||||
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
|
this->SetDefaultPostion(p_fmat, tree);
|
||||||
|
this->work_set.insert(this->work_set.end(), this->fsplit_set.begin(), this->fsplit_set.end());
|
||||||
|
std::sort(this->work_set.begin(), this->work_set.end());
|
||||||
|
this->work_set.resize(
|
||||||
|
std::unique(this->work_set.begin(), this->work_set.end()) - this->work_set.begin());
|
||||||
|
|
||||||
|
// start accumulating statistics
|
||||||
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(this->work_set);
|
||||||
|
iter->BeforeFirst();
|
||||||
|
while (iter->Next()) {
|
||||||
|
const ColBatch &batch = iter->Value();
|
||||||
|
// TWOPASS: use the real set + split set in the column iteration.
|
||||||
|
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set, tree);
|
||||||
|
|
||||||
|
// start enumeration
|
||||||
|
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||||
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
|
int offset = this->feat2workindex[batch.col_index[i]];
|
||||||
|
if (offset >= 0) {
|
||||||
|
this->UpdateHistCol(gpair, batch[i], info, tree,
|
||||||
|
fset, offset,
|
||||||
|
&this->thread_hist[omp_get_thread_num()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update node statistics.
|
||||||
|
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||||
|
&(this->thread_stats), &(this->node_stats));
|
||||||
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
|
const int nid = this->qexpand[i];
|
||||||
|
const int wid = this->node2workindex[nid];
|
||||||
|
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
|
||||||
|
.data[0] = this->node_stats[nid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
|
||||||
|
this->wspace.hset[0].data.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// cached unit pointer
|
||||||
|
std::vector<unsigned> cached_rptr_;
|
||||||
|
// cached cut value.
|
||||||
|
std::vector<bst_float> cached_cut_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template<typename TStats>
|
template<typename TStats>
|
||||||
class QuantileHistMaker: public HistMaker<TStats> {
|
class QuantileHistMaker: public HistMaker<TStats> {
|
||||||
protected:
|
protected:
|
||||||
@ -759,10 +871,22 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||||
.describe("Tree constructor that uses approximate histogram construction.")
|
.describe("Tree constructor that uses approximate histogram construction.")
|
||||||
.set_body([]() {
|
.set_body([]() {
|
||||||
return new CQHistMaker<GradStats>();
|
return new CQHistMaker<GradStats>();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker")
|
||||||
|
.describe("Tree constructor that uses approximate global proposal of histogram construction.")
|
||||||
|
.set_body([]() {
|
||||||
|
return new GlobalProposalHistMaker<GradStats>();
|
||||||
|
});
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
||||||
|
.describe("Tree constructor that uses approximate global of histogram construction.")
|
||||||
|
.set_body([]() {
|
||||||
|
return new GlobalProposalHistMaker<GradStats>();
|
||||||
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user