need more check
This commit is contained in:
parent
a514340c96
commit
22abf4e295
@ -149,7 +149,7 @@ class IFMatrix {
|
||||
virtual size_t NumCol(void) const = 0;
|
||||
/*! \brief get number of non-missing entries in column */
|
||||
virtual size_t GetColSize(size_t cidx) const = 0;
|
||||
/*! \brief get column density */
|
||||
/*! \brief get column density */
|
||||
virtual float GetColDensity(size_t cidx) const = 0;
|
||||
/*! \brief reference of buffered rowset */
|
||||
virtual const std::vector<bst_uint> &buffered_rowset(void) const = 0;
|
||||
|
||||
@ -15,6 +15,18 @@ DataMatrix* LoadDataMatrix(const char *fname,
|
||||
bool savebuffer,
|
||||
bool loadsplit,
|
||||
const char *cache_file) {
|
||||
std::string fname_ = fname;
|
||||
const char *dlm = strchr(fname, '#');
|
||||
if (dlm != NULL) {
|
||||
utils::Check(strchr(dlm + 1, '#') == NULL,
|
||||
"only one `#` is allowed in file path for cachefile specification");
|
||||
utils::Check(cache_file == NULL,
|
||||
"can only specify the cachefile with `#` or argument, not both");
|
||||
fname_ = std::string(fname, dlm - fname);
|
||||
fname = fname_.c_str();
|
||||
cache_file = dlm +1;
|
||||
}
|
||||
|
||||
if (cache_file == NULL) {
|
||||
if (!std::strcmp(fname, "stdin") ||
|
||||
!std::strncmp(fname, "s3://", 5) ||
|
||||
@ -39,16 +51,18 @@ DataMatrix* LoadDataMatrix(const char *fname,
|
||||
dmat->CacheLoad(fname, silent, savebuffer);
|
||||
return dmat;
|
||||
} else {
|
||||
if (!strcmp(fname, cache_file)) {
|
||||
FILE *fi = fopen64(cache_file, "rb");
|
||||
if (fi != NULL) {
|
||||
DMatrixPage *dmat = new DMatrixPage();
|
||||
utils::FileStream fs(utils::FopenCheck(fname, "rb"));
|
||||
dmat->LoadBinary(fs, silent, fname);
|
||||
utils::FileStream fs(fi);
|
||||
dmat->LoadBinary(fs, silent, cache_file);
|
||||
fs.Close();
|
||||
return dmat;
|
||||
} else {
|
||||
DMatrixPage *dmat = new DMatrixPage();
|
||||
dmat->LoadText(fname, cache_file, false, loadsplit);
|
||||
return dmat;
|
||||
}
|
||||
DMatrixPage *dmat = new DMatrixPage();
|
||||
dmat->LoadText(fname, cache_file, false, loadsplit);
|
||||
return dmat;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -70,33 +70,6 @@ class DMatrixPageBase : public DataMatrix {
|
||||
// do not delete row iterator, since it is owned by fmat
|
||||
// to be cleaned up in a more clear way
|
||||
}
|
||||
/*! \brief load and initialize the iterator with fi */
|
||||
inline void LoadBinary(utils::FileStream &fi,
|
||||
bool silent,
|
||||
const char *fname_) {
|
||||
std::string fname = fname_;
|
||||
int tmagic;
|
||||
utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
|
||||
utils::Check(tmagic == magic, "invalid format,magic number mismatch");
|
||||
this->info.LoadBinary(fi);
|
||||
// load in the row data file
|
||||
fname += ".row.blob";
|
||||
utils::FileStream fs(utils::FopenCheck(fname.c_str(), "rb"));
|
||||
iter_->Load(fs);
|
||||
if (!silent) {
|
||||
utils::Printf("DMatrixPage: %lux%lu matrix is loaded",
|
||||
static_cast<unsigned long>(info.num_row()),
|
||||
static_cast<unsigned long>(info.num_col()));
|
||||
if (fname_ != NULL) {
|
||||
utils::Printf(" from %s\n", fname_);
|
||||
} else {
|
||||
utils::Printf("\n");
|
||||
}
|
||||
if (info.group_ptr.size() != 0) {
|
||||
utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief save a DataMatrix as DMatrixPage */
|
||||
inline static void Save(const char *fname_, const DataMatrix &mat, bool silent) {
|
||||
std::string fname = fname_;
|
||||
@ -127,18 +100,48 @@ class DMatrixPageBase : public DataMatrix {
|
||||
static_cast<unsigned long>(mat.info.num_col()), fname_);
|
||||
}
|
||||
}
|
||||
/*! \brief load and initialize the iterator with fi */
|
||||
inline void LoadBinary(utils::FileStream &fi,
|
||||
bool silent,
|
||||
const char *fname_) {
|
||||
this->set_cache_file(fname_);
|
||||
std::string fname = fname_;
|
||||
int tmagic;
|
||||
utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
|
||||
utils::Check(tmagic == magic, "invalid format,magic number mismatch");
|
||||
this->info.LoadBinary(fi);
|
||||
// load in the row data file
|
||||
fname += ".row.blob";
|
||||
utils::FileStream fs(utils::FopenCheck(fname.c_str(), "rb"));
|
||||
iter_->Load(fs);
|
||||
if (!silent) {
|
||||
utils::Printf("DMatrixPage: %lux%lu matrix is loaded",
|
||||
static_cast<unsigned long>(info.num_row()),
|
||||
static_cast<unsigned long>(info.num_col()));
|
||||
if (fname_ != NULL) {
|
||||
utils::Printf(" from %s\n", fname_);
|
||||
} else {
|
||||
utils::Printf("\n");
|
||||
}
|
||||
if (info.group_ptr.size() != 0) {
|
||||
utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief save a LibSVM format file as DMatrixPage */
|
||||
inline void LoadText(const char *uri,
|
||||
const char* cache_file,
|
||||
bool silent,
|
||||
bool loadsplit) {
|
||||
|
||||
int rank = 0, npart = 1;
|
||||
if (loadsplit) {
|
||||
rank = rabit::GetRank();
|
||||
npart = rabit::GetWorldSize();
|
||||
}
|
||||
this->set_cache_file(cache_file);
|
||||
std::string fname_row = std::string(cache_file) + ".row.blob";
|
||||
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
|
||||
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
|
||||
SparsePage page;
|
||||
dmlc::InputSplit *in =
|
||||
dmlc::InputSplit::Create(uri, rank, npart);
|
||||
@ -190,8 +193,10 @@ class DMatrixPageBase : public DataMatrix {
|
||||
/*! \brief magic number used to identify DMatrix */
|
||||
static const int kMagic = TKMagic;
|
||||
/*! \brief page size 64 MB */
|
||||
static const size_t kPageSize = 64 << 18;
|
||||
static const size_t kPageSize = 64UL << 20UL;
|
||||
|
||||
protected:
|
||||
virtual void set_cache_file(const std::string &cache_file) = 0;
|
||||
/*! \brief row iterator */
|
||||
ThreadRowPageIterator *iter_;
|
||||
};
|
||||
@ -199,7 +204,7 @@ class DMatrixPageBase : public DataMatrix {
|
||||
class DMatrixPage : public DMatrixPageBase<0xffffab02> {
|
||||
public:
|
||||
DMatrixPage(void) {
|
||||
fmat_ = new FMatrixS(iter_);
|
||||
fmat_ = new FMatrixPage(iter_, this->info);
|
||||
}
|
||||
virtual ~DMatrixPage(void) {
|
||||
delete fmat_;
|
||||
@ -207,8 +212,11 @@ class DMatrixPage : public DMatrixPageBase<0xffffab02> {
|
||||
virtual IFMatrix *fmat(void) const {
|
||||
return fmat_;
|
||||
}
|
||||
virtual void set_cache_file(const std::string &cache_file) {
|
||||
fmat_->set_cache_file(cache_file);
|
||||
}
|
||||
/*! \brief the real fmatrix */
|
||||
IFMatrix *fmat_;
|
||||
FMatrixPage *fmat_;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace xgboost
|
||||
|
||||
@ -37,12 +37,12 @@ class ThreadColPageIterator: public utils::IIterator<ColBatch> {
|
||||
}
|
||||
/*! \brief load and initialize the iterator with fi */
|
||||
inline void SetFile(const utils::FileStream &fi) {
|
||||
itr.get_factory().SetFile(fi, 0);
|
||||
itr.get_factory().SetFile(fi);
|
||||
itr.Init();
|
||||
}
|
||||
// set index set
|
||||
inline void SetIndexSet(const std::vector<bst_uint> &fset) {
|
||||
itr.get_factory().SetIndexSet(fset);
|
||||
inline void SetIndexSet(const std::vector<bst_uint> &fset, bool load_all) {
|
||||
itr.get_factory().SetIndexSet(fset, load_all);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -55,25 +55,26 @@ class ThreadColPageIterator: public utils::IIterator<ColBatch> {
|
||||
/*!
|
||||
* \brief sparse matrix that support column access, CSC
|
||||
*/
|
||||
class FMatrixS : public IFMatrix {
|
||||
class FMatrixPage : public IFMatrix {
|
||||
public:
|
||||
typedef SparseBatch::Entry Entry;
|
||||
/*! \brief constructor */
|
||||
FMatrixS(utils::IIterator<RowBatch> *iter) {
|
||||
FMatrixPage(utils::IIterator<RowBatch> *iter,
|
||||
const learner::MetaInfo &info) : info(info) {
|
||||
this->iter_ = iter;
|
||||
}
|
||||
// destructor
|
||||
virtual ~FMatrixS(void) {
|
||||
virtual ~FMatrixPage(void) {
|
||||
if (iter_ != NULL) delete iter_;
|
||||
}
|
||||
/*! \return whether column access is enabled */
|
||||
virtual bool HaveColAccess(void) const {
|
||||
return col_ptr_.size() != 0;
|
||||
virtual bool HaveColAccess(void) const {
|
||||
return col_size_.size() != 0;
|
||||
}
|
||||
/*! \brief get number of colmuns */
|
||||
virtual size_t NumCol(void) const {
|
||||
utils::Check(this->HaveColAccess(), "NumCol:need column access");
|
||||
return col_ptr_.size() - 1;
|
||||
return col_size_.size();
|
||||
}
|
||||
/*! \brief get number of buffered rows */
|
||||
virtual const std::vector<bst_uint> &buffered_rowset(void) const {
|
||||
@ -81,17 +82,19 @@ class FMatrixS : public IFMatrix {
|
||||
}
|
||||
/*! \brief get column size */
|
||||
virtual size_t GetColSize(size_t cidx) const {
|
||||
return col_ptr_[cidx+1] - col_ptr_[cidx];
|
||||
return col_size_[cidx];
|
||||
}
|
||||
/*! \brief get column density */
|
||||
virtual float GetColDensity(size_t cidx) const {
|
||||
size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]);
|
||||
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
|
||||
size_t nmiss = num_buffered_row_ - (col_size_[cidx]);
|
||||
return 1.0f - (static_cast<float>(nmiss)) / num_buffered_row_;
|
||||
}
|
||||
virtual void InitColAccess(const std::vector<bool> &enabled,
|
||||
float pkeep = 1.0f) {
|
||||
if (this->HaveColAccess()) return;
|
||||
this->InitColData(pkeep, enabled);
|
||||
if (TryLoadColData()) return;
|
||||
this->InitColData(enabled, pkeep);
|
||||
utils::Check(TryLoadColData(), "failed on creating col.blob");
|
||||
}
|
||||
/*!
|
||||
* \brief get the row iterator associated with FMatrix
|
||||
@ -105,25 +108,171 @@ class FMatrixS : public IFMatrix {
|
||||
*/
|
||||
virtual utils::IIterator<ColBatch>* ColIterator(void) {
|
||||
size_t ncol = this->NumCol();
|
||||
col_iter_.col_index_.resize(ncol);
|
||||
col_index_.resize(0);
|
||||
for (size_t i = 0; i < ncol; ++i) {
|
||||
col_iter_.col_index_[i] = static_cast<bst_uint>(i);
|
||||
col_index_.push_back(i);
|
||||
}
|
||||
col_iter_.SetBatch(col_ptr_, col_data_);
|
||||
col_iter_.SetIndexSet(col_index_, false);
|
||||
col_iter_.BeforeFirst();
|
||||
return &col_iter_;
|
||||
}
|
||||
/*!
|
||||
* \brief colmun based iterator
|
||||
*/
|
||||
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) {
|
||||
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) {
|
||||
size_t ncol = this->NumCol();
|
||||
col_iter_.col_index_.resize(0);
|
||||
col_index_.resize(0);
|
||||
for (size_t i = 0; i < fset.size(); ++i) {
|
||||
if (fset[i] < ncol) col_iter_.col_index_.push_back(fset[i]);
|
||||
if (fset[i] < ncol) col_index_.push_back(fset[i]);
|
||||
}
|
||||
col_iter_.SetBatch(col_ptr_, col_data_);
|
||||
col_iter_.SetIndexSet(col_index_, false);
|
||||
col_iter_.BeforeFirst();
|
||||
return &col_iter_;
|
||||
}
|
||||
// set the cache file name
|
||||
inline void set_cache_file(const std::string &cache_file) {
|
||||
col_data_name_ = std::string(cache_file) + ".col.blob";
|
||||
col_meta_name_ = std::string(cache_file) + ".col.meta";
|
||||
}
|
||||
|
||||
protected:
|
||||
inline bool TryLoadColData(void) {
|
||||
FILE *fi = fopen64(col_meta_name_.c_str(), "rb");
|
||||
if (fi == NULL) return false;
|
||||
utils::FileStream fs(fi);
|
||||
LoadMeta(&fs);
|
||||
fs.Close();
|
||||
fi = utils::FopenCheck(col_data_name_.c_str(), "rb");
|
||||
if (fi == NULL) return false;
|
||||
col_iter_.SetFile(utils::FileStream(fi));
|
||||
return true;
|
||||
}
|
||||
inline void LoadMeta(utils::IStream *fi) {
|
||||
utils::Check(fi->Read(&num_buffered_row_, sizeof(num_buffered_row_)) != 0,
|
||||
"invalid col.blob file");
|
||||
utils::Check(fi->Read(&buffered_rowset_),
|
||||
"invalid col.blob file");
|
||||
utils::Check(fi->Read(&col_size_),
|
||||
"invalid col.blob file");
|
||||
}
|
||||
inline void SaveMeta(utils::IStream *fo) {
|
||||
fo->Write(&num_buffered_row_, sizeof(num_buffered_row_));
|
||||
fo->Write(buffered_rowset_);
|
||||
fo->Write(col_size_);
|
||||
}
|
||||
/*!
|
||||
* \brief intialize column data
|
||||
* \param pkeep probability to keep a row
|
||||
*/
|
||||
inline void InitColData(const std::vector<bool> &enabled, float pkeep) {
|
||||
SparsePage prow, pcol;
|
||||
size_t btop = 0;
|
||||
// clear rowset
|
||||
buffered_rowset_.clear();
|
||||
col_size_.resize(info.num_col());
|
||||
std::fill(col_size_.begin(), col_size_.end(), 0);
|
||||
utils::FileStream fo;
|
||||
fo = utils::FileStream(utils::FopenCheck(col_data_name_.c_str(), "wb"));
|
||||
// start working
|
||||
iter_->BeforeFirst();
|
||||
while (iter_->Next()) {
|
||||
const RowBatch &batch = iter_->Value();
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
if (pkeep == 1.0f || random::SampleBinary(pkeep)) {
|
||||
buffered_rowset_.push_back(ridx);
|
||||
prow.Push(batch[i]);
|
||||
if (prow.MemCostBytes() >= kPageSize) {
|
||||
this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop,
|
||||
enabled, &pcol, &fo);
|
||||
btop += prow.Size();
|
||||
prow.Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (prow.Size() != 0) {
|
||||
this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop,
|
||||
enabled, &pcol, &fo);
|
||||
}
|
||||
fo.Close();
|
||||
num_buffered_row_ = buffered_rowset_.size();
|
||||
fo = utils::FileStream(utils::FopenCheck(col_meta_name_.c_str(), "wb"));
|
||||
this->SaveMeta(&fo);
|
||||
fo.Close();
|
||||
}
|
||||
inline void PushColPage(const SparsePage &prow,
|
||||
const bst_uint *ridx,
|
||||
const std::vector<bool> &enabled,
|
||||
SparsePage *pcol,
|
||||
utils::IStream *fo) {
|
||||
pcol->Clear();
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
pcol->Clear();
|
||||
utils::ParallelGroupBuilder<SparseBatch::Entry>
|
||||
builder(&pcol->offset, &pcol->data);
|
||||
builder.InitBudget(info.num_col(), nthread);
|
||||
bst_omp_uint ndata = static_cast<bst_uint>(prow.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
int tid = omp_get_thread_num();
|
||||
for (bst_uint j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
|
||||
const SparseBatch::Entry &e = prow.data[j];
|
||||
if (enabled[e.index]) {
|
||||
builder.AddBudget(e.index, tid);
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
int tid = omp_get_thread_num();
|
||||
for (bst_uint j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
|
||||
const SparseBatch::Entry &e = prow.data[j];
|
||||
builder.Push(e.index,
|
||||
SparseBatch::Entry(ridx[i], e.fvalue),
|
||||
tid);
|
||||
}
|
||||
}
|
||||
utils::Assert(pcol->Size() == info.num_col(), "inconsistent col data");
|
||||
// sort columns
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(BeginPtr(pcol->data) + pcol->offset[i],
|
||||
BeginPtr(pcol->data) + pcol->offset[i + 1], Entry::CmpValue);
|
||||
}
|
||||
col_size_[i] += pcol->offset[i + 1] - pcol->offset[i];
|
||||
}
|
||||
pcol->Save(fo);
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief page size 256 M */
|
||||
static const size_t kPageSize = 256 << 20UL;
|
||||
// shared meta info with DMatrix
|
||||
const learner::MetaInfo &info;
|
||||
// row iterator
|
||||
utils::IIterator<RowBatch> *iter_;
|
||||
/*! \brief column based data file name */
|
||||
std::string col_data_name_;
|
||||
/*! \brief column based data file name */
|
||||
std::string col_meta_name_;
|
||||
/*! \brief list of row index that are buffered */
|
||||
std::vector<bst_uint> buffered_rowset_;
|
||||
// number of buffered rows
|
||||
size_t num_buffered_row_;
|
||||
// count for column data
|
||||
std::vector<size_t> col_size_;
|
||||
// internal column index for output
|
||||
std::vector<bst_uint> col_index_;
|
||||
// internal thread backed col iterator
|
||||
ThreadColPageIterator col_iter_;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace xgboost
|
||||
|
||||
@ -24,6 +24,10 @@ class SparsePage {
|
||||
SparsePage() {
|
||||
this->Clear();
|
||||
}
|
||||
/*! \return number of instance in the page */
|
||||
inline size_t Size() const {
|
||||
return offset.size() - 1;
|
||||
}
|
||||
/*!
|
||||
* \brief load the by providing a list of interested segments
|
||||
* only the interested segments are loaded
|
||||
@ -38,6 +42,7 @@ class SparsePage {
|
||||
offset.clear(); offset.push_back(0);
|
||||
for (size_t i = 0; i < sorted_index_set.size(); ++i) {
|
||||
bst_uint fid = sorted_index_set[i];
|
||||
utils::Check(fid + 1 < disk_offset_.size(), "bad col.blob format");
|
||||
size_t size = disk_offset_[fid + 1] - disk_offset_[fid];
|
||||
offset.push_back(offset.back() + size);
|
||||
}
|
||||
@ -49,7 +54,7 @@ class SparsePage {
|
||||
bst_uint fid = sorted_index_set[i];
|
||||
if (disk_offset_[fid] != curr_offset) {
|
||||
utils::Assert(disk_offset_[fid] > curr_offset, "fset index was not sorted");
|
||||
fi->Seek(begin + disk_offset_[fid]);
|
||||
fi->Seek(begin + disk_offset_[fid] * sizeof(SparseBatch::Entry));
|
||||
curr_offset = disk_offset_[fid];
|
||||
}
|
||||
size_t j, size_to_read = 0;
|
||||
@ -68,6 +73,10 @@ class SparsePage {
|
||||
}
|
||||
i = j;
|
||||
}
|
||||
// seek to end of record
|
||||
if (curr_offset != disk_offset_.back()) {
|
||||
fi->Seek(begin + disk_offset_.back() * sizeof(SparseBatch::Entry));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
@ -166,7 +175,8 @@ class SparsePage {
|
||||
*/
|
||||
class SparsePageFactory {
|
||||
public:
|
||||
SparsePageFactory(void) {}
|
||||
SparsePageFactory(void)
|
||||
: action_load_all_(true), set_load_all_(true) {}
|
||||
inline void SetFile(const utils::FileStream &fi,
|
||||
size_t file_begin = 0) {
|
||||
fi_ = fi;
|
||||
@ -176,19 +186,27 @@ class SparsePageFactory {
|
||||
return action_index_set_;
|
||||
}
|
||||
// set index set, will be used after next before first
|
||||
inline void SetIndexSet(const std::vector<bst_uint> &index_set) {
|
||||
set_index_set_ = index_set;
|
||||
std::sort(set_index_set_.begin(), set_index_set_.end());
|
||||
inline void SetIndexSet(const std::vector<bst_uint> &index_set,
|
||||
bool load_all) {
|
||||
set_load_all_ = load_all;
|
||||
if (!set_load_all_) {
|
||||
set_index_set_ = index_set;
|
||||
std::sort(set_index_set_.begin(), set_index_set_.end());
|
||||
}
|
||||
}
|
||||
inline bool Init(void) {
|
||||
return true;
|
||||
}
|
||||
inline void SetParam(const char *name, const char *val) {}
|
||||
inline bool LoadNext(SparsePage *val) {
|
||||
if (action_index_set_.size() != 0) {
|
||||
return val->Load(&fi_, action_index_set_);
|
||||
inline bool LoadNext(SparsePage *val) {
|
||||
if (!action_load_all_) {
|
||||
if (action_index_set_.size() == 0) {
|
||||
return false;
|
||||
} else {
|
||||
return val->Load(&fi_, action_index_set_);
|
||||
}
|
||||
} else {
|
||||
return val->Load(&fi_);
|
||||
return val->Load(&fi_);
|
||||
}
|
||||
}
|
||||
inline SparsePage *Create(void) {
|
||||
@ -202,10 +220,14 @@ class SparsePageFactory {
|
||||
}
|
||||
inline void BeforeFirst(void) {
|
||||
fi_.Seek(file_begin_);
|
||||
action_index_set_ = set_index_set_;
|
||||
action_load_all_ = set_load_all_;
|
||||
if (!set_load_all_) {
|
||||
action_index_set_ = set_index_set_;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool action_load_all_, set_load_all_;
|
||||
size_t file_begin_;
|
||||
utils::FileStream fi_;
|
||||
std::vector<bst_uint> action_index_set_;
|
||||
|
||||
@ -69,7 +69,7 @@ class BoostLearner : public rabit::Serializable {
|
||||
utils::SPrintf(str_temp, sizeof(str_temp), "%lu",
|
||||
static_cast<unsigned long>(buffer_size));
|
||||
this->SetParam("num_pbuffer", str_temp);
|
||||
this->pred_buffer_size = buffer_size;
|
||||
this->pred_buffer_size = buffer_size;
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
@ -259,7 +259,12 @@ class BoostLearner : public rabit::Serializable {
|
||||
int ncol = static_cast<int>(p_train->info.info.num_col);
|
||||
std::vector<bool> enabled(ncol, true);
|
||||
// initialize column access
|
||||
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
|
||||
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
|
||||
const int kMagicSimple = 0xffffab01;
|
||||
// check, if it is not DMatrix simple, then use hist maker
|
||||
if (p_train->magic != kMagicSimple) {
|
||||
this->SetParam("updater", "grow_histmaker,prune");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief update the model for one iteration
|
||||
|
||||
@ -50,7 +50,7 @@ class BaseMaker: public IUpdater {
|
||||
fminmax[fid * 2 + 1] = std::max(c[c.length - 1].fvalue, fminmax[fid * 2 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rabit::Allreduce<rabit::op::Max>(BeginPtr(fminmax), fminmax.size());
|
||||
}
|
||||
// get feature type, 0:empty 1:binary 2:real
|
||||
|
||||
@ -366,7 +366,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
} else {
|
||||
feat2workindex[fset[i]] = -2;
|
||||
}
|
||||
}
|
||||
}
|
||||
this->GetNodeStats(gpair, *p_fmat, tree, info,
|
||||
&thread_stats, &node_stats);
|
||||
sketchs.resize(this->qexpand.size() * freal_set.size());
|
||||
@ -578,7 +578,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector <bst_uint> &fset,
|
||||
const RegTree &tree) {
|
||||
const RegTree &tree) {
|
||||
// initialize the data structure
|
||||
int nthread = BaseMaker::get_nthread();
|
||||
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user