need more check

This commit is contained in:
tqchen
2015-04-16 12:34:39 -07:00
parent a514340c96
commit 22abf4e295
8 changed files with 271 additions and 73 deletions

View File

@@ -70,33 +70,6 @@ class DMatrixPageBase : public DataMatrix {
// do not delete row iterator, since it is owned by fmat
// to be cleaned up in a more clear way
}
/*! \brief load and initialize the iterator with fi */
inline void LoadBinary(utils::FileStream &fi,
bool silent,
const char *fname_) {
std::string fname = fname_;
int tmagic;
utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
utils::Check(tmagic == magic, "invalid format,magic number mismatch");
this->info.LoadBinary(fi);
// load in the row data file
fname += ".row.blob";
utils::FileStream fs(utils::FopenCheck(fname.c_str(), "rb"));
iter_->Load(fs);
if (!silent) {
utils::Printf("DMatrixPage: %lux%lu matrix is loaded",
static_cast<unsigned long>(info.num_row()),
static_cast<unsigned long>(info.num_col()));
if (fname_ != NULL) {
utils::Printf(" from %s\n", fname_);
} else {
utils::Printf("\n");
}
if (info.group_ptr.size() != 0) {
utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1);
}
}
}
/*! \brief save a DataMatrix as DMatrixPage */
inline static void Save(const char *fname_, const DataMatrix &mat, bool silent) {
std::string fname = fname_;
@@ -127,18 +100,48 @@ class DMatrixPageBase : public DataMatrix {
static_cast<unsigned long>(mat.info.num_col()), fname_);
}
}
/*! \brief load and initialize the iterator with fi */
inline void LoadBinary(utils::FileStream &fi,
bool silent,
const char *fname_) {
this->set_cache_file(fname_);
std::string fname = fname_;
int tmagic;
utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
utils::Check(tmagic == magic, "invalid format,magic number mismatch");
this->info.LoadBinary(fi);
// load in the row data file
fname += ".row.blob";
utils::FileStream fs(utils::FopenCheck(fname.c_str(), "rb"));
iter_->Load(fs);
if (!silent) {
utils::Printf("DMatrixPage: %lux%lu matrix is loaded",
static_cast<unsigned long>(info.num_row()),
static_cast<unsigned long>(info.num_col()));
if (fname_ != NULL) {
utils::Printf(" from %s\n", fname_);
} else {
utils::Printf("\n");
}
if (info.group_ptr.size() != 0) {
utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1);
}
}
}
/*! \brief save a LibSVM format file as DMatrixPage */
inline void LoadText(const char *uri,
const char* cache_file,
bool silent,
bool loadsplit) {
int rank = 0, npart = 1;
if (loadsplit) {
rank = rabit::GetRank();
npart = rabit::GetWorldSize();
}
this->set_cache_file(cache_file);
std::string fname_row = std::string(cache_file) + ".row.blob";
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
SparsePage page;
dmlc::InputSplit *in =
dmlc::InputSplit::Create(uri, rank, npart);
@@ -190,8 +193,10 @@ class DMatrixPageBase : public DataMatrix {
/*! \brief magic number used to identify DMatrix */
static const int kMagic = TKMagic;
/*! \brief page size 64 MB */
static const size_t kPageSize = 64 << 18;
static const size_t kPageSize = 64UL << 20UL;
protected:
virtual void set_cache_file(const std::string &cache_file) = 0;
/*! \brief row iterator */
ThreadRowPageIterator *iter_;
};
@@ -199,7 +204,7 @@ class DMatrixPageBase : public DataMatrix {
class DMatrixPage : public DMatrixPageBase<0xffffab02> {
public:
DMatrixPage(void) {
fmat_ = new FMatrixS(iter_);
fmat_ = new FMatrixPage(iter_, this->info);
}
virtual ~DMatrixPage(void) {
delete fmat_;
@@ -207,8 +212,11 @@ class DMatrixPage : public DMatrixPageBase<0xffffab02> {
virtual IFMatrix *fmat(void) const {
return fmat_;
}
virtual void set_cache_file(const std::string &cache_file) {
fmat_->set_cache_file(cache_file);
}
/*! \brief the real fmatrix */
IFMatrix *fmat_;
FMatrixPage *fmat_;
};
} // namespace io
} // namespace xgboost