diff --git a/src/io/io.cpp b/src/io/io.cpp index a8aed0d43..8a4579ab8 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -13,6 +13,14 @@ namespace xgboost { namespace io { DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { + std::string tmp_fname; + const char *fname_ext = NULL; + if (strchr(fname, ';') != NULL) { + tmp_fname = fname; + char *ptr = strchr(&tmp_fname[0], ';'); + ptr[0] = '\0'; fname = &tmp_fname[0]; + fname_ext = ptr + 1; + } int magic; utils::FileStream fs(utils::FopenCheck(fname, "rb")); utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format"); @@ -25,15 +33,23 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { return dmat; } if (magic == DMatrixPage::kMagic) { - DMatrixPage *dmat = new DMatrixPage(); - dmat->Load(fs, silent, fname); - // the file pointer is hold in page matrix - return dmat; + if (fname_ext == NULL) { + DMatrixPage *dmat = new DMatrixPage(); + dmat->Load(fs, silent, fname); + return dmat; + } else { + DMatrixColPage *dmat = new DMatrixColPage(fname_ext); + dmat->Load(fs, silent, fname, true); + return dmat; + } } if (magic == DMatrixColPage::kMagic) { - DMatrixColPage *dmat = new DMatrixColPage(fname); + std::string sfname = fname; + if (fname_ext == NULL) { + sfname += ".col"; fname_ext = sfname.c_str(); + } + DMatrixColPage *dmat = new DMatrixColPage(fname_ext); dmat->Load(fs, silent, fname); - // the file pointer is hold in page matrix return dmat; } fs.Close(); diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index d52700e87..41ad19be5 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -214,10 +214,13 @@ class DMatrixPageBase : public DataMatrix { /*! \brief load and initialize the iterator with fi */ inline void Load(utils::FileStream &fi, bool silent = false, - const char *fname = NULL) { + const char *fname = NULL, + bool skip_magic_check = false) { int tmagic; utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format"); - utils::Check(tmagic == magic, "invalid format,magic number mismatch"); + if (!skip_magic_check) { + utils::Check(tmagic == magic, "invalid format,magic number mismatch"); + } this->info.LoadBinary(fi); iter_->Load(fi); if (!silent) { diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index 327e5c144..3b53c2484 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -355,9 +355,7 @@ class FMatrixPage : public IFMatrix { class DMatrixColPage : public DMatrixPageBase<0xffffab03> { public: explicit DMatrixColPage(const char *fname) { - std::string fext = fname; - fext += ".col"; - fmat_ = new FMatrixPage(iter_, fext.c_str()); + fmat_ = new FMatrixPage(iter_, fname); } virtual ~DMatrixColPage(void) { delete fmat_; diff --git a/src/tree/updater.cpp b/src/tree/updater.cpp index 2cb6552fe..5879b2bbd 100644 --- a/src/tree/updater.cpp +++ b/src/tree/updater.cpp @@ -13,6 +13,8 @@ IUpdater* CreateUpdater(const char *name) { if (!strcmp(name, "prune")) return new TreePruner(); if (!strcmp(name, "refresh")) return new TreeRefresher(); if (!strcmp(name, "grow_colmaker")) return new ColMaker(); + if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >(); + if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >(); utils::Error("unknown updater:%s", name); return NULL; }