60 lines
1.3 KiB
C++
60 lines
1.3 KiB
C++
/*!
|
|
* Copyright 2015-2019 by Contributors
|
|
* \file simple_csr_source.cc
|
|
*/
|
|
#include <dmlc/base.h>
|
|
#include <xgboost/logging.h>
|
|
#include <xgboost/json.h>
|
|
|
|
#include "simple_csr_source.h"
|
|
|
|
namespace xgboost {
|
|
namespace data {
|
|
|
|
void SimpleCSRSource::Clear() {
|
|
page_.Clear();
|
|
this->info.Clear();
|
|
}
|
|
|
|
void SimpleCSRSource::CopyFrom(DMatrix* src) {
|
|
this->Clear();
|
|
this->info = src->Info();
|
|
for (const auto &batch : src->GetBatches<SparsePage>()) {
|
|
page_.Push(batch);
|
|
}
|
|
}
|
|
|
|
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
|
int tmagic;
|
|
CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format";
|
|
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
|
|
info.LoadBinary(fi);
|
|
fi->Read(&page_.offset.HostVector());
|
|
fi->Read(&page_.data.HostVector());
|
|
}
|
|
|
|
void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const {
|
|
int tmagic = kMagic;
|
|
fo->Write(&tmagic, sizeof(tmagic));
|
|
info.SaveBinary(fo);
|
|
fo->Write(page_.offset.HostVector());
|
|
fo->Write(page_.data.HostVector());
|
|
}
|
|
|
|
void SimpleCSRSource::BeforeFirst() {
|
|
at_first_ = true;
|
|
}
|
|
|
|
bool SimpleCSRSource::Next() {
|
|
if (!at_first_) return false;
|
|
at_first_ = false;
|
|
return true;
|
|
}
|
|
|
|
const SparsePage& SimpleCSRSource::Value() const {
|
|
return page_;
|
|
}
|
|
|
|
} // namespace data
|
|
} // namespace xgboost
|