fix magic so that it can detect binary file
This commit is contained in:
parent
9eb32b9dd4
commit
46f14b8c27
@ -2,6 +2,7 @@
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#include <string>
|
||||
#include "./io.h"
|
||||
#include "../utils/io.h"
|
||||
#include "../utils/utils.h"
|
||||
#include "simple_dmatrix-inl.hpp"
|
||||
// implements data loads using dmatrix simple for now
|
||||
@ -9,6 +10,19 @@
|
||||
namespace xgboost {
|
||||
namespace io {
|
||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
|
||||
int magic;
|
||||
utils::FileStream fs(utils::FopenCheck(fname, "rb"));
|
||||
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
|
||||
fs.Seek(0);
|
||||
|
||||
if (magic == DMatrixSimple::kMagic) {
|
||||
DMatrixSimple *dmat = new DMatrixSimple();
|
||||
dmat->LoadBinary(fs, silent, fname);
|
||||
fs.Close();
|
||||
return dmat;
|
||||
}
|
||||
fs.Close();
|
||||
|
||||
DMatrixSimple *dmat = new DMatrixSimple();
|
||||
dmat->CacheLoad(fname, silent, savebuffer);
|
||||
return dmat;
|
||||
|
||||
@ -128,6 +128,17 @@ class DMatrixSimple : public DataMatrix {
|
||||
FILE *fp = fopen64(fname, "rb");
|
||||
if (fp == NULL) return false;
|
||||
utils::FileStream fs(fp);
|
||||
this->LoadBinary(fs, silent, fname);
|
||||
fs.Close();
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
* \brief load from binary stream
|
||||
* \param fs input file stream
|
||||
* \param silent whether print information during loading
|
||||
* \param fname file name, used to print message
|
||||
*/
|
||||
inline void LoadBinary(utils::IStream &fs, bool silent = false, const char *fname = NULL) {
|
||||
int magic;
|
||||
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
|
||||
utils::Check(magic == kMagic, "invalid format,magic number mismatch");
|
||||
@ -135,16 +146,19 @@ class DMatrixSimple : public DataMatrix {
|
||||
info.LoadBinary(fs);
|
||||
FMatrixS::LoadBinary(fs, &row_ptr_, &row_data_);
|
||||
fmat.LoadColAccess(fs);
|
||||
fs.Close();
|
||||
|
||||
if (!silent) {
|
||||
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
||||
info.num_row(), info.num_col(), row_data_.size(), fname);
|
||||
printf("%lux%lu matrix with %lu entries is loaded",
|
||||
info.num_row(), info.num_col(), row_data_.size());
|
||||
if (fname != NULL) {
|
||||
printf(" from %s\n", fname);
|
||||
} else {
|
||||
printf("\n");
|
||||
}
|
||||
if (info.group_ptr.size() != 0) {
|
||||
printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
* \brief save to binary file
|
||||
|
||||
@ -102,6 +102,9 @@ class FileStream : public IStream {
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
fwrite(ptr, size, 1, fp);
|
||||
}
|
||||
inline void Seek(size_t pos) {
|
||||
fseek(fp, 0, SEEK_SET);
|
||||
}
|
||||
inline void Close(void) {
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
@ -45,7 +45,7 @@ class DMatrix:
|
||||
return
|
||||
if isinstance(data, str):
|
||||
self.handle = ctypes.c_void_p(
|
||||
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
|
||||
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0))
|
||||
elif isinstance(data, scp.csr_matrix):
|
||||
self.__init_from_csr(data)
|
||||
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user