fix magic so that it can detect binary file
This commit is contained in:
parent
9eb32b9dd4
commit
46f14b8c27
@ -2,6 +2,7 @@
|
|||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
|
#include "../utils/io.h"
|
||||||
#include "../utils/utils.h"
|
#include "../utils/utils.h"
|
||||||
#include "simple_dmatrix-inl.hpp"
|
#include "simple_dmatrix-inl.hpp"
|
||||||
// implements data loads using dmatrix simple for now
|
// implements data loads using dmatrix simple for now
|
||||||
@ -9,6 +10,19 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace io {
|
namespace io {
|
||||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
|
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
|
||||||
|
int magic;
|
||||||
|
utils::FileStream fs(utils::FopenCheck(fname, "rb"));
|
||||||
|
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
|
||||||
|
fs.Seek(0);
|
||||||
|
|
||||||
|
if (magic == DMatrixSimple::kMagic) {
|
||||||
|
DMatrixSimple *dmat = new DMatrixSimple();
|
||||||
|
dmat->LoadBinary(fs, silent, fname);
|
||||||
|
fs.Close();
|
||||||
|
return dmat;
|
||||||
|
}
|
||||||
|
fs.Close();
|
||||||
|
|
||||||
DMatrixSimple *dmat = new DMatrixSimple();
|
DMatrixSimple *dmat = new DMatrixSimple();
|
||||||
dmat->CacheLoad(fname, silent, savebuffer);
|
dmat->CacheLoad(fname, silent, savebuffer);
|
||||||
return dmat;
|
return dmat;
|
||||||
|
|||||||
@ -128,6 +128,17 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
FILE *fp = fopen64(fname, "rb");
|
FILE *fp = fopen64(fname, "rb");
|
||||||
if (fp == NULL) return false;
|
if (fp == NULL) return false;
|
||||||
utils::FileStream fs(fp);
|
utils::FileStream fs(fp);
|
||||||
|
this->LoadBinary(fs, silent, fname);
|
||||||
|
fs.Close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief load from binary stream
|
||||||
|
* \param fs input file stream
|
||||||
|
* \param silent whether print information during loading
|
||||||
|
* \param fname file name, used to print message
|
||||||
|
*/
|
||||||
|
inline void LoadBinary(utils::IStream &fs, bool silent = false, const char *fname = NULL) {
|
||||||
int magic;
|
int magic;
|
||||||
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
|
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
|
||||||
utils::Check(magic == kMagic, "invalid format,magic number mismatch");
|
utils::Check(magic == kMagic, "invalid format,magic number mismatch");
|
||||||
@ -135,16 +146,19 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
info.LoadBinary(fs);
|
info.LoadBinary(fs);
|
||||||
FMatrixS::LoadBinary(fs, &row_ptr_, &row_data_);
|
FMatrixS::LoadBinary(fs, &row_ptr_, &row_data_);
|
||||||
fmat.LoadColAccess(fs);
|
fmat.LoadColAccess(fs);
|
||||||
fs.Close();
|
|
||||||
|
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
printf("%lux%lu matrix with %lu entries is loaded",
|
||||||
info.num_row(), info.num_col(), row_data_.size(), fname);
|
info.num_row(), info.num_col(), row_data_.size());
|
||||||
|
if (fname != NULL) {
|
||||||
|
printf(" from %s\n", fname);
|
||||||
|
} else {
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
if (info.group_ptr.size() != 0) {
|
if (info.group_ptr.size() != 0) {
|
||||||
printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1);
|
printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief save to binary file
|
* \brief save to binary file
|
||||||
|
|||||||
@ -102,6 +102,9 @@ class FileStream : public IStream {
|
|||||||
virtual void Write(const void *ptr, size_t size) {
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
fwrite(ptr, size, 1, fp);
|
fwrite(ptr, size, 1, fp);
|
||||||
}
|
}
|
||||||
|
inline void Seek(size_t pos) {
|
||||||
|
fseek(fp, 0, SEEK_SET);
|
||||||
|
}
|
||||||
inline void Close(void) {
|
inline void Close(void) {
|
||||||
fclose(fp);
|
fclose(fp);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class DMatrix:
|
|||||||
return
|
return
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
self.handle = ctypes.c_void_p(
|
self.handle = ctypes.c_void_p(
|
||||||
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
|
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0))
|
||||||
elif isinstance(data, scp.csr_matrix):
|
elif isinstance(data, scp.csr_matrix):
|
||||||
self.__init_from_csr(data)
|
self.__init_from_csr(data)
|
||||||
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user