From 46f14b8c27936f0fb061dacb24402b3c5263afd4 Mon Sep 17 00:00:00 2001 From: "tqchen@graphlab.com" Date: Tue, 26 Aug 2014 12:17:27 -0700 Subject: [PATCH] fix magic so that it can detect binary file --- src/io/io.cpp | 14 ++++++++++++++ src/io/simple_dmatrix-inl.hpp | 22 ++++++++++++++++++---- src/utils/io.h | 3 +++ wrapper/xgboost.py | 2 +- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/io/io.cpp b/src/io/io.cpp index a3ea457ed..d251d7a96 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -2,6 +2,7 @@ #define _CRT_SECURE_NO_DEPRECATE #include #include "./io.h" +#include "../utils/io.h" #include "../utils/utils.h" #include "simple_dmatrix-inl.hpp" // implements data loads using dmatrix simple for now @@ -9,6 +10,19 @@ namespace xgboost { namespace io { DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { + int magic; + utils::FileStream fs(utils::FopenCheck(fname, "rb")); + utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format"); + fs.Seek(0); + + if (magic == DMatrixSimple::kMagic) { + DMatrixSimple *dmat = new DMatrixSimple(); + dmat->LoadBinary(fs, silent, fname); + fs.Close(); + return dmat; + } + fs.Close(); + DMatrixSimple *dmat = new DMatrixSimple(); dmat->CacheLoad(fname, silent, savebuffer); return dmat; diff --git a/src/io/simple_dmatrix-inl.hpp b/src/io/simple_dmatrix-inl.hpp index faf21021e..99bd0b932 100644 --- a/src/io/simple_dmatrix-inl.hpp +++ b/src/io/simple_dmatrix-inl.hpp @@ -128,6 +128,17 @@ class DMatrixSimple : public DataMatrix { FILE *fp = fopen64(fname, "rb"); if (fp == NULL) return false; utils::FileStream fs(fp); + this->LoadBinary(fs, silent, fname); + fs.Close(); + return true; + } + /*! + * \brief load from binary stream + * \param fs input file stream + * \param silent whether print information during loading + * \param fname file name, used to print message + */ + inline void LoadBinary(utils::IStream &fs, bool silent = false, const char *fname = NULL) { int magic; utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format"); utils::Check(magic == kMagic, "invalid format,magic number mismatch"); @@ -135,16 +146,19 @@ class DMatrixSimple : public DataMatrix { info.LoadBinary(fs); FMatrixS::LoadBinary(fs, &row_ptr_, &row_data_); fmat.LoadColAccess(fs); - fs.Close(); if (!silent) { - printf("%lux%lu matrix with %lu entries is loaded from %s\n", - info.num_row(), info.num_col(), row_data_.size(), fname); + printf("%lux%lu matrix with %lu entries is loaded", + info.num_row(), info.num_col(), row_data_.size()); + if (fname != NULL) { + printf(" from %s\n", fname); + } else { + printf("\n"); + } if (info.group_ptr.size() != 0) { printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); } } - return true; } /*! * \brief save to binary file diff --git a/src/utils/io.h b/src/utils/io.h index 26a9bd02e..4a80e9a58 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -102,6 +102,9 @@ class FileStream : public IStream { virtual void Write(const void *ptr, size_t size) { fwrite(ptr, size, 1, fp); } + inline void Seek(size_t pos) { + fseek(fp, 0, SEEK_SET); + } inline void Close(void) { fclose(fp); } diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 4e17dafcf..013b84b15 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -45,7 +45,7 @@ class DMatrix: return if isinstance(data, str): self.handle = ctypes.c_void_p( - xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1)) + xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0)) elif isinstance(data, scp.csr_matrix): self.__init_from_csr(data) elif isinstance(data, numpy.ndarray) and len(data.shape) == 2: