xgboost/src/data/simple_csr_source.cc

60 lines
1.3 KiB
C++

/*!
* Copyright 2015-2019 by Contributors
* \file simple_csr_source.cc
*/
#include <dmlc/base.h>
#include <xgboost/logging.h>
#include <xgboost/json.h>
#include "simple_csr_source.h"
namespace xgboost {
namespace data {
void SimpleCSRSource::Clear() {
page_.Clear();
this->info.Clear();
}
void SimpleCSRSource::CopyFrom(DMatrix* src) {
this->Clear();
this->info = src->Info();
for (const auto &batch : src->GetBatches<SparsePage>()) {
page_.Push(batch);
}
}
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
int tmagic;
CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info.LoadBinary(fi);
fi->Read(&page_.offset.HostVector());
fi->Read(&page_.data.HostVector());
}
void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const {
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
info.SaveBinary(fo);
fo->Write(page_.offset.HostVector());
fo->Write(page_.data.HostVector());
}
void SimpleCSRSource::BeforeFirst() {
at_first_ = true;
}
bool SimpleCSRSource::Next() {
if (!at_first_) return false;
at_first_ = false;
return true;
}
const SparsePage& SimpleCSRSource::Value() const {
return page_;
}
} // namespace data
} // namespace xgboost