fix magic so that it can detect binary file

This commit is contained in:
tqchen@graphlab.com 2014-08-26 12:17:27 -07:00
parent 9eb32b9dd4
commit 46f14b8c27
4 changed files with 36 additions and 5 deletions

View File

@ -2,6 +2,7 @@
#define _CRT_SECURE_NO_DEPRECATE
#include <string>
#include "./io.h"
#include "../utils/io.h"
#include "../utils/utils.h"
#include "simple_dmatrix-inl.hpp"
// implements data loads using dmatrix simple for now
@ -9,6 +10,19 @@
namespace xgboost {
namespace io {
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
int magic;
utils::FileStream fs(utils::FopenCheck(fname, "rb"));
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
fs.Seek(0);
if (magic == DMatrixSimple::kMagic) {
DMatrixSimple *dmat = new DMatrixSimple();
dmat->LoadBinary(fs, silent, fname);
fs.Close();
return dmat;
}
fs.Close();
DMatrixSimple *dmat = new DMatrixSimple();
dmat->CacheLoad(fname, silent, savebuffer);
return dmat;

View File

@ -128,6 +128,17 @@ class DMatrixSimple : public DataMatrix {
FILE *fp = fopen64(fname, "rb");
if (fp == NULL) return false;
utils::FileStream fs(fp);
this->LoadBinary(fs, silent, fname);
fs.Close();
return true;
}
/*!
* \brief load from binary stream
* \param fs input file stream
* \param silent whether print information during loading
* \param fname file name, used to print message
*/
inline void LoadBinary(utils::IStream &fs, bool silent = false, const char *fname = NULL) {
int magic;
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
utils::Check(magic == kMagic, "invalid format,magic number mismatch");
@ -135,16 +146,19 @@ class DMatrixSimple : public DataMatrix {
info.LoadBinary(fs);
FMatrixS::LoadBinary(fs, &row_ptr_, &row_data_);
fmat.LoadColAccess(fs);
fs.Close();
if (!silent) {
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
info.num_row(), info.num_col(), row_data_.size(), fname);
printf("%lux%lu matrix with %lu entries is loaded",
info.num_row(), info.num_col(), row_data_.size());
if (fname != NULL) {
printf(" from %s\n", fname);
} else {
printf("\n");
}
if (info.group_ptr.size() != 0) {
printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1);
}
}
return true;
}
/*!
* \brief save to binary file

View File

@ -102,6 +102,9 @@ class FileStream : public IStream {
virtual void Write(const void *ptr, size_t size) {
fwrite(ptr, size, 1, fp);
}
inline void Seek(size_t pos) {
fseek(fp, 0, SEEK_SET);
}
inline void Close(void) {
fclose(fp);
}

View File

@ -45,7 +45,7 @@ class DMatrix:
return
if isinstance(data, str):
self.handle = ctypes.c_void_p(
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0))
elif isinstance(data, scp.csr_matrix):
self.__init_from_csr(data)
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2: