[LZ] Improve lz4 format
This commit is contained in:
parent
31d8e93ef3
commit
c4d389c5df
@ -6,6 +6,7 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <lz4.h>
|
||||
#include <lz4hc.h>
|
||||
#include "../../src/data/sparse_batch_page.h"
|
||||
@ -21,6 +22,9 @@ class CompressArray {
|
||||
public:
|
||||
// the data content.
|
||||
std::vector<DType> data;
|
||||
CompressArray() {
|
||||
use_deep_compress_ = dmlc::GetEnv("XGBOOST_LZ4_COMPRESS_DEEP", true);
|
||||
}
|
||||
// Decompression helper
|
||||
// number of chunks
|
||||
inline int num_chunk() const {
|
||||
@ -59,6 +63,8 @@ class CompressArray {
|
||||
std::vector<std::string> out_buffer_;
|
||||
// input buffer of data.
|
||||
std::string in_buffer_;
|
||||
// use deep compression.
|
||||
bool use_deep_compress_;
|
||||
};
|
||||
|
||||
template<typename DType>
|
||||
@ -124,9 +130,16 @@ inline void CompressArray<DType>::Compress(int chunk_id) {
|
||||
int bound = LZ4_compressBound(raw_chunk_size);
|
||||
CHECK_NE(bound, 0);
|
||||
buf.resize(bound);
|
||||
int encoded_size = LZ4_compress_HC(
|
||||
reinterpret_cast<char*>(dmlc::BeginPtr(data) + raw_chunks_[chunk_id]),
|
||||
dmlc::BeginPtr(buf), raw_chunk_size, buf.length(), 9);
|
||||
int encoded_size;
|
||||
if (use_deep_compress_) {
|
||||
encoded_size = LZ4_compress_HC(
|
||||
reinterpret_cast<char*>(dmlc::BeginPtr(data) + raw_chunks_[chunk_id]),
|
||||
dmlc::BeginPtr(buf), raw_chunk_size, buf.length(), 0);
|
||||
} else{
|
||||
encoded_size = LZ4_compress_default(
|
||||
reinterpret_cast<char*>(dmlc::BeginPtr(data) + raw_chunks_[chunk_id]),
|
||||
dmlc::BeginPtr(buf), raw_chunk_size, buf.length());
|
||||
}
|
||||
CHECK_NE(encoded_size, 0);
|
||||
CHECK_LE(static_cast<size_t>(encoded_size), buf.length());
|
||||
buf.resize(encoded_size);
|
||||
@ -148,16 +161,21 @@ inline void CompressArray<DType>::Write(dmlc::Stream* fo) {
|
||||
|
||||
class SparsePageLZ4Format : public SparsePage::Format {
|
||||
public:
|
||||
SparsePageLZ4Format()
|
||||
: raw_bytes_(0), encoded_bytes_(0) {
|
||||
SparsePageLZ4Format() {
|
||||
raw_bytes_ = raw_bytes_value_ = raw_bytes_index_ = 0;
|
||||
encoded_bytes_value_ = encoded_bytes_index_ = 0;
|
||||
nthread_ = 4;
|
||||
raw_bytes_ = encoded_bytes_ = 0;
|
||||
nthread_write_ = dmlc::GetEnv("XGBOOST_LZ4_COMPRESS_NTHREAD", 12);
|
||||
}
|
||||
~SparsePageLZ4Format() {
|
||||
size_t encoded_bytes = raw_bytes_ + encoded_bytes_value_ + encoded_bytes_index_;
|
||||
raw_bytes_ += raw_bytes_value_ + raw_bytes_index_;
|
||||
if (raw_bytes_ != 0) {
|
||||
LOG(CONSOLE) << "raw_bytes=" << raw_bytes_
|
||||
<< ", encoded_bytes=" << encoded_bytes_
|
||||
<< ", ratio=" << double(encoded_bytes_) / raw_bytes_;
|
||||
<< ", encoded_bytes=" << encoded_bytes
|
||||
<< ", ratio=" << double(encoded_bytes) / raw_bytes_
|
||||
<< ",ratio-index=" << double(encoded_bytes_index_) /raw_bytes_index_
|
||||
<< ",ratio-value=" << double(encoded_bytes_value_) /raw_bytes_value_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -222,7 +240,7 @@ class SparsePageLZ4Format : public SparsePage::Format {
|
||||
int nindex = index_.num_chunk();
|
||||
int nvalue = value_.num_chunk();
|
||||
int ntotal = nindex + nvalue;
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread_)
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread_write_)
|
||||
for (int i = 0; i < ntotal; ++i) {
|
||||
if (i < nindex) {
|
||||
index_.Compress(i);
|
||||
@ -232,9 +250,12 @@ class SparsePageLZ4Format : public SparsePage::Format {
|
||||
}
|
||||
index_.Write(fo);
|
||||
value_.Write(fo);
|
||||
raw_bytes_ += index_.RawBytes() + value_.RawBytes() + page.offset.size() * sizeof(size_t);
|
||||
encoded_bytes_ += index_.EncodedBytes() +
|
||||
value_.EncodedBytes() + page.offset.size() * sizeof(size_t);
|
||||
// statistics
|
||||
raw_bytes_index_ += index_.RawBytes();
|
||||
raw_bytes_value_ += value_.RawBytes();
|
||||
encoded_bytes_index_ += index_.EncodedBytes();
|
||||
encoded_bytes_value_ += value_.EncodedBytes();
|
||||
raw_bytes_ += page.offset.size() * sizeof(size_t);
|
||||
}
|
||||
|
||||
inline void LoadIndexValue(dmlc::SeekStream* fi) {
|
||||
@ -258,13 +279,15 @@ class SparsePageLZ4Format : public SparsePage::Format {
|
||||
// default chunk size.
|
||||
static const size_t kChunkSize = 64 << 10UL;
|
||||
// maximum chunk size.
|
||||
static const size_t kMaxChunk = 64;
|
||||
static const size_t kMaxChunk = 128;
|
||||
// number of threads
|
||||
int nthread_;
|
||||
// number of writing threads
|
||||
int nthread_write_;
|
||||
// raw bytes
|
||||
size_t raw_bytes_;
|
||||
size_t raw_bytes_, raw_bytes_index_, raw_bytes_value_;
|
||||
// encoded bytes
|
||||
size_t encoded_bytes_;
|
||||
size_t encoded_bytes_index_, encoded_bytes_value_;
|
||||
/*! \brief external memory column offset */
|
||||
std::vector<size_t> disk_offset_;
|
||||
// internal index
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit 05b958c178b16d707ff16b4b05506be124087e13
|
||||
Subproject commit 112d866dc92354304c0891500374fe40cdf13a50
|
||||
@ -76,6 +76,8 @@ struct LearnerTrainParam
|
||||
std::string test_flag;
|
||||
// maximum buffered row value
|
||||
float prob_buffer_row;
|
||||
// maximum row per batch.
|
||||
size_t max_row_perbatch;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(seed).set_default(0)
|
||||
@ -92,6 +94,8 @@ struct LearnerTrainParam
|
||||
.describe("Internal test flag");
|
||||
DMLC_DECLARE_FIELD(prob_buffer_row).set_default(1.0f).set_range(0.0f, 1.0f)
|
||||
.describe("Maximum buffered row portion");
|
||||
DMLC_DECLARE_FIELD(max_row_perbatch).set_default(std::numeric_limits<size_t>::max())
|
||||
.describe("maximum row per batch.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -328,9 +332,9 @@ class LearnerImpl : public Learner {
|
||||
std::vector<bool> enabled(ncol, true);
|
||||
// set max row per batch to limited value
|
||||
// in distributed mode, use safe choice otherwise
|
||||
size_t max_row_perbatch = std::numeric_limits<size_t>::max();
|
||||
size_t max_row_perbatch = tparam.max_row_perbatch;
|
||||
if (tparam.test_flag == "block" || tparam.dsplit == 2) {
|
||||
max_row_perbatch = 32UL << 10UL;
|
||||
max_row_perbatch = std::min(32UL << 10UL, max_row_perbatch);
|
||||
}
|
||||
// initialize column access
|
||||
p_train->InitColAccess(enabled,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user