[CORE] Refactor cache mechanism (#1540)
This commit is contained in:
@@ -22,7 +22,7 @@ namespace xgboost {
|
||||
// booster wrapper for backward compatible reason.
|
||||
class Booster {
|
||||
public:
|
||||
explicit Booster(const std::vector<DMatrix*>& cache_mats)
|
||||
explicit Booster(const std::vector<std::shared_ptr<DMatrix> >& cache_mats)
|
||||
: configured_(false),
|
||||
initialized_(false),
|
||||
learner_(Learner::Create(cache_mats)) {}
|
||||
@@ -207,8 +207,7 @@ int XGDMatrixCreateFromFile(const char *fname,
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||
<< "will split data among workers";
|
||||
}
|
||||
*out = DMatrix::Load(
|
||||
fname, false, true);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, false, true));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -224,7 +223,7 @@ int XGDMatrixCreateFromDataIter(
|
||||
scache = cache_info;
|
||||
}
|
||||
NativeDataIter parser(data_handle, callback);
|
||||
*out = DMatrix::Create(&parser, scache);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&parser, scache));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -250,16 +249,16 @@ XGB_DLL int XGDMatrixCreateFromCSR(const xgboost::bst_ulong* indptr,
|
||||
}
|
||||
mat.info.num_row = nindptr - 1;
|
||||
mat.info.num_nonzero = static_cast<uint64_t>(nelem);
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSC(const xgboost::bst_ulong* col_ptr,
|
||||
const unsigned* indices,
|
||||
const float* data,
|
||||
xgboost::bst_ulong nindptr,
|
||||
xgboost::bst_ulong nelem,
|
||||
DMatrixHandle* out) {
|
||||
const unsigned* indices,
|
||||
const float* data,
|
||||
xgboost::bst_ulong nindptr,
|
||||
xgboost::bst_ulong nelem,
|
||||
DMatrixHandle* out) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
@@ -292,15 +291,15 @@ XGB_DLL int XGDMatrixCreateFromCSC(const xgboost::bst_ulong* col_ptr,
|
||||
mat.info.num_row = mat.row_ptr_.size() - 1;
|
||||
mat.info.num_col = static_cast<uint64_t>(ncol);
|
||||
mat.info.num_nonzero = nelem;
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromMat(const float* data,
|
||||
xgboost::bst_ulong nrow,
|
||||
xgboost::bst_ulong ncol,
|
||||
float missing,
|
||||
DMatrixHandle* out) {
|
||||
xgboost::bst_ulong nrow,
|
||||
xgboost::bst_ulong ncol,
|
||||
float missing,
|
||||
DMatrixHandle* out) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
@@ -324,19 +323,19 @@ XGB_DLL int XGDMatrixCreateFromMat(const float* data,
|
||||
mat.row_ptr_.push_back(mat.row_ptr_.back() + nelem);
|
||||
}
|
||||
mat.info.num_nonzero = mat.row_data_.size();
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
DMatrixHandle* out) {
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
DMatrixHandle* out) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
data::SimpleCSRSource src;
|
||||
src.CopyFrom(static_cast<DMatrix*>(handle));
|
||||
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
||||
data::SimpleCSRSource& ret = *source;
|
||||
|
||||
CHECK_EQ(src.info.group_ptr.size(), 0)
|
||||
@@ -371,21 +370,21 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
ret.info.root_index.push_back(src.info.root_index[ridx]);
|
||||
}
|
||||
}
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
|
||||
API_BEGIN();
|
||||
delete static_cast<DMatrix*>(handle);
|
||||
delete static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
|
||||
const char* fname,
|
||||
int silent) {
|
||||
const char* fname,
|
||||
int silent) {
|
||||
API_BEGIN();
|
||||
static_cast<DMatrix*>(handle)->SaveToLocalFile(fname);
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -394,7 +393,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const float* info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
static_cast<DMatrix*>(handle)->info().SetInfo(field, info, kFloat32, len);
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->info().SetInfo(field, info, kFloat32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -403,16 +403,17 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const unsigned* info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
static_cast<DMatrix*>(handle)->info().SetInfo(field, info, kUInt32, len);
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->info().SetInfo(field, info, kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
DMatrix *pmat = static_cast<DMatrix*>(handle);
|
||||
MetaInfo& info = pmat->info();
|
||||
std::shared_ptr<DMatrix> *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||
MetaInfo& info = pmat->get()->info();
|
||||
info.group_ptr.resize(len + 1);
|
||||
info.group_ptr[0] = 0;
|
||||
for (uint64_t i = 0; i < len; ++i) {
|
||||
@@ -422,11 +423,11 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
|
||||
const char* field,
|
||||
xgboost::bst_ulong* out_len,
|
||||
const float** out_dptr) {
|
||||
const char* field,
|
||||
xgboost::bst_ulong* out_len,
|
||||
const float** out_dptr) {
|
||||
API_BEGIN();
|
||||
const MetaInfo& info = static_cast<const DMatrix*>(handle)->info();
|
||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info();
|
||||
const std::vector<float>* vec = nullptr;
|
||||
if (!std::strcmp(field, "label")) {
|
||||
vec = &info.labels;
|
||||
@@ -443,11 +444,11 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
const char *field,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const unsigned **out_dptr) {
|
||||
const char *field,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const unsigned **out_dptr) {
|
||||
API_BEGIN();
|
||||
const MetaInfo& info = static_cast<const DMatrix*>(handle)->info();
|
||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info();
|
||||
const std::vector<unsigned>* vec = nullptr;
|
||||
if (!std::strcmp(field, "root_index")) {
|
||||
vec = &info.root_index;
|
||||
@@ -460,16 +461,18 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
*out = static_cast<xgboost::bst_ulong>(static_cast<const DMatrix*>(handle)->info().num_row);
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info().num_row);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
*out = static_cast<size_t>(static_cast<const DMatrix*>(handle)->info().num_col);
|
||||
*out = static_cast<size_t>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info().num_col);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -478,9 +481,9 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
xgboost::bst_ulong len,
|
||||
BoosterHandle *out) {
|
||||
API_BEGIN();
|
||||
std::vector<DMatrix*> mats;
|
||||
std::vector<std::shared_ptr<DMatrix> > mats;
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
mats.push_back(static_cast<DMatrix*>(dmats[i]));
|
||||
mats.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
|
||||
}
|
||||
*out = new Booster(mats);
|
||||
API_END();
|
||||
@@ -493,50 +496,52 @@ XGB_DLL int XGBoosterFree(BoosterHandle handle) {
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
||||
const char *name,
|
||||
const char *value) {
|
||||
const char *name,
|
||||
const char *value) {
|
||||
API_BEGIN();
|
||||
static_cast<Booster*>(handle)->SetParam(name, value);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
|
||||
int iter,
|
||||
DMatrixHandle dtrain) {
|
||||
int iter,
|
||||
DMatrixHandle dtrain) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
DMatrix *dtr = static_cast<DMatrix*>(dtrain);
|
||||
std::shared_ptr<DMatrix> *dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
|
||||
bst->LazyInit();
|
||||
bst->learner()->UpdateOneIter(iter, dtr);
|
||||
bst->learner()->UpdateOneIter(iter, dtr->get());
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
DMatrixHandle dtrain,
|
||||
float *grad,
|
||||
float *hess,
|
||||
xgboost::bst_ulong len) {
|
||||
DMatrixHandle dtrain,
|
||||
float *grad,
|
||||
float *hess,
|
||||
xgboost::bst_ulong len) {
|
||||
std::vector<bst_gpair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
DMatrix* dtr = static_cast<DMatrix*>(dtrain);
|
||||
std::shared_ptr<DMatrix>* dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
tmp_gpair.resize(len);
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
tmp_gpair[i] = bst_gpair(grad[i], hess[i]);
|
||||
}
|
||||
|
||||
bst->LazyInit();
|
||||
bst->learner()->BoostOneIter(0, dtr, &tmp_gpair);
|
||||
bst->learner()->BoostOneIter(0, dtr->get(), &tmp_gpair);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
int iter,
|
||||
DMatrixHandle dmats[],
|
||||
const char* evnames[],
|
||||
xgboost::bst_ulong len,
|
||||
const char** out_str) {
|
||||
int iter,
|
||||
DMatrixHandle dmats[],
|
||||
const char* evnames[],
|
||||
xgboost::bst_ulong len,
|
||||
const char** out_str) {
|
||||
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
@@ -544,7 +549,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
std::vector<std::string> data_names;
|
||||
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
data_sets.push_back(static_cast<DMatrix*>(dmats[i]));
|
||||
data_sets.push_back(static_cast<std::shared_ptr<DMatrix>*>(dmats[i])->get());
|
||||
data_names.push_back(std::string(evnames[i]));
|
||||
}
|
||||
|
||||
@@ -555,17 +560,17 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
xgboost::bst_ulong *len,
|
||||
const float **out_result) {
|
||||
DMatrixHandle dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
xgboost::bst_ulong *len,
|
||||
const float **out_result) {
|
||||
std::vector<float>& preds = XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
bst->learner()->Predict(
|
||||
static_cast<DMatrix*>(dmat),
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
(option_mask & 1) != 0,
|
||||
&preds, ntree_limit,
|
||||
(option_mask & 2) != 0);
|
||||
|
||||
Reference in New Issue
Block a user