From 225f5258c7402305410ba3f45e8d647853651ed7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 29 Feb 2016 14:59:44 -0800 Subject: [PATCH] [DMLC] Add dep to dmlc logging --- include/dmlc/base.h | 207 ++++++++++++++++++++ include/dmlc/io.h | 101 +++++----- include/dmlc/logging.h | 268 ++++++++++++++++++++++++++ include/dmlc/serializer.h | 383 +++++++++++++++++++++++++++++++++++++ include/dmlc/type_traits.h | 171 +++++++++++++++++ 5 files changed, 1077 insertions(+), 53 deletions(-) create mode 100644 include/dmlc/base.h create mode 100644 include/dmlc/logging.h create mode 100644 include/dmlc/serializer.h create mode 100644 include/dmlc/type_traits.h diff --git a/include/dmlc/base.h b/include/dmlc/base.h new file mode 100644 index 000000000..01413f100 --- /dev/null +++ b/include/dmlc/base.h @@ -0,0 +1,207 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file base.h + * \brief defines configuration macros + */ +#ifndef DMLC_BASE_H_ +#define DMLC_BASE_H_ + +/*! \brief whether use glog for logging */ +#ifndef DMLC_USE_GLOG +#define DMLC_USE_GLOG 0 +#endif + +/*! + * \brief whether throw dmlc::Error instead of + * directly calling abort when FATAL error occured + * NOTE: this may still not be perfect. + * do not use FATAL and CHECK in destructors + */ +#ifndef DMLC_LOG_FATAL_THROW +#define DMLC_LOG_FATAL_THROW 1 +#endif + +/*! + * \brief whether always log a message before throw + * This can help identify the error that cannot be catched. + */ +#ifndef DMLC_LOG_BEFORE_THROW +#define DMLC_LOG_BEFORE_THROW 1 +#endif + +/*! + * \brief Whether to use customized logger, + * whose output can be decided by other libraries. + */ +#ifndef DMLC_LOG_CUSTOMIZE +#define DMLC_LOG_CUSTOMIZE 0 +#endif + +/*! \brief whether compile with hdfs support */ +#ifndef DMLC_USE_HDFS +#define DMLC_USE_HDFS 0 +#endif + +/*! \brief whether compile with s3 support */ +#ifndef DMLC_USE_S3 +#define DMLC_USE_S3 0 +#endif + +/*! \brief whether or not use parameter server */ +#ifndef DMLC_USE_PS +#define DMLC_USE_PS 0 +#endif + +/*! \brief whether or not use c++11 support */ +#ifndef DMLC_USE_CXX11 +#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ + __cplusplus >= 201103L || defined(_MSC_VER)) +#endif + +/// check if g++ is before 4.6 +#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) +#if __GNUC__ == 4 && __GNUC_MINOR__ < 6 +#pragma message("Will need g++-4.6 or higher to compile all" \ + "the features in dmlc-core, " \ + "compile without c++0x, some features may be disabled") +#undef DMLC_USE_CXX11 +#define DMLC_USE_CXX11 0 +#endif +#endif + +/*! + * \brief Enable std::thread related modules, + * Used to disable some module in mingw compile. + */ +#ifndef DMLC_ENABLE_STD_THREAD +#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11 +#endif + +/*! \brief whether enable regex support, actually need g++-4.9 or higher*/ +#ifndef DMLC_USE_REGEX +#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER)) +#endif + +/*! + * \brief Disable copy constructor and assignment operator. + * + * If C++11 is supported, both copy and move constructors and + * assignment operators are deleted explicitly. Otherwise, they are + * only declared but not implemented. Place this macro in private + * section if C++11 is not available. + */ +#ifndef DISALLOW_COPY_AND_ASSIGN +# if DMLC_USE_CXX11 +# define DISALLOW_COPY_AND_ASSIGN(T) \ + T(T const&) = delete; \ + T(T&&) = delete; \ + T& operator=(T const&) = delete; \ + T& operator=(T&&) = delete +# else +# define DISALLOW_COPY_AND_ASSIGN(T) \ + T(T const&); \ + T& operator=(T const&) +# endif +#endif + +/// +/// code block to handle optionally loading +/// +#if !defined(__GNUC__) +#define fopen64 std::fopen +#endif +#ifdef _MSC_VER +#if _MSC_VER < 1900 +// NOTE: sprintf_s is not equivalent to snprintf, +// they are equivalent when success, which is sufficient for our case +#define snprintf sprintf_s +#define vsnprintf vsprintf_s +#endif +#else +#ifdef _FILE_OFFSET_BITS +#if _FILE_OFFSET_BITS == 32 +#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") +#endif +#endif + +#ifdef __APPLE__ +#define off64_t off_t +#define fopen64 std::fopen +#endif + +extern "C" { +#include +} +#endif + +#ifdef _MSC_VER +//! \cond Doxygen_Suppress +typedef signed char int8_t; +typedef __int16 int16_t; +typedef __int32 int32_t; +typedef __int64 int64_t; +typedef unsigned char uint8_t; +typedef unsigned __int16 uint16_t; +typedef unsigned __int32 uint32_t; +typedef unsigned __int64 uint64_t; +//! \endcond +#else +#include +#endif +#include +#include + +/*! \brief namespace for dmlc */ +namespace dmlc { +/*! + * \brief safely get the beginning address of a vector + * \param vec input vector + * \return beginning address of a vector + */ +template +inline T *BeginPtr(std::vector &vec) { // NOLINT(*) + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +/*! + * \brief get the beginning address of a vector + * \param vec input vector + * \return beginning address of a vector + */ +template +inline const T *BeginPtr(const std::vector &vec) { + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +/*! + * \brief get the beginning address of a vector + * \param str input string + * \return beginning address of a string + */ +inline char* BeginPtr(std::string &str) { // NOLINT(*) + if (str.length() == 0) return NULL; + return &str[0]; +} +/*! + * \brief get the beginning address of a vector + * \param str input string + * \return beginning address of a string + */ +inline const char* BeginPtr(const std::string &str) { + if (str.length() == 0) return NULL; + return &str[0]; +} +} // namespace dmlc + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define constexpr const +#define alignof __alignof +#endif + +#endif // DMLC_BASE_H_ diff --git a/include/dmlc/io.h b/include/dmlc/io.h index 189c2b8b7..31aed37e1 100644 --- a/include/dmlc/io.h +++ b/include/dmlc/io.h @@ -57,29 +57,31 @@ class Stream { // NOLINT(*) bool allow_null = false); // helper functions to write/read different data structures /*! - * \brief writes a vector - * \param vec vector to be written/serialized + * \brief writes a data to stream + * + * dmlc::Stream support Write/Read of most STL + * composites and base types. + * If the data type is not supported, a compile time error will + * be issued. + * + * \param data data to be written + * \tparam T the data type to be written */ template - inline void Write(const std::vector &vec); + inline void Write(const T &data); /*! - * \brief loads a vector - * \param out_vec vector to be loaded/deserialized + * \brief loads a data from stream. + * + * dmlc::Stream support Write/Read of most STL + * composites and base types. + * If the data type is not supported, a compile time error will + * be issued. + * + * \param out_data place holder of data to be deserialized * \return whether the load was successful */ template - inline bool Read(std::vector *out_vec); - /*! - * \brief writes a string - * \param str the string to be written/serialized - */ - inline void Write(const std::string &str); - /*! - * \brief loads a string - * \param out_str string to be loaded/deserialized - * \return whether the load/deserialization was successful - */ - inline bool Read(std::string *out_str); + inline bool Read(T *out_data); }; /*! \brief interface of i/o stream that support seek */ @@ -108,7 +110,7 @@ class SeekStream: public Stream { /*! \brief interface for serializable objects */ class Serializable { public: - /*! \brief destructor */ + /*! \brief virtual destructor */ virtual ~Serializable() {} /*! * \brief load the model from a stream @@ -185,6 +187,14 @@ class InputSplit { virtual bool NextChunk(Blob *out_chunk) = 0; /*! \brief destructor*/ virtual ~InputSplit(void) {} + /*! + * \brief reset the Input split to a certain part id, + * The InputSplit will be pointed to the head of the new specified segment. + * This feature may not be supported by every implementation of InputSplit. + * \param part_index The part id of the new input. + * \param num_parts The total number of parts. + */ + virtual void ResetPartition(unsigned part_index, unsigned num_parts) = 0; /*! * \brief factory function: * create input split given a uri @@ -245,22 +255,30 @@ class ostream : public std::basic_ostream { this->rdbuf(&buf_); } + /*! \return how many bytes we written so far */ + inline size_t bytes_written(void) const { + return buf_.bytes_out(); + } + private: // internal streambuf class OutBuf : public std::streambuf { public: explicit OutBuf(size_t buffer_size) - : stream_(NULL), buffer_(buffer_size) { + : stream_(NULL), buffer_(buffer_size), bytes_out_(0) { if (buffer_size == 0) buffer_.resize(2); } // set stream to the buffer inline void set_stream(Stream *stream); + inline size_t bytes_out() const { return bytes_out_; } private: /*! \brief internal stream by StreamBuf */ Stream *stream_; /*! \brief internal buffer */ std::vector buffer_; + /*! \brief number of bytes written so far */ + size_t bytes_out_; // override sync inline int_type sync(void); // override overflow @@ -337,45 +355,19 @@ class istream : public std::basic_istream { /*! \brief input buffer */ InBuf buf_; }; +} // namespace dmlc +#include "./serializer.h" + +namespace dmlc { // implementations of inline functions template -inline void Stream::Write(const std::vector &vec) { - uint64_t sz = static_cast(vec.size()); - this->Write(&sz, sizeof(sz)); - if (sz != 0) { - this->Write(&vec[0], sizeof(T) * vec.size()); - } +inline void Stream::Write(const T &data) { + serializer::Handler::Write(this, data); } template -inline bool Stream::Read(std::vector *out_vec) { - uint64_t sz; - if (this->Read(&sz, sizeof(sz)) == 0) return false; - size_t size = static_cast(sz); - out_vec->resize(size); - if (sz != 0) { - if (this->Read(&(*out_vec)[0], sizeof(T) * size) == 0) return false; - } - return true; -} -inline void Stream::Write(const std::string &str) { - uint64_t sz = static_cast(str.length()); - this->Write(&sz, sizeof(sz)); - if (sz != 0) { - this->Write(&str[0], sizeof(char) * str.length()); - } -} -inline bool Stream::Read(std::string *out_str) { - uint64_t sz; - if (this->Read(&sz, sizeof(sz)) == 0) return false; - size_t size = static_cast(sz); - out_str->resize(size); - if (sz != 0) { - if (this->Read(&(*out_str)[0], sizeof(char) * size) == 0) { - return false; - } - } - return true; +inline bool Stream::Read(T *out_data) { + return serializer::Handler::Read(this, out_data); } // implementations for ostream @@ -389,6 +381,7 @@ inline int ostream::OutBuf::sync(void) { std::ptrdiff_t n = pptr() - pbase(); stream_->Write(pbase(), n); this->pbump(-static_cast(n)); + bytes_out_ += n; return 0; } inline int ostream::OutBuf::overflow(int c) { @@ -397,8 +390,10 @@ inline int ostream::OutBuf::overflow(int c) { this->pbump(-static_cast(n)); if (c == EOF) { stream_->Write(pbase(), n); + bytes_out_ += n; } else { stream_->Write(pbase(), n + 1); + bytes_out_ += n + 1; } return c; } diff --git a/include/dmlc/logging.h b/include/dmlc/logging.h new file mode 100644 index 000000000..eb75268ea --- /dev/null +++ b/include/dmlc/logging.h @@ -0,0 +1,268 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file logging.h + * \brief defines logging macros of dmlc + * allows use of GLOG, fall back to internal + * implementation when disabled + */ +#ifndef DMLC_LOGGING_H_ +#define DMLC_LOGGING_H_ +#include +#include +#include +#include +#include +#include "./base.h" + +namespace dmlc { +/*! + * \brief exception class that will be thrown by + * default logger if DMLC_LOG_FATAL_THROW == 1 + */ +struct Error : public std::runtime_error { + /*! + * \brief constructor + * \param s the error message + */ + explicit Error(const std::string &s) : std::runtime_error(s) {} +}; +} // namespace dmlc + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define noexcept(a) +#endif + +#if DMLC_USE_CXX11 +#define DMLC_THROW_EXCEPTION noexcept(false) +#else +#define DMLC_THROW_EXCEPTION +#endif + +#if DMLC_USE_GLOG +#include + +namespace dmlc { +inline void InitLogging(const char* argv0) { + google::InitGoogleLogging(argv0); +} +} // namespace dmlc + +#else +// use a light version of glog +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable : 4722) +#endif + +namespace dmlc { +inline void InitLogging(const char* argv0) { + // DO NOTHING +} + +// Always-on checking +#define CHECK(x) \ + if (!(x)) \ + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ + "failed: " #x << ' ' +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) +#define CHECK_NOTNULL(x) \ + ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) +// Debug-only checking. +#ifdef NDEBUG +#define DCHECK(x) \ + while (false) CHECK(x) +#define DCHECK_LT(x, y) \ + while (false) CHECK((x) < (y)) +#define DCHECK_GT(x, y) \ + while (false) CHECK((x) > (y)) +#define DCHECK_LE(x, y) \ + while (false) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) \ + while (false) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) \ + while (false) CHECK((x) == (y)) +#define DCHECK_NE(x, y) \ + while (false) CHECK((x) != (y)) +#else +#define DCHECK(x) CHECK(x) +#define DCHECK_LT(x, y) CHECK((x) < (y)) +#define DCHECK_GT(x, y) CHECK((x) > (y)) +#define DCHECK_LE(x, y) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) CHECK((x) == (y)) +#define DCHECK_NE(x, y) CHECK((x) != (y)) +#endif // NDEBUG + +#if DMLC_LOG_CUSTOMIZE +#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__) +#else +#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) +#endif +#define LOG_ERROR LOG_INFO +#define LOG_WARNING LOG_INFO +#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) +#define LOG_QFATAL LOG_FATAL + +// Poor man version of VLOG +#define VLOG(x) LOG_INFO.stream() + +#define LOG(severity) LOG_##severity.stream() +#define LG LOG_INFO.stream() +#define LOG_IF(severity, condition) \ + !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) + +#ifdef NDEBUG +#define LOG_DFATAL LOG_ERROR +#define DFATAL ERROR +#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#define DLOG_IF(severity, condition) \ + (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#else +#define LOG_DFATAL LOG_FATAL +#define DFATAL FATAL +#define DLOG(severity) LOG(severity) +#define DLOG_IF(severity, condition) LOG_IF(severity, condition) +#endif + +// Poor man version of LOG_EVERY_N +#define LOG_EVERY_N(severity, n) LOG(severity) + +class DateLogger { + public: + DateLogger() { +#if defined(_MSC_VER) + _tzset(); +#endif + } + const char* HumanDate() { +#if defined(_MSC_VER) + _strtime_s(buffer_, sizeof(buffer_)); +#else + time_t time_value = time(NULL); + struct tm *pnow; +#if !defined(_WIN32) + struct tm now; + pnow = localtime_r(&time_value, &now); +#else + pnow = localtime(&time_value); // NOLINT(*) +#endif + snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", + pnow->tm_hour, pnow->tm_min, pnow->tm_sec); +#endif + return buffer_; + } + + private: + char buffer_[9]; +}; + +class LogMessage { + public: + LogMessage(const char* file, int line) + : +#ifdef __ANDROID__ + log_stream_(std::cout) +#else + log_stream_(std::cerr) +#endif + { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + ~LogMessage() { log_stream_ << '\n'; } + std::ostream& stream() { return log_stream_; } + + protected: + std::ostream& log_stream_; + + private: + DateLogger pretty_date_; + LogMessage(const LogMessage&); + void operator=(const LogMessage&); +}; + +// customized logger that can allow user to define where to log the message. +class CustomLogMessage { + public: + CustomLogMessage(const char* file, int line) { + log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":" + << line << ": "; + } + ~CustomLogMessage() { + Log(log_stream_.str()); + } + std::ostream& stream() { return log_stream_; } + /*! + * \brief customized logging of the message. + * This function won't be implemented by libdmlc + * \param msg The message to be logged. + */ + static void Log(const std::string& msg); + + private: + std::ostringstream log_stream_; +}; + +#if DMLC_LOG_FATAL_THROW == 0 +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} + ~LogMessageFatal() { + log_stream_ << "\n"; + abort(); + } + + private: + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#else +class LogMessageFatal { + public: + LogMessageFatal(const char* file, int line) { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + std::ostringstream &stream() { return log_stream_; } + ~LogMessageFatal() DMLC_THROW_EXCEPTION { + // throwing out of destructor is evil + // hopefully we can do it here + // also log the message before throw +#if DMLC_LOG_BEFORE_THROW + LOG(ERROR) << log_stream_.str(); +#endif + throw Error(log_stream_.str()); + } + + private: + std::ostringstream log_stream_; + DateLogger pretty_date_; + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#endif + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class LogMessageVoidify { + public: + LogMessageVoidify() {} + // This has to be an operator with a precedence lower than << but + // higher than "?:". See its usage. + void operator&(std::ostream&) {} +}; + +} // namespace dmlc + +#endif +#endif // DMLC_LOGGING_H_ diff --git a/include/dmlc/serializer.h b/include/dmlc/serializer.h new file mode 100644 index 000000000..be80dd008 --- /dev/null +++ b/include/dmlc/serializer.h @@ -0,0 +1,383 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file serializer.h + * \brief serializer template class that helps serialization. + * This file do not need to be directly used by most user. + */ +#ifndef DMLC_SERIALIZER_H_ +#define DMLC_SERIALIZER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "./base.h" +#include "./io.h" +#include "./logging.h" +#include "./type_traits.h" + +#if DMLC_USE_CXX11 +#include +#include +#endif + +namespace dmlc { +/*! \brief internal namespace for serializers */ +namespace serializer { +/*! + * \brief generic serialization handler + * \tparam T the type to be serialized + */ +template +struct Handler; + +//! \cond Doxygen_Suppress +/*! + * \brief Serializer that redirect calls by condition + * \tparam cond the condition + * \tparam Then the serializer used for then condition + * \tparam Else the serializer used for else condition + * \tparam Return the type of data the serializer handles + */ +template +struct IfThenElse; + +template +struct IfThenElse { + inline static void Write(Stream *strm, const T &data) { + Then::Write(strm, data); + } + inline static bool Read(Stream *strm, T *data) { + return Then::Read(strm, data); + } +}; +template +struct IfThenElse { + inline static void Write(Stream *strm, const T &data) { + Else::Write(strm, data); + } + inline static bool Read(Stream *strm, T *data) { + return Else::Read(strm, data); + } +}; + +/*! \brief Serializer for POD(plain-old-data) data */ +template +struct PODHandler { + inline static void Write(Stream *strm, const T &data) { + strm->Write(&data, sizeof(T)); + } + inline static bool Read(Stream *strm, T *dptr) { + return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) + } +}; + +// serializer for class that have save/load function +template +struct SaveLoadClassHandler { + inline static void Write(Stream *strm, const T &data) { + data.Save(strm); + } + inline static bool Read(Stream *strm, T *data) { + return data->Load(strm); + } +}; + +/*! + * \brief dummy class for undefined serialization. + * This is used to generate error message when user tries to + * serialize something that is not supported. + * \tparam T the type to be serialized + */ +template +struct UndefinedSerializerFor { +}; + +/*! + * \brief Serializer handler for std::vector where T is POD type. + * \tparam T element type + */ +template +struct PODVectorHandler { + inline static void Write(Stream *strm, const std::vector &vec) { + uint64_t sz = static_cast(vec.size()); + strm->Write(&sz, sizeof(sz)); + if (sz != 0) { + strm->Write(&vec[0], sizeof(T) * vec.size()); + } + } + inline static bool Read(Stream *strm, std::vector *out_vec) { + uint64_t sz; + if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; + size_t size = static_cast(sz); + out_vec->resize(size); + if (sz != 0) { + size_t nbytes = sizeof(T) * size; + return strm->Read(&(*out_vec)[0], nbytes) == nbytes; + } + return true; + } +}; + +/*! + * \brief Serializer handler for std::vector where T can be composed type + * \tparam T element type + */ +template +struct ComposeVectorHandler { + inline static void Write(Stream *strm, const std::vector &vec) { + uint64_t sz = static_cast(vec.size()); + strm->Write(&sz, sizeof(sz)); + for (size_t i = 0; i < vec.size(); ++i) { + Handler::Write(strm, vec[i]); + } + } + inline static bool Read(Stream *strm, std::vector *out_vec) { + uint64_t sz; + if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; + size_t size = static_cast(sz); + out_vec->resize(size); + for (size_t i = 0; i < size; ++i) { + if (!Handler::Read(strm, &(*out_vec)[i])) return false; + } + return true; + } +}; + +/*! + * \brief Serializer handler for std::basic_string where T is POD type. + * \tparam T element type + */ +template +struct PODStringHandler { + inline static void Write(Stream *strm, const std::basic_string &vec) { + uint64_t sz = static_cast(vec.length()); + strm->Write(&sz, sizeof(sz)); + if (sz != 0) { + strm->Write(&vec[0], sizeof(T) * vec.length()); + } + } + inline static bool Read(Stream *strm, std::basic_string *out_vec) { + uint64_t sz; + if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; + size_t size = static_cast(sz); + out_vec->resize(size); + if (sz != 0) { + size_t nbytes = sizeof(T) * size; + return strm->Read(&(*out_vec)[0], nbytes) == nbytes; + } + return true; + } +}; + +/*! \brief Serializer for std::pair */ +template +struct PairHandler { + inline static void Write(Stream *strm, const std::pair &data) { + Handler::Write(strm, data.first); + Handler::Write(strm, data.second); + } + inline static bool Read(Stream *strm, std::pair *data) { + return Handler::Read(strm, &(data->first)) && + Handler::Read(strm, &(data->second)); + } +}; + +// set type handler that can handle most collection type case +template +struct CollectionHandler { + inline static void Write(Stream *strm, const ContainerType &data) { + typedef typename ContainerType::value_type ElemType; + // dump data to vector + std::vector vdata(data.begin(), data.end()); + // serialize the vector + Handler >::Write(strm, vdata); + } + inline static bool Read(Stream *strm, ContainerType *data) { + typedef typename ContainerType::value_type ElemType; + std::vector vdata; + if (!Handler >::Read(strm, &vdata)) return false; + data->clear(); + data->insert(vdata.begin(), vdata.end()); + return true; + } +}; + + +// handler that can handle most list type case +// this type insert function takes additional iterator +template +struct ListHandler { + inline static void Write(Stream *strm, const ListType &data) { + typedef typename ListType::value_type ElemType; + // dump data to vector + std::vector vdata(data.begin(), data.end()); + // serialize the vector + Handler >::Write(strm, vdata); + } + inline static bool Read(Stream *strm, ListType *data) { + typedef typename ListType::value_type ElemType; + std::vector vdata; + if (!Handler >::Read(strm, &vdata)) return false; + data->clear(); + data->insert(data->begin(), vdata.begin(), vdata.end()); + return true; + } +}; + +//! \endcond + +/*! + * \brief generic serialization handler for type T + * + * User can define specialization of this class to support + * composite serialization of their own class. + * + * \tparam T the type to be serialized + */ +template +struct Handler { + /*! + * \brief write data to stream + * \param strm the stream we write the data. + * \param data the data obeject to be serialized + */ + inline static void Write(Stream *strm, const T &data) { + IfThenElse::value, + PODHandler, + IfThenElse::value, + SaveLoadClassHandler, + UndefinedSerializerFor, T>, + T> + ::Write(strm, data); + } + /*! + * \brief read data to stream + * \param strm the stream to read the data. + * \param data the pointer to the data obeject to read + * \return whether the read is successful + */ + inline static bool Read(Stream *strm, T *data) { + return IfThenElse::value, + PODHandler, + IfThenElse::value, + SaveLoadClassHandler, + UndefinedSerializerFor, T>, + T> + ::Read(strm, data); + } +}; + +//! \cond Doxygen_Suppress +template +struct Handler > { + inline static void Write(Stream *strm, const std::vector &data) { + IfThenElse::value, + PODVectorHandler, + ComposeVectorHandler, std::vector > + ::Write(strm, data); + } + inline static bool Read(Stream *strm, std::vector *data) { + return IfThenElse::value, + PODVectorHandler, + ComposeVectorHandler, + std::vector > + ::Read(strm, data); + } +}; + +template +struct Handler > { + inline static void Write(Stream *strm, const std::basic_string &data) { + IfThenElse::value, + PODStringHandler, + UndefinedSerializerFor, + std::basic_string > + ::Write(strm, data); + } + inline static bool Read(Stream *strm, std::basic_string *data) { + return IfThenElse::value, + PODStringHandler, + UndefinedSerializerFor, + std::basic_string > + ::Read(strm, data); + } +}; + +template +struct Handler > { + inline static void Write(Stream *strm, const std::pair &data) { + IfThenElse::value && dmlc::is_pod::value, + PODHandler >, + PairHandler, + std::pair > + ::Write(strm, data); + } + inline static bool Read(Stream *strm, std::pair *data) { + return IfThenElse::value && dmlc::is_pod::value, + PODHandler >, + PairHandler, + std::pair > + ::Read(strm, data); + } +}; + +template +struct Handler > + : public CollectionHandler > { +}; + +template +struct Handler > + : public CollectionHandler > { +}; + +template +struct Handler > + : public CollectionHandler > { +}; + +template +struct Handler > + : public CollectionHandler > { +}; + +template +struct Handler > + : public ListHandler > { +}; + +template +struct Handler > + : public ListHandler > { +}; + +#if DMLC_USE_CXX11 +template +struct Handler > + : public CollectionHandler > { +}; + +template +struct Handler > + : public CollectionHandler > { +}; + +template +struct Handler > + : public CollectionHandler > { +}; + +template +struct Handler > + : public CollectionHandler > { +}; +#endif +//! \endcond +} // namespace serializer +} // namespace dmlc +#endif // DMLC_SERIALIZER_H_ diff --git a/include/dmlc/type_traits.h b/include/dmlc/type_traits.h new file mode 100644 index 000000000..73abfba80 --- /dev/null +++ b/include/dmlc/type_traits.h @@ -0,0 +1,171 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file type_traits.h + * \brief type traits information header + */ +#ifndef DMLC_TYPE_TRAITS_H_ +#define DMLC_TYPE_TRAITS_H_ + +#include "./base.h" +#if DMLC_USE_CXX11 +#include +#endif +#include + +namespace dmlc { +/*! + * \brief whether a type is pod type + * \tparam T the type to query + */ +template +struct is_pod { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_pod::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + + +/*! + * \brief whether a type is integer type + * \tparam T the type to query + */ +template +struct is_integral { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_integral::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + +/*! + * \brief whether a type is floating point type + * \tparam T the type to query + */ +template +struct is_floating_point { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_floating_point::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + +/*! + * \brief whether a type is arithemetic type + * \tparam T the type to query + */ +template +struct is_arithmetic { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_arithmetic::value; +#else + /*! \brief the value of the traits */ + static const bool value = (dmlc::is_integral::value || + dmlc::is_floating_point::value); +#endif +}; + +/*! + * \brief the string representation of type name + * \tparam T the type to query + * \return a const string of typename. + */ +template +inline const char* type_name() { + return ""; +} + +/*! + * \brief whether a type have save/load function + * \tparam T the type to query + */ +template +struct has_saveload { + /*! \brief the value of the traits */ + static const bool value = false; +}; + +/*! + * \brief template to select type based on condition + * For example, IfThenElseType::Type will give int + * \tparam cond the condition + * \tparam Then the typename to be returned if cond is true + * \tparam The typename to be returned if cond is false +*/ +template +struct IfThenElseType; + +/*! \brief macro to quickly declare traits information */ +#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \ + template<> \ + struct Trait { \ + static const bool value = Value; \ + } + +/*! \brief macro to quickly declare traits information */ +#define DMLC_DECLARE_TYPE_NAME(Type, Name) \ + template<> \ + inline const char* type_name() { \ + return Name; \ + } + +//! \cond Doxygen_Suppress +// declare special traits when C++11 is not available +#if DMLC_USE_CXX11 == 0 +DMLC_DECLARE_TRAITS(is_pod, char, true); +DMLC_DECLARE_TRAITS(is_pod, int8_t, true); +DMLC_DECLARE_TRAITS(is_pod, int16_t, true); +DMLC_DECLARE_TRAITS(is_pod, int32_t, true); +DMLC_DECLARE_TRAITS(is_pod, int64_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint8_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint16_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint32_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint64_t, true); +DMLC_DECLARE_TRAITS(is_pod, float, true); +DMLC_DECLARE_TRAITS(is_pod, double, true); + +DMLC_DECLARE_TRAITS(is_integral, char, true); +DMLC_DECLARE_TRAITS(is_integral, int8_t, true); +DMLC_DECLARE_TRAITS(is_integral, int16_t, true); +DMLC_DECLARE_TRAITS(is_integral, int32_t, true); +DMLC_DECLARE_TRAITS(is_integral, int64_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint8_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint16_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint32_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint64_t, true); + +DMLC_DECLARE_TRAITS(is_floating_point, float, true); +DMLC_DECLARE_TRAITS(is_floating_point, double, true); + +#endif + +DMLC_DECLARE_TYPE_NAME(float, "float"); +DMLC_DECLARE_TYPE_NAME(double, "double"); +DMLC_DECLARE_TYPE_NAME(int, "int"); +DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)"); +DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)"); +DMLC_DECLARE_TYPE_NAME(std::string, "string"); +DMLC_DECLARE_TYPE_NAME(bool, "boolean"); + +template +struct IfThenElseType { + typedef Then Type; +}; + +template +struct IfThenElseType { + typedef Else Type; +}; +//! \endcond +} // namespace dmlc +#endif // DMLC_TYPE_TRAITS_H_