more movement to beginptr

This commit is contained in:
tqchen 2014-09-02 11:14:57 -07:00
parent 27cabd131e
commit c75275a861
5 changed files with 30 additions and 18 deletions

View File

@ -10,7 +10,6 @@
#include "src/utils/matrix_csr.h" #include "src/utils/matrix_csr.h"
using namespace std; using namespace std;
using namespace xgboost; using namespace xgboost;
using namespace xgboost::utils;
extern "C" { extern "C" {
void XGBoostAssert_R(int exp, const char *fmt, ...); void XGBoostAssert_R(int exp, const char *fmt, ...);

View File

@ -54,8 +54,10 @@ class DMatrixSimple : public DataMatrix {
for (size_t i = 0; i < batch.size; ++i) { for (size_t i = 0; i < batch.size; ++i) {
RowBatch::Inst inst = batch[i]; RowBatch::Inst inst = batch[i];
row_data_.resize(row_data_.size() + inst.length); row_data_.resize(row_data_.size() + inst.length);
memcpy(&row_data_[row_ptr_.back()], inst.data, if (inst.length != 0) {
sizeof(RowBatch::Entry) * inst.length); memcpy(&row_data_[row_ptr_.back()], inst.data,
sizeof(RowBatch::Entry) * inst.length);
}
row_ptr_.push_back(row_ptr_.back() + inst.length); row_ptr_.push_back(row_ptr_.back() + inst.length);
} }
} }
@ -244,8 +246,8 @@ class DMatrixSimple : public DataMatrix {
at_first_ = false; at_first_ = false;
batch_.size = parent_->row_ptr_.size() - 1; batch_.size = parent_->row_ptr_.size() - 1;
batch_.base_rowid = 0; batch_.base_rowid = 0;
batch_.ind_ptr = &parent_->row_ptr_[0]; batch_.ind_ptr = BeginPtr(parent_->row_ptr_);
batch_.data_ptr = &parent_->row_data_[0]; batch_.data_ptr = BeginPtr(parent_->row_data_);
return true; return true;
} }
virtual const RowBatch &Value(void) const { virtual const RowBatch &Value(void) const {

View File

@ -110,9 +110,9 @@ class FMatrixS : public IFMatrix{
const std::vector<RowBatch::Entry> &data) { const std::vector<RowBatch::Entry> &data) {
size_t nrow = ptr.size() - 1; size_t nrow = ptr.size() - 1;
fo.Write(&nrow, sizeof(size_t)); fo.Write(&nrow, sizeof(size_t));
fo.Write(&ptr[0], ptr.size() * sizeof(size_t)); fo.Write(BeginPtr(ptr), ptr.size() * sizeof(size_t));
if (data.size() != 0) { if (data.size() != 0) {
fo.Write(&data[0], data.size() * sizeof(RowBatch::Entry)); fo.Write(BeginPtr(data), data.size() * sizeof(RowBatch::Entry));
} }
} }
/*! /*!
@ -127,11 +127,11 @@ class FMatrixS : public IFMatrix{
size_t nrow; size_t nrow;
utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format"); utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format");
out_ptr->resize(nrow + 1); out_ptr->resize(nrow + 1);
utils::Check(fi.Read(&(*out_ptr)[0], out_ptr->size() * sizeof(size_t)) != 0, utils::Check(fi.Read(BeginPtr(*out_ptr), out_ptr->size() * sizeof(size_t)) != 0,
"invalid input file format"); "invalid input file format");
out_data->resize(out_ptr->back()); out_data->resize(out_ptr->back());
if (out_data->size() != 0) { if (out_data->size() != 0) {
utils::Assert(fi.Read(&(*out_data)[0], out_data->size() * sizeof(RowBatch::Entry)) != 0, utils::Assert(fi.Read(BeginPtr(*out_data), out_data->size() * sizeof(RowBatch::Entry)) != 0,
"invalid input file format"); "invalid input file format");
} }
} }
@ -213,8 +213,8 @@ class FMatrixS : public IFMatrix{
col_data_[i] = SparseBatch::Inst(&data[0] + ptr[ridx], col_data_[i] = SparseBatch::Inst(&data[0] + ptr[ridx],
static_cast<bst_uint>(ptr[ridx+1] - ptr[ridx])); static_cast<bst_uint>(ptr[ridx+1] - ptr[ridx]));
} }
batch_.col_index = &col_index_[0]; batch_.col_index = BeginPtr(col_index_);
batch_.col_data = &col_data_[0]; batch_.col_data = BeginPtr(col_data_);
this->BeforeFirst(); this->BeforeFirst();
} }
// data content // data content

View File

@ -154,6 +154,8 @@ inline FILE *FopenCheck(const char *fname, const char *flag) {
Check(fp != NULL, "can not open file \"%s\"\n", fname); Check(fp != NULL, "can not open file \"%s\"\n", fname);
return fp; return fp;
} }
} // namespace utils
// easy utils that can be directly acessed in xgboost
/*! \brief get the beginning address of a vector */ /*! \brief get the beginning address of a vector */
template<typename T> template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { inline T *BeginPtr(std::vector<T> &vec) {
@ -163,6 +165,14 @@ inline T *BeginPtr(std::vector<T> &vec) {
return &vec[0]; return &vec[0];
} }
} }
} // namespace utils /*! \brief get the beginning address of a vector */
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_UTILS_UTILS_H_ #endif // XGBOOST_UTILS_UTILS_H_

