need more check
This commit is contained in:
parent
a514340c96
commit
22abf4e295
@ -15,6 +15,18 @@ DataMatrix* LoadDataMatrix(const char *fname,
|
|||||||
bool savebuffer,
|
bool savebuffer,
|
||||||
bool loadsplit,
|
bool loadsplit,
|
||||||
const char *cache_file) {
|
const char *cache_file) {
|
||||||
|
std::string fname_ = fname;
|
||||||
|
const char *dlm = strchr(fname, '#');
|
||||||
|
if (dlm != NULL) {
|
||||||
|
utils::Check(strchr(dlm + 1, '#') == NULL,
|
||||||
|
"only one `#` is allowed in file path for cachefile specification");
|
||||||
|
utils::Check(cache_file == NULL,
|
||||||
|
"can only specify the cachefile with `#` or argument, not both");
|
||||||
|
fname_ = std::string(fname, dlm - fname);
|
||||||
|
fname = fname_.c_str();
|
||||||
|
cache_file = dlm +1;
|
||||||
|
}
|
||||||
|
|
||||||
if (cache_file == NULL) {
|
if (cache_file == NULL) {
|
||||||
if (!std::strcmp(fname, "stdin") ||
|
if (!std::strcmp(fname, "stdin") ||
|
||||||
!std::strncmp(fname, "s3://", 5) ||
|
!std::strncmp(fname, "s3://", 5) ||
|
||||||
@ -39,16 +51,18 @@ DataMatrix* LoadDataMatrix(const char *fname,
|
|||||||
dmat->CacheLoad(fname, silent, savebuffer);
|
dmat->CacheLoad(fname, silent, savebuffer);
|
||||||
return dmat;
|
return dmat;
|
||||||
} else {
|
} else {
|
||||||
if (!strcmp(fname, cache_file)) {
|
FILE *fi = fopen64(cache_file, "rb");
|
||||||
|
if (fi != NULL) {
|
||||||
DMatrixPage *dmat = new DMatrixPage();
|
DMatrixPage *dmat = new DMatrixPage();
|
||||||
utils::FileStream fs(utils::FopenCheck(fname, "rb"));
|
utils::FileStream fs(fi);
|
||||||
dmat->LoadBinary(fs, silent, fname);
|
dmat->LoadBinary(fs, silent, cache_file);
|
||||||
fs.Close();
|
fs.Close();
|
||||||
return dmat;
|
return dmat;
|
||||||
|
} else {
|
||||||
|
DMatrixPage *dmat = new DMatrixPage();
|
||||||
|
dmat->LoadText(fname, cache_file, false, loadsplit);
|
||||||
|
return dmat;
|
||||||
}
|
}
|
||||||
DMatrixPage *dmat = new DMatrixPage();
|
|
||||||
dmat->LoadText(fname, cache_file, false, loadsplit);
|
|
||||||
return dmat;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,33 +70,6 @@ class DMatrixPageBase : public DataMatrix {
|
|||||||
// do not delete row iterator, since it is owned by fmat
|
// do not delete row iterator, since it is owned by fmat
|
||||||
// to be cleaned up in a more clear way
|
// to be cleaned up in a more clear way
|
||||||
}
|
}
|
||||||
/*! \brief load and initialize the iterator with fi */
|
|
||||||
inline void LoadBinary(utils::FileStream &fi,
|
|
||||||
bool silent,
|
|
||||||
const char *fname_) {
|
|
||||||
std::string fname = fname_;
|
|
||||||
int tmagic;
|
|
||||||
utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
|
|
||||||
utils::Check(tmagic == magic, "invalid format,magic number mismatch");
|
|
||||||
this->info.LoadBinary(fi);
|
|
||||||
// load in the row data file
|
|
||||||
fname += ".row.blob";
|
|
||||||
utils::FileStream fs(utils::FopenCheck(fname.c_str(), "rb"));
|
|
||||||
iter_->Load(fs);
|
|
||||||
if (!silent) {
|
|
||||||
utils::Printf("DMatrixPage: %lux%lu matrix is loaded",
|
|
||||||
static_cast<unsigned long>(info.num_row()),
|
|
||||||
static_cast<unsigned long>(info.num_col()));
|
|
||||||
if (fname_ != NULL) {
|
|
||||||
utils::Printf(" from %s\n", fname_);
|
|
||||||
} else {
|
|
||||||
utils::Printf("\n");
|
|
||||||
}
|
|
||||||
if (info.group_ptr.size() != 0) {
|
|
||||||
utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*! \brief save a DataMatrix as DMatrixPage */
|
/*! \brief save a DataMatrix as DMatrixPage */
|
||||||
inline static void Save(const char *fname_, const DataMatrix &mat, bool silent) {
|
inline static void Save(const char *fname_, const DataMatrix &mat, bool silent) {
|
||||||
std::string fname = fname_;
|
std::string fname = fname_;
|
||||||
@ -127,16 +100,46 @@ class DMatrixPageBase : public DataMatrix {
|
|||||||
static_cast<unsigned long>(mat.info.num_col()), fname_);
|
static_cast<unsigned long>(mat.info.num_col()), fname_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*! \brief load and initialize the iterator with fi */
|
||||||
|
inline void LoadBinary(utils::FileStream &fi,
|
||||||
|
bool silent,
|
||||||
|
const char *fname_) {
|
||||||
|
this->set_cache_file(fname_);
|
||||||
|
std::string fname = fname_;
|
||||||
|
int tmagic;
|
||||||
|
utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
|
||||||
|
utils::Check(tmagic == magic, "invalid format,magic number mismatch");
|
||||||
|
this->info.LoadBinary(fi);
|
||||||
|
// load in the row data file
|
||||||
|
fname += ".row.blob";
|
||||||
|
utils::FileStream fs(utils::FopenCheck(fname.c_str(), "rb"));
|
||||||
|
iter_->Load(fs);
|
||||||
|
if (!silent) {
|
||||||
|
utils::Printf("DMatrixPage: %lux%lu matrix is loaded",
|
||||||
|
static_cast<unsigned long>(info.num_row()),
|
||||||
|
static_cast<unsigned long>(info.num_col()));
|
||||||
|
if (fname_ != NULL) {
|
||||||
|
utils::Printf(" from %s\n", fname_);
|
||||||
|
} else {
|
||||||
|
utils::Printf("\n");
|
||||||
|
}
|
||||||
|
if (info.group_ptr.size() != 0) {
|
||||||
|
utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
/*! \brief save a LibSVM format file as DMatrixPage */
|
/*! \brief save a LibSVM format file as DMatrixPage */
|
||||||
inline void LoadText(const char *uri,
|
inline void LoadText(const char *uri,
|
||||||
const char* cache_file,
|
const char* cache_file,
|
||||||
bool silent,
|
bool silent,
|
||||||
bool loadsplit) {
|
bool loadsplit) {
|
||||||
|
|
||||||
int rank = 0, npart = 1;
|
int rank = 0, npart = 1;
|
||||||
if (loadsplit) {
|
if (loadsplit) {
|
||||||
rank = rabit::GetRank();
|
rank = rabit::GetRank();
|
||||||
npart = rabit::GetWorldSize();
|
npart = rabit::GetWorldSize();
|
||||||
}
|
}
|
||||||
|
this->set_cache_file(cache_file);
|
||||||
std::string fname_row = std::string(cache_file) + ".row.blob";
|
std::string fname_row = std::string(cache_file) + ".row.blob";
|
||||||
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
|
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
|
||||||
SparsePage page;
|
SparsePage page;
|
||||||
@ -190,8 +193,10 @@ class DMatrixPageBase : public DataMatrix {
|
|||||||
/*! \brief magic number used to identify DMatrix */
|
/*! \brief magic number used to identify DMatrix */
|
||||||
static const int kMagic = TKMagic;
|
static const int kMagic = TKMagic;
|
||||||
/*! \brief page size 64 MB */
|
/*! \brief page size 64 MB */
|
||||||
static const size_t kPageSize = 64 << 18;
|
static const size_t kPageSize = 64UL << 20UL;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
virtual void set_cache_file(const std::string &cache_file) = 0;
|
||||||
/*! \brief row iterator */
|
/*! \brief row iterator */
|
||||||
ThreadRowPageIterator *iter_;
|
ThreadRowPageIterator *iter_;
|
||||||
};
|
};
|
||||||
@ -199,7 +204,7 @@ class DMatrixPageBase : public DataMatrix {
|
|||||||
class DMatrixPage : public DMatrixPageBase<0xffffab02> {
|
class DMatrixPage : public DMatrixPageBase<0xffffab02> {
|
||||||
public:
|
public:
|
||||||
DMatrixPage(void) {
|
DMatrixPage(void) {
|
||||||
fmat_ = new FMatrixS(iter_);
|
fmat_ = new FMatrixPage(iter_, this->info);
|
||||||
}
|
}
|
||||||
virtual ~DMatrixPage(void) {
|
virtual ~DMatrixPage(void) {
|
||||||
delete fmat_;
|
delete fmat_;
|
||||||
@ -207,8 +212,11 @@ class DMatrixPage : public DMatrixPageBase<0xffffab02> {
|
|||||||
virtual IFMatrix *fmat(void) const {
|
virtual IFMatrix *fmat(void) const {
|
||||||
return fmat_;
|
return fmat_;
|
||||||
}
|
}
|
||||||
|
virtual void set_cache_file(const std::string &cache_file) {
|
||||||
|
fmat_->set_cache_file(cache_file);
|
||||||
|
}
|
||||||
/*! \brief the real fmatrix */
|
/*! \brief the real fmatrix */
|
||||||
IFMatrix *fmat_;
|
FMatrixPage *fmat_;
|
||||||
};
|
};
|
||||||
} // namespace io
|
} // namespace io
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -37,12 +37,12 @@ class ThreadColPageIterator: public utils::IIterator<ColBatch> {
|
|||||||
}
|
}
|
||||||
/*! \brief load and initialize the iterator with fi */
|
/*! \brief load and initialize the iterator with fi */
|
||||||
inline void SetFile(const utils::FileStream &fi) {
|
inline void SetFile(const utils::FileStream &fi) {
|
||||||
itr.get_factory().SetFile(fi, 0);
|
itr.get_factory().SetFile(fi);
|
||||||
itr.Init();
|
itr.Init();
|
||||||
}
|
}
|
||||||
// set index set
|
// set index set
|
||||||
inline void SetIndexSet(const std::vector<bst_uint> &fset) {
|
inline void SetIndexSet(const std::vector<bst_uint> &fset, bool load_all) {
|
||||||
itr.get_factory().SetIndexSet(fset);
|
itr.get_factory().SetIndexSet(fset, load_all);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -55,25 +55,26 @@ class ThreadColPageIterator: public utils::IIterator<ColBatch> {
|
|||||||
/*!
|
/*!
|
||||||
* \brief sparse matrix that support column access, CSC
|
* \brief sparse matrix that support column access, CSC
|
||||||
*/
|
*/
|
||||||
class FMatrixS : public IFMatrix {
|
class FMatrixPage : public IFMatrix {
|
||||||
public:
|
public:
|
||||||
typedef SparseBatch::Entry Entry;
|
typedef SparseBatch::Entry Entry;
|
||||||
/*! \brief constructor */
|
/*! \brief constructor */
|
||||||
FMatrixS(utils::IIterator<RowBatch> *iter) {
|
FMatrixPage(utils::IIterator<RowBatch> *iter,
|
||||||
|
const learner::MetaInfo &info) : info(info) {
|
||||||
this->iter_ = iter;
|
this->iter_ = iter;
|
||||||
}
|
}
|
||||||
// destructor
|
// destructor
|
||||||
virtual ~FMatrixS(void) {
|
virtual ~FMatrixPage(void) {
|
||||||
if (iter_ != NULL) delete iter_;
|
if (iter_ != NULL) delete iter_;
|
||||||
}
|
}
|
||||||
/*! \return whether column access is enabled */
|
/*! \return whether column access is enabled */
|
||||||
virtual bool HaveColAccess(void) const {
|
virtual bool HaveColAccess(void) const {
|
||||||
return col_ptr_.size() != 0;
|
return col_size_.size() != 0;
|
||||||
}
|
}
|
||||||
/*! \brief get number of colmuns */
|
/*! \brief get number of colmuns */
|
||||||
virtual size_t NumCol(void) const {
|
virtual size_t NumCol(void) const {
|
||||||
utils::Check(this->HaveColAccess(), "NumCol:need column access");
|
utils::Check(this->HaveColAccess(), "NumCol:need column access");
|
||||||
return col_ptr_.size() - 1;
|
return col_size_.size();
|
||||||
}
|
}
|
||||||
/*! \brief get number of buffered rows */
|
/*! \brief get number of buffered rows */
|
||||||
virtual const std::vector<bst_uint> &buffered_rowset(void) const {
|
virtual const std::vector<bst_uint> &buffered_rowset(void) const {
|
||||||
@ -81,17 +82,19 @@ class FMatrixS : public IFMatrix {
|
|||||||
}
|
}
|
||||||
/*! \brief get column size */
|
/*! \brief get column size */
|
||||||
virtual size_t GetColSize(size_t cidx) const {
|
virtual size_t GetColSize(size_t cidx) const {
|
||||||
return col_ptr_[cidx+1] - col_ptr_[cidx];
|
return col_size_[cidx];
|
||||||
}
|
}
|
||||||
/*! \brief get column density */
|
/*! \brief get column density */
|
||||||
virtual float GetColDensity(size_t cidx) const {
|
virtual float GetColDensity(size_t cidx) const {
|
||||||
size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]);
|
size_t nmiss = num_buffered_row_ - (col_size_[cidx]);
|
||||||
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
|
return 1.0f - (static_cast<float>(nmiss)) / num_buffered_row_;
|
||||||
}
|
}
|
||||||
virtual void InitColAccess(const std::vector<bool> &enabled,
|
virtual void InitColAccess(const std::vector<bool> &enabled,
|
||||||
float pkeep = 1.0f) {
|
float pkeep = 1.0f) {
|
||||||
if (this->HaveColAccess()) return;
|
if (this->HaveColAccess()) return;
|
||||||
this->InitColData(pkeep, enabled);
|
if (TryLoadColData()) return;
|
||||||
|
this->InitColData(enabled, pkeep);
|
||||||
|
utils::Check(TryLoadColData(), "failed on creating col.blob");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief get the row iterator associated with FMatrix
|
* \brief get the row iterator associated with FMatrix
|
||||||
@ -105,11 +108,12 @@ class FMatrixS : public IFMatrix {
|
|||||||
*/
|
*/
|
||||||
virtual utils::IIterator<ColBatch>* ColIterator(void) {
|
virtual utils::IIterator<ColBatch>* ColIterator(void) {
|
||||||
size_t ncol = this->NumCol();
|
size_t ncol = this->NumCol();
|
||||||
col_iter_.col_index_.resize(ncol);
|
col_index_.resize(0);
|
||||||
for (size_t i = 0; i < ncol; ++i) {
|
for (size_t i = 0; i < ncol; ++i) {
|
||||||
col_iter_.col_index_[i] = static_cast<bst_uint>(i);
|
col_index_.push_back(i);
|
||||||
}
|
}
|
||||||
col_iter_.SetBatch(col_ptr_, col_data_);
|
col_iter_.SetIndexSet(col_index_, false);
|
||||||
|
col_iter_.BeforeFirst();
|
||||||
return &col_iter_;
|
return &col_iter_;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -117,13 +121,158 @@ class FMatrixS : public IFMatrix {
|
|||||||
*/
|
*/
|
||||||
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) {
|
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) {
|
||||||
size_t ncol = this->NumCol();
|
size_t ncol = this->NumCol();
|
||||||
col_iter_.col_index_.resize(0);
|
col_index_.resize(0);
|
||||||
for (size_t i = 0; i < fset.size(); ++i) {
|
for (size_t i = 0; i < fset.size(); ++i) {
|
||||||
if (fset[i] < ncol) col_iter_.col_index_.push_back(fset[i]);
|
if (fset[i] < ncol) col_index_.push_back(fset[i]);
|
||||||
}
|
}
|
||||||
col_iter_.SetBatch(col_ptr_, col_data_);
|
col_iter_.SetIndexSet(col_index_, false);
|
||||||
|
col_iter_.BeforeFirst();
|
||||||
return &col_iter_;
|
return &col_iter_;
|
||||||
}
|
}
|
||||||
|
// set the cache file name
|
||||||
|
inline void set_cache_file(const std::string &cache_file) {
|
||||||
|
col_data_name_ = std::string(cache_file) + ".col.blob";
|
||||||
|
col_meta_name_ = std::string(cache_file) + ".col.meta";
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
inline bool TryLoadColData(void) {
|
||||||
|
FILE *fi = fopen64(col_meta_name_.c_str(), "rb");
|
||||||
|
if (fi == NULL) return false;
|
||||||
|
utils::FileStream fs(fi);
|
||||||
|
LoadMeta(&fs);
|
||||||
|
fs.Close();
|
||||||
|
fi = utils::FopenCheck(col_data_name_.c_str(), "rb");
|
||||||
|
if (fi == NULL) return false;
|
||||||
|
col_iter_.SetFile(utils::FileStream(fi));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
inline void LoadMeta(utils::IStream *fi) {
|
||||||
|
utils::Check(fi->Read(&num_buffered_row_, sizeof(num_buffered_row_)) != 0,
|
||||||
|
"invalid col.blob file");
|
||||||
|
utils::Check(fi->Read(&buffered_rowset_),
|
||||||
|
"invalid col.blob file");
|
||||||
|
utils::Check(fi->Read(&col_size_),
|
||||||
|
"invalid col.blob file");
|
||||||
|
}
|
||||||
|
inline void SaveMeta(utils::IStream *fo) {
|
||||||
|
fo->Write(&num_buffered_row_, sizeof(num_buffered_row_));
|
||||||
|
fo->Write(buffered_rowset_);
|
||||||
|
fo->Write(col_size_);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief intialize column data
|
||||||
|
* \param pkeep probability to keep a row
|
||||||
|
*/
|
||||||
|
inline void InitColData(const std::vector<bool> &enabled, float pkeep) {
|
||||||
|
SparsePage prow, pcol;
|
||||||
|
size_t btop = 0;
|
||||||
|
// clear rowset
|
||||||
|
buffered_rowset_.clear();
|
||||||
|
col_size_.resize(info.num_col());
|
||||||
|
std::fill(col_size_.begin(), col_size_.end(), 0);
|
||||||
|
utils::FileStream fo;
|
||||||
|
fo = utils::FileStream(utils::FopenCheck(col_data_name_.c_str(), "wb"));
|
||||||
|
// start working
|
||||||
|
iter_->BeforeFirst();
|
||||||
|
while (iter_->Next()) {
|
||||||
|
const RowBatch &batch = iter_->Value();
|
||||||
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
|
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||||
|
if (pkeep == 1.0f || random::SampleBinary(pkeep)) {
|
||||||
|
buffered_rowset_.push_back(ridx);
|
||||||
|
prow.Push(batch[i]);
|
||||||
|
if (prow.MemCostBytes() >= kPageSize) {
|
||||||
|
this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop,
|
||||||
|
enabled, &pcol, &fo);
|
||||||
|
btop += prow.Size();
|
||||||
|
prow.Clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (prow.Size() != 0) {
|
||||||
|
this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop,
|
||||||
|
enabled, &pcol, &fo);
|
||||||
|
}
|
||||||
|
fo.Close();
|
||||||
|
num_buffered_row_ = buffered_rowset_.size();
|
||||||
|
fo = utils::FileStream(utils::FopenCheck(col_meta_name_.c_str(), "wb"));
|
||||||
|
this->SaveMeta(&fo);
|
||||||
|
fo.Close();
|
||||||
|
}
|
||||||
|
inline void PushColPage(const SparsePage &prow,
|
||||||
|
const bst_uint *ridx,
|
||||||
|
const std::vector<bool> &enabled,
|
||||||
|
SparsePage *pcol,
|
||||||
|
utils::IStream *fo) {
|
||||||
|
pcol->Clear();
|
||||||
|
int nthread;
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
nthread = omp_get_num_threads();
|
||||||
|
}
|
||||||
|
pcol->Clear();
|
||||||
|
utils::ParallelGroupBuilder<SparseBatch::Entry>
|
||||||
|
builder(&pcol->offset, &pcol->data);
|
||||||
|
builder.InitBudget(info.num_col(), nthread);
|
||||||
|
bst_omp_uint ndata = static_cast<bst_uint>(prow.Size());
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
|
int tid = omp_get_thread_num();
|
||||||
|
for (bst_uint j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
|
||||||
|
const SparseBatch::Entry &e = prow.data[j];
|
||||||
|
if (enabled[e.index]) {
|
||||||
|
builder.AddBudget(e.index, tid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.InitStorage();
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
|
int tid = omp_get_thread_num();
|
||||||
|
for (bst_uint j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
|
||||||
|
const SparseBatch::Entry &e = prow.data[j];
|
||||||
|
builder.Push(e.index,
|
||||||
|
SparseBatch::Entry(ridx[i], e.fvalue),
|
||||||
|
tid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
utils::Assert(pcol->Size() == info.num_col(), "inconsistent col data");
|
||||||
|
// sort columns
|
||||||
|
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||||
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
|
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||||
|
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||||
|
std::sort(BeginPtr(pcol->data) + pcol->offset[i],
|
||||||
|
BeginPtr(pcol->data) + pcol->offset[i + 1], Entry::CmpValue);
|
||||||
|
}
|
||||||
|
col_size_[i] += pcol->offset[i + 1] - pcol->offset[i];
|
||||||
|
}
|
||||||
|
pcol->Save(fo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief page size 256 M */
|
||||||
|
static const size_t kPageSize = 256 << 20UL;
|
||||||
|
// shared meta info with DMatrix
|
||||||
|
const learner::MetaInfo &info;
|
||||||
|
// row iterator
|
||||||
|
utils::IIterator<RowBatch> *iter_;
|
||||||
|
/*! \brief column based data file name */
|
||||||
|
std::string col_data_name_;
|
||||||
|
/*! \brief column based data file name */
|
||||||
|
std::string col_meta_name_;
|
||||||
|
/*! \brief list of row index that are buffered */
|
||||||
|
std::vector<bst_uint> buffered_rowset_;
|
||||||
|
// number of buffered rows
|
||||||
|
size_t num_buffered_row_;
|
||||||
|
// count for column data
|
||||||
|
std::vector<size_t> col_size_;
|
||||||
|
// internal column index for output
|
||||||
|
std::vector<bst_uint> col_index_;
|
||||||
|
// internal thread backed col iterator
|
||||||
|
ThreadColPageIterator col_iter_;
|
||||||
};
|
};
|
||||||
} // namespace io
|
} // namespace io
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -24,6 +24,10 @@ class SparsePage {
|
|||||||
SparsePage() {
|
SparsePage() {
|
||||||
this->Clear();
|
this->Clear();
|
||||||
}
|
}
|
||||||
|
/*! \return number of instance in the page */
|
||||||
|
inline size_t Size() const {
|
||||||
|
return offset.size() - 1;
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load the by providing a list of interested segments
|
* \brief load the by providing a list of interested segments
|
||||||
* only the interested segments are loaded
|
* only the interested segments are loaded
|
||||||
@ -38,6 +42,7 @@ class SparsePage {
|
|||||||
offset.clear(); offset.push_back(0);
|
offset.clear(); offset.push_back(0);
|
||||||
for (size_t i = 0; i < sorted_index_set.size(); ++i) {
|
for (size_t i = 0; i < sorted_index_set.size(); ++i) {
|
||||||
bst_uint fid = sorted_index_set[i];
|
bst_uint fid = sorted_index_set[i];
|
||||||
|
utils::Check(fid + 1 < disk_offset_.size(), "bad col.blob format");
|
||||||
size_t size = disk_offset_[fid + 1] - disk_offset_[fid];
|
size_t size = disk_offset_[fid + 1] - disk_offset_[fid];
|
||||||
offset.push_back(offset.back() + size);
|
offset.push_back(offset.back() + size);
|
||||||
}
|
}
|
||||||
@ -49,7 +54,7 @@ class SparsePage {
|
|||||||
bst_uint fid = sorted_index_set[i];
|
bst_uint fid = sorted_index_set[i];
|
||||||
if (disk_offset_[fid] != curr_offset) {
|
if (disk_offset_[fid] != curr_offset) {
|
||||||
utils::Assert(disk_offset_[fid] > curr_offset, "fset index was not sorted");
|
utils::Assert(disk_offset_[fid] > curr_offset, "fset index was not sorted");
|
||||||
fi->Seek(begin + disk_offset_[fid]);
|
fi->Seek(begin + disk_offset_[fid] * sizeof(SparseBatch::Entry));
|
||||||
curr_offset = disk_offset_[fid];
|
curr_offset = disk_offset_[fid];
|
||||||
}
|
}
|
||||||
size_t j, size_to_read = 0;
|
size_t j, size_to_read = 0;
|
||||||
@ -68,6 +73,10 @@ class SparsePage {
|
|||||||
}
|
}
|
||||||
i = j;
|
i = j;
|
||||||
}
|
}
|
||||||
|
// seek to end of record
|
||||||
|
if (curr_offset != disk_offset_.back()) {
|
||||||
|
fi->Seek(begin + disk_offset_.back() * sizeof(SparseBatch::Entry));
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -166,7 +175,8 @@ class SparsePage {
|
|||||||
*/
|
*/
|
||||||
class SparsePageFactory {
|
class SparsePageFactory {
|
||||||
public:
|
public:
|
||||||
SparsePageFactory(void) {}
|
SparsePageFactory(void)
|
||||||
|
: action_load_all_(true), set_load_all_(true) {}
|
||||||
inline void SetFile(const utils::FileStream &fi,
|
inline void SetFile(const utils::FileStream &fi,
|
||||||
size_t file_begin = 0) {
|
size_t file_begin = 0) {
|
||||||
fi_ = fi;
|
fi_ = fi;
|
||||||
@ -176,17 +186,25 @@ class SparsePageFactory {
|
|||||||
return action_index_set_;
|
return action_index_set_;
|
||||||
}
|
}
|
||||||
// set index set, will be used after next before first
|
// set index set, will be used after next before first
|
||||||
inline void SetIndexSet(const std::vector<bst_uint> &index_set) {
|
inline void SetIndexSet(const std::vector<bst_uint> &index_set,
|
||||||
set_index_set_ = index_set;
|
bool load_all) {
|
||||||
std::sort(set_index_set_.begin(), set_index_set_.end());
|
set_load_all_ = load_all;
|
||||||
|
if (!set_load_all_) {
|
||||||
|
set_index_set_ = index_set;
|
||||||
|
std::sort(set_index_set_.begin(), set_index_set_.end());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
inline bool Init(void) {
|
inline bool Init(void) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
inline void SetParam(const char *name, const char *val) {}
|
inline void SetParam(const char *name, const char *val) {}
|
||||||
inline bool LoadNext(SparsePage *val) {
|
inline bool LoadNext(SparsePage *val) {
|
||||||
if (action_index_set_.size() != 0) {
|
if (!action_load_all_) {
|
||||||
return val->Load(&fi_, action_index_set_);
|
if (action_index_set_.size() == 0) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return val->Load(&fi_, action_index_set_);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return val->Load(&fi_);
|
return val->Load(&fi_);
|
||||||
}
|
}
|
||||||
@ -202,10 +220,14 @@ class SparsePageFactory {
|
|||||||
}
|
}
|
||||||
inline void BeforeFirst(void) {
|
inline void BeforeFirst(void) {
|
||||||
fi_.Seek(file_begin_);
|
fi_.Seek(file_begin_);
|
||||||
action_index_set_ = set_index_set_;
|
action_load_all_ = set_load_all_;
|
||||||
|
if (!set_load_all_) {
|
||||||
|
action_index_set_ = set_index_set_;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
bool action_load_all_, set_load_all_;
|
||||||
size_t file_begin_;
|
size_t file_begin_;
|
||||||
utils::FileStream fi_;
|
utils::FileStream fi_;
|
||||||
std::vector<bst_uint> action_index_set_;
|
std::vector<bst_uint> action_index_set_;
|
||||||
|
|||||||
@ -260,6 +260,11 @@ class BoostLearner : public rabit::Serializable {
|
|||||||
std::vector<bool> enabled(ncol, true);
|
std::vector<bool> enabled(ncol, true);
|
||||||
// initialize column access
|
// initialize column access
|
||||||
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
|
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
|
||||||
|
const int kMagicSimple = 0xffffab01;
|
||||||
|
// check, if it is not DMatrix simple, then use hist maker
|
||||||
|
if (p_train->magic != kMagicSimple) {
|
||||||
|
this->SetParam("updater", "grow_histmaker,prune");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief update the model for one iteration
|
* \brief update the model for one iteration
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user