View File

@ -13,6 +13,7 @@ using namespace std;
#include "../src/data.h" #include "../src/data.h"
#include "../src/learner/learner-inl.hpp" #include "../src/learner/learner-inl.hpp"
#include "../src/io/io.h" #include "../src/io/io.h"
#include "../src/utils/utils.h"
#include "../src/io/simple_dmatrix-inl.hpp" #include "../src/io/simple_dmatrix-inl.hpp"
using namespace xgboost; using namespace xgboost;
@ -32,7 +33,7 @@ class Booster: public learner::BoostLearner {
this->CheckInitModel(); this->CheckInitModel();
this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit); this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit);
*len = static_cast<bst_ulong>(this->preds_.size()); *len = static_cast<bst_ulong>(this->preds_.size());
return &this->preds_[0]; return BeginPtr(this->preds_);
} }
inline void BoostOneIter(const DataMatrix &train, inline void BoostOneIter(const DataMatrix &train,
float *grad, float *hess, bst_ulong len) { float *grad, float *hess, bst_ulong len) {
@ -60,7 +61,7 @@ class Booster: public learner::BoostLearner {
model_dump_cptr[i] = model_dump[i].c_str(); model_dump_cptr[i] = model_dump[i].c_str();
} }
*len = static_cast<bst_ulong>(model_dump.size()); *len = static_cast<bst_ulong>(model_dump.size());
return &model_dump_cptr[0]; return BeginPtr(model_dump_cptr);
} }
// temporal fields // temporal fields
// temporal data to save evaluation dump // temporal data to save evaluation dump
@ -177,13 +178,13 @@ extern "C"{
std::vector<float> &vec = std::vector<float> &vec =
static_cast<DataMatrix*>(handle)->info.GetFloatInfo(field); static_cast<DataMatrix*>(handle)->info.GetFloatInfo(field);
vec.resize(len); vec.resize(len);
memcpy(&vec[0], info, sizeof(float) * len); memcpy(BeginPtr(vec), info, sizeof(float) * len);
} }
void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *info, bst_ulong len) { void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *info, bst_ulong len) {
std::vector<unsigned> &vec = std::vector<unsigned> &vec =
static_cast<DataMatrix*>(handle)->info.GetUIntInfo(field); static_cast<DataMatrix*>(handle)->info.GetUIntInfo(field);
vec.resize(len); vec.resize(len);
memcpy(&vec[0], info, sizeof(unsigned) * len); memcpy(BeginPtr(vec), info, sizeof(unsigned) * len);
} }
void XGDMatrixSetGroup(void *handle, const unsigned *group, bst_ulong len) { void XGDMatrixSetGroup(void *handle, const unsigned *group, bst_ulong len) {
DataMatrix *pmat = static_cast<DataMatrix*>(handle); DataMatrix *pmat = static_cast<DataMatrix*>(handle);
@ -197,13 +198,13 @@ extern "C"{
const std::vector<float> &vec = const std::vector<float> &vec =
static_cast<const DataMatrix*>(handle)->info.GetFloatInfo(field); static_cast<const DataMatrix*>(handle)->info.GetFloatInfo(field);
*len = static_cast<bst_ulong>(vec.size()); *len = static_cast<bst_ulong>(vec.size());
return &vec[0]; return BeginPtr(vec);
} }
const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, bst_ulong* len) { const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, bst_ulong* len) {
const std::vector<unsigned> &vec = const std::vector<unsigned> &vec =
static_cast<const DataMatrix*>(handle)->info.GetUIntInfo(field); static_cast<const DataMatrix*>(handle)->info.GetUIntInfo(field);
*len = static_cast<bst_ulong>(vec.size()); *len = static_cast<bst_ulong>(vec.size());
return &vec[0]; return BeginPtr(vec);
} }
bst_ulong XGDMatrixNumRow(const void *handle) { bst_ulong XGDMatrixNumRow(const void *handle) {
return static_cast<bst_ulong>(static_cast<const DataMatrix*>(handle)->info.num_row()); return static_cast<bst_ulong>(static_cast<const DataMatrix*>(handle)->info.num_row());